diff --git a/.env.example b/.env.example index d87fa74..9e97462 100644 --- a/.env.example +++ b/.env.example @@ -40,5 +40,9 @@ LOG_TO_STDOUT=True RATELIMIT_ENABLED=True RATELIMIT_STORAGE_URL=redis://localhost:6379/1 -# Testing -TESTING=False +# SSH CA +# Path to CA private key file (alternative to SSH_CA_PRIVATE_KEY env var) +SSH_CA_KEY_PATH=/path/to/ca-users +# Or set the key content directly (takes priority over SSH_CA_KEY_PATH): +# SSH_CA_PRIVATE_KEY= + diff --git a/client/gatehouse-cli.py b/client/gatehouse-cli.py new file mode 100755 index 0000000..2060c5a --- /dev/null +++ b/client/gatehouse-cli.py @@ -0,0 +1,528 @@ +#!/usr/bin/python3 +import base64 +import os +import sys +import webbrowser +import requests +import argparse +import jwt +import json +import datetime +import pytz +from http.server import BaseHTTPRequestHandler, HTTPServer +from urllib.parse import urlparse, parse_qsl +from dotenv import load_dotenv +from sshkey_tools.cert import SSHCertificate +import logging +import coloredlogs +import subprocess + +# Load environment variables from the .env file +load_dotenv() + +# Get the API_URL from the environment variables +SIGN_URL = os.getenv("SIGN_URL", "http://localhost:5000") +LISTENER_HOST_NAME = "127.0.0.1" +LISTENER_SERVER_PORT = 8250 +CACHE_FILE = os.path.expanduser('~/.gatehouse/token_cache.json') +os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True) +CERT_FILE_PATH = "/tmp/ssh-cert" +CHALLENGE_FILE_PATH = "/tmp/challenge.txt" +CHALLENGE_SIG_FILE_PATH = "/tmp/challenge.txt.sig" + +# Configure logger +logger = logging.getLogger(__name__) +coloredlogs.install(level='DEBUG', logger=logger, fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + +token = "" + +def auth_headers(content_type="application/json"): + """Return auth headers using the current cached token.""" + return {"Authorization": f"Bearer {token}", "Content-Type": content_type} + + +class MyServer(BaseHTTPRequestHandler): + def do_GET(self): + """Handle GET requests and process token reception.""" + global server_done, token + + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(bytes("OIDC Workflow Tool", "utf-8")) + self.wfile.write(bytes("

The token has been received

", "utf-8")) + self.wfile.write(bytes("

You may now close this window.

", "utf-8")) + self.wfile.write(bytes("", "utf-8")) + + parsed_url = urlparse(self.path) + query_data = dict(parse_qsl(parsed_url.query)) + received_token = query_data.get('token') + + if received_token: + token = received_token + server_done = True + logger.info("Token received") + save_token_to_cache(token) + + def log_message(self, format, *args): + """Log messages using the logger instead of stdout.""" + logger.info("%s - %s" % (self.client_address[0], format % args)) + + +def load_token_from_cache(): + """Load the token from the cache file.""" + if os.path.exists(CACHE_FILE): + with open(CACHE_FILE, 'r') as f: + data = json.load(f) + if 'token' in data: + return data['token'] + return None + +def save_token_to_cache(token): + """Save the token to the cache file.""" + with open(CACHE_FILE, 'w') as f: + json.dump({'token': token}, f) + +def clear_token_cache(): + """Remove the cached token file.""" + if os.path.exists(CACHE_FILE): + os.remove(CACHE_FILE) + logger.info("Cached token removed.") + else: + logger.info("No cached token found.") + +def decode_and_validate_token(token): + """Decode the JWT and validate its claims. + + Returns True if the token is a valid, non-expired JWT. + Returns False if the token is not a JWT (e.g. opaque session token) + or if it has expired — callers should then fall back to /auth/me. + """ + try: + decoded_token = jwt.decode(token, options={"verify_signature": False}) + except jwt.exceptions.DecodeError: + # Not a JWT — likely an opaque session token; let /auth/me handle it. + return False + except Exception as e: + logger.debug(f"Unexpected JWT decode error: {e}") + return False + + iat = decoded_token.get('iat') + exp = decoded_token.get('exp') + + if iat is None or exp is None: + logger.debug("JWT is missing 'iat' or 'exp' claims — treating as invalid.") + return False + + now = datetime.datetime.now(pytz.UTC) + exp_dt = datetime.datetime.fromtimestamp(exp, pytz.UTC) + iat_dt = datetime.datetime.fromtimestamp(iat, pytz.UTC) + + logger.debug(f"JWT iat={iat_dt.isoformat()} exp={exp_dt.isoformat()}") + + if exp_dt < now: + logger.debug("JWT has expired.") + return False + + if iat_dt > now: + logger.debug("JWT 'iat' is in the future — clock skew?") + + return True + +def request_token(): + global server_done, token + server_done = False + logger.info("Starting request_token process.") + + # Attempt to load the token from the cache + token = load_token_from_cache() + logger.debug("Token loaded from cache: %s", token) + + # Validate the cached token, if it exists + if token: + try: + if decode_and_validate_token(token): + logger.info("Cached token is valid. Using cached token.") + return token + except Exception: + pass + # Try opaque token via /auth/me + try: + r = requests.get( + f"{SIGN_URL}/api/v1/auth/me", + headers={"Authorization": f"Bearer {token}"}, + timeout=5, + ) + if r.status_code == 200: + logger.info("Cached session token is valid. Using cached token.") + return token + except Exception: + pass + logger.info("Cached token is expired or invalid, requesting a new token.") + token = "" + + # Prepare the redirect URL for the token request + redirect_url = f"http://{LISTENER_HOST_NAME}:{LISTENER_SERVER_PORT}/?token=" + logger.info("Redirect URL: %s", redirect_url) + + # Construct the token request URL + token_url = f"{SIGN_URL}/api/v1/token_please?redirect_url={redirect_url}" + logger.info("Token request URL: %s", token_url) + + # Start the web server to handle the token response + logger.debug("Starting the HTTP server on %s:%d", LISTENER_HOST_NAME, LISTENER_SERVER_PORT) + webServer = HTTPServer((LISTENER_HOST_NAME, LISTENER_SERVER_PORT), MyServer) + + # Open the web browser to initiate the token request + logger.info("Opening web browser to request token.") + webbrowser.open(token_url, new=2) + + # Wait for the server to handle the request and receive the token + logger.debug("Waiting for the token response...") + while not server_done: + webServer.handle_request() + logger.debug("Server handled a request, server_done status: %s", server_done) + + logger.info("Token received: %s", token) + return token + +def get_activated_ssh_key(): + """Retrieve the list of SSH keys and return the ID of a verified key.""" + try: + response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers()) + if response.status_code != 200: + logger.error(f"Failed to retrieve SSH keys: {response.status_code} - {response.text}") + exit(1) + + keys = response.json().get('data', {}).get('keys', []) + verified_keys = [k for k in keys if k['verified']] + + if not verified_keys: + logger.error("No verified SSH keys found for the user.") + exit(1) + + if len(verified_keys) > 1 and sys.stdout.isatty(): + print("\nMultiple verified SSH keys found. Please choose one:") + for i, k in enumerate(verified_keys): + print(f" [{i+1}] {k['id'][:8]}... fingerprint={k.get('fingerprint','?')} name={k.get('key_comment','?')}") + try: + choice = int(input("Enter number: ").strip()) - 1 + if 0 <= choice < len(verified_keys): + return verified_keys[choice]['id'] + except (ValueError, EOFError): + pass + logger.info("Invalid choice; using the most recently added key.") + + verified_keys.sort(key=lambda k: k.get('created_at', ''), reverse=True) + return verified_keys[0]['id'] + + except SystemExit: + raise + except Exception as e: + logger.error(f"Error while retrieving SSH keys: {e}") + exit(1) + + +def fetch_my_principals(): + """Fetch all principal names the current user is entitled to from the API. + For regular members: returns their assigned principals. + For org admins/owners: returns all principals in the org (they can sign for any). + """ + global token + response = requests.get( + f"{SIGN_URL}/api/v1/users/me/principals", + headers={"Authorization": f"Bearer {token}"}, + timeout=10, + ) + if response.status_code != 200: + logger.error(f"Failed to fetch principals from server: {response.status_code} - {response.text}") + exit(1) + + orgs = response.json().get("data", {}).get("orgs", []) + principal_names = [] + for org in orgs: + # Admins/owners get all principals; regular members get only their assigned ones + if org.get("is_admin"): + source = org.get("all_principals", []) + else: + source = org.get("my_principals", []) + for p in source: + if p["name"] not in principal_names: + principal_names.append(p["name"]) + + return principal_names + + +def request_certificate(): + CERT_ID = os.getenv("CERT_ID") or get_activated_ssh_key() + + principals = fetch_my_principals() + if not principals: + logger.error("You have no principals assigned. Contact your org admin.") + exit(1) + logger.info(f"Requesting certificate for principals: {', '.join(principals)}") + + headers = { + 'content-type': 'application/json', + "Authorization": "bearer " + token + } + + payload = { + 'cert_id': CERT_ID, + 'principals': principals, + } + + try: + response = requests.post(f"{SIGN_URL}/api/v1/ssh/sign", json=payload, headers=headers) + + if response.status_code == 201: + json_result = response.json().get('data', response.json()) + with open(CERT_FILE_PATH, 'w') as f: + f.write(json_result['certificate']) + logger.info(f"Certificate signed successfully, located at {CERT_FILE_PATH}") + logger.info(f"Valid for principals: {', '.join(json_result.get('principals', principals))}") + logger.info("You can login to your destination server with the following command") + logger.info(f"\tssh user@server -o CertificateFile={CERT_FILE_PATH}") + else: + logger.error("Error in response from server") + logger.error(f"Status code: {response.status_code}") + logger.error(f"Response text: {response.text}") + except Exception as e: + logger.error(f"Error during certificate signing: {e}") + +def generate_and_sign_challenge(ssh_key_file, key_id): + """Fetch a challenge from the server, sign it with the SSH key, and submit the signature.""" + logger.debug(f"generate_and_sign_challenge - {ssh_key_file} {key_id}") + + # Fetch challenge text + try: + response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify", headers=auth_headers()) + if response.status_code != 200: + logger.error(f"Server returned unexpected code {response.status_code}") + return False + resp_json = response.json() + data = resp_json.get('data', resp_json) + challenge_text = data.get('challenge_text', data.get('validationText', '')) + "\n" + except Exception as e: + logger.error(f"Unable to fetch SSH Key validation data: {e}") + return False + + # Sign the challenge + try: + for path in (CHALLENGE_FILE_PATH, CHALLENGE_SIG_FILE_PATH): + if os.path.exists(path): + os.remove(path) + + with open(CHALLENGE_FILE_PATH, 'w') as f: + f.write(challenge_text) + + subprocess.run( + ["ssh-keygen", "-Y", "sign", "-f", ssh_key_file, "-n", "file", CHALLENGE_FILE_PATH], + check=True, + ) + + with open(CHALLENGE_SIG_FILE_PATH, 'rb') as f: + signature = base64.b64encode(f.read()).decode('utf-8') + except Exception as e: + logger.error(f"Unable to sign the challenge response: {e}") + return False + + # Submit signature + try: + response = requests.post( + f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify", + headers=auth_headers(), + json={"signature": signature}, + ) + if response.status_code == 200: + logger.info("SSH key verified successfully.") + else: + logger.error(f"Verification failed: {response.status_code} - {response.text}") + except Exception as e: + logger.error(f"Unable to submit the challenge response: {e}") + + return signature + +def remove_ssh_key(key_id=None): + """ + Remove an SSH key from the server. If key_id is None, list keys and prompt user to pick one. + """ + response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers()) + if response.status_code != 200: + logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}") + exit(1) + + keys = response.json().get('data', {}).get('keys', []) + if not keys: + logger.info("No SSH keys found for your user.") + return + + if key_id: + target = next((k for k in keys if k['id'] == key_id), None) + if not target: + logger.error(f"Key ID {key_id} not found in your profile.") + exit(1) + keys_to_delete = [target] + else: + print("\nYour SSH keys:") + for i, k in enumerate(keys): + verified = "✓ verified" if k['verified'] else "✗ unverified" + print(f" [{i+1}] {k['id']} {verified} {k.get('description', '')} (added {k['created_at'][:10]})") + print(" [a] Delete ALL keys") + print(" [q] Quit") + choice = input("\nEnter number to delete (or 'a' for all, 'q' to quit): ").strip().lower() + + if choice == 'q': + return + elif choice == 'a': + keys_to_delete = keys + else: + try: + idx = int(choice) - 1 + if idx < 0 or idx >= len(keys): + raise ValueError() + keys_to_delete = [keys[idx]] + except ValueError: + logger.error("Invalid selection.") + exit(1) + + for k in keys_to_delete: + del_response = requests.delete(f"{SIGN_URL}/api/v1/ssh/keys/{k['id']}", headers=auth_headers()) + if del_response.status_code == 200: + logger.info(f"Key {k['id']} removed successfully.") + else: + logger.error(f"Failed to remove key {k['id']}: {del_response.status_code} - {del_response.text}") + + +def add_ssh_key(ssh_key_file): + """Add an SSH key to the server and auto-verify it.""" + if hasattr(ssh_key_file, 'read'): + key_bytes = ssh_key_file.read() + key_path = ssh_key_file.name + elif isinstance(ssh_key_file, bytes): + key_bytes = ssh_key_file + key_path = None + else: + key_path = str(ssh_key_file) + with open(key_path, 'rb') as f: + key_bytes = f.read() + + ssh_key = key_bytes.decode('utf-8').strip() + payload = { + 'description': 'Added via gatehouse CLI tool', + 'key': ssh_key, + } + + response = requests.post(f"{SIGN_URL}/api/v1/ssh/keys", json=payload, headers=auth_headers()) + if response.status_code == 201: + ssh_key_id = response.json().get('data', {}).get('id') + logger.info(f"SSH key {ssh_key_id} added successfully") + if key_path: + private_key_path = key_path[:-4] if key_path.endswith('.pub') else key_path + generate_and_sign_challenge(private_key_path, ssh_key_id) + else: + logger.warning("No key file path available — skipping auto-verification. " + "Run with -k to enable automatic key verification.") + else: + logger.error(f"Failed to add SSH key: {response.status_code} - {response.text}") + +def checkCert(): + logger.info("Running cert check") + if not os.path.isfile(CERT_FILE_PATH): + logger.warning("Certificate does not exist, new certificate required") + return 1 + + try: + certificate = SSHCertificate.from_file(CERT_FILE_PATH) + except Exception: + logger.warning("Certificate file is invalid or corrupt, renewal required") + return 1 + + # Get the current datetime + now = datetime.datetime.now() + logger.debug(certificate + ) + + # Check if the date is in the past or future + if certificate.get("valid_before") > now: + # Expiry is in the future + if args.force: + return 0 + else: + logger.info("You have a valid SSH Certificate with the principals {} expiring at {}, not renewing. Use -f to force renewal".format(certificate.get("principals"), certificate.get("valid_before"))) + return 0 + else: + logger.warning("Certificate is not valid, renewal required") + return 1 + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Sign an SSH key via a web service') + parser.add_argument("-k", "--ssh-key", type=argparse.FileType('rb'), dest="sshkeyfile", help="Add an SSH Public Key to your user profile in gatehouse") + parser.add_argument("-f", "--force", action='store_true', default=False, help="Force the certificate renewal") + parser.add_argument("-a", "--add-key", action='store_true', default=False, help="Add SSH key to the server") + parser.add_argument("-c", "--check-cert", action='store_true', default=False, help="Check the certificate, if it's valid exit 0, if it's invalid exit 1") + parser.add_argument("-r", "--request-cert", action='store_true', default=False, help="Request that gatehouse sign a new certificate for you based on an SSH public key on file in your profile") + parser.add_argument("--clear-cache", action='store_true', default=False, help="Remove the cached authentication token") + parser.add_argument("--remove-key", nargs='?', const='', metavar='KEY_ID', help="Remove an SSH key from your profile. Omit KEY_ID to pick interactively.") + parser.add_argument("--list-keys", action='store_true', default=False, help="List SSH keys in your profile") + + args = parser.parse_args() + if not (args.check_cert or args.request_cert or args.add_key or args.clear_cache + or args.remove_key is not None or args.list_keys): + parser.error("At least one of --check-cert, --request-cert, --add-key, --list-keys, --remove-key, or --clear-cache must be provided.") + + + # Retrieve SSH key from environment variables if not provided via CLI + ssh_key_file = args.sshkeyfile if args.sshkeyfile else os.getenv('SSH_KEY_FILE') + + if args.check_cert: + logger.info("Only checking certificate") + exit(checkCert()) + + if args.clear_cache: + clear_token_cache() + exit(0) + + if args.remove_key is not None: + request_token() + remove_ssh_key(args.remove_key if args.remove_key else None) + exit(0) + + if args.list_keys: + request_token() + response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers()) + if response.status_code == 200: + keys = response.json().get('data', {}).get('keys', []) + if not keys: + print("No SSH keys found in your profile.") + else: + for k in keys: + verified = "✓ verified" if k.get('verified') else "✗ unverified" + print(f" {k['id']} {verified} {k.get('description', '')} (added {k['created_at'][:10]})") + else: + logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}") + exit(0) + + if args.add_key: + request_token() + + if not ssh_key_file: + logger.error("SSH key file is required to add SSH key") + exit(1) + + # If ssh_key_file is retrieved from the environment, it will be a string (file path), so open it + if isinstance(ssh_key_file, str): + with open(ssh_key_file, 'rb') as f: + ssh_key_file = f.read() + + add_ssh_key(ssh_key_file) + exit(0) + + + if args.request_cert: + request_token() + if args.force: + logger.info("Forcing renewal of certificate") + if args.force or checkCert() == 1: + request_certificate() + exit(0) diff --git a/config/base.py b/config/base.py index 4a2dfe6..73a4d7f 100644 --- a/config/base.py +++ b/config/base.py @@ -28,6 +28,11 @@ class BaseConfig: # Encryption key for sensitive data (client secrets, tokens, etc.) ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "dev-encryption-key-change-in-production") + + # Encryption key for CA private keys stored in the database. + # Must be set to a strong random secret in production. + # Any string is accepted — it is SHA-256 derived to a 32-byte Fernet key internally. + CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "dev-ca-encryption-key-change-in-production") # Session configuration for WebAuthn cross-origin support SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "True").lower() == "true" @@ -72,6 +77,13 @@ class BaseConfig: RATELIMIT_STORAGE_URL = os.getenv("RATELIMIT_STORAGE_URL", "redis://localhost:6379/1") RATELIMIT_DEFAULT = "100/hour" + # Per-endpoint auth rate limits (override via env vars for each environment) + RATELIMIT_AUTH_REGISTER = os.getenv("RATELIMIT_AUTH_REGISTER", "10 per minute; 50 per hour") + RATELIMIT_AUTH_LOGIN = os.getenv("RATELIMIT_AUTH_LOGIN", "20 per minute; 100 per hour") + RATELIMIT_AUTH_TOTP_VERIFY = os.getenv("RATELIMIT_AUTH_TOTP_VERIFY", "20 per minute; 100 per hour") + RATELIMIT_AUTH_FORGOT_PASSWORD = os.getenv("RATELIMIT_AUTH_FORGOT_PASSWORD", "5 per minute; 20 per hour") + RATELIMIT_AUTH_RESET_PASSWORD = os.getenv("RATELIMIT_AUTH_RESET_PASSWORD", "10 per minute; 30 per hour") + # Logging LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true" @@ -116,3 +128,11 @@ class BaseConfig: # Frontend URL (for OAuth callback redirects) FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:8080") + + # Email / SMTP + EMAIL_ENABLED = os.getenv("EMAIL_ENABLED", "False").lower() == "true" + SMTP_HOST = os.getenv("SMTP_HOST", "smtp.gmail.com") + SMTP_PORT = int(os.getenv("SMTP_PORT", "587")) + SMTP_USERNAME = os.getenv("SMTP_USERNAME", "") + SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "") + FROM_ADDRESS = os.getenv("FROM_ADDRESS", "noreply@gatehouse.local") diff --git a/config/development.py b/config/development.py index 840a84b..5436adf 100644 --- a/config/development.py +++ b/config/development.py @@ -25,3 +25,15 @@ class DevelopmentConfig(BaseConfig): "CORS_ORIGINS", "http://localhost:8080,http://localhost:3000,http://localhost:5173,https://ui.webauthn.local" ).split(",") + + # ── Email / SMTP ────────────────────────────────────────────────────────── + # Read from .env so real SMTP credentials work in dev. + # Set EMAIL_ENABLED=false in .env to disable; defaults to True if SMTP_HOST is set. + EMAIL_ENABLED = os.getenv("EMAIL_ENABLED", "True").lower() == "true" + SMTP_HOST = os.getenv("SMTP_HOST", "localhost") + SMTP_PORT = int(os.getenv("SMTP_PORT", "1025")) + SMTP_USERNAME = os.getenv("SMTP_USERNAME") or None + SMTP_PASSWORD = os.getenv("SMTP_PASSWORD") or None + SMTP_USE_TLS = os.getenv("SMTP_USE_TLS", "").lower() == "true" if os.getenv("SMTP_USE_TLS") else int(os.getenv("SMTP_PORT", "1025")) not in (25, 1025) + FROM_ADDRESS = os.getenv("FROM_ADDRESS", "noreply@gatehouse.local") + EMAIL_FROM = FROM_ADDRESS # alias diff --git a/config/testing.py b/config/testing.py index 4ecfff0..aa988a7 100644 --- a/config/testing.py +++ b/config/testing.py @@ -12,6 +12,9 @@ class TestingConfig(BaseConfig): # Explicitly set SECRET_KEY for testing SECRET_KEY = os.getenv("SECRET_KEY", "test-secret-key-for-testing") + # CA key encryption — use a fixed test key so tests are deterministic + CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "test-ca-encryption-key-fixed-for-tests") + # Use in-memory SQLite for testing SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:" SQLALCHEMY_ECHO = False diff --git a/etc/ssh_ca.conf b/etc/ssh_ca.conf new file mode 100644 index 0000000..212de10 --- /dev/null +++ b/etc/ssh_ca.conf @@ -0,0 +1,30 @@ + +[default] +# Certificate validity period (in hours) +cert_validity_hours=8 + +# Maximum certificate validity allowed (in hours) +max_cert_validity_hours=720 + +# CA private key path (required for local encryption mode) +ca_key_path= + +# Certificate Field Limits +max_principals_per_cert=256 +max_key_id_length=255 + +# Verification challenge max age (in hours) +verification_challenge_max_age=24 + +# Cleanup: delete unverified SSH keys after this many days +auto_delete_unverified_days=30 + +[development] +ca_key_path=${SSH_CA_KEY_PATH} +cert_validity_hours=24 + +[production] +cert_validity_hours=8 + +[testing] +cert_validity_hours=8 diff --git a/gatehouse_app/__init__.py b/gatehouse_app/__init__.py index a4e1043..f17c784 100644 --- a/gatehouse_app/__init__.py +++ b/gatehouse_app/__init__.py @@ -2,6 +2,9 @@ import os import logging +from dotenv import load_dotenv +load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env')) + # Test debug logging - this should appear when running `flask run --debug` _root_logger = logging.getLogger(__name__) _root_logger.debug("[TEST] Debug logging is working!") @@ -239,3 +242,6 @@ def initialize_oidc_jwks(app): app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}") except Exception as e: app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}") + +# Create default app instance for gunicorn/wsgi +app = create_app() diff --git a/gatehouse_app/api/oidc.py b/gatehouse_app/api/oidc.py index 5ec7700..43bf5b4 100644 --- a/gatehouse_app/api/oidc.py +++ b/gatehouse_app/api/oidc.py @@ -21,8 +21,12 @@ from gatehouse_app.extensions import db from gatehouse_app.extensions import bcrypt as flask_bcrypt from gatehouse_app.extensions import redis_client as _redis_client_ref # may be None until app init from gatehouse_app.models import User, OIDCClient -from gatehouse_app.models.organization import Organization -from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError +from gatehouse_app.models.organization.organization import Organization +from gatehouse_app.exceptions.auth_exceptions import ( + InvalidCredentialsError, + AccountSuspendedError, + AccountInactiveError, +) # --------------------------------------------------------------------------- # Helpers for Redis-backed OIDC pending state @@ -326,7 +330,7 @@ def oidc_complete(): 400: invalid request 401: invalid token """ - from gatehouse_app.models.session import Session as GHSession + from gatehouse_app.models.user.session import Session as GHSession from gatehouse_app.utils.constants import SessionStatus data = request.get_json(silent=True) or {} @@ -343,6 +347,20 @@ def oidc_complete(): user_id = str(gh_session.user_id) + # Check the user is still active (not suspended after session was issued) + from gatehouse_app.models.user.user import User as _User + from gatehouse_app.utils.constants import UserStatus + _complete_user = _User.query.filter_by(id=user_id, deleted_at=None).first() + if not _complete_user or _complete_user.status in ( + UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED, UserStatus.INACTIVE + ): + return api_response( + success=False, + message="Your account is not active or has been suspended.", + status=403, + error_type="ACCOUNT_SUSPENDED", + ) + # Retrieve stashed OIDC params (consume = True removes from Redis atomically) params = _fetch_oidc_params(oidc_session_id, consume=True) if not params: @@ -565,6 +583,28 @@ def oidc_authorize(): session["oidc_user_id"] = user_id logger.debug("[OIDC] User authentication successful: user_id=%s, email=%s", user_id, email) + except AccountSuspendedError: + logger.debug("[OIDC] User authentication failed: account suspended for email=%s", email) + return _show_login_page( + client_id=client_id, + redirect_uri=redirect_uri, + scope=scope, + state=state, + nonce=nonce, + response_type=response_type, + error="Your account has been suspended. Please contact an administrator.", + ) + except AccountInactiveError: + logger.debug("[OIDC] User authentication failed: account inactive for email=%s", email) + return _show_login_page( + client_id=client_id, + redirect_uri=redirect_uri, + scope=scope, + state=state, + nonce=nonce, + response_type=response_type, + error="Your account is not active. Please verify your email.", + ) except InvalidCredentialsError: logger.debug("[OIDC] User authentication failed: invalid credentials for email=%s", email) return _show_login_page( @@ -600,7 +640,34 @@ def oidc_authorize(): if not user: logger.debug("[OIDC] Redirecting with error: server_error (user not found)") return _redirect_with_error(redirect_uri, "server_error", "User not found", state) - + + # Check account is still active (user could have been suspended after session start) + from gatehouse_app.utils.constants import UserStatus as _UserStatus + if user.status in (_UserStatus.SUSPENDED, _UserStatus.COMPLIANCE_SUSPENDED): + session.pop("oidc_user_id", None) # clear stale session + logger.debug("[OIDC] User is suspended, clearing session and showing login error: user_id=%s", user_id) + return _show_login_page( + client_id=client_id, + redirect_uri=redirect_uri, + scope=scope, + state=state, + nonce=nonce, + response_type=response_type, + error="Your account has been suspended. Please contact an administrator.", + ) + if user.status == _UserStatus.INACTIVE: + session.pop("oidc_user_id", None) + logger.debug("[OIDC] User is inactive, clearing session and showing login error: user_id=%s", user_id) + return _show_login_page( + client_id=client_id, + redirect_uri=redirect_uri, + scope=scope, + state=state, + nonce=nonce, + response_type=response_type, + error="Your account is not active. Please verify your email.", + ) + logger.debug("[OIDC] Generating authorization code...") logger.debug("[OIDC] Authorization code params: client_id=%s, user_id=%s, redirect_uri=%s", client_id, user_id, redirect_uri) logger.debug("[OIDC] Authorization code params: scopes=%s, state=%s, nonce=%s", valid_scopes, state, nonce) diff --git a/gatehouse_app/api/v1/__init__.py b/gatehouse_app/api/v1/__init__.py index b97cd94..14534ed 100644 --- a/gatehouse_app/api/v1/__init__.py +++ b/gatehouse_app/api/v1/__init__.py @@ -5,4 +5,6 @@ from flask import Blueprint api_v1_bp = Blueprint("api_v1", __name__) # Import route modules to register them -from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth +from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh + +api_v1_bp.register_blueprint(ssh.ssh_bp) diff --git a/gatehouse_app/api/v1/auth.py b/gatehouse_app/api/v1/auth.py index ebffe8d..993f213 100644 --- a/gatehouse_app/api/v1/auth.py +++ b/gatehouse_app/api/v1/auth.py @@ -1,8 +1,10 @@ """Authentication endpoints.""" import json -from flask import request, session, g, jsonify +import logging +from flask import request, session, g, jsonify, current_app from marshmallow import ValidationError from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.extensions import limiter from gatehouse_app.utils.response import api_response from gatehouse_app.schemas.auth_schema import ( RegisterSchema, @@ -23,6 +25,7 @@ from gatehouse_app.services.auth_service import AuthService from gatehouse_app.services.webauthn_service import WebAuthnService from gatehouse_app.services.user_service import UserService from gatehouse_app.services.mfa_policy_service import MfaPolicyService +from gatehouse_app.services.notification_service import NotificationService from gatehouse_app.utils.decorators import login_required from gatehouse_app.utils.constants import AuditAction from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError @@ -30,6 +33,7 @@ from gatehouse_app.exceptions.validation_exceptions import ConflictError, NotFou @api_v1_bp.route("/auth/register", methods=["POST"]) +@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_REGISTER"]) def register(): """ Register a new user. @@ -57,14 +61,66 @@ def register(): full_name=data.get("full_name"), ) + # Send verification email + try: + from gatehouse_app.models import EmailVerificationToken + verify_token = EmailVerificationToken.generate(user_id=user.id) + app_url = current_app.config.get("APP_URL", "http://localhost:8080") + verify_link = f"{app_url}/verify-email?token={verify_token.token}" + subject = "Verify your Gatehouse email address" + body = ( + f"Hi {user.full_name or user.email},\n\n" + f"Welcome to Gatehouse! Please verify your email address by clicking the link below (valid for 24 hours):\n" + f"{verify_link}\n\n" + f"Gatehouse Security Team" + ) + NotificationService._send_email(to_address=user.email, subject=subject, body=body) + except Exception as exc: + logging.getLogger(__name__).warning(f"Failed to send verification email on register: {exc}") + # Create session user_session = AuthService.create_session(user) + # ── Post-registration hints ───────────────────────────────────────── + from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + from gatehouse_app.models.user.user import User as _User + from datetime import datetime, timezone as _tz + + now = datetime.now(_tz.utc) + pending_invites = OrgInviteToken.query.filter( + OrgInviteToken.email == user.email, + OrgInviteToken.accepted_at.is_(None), + OrgInviteToken.expires_at > now, + OrgInviteToken.deleted_at.is_(None), + ).all() + + # Determine if this is the very first user ever registered on this + # instance (exactly 1 active user means it must be this one). + total_users = _User.query.filter(_User.deleted_at.is_(None)).count() + is_first_user = total_users == 1 + + expires_str = user_session.expires_at.isoformat() + if expires_str[-1] != "Z": + expires_str += "Z" + return api_response( data={ "user": user.to_dict(), "token": user_session.token, - "expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(), + "expires_at": expires_str, + "is_first_user": is_first_user, + "pending_invites": [ + { + "token": inv.token, + "organization": { + "id": str(inv.organization_id), + "name": inv.organization.name, + }, + "role": inv.role, + "expires_at": inv.expires_at.isoformat(), + } + for inv in pending_invites + ], }, message="Registration successful", status=201, @@ -81,6 +137,7 @@ def register(): @api_v1_bp.route("/auth/login", methods=["POST"]) +@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_LOGIN"]) def login(): """ Login user. @@ -179,7 +236,9 @@ def login(): "organization_id": org.organization_id, "organization_name": org.organization_name, "status": org.status, + "effective_mode": org.effective_mode, "deadline_at": org.deadline_at, + "applied_at": org.applied_at, } for org in policy_result.compliance_summary.orgs ], @@ -189,6 +248,36 @@ def login(): if is_compliance_only: response_data["requires_mfa_enrollment"] = True + # ── Org-setup hint for org-less users ──────────────────────────────── + # If the user has no organisation memberships, surface any pending + # invitations so the UI can redirect straight to /org-setup instead of + # showing an empty dashboard. + user_orgs = user.get_organizations() + if not user_orgs: + from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + from datetime import datetime, timezone as _tz + _now = datetime.now(_tz.utc) + pending_invites = OrgInviteToken.query.filter( + OrgInviteToken.email == user.email, + OrgInviteToken.accepted_at.is_(None), + OrgInviteToken.expires_at > _now, + OrgInviteToken.deleted_at.is_(None), + ).all() + response_data["pending_invites"] = [ + { + "token": inv.token, + "organization": { + "id": str(inv.organization_id), + "name": inv.organization.name, + }, + "role": inv.role, + "expires_at": inv.expires_at.isoformat(), + } + for inv in pending_invites + ] + # Flag so the UI knows to send this user through org-setup + response_data["requires_org_setup"] = True + return api_response( data=response_data, message="Login successful", @@ -239,8 +328,13 @@ def get_current_user(): data={ "user": user.to_dict(), "organizations": [ - {"id": org.id, "name": org.name, "slug": org.slug} - for org in user.get_organizations() + { + "id": membership.organization.id, + "name": membership.organization.name, + "slug": membership.organization.slug, + "role": membership.role, + } + for membership in user.organization_memberships ], }, message="User retrieved successfully", @@ -284,7 +378,7 @@ def revoke_session(session_id): 401: Not authenticated 404: Session not found """ - from gatehouse_app.models.session import Session + from gatehouse_app.models.user.session import Session # Ensure session belongs to current user user_session = Session.query.filter_by( @@ -392,6 +486,7 @@ def verify_totp_enrollment(): @api_v1_bp.route("/auth/totp/verify", methods=["POST"]) +@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_TOTP_VERIFY"]) def verify_totp(): """ Verify TOTP code during login. @@ -424,7 +519,7 @@ def verify_totp(): ) # Get user from database - from gatehouse_app.models.user import User + from gatehouse_app.models.user.user import User user = User.query.get(user_id) if not user: return api_response( @@ -434,6 +529,18 @@ def verify_totp(): error_type="AUTHENTICATION_ERROR", ) + # Check account suspension before completing TOTP verification + from gatehouse_app.utils.constants import UserStatus + if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + session.pop("totp_pending_user_id", None) + session.pop("webauthn_pending_user_id", None) + return api_response( + success=False, + message="Account is suspended. Contact an administrator.", + status=403, + error_type="ACCOUNT_SUSPENDED", + ) + # Verify TOTP code AuthService.authenticate_with_totp( user, @@ -475,7 +582,9 @@ def verify_totp(): "organization_id": org.organization_id, "organization_name": org.organization_name, "status": org.status, + "effective_mode": org.effective_mode, "deadline_at": org.deadline_at, + "applied_at": org.applied_at, } for org in policy_result.compliance_summary.orgs ], @@ -806,7 +915,7 @@ def begin_webauthn_login(): data = schema.load(request.json) # Find user by email - from gatehouse_app.models.user import User + from gatehouse_app.models.user.user import User user = User.query.filter_by( email=data["email"].lower(), deleted_at=None @@ -820,7 +929,18 @@ def begin_webauthn_login(): status=404, error_type="NOT_FOUND", ) - + + # Check account suspension before proceeding + from gatehouse_app.utils.constants import UserStatus + if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + logger.warning(f"WebAuthn login begin - suspended account attempt: {user.email}") + return api_response( + success=False, + message="Account is suspended. Contact an administrator.", + status=403, + error_type="ACCOUNT_SUSPENDED", + ) + # Check if user has any WebAuthn credentials if not user.has_webauthn_enabled(): logger.warning(f"WebAuthn login begin - no credentials for user: {user.email}") @@ -893,7 +1013,7 @@ def complete_webauthn_login(): data = schema.load(request.json) # Get user from database - from gatehouse_app.models.user import User + from gatehouse_app.models.user.user import User user = User.query.get(user_id) if not user: logger.error(f"WebAuthn login complete - user not found: {user_id}") @@ -903,7 +1023,19 @@ def complete_webauthn_login(): status=401, error_type="AUTHENTICATION_ERROR", ) - + + # Check account suspension before completing login + from gatehouse_app.utils.constants import UserStatus + if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + session.pop("webauthn_pending_user_id", None) + logger.warning(f"WebAuthn login complete - suspended account attempt: {user.email}") + return api_response( + success=False, + message="Account is suspended. Contact an administrator.", + status=403, + error_type="ACCOUNT_SUSPENDED", + ) + # Extract challenge from client data client_data = data.get("response", {}).get("clientDataJSON", "") @@ -962,7 +1094,9 @@ def complete_webauthn_login(): "organization_id": org.organization_id, "organization_name": org.organization_name, "status": org.status, + "effective_mode": org.effective_mode, "deadline_at": org.deadline_at, + "applied_at": org.applied_at, } for org in policy_result.compliance_summary.orgs ], @@ -1039,6 +1173,19 @@ def delete_webauthn_credential(credential_id): """ user = g.current_user + # First check that the specific credential actually belongs to this user. + # Only then check whether it is the last one — otherwise a user with zero + # credentials gets a misleading "Cannot delete the last passkey" error + # instead of a 404. + credential_exists = WebAuthnService.credential_belongs_to_user(credential_id, user) + if not credential_exists: + return api_response( + success=False, + message="Credential not found", + status=404, + error_type="NOT_FOUND", + ) + # Check if this is the last credential credential_count = user.get_webauthn_credential_count() if credential_count <= 1: @@ -1142,3 +1289,403 @@ def get_webauthn_status(): }, message="WebAuthn status retrieved successfully", ) + + +_pw_logger = logging.getLogger(__name__) + + +@api_v1_bp.route("/auth/forgot-password", methods=["POST"]) +@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_FORGOT_PASSWORD"]) +def forgot_password(): + """Request a password reset email. + + Always returns 200 to avoid leaking account existence. + + Request body: + email: User email address + + Returns: + 200: Password reset email sent (or silently no-op if email not found) + """ + from gatehouse_app.models import User, PasswordResetToken + + data = request.get_json() or {} + email = (data.get("email") or "").strip().lower() + + if not email: + return api_response( + success=False, + message="Email is required", + status=400, + error_type="VALIDATION_ERROR", + ) + + # Always return 200 — don't leak whether the email exists + user = User.query.filter_by(email=email, deleted_at=None).first() + if user: + try: + reset_token = PasswordResetToken.generate(user_id=user.id) + app_url = current_app.config.get("APP_URL", "http://localhost:8080") + reset_link = f"{app_url}/reset-password?token={reset_token.token}" + subject = "Reset your Gatehouse password" + body = ( + f"Hi {user.full_name or user.email},\n\n" + f"You requested a password reset for your Gatehouse account.\n\n" + f"Click the link below to reset your password (valid for 2 hours):\n" + f"{reset_link}\n\n" + f"If you did not request this, you can safely ignore this email.\n\n" + f"Gatehouse Security Team" + ) + NotificationService._send_email( + to_address=user.email, + subject=subject, + body=body, + ) + _pw_logger.info(f"Password reset token generated for user {user.id}") + except Exception as exc: + _pw_logger.exception(f"Error generating password reset token: {exc}") + + return api_response( + data={}, + message="If an account exists for this email, you will receive a password reset link shortly.", + ) + + +@api_v1_bp.route("/auth/reset-password", methods=["POST"]) +@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESET_PASSWORD"]) +def reset_password(): + """Reset a user's password using a reset token. + + Request body: + token: Password reset token from email + password: New password + password_confirm: Password confirmation + + Returns: + 200: Password reset successfully + 400: Invalid or expired token / validation error + """ + import bcrypt as _bcrypt + from gatehouse_app.extensions import bcrypt + from gatehouse_app.models import PasswordResetToken, AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + data = request.get_json() or {} + token_value = (data.get("token") or "").strip() + new_password = data.get("password") or "" + password_confirm = data.get("password_confirm") or "" + + if not token_value or not new_password: + return api_response( + success=False, + message="Token and new password are required", + status=400, + error_type="VALIDATION_ERROR", + ) + + if new_password != password_confirm: + return api_response( + success=False, + message="Passwords do not match", + status=400, + error_type="VALIDATION_ERROR", + ) + + if len(new_password) < 8: + return api_response( + success=False, + message="Password must be at least 8 characters", + status=400, + error_type="VALIDATION_ERROR", + ) + + reset_token = PasswordResetToken.query.filter_by(token=token_value).first() + if not reset_token or not reset_token.is_valid: + return api_response( + success=False, + message="This password reset link is invalid or has expired.", + status=400, + error_type="INVALID_TOKEN", + ) + + try: + user = reset_token.user + # Update the password hash on the authentication method + auth_method = AuthenticationMethod.query.filter_by( + user_id=user.id, + method_type=AuthMethodType.PASSWORD, + deleted_at=None, + ).first() + if auth_method: + auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8") + from gatehouse_app.extensions import db + db.session.add(auth_method) + + reset_token.consume() + _pw_logger.info(f"Password reset for user {user.id}") + + return api_response( + data={}, + message="Your password has been reset. You can now sign in with your new password.", + ) + except Exception as exc: + _pw_logger.exception(f"Error resetting password: {exc}") + return api_response( + success=False, + message="An error occurred while resetting your password.", + status=500, + error_type="INTERNAL_ERROR", + ) + + +@api_v1_bp.route("/auth/verify-email", methods=["POST"]) +def verify_email(): + """Verify a user's email address using a verification token. + + Request body: + token: Email verification token + + Returns: + 200: Email verified successfully + 400: Invalid or expired token + """ + from gatehouse_app.models import EmailVerificationToken + + data = request.get_json() or {} + token_value = (data.get("token") or "").strip() + + if not token_value: + return api_response( + success=False, + message="Verification token is required", + status=400, + error_type="VALIDATION_ERROR", + ) + + verify_token = EmailVerificationToken.query.filter_by(token=token_value).first() + if not verify_token or not verify_token.is_valid: + return api_response( + success=False, + message="This verification link is invalid or has expired.", + status=400, + error_type="INVALID_TOKEN", + ) + + try: + user = verify_token.user + user.email_verified = True + from gatehouse_app.extensions import db + db.session.add(user) + verify_token.consume() + _pw_logger.info(f"Email verified for user {user.id}") + + return api_response( + data={}, + message="Your email has been verified. You can now sign in.", + ) + except Exception as exc: + _pw_logger.exception(f"Error verifying email: {exc}") + return api_response( + success=False, + message="An error occurred while verifying your email.", + status=500, + error_type="INTERNAL_ERROR", + ) + + +@api_v1_bp.route("/auth/resend-verification", methods=["POST"]) +def resend_verification(): + """Resend email verification link. + + Always returns 200 to avoid leaking account existence. + + Request body: + email: User email address + + Returns: + 200: Verification email sent (or silently no-op) + """ + from gatehouse_app.models import User, EmailVerificationToken + + data = request.get_json() or {} + email = (data.get("email") or "").strip().lower() + + if not email: + return api_response( + success=False, + message="Email is required", + status=400, + error_type="VALIDATION_ERROR", + ) + + user = User.query.filter_by(email=email, deleted_at=None).first() + if user and not user.email_verified: + try: + verify_token = EmailVerificationToken.generate(user_id=user.id) + app_url = current_app.config.get("APP_URL", "http://localhost:8080") + verify_link = f"{app_url}/verify-email?token={verify_token.token}" + subject = "Verify your Gatehouse email address" + body = ( + f"Hi {user.full_name or user.email},\n\n" + f"Please verify your email address by clicking the link below (valid for 24 hours):\n" + f"{verify_link}\n\n" + f"Gatehouse Security Team" + ) + NotificationService._send_email( + to_address=user.email, + subject=subject, + body=body, + ) + _pw_logger.info(f"Verification email sent for user {user.id}") + except Exception as exc: + _pw_logger.exception(f"Error sending verification email: {exc}") + + return api_response( + data={}, + message="If an account exists for this email and is not yet verified, you will receive a verification link shortly.", + ) + + +# ============================================================================= +# Account Activation (separate from email-verification) +# ============================================================================= + +@api_v1_bp.route("/auth/activate", methods=["POST"]) +def activate_account(): + """Activate a user account via a one-time activation code. + + Request body: + code – the activation_key from the welcome email + + Returns: + 200: Account activated, session token returned + 400: Missing code + 404: Invalid or already-used code + """ + import secrets + from gatehouse_app.models.user.user import User + from gatehouse_app.extensions import db + + data = request.get_json() or {} + code = (data.get("code") or "").strip() + if not code: + return api_response(success=False, message="Activation code is required", status=400, error_type="VALIDATION_ERROR") + + user = User.query.filter_by(activation_key=code, deleted_at=None).first() + if not user: + return api_response(success=False, message="Invalid or expired activation code", status=404, error_type="NOT_FOUND") + + user.activated = True + user.activation_key = None # one-time use + db.session.add(user) + db.session.commit() + + user_session = AuthService.create_session(user) + _pw_logger.info(f"Account activated for user {user.id}") + + return api_response( + data={ + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z" + if user_session.expires_at.isoformat()[-1] != "Z" + else user_session.expires_at.isoformat(), + }, + message="Account activated successfully", + ) + + +@api_v1_bp.route("/auth/resend-activation", methods=["POST"]) +def resend_activation(): + """Re-send an account activation email. + + Always returns 200 to avoid leaking whether an account exists. + + Request body: + email – user email address + """ + import secrets + from gatehouse_app.models.user.user import User + from gatehouse_app.extensions import db + + data = request.get_json() or {} + email = (data.get("email") or "").strip().lower() + if not email: + return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR") + + user = User.query.filter_by(email=email, deleted_at=None).first() + if user and not user.activated: + try: + code = secrets.token_urlsafe(32) + user.activation_key = code + db.session.add(user) + db.session.commit() + + app_url = current_app.config.get("APP_URL", current_app.config.get("FRONTEND_URL", "http://localhost:8080")) + activate_link = f"{app_url}/activate?code={code}" + subject = "Activate your Gatehouse account" + body = ( + f"Hi {user.full_name or user.email},\n\n" + f"Please activate your Gatehouse account by clicking the link below:\n" + f"{activate_link}\n\n" + f"If you did not create an account, you can safely ignore this email.\n\n" + f"Gatehouse Security Team" + ) + NotificationService._send_email(to_address=user.email, subject=subject, body=body) + _pw_logger.info(f"Activation email re-sent to {user.id}") + except Exception as exc: + _pw_logger.exception(f"Error re-sending activation email: {exc}") + + return api_response( + data={}, + message="If an unactivated account exists for this email, you will receive a new activation link shortly.", + ) + + +# ============================================================================= +# Token retrieval / redirect (for CLI / external tools) +# ============================================================================= + +@api_v1_bp.route("/auth/token", methods=["GET"]) +@login_required +def get_token(): + """Return the current session token, optionally redirecting to a URL. + + Query parameters: + redirect – optional URL to redirect to with the token appended as + a query param: ``?token=`` + + Returns: + 200: JSON ``{"token": ""}`` (no redirect given) + 302: Redirect to ``?token=`` + """ + from flask import redirect as flask_redirect + from urllib.parse import urlparse + + token = g.current_session.token + redirect_url = request.args.get("redirect", "").strip() + + if redirect_url: + # Validate redirect URL against allowed origins to prevent open-redirect + # token exfiltration attacks (CWE-601). + allowed_origins = set(current_app.config.get("CORS_ORIGINS", [])) + frontend_url = current_app.config.get("FRONTEND_URL", "") + if frontend_url: + parsed = urlparse(frontend_url) + allowed_origins.add(f"{parsed.scheme}://{parsed.netloc}") + + parsed_redirect = urlparse(redirect_url) + redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}" + + if redirect_origin not in allowed_origins: + return api_response( + success=False, + message="Redirect URL is not allowed.", + status=400, + error_type="INVALID_REDIRECT", + ) + + sep = "&" if "?" in redirect_url else "?" + return flask_redirect(f"{redirect_url}{sep}token={token}", code=302) + + return api_response(data={"token": token}, message="Token retrieved") diff --git a/gatehouse_app/api/v1/departments.py b/gatehouse_app/api/v1/departments.py new file mode 100644 index 0000000..d305c66 --- /dev/null +++ b/gatehouse_app/api/v1/departments.py @@ -0,0 +1,699 @@ +"""Department endpoints.""" +from flask import g, request +from marshmallow import Schema, fields, validate, ValidationError +from sqlalchemy.orm.attributes import flag_modified + +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required +from gatehouse_app.models import Department, DepartmentMembership +from gatehouse_app.services.organization_service import OrganizationService +from gatehouse_app.services.user_service import UserService +from gatehouse_app.extensions import db + + +class DepartmentCreateSchema(Schema): + """Schema for creating a department.""" + name = fields.Str(required=True, validate=validate.Length(min=1, max=255)) + description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) + + +class DepartmentUpdateSchema(Schema): + """Schema for updating a department.""" + name = fields.Str(validate=validate.Length(min=1, max=255)) + description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) + + +class AddDepartmentMemberSchema(Schema): + """Schema for adding a member to a department.""" + email = fields.Email(required=True) + + +@api_v1_bp.route("/organizations//departments", methods=["GET"]) +@login_required +@full_access_required +def list_departments(org_id): + """ + List all departments in an organization. + + Args: + org_id: Organization ID + + Returns: + 200: List of departments + 401: Not authenticated + 403: Not a member + 404: Organization not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + # Check if user is a member + if not org.is_member(g.current_user.id): + return api_response( + success=False, + message="You are not a member of this organization", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + departments = Department.query.filter_by( + organization_id=org_id, + deleted_at=None + ).all() + + return api_response( + data={ + "departments": [d.to_dict() for d in departments], + "count": len(departments), + }, + message="Departments retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//departments", methods=["POST"]) +@login_required +@require_admin +@full_access_required +def create_department(org_id): + """ + Create a new department. + + Args: + org_id: Organization ID + + Request body: + name: Department name (required) + description: Optional description + + Returns: + 201: Department created successfully + 400: Validation error + 401: Not authenticated + 403: Not an admin + 404: Organization not found + 409: Department name already exists in org + """ + try: + org = OrganizationService.get_organization_by_id(org_id) + + schema = DepartmentCreateSchema() + data = schema.load(request.json or {}) + + # Check if department name already exists + existing = Department.query.filter_by( + organization_id=org_id, + name=data["name"], + deleted_at=None + ).first() + + if existing: + return api_response( + success=False, + message=f"Department '{data['name']}' already exists in this organization", + status=409, + error_type="CONFLICT", + ) + + # Create department + dept = Department( + organization_id=org_id, + name=data["name"], + description=data.get("description"), + ) + db.session.add(dept) + db.session.commit() + + return api_response( + data={"department": dept.to_dict()}, + message="Department created successfully", + status=201, + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + +@api_v1_bp.route("/organizations//departments/", methods=["GET"]) +@login_required +@full_access_required +def get_department(org_id, dept_id): + """ + Get a specific department. + + Args: + org_id: Organization ID + dept_id: Department ID + + Returns: + 200: Department data + 401: Not authenticated + 403: Not a member + 404: Organization or department not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + if not org.is_member(g.current_user.id): + return api_response( + success=False, + message="You are not a member of this organization", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + dept = Department.query.filter_by( + id=dept_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not dept: + return api_response( + success=False, + message="Department not found", + status=404, + error_type="NOT_FOUND", + ) + + return api_response( + data={"department": dept.to_dict()}, + message="Department retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//departments/", methods=["PATCH"]) +@login_required +@require_admin +@full_access_required +def update_department(org_id, dept_id): + """ + Update a department. + + Args: + org_id: Organization ID + dept_id: Department ID + + Request body: + name: Optional new name + description: Optional new description + + Returns: + 200: Department updated successfully + 400: Validation error + 401: Not authenticated + 403: Not an admin + 404: Organization or department not found + 409: Name already exists + """ + try: + org = OrganizationService.get_organization_by_id(org_id) + + dept = Department.query.filter_by( + id=dept_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not dept: + return api_response( + success=False, + message="Department not found", + status=404, + error_type="NOT_FOUND", + ) + + schema = DepartmentUpdateSchema() + data = schema.load(request.json or {}) + + # Check if new name already exists + if "name" in data and data["name"] != dept.name: + existing = Department.query.filter_by( + organization_id=org_id, + name=data["name"], + deleted_at=None + ).first() + if existing: + return api_response( + success=False, + message=f"Department '{data['name']}' already exists", + status=409, + error_type="CONFLICT", + ) + + # Update fields + for key, value in data.items(): + setattr(dept, key, value) + + db.session.commit() + + return api_response( + data={"department": dept.to_dict()}, + message="Department updated successfully", + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + +@api_v1_bp.route("/organizations//departments/", methods=["DELETE"]) +@login_required +@require_admin +@full_access_required +def delete_department(org_id, dept_id): + """ + Delete a department (soft delete). + + Args: + org_id: Organization ID + dept_id: Department ID + + Returns: + 200: Department deleted successfully + 401: Not authenticated + 403: Not an admin + 404: Organization or department not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + dept = Department.query.filter_by( + id=dept_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not dept: + return api_response( + success=False, + message="Department not found", + status=404, + error_type="NOT_FOUND", + ) + + # Soft delete + dept.deleted_at = db.func.now() + db.session.commit() + + return api_response( + message="Department deleted successfully", + ) + + +@api_v1_bp.route("/organizations//departments//members", methods=["GET"]) +@login_required +@full_access_required +def get_department_members(org_id, dept_id): + """ + Get all members of a department. + + Args: + org_id: Organization ID + dept_id: Department ID + + Returns: + 200: List of members + 401: Not authenticated + 403: Not a member + 404: Organization or department not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + if not org.is_member(g.current_user.id): + return api_response( + success=False, + message="You are not a member of this organization", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + dept = Department.query.filter_by( + id=dept_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not dept: + return api_response( + success=False, + message="Department not found", + status=404, + error_type="NOT_FOUND", + ) + + members = DepartmentMembership.query.filter_by( + department_id=dept_id, + deleted_at=None + ).all() + + members_data = [] + for member in members: + member_dict = member.to_dict() + member_dict["user"] = member.user.to_dict() + members_data.append(member_dict) + + return api_response( + data={ + "members": members_data, + "count": len(members_data), + }, + message="Members retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//departments//members", methods=["POST"]) +@login_required +@require_admin +@full_access_required +def add_department_member(org_id, dept_id): + """ + Add a member to a department. + + Args: + org_id: Organization ID + dept_id: Department ID + + Request body: + email: User email to add + + Returns: + 201: Member added successfully + 400: Validation error + 401: Not authenticated + 403: Not an admin + 404: Organization, department, or user not found + 409: User already a member + """ + try: + org = OrganizationService.get_organization_by_id(org_id) + + dept = Department.query.filter_by( + id=dept_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not dept: + return api_response( + success=False, + message="Department not found", + status=404, + error_type="NOT_FOUND", + ) + + schema = AddDepartmentMemberSchema() + data = schema.load(request.json or {}) + + # Find user by email + user = UserService.get_user_by_email(data["email"]) + if not user: + return api_response( + success=False, + message="User not found", + status=404, + error_type="NOT_FOUND", + ) + + # Check if already an active member + existing = DepartmentMembership.query.filter_by( + user_id=user.id, + department_id=dept_id, + deleted_at=None + ).first() + + if existing: + return api_response( + success=False, + message="User is already a member of this department", + status=409, + error_type="CONFLICT", + ) + + # Check for a previously soft-deleted row and resurrect it instead of inserting + soft_deleted = DepartmentMembership.query.filter( + DepartmentMembership.user_id == user.id, + DepartmentMembership.department_id == dept_id, + DepartmentMembership.deleted_at.isnot(None) + ).first() + + if soft_deleted: + soft_deleted.deleted_at = None + membership = soft_deleted + else: + membership = DepartmentMembership( + user_id=user.id, + department_id=dept_id, + ) + db.session.add(membership) + + db.session.commit() + + member_dict = membership.to_dict() + member_dict["user"] = user.to_dict() + + return api_response( + data={"member": member_dict}, + message="Member added successfully", + status=201, + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + +@api_v1_bp.route("/organizations//departments//members/", methods=["DELETE"]) +@login_required +@require_admin +@full_access_required +def remove_department_member(org_id, dept_id, user_id): + """ + Remove a member from a department. + + Args: + org_id: Organization ID + dept_id: Department ID + user_id: User ID to remove + + Returns: + 200: Member removed successfully + 401: Not authenticated + 403: Not an admin + 404: Organization, department, or member not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + dept = Department.query.filter_by( + id=dept_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not dept: + return api_response( + success=False, + message="Department not found", + status=404, + error_type="NOT_FOUND", + ) + + membership = DepartmentMembership.query.filter_by( + user_id=user_id, + department_id=dept_id, + deleted_at=None + ).first() + + if not membership: + return api_response( + success=False, + message="User is not a member of this department", + status=404, + error_type="NOT_FOUND", + ) + + # Soft delete + membership.deleted_at = db.func.now() + db.session.commit() + + return api_response( + message="Member removed successfully", + ) + + +@api_v1_bp.route("/organizations//departments//principals", methods=["GET"]) +@login_required +@full_access_required +def get_department_principals(org_id, dept_id): + """Get all principals linked to a department.""" + org = OrganizationService.get_organization_by_id(org_id) + + if not org.is_member(g.current_user.id): + return api_response( + success=False, + message="You are not a member of this organization", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + dept = Department.query.filter_by( + id=dept_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not dept: + return api_response( + success=False, + message="Department not found", + status=404, + error_type="NOT_FOUND", + ) + + principals = dept.get_principals(active_only=True) + + return api_response( + data={ + "principals": [p.to_dict() for p in principals], + "count": len(principals), + }, + message="Principals retrieved successfully", + ) + + +# --------------------------------------------------------------------------- +# Department Certificate Policy +# --------------------------------------------------------------------------- + +@api_v1_bp.route("/organizations//departments//cert-policy", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def get_dept_cert_policy(org_id, dept_id): + """Get the certificate issuance policy for a department (admin only).""" + from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy, STANDARD_EXTENSIONS + + dept = Department.query.filter_by( + id=dept_id, organization_id=org_id, deleted_at=None + ).first() + if not dept: + return api_response(success=False, message="Department not found", status=404, error_type="NOT_FOUND") + + policy = DepartmentCertPolicy.query.filter( + DepartmentCertPolicy.department_id == dept_id, + DepartmentCertPolicy.deleted_at.is_(None), + ).first() + + if policy: + data = policy.to_dict() + else: + # Return default (all standard extensions, no user expiry choice) + data = { + "department_id": str(dept_id), + "allow_user_expiry": False, + "default_expiry_hours": 1, + "max_expiry_hours": 24, + "allowed_extensions": list(STANDARD_EXTENSIONS), + "custom_extensions": [], + "all_extensions": list(STANDARD_EXTENSIONS), + "standard_extensions": list(STANDARD_EXTENSIONS), + } + + return api_response(data={"cert_policy": data}, message="Certificate policy retrieved") + + +@api_v1_bp.route("/organizations//departments//cert-policy", methods=["PUT"]) +@login_required +@require_admin +@full_access_required +def set_dept_cert_policy(org_id, dept_id): + """Create or update the certificate issuance policy for a department (admin only).""" + from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy, STANDARD_EXTENSIONS + + dept = Department.query.filter_by( + id=dept_id, organization_id=org_id, deleted_at=None + ).first() + if not dept: + return api_response(success=False, message="Department not found", status=404, error_type="NOT_FOUND") + + body = request.get_json() or {} + + # Validate expiry values + default_expiry = body.get("default_expiry_hours") + max_expiry = body.get("max_expiry_hours") + if default_expiry is not None: + try: + default_expiry = int(default_expiry) + if default_expiry < 1: + raise ValueError + except (ValueError, TypeError): + return api_response(success=False, message="default_expiry_hours must be a positive integer", status=400, error_type="VALIDATION_ERROR") + if max_expiry is not None: + try: + max_expiry = int(max_expiry) + if max_expiry < 1: + raise ValueError + except (ValueError, TypeError): + return api_response(success=False, message="max_expiry_hours must be a positive integer", status=400, error_type="VALIDATION_ERROR") + if default_expiry and max_expiry and default_expiry > max_expiry: + return api_response(success=False, message="default_expiry_hours cannot exceed max_expiry_hours", status=400, error_type="VALIDATION_ERROR") + + # Validate allowed_extensions — must be subset of STANDARD_EXTENSIONS + allowed_extensions = body.get("allowed_extensions") + if allowed_extensions is not None: + if not isinstance(allowed_extensions, list): + return api_response(success=False, message="allowed_extensions must be a list", status=400, error_type="VALIDATION_ERROR") + invalid_ext = [e for e in allowed_extensions if e not in STANDARD_EXTENSIONS] + if invalid_ext: + return api_response( + success=False, + message=f"Invalid standard extensions: {', '.join(invalid_ext)}. Valid: {', '.join(STANDARD_EXTENSIONS)}", + status=400, + error_type="VALIDATION_ERROR", + ) + + # Validate custom_extensions — plain strings + custom_extensions = body.get("custom_extensions") + if custom_extensions is not None: + if not isinstance(custom_extensions, list) or not all(isinstance(e, str) for e in custom_extensions): + return api_response(success=False, message="custom_extensions must be a list of strings", status=400, error_type="VALIDATION_ERROR") + + policy = DepartmentCertPolicy.query.filter( + DepartmentCertPolicy.department_id == dept_id, + DepartmentCertPolicy.deleted_at.is_(None), + ).first() + + if policy is None: + policy = DepartmentCertPolicy(department_id=dept_id) + db.session.add(policy) + + if "allow_user_expiry" in body: + policy.allow_user_expiry = bool(body["allow_user_expiry"]) + if default_expiry is not None: + policy.default_expiry_hours = default_expiry + if max_expiry is not None: + policy.max_expiry_hours = max_expiry + if allowed_extensions is not None: + policy.allowed_extensions = list(allowed_extensions) + flag_modified(policy, "allowed_extensions") + if custom_extensions is not None: + policy.custom_extensions = list(custom_extensions) + flag_modified(policy, "custom_extensions") + + db.session.commit() + + return api_response(data={"cert_policy": policy.to_dict()}, message="Certificate policy saved") + diff --git a/gatehouse_app/api/v1/external_auth.py b/gatehouse_app/api/v1/external_auth.py index bb2969d..a23101c 100644 --- a/gatehouse_app/api/v1/external_auth.py +++ b/gatehouse_app/api/v1/external_auth.py @@ -46,6 +46,33 @@ def _pop_oidc_bridge(oauth_state: str) -> str | None: pass return None + +def _store_cli_redirect(oauth_state: str, redirect_url: str) -> None: + """Store CLI redirect_url keyed by OAuth state (for /token_please flow).""" + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + rc.setex(f"oauth_cli_redirect:{oauth_state}", _OAUTH_BRIDGE_TTL, redirect_url) + except Exception: + pass + + +def _pop_cli_redirect(oauth_state: str) -> str | None: + """Retrieve and delete CLI redirect_url for the given OAuth state.""" + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + key = f"oauth_cli_redirect:{oauth_state}" + val = rc.get(key) + if val: + rc.delete(key) + return val.decode() if isinstance(val, bytes) else val + except Exception: + pass + return None + logger = logging.getLogger(__name__) @@ -69,6 +96,137 @@ def get_provider_type(provider: str) -> AuthMethodType: return PROVIDER_TYPE_MAP[provider_lower] +@api_v1_bp.route("/token_please", methods=["GET"]) +def token_please(): + """ + CLI token acquisition endpoint. + + Redirects the user's browser to the Gatehouse login page so they can + authenticate using any method (password, OAuth, passkey, TOTP, etc.). + On successful login the frontend delivers the session token directly to + the CLI's local callback server. + + This endpoint is designed for CLI clients that: + 1. Start a local HTTP server on LISTENER_SERVER_PORT (e.g. 8250) + 2. Open a browser to /api/v1/token_please?redirect_url=http://127.0.0.1:8250/?token= + 3. Wait for the browser to deliver the token to their local server + + Query parameters: + redirect_url: Local callback URL where the token will be appended + """ + import secrets + from urllib.parse import urlencode, quote + from flask import current_app, redirect as flask_redirect + + redirect_url = request.args.get("redirect_url", "").strip() + + if not redirect_url: + return api_response( + success=False, + message="redirect_url query parameter is required", + status=400, + error_type="MISSING_REDIRECT_URL", + ) + + # Validate redirect_url is localhost/127.0.0.1 (security: prevent open redirect) + from urllib.parse import urlparse as _urlparse + parsed = _urlparse(redirect_url) + if parsed.hostname not in ("localhost", "127.0.0.1"): + return api_response( + success=False, + message="redirect_url must point to localhost", + status=400, + error_type="INVALID_REDIRECT_URL", + ) + + # Store the CLI redirect URL in Redis keyed by a short-lived token so the + # frontend can retrieve it after login without it being visible in the URL. + cli_token = secrets.token_urlsafe(32) + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + rc.setex(f"cli_redirect:{cli_token}", _OAUTH_BRIDGE_TTL, redirect_url) + else: + logger.warning("Redis not available; passing cli_redirect directly in URL") + cli_token = None + except Exception: + cli_token = None + + frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080") + + if cli_token: + # Pass an opaque token; the frontend exchanges it for the real URL via + # GET /api/v1/cli/redirect-url?token= + login_url = f"{frontend_url}/login?cli_token={cli_token}" + else: + # Fallback: put the redirect URL directly (still localhost-only, validated above) + login_url = f"{frontend_url}/login?cli_redirect={quote(redirect_url, safe='')}" + + logger.info(f"CLI token_please: redirecting browser to Gatehouse login page") + return flask_redirect(login_url, code=302) + + +@api_v1_bp.route("/cli/redirect-url", methods=["GET"]) +def cli_redirect_url_lookup(): + """ + Exchange a short-lived cli_token for the CLI's local redirect URL. + + Called by the frontend LoginPage after it detects the cli_token query + param so it can obtain the actual CLI callback URL from Redis without + exposing it in the browser URL bar. + + Query parameters: + token: The cli_token issued by /token_please + + Returns: + 200: { "redirect_url": "http://127.0.0.1:8250/?token=" } + 400: Missing token + 404: Token not found or expired + """ + cli_token = request.args.get("token", "").strip() + if not cli_token: + return api_response( + success=False, + message="token query parameter is required", + status=400, + error_type="MISSING_TOKEN", + ) + + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + key = f"cli_redirect:{cli_token}" + val = rc.get(key) + if val is None: + return api_response( + success=False, + message="CLI token not found or expired", + status=404, + error_type="TOKEN_NOT_FOUND", + ) + # Keep the key alive until the login actually completes (consume on use + # would break multi-step auth like TOTP), so we leave it as-is. + redirect_url = val.decode() if isinstance(val, bytes) else val + return api_response(data={"redirect_url": redirect_url}) + except Exception as e: + logger.error(f"cli_redirect_url_lookup error: {e}") + return api_response( + success=False, + message="Internal error looking up CLI token", + status=500, + error_type="INTERNAL_ERROR", + ) + + return api_response( + success=False, + message="Redis not available", + status=503, + error_type="SERVICE_UNAVAILABLE", + ) + + # ============================================================================= # Provider Configuration Endpoints (Admin) # ============================================================================= @@ -83,7 +241,7 @@ def list_providers(): 200: List of providers with their configuration status 401: Not authenticated """ - from gatehouse_app.models.authentication_method import ApplicationProviderConfig + from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig from gatehouse_app.services.external_auth_service import ExternalProviderConfig # Check app-level provider configs (ApplicationProviderConfig) @@ -575,8 +733,6 @@ def initiate_oauth_authorize(provider: str): "state": "state_token" } """ - provider_type = get_provider_type(provider) - # Get query parameters - organization_id is now optional flow = request.args.get("flow", "login") redirect_uri = request.args.get("redirect_uri") @@ -592,7 +748,7 @@ def initiate_oauth_authorize(provider: str): ) try: - # Initiate flow - organization_id is now optional + provider_type = get_provider_type(provider) if flow == "login": auth_url, state = OAuthFlowService.initiate_login_flow( provider_type=provider_type, @@ -626,6 +782,13 @@ def initiate_oauth_authorize(provider: str): status=e.status_code, error_type=e.error_type, ) + except ExternalAuthError as e: + return api_response( + success=False, + message=e.message, + status=e.status_code, + error_type=e.error_type, + ) @api_v1_bp.route("/auth/external//callback", methods=["GET"]) @@ -666,8 +829,19 @@ def handle_oauth_callback(provider: str): frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080") frontend_callback = f"{frontend_url}/oauth/callback" + # Check if this is a CLI /token_please flow — retrieve stored redirect_url + cli_redirect_url = _pop_cli_redirect(state) if state else None + def redirect_error(message: str, error_type: str = "OAUTH_ERROR"): - """Redirect to frontend with error params.""" + """Redirect to frontend (or CLI) with error params.""" + if cli_redirect_url: + # CLI flow: return a plain error page instead of redirecting back + from flask import make_response + return make_response( + f"

Authentication Error

{message}

" + f"

You may close this window.

", + 400, + ) params = {"error": message, "error_type": error_type} if state: params["state"] = state @@ -706,8 +880,11 @@ def handle_oauth_callback(provider: str): # Recover oidc_session_id if this was triggered from an OIDC bridge flow oidc_session_id = _pop_oidc_bridge(state) + # Organization selection / creation flows are not supported in CLI mode + # (fall through to token redirect with whatever session we have) + # Organization selection needed (user belongs to multiple orgs) - if result.get("requires_org_selection"): + if result.get("requires_org_selection") and not cli_redirect_url: import json orgs = json.dumps(result.get("available_organizations", [])) params = { @@ -722,12 +899,20 @@ def handle_oauth_callback(provider: str): return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302) # Organization creation needed (new user via OAuth with no org) - if result.get("requires_org_creation"): + if result.get("requires_org_creation") and not cli_redirect_url: + import json as _json + session_data = result.get("session", {}) + token = session_data.get("token", "") + expires_in = session_data.get("expires_in", 86400) + pending_invites = result.get("pending_invites", []) params = { "requires_org_creation": "1", "state": result["state"], "provider": provider, "flow": flow_type, + "token": token, + "expires_in": str(expires_in), + "pending_invites": _json.dumps(pending_invites), } if oidc_session_id: params["oidc_session_id"] = oidc_session_id @@ -751,6 +936,19 @@ def handle_oauth_callback(provider: str): user_info = result.get("user", {}) if user_info.get("email"): params["email"] = user_info["email"] + + # ── CLI /token_please flow: redirect to the CLI's local callback ───── + if cli_redirect_url: + # The CLI expects: http://127.0.0.1:8250/?token= + # cli_redirect_url already ends with "token=" so just append the value + cli_final_url = cli_redirect_url + token + logger.info( + f"CLI token_please success: provider={provider}, user={user_info.get('email')}, " + f"redirecting to CLI callback" + ) + return flask_redirect(cli_final_url, code=302) + + # ── Frontend flow ───────────────────────────────────────────────────── # Pass oidc_session_id through so the frontend can complete the OIDC flow if oidc_session_id: params["oidc_session_id"] = oidc_session_id @@ -1049,3 +1247,203 @@ def _get_provider_endpoints(provider_type: AuthMethodType): "UNSUPPORTED_PROVIDER", 400, ) + + +# ============================================================================= +# Admin: Application-level OAuth Provider Management +# ============================================================================= + +@api_v1_bp.route("/admin/oauth/providers", methods=["GET"]) +@login_required +def admin_list_app_providers(): + """List all application-level OAuth provider configurations (admin only). + + Returns: + 200: List of providers with client_id and enabled status + 401: Not authenticated + 403: Not an admin + """ + from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig + from gatehouse_app.models import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + + # Verify caller is admin in any org + admin_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == g.current_user.id, + OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]), + ).all() + + if not admin_memberships: + return api_response( + success=False, + message="Admin access required", + status=403, + error_type="FORBIDDEN", + ) + + PROVIDERS = [ + {"id": "google", "name": "Google"}, + {"id": "github", "name": "GitHub"}, + {"id": "microsoft", "name": "Microsoft"}, + ] + + db_configs = { + c.provider_type: c + for c in ApplicationProviderConfig.query.all() + } + + result = [] + for p in PROVIDERS: + cfg = db_configs.get(p["id"]) + result.append({ + "id": p["id"], + "name": p["name"], + "is_configured": cfg is not None, + "is_enabled": cfg.is_enabled if cfg else False, + "client_id": cfg.client_id if cfg else None, + }) + + return api_response( + data={"providers": result}, + message="OAuth providers retrieved successfully", + ) + + +@api_v1_bp.route("/admin/oauth/providers/", methods=["PUT"]) +@login_required +def admin_configure_app_provider(provider: str): + """Create or update an application-level OAuth provider config (admin only). + + Args: + provider: Provider type (google, github, microsoft) + + Request body: + client_id: OAuth client ID + client_secret: OAuth client secret (optional — omit to keep existing) + is_enabled: Whether the provider is enabled (default: true) + + Returns: + 200: Provider configuration updated + 400: Validation error + 401: Not authenticated + 403: Not an admin + """ + from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig + from gatehouse_app.models import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + from gatehouse_app.extensions import db + + SUPPORTED = ["google", "github", "microsoft"] + if provider not in SUPPORTED: + return api_response( + success=False, + message=f"Unsupported provider. Must be one of: {', '.join(SUPPORTED)}", + status=400, + error_type="VALIDATION_ERROR", + ) + + # Verify caller is admin in any org + admin_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == g.current_user.id, + OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]), + ).all() + + if not admin_memberships: + return api_response( + success=False, + message="Admin access required", + status=403, + error_type="FORBIDDEN", + ) + + data = request.json or {} + client_id = (data.get("client_id") or "").strip() + client_secret = (data.get("client_secret") or "").strip() + is_enabled = data.get("is_enabled", True) + + if not client_id: + return api_response( + success=False, + message="client_id is required", + status=400, + error_type="VALIDATION_ERROR", + ) + + cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first() + if cfg: + cfg.client_id = client_id + if client_secret: + cfg.set_client_secret(client_secret) + cfg.is_enabled = bool(is_enabled) + db.session.commit() + else: + cfg = ApplicationProviderConfig( + provider_type=provider, + client_id=client_id, + is_enabled=bool(is_enabled), + ) + if client_secret: + cfg.set_client_secret(client_secret) + db.session.add(cfg) + db.session.commit() + + return api_response( + data={ + "provider": { + "id": provider, + "client_id": cfg.client_id, + "is_enabled": cfg.is_enabled, + } + }, + message=f"{provider.capitalize()} OAuth provider configured successfully", + ) + + +@api_v1_bp.route("/admin/oauth/providers/", methods=["DELETE"]) +@login_required +def admin_delete_app_provider(provider: str): + """Delete an application-level OAuth provider config (admin only). + + Args: + provider: Provider type (google, github, microsoft) + + Returns: + 200: Provider configuration deleted + 404: Provider not found + 401: Not authenticated + 403: Not an admin + """ + from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig + from gatehouse_app.models import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + from gatehouse_app.extensions import db + + # Verify caller is admin in any org + admin_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == g.current_user.id, + OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]), + ).all() + + if not admin_memberships: + return api_response( + success=False, + message="Admin access required", + status=403, + error_type="FORBIDDEN", + ) + + cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first() + if not cfg: + return api_response( + success=False, + message=f"Provider '{provider}' is not configured", + status=404, + error_type="NOT_FOUND", + ) + + db.session.delete(cfg) + db.session.commit() + + return api_response( + message=f"{provider.capitalize()} OAuth provider configuration removed", + ) diff --git a/gatehouse_app/api/v1/organizations.py b/gatehouse_app/api/v1/organizations.py index 98c39e1..7d8f4d3 100644 --- a/gatehouse_app/api/v1/organizations.py +++ b/gatehouse_app/api/v1/organizations.py @@ -1,5 +1,5 @@ """Organization endpoints.""" -from flask import g, request +from flask import g, request, current_app from marshmallow import ValidationError from gatehouse_app.api.v1 import api_v1_bp from gatehouse_app.utils.response import api_response @@ -13,6 +13,72 @@ from gatehouse_app.schemas.organization_schema import ( from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.services.user_service import UserService from gatehouse_app.utils.constants import OrganizationRole +from gatehouse_app.extensions import db + + + +def _get_system_ca_dict(): + """Return a synthetic read-only CA dict for the config-file CA, or None. + + This is injected into the org CA list when no DB CA exists for a given + ca_type so that the admin UI correctly shows "configured" rather than + "Not configured" when a system-level CA key is present. + + The returned dict has ``is_system=True`` so the frontend can render it + as read-only (no delete / edit / generate buttons). + """ + import os + try: + from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config + from gatehouse_app.utils.crypto import compute_ssh_fingerprint + + # Check env var first (takes priority over file path) + priv_key = os.environ.get("SSH_CA_PRIVATE_KEY", "").strip() + pub_key = "" + + if not priv_key: + cfg = get_ssh_ca_config() + key_path = cfg.get_str("ca_key_path", "").strip() + if not key_path: + return None + pub_path = key_path + ".pub" + if not os.path.exists(pub_path): + return None + with open(pub_path) as f: + pub_key = f.read().strip() + else: + # Derive the public key from the private key + from sshkey_tools.keys import PrivateKey + pk = PrivateKey.from_string(priv_key) + pub_key = pk.public_key.to_string() + + fingerprint = compute_ssh_fingerprint(pub_key) + return { + "id": f"system-ca-{fingerprint[:16]}", + "organization_id": None, + "name": "System CA (config file)", + "description": ( + "Read-only — this CA is loaded from the server's SSH_CA_PRIVATE_KEY " + "environment variable or etc/ssh_ca.conf. Manage it on the server." + ), + # ca_type is set by the caller + "ca_type": "user", + "key_type": "unknown", + "public_key": pub_key, + "fingerprint": fingerprint, + "is_active": True, + "is_system": True, + "default_cert_validity_hours": 0, + "max_cert_validity_hours": 0, + "total_certs": 0, + "active_certs": 0, + "revoked_certs": 0, + "created_at": None, + "updated_at": None, + } + except Exception: + return None + @api_v1_bp.route("/organizations", methods=["POST"]) @@ -160,6 +226,10 @@ def delete_organization(org_id): """ Delete organization (soft delete). + The owner may only delete the organization if they are the *sole* remaining + member. If other active members exist they must first transfer ownership + (or remove all other members) before deleting the organization. + Args: org_id: Organization ID @@ -168,9 +238,26 @@ def delete_organization(org_id): 401: Not authenticated 403: Not the owner 404: Organization not found + 409: Organization still has other members — transfer ownership first """ org = OrganizationService.get_organization_by_id(org_id) + # Guard: block deletion while non-owner members still exist so ownership + # can be transferred rather than silently orphaning them. + active_member_count = org.get_member_count() + if active_member_count > 1: + return api_response( + success=False, + message=( + "This organization still has other members. " + "Please transfer ownership to another member or remove all " + "other members before deleting the organization." + ), + status=409, + error_type="ORG_HAS_MEMBERS", + error_details={"member_count": active_member_count}, + ) + OrganizationService.delete_organization( org=org, user_id=g.current_user.id, @@ -378,3 +465,1399 @@ def update_member_role(org_id, user_id): error_type="VALIDATION_ERROR", error_details=e.messages, ) + + +@api_v1_bp.route("/organizations//transfer-ownership", methods=["POST"]) +@login_required +@full_access_required +def transfer_organization_ownership(org_id): + """Transfer organization ownership from the current user to another member. + + Only the current OWNER of the organization may call this endpoint. + The caller will be demoted to ADMIN and the target user will be promoted to OWNER. + + Request body: + new_owner_user_id (str): UUID of the member to promote to OWNER. + + Returns: + 200: Ownership transferred successfully + 400: Validation error / missing fields + 403: Caller is not the OWNER of this org + 404: Organization or target member not found + 409: Target is already the OWNER + """ + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole, AuditAction + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + + data = request.get_json() or {} + new_owner_user_id = data.get("new_owner_user_id") + if not new_owner_user_id: + return api_response( + success=False, + message="new_owner_user_id is required", + status=400, + error_type="VALIDATION_ERROR", + ) + + if str(new_owner_user_id) == str(caller.id): + return api_response( + success=False, + message="You are already the owner of this organization.", + status=409, + error_type="CONFLICT", + ) + + # Fetch org (raises NotFound internally) + org = OrganizationService.get_organization_by_id(org_id) + + # Confirm caller is the current OWNER + caller_membership = OrganizationMember.query.filter_by( + organization_id=org.id, + user_id=caller.id, + deleted_at=None, + ).first() + if not caller_membership or caller_membership.role != OrganizationRole.OWNER: + return api_response( + success=False, + message="Only the organization owner can transfer ownership.", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + # Verify the target is an active member + target_membership = OrganizationMember.query.filter_by( + organization_id=org.id, + user_id=new_owner_user_id, + deleted_at=None, + ).first() + if not target_membership: + return api_response( + success=False, + message="Target user is not a member of this organization.", + status=404, + error_type="NOT_FOUND", + ) + + if target_membership.role == OrganizationRole.OWNER: + return api_response( + success=False, + message="Target user is already the owner.", + status=409, + error_type="CONFLICT", + ) + + # ── Atomic role swap ───────────────────────────────────────────────────── + # Demote caller → ADMIN, promote target → OWNER. + # Both updates go through OrganizationService so all hooks/auditing fire. + try: + demoted = OrganizationService.update_member_role( + org=org, + user_id=str(caller.id), + new_role=OrganizationRole.ADMIN, + updater_id=str(caller.id), + ) + promoted = OrganizationService.update_member_role( + org=org, + user_id=str(new_owner_user_id), + new_role=OrganizationRole.OWNER, + updater_id=str(caller.id), + ) + except Exception as exc: + from gatehouse_app.extensions import db as _db + _db.session.rollback() + return api_response( + success=False, + message=f"Failed to transfer ownership: {exc}", + status=500, + error_type="SERVER_ERROR", + ) + + AuditService.log_action( + action=AuditAction.ORG_OWNERSHIP_TRANSFERRED, + user_id=caller.id, + organization_id=org.id, + resource_type="organization", + resource_id=str(org.id), + description=( + f"Ownership of '{org.name}' transferred from {caller.email} " + f"to {target_membership.user.email if target_membership.user else new_owner_user_id}" + ), + metadata={ + "previous_owner_id": str(caller.id), + "previous_owner_email": caller.email, + "new_owner_id": str(new_owner_user_id), + "new_owner_email": ( + target_membership.user.email if target_membership.user else None + ), + }, + ) + + def _member_dict(m): + d = m.to_dict() + if m.user: + d["user"] = m.user.to_dict() + return d + + return api_response( + data={ + "previous_owner": _member_dict(demoted), + "new_owner": _member_dict(promoted), + }, + message=( + f"Ownership of '{org.name}' successfully transferred to " + f"{target_membership.user.email if target_membership.user else new_owner_user_id}." + ), + ) + + +@api_v1_bp.route("/organizations//audit-logs", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def get_organization_audit_logs(org_id): + """ + Get audit logs for an organization. + + Query params: + page: Page number (default 1) + per_page: Results per page (default 50, max 200) + action: Filter by action type + + Returns: + 200: List of audit log entries + 401: Not authenticated + 403: Not a member / insufficient permissions + 404: Organization not found + """ + from gatehouse_app.models.auth.audit_log import AuditLog + + # Ensure org exists and user is a member (full_access_required handles this) + OrganizationService.get_organization_by_id(org_id) + + page = int(request.args.get("page", 1)) + per_page = min(int(request.args.get("per_page", 50)), 200) + action_filter = request.args.get("action") + + query = AuditLog.query.filter_by(organization_id=org_id) + if action_filter: + query = query.filter_by(action=action_filter) + + query = query.order_by(AuditLog.created_at.desc()) + total = query.count() + logs = query.offset((page - 1) * per_page).limit(per_page).all() + + def log_to_dict(log): + return { + "id": log.id, + "action": log.action.value if log.action else None, + "user_id": log.user_id, + "user_email": log.user.email if log.user else None, + "user": {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} if log.user else None, + "organization_id": log.organization_id, + "resource_type": log.resource_type, + "resource_id": log.resource_id, + "ip_address": log.ip_address, + "user_agent": log.user_agent, + "request_id": log.request_id, + "description": log.description, + "success": log.success, + "error_message": log.error_message, + "metadata": log.extra_data, + "created_at": log.created_at.isoformat() if log.created_at else None, + "updated_at": log.updated_at.isoformat() if log.updated_at else None, + } + + return api_response( + data={ + "audit_logs": [log_to_dict(log) for log in logs], + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + }, + message="Audit logs retrieved successfully", + ) + + +# ============================================================================ +# Organization Invite Tokens +# ============================================================================ + +@api_v1_bp.route("/organizations//invites", methods=["POST"]) +@login_required +@require_admin +def create_org_invite(org_id): + """Create an invite token for an organization. + + Request body: + email: Email address to invite + role: Role to assign (default: member) + + Returns: + 201: Invite created + 400: Validation error + 403: Not an admin + 404: Organization not found + """ + from gatehouse_app.models import OrgInviteToken, Organization + from gatehouse_app.services.notification_service import NotificationService + from flask import current_app + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + data = request.get_json() or {} + email = (data.get("email") or "").strip().lower() + role = (data.get("role") or "member").strip() + + if not email: + return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR") + + invite = OrgInviteToken.generate( + organization_id=org_id, + email=email, + role=role, + invited_by_id=g.current_user.id, + ) + + app_url = current_app.config.get("APP_URL", "http://localhost:8080") + invite_link = f"{app_url}/invite?token={invite.token}" + + email_sent = NotificationService._send_email( + to_address=email, + subject=f"You're invited to join {org.name} on Gatehouse", + body=( + f"You've been invited to join {org.name} on Gatehouse.\n\n" + f"Click the link below to accept the invitation (valid for 7 days):\n" + f"{invite_link}\n\n" + f"Gatehouse Security Team" + ), + ) + + # In dev mode email may not be configured — always log the link so it's findable + import logging + if not email_sent: + logging.getLogger(__name__).warning( + f"[INVITE LINK] Email not sent (EMAIL_ENABLED=False or SMTP down). " + f"Invite for {email} → {invite_link}" + ) + else: + logging.getLogger(__name__).info( + f"[INVITE] Email sent successfully to {email}" + ) + + response_data = { + "invite": { + "id": invite.id, + "email": invite.email, + "role": invite.role, + "expires_at": invite.expires_at.isoformat() + "Z", + # Only include invite_link when email delivery failed — signals frontend to show copy dialog + **({"invite_link": invite_link} if not email_sent else {}), + } + } + + return api_response( + data=response_data, + message="Invite sent successfully", + status=201, + ) + + +@api_v1_bp.route("/organizations//invites", methods=["GET"]) +@login_required +@require_admin +def list_org_invites(org_id): + """List pending invite tokens for an organization. + + Returns: + 200: List of invites + 403: Not an admin + 404: Organization not found + """ + from gatehouse_app.models import OrgInviteToken, Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + invites = ( + OrgInviteToken.query.filter_by(organization_id=org_id) + .filter(OrgInviteToken.accepted_at == None) + .filter(OrgInviteToken.deleted_at == None) + .all() + ) + + def invite_to_dict(inv): + return { + "id": inv.id, + "email": inv.email, + "role": inv.role, + "invited_by_id": inv.invited_by_id, + "created_at": inv.created_at.isoformat() + "Z", + "expires_at": inv.expires_at.isoformat() + "Z", + } + + return api_response( + data={"invites": [invite_to_dict(i) for i in invites]}, + message="Invites retrieved", + ) + + +@api_v1_bp.route("/organizations//invites/", methods=["DELETE"]) +@login_required +@require_admin +def cancel_org_invite(org_id, invite_id): + """Cancel (soft-delete) an organization invite. + + Returns: + 200: Invite cancelled + 403: Not an admin + 404: Invite not found + """ + from gatehouse_app.models import OrgInviteToken, Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + invite = OrgInviteToken.query.filter_by(id=invite_id, organization_id=org_id, deleted_at=None).first() + if not invite: + return api_response(success=False, message="Invite not found", status=404) + + # Soft delete the invite so it's no longer usable + invite.delete(soft=True) + + return api_response(data={}, message="Invite cancelled") + + +@api_v1_bp.route("/invites/", methods=["GET"]) +def get_invite(token): + """Get invite details by token. + + Returns: + 200: Invite details (org name, email) + 400: Invalid or expired token + """ + from gatehouse_app.models import OrgInviteToken, User + + invite = OrgInviteToken.query.filter_by(token=token).first() + if not invite or not invite.is_valid: + return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN") + + user_exists = User.query.filter_by(email=invite.email, deleted_at=None).first() is not None + + return api_response( + data={ + "email": invite.email, + "organization": {"id": invite.organization_id, "name": invite.organization.name}, + "role": invite.role, + "user_exists": user_exists, + }, + message="Invite found", + ) + + +@api_v1_bp.route("/invites//accept", methods=["POST"]) +def accept_invite(token): + """Accept an organization invite. + + Creates the user account (if not already registered) and adds them + to the organization. + + Request body: + full_name: User's display name + password: Password for new account (if not already registered) + password_confirm: Password confirmation + + Returns: + 200: Invite accepted, returns user token + 400: Invalid/expired token or validation error + 409: Already a member + """ + from gatehouse_app.models import OrgInviteToken, User + from gatehouse_app.services.auth_service import AuthService + from gatehouse_app.services.organization_service import OrganizationService + from gatehouse_app.utils.constants import OrganizationRole + + invite = OrgInviteToken.query.filter_by(token=token).first() + if not invite or not invite.is_valid: + return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN") + + data = request.get_json() or {} + full_name = data.get("full_name") or "" + password = data.get("password") or "" + password_confirm = data.get("password_confirm") or "" + + user = User.query.filter_by(email=invite.email, deleted_at=None).first() + + if not user: + # Register a new user + if not password: + return api_response(success=False, message="Password is required for new accounts.", status=400, error_type="VALIDATION_ERROR") + if password != password_confirm: + return api_response(success=False, message="Passwords do not match.", status=400, error_type="VALIDATION_ERROR") + if len(password) < 8: + return api_response(success=False, message="Password must be at least 8 characters.", status=400, error_type="VALIDATION_ERROR") + try: + user = AuthService.register_user(email=invite.email, password=password, full_name=full_name or None) + except Exception as exc: + return api_response(success=False, message=str(exc), status=400, error_type="REGISTRATION_ERROR") + + # Add to org + role_value = invite.role + try: + org_role = OrganizationRole(role_value) + except ValueError: + org_role = OrganizationRole.MEMBER + + try: + OrganizationService.add_member( + org=invite.organization, + user_id=user.id, + role=org_role, + inviter_id=invite.invited_by_id, + ) + except Exception: + from gatehouse_app.extensions import db + db.session.rollback() # Clear broken transaction so invite.accept() can commit + + invite.accept() + + has_webauthn = user.has_webauthn_enabled() + has_totp = user.has_totp_enabled() + + if has_webauthn: + from flask import session as flask_session + flask_session["webauthn_pending_user_id"] = user.id + return api_response( + data={"requires_webauthn": True}, + message="Passkey verification required. Please use your passkey to complete sign-in.", + ) + + if has_totp: + from flask import session as flask_session + flask_session["totp_pending_user_id"] = user.id + return api_response( + data={"requires_totp": True}, + message="TOTP code required. Please enter your 6-digit code from your authenticator app.", + ) + + user_session = AuthService.create_session(user) + + return api_response( + data={ + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z", + }, + message="Invitation accepted. Welcome!", + ) + + +# ============================================================================ +# Organization OIDC Clients +# ============================================================================ + +@api_v1_bp.route("/organizations//clients", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def list_org_clients(org_id): + """List OIDC clients for an organization. + + Returns: + 200: List of OIDC clients + 403: Not an admin + 404: Organization not found + """ + from gatehouse_app.models import OIDCClient, Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + clients = OIDCClient.query.filter_by(organization_id=org_id, is_active=True).all() + + def client_to_dict(c): + return { + "id": c.id, + "name": c.name, + "client_id": c.client_id, + "redirect_uris": c.redirect_uris, + "scopes": c.scopes, + "grant_types": c.grant_types, + "is_active": c.is_active, + "created_at": c.created_at.isoformat() + "Z", + } + + return api_response( + data={"clients": [client_to_dict(c) for c in clients], "count": len(clients)}, + message="Clients retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//clients", methods=["POST"]) +@login_required +@require_admin +def create_org_client(org_id): + """Create a new OIDC client for an organization. + + Request body: + name: Client name + redirect_uris: List of allowed redirect URIs (newline or comma separated string) + + Returns: + 201: Client created with client_id and client_secret + 403: Not an admin + 404: Organization not found + """ + import secrets as _secrets + from gatehouse_app.extensions import bcrypt + from gatehouse_app.models import OIDCClient, Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + data = request.get_json() or {} + name = (data.get("name") or "").strip() + redirect_uris_raw = data.get("redirect_uris") or [] + + if not name: + return api_response(success=False, message="Client name is required", status=400, error_type="VALIDATION_ERROR") + + if isinstance(redirect_uris_raw, str): + redirect_uris = [u.strip() for u in redirect_uris_raw.replace(",", "\n").splitlines() if u.strip()] + else: + redirect_uris = [u.strip() for u in redirect_uris_raw if isinstance(u, str) and u.strip()] + + if not redirect_uris: + return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR") + + client_id = _secrets.token_hex(16) + client_secret = _secrets.token_urlsafe(32) + + client = OIDCClient( + organization_id=org_id, + name=name, + client_id=client_id, + client_secret_hash=bcrypt.generate_password_hash(client_secret).decode("utf-8"), + redirect_uris=redirect_uris, + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scopes=["openid", "profile", "email"], + is_active=True, + is_confidential=True, + ) + from gatehouse_app.extensions import db + db.session.add(client) + db.session.commit() + + return api_response( + data={ + "client": { + "id": client.id, + "name": client.name, + "client_id": client.client_id, + "client_secret": client_secret, # Only returned once + "redirect_uris": client.redirect_uris, + "scopes": client.scopes, + "created_at": client.created_at.isoformat() + "Z", + } + }, + message="OIDC client created successfully", + status=201, + ) + + +@api_v1_bp.route("/organizations//clients/", methods=["DELETE"]) +@login_required +@require_admin +def delete_org_client(org_id, client_id): + """Deactivate an OIDC client. + + Returns: + 200: Client deactivated + 403: Not an admin + 404: Client not found + """ + from gatehouse_app.models import OIDCClient + from gatehouse_app.extensions import db + + client = OIDCClient.query.filter_by(id=client_id, organization_id=org_id).first() + if not client: + return api_response(success=False, message="Client not found", status=404) + + client.is_active = False + db.session.commit() + + return api_response(data={}, message="Client deactivated successfully") + + +@api_v1_bp.route("/organizations//members//send-mfa-reminder", methods=["POST"]) +@login_required +@require_admin +def send_mfa_reminder(org_id, user_id): + """Send an MFA reminder email to a specific member. + + Returns: + 200: Reminder sent (or silently skipped if no deadline record) + 403: Not an admin + 404: Member not found + """ + from gatehouse_app.models import User, MfaPolicyCompliance, OrganizationSecurityPolicy + from gatehouse_app.services.notification_service import NotificationService + + user = User.query.filter_by(id=user_id, deleted_at=None).first() + if not user: + return api_response(success=False, message="User not found", status=404) + + compliance = MfaPolicyCompliance.query.filter_by( + user_id=user_id, organization_id=org_id + ).first() + policy = OrganizationSecurityPolicy.query.filter_by(organization_id=org_id).first() + + if compliance and policy and compliance.deadline_at: + NotificationService.send_mfa_deadline_reminder(user, compliance, policy) + else: + # No compliance deadline — send a generic nudge + NotificationService._send_email( + to_address=user.email, + subject="Reminder: Set up multi-factor authentication", + body=( + f"Hi {user.full_name or user.email},\n\n" + "Your organization administrator has asked you to set up " + "multi-factor authentication (MFA) on your Gatehouse account.\n\n" + "Please log in and configure MFA as soon as possible.\n\n" + "Gatehouse Security Team" + ), + ) + + return api_response(data={}, message="Reminder sent successfully") + + +# ============================================================================= +# System-wide Audit Log (admin view) + User self audit +# ============================================================================= + +def _audit_log_to_dict(log): + """Serialize an AuditLog record to a dict.""" + return { + "id": log.id, + "action": log.action.value if log.action else None, + "user_id": log.user_id, + "user": ( + {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} + if log.user else None + ), + "organization_id": log.organization_id, + "resource_type": log.resource_type, + "resource_id": log.resource_id, + "ip_address": log.ip_address, + "user_agent": log.user_agent, + "request_id": log.request_id, + "description": log.description, + "success": log.success, + "error_message": log.error_message, + "metadata": log.extra_data, + "created_at": log.created_at.isoformat() if log.created_at else None, + "updated_at": log.updated_at.isoformat() if log.updated_at else None, + } + + +@api_v1_bp.route("/audit-logs", methods=["GET"]) +@login_required +def get_system_audit_logs(): + """ + Get all audit logs (system-wide). Any authenticated user can query + their own logs; org owners/admins also see org-scoped logs; this + endpoint returns ALL logs for users who own at least one org + (acting as an admin view). + + Query params: + page – page number (default 1) + per_page – results per page (default 50, max 200) + action – filter by AuditAction value + user_id – filter by user id + resource_type – filter by resource type + success – "true"/"false" + q – free-text search on description + """ + from gatehouse_app.models.auth.audit_log import AuditLog + from gatehouse_app.models.organization.organization_member import OrganizationMember + + current_user = g.current_user + page = max(1, int(request.args.get("page", 1))) + per_page = min(int(request.args.get("per_page", 50)), 200) + + # Check if the user is an admin or owner of any org to grant admin-level access + is_admin = OrganizationMember.query.filter( + OrganizationMember.user_id == current_user.id, + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).first() is not None + + query = AuditLog.query + + if not is_admin: + # Non-admins can only see their own logs + query = query.filter(AuditLog.user_id == current_user.id) + + # Optional filters + action_filter = request.args.get("action") + if action_filter: + query = query.filter(AuditLog.action == action_filter) + + user_id_filter = request.args.get("user_id") + if user_id_filter: + query = query.filter(AuditLog.user_id == user_id_filter) + + resource_type_filter = request.args.get("resource_type") + if resource_type_filter: + query = query.filter(AuditLog.resource_type == resource_type_filter) + + success_filter = request.args.get("success") + if success_filter is not None: + query = query.filter(AuditLog.success == (success_filter.lower() == "true")) + + q = request.args.get("q", "").strip() + if q: + query = query.filter(AuditLog.description.ilike(f"%{q}%")) + + query = query.order_by(AuditLog.created_at.desc()) + total = query.count() + logs = query.offset((page - 1) * per_page).limit(per_page).all() + + return api_response( + data={ + "audit_logs": [_audit_log_to_dict(log) for log in logs], + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + "is_admin_view": is_admin, + }, + message="Audit logs retrieved", + ) + + +@api_v1_bp.route("/auth/audit-logs", methods=["GET"]) +@login_required +def get_my_audit_logs(): + """ + Get audit logs for the currently authenticated user only. + + Query params: + page – page number (default 1) + per_page – results per page (default 50, max 200) + action – filter by AuditAction value + """ + from gatehouse_app.models.auth.audit_log import AuditLog + + current_user = g.current_user + page = max(1, int(request.args.get("page", 1))) + per_page = min(int(request.args.get("per_page", 50)), 200) + + query = AuditLog.query.filter(AuditLog.user_id == current_user.id) + + action_filter = request.args.get("action") + if action_filter: + query = query.filter(AuditLog.action == action_filter) + + query = query.order_by(AuditLog.created_at.desc()) + total = query.count() + logs = query.offset((page - 1) * per_page).limit(per_page).all() + + return api_response( + data={ + "audit_logs": [_audit_log_to_dict(log) for log in logs], + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + }, + message="Activity retrieved", + ) + + + +@api_v1_bp.route("/organizations//roles", methods=["GET"]) +@login_required +def list_organization_roles(org_id): + """List the available roles for an organization. + + Returns the canonical set of OrganizationRole values together with every + current member assigned to each role. + + Returns: + 200: roles list with member counts + 401: Not authenticated + 404: Organization not found + """ + from gatehouse_app.models.organization.organization import Organization + from gatehouse_app.models.organization.organization_member import OrganizationMember + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + # Load all active members grouped by role + members = OrganizationMember.query.filter_by(organization_id=org_id, deleted_at=None).all() + by_role: dict = {r.value: [] for r in OrganizationRole} + for m in members: + role_key = m.role.value if hasattr(m.role, "value") else str(m.role) + if role_key in by_role: + by_role[role_key].append({ + "user_id": m.user_id, + "email": m.user.email if m.user else None, + "full_name": m.user.full_name if m.user else None, + "joined_at": m.created_at.isoformat() if m.created_at else None, + }) + + roles = [ + { + "role": r.value, + "member_count": len(by_role[r.value]), + "members": by_role[r.value], + } + for r in OrganizationRole + ] + return api_response(data={"roles": roles, "organization_id": org_id}, message="Roles retrieved") + + +@api_v1_bp.route("/organizations//roles//members", methods=["POST"]) +@login_required +@require_admin +def assign_role_to_member(org_id, role_name): + """Assign a role to a user in the organization (admin/owner only). + + This is a convenience endpoint equivalent to PATCH + /organizations//members//role but driven by role name. + + Request body: + user_id – UUID of the member to assign + + Returns: + 200: Role assigned + 400: Invalid role / missing user_id + 403: Not an admin/owner + 404: Org or member not found + """ + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.extensions import db + + try: + new_role = OrganizationRole(role_name.lower()) + except ValueError: + valid = [r.value for r in OrganizationRole] + return api_response(success=False, message=f"Invalid role. Must be one of: {valid}", status=400, error_type="VALIDATION_ERROR") + + data = request.get_json() or {} + target_user_id = data.get("user_id") + if not target_user_id: + return api_response(success=False, message="user_id is required", status=400, error_type="VALIDATION_ERROR") + + membership = OrganizationMember.query.filter_by( + organization_id=org_id, user_id=target_user_id, deleted_at=None + ).first() + if not membership: + return api_response(success=False, message="Member not found in this organization", status=404, error_type="NOT_FOUND") + + membership.role = new_role + db.session.commit() + return api_response( + data={"user_id": target_user_id, "role": new_role.value}, + message=f"Role updated to {new_role.value}", + ) + + +@api_v1_bp.route("/organizations//roles//members/", methods=["DELETE"]) +@login_required +@require_admin +def remove_role_from_member(org_id, role_name, user_id): + """Demote a member to GUEST (effectively removing a named role). + + Removing a role downgrades the member to GUEST rather than removing them + from the organization entirely. Use the existing DELETE + /organizations//members/ endpoint to fully remove. + + Returns: + 200: Role removed (member demoted to GUEST) + 400: Invalid role name + 403: Not an admin/owner + 404: Org or member not found + """ + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.extensions import db + + try: + OrganizationRole(role_name.lower()) # validate the name + except ValueError: + valid = [r.value for r in OrganizationRole] + return api_response(success=False, message=f"Invalid role. Must be one of: {valid}", status=400, error_type="VALIDATION_ERROR") + + membership = OrganizationMember.query.filter_by( + organization_id=org_id, user_id=user_id, deleted_at=None + ).first() + if not membership: + return api_response(success=False, message="Member not found in this organization", status=404, error_type="NOT_FOUND") + + membership.role = OrganizationRole.GUEST + db.session.commit() + return api_response( + data={"user_id": user_id, "role": OrganizationRole.GUEST.value}, + message="Role removed; member demoted to GUEST", + ) + + +@api_v1_bp.route("/organizations//cas", methods=["GET"]) +@login_required +@require_admin +def list_org_cas(org_id): + """List all Certificate Authorities for an organization. + + If the system config-file CA is configured (via SSH_CA_PRIVATE_KEY env var + or ca_key_path in etc/ssh_ca.conf) and no DB CA exists for a given ca_type, + a synthetic read-only entry is injected so the UI correctly shows the + system CA as configured rather than "Not configured". + + Returns: + 200: List of CAs (private_key excluded) + 403: Not admin/owner + 404: Org not found + """ + from gatehouse_app.models.ssh_ca.ca import CA, CaType + from gatehouse_app.models.organization.organization import Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + cas = CA.query.filter_by(organization_id=org_id, deleted_at=None).all() + ca_list = [ca.to_dict() for ca in cas] + + # Determine which ca_types are already covered by a DB CA + covered_types = {ca.ca_type for ca in cas} + + # Check whether a system config-file CA is available + system_ca_dict = _get_system_ca_dict() + if system_ca_dict: + # Inject a synthetic entry for each ca_type NOT covered by a real DB CA. + # The system CA only signs user certs (cert_type="user"), so we only + # inject it for the user slot. Host signing always needs a DB CA. + if CaType.USER not in covered_types: + ca_list.append({**system_ca_dict, "ca_type": "user"}) + + return api_response( + data={"cas": ca_list, "count": len(ca_list)}, + message="CAs retrieved", + ) + + +@api_v1_bp.route("/organizations//cas/", methods=["PATCH"]) +@login_required +@require_admin +def update_org_ca(org_id, ca_id): + """Update CA configuration (validity hours). + + Request body: + default_cert_validity_hours: Default validity in hours (optional) + max_cert_validity_hours: Maximum validity in hours (optional) + + Returns: + 200: CA updated successfully + 400: Validation error + 403: Not admin/owner + 404: Org or CA not found + """ + from gatehouse_app.models.ssh_ca.ca import CA + from gatehouse_app.models.organization.organization import Organization + from marshmallow import Schema, fields, validate, ValidationError + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + try: + class CAUpdateSchema(Schema): + default_cert_validity_hours = fields.Int( + validate=validate.Range(min=1), + required=False + ) + max_cert_validity_hours = fields.Int( + validate=validate.Range(min=1), + required=False + ) + + schema = CAUpdateSchema() + data = schema.load(request.json or {}) + + # Validate that max >= default if both are provided + default_hours = data.get('default_cert_validity_hours', ca.default_cert_validity_hours) + max_hours = data.get('max_cert_validity_hours', ca.max_cert_validity_hours) + + if default_hours > max_hours: + return api_response( + success=False, + message="Default validity must be less than or equal to maximum validity", + status=400, + error_type="VALIDATION_ERROR", + ) + + # Update fields + if 'default_cert_validity_hours' in data: + ca.default_cert_validity_hours = data['default_cert_validity_hours'] + if 'max_cert_validity_hours' in data: + ca.max_cert_validity_hours = data['max_cert_validity_hours'] + + db.session.commit() + + return api_response( + data={"ca": ca.to_dict()}, + message="CA updated successfully", + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + except Exception as e: + db.session.rollback() + return api_response( + success=False, + message="Failed to update CA", + status=500, + error_type="SERVER_ERROR", + ) + + +@api_v1_bp.route("/organizations//cas", methods=["POST"]) +@login_required +@require_admin +def create_org_ca(org_id): + """Create a new Certificate Authority for an organization. + + Request body: + name: CA display name (required) + description: Optional description + key_type: "ed25519" (default), "rsa", or "ecdsa" + default_cert_validity_hours: Default cert validity in hours (optional) + max_cert_validity_hours: Max cert validity in hours (optional) + + Returns: + 201: CA created successfully + 400: Validation error or name already taken + 403: Not admin/owner + 404: Org not found + """ + from gatehouse_app.models.ssh_ca.ca import CA, KeyType + from gatehouse_app.models.organization.organization import Organization + from gatehouse_app.utils.crypto import compute_ssh_fingerprint + from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key + from marshmallow import Schema, fields as ma_fields, validate, ValidationError as MaValidationError + from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + class CreateCASchema(Schema): + name = ma_fields.Str(required=True, validate=validate.Length(min=1, max=255)) + description = ma_fields.Str(load_default=None, allow_none=True) + ca_type = ma_fields.Str(load_default="user", validate=validate.OneOf(["user", "host"])) + key_type = ma_fields.Str(load_default="ed25519", validate=validate.OneOf(["ed25519", "rsa", "ecdsa"])) + default_cert_validity_hours = ma_fields.Int(load_default=8, validate=validate.Range(min=1)) + max_cert_validity_hours = ma_fields.Int(load_default=720, validate=validate.Range(min=1)) + + try: + schema = CreateCASchema() + data = schema.load(request.get_json() or {}) + + # Check name uniqueness within org + existing = CA.query.filter_by( + organization_id=org_id, name=data["name"], deleted_at=None + ).first() + if existing: + return api_response( + success=False, + message="A CA with that name already exists in this organization", + status=400, + error_type="DUPLICATE_NAME", + ) + + # Enforce one CA per type per org + from gatehouse_app.models.ssh_ca.ca import CaType + ca_type_val = data["ca_type"] + existing_type = CA.query.filter_by( + organization_id=org_id, deleted_at=None + ).filter(CA.ca_type == CaType(ca_type_val)).first() + if existing_type: + type_label = "User" if ca_type_val == "user" else "Host" + return api_response( + success=False, + message=f"A {type_label} CA already exists for this organization. " + f"You can only have one {type_label} CA per organization.", + status=400, + error_type="DUPLICATE_CA_TYPE", + ) + + # Validate cross-field + if data["default_cert_validity_hours"] > data["max_cert_validity_hours"]: + return api_response( + success=False, + message="Default validity must be less than or equal to maximum validity", + status=400, + error_type="VALIDATION_ERROR", + ) + + # Generate key pair + key_type = data["key_type"] + if key_type == "ed25519": + private_key_obj = Ed25519PrivateKey.generate() + elif key_type == "rsa": + private_key_obj = RsaPrivateKey.generate(4096) + else: # ecdsa + private_key_obj = EcdsaPrivateKey.generate() + + private_key_pem = private_key_obj.to_string() + public_key_str = private_key_obj.public_key.to_string() + fingerprint = compute_ssh_fingerprint(public_key_str) + + # Encrypt the private key before storing in the database + encrypted_private_key = encrypt_ca_key(private_key_pem) + + ca = CA( + organization_id=org_id, + name=data["name"], + description=data["description"], + ca_type=CaType(ca_type_val), + key_type=KeyType(key_type), + private_key=encrypted_private_key, + public_key=public_key_str, + fingerprint=fingerprint, + default_cert_validity_hours=data["default_cert_validity_hours"], + max_cert_validity_hours=data["max_cert_validity_hours"], + is_active=True, + ) + db.session.add(ca) + try: + db.session.commit() + except Exception as commit_exc: + db.session.rollback() + # Surface unique-constraint violations (soft-deleted record with same name) as a + # user-friendly 400 instead of a 500. + exc_str = str(commit_exc).lower() + if "uix_org_ca_name" in exc_str or "unique" in exc_str: + return api_response( + success=False, + message=( + "A CA with that name already exists in this organization " + "(it may have been recently deleted — choose a different name)." + ), + status=400, + error_type="DUPLICATE_NAME", + ) + raise + + return api_response( + data={"ca": ca.to_dict()}, + message="CA created successfully", + status=201, + ) + + except MaValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + except Exception as e: + db.session.rollback() + current_app.logger.exception("Failed to create CA") + return api_response( + success=False, + message="Failed to create CA", + status=500, + error_type="SERVER_ERROR", + ) + + +@api_v1_bp.route("/organizations//cas/", methods=["DELETE"]) +@login_required +@require_admin +def delete_org_ca(org_id, ca_id): + """Soft-delete a Certificate Authority. + + Deactivates the CA so no new certificates can be signed with it. + Existing certificates remain valid until they expire. + + Returns: + 200: CA deleted successfully + 403: Not admin/owner + 404: Org or CA not found + """ + from gatehouse_app.models.ssh_ca.ca import CA + from gatehouse_app.models.organization.organization import Organization + from gatehouse_app.utils.constants import AuditAction + from gatehouse_app.models import AuditLog + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + try: + ca_name = ca.name + ca_type = ca.ca_type.value if hasattr(ca.ca_type, "value") else str(ca.ca_type) + ca.is_active = False + ca.delete(soft=True) + + AuditLog.log( + action=AuditAction.CA_DELETED, + user_id=g.current_user.id, + resource_type="CA", + resource_id=ca_id, + organization_id=org_id, + ip_address=request.remote_addr, + description=f"CA '{ca_name}' ({ca_type}) deleted", + ) + + return api_response( + data={"ca_id": ca_id}, + message="CA deleted successfully", + ) + except Exception as e: + db.session.rollback() + current_app.logger.exception("Failed to delete CA") + return api_response( + success=False, + message="Failed to delete CA", + status=500, + error_type="SERVER_ERROR", + ) + + +@api_v1_bp.route("/organizations//cas//rotate", methods=["POST"]) +@login_required +@require_admin +def rotate_org_ca(org_id, ca_id): + """Rotate (replace) a CA's key pair. + + Generates a new key pair of the same or different type. The old public key + fingerprint is returned so admins can update TrustedUserCAKeys / known_hosts + on their servers. All previously-issued certificates remain valid until they + expire but no new certificates will be signed with the old key. + + Request body (all optional): + key_type: "ed25519" (default keeps current), "rsa", or "ecdsa" + reason: Human-readable reason for the rotation + + Returns: + 200: CA rotated — { ca, old_fingerprint } + 403: Not admin/owner + 404: Org or CA not found + """ + from gatehouse_app.models.ssh_ca.ca import CA, KeyType + from gatehouse_app.models.organization.organization import Organization + from gatehouse_app.utils.crypto import compute_ssh_fingerprint + from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key + from gatehouse_app.utils.constants import AuditAction + from gatehouse_app.models import AuditLog + from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + data = request.get_json() or {} + new_key_type = data.get("key_type") or (ca.key_type.value if hasattr(ca.key_type, "value") else str(ca.key_type)) + reason = data.get("reason", "Admin-initiated key rotation") + + if new_key_type not in ("ed25519", "rsa", "ecdsa"): + return api_response( + success=False, + message="Invalid key_type. Must be one of: ed25519, rsa, ecdsa", + status=400, + error_type="VALIDATION_ERROR", + ) + + try: + old_fingerprint = ca.fingerprint + + # Generate new key pair + if new_key_type == "ed25519": + private_key_obj = Ed25519PrivateKey.generate() + elif new_key_type == "rsa": + private_key_obj = RsaPrivateKey.generate(4096) + else: # ecdsa + private_key_obj = EcdsaPrivateKey.generate() + + new_private_key = private_key_obj.to_string() + new_public_key = private_key_obj.public_key.to_string() + new_fingerprint = compute_ssh_fingerprint(new_public_key) + + # Encrypt the new private key before storing + encrypted_new_private_key = encrypt_ca_key(new_private_key) + + ca.rotate_key( + new_private_key=encrypted_new_private_key, + new_public_key=new_public_key, + new_fingerprint=new_fingerprint, + reason=reason, + ) + ca.key_type = KeyType(new_key_type) + db.session.commit() + + AuditLog.log( + action=AuditAction.CA_KEY_ROTATED, + user_id=g.current_user.id, + resource_type="CA", + resource_id=ca_id, + organization_id=org_id, + ip_address=request.remote_addr, + description=( + f"CA '{ca.name}' key rotated. " + f"Old fingerprint: {old_fingerprint}, New fingerprint: {new_fingerprint}. " + f"Reason: {reason}" + ), + ) + + return api_response( + data={ + "ca": ca.to_dict(), + "old_fingerprint": old_fingerprint, + }, + message="CA key rotated successfully. Update TrustedUserCAKeys / known_hosts on your servers.", + ) + except Exception as e: + db.session.rollback() + current_app.logger.exception("Failed to rotate CA key") + return api_response( + success=False, + message="Failed to rotate CA key", + status=500, + error_type="SERVER_ERROR", + ) + diff --git a/gatehouse_app/api/v1/policies.py b/gatehouse_app/api/v1/policies.py index ddf16f2..c97465f 100644 --- a/gatehouse_app/api/v1/policies.py +++ b/gatehouse_app/api/v1/policies.py @@ -195,20 +195,46 @@ def get_org_mfa_compliance(org_id): limit = min(int(request.args.get("limit", 100)), 100) offset = int(request.args.get("offset", 0)) + page = int(request.args.get("page", 1)) + page_size = min(int(request.args.get("page_size", limit)), 100) + + effective_offset = offset if request.args.get("offset") else (page - 1) * page_size compliance_list = MfaPolicyService.get_org_compliance_list( organization_id=org_id, status=status, - limit=limit, - offset=offset, + limit=page_size, + offset=effective_offset, ) + def format_member(c): + """Normalize compliance record to UI-expected shape.""" + if isinstance(c, dict): + return { + "user_id": c.get("user_id"), + "user_email": c.get("email"), + "user_name": c.get("full_name"), + "status": c.get("status"), + "deadline_at": c.get("deadline_at"), + "compliant_at": c.get("compliant_at"), + "last_notified_at": c.get("notified_at"), + } + return { + "user_id": getattr(c, "user_id", None), + "user_email": getattr(c, "email", None), + "user_name": getattr(c, "full_name", None), + "status": getattr(c, "status", None), + "deadline_at": getattr(c, "deadline_at", None), + "compliant_at": getattr(c, "compliant_at", None), + "last_notified_at": getattr(c, "notified_at", None), + } + return api_response( data={ - "compliance": compliance_list, + "members": [format_member(c) for c in compliance_list], "count": len(compliance_list), - "limit": limit, - "offset": offset, + "page": page, + "page_size": page_size, }, message="Compliance records retrieved successfully", ) @@ -325,12 +351,10 @@ def get_my_mfa_compliance(): return api_response( data={ - "mfa_compliance": { - "overall_status": compliance_summary.overall_status, - "missing_methods": compliance_summary.missing_methods, - "deadline_at": compliance_summary.deadline_at, - "orgs": orgs, - } + "overall_status": compliance_summary.overall_status, + "missing_methods": compliance_summary.missing_methods, + "deadline_at": compliance_summary.deadline_at, + "orgs": orgs, }, message="MFA compliance retrieved successfully", ) \ No newline at end of file diff --git a/gatehouse_app/api/v1/principals.py b/gatehouse_app/api/v1/principals.py new file mode 100644 index 0000000..0da8315 --- /dev/null +++ b/gatehouse_app/api/v1/principals.py @@ -0,0 +1,779 @@ +"""Principal endpoints.""" +from flask import g, request +from marshmallow import Schema, fields, validate, ValidationError + +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required +from gatehouse_app.models import Principal, PrincipalMembership, Department, DepartmentPrincipal +from gatehouse_app.services.organization_service import OrganizationService +from gatehouse_app.services.user_service import UserService +from gatehouse_app.exceptions import OrganizationNotFoundError +from gatehouse_app.extensions import db + + +class PrincipalCreateSchema(Schema): + """Schema for creating a principal.""" + name = fields.Str(required=True, validate=validate.Length(min=1, max=255)) + description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) + + +class PrincipalUpdateSchema(Schema): + """Schema for updating a principal.""" + name = fields.Str(validate=validate.Length(min=1, max=255)) + description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) + + +class AddPrincipalMemberSchema(Schema): + """Schema for adding a member to a principal.""" + email = fields.Email(required=True) + + +class LinkPrincipalSchema(Schema): + """Schema for linking principal to department.""" + department_id = fields.Str(required=True) + + +@api_v1_bp.route("/organizations//principals", methods=["GET"]) +@login_required +@full_access_required +def list_principals(org_id): + """ + List all principals in an organization. + + Args: + org_id: Organization ID + + Returns: + 200: List of principals + 401: Not authenticated + 403: Not a member + 404: Organization not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + if not org.is_member(g.current_user.id): + return api_response( + success=False, + message="You are not a member of this organization", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + principals = Principal.query.filter_by( + organization_id=org_id, + deleted_at=None + ).all() + + return api_response( + data={ + "principals": [p.to_dict() for p in principals], + "count": len(principals), + }, + message="Principals retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//principals", methods=["POST"]) +@login_required +@require_admin +@full_access_required +def create_principal(org_id): + """ + Create a new principal. + + Args: + org_id: Organization ID + + Request body: + name: Principal name (required) + description: Optional description + + Returns: + 201: Principal created successfully + 400: Validation error + 401: Not authenticated + 403: Not an admin + 404: Organization not found + 409: Principal name already exists + """ + try: + org = OrganizationService.get_organization_by_id(org_id) + + schema = PrincipalCreateSchema() + data = schema.load(request.json or {}) + + # Check if principal name already exists + existing = Principal.query.filter_by( + organization_id=org_id, + name=data["name"], + deleted_at=None + ).first() + + if existing: + return api_response( + success=False, + message=f"Principal '{data['name']}' already exists", + status=409, + error_type="CONFLICT", + ) + + # Create principal + principal = Principal( + organization_id=org_id, + name=data["name"], + description=data.get("description"), + ) + db.session.add(principal) + db.session.commit() + + return api_response( + data={"principal": principal.to_dict()}, + message="Principal created successfully", + status=201, + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + +@api_v1_bp.route("/organizations//principals/", methods=["GET"]) +@login_required +@full_access_required +def get_principal(org_id, principal_id): + """ + Get a specific principal. + + Args: + org_id: Organization ID + principal_id: Principal ID + + Returns: + 200: Principal data + 401: Not authenticated + 403: Not a member + 404: Organization or principal not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + if not org.is_member(g.current_user.id): + return api_response( + success=False, + message="You are not a member of this organization", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + principal = Principal.query.filter_by( + id=principal_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not principal: + return api_response( + success=False, + message="Principal not found", + status=404, + error_type="NOT_FOUND", + ) + + return api_response( + data={"principal": principal.to_dict()}, + message="Principal retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//principals/", methods=["PATCH"]) +@login_required +@require_admin +@full_access_required +def update_principal(org_id, principal_id): + """ + Update a principal. + + Args: + org_id: Organization ID + principal_id: Principal ID + + Request body: + name: Optional new name + description: Optional new description + + Returns: + 200: Principal updated successfully + 400: Validation error + 401: Not authenticated + 403: Not an admin + 404: Organization or principal not found + 409: Name already exists + """ + try: + org = OrganizationService.get_organization_by_id(org_id) + + principal = Principal.query.filter_by( + id=principal_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not principal: + return api_response( + success=False, + message="Principal not found", + status=404, + error_type="NOT_FOUND", + ) + + schema = PrincipalUpdateSchema() + data = schema.load(request.json or {}) + + # Check if new name already exists + if "name" in data and data["name"] != principal.name: + existing = Principal.query.filter_by( + organization_id=org_id, + name=data["name"], + deleted_at=None + ).first() + if existing: + return api_response( + success=False, + message=f"Principal '{data['name']}' already exists", + status=409, + error_type="CONFLICT", + ) + + # Update fields + for key, value in data.items(): + setattr(principal, key, value) + + db.session.commit() + + return api_response( + data={"principal": principal.to_dict()}, + message="Principal updated successfully", + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + +@api_v1_bp.route("/organizations//principals/", methods=["DELETE"]) +@login_required +@require_admin +@full_access_required +def delete_principal(org_id, principal_id): + """ + Delete a principal (soft delete). + + Args: + org_id: Organization ID + principal_id: Principal ID + + Returns: + 200: Principal deleted successfully + 401: Not authenticated + 403: Not an admin + 404: Organization or principal not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + principal = Principal.query.filter_by( + id=principal_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not principal: + return api_response( + success=False, + message="Principal not found", + status=404, + error_type="NOT_FOUND", + ) + + # Soft delete + principal.deleted_at = db.func.now() + db.session.commit() + + return api_response( + message="Principal deleted successfully", + ) + + +@api_v1_bp.route("/organizations//principals//members", methods=["GET"]) +@login_required +@full_access_required +def get_principal_members(org_id, principal_id): + """ + Get all members (direct + via department) with access to a principal. + + Args: + org_id: Organization ID + principal_id: Principal ID + + Returns: + 200: List of members + 401: Not authenticated + 403: Not a member + 404: Organization or principal not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + if not org.is_member(g.current_user.id): + return api_response( + success=False, + message="You are not a member of this organization", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + principal = Principal.query.filter_by( + id=principal_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not principal: + return api_response( + success=False, + message="Principal not found", + status=404, + error_type="NOT_FOUND", + ) + + # Get direct members + direct_members = PrincipalMembership.query.filter_by( + principal_id=principal_id, + deleted_at=None + ).all() + + all_users = set() + for membership in direct_members: + if membership.user.deleted_at is None: + all_users.add(membership.user) + + # Get members via departments + dept_links = DepartmentPrincipal.query.filter_by( + principal_id=principal_id, + deleted_at=None + ).all() + + for link in dept_links: + dept = link.department + if dept.deleted_at is None: + dept_members = dept.get_members(active_only=True) + for dept_member in dept_members: + if dept_member.user.deleted_at is None: + all_users.add(dept_member.user) + + users_data = [u.to_dict() for u in all_users] + + return api_response( + data={ + "members": users_data, + "count": len(users_data), + }, + message="Members retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//principals//members", methods=["POST"]) +@login_required +@require_admin +@full_access_required +def add_principal_member(org_id, principal_id): + """ + Add a direct member to a principal. + + Args: + org_id: Organization ID + principal_id: Principal ID + + Request body: + email: User email to add + + Returns: + 201: Member added successfully + 400: Validation error + 401: Not authenticated + 403: Not an admin + 404: Organization, principal, or user not found + 409: User already a member + """ + try: + org = OrganizationService.get_organization_by_id(org_id) + + principal = Principal.query.filter_by( + id=principal_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not principal: + return api_response( + success=False, + message="Principal not found", + status=404, + error_type="NOT_FOUND", + ) + + schema = AddPrincipalMemberSchema() + data = schema.load(request.json or {}) + + # Find user by email + user = UserService.get_user_by_email(data["email"]) + if not user: + return api_response( + success=False, + message="User not found", + status=404, + error_type="NOT_FOUND", + ) + + # Check if already a member + existing = PrincipalMembership.query.filter_by( + user_id=user.id, + principal_id=principal_id, + deleted_at=None + ).first() + + if existing: + return api_response( + success=False, + message="User is already a member of this principal", + status=409, + error_type="CONFLICT", + ) + + soft_deleted = PrincipalMembership.query.filter( + PrincipalMembership.user_id == user.id, + PrincipalMembership.principal_id == principal_id, + PrincipalMembership.deleted_at.isnot(None) + ).first() + + if soft_deleted: + soft_deleted.deleted_at = None + membership = soft_deleted + else: + membership = PrincipalMembership( + user_id=user.id, + principal_id=principal_id, + ) + db.session.add(membership) + + db.session.commit() + + member_dict = membership.to_dict() + member_dict["user"] = user.to_dict() + + return api_response( + data={"member": member_dict}, + message="Member added successfully", + status=201, + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + +@api_v1_bp.route("/organizations//principals//members/", methods=["DELETE"]) +@login_required +@require_admin +@full_access_required +def remove_principal_member(org_id, principal_id, user_id): + """ + Remove a direct member from a principal. + + Args: + org_id: Organization ID + principal_id: Principal ID + user_id: User ID to remove + + Returns: + 200: Member removed successfully + 401: Not authenticated + 403: Not an admin + 404: Organization, principal, or member not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + principal = Principal.query.filter_by( + id=principal_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not principal: + return api_response( + success=False, + message="Principal not found", + status=404, + error_type="NOT_FOUND", + ) + + membership = PrincipalMembership.query.filter_by( + user_id=user_id, + principal_id=principal_id, + deleted_at=None + ).first() + + if not membership: + return api_response( + success=False, + message="User is not a member of this principal", + status=404, + error_type="NOT_FOUND", + ) + + # Soft delete + membership.deleted_at = db.func.now() + db.session.commit() + + return api_response( + message="Member removed successfully", + ) + + +@api_v1_bp.route("/organizations//principals//departments", methods=["GET"]) +@login_required +@full_access_required +def get_principal_departments(org_id, principal_id): + """ + Get all departments this principal is assigned to. + + Args: + org_id: Organization ID + principal_id: Principal ID + + Returns: + 200: List of departments + 401: Not authenticated + 403: Not a member + 404: Organization or principal not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + if not org.is_member(g.current_user.id): + return api_response( + success=False, + message="You are not a member of this organization", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + principal = Principal.query.filter_by( + id=principal_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not principal: + return api_response( + success=False, + message="Principal not found", + status=404, + error_type="NOT_FOUND", + ) + + depts = principal.get_departments(active_only=True) + + return api_response( + data={ + "departments": [d.to_dict() for d in depts], + "count": len(depts), + }, + message="Departments retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//principals//departments/", methods=["POST"]) +@login_required +@require_admin +@full_access_required +def link_principal_to_department(org_id, principal_id, dept_id): + """ + Link a principal to a department. + + Args: + org_id: Organization ID + principal_id: Principal ID + dept_id: Department ID + + Returns: + 201: Principal linked successfully + 401: Not authenticated + 403: Not an admin + 404: Organization, principal, or department not found + 409: Already linked + """ + try: + org = OrganizationService.get_organization_by_id(org_id) + except OrganizationNotFoundError: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + principal = Principal.query.filter_by( + id=principal_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not principal: + return api_response( + success=False, + message="Principal not found", + status=404, + error_type="NOT_FOUND", + ) + + dept = Department.query.filter_by( + id=dept_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not dept: + return api_response( + success=False, + message="Department not found", + status=404, + error_type="NOT_FOUND", + ) + + existing = DepartmentPrincipal.query.filter_by( + department_id=dept_id, + principal_id=principal_id, + deleted_at=None + ).first() + + if existing: + return api_response( + success=False, + message="Principal is already linked to this department", + status=409, + error_type="CONFLICT", + ) + + soft_deleted = DepartmentPrincipal.query.filter( + DepartmentPrincipal.department_id == dept_id, + DepartmentPrincipal.principal_id == principal_id, + DepartmentPrincipal.deleted_at.isnot(None), + ).first() + + try: + if soft_deleted: + soft_deleted.deleted_at = None + else: + link = DepartmentPrincipal( + department_id=dept_id, + principal_id=principal_id, + ) + db.session.add(link) + db.session.commit() + except Exception: + db.session.rollback() + return api_response( + success=False, + message="Failed to link principal to department", + status=500, + error_type="SERVER_ERROR", + ) + + return api_response( + data={ + "principal": principal.to_dict(), + "department": dept.to_dict(), + }, + message="Principal linked to department successfully", + status=201, + ) + + +@api_v1_bp.route("/organizations//principals//departments/", methods=["DELETE"]) +@login_required +@require_admin +@full_access_required +def unlink_principal_from_department(org_id, principal_id, dept_id): + """ + Unlink a principal from a department. + + Args: + org_id: Organization ID + principal_id: Principal ID + dept_id: Department ID + + Returns: + 200: Principal unlinked successfully + 401: Not authenticated + 403: Not an admin + 404: Organization, principal, department, or link not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + principal = Principal.query.filter_by( + id=principal_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not principal: + return api_response( + success=False, + message="Principal not found", + status=404, + error_type="NOT_FOUND", + ) + + dept = Department.query.filter_by( + id=dept_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not dept: + return api_response( + success=False, + message="Department not found", + status=404, + error_type="NOT_FOUND", + ) + + link = DepartmentPrincipal.query.filter_by( + department_id=dept_id, + principal_id=principal_id, + deleted_at=None + ).first() + + if not link: + return api_response( + success=False, + message="Principal is not linked to this department", + status=404, + error_type="NOT_FOUND", + ) + + # Soft delete + link.deleted_at = db.func.now() + db.session.commit() + + return api_response( + message="Principal unlinked from department successfully", + ) diff --git a/gatehouse_app/api/v1/ssh.py b/gatehouse_app/api/v1/ssh.py new file mode 100644 index 0000000..8a54ae5 --- /dev/null +++ b/gatehouse_app/api/v1/ssh.py @@ -0,0 +1,1108 @@ +"""SSH Key and Certificate API routes.""" +from flask import Blueprint, request, jsonify, g +from sqlalchemy.exc import IntegrityError +from gatehouse_app.services.ssh_key_service import SSHKeyService +from gatehouse_app.services.ssh_ca_signing_service import ( + SSHCASigningService, + SSHCertificateSigningRequest, +) +from gatehouse_app.exceptions import ( + SSHKeyError, + SSHKeyNotFoundError, + SSHCertificateError, + ValidationError, + SSHKeyAlreadyExistsError, +) +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.models import AuditLog +from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog +from gatehouse_app.utils.decorators import login_required +from gatehouse_app.utils.response import api_response + +ssh_bp = Blueprint('ssh', __name__, url_prefix='/ssh') +ssh_key_service = SSHKeyService() +ssh_ca_service = SSHCASigningService() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _get_org_ca_for_user(user, ca_type: str = "user"): + """Return the active DB CA of the given type for the user's first org, or None. + + Args: + user: The current user object. + ca_type: ``"user"`` (default) or ``"host"`` — selects the CA that signs + the corresponding certificate type. + """ + try: + from gatehouse_app.models.ssh_ca.ca import CA, CaType + org_ids = [m.organization_id for m in user.organization_memberships] + if not org_ids: + return None + return CA.query.filter( + CA.organization_id.in_(org_ids), + CA.ca_type == CaType(ca_type), + CA.is_active == True, # noqa: E712 + ).first() + except Exception: + return None + + +def _get_or_create_system_ca(): + """ + Return a CA DB record representing the config-file CA. + + This is used as the ``ca_id`` FK when persisting certificates that were + signed by the globally-configured CA key (not an org-specific DB CA). + The record is created on first use and has no ``organization_id``. + """ + from gatehouse_app.extensions import db + from gatehouse_app.models.ssh_ca.ca import CA, KeyType + from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config + from gatehouse_app.utils.crypto import compute_ssh_fingerprint + import os + + try: + existing = CA.query.filter_by(name="system-config-ca").first() + if existing: + return existing + + cfg = get_ssh_ca_config() + key_path = cfg.get_str("ca_key_path", "").strip() + pub_key_path = key_path + ".pub" + + if not os.path.exists(pub_key_path): + return None + + with open(pub_key_path) as f: + pub_key = f.read().strip() + + # Load private key for the record (encrypt before storing in DB) + priv_key = "" + if os.path.exists(key_path): + with open(key_path) as f: + raw_priv_key = f.read() + try: + from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key + priv_key = encrypt_ca_key(raw_priv_key) + except Exception: + priv_key = raw_priv_key # fallback: store as-is if encryption unavailable + + fingerprint = compute_ssh_fingerprint(pub_key) + + # Check by fingerprint in case it was created under a different name + existing_by_fp = CA.query.filter_by(fingerprint=fingerprint).first() + if existing_by_fp: + return existing_by_fp + + system_ca = CA( + name="system-config-ca", + description="Global CA loaded from etc/ssh_ca.conf (ca_key_path)", + key_type=KeyType.ED25519, + private_key=priv_key, + public_key=pub_key, + fingerprint=fingerprint, + is_active=True, + default_cert_validity_hours=24, + max_cert_validity_hours=720, + ) + # organization_id is nullable=False in schema — we need a dummy org or + # need to allow NULL. Use None; the DB constraint will tell us quickly. + # If the migration enforces NOT NULL we'll catch the error gracefully. + db.session.add(system_ca) + db.session.commit() + return system_ca + except Exception as exc: + import logging + logging.getLogger(__name__).warning( + f"Could not upsert system-config-ca: {exc}" + ) + try: + db.session.rollback() + except Exception: + pass + return None + + +def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user', cert_identity=None): + """Save a signed certificate to the ssh_certificates table. + + Args: + user_id: UUID of the user + ssh_key_id: UUID of the SSH key that was signed + ca: CA model instance (may be None — cert still returned but not persisted) + signing_response: SSHCertificateSigningResponse + request_ip: Client IP address + cert_type_str: 'user' or 'host' (from the sign request) + cert_identity: Rich OpenSSH key_id string (e.g. "user@host (Name) [org:slug]"). + Falls back to str(ssh_key_id) when not provided. + + Returns: + SSHCertificate instance or None if persistence failed + """ + if ca is None: + return None + + try: + from gatehouse_app.extensions import db + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus + from gatehouse_app.models.ssh_ca.ca import CertType + + try: + resolved_cert_type = CertType(cert_type_str) + except ValueError: + resolved_cert_type = CertType.USER + + cert_record = SSHCertificate( + ca_id=ca.id, + user_id=user_id, + ssh_key_id=ssh_key_id, + certificate=signing_response.certificate, + serial=signing_response.serial, + key_id=cert_identity or str(ssh_key_id), + cert_type=resolved_cert_type, + principals=signing_response.principals, + valid_after=signing_response.valid_after, + valid_before=signing_response.valid_before, + revoked=False, + status=CertificateStatus.ISSUED, + request_ip=request_ip, + ) + db.session.add(cert_record) + db.session.commit() + return cert_record + except Exception as exc: + import logging + logging.getLogger(__name__).warning( + f"Failed to persist certificate to DB: {exc}" + ) + try: + from gatehouse_app.extensions import db as _db + _db.session.rollback() + except Exception: + pass + return None + + + +def _get_merged_dept_cert_policy(user_id): + """Return a merged cert policy view for the given user across all their departments. + + Rules for merging when a user belongs to multiple departments: + - ``allow_user_expiry``: True only if ALL departments allow it. + - ``default_expiry_hours``: minimum across departments (most restrictive). + - ``max_expiry_hours``: minimum across departments (most restrictive). + - ``extensions``: intersection — only extensions allowed by ALL departments. + + Returns a plain dict with keys: + allow_user_expiry, default_expiry_hours, max_expiry_hours, extensions + Or None if the user has no department memberships or no policies are configured. + """ + from gatehouse_app.models.organization.department import DepartmentMembership + from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy, STANDARD_EXTENSIONS + + memberships = DepartmentMembership.query.filter_by(user_id=user_id, deleted_at=None).all() + dept_ids = [m.department_id for m in memberships if m.department and m.department.deleted_at is None] + if not dept_ids: + return None + + policies = DepartmentCertPolicy.query.filter( + DepartmentCertPolicy.department_id.in_(dept_ids), + DepartmentCertPolicy.deleted_at.is_(None), + ).all() + if not policies: + return None + + allow_user_expiry = all(p.allow_user_expiry for p in policies) + default_expiry_hours = min(p.default_expiry_hours for p in policies) + max_expiry_hours = min(p.max_expiry_hours for p in policies) + + # Intersection of all_extensions() across policies + ext_sets = [set(p.all_extensions()) for p in policies] + extensions = list(ext_sets[0].intersection(*ext_sets[1:])) + + return { + "allow_user_expiry": allow_user_expiry, + "default_expiry_hours": default_expiry_hours, + "max_expiry_hours": max_expiry_hours, + "extensions": extensions, + } + + +@ssh_bp.route('/dept-cert-policy', methods=['GET']) +@login_required +def get_my_dept_cert_policy(): + """Return the merged department certificate policy for the current user. + + Admins always get allow_user_expiry=True so the frontend shows the expiry + picker for them regardless of the member-facing toggle setting. + """ + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.organization.department_cert_policy import STANDARD_EXTENSIONS + from gatehouse_app.utils.constants import OrganizationRole + + user = g.current_user + user_id = user.id + + # Check if caller is an org admin/owner + is_org_admin = OrganizationMember.query.filter( + OrganizationMember.user_id == user_id, + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).first() is not None + + policy = _get_merged_dept_cert_policy(user_id) + if policy is None: + policy = { + "allow_user_expiry": is_org_admin, # admins default to True even without a dept policy + "default_expiry_hours": 1, + "max_expiry_hours": 24, + "extensions": list(STANDARD_EXTENSIONS), + } + elif is_org_admin: + # Override allow_user_expiry for admins — they can always pick + policy = {**policy, "allow_user_expiry": True} + + return api_response(data={"policy": policy}, message="Certificate policy retrieved") + + +@ssh_bp.route('/keys', methods=['GET']) +@login_required +def list_ssh_keys(): + """Get all SSH keys for current user.""" + user_id = g.current_user.id + + keys = ssh_key_service.get_user_ssh_keys(user_id) + return api_response( + data={ + 'keys': [k.to_dict() for k in keys], + 'count': len(keys), + }, + message="SSH keys retrieved successfully" + ) + + +@ssh_bp.route('/keys', methods=['POST']) +@login_required +def add_ssh_key(): + """Add a new SSH public key for current user.""" + user_id = g.current_user.id + + data = request.get_json() + if not data: + return api_response(success=False, message='No JSON data provided', status=400, error_type='BAD_REQUEST') + + public_key = data.get('public_key') or data.get('key') + description = data.get('description') + + if not public_key: + return api_response(success=False, message='public_key is required', status=400, error_type='BAD_REQUEST') + + try: + ssh_key = ssh_key_service.add_ssh_key( + user_id=user_id, + public_key=public_key, + description=description, + ) + + AuditLog.log( + action=AuditAction.SSH_KEY_ADDED, + user_id=user_id, + resource_type='SSHKey', + resource_id=ssh_key.id, + ip_address=request.remote_addr, + ) + + return api_response(success=True, message='SSH key added', data=ssh_key.to_dict(), status=201) + + except SSHKeyAlreadyExistsError as e: + return api_response(success=False, message=e.message, status=409, error_type='SSH_KEY_ALREADY_EXISTS') + except IntegrityError: + return api_response(success=False, message='SSH key already exists', status=409, error_type='SSH_KEY_ALREADY_EXISTS') + except SSHKeyError as e: + return api_response(success=False, message=str(e), status=400, error_type='SSH_KEY_ERROR') + except ValidationError as e: + return api_response(success=False, message=str(e), status=400, error_type='VALIDATION_ERROR') + + +@ssh_bp.route('/keys/', methods=['GET']) +@login_required +def get_ssh_key(key_id): + """Get a specific SSH key.""" + user_id = g.current_user.id + + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + + if ssh_key.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + + return api_response(success=True, message='SSH key retrieved', data=ssh_key.to_dict(), status=200) + + except SSHKeyNotFoundError: + return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') + + +@ssh_bp.route('/keys/', methods=['DELETE']) +@login_required +def delete_ssh_key(key_id): + """Delete an SSH key.""" + user_id = g.current_user.id + + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + + if ssh_key.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + + ssh_key_service.delete_ssh_key(key_id) + + AuditLog.log( + action=AuditAction.SSH_KEY_DELETED, + user_id=user_id, + resource_type='SSHKey', + resource_id=key_id, + ip_address=request.remote_addr, + ) + + return api_response(success=True, message='SSH key deleted', data={'status': 'deleted'}, status=200) + + except SSHKeyNotFoundError: + return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') + + +@ssh_bp.route('/keys//verify', methods=['GET', 'POST']) +@login_required +def verify_ssh_key(key_id): + """Generate or verify SSH key ownership challenge.""" + user_id = g.current_user.id + + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + + if ssh_key.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + + # GET — return a fresh challenge + if request.method == 'GET': + challenge = ssh_key_service.generate_verification_challenge(key_id) + return api_response(success=True, message='Challenge generated', data={ + 'challenge_text': challenge, + 'validationText': challenge, + 'key_id': key_id, + }, status=200) + + # POST — verify signature or generate challenge + data = request.get_json() or {} + action = data.get('action', 'verify_signature') + + if action == 'verify_signature': + signature = data.get('signature') + if not signature: + return api_response(success=False, message='signature is required', status=400, error_type='BAD_REQUEST') + + try: + verified = ssh_key_service.verify_ssh_key_ownership(key_id, signature) + + AuditLog.log( + action=AuditAction.SSH_KEY_VERIFIED, + user_id=user_id, + resource_type='SSHKey', + resource_id=key_id, + ip_address=request.remote_addr, + success=verified, + ) + + return api_response(success=True, message='Verification complete', data={'verified': verified}, status=200) + + except Exception as e: + AuditLog.log( + action=AuditAction.SSH_KEY_VALIDATION_FAILED, + user_id=user_id, + resource_type='SSHKey', + resource_id=key_id, + ip_address=request.remote_addr, + success=False, + error_message=str(e), + ) + return api_response(success=False, message=str(e), status=400, error_type='VERIFICATION_FAILED') + + else: # generate_challenge + challenge = ssh_key_service.generate_verification_challenge(key_id) + return api_response(success=True, message='Challenge generated', data={ + 'challenge_text': challenge, + 'challenge': challenge, + }, status=200) + + except SSHKeyNotFoundError: + return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') + + +@ssh_bp.route('/keys//update-description', methods=['PATCH']) +@login_required +def update_ssh_key_description(key_id): + """Update SSH key description.""" + user_id = g.current_user.id + + data = request.get_json() + if not data or 'description' not in data: + return api_response(success=False, message='description is required', status=400, error_type='BAD_REQUEST') + + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + + if ssh_key.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + + updated_key = ssh_key_service.update_ssh_key_description(key_id, data['description']) + + return api_response(success=True, message='Description updated', data=updated_key.to_dict(), status=200) + + except SSHKeyNotFoundError: + return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') + + +@ssh_bp.route('/sign', methods=['POST']) +@login_required +def sign_certificate(): + """Sign an SSH certificate for the current user.""" + user = g.current_user + user_id = user.id + + # ── Check account suspension ────────────────────────────────────────────── + from gatehouse_app.utils.constants import UserStatus + if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + return api_response( + success=False, + message="Your account is suspended. Contact an administrator.", + status=403, + error_type="ACCOUNT_SUSPENDED", + ) + + data = request.get_json() + if not data: + return api_response(success=False, message="No JSON data provided", status=400, error_type="BAD_REQUEST") + + requested_principals = data.get('principals') or [] + cert_type = data.get('cert_type', 'user') + key_id = data.get('key_id') or data.get('cert_id') + expiry_hours = data.get('expiry_hours') + + # ── Log the request ─────────────────────────────────────────────────────── + AuditLog.log( + action=AuditAction.SSH_CERT_REQUESTED, + user_id=user_id, + resource_type='SSHCertificate', + ip_address=request.remote_addr, + description=( + f'{user.email} requested a certificate' + + (f' for principals: {", ".join(requested_principals)}' if requested_principals else '') + ), + ) + + # ── Resolve which principals the user is allowed to use ────────────────── + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.organization.principal import Principal, PrincipalMembership + from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal + from gatehouse_app.utils.constants import OrganizationRole + + allowed_principal_names = set() + + memberships = OrganizationMember.query.filter_by(user_id=user_id).all() + for om in memberships: + org = om.organization + if not org or org.deleted_at is not None: + continue + role = om.role + if role in (OrganizationRole.ADMIN, OrganizationRole.OWNER): + # Admin/owner can use any principal in the org + for p in Principal.query.filter_by(organization_id=org.id, deleted_at=None).all(): + allowed_principal_names.add(p.name) + else: + # Direct memberships + for pm in PrincipalMembership.query.filter_by(user_id=user_id, deleted_at=None).all(): + if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None: + allowed_principal_names.add(pm.principal.name) + # Via department + for dm in DepartmentMembership.query.filter_by(user_id=user_id, deleted_at=None).all(): + if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None: + for dp in DepartmentPrincipal.query.filter_by(department_id=dm.department_id, deleted_at=None).all(): + if dp.principal and dp.principal.deleted_at is None: + allowed_principal_names.add(dp.principal.name) + + # ── Determine final principals list ───────────────────────────────────── + if not requested_principals: + # Auto-resolve: use all principals the user is assigned to + principals = list(allowed_principal_names) + if not principals: + return api_response( + success=False, + message="You have no principals assigned. Ask an admin to add you to a principal.", + status=400, + error_type="NO_PRINCIPALS", + ) + else: + # Validate each requested principal is within the user's allowed set + invalid = [p for p in requested_principals if p not in allowed_principal_names] + if invalid: + return api_response( + success=False, + message=f"You are not authorised to request principals: {', '.join(invalid)}", + status=403, + error_type="UNAUTHORIZED_PRINCIPALS", + ) + principals = requested_principals + + # ── Key resolution ──────────────────────────────────────────────────────── + if not key_id: + verified_keys = ssh_key_service.get_user_verified_ssh_keys(user_id) + if not verified_keys: + return api_response( + success=False, + message="No verified SSH keys found. Verify a key before requesting a certificate.", + status=400, + error_type="NO_VERIFIED_KEYS", + ) + key_id = verified_keys[0].id + + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + except SSHKeyNotFoundError: + return api_response(success=False, message="SSH key not found", status=404, error_type="NOT_FOUND") + + if ssh_key.user_id != user_id: + return api_response(success=False, message="Forbidden", status=403, error_type="FORBIDDEN") + + if not ssh_key.verified: + return api_response( + success=False, + message="SSH key is not verified. Verify it before requesting a certificate.", + status=400, + error_type="KEY_NOT_VERIFIED", + ) + + db_ca = _get_org_ca_for_user(user, ca_type=cert_type) + if db_ca is None: + return api_response( + success=False, + message=( + "No active Certificate Authority is configured for your organization. " + "An admin must generate a CA on the Certificate Authorities page before " + "certificates can be issued." + ), + status=503, + error_type="CA_NOT_CONFIGURED", + ) + + # Determine if the caller is an org admin/owner (admins can always choose expiry) + is_org_admin = any( + om.role in (OrganizationRole.ADMIN, OrganizationRole.OWNER) + for om in memberships + if om.organization and om.organization.deleted_at is None + ) + + # ── Apply department certificate policy ─────────────────────────────────── + dept_policy = _get_merged_dept_cert_policy(user_id) + if dept_policy: + if is_org_admin: + # Admins can always choose their own expiry, but still capped at dept max + if expiry_hours is not None: + expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"]) + elif not dept_policy["allow_user_expiry"]: + # Regular members: ignore user-requested expiry; use dept default + expiry_hours = dept_policy["default_expiry_hours"] + else: + # Regular members allowed to pick, cap at dept maximum + if expiry_hours is not None: + expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"]) + policy_extensions = dept_policy["extensions"] + else: + policy_extensions = None # let signing service use its own defaults + + # ── Build rich key_id identity for the OpenSSH cert ───────────────────── + # This appears in `ssh-keygen -L -f cert.pub` as the Key ID field and + # is stored in the DB cert record so it's auditable. + org_slugs = sorted({ + om.organization.slug + for om in memberships + if om.organization and om.organization.deleted_at is None + and getattr(om.organization, 'slug', None) + }) + org_slug = org_slugs[0] if org_slugs else "unknown" + full_name = getattr(user, 'full_name', None) or getattr(user, 'name', None) or "unknown" + cert_identity = f"{user.email} ({full_name}) [org:{org_slug}]" + + signing_request = SSHCertificateSigningRequest( + ssh_public_key=ssh_key.payload, + principals=principals, + cert_type=cert_type, + key_id=cert_identity, + expiry_hours=int(expiry_hours) if expiry_hours else None, + extensions=policy_extensions, + ) + validation_errors = signing_request.validate() + if validation_errors: + return api_response( + success=False, + message="Invalid signing request", + status=400, + error_type="VALIDATION_ERROR", + error_details={"errors": validation_errors}, + ) + + try: + from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key + ca_private_key_pem = decrypt_ca_key(db_ca.private_key) + response = ssh_ca_service.sign_certificate( + signing_request, ca_private_key=ca_private_key_pem, ca_obj=db_ca + ) + except SSHCertificateError as e: + AuditLog.log( + action=AuditAction.SSH_CERT_FAILED, + user_id=user_id, + resource_type='SSHCertificate', + ip_address=request.remote_addr, + success=False, + error_message=str(e), + ) + return api_response(success=False, message=str(e), status=400, error_type="SIGNING_FAILED") + except Exception as e: + AuditLog.log( + action=AuditAction.SSH_CERT_FAILED, + user_id=user_id, + resource_type='SSHCertificate', + ip_address=request.remote_addr, + success=False, + error_message=str(e), + ) + return api_response(success=False, message="Certificate signing failed", status=500, error_type="SERVER_ERROR") + + cert_record = _persist_certificate( + user_id=user_id, + ssh_key_id=key_id, + ca=db_ca, + signing_response=response, + request_ip=request.remote_addr, + cert_type_str=cert_type, + cert_identity=cert_identity, + ) + + AuditLog.log( + action=AuditAction.SSH_CERT_ISSUED, + user_id=user_id, + resource_type='SSHCertificate', + resource_id=cert_record.id if cert_record else key_id, + ip_address=request.remote_addr, + description=( + f'Certificate serial={response.serial} issued for {user.email}; ' + f'principals: {", ".join(principals)}' + ), + extra_data={ + 'serial': response.serial, + 'key_id': cert_identity, + 'principals': principals, + 'ca_id': str(db_ca.id), + 'ssh_key_id': str(key_id), + }, + ) + + if cert_record: + CertificateAuditLog.log( + certificate_id=cert_record.id, + action='issued', + user_id=user_id, + ip_address=request.remote_addr, + user_agent=request.headers.get('User-Agent'), + message=( + f'Certificate serial={response.serial} issued for {user.email}; ' + f'principals: {", ".join(principals)}' + ), + extra_data={ + 'serial': response.serial, + 'key_id': cert_identity, + 'principals': principals, + 'ca_id': str(db_ca.id), + 'ssh_key_id': str(key_id), + 'valid_after': response.valid_after.isoformat() if response.valid_after else None, + 'valid_before': response.valid_before.isoformat() if response.valid_before else None, + }, + success=True, + ) + + result = { + 'certificate': response.certificate, + 'serial': response.serial, + 'principals': response.principals, + 'valid_after': response.valid_after.isoformat() if response.valid_after else None, + 'valid_before': response.valid_before.isoformat() if response.valid_before else None, + } + if cert_record: + result['cert_id'] = str(cert_record.id) + + return api_response(data=result, message="Certificate signed successfully", status=201) + + +@ssh_bp.route('/certificates', methods=['GET']) +@login_required +def list_certificates(): + """List all SSH certificates issued for the current user.""" + user_id = g.current_user.id + + try: + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate + certs = ( + SSHCertificate.query + .filter_by(user_id=user_id, deleted_at=None) + .order_by(SSHCertificate.created_at.desc()) + .all() + ) + return api_response( + data={ + 'certificates': [c.to_dict() for c in certs], + 'count': len(certs), + }, + message="Certificates retrieved successfully" + ) + except Exception as e: + return api_response( + success=False, + message=str(e), + status=500, + error_type='INTERNAL_ERROR' + ) + + +@ssh_bp.route('/certificates/', methods=['GET']) +@login_required +def get_certificate(cert_id): + """Get a specific issued certificate (metadata only).""" + user_id = g.current_user.id + + try: + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate + cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first() + if not cert: + return api_response(success=False, message='Certificate not found', status=404, error_type='NOT_FOUND') + if cert.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + data = cert.to_dict() + data['certificate'] = cert.certificate + return api_response(success=True, message='Certificate retrieved', data=data, status=200) + except Exception as e: + return api_response(success=False, message=str(e), status=500, error_type='INTERNAL_ERROR') + + +@ssh_bp.route('/certificates//revoke', methods=['POST']) +@login_required +def revoke_certificate(cert_id): + """Revoke an issued certificate.""" + user_id = g.current_user.id + + data = request.get_json() or {} + reason = data.get('reason', 'User requested revocation') + + try: + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate + cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first() + if not cert: + return api_response(success=False, message='Certificate not found', status=404, error_type='NOT_FOUND') + if cert.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + if cert.revoked: + return api_response(success=False, message='Certificate is already revoked', status=409, error_type='ALREADY_REVOKED') + + cert.revoke(reason=reason) + + AuditLog.log( + action=AuditAction.SSH_CERT_REVOKED, + user_id=user_id, + resource_type='SSHCertificate', + resource_id=cert_id, + ip_address=request.remote_addr, + description=f'Revoked: {reason}', + ) + + CertificateAuditLog.log( + certificate_id=cert_id, + action='revoked', + user_id=user_id, + ip_address=request.remote_addr, + user_agent=request.headers.get('User-Agent'), + message=f'Certificate revoked: {reason}', + success=True, + ) + + return api_response( + success=True, + message='Certificate revoked successfully', + data={'status': 'revoked', 'cert_id': cert_id, 'reason': reason}, + status=200, + ) + except Exception as e: + return api_response(success=False, message=str(e), status=500, error_type='INTERNAL_ERROR') + + +@ssh_bp.route('/ca/public-key', methods=['GET']) +@login_required +def get_ca_public_key(): + """ + Return the CA public key for this user's organization. + + Server admins should add this key to their host's ``TrustedUserCAKeys`` + directive so that certificates issued by gatehouse are trusted. + + Query parameters: + ca_type: 'user' (default) or 'host' — which CA's public key to return + format: 'openssh' (default) or 'text' — affects Content-Type only + + Returns: + { "public_key": "ssh-ed25519 AAAA...", + "fingerprint": "SHA256:...", + "ca_name": "..." } + """ + user = g.current_user + ca_type = request.args.get("ca_type", "user") + if ca_type not in ("user", "host"): + return api_response( + success=False, + message="ca_type must be 'user' or 'host'", + status=400, + error_type="BAD_REQUEST", + ) + + db_ca = _get_org_ca_for_user(user, ca_type=ca_type) + if db_ca: + return api_response( + data={ + 'public_key': db_ca.public_key, + 'fingerprint': db_ca.fingerprint, + 'ca_name': db_ca.name, + 'ca_type': ca_type, + 'source': 'db', + }, + message="CA public key retrieved successfully" + ) + + return api_response( + success=False, + message=( + f"No {ca_type} CA is configured for your organization. " + "An admin must generate one on the Certificate Authorities page." + ), + status=404, + error_type="CA_NOT_CONFIGURED", + ) + + +# --------------------------------------------------------------------------- +# CA Permissions +# --------------------------------------------------------------------------- + +@ssh_bp.route('/ca//permissions', methods=['GET']) +@login_required +def list_ca_permissions(ca_id): + """List permissions for a Certificate Authority. + + Returns: + 200: { ca_id, permissions: [...], open_to_all: bool } + 403: Not admin/owner + 404: CA not found + """ + from gatehouse_app.models.ssh_ca.ca import CA, CAPermission + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + + user = g.current_user + + ca = CA.query.filter_by(id=ca_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + # Verify user is admin/owner of the CA's org + if ca.organization_id: + membership = OrganizationMember.query.filter_by( + organization_id=ca.organization_id, + user_id=user.id, + deleted_at=None, + ).first() + if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER): + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + perms = CAPermission.query.filter_by(ca_id=ca_id, deleted_at=None).all() + perm_list = [] + for p in perms: + d = p.to_dict() + d["user_email"] = p.user.email if p.user else None + perm_list.append(d) + + return api_response( + data={ + "ca_id": ca_id, + "permissions": perm_list, + "open_to_all": len(perms) == 0, + }, + message="CA permissions retrieved", + ) + + +@ssh_bp.route('/ca//permissions', methods=['POST']) +@login_required +def add_ca_permission(ca_id): + """Grant a user permission on a Certificate Authority. + + Request body: + user_id: UUID of the user to grant access + permission: "sign" or "admin" (default: "sign") + + Returns: + 201: Permission granted + 400: Validation error + 403: Not admin/owner + 404: CA or user not found + 409: Permission already exists + """ + from gatehouse_app.models.ssh_ca.ca import CA, CAPermission + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.user import User + from gatehouse_app.utils.constants import OrganizationRole, AuditAction + from gatehouse_app.models import AuditLog + from gatehouse_app.extensions import db + + user = g.current_user + + ca = CA.query.filter_by(id=ca_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + # Verify user is admin/owner of the CA's org + if ca.organization_id: + membership = OrganizationMember.query.filter_by( + organization_id=ca.organization_id, + user_id=user.id, + deleted_at=None, + ).first() + if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER): + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + data = request.get_json() or {} + target_user_id = (data.get("user_id") or "").strip() + permission = data.get("permission", "sign") + + if not target_user_id: + return api_response(success=False, message="user_id is required", status=400, error_type="VALIDATION_ERROR") + if permission not in ("sign", "admin"): + return api_response( + success=False, + message="permission must be 'sign' or 'admin'", + status=400, + error_type="VALIDATION_ERROR", + ) + + target_user = User.query.filter_by(id=target_user_id, deleted_at=None).first() + if not target_user: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + # Check for duplicate + existing = CAPermission.query.filter_by( + ca_id=ca_id, user_id=target_user_id, deleted_at=None + ).first() + if existing: + # Update permission level if different + if existing.permission != permission: + existing.permission = permission + db.session.commit() + d = existing.to_dict() + d["user_email"] = target_user.email + return api_response( + data={"message": "Permission updated", "permission": d}, + message="Permission updated", + ) + return api_response( + success=False, + message="User already has this permission on the CA", + status=409, + error_type="DUPLICATE", + ) + + perm = CAPermission( + ca_id=ca_id, + user_id=target_user_id, + permission=permission, + ) + db.session.add(perm) + db.session.commit() + + AuditLog.log( + action=AuditAction.CA_UPDATED, + user_id=user.id, + resource_type="CAPermission", + resource_id=perm.id, + ip_address=request.remote_addr, + description=f"Granted '{permission}' on CA '{ca.name}' to user {target_user.email}", + ) + + d = perm.to_dict() + d["user_email"] = target_user.email + return api_response( + data={"message": "Permission granted", "permission": d}, + message="Permission granted", + status=201, + ) + + +@ssh_bp.route('/ca//permissions/', methods=['DELETE']) +@login_required +def remove_ca_permission(ca_id, target_user_id): + """Revoke a user's permission on a Certificate Authority. + + Returns: + 200: Permission revoked + 403: Not admin/owner + 404: CA or permission not found + """ + from gatehouse_app.models.ssh_ca.ca import CA, CAPermission + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole, AuditAction + from gatehouse_app.models import AuditLog + from gatehouse_app.extensions import db + + user = g.current_user + + ca = CA.query.filter_by(id=ca_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + # Verify user is admin/owner of the CA's org + if ca.organization_id: + membership = OrganizationMember.query.filter_by( + organization_id=ca.organization_id, + user_id=user.id, + deleted_at=None, + ).first() + if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER): + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + perm = CAPermission.query.filter_by( + ca_id=ca_id, user_id=target_user_id, deleted_at=None + ).first() + if not perm: + return api_response(success=False, message="Permission not found", status=404, error_type="NOT_FOUND") + + perm.delete(soft=True) + + AuditLog.log( + action=AuditAction.CA_UPDATED, + user_id=user.id, + resource_type="CAPermission", + resource_id=perm.id, + ip_address=request.remote_addr, + description=f"Revoked permission on CA '{ca.name}' from user {target_user_id}", + ) + + return api_response( + data={}, + message="Permission revoked", + ) + diff --git a/gatehouse_app/api/v1/users.py b/gatehouse_app/api/v1/users.py index 407e373..65c42c3 100644 --- a/gatehouse_app/api/v1/users.py +++ b/gatehouse_app/api/v1/users.py @@ -73,11 +73,51 @@ def delete_me(): """ Delete current user account (soft delete). + Blocked if the user is the sole owner of any organization that has other + active members — they must transfer ownership or dissolve those organizations + first. + Returns: 200: Account deleted successfully 401: Not authenticated + 409: User is sole owner of one or more organizations with other members """ - UserService.delete_user(g.current_user, soft=True) + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + + user = g.current_user + + # Find orgs where this user is the sole owner AND other members exist. + owned_memberships = OrganizationMember.query.filter_by( + user_id=user.id, + role=OrganizationRole.OWNER, + deleted_at=None, + ).all() + + blocked_orgs = [] + for membership in owned_memberships: + org = membership.organization + if org.deleted_at is not None: + continue + member_count = org.get_member_count() + if member_count > 1: + blocked_orgs.append(org.name) + + if blocked_orgs: + names = ", ".join(f'"{n}"' for n in blocked_orgs) + return api_response( + success=False, + message=( + f"You are the sole owner of {len(blocked_orgs)} organization" + f"{'s' if len(blocked_orgs) > 1 else ''}: {names}. " + "Transfer ownership or delete those organizations before deleting your account." + ), + status=409, + error_type="USER_IS_SOLE_OWNER", + error_details={"organizations": blocked_orgs}, + ) + + UserService.delete_user(user, soft=True) return api_response( message="Account deleted successfully", @@ -142,18 +182,686 @@ def change_password(): @full_access_required def get_my_organizations(): """ - Get all organizations current user is a member of. + Get all organizations current user is a member of, including the user's role. Returns: - 200: List of organizations + 200: List of organizations with role 401: Not authenticated """ - organizations = UserService.get_user_organizations(g.current_user) + from gatehouse_app.models.organization.organization_member import OrganizationMember + + user = g.current_user + memberships = OrganizationMember.query.filter_by( + user_id=user.id, + deleted_at=None, + ).all() + + orgs = [] + for membership in memberships: + org = membership.organization + if not org or org.deleted_at is not None: + continue + org_dict = org.to_dict() + org_dict["role"] = membership.role.value if hasattr(membership.role, "value") else str(membership.role) + orgs.append(org_dict) return api_response( data={ - "organizations": [org.to_dict() for org in organizations], - "count": len(organizations), + "organizations": orgs, + "count": len(orgs), }, message="Organizations retrieved successfully", ) + + +@api_v1_bp.route("/users/me/principals", methods=["GET"]) +@login_required +@full_access_required +def get_my_principals(): + """Return all principals the current user can sign certificates for. + + For each organization the user belongs to, returns: + - Their effective principals (direct membership + via department) + - Their role in that org (so the frontend can offer admin-mode selection) + - All principals in the org (admin/owner only — so they can pick any) + + Returns: + 200: { + orgs: [{ + org_id, org_name, role, + my_principals: [{id, name, description}], + all_principals: [{id, name, description}] # populated for admin/owner only + }] + } + """ + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.organization.principal import Principal, PrincipalMembership + from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal + from gatehouse_app.utils.constants import OrganizationRole + + user = g.current_user + user_id = user.id + + # Get all org memberships + memberships = OrganizationMember.query.filter_by( + user_id=user_id, + ).all() + + orgs_result = [] + for membership in memberships: + org = membership.organization + if not org or org.deleted_at is not None: + continue + + role = membership.role + is_admin = role in (OrganizationRole.ADMIN, OrganizationRole.OWNER) + + # Collect the user's effective principals for this org + # Track direct vs via-department separately + direct_principal_ids = set() + via_dept_principal_ids = set() + + # Direct memberships + direct = PrincipalMembership.query.filter_by( + user_id=user_id, + deleted_at=None, + ).all() + for pm in direct: + if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None: + direct_principal_ids.add(pm.principal_id) + + # Via department + dept_memberships = DepartmentMembership.query.filter_by( + user_id=user_id, + deleted_at=None, + ).all() + for dm in dept_memberships: + if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None: + dept_principals = DepartmentPrincipal.query.filter_by( + department_id=dm.department_id, + deleted_at=None, + ).all() + for dp in dept_principals: + if dp.principal and dp.principal.deleted_at is None: + via_dept_principal_ids.add(dp.principal_id) + + effective_principal_ids = direct_principal_ids | via_dept_principal_ids + + # Fetch principal objects + my_principals = [] + if effective_principal_ids: + my_p = Principal.query.filter( + Principal.id.in_(list(effective_principal_ids)), + Principal.deleted_at == None, + ).all() + my_principals = [ + { + "id": p.id, + "name": p.name, + "description": p.description, + # direct=True means removable via API; False=inherited via department + "direct": p.id in direct_principal_ids, + } + for p in my_p + ] + + # For admins/owners: also return all principals in the org + all_principals = [] + if is_admin: + all_p = Principal.query.filter_by( + organization_id=org.id, + deleted_at=None, + ).all() + all_principals = [{"id": p.id, "name": p.name, "description": p.description} for p in all_p] + + orgs_result.append({ + "org_id": org.id, + "org_name": org.name, + "role": role.value if hasattr(role, "value") else role, + "is_admin": is_admin, + "my_principals": my_principals, + "all_principals": all_principals, + }) + + return api_response( + data={"orgs": orgs_result}, + message="Principals retrieved successfully", + ) + + +@api_v1_bp.route("/admin/users", methods=["GET"]) +@login_required +@full_access_required +def admin_list_users(): + """List all users the caller has admin rights to see. + + The caller must be an OWNER or ADMIN of at least one organization. + Returns users that share an organization with the caller and where the + caller holds admin/owner role in that organization. + + Query params: + q – optional search string (matched against name/email) + page – page number (default 1) + per_page – page size (default 50, max 200) + """ + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.user.user import User as _User + from gatehouse_app.extensions import db as _db + from sqlalchemy import or_ + + caller = g.current_user + + # Find orgs where caller is admin/owner + admin_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == caller.id, + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).all() + + if not admin_memberships: + return api_response( + success=False, + message="Admin or owner role required", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + admin_org_ids = [m.organization_id for m in admin_memberships] + + # Collect user IDs in those orgs + member_rows = OrganizationMember.query.filter( + OrganizationMember.organization_id.in_(admin_org_ids), + OrganizationMember.deleted_at == None, + ).all() + visible_user_ids = list({row.user_id for row in member_rows}) + + # Optional search + q = request.args.get("q", "").strip() + try: + page = max(1, int(request.args.get("page", 1))) + per_page = min(200, max(1, int(request.args.get("per_page", 50)))) + except ValueError: + page, per_page = 1, 50 + + query = _User.query.filter( + _User.id.in_(visible_user_ids), + _User.deleted_at == None, + ) + if q: + like = f"%{q}%" + query = query.filter(or_(_User.email.ilike(like), _User.full_name.ilike(like))) + + total = query.count() + users = query.order_by(_User.email).offset((page - 1) * per_page).limit(per_page).all() + + member_lookup: dict = {} + for row in member_rows: + if row.user_id not in member_lookup: + member_lookup[row.user_id] = { + "organization_id": row.organization_id, + "role": row.role.value if hasattr(row.role, "value") else row.role, + } + + users_data = [] + for u in users: + d = u.to_dict() + m = member_lookup.get(u.id, {}) + d["org_role"] = m.get("role", "member") + d["org_id"] = m.get("organization_id") + users_data.append(d) + + return api_response( + data={ + "users": users_data, + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + }, + message="Users retrieved successfully", + ) + + +@api_v1_bp.route("/admin/users/", methods=["GET"]) +@login_required +@full_access_required +def admin_get_user(user_id): + """Get a single user's profile (admin view with SSH keys).""" + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.user.user import User as _User + from gatehouse_app.models.ssh_ca.ssh_key import SSHKey + + caller = g.current_user + + target = _User.query.filter_by(id=user_id, deleted_at=None).first() + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + # Verify caller has admin access to a shared org + target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None} + has_access = OrganizationMember.query.filter( + OrganizationMember.user_id == caller.id, + OrganizationMember.organization_id.in_(target_org_ids), + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).first() is not None + + if not has_access: + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + ssh_keys = SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all() + + return api_response( + data={ + "user": target.to_dict(), + "ssh_keys": [k.to_dict() for k in ssh_keys], + }, + message="User retrieved", + ) + + +@api_v1_bp.route("/admin/users//suspend", methods=["POST"]) +@login_required +@full_access_required +def admin_suspend_user(user_id): + """Suspend a user account (blocks CA issuance and login). + + The caller must be an OWNER or ADMIN of an organization the target user belongs to. + """ + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.user.user import User as _User + from gatehouse_app.extensions import db as _db + from gatehouse_app.utils.constants import UserStatus, AuditAction + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + target = _User.query.filter_by(id=user_id, deleted_at=None).first() + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if target.id == caller.id: + return api_response(success=False, message="Cannot suspend yourself", status=400, error_type="BAD_REQUEST") + + # Verify caller has admin access to a shared org + target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None} + admin_in_shared_org = OrganizationMember.query.filter( + OrganizationMember.user_id == caller.id, + OrganizationMember.organization_id.in_(target_org_ids), + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).first() + + if not admin_in_shared_org: + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + # ── Owner protection ────────────────────────────────────────────────────── + # An org owner cannot be suspended until they transfer ownership. + from gatehouse_app.utils.constants import OrganizationRole + owner_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == target.id, + OrganizationMember.role == OrganizationRole.OWNER, + OrganizationMember.deleted_at == None, + ).all() + if owner_memberships: + org_names = [ + m.organization.name + for m in owner_memberships + if m.organization and not m.organization.deleted_at + ] + return api_response( + success=False, + message=( + f"Cannot suspend an organization owner. " + f"{target.email} is the owner of: {', '.join(org_names)}. " + "Transfer ownership to another member first." + ), + status=403, + error_type="OWNER_PROTECTION", + ) + + if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT") + + target.status = UserStatus.SUSPENDED + _db.session.commit() + + AuditService.log_action( + action=AuditAction.USER_SUSPEND, + user_id=caller.id, + organization_id=admin_in_shared_org.organization_id, + resource_type="user", + resource_id=str(target.id), + description=f"Admin suspended user {target.email}", + metadata={"target_user_id": str(target.id), "target_email": target.email}, + ) + + return api_response(data={"user": target.to_dict()}, message="User suspended successfully") + + +@api_v1_bp.route("/admin/users//unsuspend", methods=["POST"]) +@login_required +@full_access_required +def admin_unsuspend_user(user_id): + """Restore a suspended user account to active status.""" + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.user.user import User as _User + from gatehouse_app.extensions import db as _db + from gatehouse_app.utils.constants import UserStatus, AuditAction + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + target = _User.query.filter_by(id=user_id, deleted_at=None).first() + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + # Verify caller has admin access to a shared org + target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None} + admin_in_shared_org = OrganizationMember.query.filter( + OrganizationMember.user_id == caller.id, + OrganizationMember.organization_id.in_(target_org_ids), + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).first() + + if not admin_in_shared_org: + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + if target.status not in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + return api_response(success=False, message="User is not suspended", status=409, error_type="CONFLICT") + + target.status = UserStatus.ACTIVE + _db.session.commit() + + AuditService.log_action( + action=AuditAction.USER_UNSUSPEND, + user_id=caller.id, + organization_id=admin_in_shared_org.organization_id, + resource_type="user", + resource_id=str(target.id), + description=f"Admin unsuspended user {target.email}", + metadata={"target_user_id": str(target.id), "target_email": target.email}, + ) + + return api_response(data={"user": target.to_dict()}, message="User unsuspended successfully") + + +@api_v1_bp.route("/users/me/invites", methods=["GET"]) +@login_required +def get_my_pending_invites(): + """Return pending (unaccepted, non-expired) invitations for the current user's email.""" + from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + from datetime import datetime, timezone + + user = g.current_user + now = datetime.now(timezone.utc) + + invites = OrgInviteToken.query.filter( + OrgInviteToken.email == user.email, + OrgInviteToken.accepted_at.is_(None), + OrgInviteToken.expires_at > now, + OrgInviteToken.deleted_at.is_(None), + ).all() + + return api_response( + data={ + "invites": [ + { + "token": i.token, + "organization": {"id": str(i.organization_id), "name": i.organization.name}, + "role": i.role, + "expires_at": i.expires_at.isoformat(), + } + for i in invites + ] + }, + message="Pending invitations retrieved", + ) + + +@api_v1_bp.route("/users/me/memberships", methods=["GET"]) +@login_required +def get_my_memberships(): + """Return the current user's department and principal memberships across all orgs. + + Returns: + 200: { + orgs: [{ + org_id, org_name, role, + departments: [{id, name, description}], + principals: [{id, name, description, via_department: bool}] + }] + } + """ + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal, Department + from gatehouse_app.models.organization.principal import Principal, PrincipalMembership + + user = g.current_user + + memberships = OrganizationMember.query.filter_by( + user_id=user.id, + deleted_at=None, + ).all() + + orgs_result = [] + for membership in memberships: + org = membership.organization + if not org or org.deleted_at is not None: + continue + + # Departments the user belongs to + dept_memberships = DepartmentMembership.query.filter_by( + user_id=user.id, + deleted_at=None, + ).all() + user_depts = [ + dm.department for dm in dept_memberships + if dm.department + and dm.department.organization_id == org.id + and dm.department.deleted_at is None + ] + + # Principals: direct + direct_pm = PrincipalMembership.query.filter_by( + user_id=user.id, + deleted_at=None, + ).all() + direct_principal_ids = { + pm.principal_id for pm in direct_pm + if pm.principal + and pm.principal.organization_id == org.id + and pm.principal.deleted_at is None + } + + # Principals: via department + via_dept_principal_ids = set() + for dept in user_depts: + for dp in DepartmentPrincipal.query.filter_by(department_id=dept.id, deleted_at=None).all(): + if dp.principal and dp.principal.deleted_at is None: + via_dept_principal_ids.add(dp.principal_id) + + all_principal_ids = direct_principal_ids | via_dept_principal_ids + principals_list = [] + if all_principal_ids: + for p in Principal.query.filter( + Principal.id.in_(list(all_principal_ids)), + Principal.deleted_at == None, + ).all(): + principals_list.append({ + "id": str(p.id), + "name": p.name, + "description": p.description, + "via_department": p.id not in direct_principal_ids, + }) + + role = membership.role + orgs_result.append({ + "org_id": str(org.id), + "org_name": org.name, + "role": role.value if hasattr(role, "value") else role, + "departments": [ + {"id": str(d.id), "name": d.name, "description": d.description} + for d in user_depts + ], + "principals": principals_list, + }) + + return api_response( + data={"orgs": orgs_result}, + message="Memberships retrieved", + ) + + +@api_v1_bp.route("/admin/users//delete", methods=["POST"]) +@login_required +@full_access_required +def admin_hard_delete_user(user_id): + """Permanently delete a user and ALL associated data (hard delete, irreversible). + + Required body: {"confirm": true} + + Pre-conditions: + - Caller is OWNER or ADMIN of a shared org with the target. + - Cannot delete yourself. + - Target must not be the OWNER of any active organization (transfer first). + + Side-effects: + - All active SSH certificates are revoked before deletion. + - The user row and all cascaded rows are hard-deleted from the database. + - An audit log entry is written by the *caller* (so it is not lost with the user). + """ + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.user.user import User as _User + from gatehouse_app.extensions import db as _db + from gatehouse_app.utils.constants import UserStatus, AuditAction, OrganizationRole + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + data = request.get_json() or {} + + if not data.get("confirm"): + return api_response( + success=False, + message="Deletion requires explicit confirmation. Send {\"confirm\": true} to proceed.", + status=400, + error_type="CONFIRMATION_REQUIRED", + ) + + target = _User.query.filter_by(id=user_id).first() + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if target.id == caller.id: + return api_response( + success=False, + message="Cannot delete your own account via this endpoint.", + status=400, + error_type="BAD_REQUEST", + ) + + # Caller must be OWNER/ADMIN of a shared org. + # Include soft-deleted memberships so that already-soft-deleted users can + # still be hard-deleted by an admin who shared an org with them. + target_org_ids = {m.organization_id for m in target.organization_memberships} + admin_in_shared_org = OrganizationMember.query.filter( + OrganizationMember.user_id == caller.id, + OrganizationMember.organization_id.in_(target_org_ids), + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).first() + if not admin_in_shared_org: + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + # Block deletion if target is an org owner — they must transfer first + owner_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == target.id, + OrganizationMember.role == OrganizationRole.OWNER, + OrganizationMember.deleted_at == None, + ).all() + if owner_memberships: + org_names = [ + m.organization.name + for m in owner_memberships + if m.organization and not m.organization.deleted_at + ] + return api_response( + success=False, + message=( + f"Cannot delete an organization owner. " + f"{target.email} is the owner of: {', '.join(org_names)}. " + "Transfer ownership to another member first." + ), + status=403, + error_type="OWNER_PROTECTION", + ) + + # ── Collect counts for audit metadata ──────────────────────────────────── + from gatehouse_app.models.ssh_ca.ssh_key import SSHKey + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus + + ssh_key_count = SSHKey.query.filter_by(user_id=target.id, deleted_at=None).count() + active_cert_count = SSHCertificate.query.filter_by( + user_id=target.id, revoked=False + ).filter(SSHCertificate.deleted_at == None).count() + + # ── Revoke all active SSH certificates before deletion ─────────────────── + active_certs = SSHCertificate.query.filter_by( + user_id=target.id, revoked=False + ).filter(SSHCertificate.deleted_at == None).all() + for cert in active_certs: + try: + cert.revoke("account_deleted") + except Exception: + pass + + if active_certs: + try: + _db.session.flush() + except Exception: + pass + + # ── Hard delete ─────────────────────────────────────────────────────────── + target_email = target.email # capture before deletion + target_id_str = str(target.id) + + try: + _db.session.delete(target) # cascades to all child tables + _db.session.flush() + except Exception as exc: + _db.session.rollback() + import logging + logging.getLogger(__name__).error(f"Hard delete failed for {target_id_str}: {exc}") + return api_response( + success=False, + message="Failed to delete user account. Please try again.", + status=500, + error_type="SERVER_ERROR", + ) + + # ── Audit log (written as the caller so it survives the deletion) ───────── + AuditService.log_action( + action=AuditAction.USER_HARD_DELETE, + user_id=caller.id, + organization_id=admin_in_shared_org.organization_id, + resource_type="user", + resource_id=target_id_str, + description=f"Admin permanently deleted user account: {target_email}", + metadata={ + "deleted_user_id": target_id_str, + "deleted_user_email": target_email, + "ssh_keys_deleted": ssh_key_count, + "certs_revoked": active_cert_count, + }, + ) + + _db.session.commit() + + return api_response( + message=f"User account {target_email} has been permanently deleted.", + data={ + "deleted_user_id": target_id_str, + "deleted_user_email": target_email, + "ssh_keys_deleted": ssh_key_count, + "certs_revoked": active_cert_count, + }, + ) diff --git a/gatehouse_app/config/ssh_ca_config.py b/gatehouse_app/config/ssh_ca_config.py new file mode 100644 index 0000000..7b62a96 --- /dev/null +++ b/gatehouse_app/config/ssh_ca_config.py @@ -0,0 +1,234 @@ +"""SSH CA Configuration Manager. + +Handles loading and managing SSH CA configuration from etc/ssh_ca.conf +and environment variables. +""" +import os +import configparser +from pathlib import Path +from typing import Optional, Union + + +class SSHCAConfig: + """Configuration manager for SSH CA settings. + + Loads configuration from: + 1. etc/ssh_ca.conf file + 2. Environment variables (override config file) + 3. Application environment-specific defaults + + Example: + config = SSHCAConfig() + cert_hours = config.get_int('cert_validity_hours') + key_path = config.get_str('ca_key_path') + """ + + # Configuration file location (relative to project root) + DEFAULT_CONFIG_FILE = "etc/ssh_ca.conf" + + # Default values if config file is missing + DEFAULTS = { + 'cert_validity_hours': '8', + 'max_cert_validity_hours': '720', + 'ca_key_path': '', + 'max_principals_per_cert': '256', + 'max_key_id_length': '255', + 'verification_challenge_max_age': '24', + 'auto_delete_unverified_days': '30', + } + + def __init__(self, config_file: Optional[str] = None, environment: Optional[str] = None): + """Initialize SSH CA configuration. + + Args: + config_file: Path to config file (default: etc/ssh_ca.conf) + environment: Environment name (development, production, testing) + Default: value of FLASK_ENV or 'development' + """ + self.config = configparser.ConfigParser() + + # Determine environment + if environment is None: + environment = os.environ.get('FLASK_ENV', 'development') + self.environment = environment + + # Load config file + if config_file is None: + # Try to find config file relative to this module + module_dir = Path(__file__).parent.parent.parent + config_file = module_dir / self.DEFAULT_CONFIG_FILE + + self.config_file = config_file + self._load_config() + + def _load_config(self): + """Load configuration from file and apply environment-specific overrides.""" + # Set defaults + self.config['default'] = self.DEFAULTS.copy() + + # Load config file if it exists + if Path(self.config_file).exists(): + self.config.read(self.config_file) + + # Apply environment-specific configuration + if self.environment in self.config: + for key, value in self.config[self.environment].items(): + self.config['default'][key] = value + + def get_str(self, key: str, default: Optional[str] = None) -> str: + """Get a string configuration value. + + First checks environment variables (SSH_CA_), then config file. + + Args: + key: Configuration key + default: Default value if not found + + Returns: + Configuration value as string + """ + env_key = f"SSH_CA_{key.upper()}" + + # Check environment variable first + if env_key in os.environ: + return os.environ[env_key] + + # Check config file + if key in self.config['default']: + value = self.config['default'][key] + # Handle environment variable substitution + return os.path.expandvars(value) + + # Return default + if default is not None: + return default + + return self.DEFAULTS.get(key, '') + + def get_int(self, key: str, default: Optional[int] = None) -> int: + """Get an integer configuration value. + + Args: + key: Configuration key + default: Default value if not found + + Returns: + Configuration value as integer + + Raises: + ValueError: If value cannot be converted to integer + """ + str_value = self.get_str(key) + if not str_value: + if default is not None: + return default + raise ValueError(f"No value found for {key}") + + try: + return int(str_value) + except ValueError: + if default is not None: + return default + raise ValueError(f"Configuration {key}={str_value} is not a valid integer") + + def get_bool(self, key: str, default: Optional[bool] = None) -> bool: + """Get a boolean configuration value. + + Args: + key: Configuration key + default: Default value if not found + + Returns: + Configuration value as boolean + """ + str_value = self.get_str(key) + if not str_value: + if default is not None: + return default + return False + + return str_value.lower() in ('true', '1', 'yes', 'on') + + def get_list(self, key: str, delimiter: str = ',', default: Optional[list] = None) -> list: + """Get a comma-separated list configuration value. + + Args: + key: Configuration key + delimiter: Delimiter between items (default: comma) + default: Default value if not found + + Returns: + Configuration value as list of strings + """ + str_value = self.get_str(key) + if not str_value: + if default is not None: + return default + return [] + + return [item.strip() for item in str_value.split(delimiter) if item.strip()] + + def validate_config(self) -> list: + """Validate SSH CA configuration. + + Returns: + List of validation error messages (empty if valid) + """ + errors = [] + + # Check cert validity hours + try: + validity = self.get_int('cert_validity_hours') + max_validity = self.get_int('max_cert_validity_hours') + if validity > max_validity: + errors.append( + f"cert_validity_hours ({validity}) > max_cert_validity_hours ({max_validity})" + ) + except ValueError as e: + errors.append(f"Invalid cert validity hours: {e}") + + # Check principals limit + max_principals = self.get_int('max_principals_per_cert') + if max_principals > 256: + errors.append(f"max_principals_per_cert ({max_principals}) exceeds SSH limit of 256") + + # Check ca_key_path is set + if not self.get_str('ca_key_path', '').strip(): + errors.append("ca_key_path is not set") + + return errors + + def to_dict(self) -> dict: + """Export current configuration as dictionary. + """ + return dict(self.config['default']) + + def __repr__(self): + """String representation of configuration.""" + return f"" + + +# Global configuration instance +_config_instance = None + + +def get_ssh_ca_config() -> SSHCAConfig: + """Get the global SSH CA configuration instance. + + This function uses a singleton pattern to ensure only one + configuration instance is created and reused. + + Returns: + SSHCAConfig instance + """ + global _config_instance + if _config_instance is None: + _config_instance = SSHCAConfig() + return _config_instance + + +def reset_config_instance(): + """Reset the global configuration instance. + """ + global _config_instance + _config_instance = None diff --git a/gatehouse_app/exceptions/__init__.py b/gatehouse_app/exceptions/__init__.py index bd8e6f8..06c3f34 100644 --- a/gatehouse_app/exceptions/__init__.py +++ b/gatehouse_app/exceptions/__init__.py @@ -19,6 +19,21 @@ from gatehouse_app.exceptions.validation_exceptions import ( OrganizationNotFoundError, UserNotFoundError, ) +from gatehouse_app.exceptions.ssh_exceptions import ( + SSHCAError, + SSHKeyError, + SSHKeyNotFoundError, + SSHKeyAlreadyExistsError, + SSHKeyNotVerifiedError, + SSHCertificateError, + SSHCertificateNotFoundError, + CAError, + CANotFoundError, + PrincipalError, + PrincipalNotFoundError, + DepartmentError, + DepartmentNotFoundError, +) __all__ = [ "BaseAPIException", @@ -37,4 +52,18 @@ __all__ = [ "EmailAlreadyExistsError", "OrganizationNotFoundError", "UserNotFoundError", + "SSHCAError", + "SSHKeyError", + "SSHKeyNotFoundError", + "SSHKeyAlreadyExistsError", + "SSHKeyNotVerifiedError", + "SSHCertificateError", + "SSHCertificateNotFoundError", + "CAError", + "CANotFoundError", + "PrincipalError", + "PrincipalNotFoundError", + "DepartmentError", + "DepartmentNotFoundError", ] + diff --git a/gatehouse_app/exceptions/base.py b/gatehouse_app/exceptions/base.py index 3f1f42b..f7cfb0e 100644 --- a/gatehouse_app/exceptions/base.py +++ b/gatehouse_app/exceptions/base.py @@ -16,9 +16,10 @@ class BaseAPIException(Exception): message: Custom error message error_details: Additional error details dictionary """ - super().__init__() + super().__init__(self.message) if message: self.message = message + super().__init__(message) # update args so str(e) works self.error_details = error_details or {} def to_dict(self): diff --git a/gatehouse_app/exceptions/ssh_exceptions.py b/gatehouse_app/exceptions/ssh_exceptions.py new file mode 100644 index 0000000..141c1e9 --- /dev/null +++ b/gatehouse_app/exceptions/ssh_exceptions.py @@ -0,0 +1,93 @@ +"""SSH-specific exceptions.""" +from gatehouse_app.exceptions.base import BaseAPIException + + +class SSHCAError(BaseAPIException): + """Base exception for SSH CA operations.""" + + status_code = 500 + error_type = "SSH_CA_ERROR" + + +class SSHKeyError(BaseAPIException): + """Exception for SSH key operations.""" + + status_code = 400 + error_type = "SSH_KEY_ERROR" + + +class SSHKeyNotFoundError(BaseAPIException): + """SSH key not found.""" + + status_code = 404 + error_type = "SSH_KEY_NOT_FOUND" + + +class SSHKeyAlreadyExistsError(BaseAPIException): + """SSH key already exists (duplicate fingerprint).""" + + status_code = 409 + error_type = "SSH_KEY_ALREADY_EXISTS" + + +class SSHKeyNotVerifiedError(BaseAPIException): + """SSH key has not been verified.""" + + status_code = 400 + error_type = "SSH_KEY_NOT_VERIFIED" + + +class SSHCertificateError(BaseAPIException): + """Exception for SSH certificate operations.""" + + status_code = 400 + error_type = "SSH_CERT_ERROR" + + +class SSHCertificateNotFoundError(BaseAPIException): + """SSH certificate not found.""" + + status_code = 404 + error_type = "SSH_CERT_NOT_FOUND" + + +class CAError(BaseAPIException): + """Exception for Certificate Authority operations.""" + + status_code = 400 + error_type = "CA_ERROR" + + +class CANotFoundError(BaseAPIException): + """Certificate Authority not found.""" + + status_code = 404 + error_type = "CA_NOT_FOUND" + + +class PrincipalError(BaseAPIException): + """Exception for principal operations.""" + + status_code = 400 + error_type = "PRINCIPAL_ERROR" + + +class PrincipalNotFoundError(BaseAPIException): + """Principal not found.""" + + status_code = 404 + error_type = "PRINCIPAL_NOT_FOUND" + + +class DepartmentError(BaseAPIException): + """Exception for department operations.""" + + status_code = 400 + error_type = "DEPARTMENT_ERROR" + + +class DepartmentNotFoundError(BaseAPIException): + """Department not found.""" + + status_code = 404 + error_type = "DEPARTMENT_NOT_FOUND" diff --git a/gatehouse_app/middleware/security_headers.py b/gatehouse_app/middleware/security_headers.py index eb1b6a4..ea03dc4 100644 --- a/gatehouse_app/middleware/security_headers.py +++ b/gatehouse_app/middleware/security_headers.py @@ -1,5 +1,6 @@ """Security headers middleware.""" -from flask import request +import os +from flask import current_app, request class SecurityHeadersMiddleware: @@ -34,13 +35,22 @@ class SecurityHeadersMiddleware: ) # Content Security Policy + try: + flask_env = current_app.config.get("ENV") or os.environ.get("FLASK_ENV", "production") + if flask_env == "development": + connect_src = "connect-src 'self' http://localhost:5000 http://127.0.0.1:5000" + else: + connect_src = "connect-src 'self'" + except RuntimeError: + connect_src = "connect-src 'self'" + response.headers["Content-Security-Policy"] = ( "default-src 'self'; " "script-src 'self' 'unsafe-inline'; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data: https:; " "font-src 'self' data:; " - "connect-src 'self'" + + connect_src ) # Referrer Policy diff --git a/gatehouse_app/models/__init__.py b/gatehouse_app/models/__init__.py index 99ef6fb..f15764c 100644 --- a/gatehouse_app/models/__init__.py +++ b/gatehouse_app/models/__init__.py @@ -1,43 +1,149 @@ -"""Models package.""" -from gatehouse_app.models.base import BaseModel -from gatehouse_app.models.user import User -from gatehouse_app.models.organization import Organization -from gatehouse_app.models.organization_member import OrganizationMember -from gatehouse_app.models.authentication_method import ( +"""Models package. + +Sub-packages +------------ +models.user — User, Session +models.organization — Organization, OrganizationMember, Department, + DepartmentMembership, DepartmentPrincipal, + DepartmentCertPolicy, Principal, PrincipalMembership, + OrgInviteToken +models.auth — AuthenticationMethod, ApplicationProviderConfig, + OrganizationProviderOverride, OAuthState, + AuditLog, PasswordResetToken, EmailVerificationToken +models.oidc — OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCSession, + OIDCTokenMetadata, OIDCAuditLog, OidcJwksKey +models.ssh_ca — CA, KeyType, CertType, CaType, CAPermission, + SSHKey, SSHCertificate, CertificateStatus, + CertificateAuditLog +models.security — OrganizationSecurityPolicy, UserSecurityPolicy, + MfaPolicyCompliance + +All names are re-exported here so that existing code using the flat import +style (``from gatehouse_app.models import X``) or the old per-file style +(``from gatehouse_app.models.user import User``) continue to work unchanged. +""" + +# ── Base ────────────────────────────────────────────────────────────────────── +from gatehouse_app.models.base import BaseModel # noqa: F401 + +# ── User ────────────────────────────────────────────────────────────────────── +from gatehouse_app.models.user.user import User # noqa: F401 +from gatehouse_app.models.user.session import Session # noqa: F401 + +# ── Organization ────────────────────────────────────────────────────────────── +from gatehouse_app.models.organization.organization import Organization # noqa: F401 +from gatehouse_app.models.organization.organization_member import ( # noqa: F401 + OrganizationMember, +) +from gatehouse_app.models.organization.department import ( # noqa: F401 + Department, + DepartmentMembership, + DepartmentPrincipal, +) +from gatehouse_app.models.organization.department_cert_policy import ( # noqa: F401 + DepartmentCertPolicy, + STANDARD_EXTENSIONS, +) +from gatehouse_app.models.organization.principal import ( # noqa: F401 + Principal, + PrincipalMembership, +) +from gatehouse_app.models.organization.org_invite_token import OrgInviteToken # noqa: F401 + +# ── Auth ────────────────────────────────────────────────────────────────────── +from gatehouse_app.models.auth.authentication_method import ( # noqa: F401 AuthenticationMethod, ApplicationProviderConfig, OrganizationProviderOverride, OAuthState, ) -from gatehouse_app.models.session import Session -from gatehouse_app.models.audit_log import AuditLog -from gatehouse_app.models.oidc_client import OIDCClient -from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode -from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken -from gatehouse_app.models.oidc_session import OIDCSession -from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata -from gatehouse_app.models.oidc_audit_log import OIDCAuditLog -from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy -from gatehouse_app.models.user_security_policy import UserSecurityPolicy -from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance +from gatehouse_app.models.auth.audit_log import AuditLog # noqa: F401 +from gatehouse_app.models.auth.password_reset_token import PasswordResetToken # noqa: F401 +from gatehouse_app.models.auth.email_verification_token import ( # noqa: F401 + EmailVerificationToken, +) + +# ── OIDC ────────────────────────────────────────────────────────────────────── +from gatehouse_app.models.oidc.oidc_client import OIDCClient # noqa: F401 +from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode # noqa: F401 +from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken # noqa: F401 +from gatehouse_app.models.oidc.oidc_session import OIDCSession # noqa: F401 +from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata # noqa: F401 +from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog # noqa: F401 +from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey # noqa: F401 + +# ── SSH / CA ────────────────────────────────────────────────────────────────── +from gatehouse_app.models.ssh_ca.ca import ( # noqa: F401 + CA, + KeyType, + CertType, + CaType, + CAPermission, +) +from gatehouse_app.models.ssh_ca.ssh_key import SSHKey # noqa: F401 +from gatehouse_app.models.ssh_ca.ssh_certificate import ( # noqa: F401 + SSHCertificate, + CertificateStatus, +) +from gatehouse_app.models.ssh_ca.certificate_audit_log import ( # noqa: F401 + CertificateAuditLog, +) + +# ── Security ────────────────────────────────────────────────────────────────── +from gatehouse_app.models.security.organization_security_policy import ( # noqa: F401 + OrganizationSecurityPolicy, +) +from gatehouse_app.models.security.user_security_policy import ( # noqa: F401 + UserSecurityPolicy, +) +from gatehouse_app.models.security.mfa_policy_compliance import ( # noqa: F401 + MfaPolicyCompliance, +) __all__ = [ + # Base "BaseModel", + # User "User", + "Session", + # Organization "Organization", "OrganizationMember", + "Department", + "DepartmentMembership", + "DepartmentPrincipal", + "DepartmentCertPolicy", + "STANDARD_EXTENSIONS", + "Principal", + "PrincipalMembership", + "OrgInviteToken", + # Auth "AuthenticationMethod", "ApplicationProviderConfig", "OrganizationProviderOverride", "OAuthState", - "Session", "AuditLog", + "PasswordResetToken", + "EmailVerificationToken", + # OIDC "OIDCClient", "OIDCAuthCode", "OIDCRefreshToken", "OIDCSession", "OIDCTokenMetadata", "OIDCAuditLog", + "OidcJwksKey", + # SSH / CA + "CA", + "KeyType", + "CertType", + "CaType", + "CAPermission", + "SSHKey", + "SSHCertificate", + "CertificateStatus", + "CertificateAuditLog", + # Security "OrganizationSecurityPolicy", "UserSecurityPolicy", "MfaPolicyCompliance", diff --git a/gatehouse_app/models/auth/__init__.py b/gatehouse_app/models/auth/__init__.py new file mode 100644 index 0000000..e28b467 --- /dev/null +++ b/gatehouse_app/models/auth/__init__.py @@ -0,0 +1,20 @@ +"""Auth subpackage — authentication methods, tokens, and audit logs.""" +from gatehouse_app.models.auth.authentication_method import ( + AuthenticationMethod, + ApplicationProviderConfig, + OrganizationProviderOverride, + OAuthState, +) +from gatehouse_app.models.auth.audit_log import AuditLog +from gatehouse_app.models.auth.password_reset_token import PasswordResetToken +from gatehouse_app.models.auth.email_verification_token import EmailVerificationToken + +__all__ = [ + "AuthenticationMethod", + "ApplicationProviderConfig", + "OrganizationProviderOverride", + "OAuthState", + "AuditLog", + "PasswordResetToken", + "EmailVerificationToken", +] diff --git a/gatehouse_app/models/audit_log.py b/gatehouse_app/models/auth/audit_log.py similarity index 92% rename from gatehouse_app/models/audit_log.py rename to gatehouse_app/models/auth/audit_log.py index 3e3cea1..849f915 100644 --- a/gatehouse_app/models/audit_log.py +++ b/gatehouse_app/models/auth/audit_log.py @@ -26,14 +26,13 @@ class AuditLog(BaseModel): extra_data = db.Column(db.JSON, nullable=True) description = db.Column(db.Text, nullable=True) - # Success/failure + # Outcome success = db.Column(db.Boolean, default=True, nullable=False) error_message = db.Column(db.Text, nullable=True) # Relationships user = db.relationship("User", back_populates="audit_logs") - # Indexes for common queries __table_args__ = ( db.Index("idx_audit_user_action", "user_id", "action"), db.Index("idx_audit_resource", "resource_type", "resource_id"), @@ -45,9 +44,8 @@ class AuditLog(BaseModel): return f"" @classmethod - def log(cls, action, user_id=None, **kwargs): - """ - Create an audit log entry. + def log(cls, action, user_id=None, **kwargs) -> "AuditLog": + """Create an audit log entry. Args: action: AuditAction enum value diff --git a/gatehouse_app/models/authentication_method.py b/gatehouse_app/models/auth/authentication_method.py similarity index 74% rename from gatehouse_app/models/authentication_method.py rename to gatehouse_app/models/auth/authentication_method.py index 3766d52..3fbd7c0 100644 --- a/gatehouse_app/models/authentication_method.py +++ b/gatehouse_app/models/auth/authentication_method.py @@ -1,4 +1,4 @@ -"""Authentication method model.""" +"""Authentication method model — user credentials and OAuth provider config.""" from datetime import datetime, timedelta, timezone import secrets from gatehouse_app.extensions import db @@ -35,7 +35,6 @@ class AuthenticationMethod(BaseModel): # Relationships user = db.relationship("User", back_populates="authentication_methods") - # Ensure unique provider combinations __table_args__ = ( db.Index("idx_user_method", "user_id", "method_type"), db.UniqueConstraint( @@ -45,13 +44,15 @@ class AuthenticationMethod(BaseModel): def __repr__(self): """String representation of AuthenticationMethod.""" - return f"" + return ( + f"" + ) - def is_password(self): + def is_password(self) -> bool: """Check if this is a password authentication method.""" return self.method_type == AuthMethodType.PASSWORD - def is_oauth(self): + def is_oauth(self) -> bool: """Check if this is an OAuth authentication method.""" return self.method_type in [ AuthMethodType.GOOGLE, @@ -59,32 +60,32 @@ class AuthenticationMethod(BaseModel): AuthMethodType.MICROSOFT, ] - def is_totp(self): + def is_totp(self) -> bool: """Check if this is a TOTP authentication method.""" return self.method_type == AuthMethodType.TOTP - def is_webauthn(self): + def is_webauthn(self) -> bool: """Check if this is a WebAuthn authentication method.""" return self.method_type == AuthMethodType.WEBAUTHN def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude password hash and TOTP secrets - exclude.append("password_hash") - exclude.append("totp_secret") - exclude.append("totp_backup_codes") + # Always exclude credential material + for field in ("password_hash", "totp_secret", "totp_backup_codes"): + if field not in exclude: + exclude.append(field) return super().to_dict(exclude=exclude) def to_webauthn_dict(self): """Convert WebAuthn credential to public dictionary. - + Returns: - Dictionary with safe-to-expose credential information. + Dictionary with safe-to-expose credential information, or None. """ if not self.is_webauthn() or not self.provider_data: return None - + data = self.provider_data return { "id": data.get("credential_id"), @@ -98,26 +99,26 @@ class AuthenticationMethod(BaseModel): class ApplicationProviderConfig(BaseModel): """Application-wide OAuth provider configuration. - - This model stores OAuth provider credentials at the application level, - allowing users to authenticate without needing to specify an organization first. + + Stores OAuth provider credentials at the application level, allowing users + to authenticate without needing to specify an organization first. """ __tablename__ = "application_provider_configs" # Provider identification provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True) - - # OAuth credentials (encrypted) + + # OAuth credentials (client_secret encrypted at rest) client_id = db.Column(db.String(255), nullable=False) client_secret_encrypted = db.Column(db.String(512), nullable=True) - + # Provider status is_enabled = db.Column(db.Boolean, default=True, nullable=False) - + # Default redirect URL default_redirect_url = db.Column(db.String(2048), nullable=True) - + # Provider-specific settings (JSON) additional_config = db.Column(db.JSON, nullable=True) @@ -126,28 +127,34 @@ class ApplicationProviderConfig(BaseModel): "OrganizationProviderOverride", back_populates="application_config", foreign_keys="OrganizationProviderOverride.provider_type", - primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type", - cascade="all, delete-orphan" + primaryjoin=( + "ApplicationProviderConfig.provider_type" + "==OrganizationProviderOverride.provider_type" + ), + cascade="all, delete-orphan", ) def __repr__(self): """String representation of ApplicationProviderConfig.""" - return f"" + return ( + f"" + ) - def set_client_secret(self, plaintext_secret: str): + def set_client_secret(self, plaintext_secret: str) -> None: """Encrypt and store client secret. - + Args: plaintext_secret: The plaintext OAuth client secret """ if plaintext_secret: self.client_secret_encrypted = encrypt(plaintext_secret) - def get_client_secret(self) -> str: + def get_client_secret(self) -> str | None: """Decrypt and return client secret. - + Returns: - The plaintext OAuth client secret + The plaintext OAuth client secret, or None if not set. """ if self.client_secret_encrypted: return decrypt(self.client_secret_encrypted) @@ -156,37 +163,38 @@ class ApplicationProviderConfig(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude encrypted client secret - exclude.append("client_secret_encrypted") + if "client_secret_encrypted" not in exclude: + exclude.append("client_secret_encrypted") return super().to_dict(exclude=exclude) class OrganizationProviderOverride(BaseModel): """Organization-specific OAuth configuration overrides. - - This model allows organizations to override application-level OAuth settings - for enterprise SSO scenarios or custom provider configurations. + + Allows organizations to override application-level OAuth settings for + enterprise SSO scenarios or custom provider configurations. """ __tablename__ = "organization_provider_overrides" - # References organization_id = db.Column( - db.String(36), db.ForeignKey("organizations.id"), - nullable=False, index=True + db.String(36), + db.ForeignKey("organizations.id"), + nullable=False, + index=True, ) provider_type = db.Column(db.String(50), nullable=False, index=True) - - # Override OAuth credentials (encrypted, nullable - only if overriding) + + # Override OAuth credentials (encrypted, nullable — only set when overriding) client_id = db.Column(db.String(255), nullable=True) client_secret_encrypted = db.Column(db.String(512), nullable=True) - + # Provider status is_enabled = db.Column(db.Boolean, default=True, nullable=False) - + # Redirect URL override redirect_url_override = db.Column(db.String(2048), nullable=True) - + # Provider-specific settings override (JSON) additional_config = db.Column(db.JSON, nullable=True) @@ -196,37 +204,33 @@ class OrganizationProviderOverride(BaseModel): "ApplicationProviderConfig", back_populates="organization_overrides", foreign_keys=[provider_type], - primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type", - viewonly=True + primaryjoin=( + "ApplicationProviderConfig.provider_type" + "==OrganizationProviderOverride.provider_type" + ), + viewonly=True, ) - # Unique constraint on (organization_id, provider_type) __table_args__ = ( db.UniqueConstraint( - "organization_id", "provider_type", - name="uix_org_provider_type" + "organization_id", "provider_type", name="uix_org_provider_type" ), ) def __repr__(self): """String representation of OrganizationProviderOverride.""" - return f"" + return ( + f"" + ) - def set_client_secret(self, plaintext_secret: str): - """Encrypt and store client secret override. - - Args: - plaintext_secret: The plaintext OAuth client secret - """ + def set_client_secret(self, plaintext_secret: str) -> None: + """Encrypt and store client secret override.""" if plaintext_secret: self.client_secret_encrypted = encrypt(plaintext_secret) - def get_client_secret(self) -> str: - """Decrypt and return client secret override. - - Returns: - The plaintext OAuth client secret - """ + def get_client_secret(self) -> str | None: + """Decrypt and return client secret override.""" if self.client_secret_encrypted: return decrypt(self.client_secret_encrypted) return None @@ -234,53 +238,52 @@ class OrganizationProviderOverride(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude encrypted client secret - exclude.append("client_secret_encrypted") + if "client_secret_encrypted" not in exclude: + exclude.append("client_secret_encrypted") return super().to_dict(exclude=exclude) class OAuthState(BaseModel): """OAuth flow state tracking. - - This model tracks OAuth authentication flow state, including PKCE parameters - and organization context (which is now optional to support login flows where - the organization isn't known until after authentication). + + Tracks OAuth authentication flow state, including PKCE parameters and + organization context (which is optional to support login flows where the + organization isn't known until after authentication). """ __tablename__ = "oauth_states" # OAuth state parameter (unique, used for CSRF protection) state = db.Column(db.String(64), unique=True, nullable=False, index=True) - + # Flow type: "login", "register", "link" flow_type = db.Column(db.String(50), nullable=False) - + # Provider type provider_type = db.Column(db.String(50), nullable=False) - - # User context (optional - not set for login/register flows) + + # User context (optional — not set for login/register flows) user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True) - - # Organization context (NOW OPTIONAL - for SSO discovery or post-auth) + + # Organization context (optional — for SSO discovery or post-auth) organization_id = db.Column( - db.String(36), db.ForeignKey("organizations.id"), - nullable=True, index=True + db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True ) - + # PKCE parameters nonce = db.Column(db.String(128), nullable=True) code_verifier = db.Column(db.String(128), nullable=True) code_challenge = db.Column(db.String(128), nullable=True) - + # OAuth parameters redirect_uri = db.Column(db.String(2048), nullable=True) - + # Post-auth redirect (for frontend routing) return_url = db.Column(db.String(2048), nullable=True) - + # Additional state data extra_data = db.Column(db.JSON, nullable=True) - + # Expiration and usage tracking expires_at = db.Column(db.DateTime, nullable=False, index=True) used = db.Column(db.Boolean, default=False, nullable=False) @@ -291,7 +294,10 @@ class OAuthState(BaseModel): def __repr__(self): """String representation of OAuthState.""" - return f"" + return ( + f"" + ) @classmethod def create_state( @@ -306,10 +312,10 @@ class OAuthState(BaseModel): code_challenge: str = None, nonce: str = None, extra_data: dict = None, - lifetime_seconds: int = 600 - ): - """Create a new OAuth state with auto-generated state parameter. - + lifetime_seconds: int = 600, + ) -> "OAuthState": + """Create a new OAuth state with an auto-generated state parameter. + Args: flow_type: Type of flow ("login", "register", "link") provider_type: OAuth provider type @@ -322,13 +328,13 @@ class OAuthState(BaseModel): nonce: OpenID Connect nonce extra_data: Additional state data lifetime_seconds: How long the state is valid (default 10 minutes) - + Returns: New OAuthState instance """ state = secrets.token_urlsafe(32) expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds) - + oauth_state = cls( state=state, flow_type=flow_type, @@ -342,31 +348,30 @@ class OAuthState(BaseModel): nonce=nonce, extra_data=extra_data, expires_at=expires_at, - used=False + used=False, ) oauth_state.save() return oauth_state def is_valid(self) -> bool: """Check if the OAuth state is still valid. - + Returns: - True if state hasn't expired and hasn't been used + True if state hasn't expired and hasn't been used. """ now = datetime.now(timezone.utc) - # Make expires_at timezone-aware if it's naive (database returns naive datetimes) expires_at = self.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) return not self.used and expires_at > now - def mark_used(self): + def mark_used(self) -> None: """Mark the state as used to prevent replay attacks.""" self.used = True self.save() @classmethod - def cleanup_expired(cls): + def cleanup_expired(cls) -> None: """Remove expired OAuth states.""" now = datetime.now(timezone.utc) cls.query.filter(cls.expires_at < now).delete() @@ -375,6 +380,7 @@ class OAuthState(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Exclude code_verifier as it's sensitive - exclude.append("code_verifier") + # code_verifier must never be exposed + if "code_verifier" not in exclude: + exclude.append("code_verifier") return super().to_dict(exclude=exclude) diff --git a/gatehouse_app/models/auth/email_verification_token.py b/gatehouse_app/models/auth/email_verification_token.py new file mode 100644 index 0000000..9f40682 --- /dev/null +++ b/gatehouse_app/models/auth/email_verification_token.py @@ -0,0 +1,68 @@ +"""Email verification token model.""" +import secrets +from datetime import datetime, timezone, timedelta + +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class EmailVerificationToken(BaseModel): + """Single-use token for verifying a user's email address.""" + + __tablename__ = "email_verification_tokens" + + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + token = db.Column(db.String(128), unique=True, nullable=False, index=True) + expires_at = db.Column(db.DateTime, nullable=False) + used_at = db.Column(db.DateTime, nullable=True) + + user = db.relationship( + "User", + backref=db.backref("email_verification_tokens", cascade="all, delete-orphan"), + ) + + @classmethod + def generate(cls, user_id: str, ttl_hours: int = 24) -> "EmailVerificationToken": + """Create a new verification token for a user. + + Any existing unused tokens for this user are invalidated first. + """ + cls.query.filter_by(user_id=user_id, used_at=None).delete() + db.session.flush() + + token_value = secrets.token_urlsafe(48) + instance = cls( + user_id=user_id, + token=token_value, + expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours), + ) + db.session.add(instance) + db.session.commit() + return instance + + @property + def is_valid(self) -> bool: + """Return True if the token has not been used and has not expired.""" + if self.used_at is not None: + return False + now = datetime.now(timezone.utc) + expires = self.expires_at + if expires.tzinfo is None: + expires = expires.replace(tzinfo=timezone.utc) + return now < expires + + def consume(self) -> None: + """Mark the token as used.""" + self.used_at = datetime.now(timezone.utc) + db.session.commit() + + def __repr__(self) -> str: + return ( + f"" + ) diff --git a/gatehouse_app/models/auth/password_reset_token.py b/gatehouse_app/models/auth/password_reset_token.py new file mode 100644 index 0000000..53072ef --- /dev/null +++ b/gatehouse_app/models/auth/password_reset_token.py @@ -0,0 +1,69 @@ +"""Password reset token model.""" +import secrets +from datetime import datetime, timezone, timedelta + +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class PasswordResetToken(BaseModel): + """Single-use token for resetting a user's password.""" + + __tablename__ = "password_reset_tokens" + + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + token = db.Column(db.String(128), unique=True, nullable=False, index=True) + expires_at = db.Column(db.DateTime, nullable=False) + used_at = db.Column(db.DateTime, nullable=True) + + user = db.relationship( + "User", + backref=db.backref("password_reset_tokens", cascade="all, delete-orphan"), + ) + + @classmethod + def generate(cls, user_id: str, ttl_hours: int = 2) -> "PasswordResetToken": + """Create a new password reset token for a user. + + Any existing unused tokens for this user are invalidated first. + """ + # Invalidate any existing unused tokens for this user + cls.query.filter_by(user_id=user_id, used_at=None).delete() + db.session.flush() + + token_value = secrets.token_urlsafe(48) + instance = cls( + user_id=user_id, + token=token_value, + expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours), + ) + db.session.add(instance) + db.session.commit() + return instance + + @property + def is_valid(self) -> bool: + """Return True if the token has not been used and has not expired.""" + if self.used_at is not None: + return False + now = datetime.now(timezone.utc) + expires = self.expires_at + if expires.tzinfo is None: + expires = expires.replace(tzinfo=timezone.utc) + return now < expires + + def consume(self) -> None: + """Mark the token as used.""" + self.used_at = datetime.now(timezone.utc) + db.session.commit() + + def __repr__(self) -> str: + return ( + f"" + ) diff --git a/gatehouse_app/models/base.py b/gatehouse_app/models/base.py index 0a63735..d1ce1c4 100644 --- a/gatehouse_app/models/base.py +++ b/gatehouse_app/models/base.py @@ -82,7 +82,10 @@ class BaseModel(db.Model): if column.name not in exclude: value = getattr(self, column.name) if isinstance(value, datetime): - result[column.name] = value.isoformat() + if value.tzinfo is None: + result[column.name] = value.isoformat() + "Z" + else: + result[column.name] = value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") else: result[column.name] = value return result diff --git a/gatehouse_app/models/oidc/__init__.py b/gatehouse_app/models/oidc/__init__.py new file mode 100644 index 0000000..2cb08da --- /dev/null +++ b/gatehouse_app/models/oidc/__init__.py @@ -0,0 +1,18 @@ +"""OIDC subpackage — clients, tokens, sessions, and audit logs.""" +from gatehouse_app.models.oidc.oidc_client import OIDCClient +from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode +from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken +from gatehouse_app.models.oidc.oidc_session import OIDCSession +from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata +from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog +from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey + +__all__ = [ + "OIDCClient", + "OIDCAuthCode", + "OIDCRefreshToken", + "OIDCSession", + "OIDCTokenMetadata", + "OIDCAuditLog", + "OidcJwksKey", +] diff --git a/gatehouse_app/models/oidc_audit_log.py b/gatehouse_app/models/oidc/oidc_audit_log.py similarity index 68% rename from gatehouse_app/models/oidc_audit_log.py rename to gatehouse_app/models/oidc/oidc_audit_log.py index 39b21a5..c0ae557 100644 --- a/gatehouse_app/models/oidc_audit_log.py +++ b/gatehouse_app/models/oidc/oidc_audit_log.py @@ -1,5 +1,4 @@ """OIDC Audit Log model for comprehensive OIDC event tracking.""" -from datetime import datetime from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel @@ -7,8 +6,7 @@ from gatehouse_app.models.base import BaseModel class OIDCAuditLog(BaseModel): """OIDC Audit Log model for comprehensive OIDC event tracking. - This model logs all OIDC-related events for security, compliance, - and debugging purposes. + Logs all OIDC-related events for security, compliance, and debugging. """ __tablename__ = "oidc_audit_logs" @@ -46,16 +44,29 @@ class OIDCAuditLog(BaseModel): def __repr__(self): """String representation of OIDCAuditLog.""" status = "success" if self.success else "failed" - return f"" + return ( + f"" + ) @classmethod - def log_event(cls, event_type, client_id=None, user_id=None, success=True, - error_code=None, error_description=None, ip_address=None, - user_agent=None, request_id=None, event_metadata=None): + def log_event( + cls, + event_type: str, + client_id: str = None, + user_id: str = None, + success: bool = True, + error_code: str = None, + error_description: str = None, + ip_address: str = None, + user_agent: str = None, + request_id: str = None, + event_metadata: dict = None, + ) -> "OIDCAuditLog": """Log an OIDC event. Args: - event_type: Type of event (e.g., "authorization_request", "token_issue") + event_type: Type of event (e.g., "authorization_request") client_id: The OIDC client ID user_id: The user ID success: Whether the event was successful @@ -86,9 +97,19 @@ class OIDCAuditLog(BaseModel): return log @classmethod - def log_authorization_request(cls, client_id, user_id, redirect_uri, scope, - ip_address=None, user_agent=None, request_id=None, - success=True, error_code=None, error_description=None): + def log_authorization_request( + cls, + client_id: str, + user_id: str, + redirect_uri: str, + scope, + ip_address: str = None, + user_agent: str = None, + request_id: str = None, + success: bool = True, + error_code: str = None, + error_description: str = None, + ) -> "OIDCAuditLog": """Log an authorization request event.""" return cls.log_event( event_type="authorization_request", @@ -100,15 +121,19 @@ class OIDCAuditLog(BaseModel): ip_address=ip_address, user_agent=user_agent, request_id=request_id, - event_metadata={ - "redirect_uri": redirect_uri, - "scope": scope, - } + event_metadata={"redirect_uri": redirect_uri, "scope": scope}, ) @classmethod - def log_token_issue(cls, client_id, user_id, token_type, - ip_address=None, user_agent=None, request_id=None): + def log_token_issue( + cls, + client_id: str, + user_id: str, + token_type: str, + ip_address: str = None, + user_agent: str = None, + request_id: str = None, + ) -> "OIDCAuditLog": """Log a token issuance event.""" return cls.log_event( event_type="token_issue", @@ -118,12 +143,20 @@ class OIDCAuditLog(BaseModel): ip_address=ip_address, user_agent=user_agent, request_id=request_id, - event_metadata={"token_type": token_type} + event_metadata={"token_type": token_type}, ) @classmethod - def log_token_revocation(cls, client_id, user_id, token_type, reason=None, - ip_address=None, user_agent=None, request_id=None): + def log_token_revocation( + cls, + client_id: str, + user_id: str, + token_type: str, + reason: str = None, + ip_address: str = None, + user_agent: str = None, + request_id: str = None, + ) -> "OIDCAuditLog": """Log a token revocation event.""" return cls.log_event( event_type="token_revocation", @@ -133,15 +166,19 @@ class OIDCAuditLog(BaseModel): ip_address=ip_address, user_agent=user_agent, request_id=request_id, - event_metadata={ - "token_type": token_type, - "reason": reason, - } + event_metadata={"token_type": token_type, "reason": reason}, ) @classmethod - def log_authentication_failure(cls, client_id, error_code, error_description, - ip_address=None, user_agent=None, request_id=None): + def log_authentication_failure( + cls, + client_id: str, + error_code: str, + error_description: str, + ip_address: str = None, + user_agent: str = None, + request_id: str = None, + ) -> "OIDCAuditLog": """Log an authentication failure event.""" return cls.log_event( event_type="authentication_failure", @@ -155,7 +192,7 @@ class OIDCAuditLog(BaseModel): ) @classmethod - def get_events_for_user(cls, user_id, limit=100): + def get_events_for_user(cls, user_id: str, limit: int = 100) -> list: """Get audit events for a user. Args: @@ -165,13 +202,15 @@ class OIDCAuditLog(BaseModel): Returns: List of OIDCAuditLog instances """ - return cls.query.filter_by(user_id=user_id, deleted_at=None)\ - .order_by(cls.created_at.desc())\ - .limit(limit)\ + return ( + cls.query.filter_by(user_id=user_id, deleted_at=None) + .order_by(cls.created_at.desc()) + .limit(limit) .all() + ) @classmethod - def get_events_for_client(cls, client_id, limit=100): + def get_events_for_client(cls, client_id: str, limit: int = 100) -> list: """Get audit events for a client. Args: @@ -181,14 +220,22 @@ class OIDCAuditLog(BaseModel): Returns: List of OIDCAuditLog instances """ - return cls.query.filter_by(client_id=client_id, deleted_at=None)\ - .order_by(cls.created_at.desc())\ - .limit(limit)\ + return ( + cls.query.filter_by(client_id=client_id, deleted_at=None) + .order_by(cls.created_at.desc()) + .limit(limit) .all() + ) @classmethod - def get_failed_events(cls, client_id=None, user_id=None, start_date=None, - end_date=None, limit=100): + def get_failed_events( + cls, + client_id: str = None, + user_id: str = None, + start_date=None, + end_date=None, + limit: int = 100, + ) -> list: """Get failed audit events. Args: @@ -210,22 +257,8 @@ class OIDCAuditLog(BaseModel): query = query.filter(cls.created_at >= start_date) if end_date: query = query.filter(cls.created_at <= end_date) - return query.order_by(cls.created_at.desc()).limit(limit).all() def to_dict(self, exclude=None): """Convert to dictionary.""" return super().to_dict(exclude=exclude) - - -# Add relationship back to User model -from gatehouse_app.models.user import User -User.oidc_audit_logs = db.relationship( - "OIDCAuditLog", back_populates="user", cascade="all, delete-orphan" -) - -# Add relationship back to OIDCClient model -from gatehouse_app.models.oidc_client import OIDCClient -OIDCClient.audit_logs = db.relationship( - "OIDCAuditLog", back_populates="client", cascade="all, delete-orphan" -) diff --git a/gatehouse_app/models/oidc_authorization_code.py b/gatehouse_app/models/oidc/oidc_authorization_code.py similarity index 65% rename from gatehouse_app/models/oidc_authorization_code.py rename to gatehouse_app/models/oidc/oidc_authorization_code.py index 640078e..3884592 100644 --- a/gatehouse_app/models/oidc_authorization_code.py +++ b/gatehouse_app/models/oidc/oidc_authorization_code.py @@ -1,14 +1,14 @@ -"""OIDC Authorization Code model for auth code flow.""" +"""OIDC Authorization Code model for the authorization code grant flow.""" from datetime import datetime, timedelta, timezone from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel class OIDCAuthCode(BaseModel): - """OIDC Authorization Code model for authorization code flow. + """OIDC Authorization Code model for the authorization code grant flow. - Authorization codes are single-use, short-lived codes used in the - authorization code grant flow. The code is hashed for security. + Authorization codes are single-use, short-lived codes. The code itself is + hashed before storage so that a database breach cannot replay codes. """ __tablename__ = "oidc_authorization_codes" @@ -26,9 +26,9 @@ class OIDCAuthCode(BaseModel): # Request parameters redirect_uri = db.Column(db.String(512), nullable=False) - scope = db.Column(db.JSON, nullable=True) # Requested scopes - nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation - code_verifier = db.Column(db.String(255), nullable=True) # For PKCE + scope = db.Column(db.JSON, nullable=True) + nonce = db.Column(db.String(255), nullable=True) + code_verifier = db.Column(db.String(255), nullable=True) # Status tracking expires_at = db.Column(db.DateTime, nullable=False, index=True) @@ -39,37 +39,48 @@ class OIDCAuthCode(BaseModel): ip_address = db.Column(db.String(45), nullable=True) user_agent = db.Column(db.Text, nullable=True) - # Relationships + # Relationships — back_populates declared on User and OIDCClient client = db.relationship("OIDCClient", back_populates="authorization_codes") user = db.relationship("User", back_populates="oidc_auth_codes") def __repr__(self): """String representation of OIDCAuthCode.""" - return f"" + return ( + f"" + ) - def is_expired(self): + def is_expired(self) -> bool: """Check if the authorization code has expired.""" - # Handle both timezone-aware and timezone-naive expires_at values expires_at = self.expires_at if expires_at.tzinfo is None: - # Make naive datetime timezone-aware (UTC) expires_at = expires_at.replace(tzinfo=timezone.utc) return datetime.now(timezone.utc) > expires_at - def is_valid(self): + def is_valid(self) -> bool: """Check if the authorization code is valid for use.""" return not self.is_used and not self.is_expired() and self.deleted_at is None - def mark_as_used(self): + def mark_as_used(self) -> None: """Mark the authorization code as used.""" self.is_used = True self.used_at = datetime.now(timezone.utc) db.session.commit() @classmethod - def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None, - nonce=None, code_verifier=None, ip_address=None, user_agent=None, - lifetime_seconds=600): + def create_code( + cls, + client_id: str, + user_id: str, + code_hash: str, + redirect_uri: str, + scope=None, + nonce: str = None, + code_verifier: str = None, + ip_address: str = None, + user_agent: str = None, + lifetime_seconds: int = 600, + ) -> "OIDCAuthCode": """Create a new authorization code. Args: @@ -79,7 +90,7 @@ class OIDCAuthCode(BaseModel): redirect_uri: The redirect URI scope: Requested scopes nonce: OIDC nonce - code_verifier: PKCE code verifier + code_verifier: PKCE code verifier (stored hashed server-side) ip_address: Client IP address user_agent: Client user agent lifetime_seconds: Code lifetime in seconds (default 10 minutes) @@ -106,20 +117,7 @@ class OIDCAuthCode(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude code hash - exclude.append("code_hash") - exclude.append("code_verifier") + for field in ("code_hash", "code_verifier"): + if field not in exclude: + exclude.append(field) return super().to_dict(exclude=exclude) - - -# Add relationship back to User model -from gatehouse_app.models.user import User -User.oidc_auth_codes = db.relationship( - "OIDCAuthCode", back_populates="user", cascade="all, delete-orphan" -) - -# Add relationship back to OIDCClient model -from gatehouse_app.models.oidc_client import OIDCClient -OIDCClient.authorization_codes = db.relationship( - "OIDCAuthCode", back_populates="client", cascade="all, delete-orphan" -) diff --git a/gatehouse_app/models/oidc_client.py b/gatehouse_app/models/oidc/oidc_client.py similarity index 63% rename from gatehouse_app/models/oidc_client.py rename to gatehouse_app/models/oidc/oidc_client.py index a446983..03c0b18 100644 --- a/gatehouse_app/models/oidc_client.py +++ b/gatehouse_app/models/oidc/oidc_client.py @@ -17,10 +17,10 @@ class OIDCClient(BaseModel): client_secret_hash = db.Column(db.String(255), nullable=False) # OAuth/OIDC configuration - redirect_uris = db.Column(db.JSON, nullable=False) # List of allowed redirect URIs - grant_types = db.Column(db.JSON, nullable=False) # List of allowed grant types - response_types = db.Column(db.JSON, nullable=False) # List of allowed response types - scopes = db.Column(db.JSON, nullable=False) # List of allowed scopes + redirect_uris = db.Column(db.JSON, nullable=False) # Allowed redirect URIs + grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types + response_types = db.Column(db.JSON, nullable=False) # Allowed response types + scopes = db.Column(db.JSON, nullable=False) # Allowed scopes # Client metadata logo_uri = db.Column(db.String(512), nullable=True) @@ -41,6 +41,23 @@ class OIDCClient(BaseModel): # Relationships organization = db.relationship("Organization", back_populates="oidc_clients") + # OIDC sub-resource relationships (declared here, not monkey-patched elsewhere) + authorization_codes = db.relationship( + "OIDCAuthCode", back_populates="client", cascade="all, delete-orphan" + ) + refresh_tokens = db.relationship( + "OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan" + ) + oidc_sessions = db.relationship( + "OIDCSession", back_populates="client", cascade="all, delete-orphan" + ) + token_metadata = db.relationship( + "OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan" + ) + audit_logs = db.relationship( + "OIDCAuditLog", back_populates="client", cascade="all, delete-orphan" + ) + def __repr__(self): """String representation of OIDCClient.""" return f"" @@ -48,22 +65,22 @@ class OIDCClient(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude client secret - exclude.append("client_secret_hash") + if "client_secret_hash" not in exclude: + exclude.append("client_secret_hash") return super().to_dict(exclude=exclude) - def has_grant_type(self, grant_type): + def has_grant_type(self, grant_type) -> bool: """Check if client supports a specific grant type.""" return grant_type in self.grant_types - def has_response_type(self, response_type): + def has_response_type(self, response_type) -> bool: """Check if client supports a specific response type.""" return response_type in self.response_types - def is_redirect_uri_allowed(self, redirect_uri): + def is_redirect_uri_allowed(self, redirect_uri: str) -> bool: """Check if a redirect URI is allowed for this client.""" return redirect_uri in self.redirect_uris - def has_scope(self, scope): + def has_scope(self, scope: str) -> bool: """Check if client is allowed to request a specific scope.""" return scope in self.scopes diff --git a/gatehouse_app/models/oidc/oidc_jwks_key.py b/gatehouse_app/models/oidc/oidc_jwks_key.py new file mode 100644 index 0000000..f8fa982 --- /dev/null +++ b/gatehouse_app/models/oidc/oidc_jwks_key.py @@ -0,0 +1,76 @@ +"""OIDC JWKS Key model for persisting signing keys.""" +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class OidcJwksKey(BaseModel): + """OIDC JWKS Key model for persisting JSON Web Key Set signing keys. + + Stores RSA/ECDSA key pairs used for signing OIDC tokens. Multiple keys can + be stored to support key rotation scenarios. + + Attributes: + kid: Unique key ID used in JWT ``kid`` header + key_type: Type of key (e.g., "RSA", "EC") + private_key: PEM-encoded private key (never exposed in API responses) + public_key: PEM-encoded public key + algorithm: Signing algorithm (e.g., "RS256", "ES256") + is_active: Whether this key is currently used for signing/verification + is_primary: Whether this is the primary signing key + expires_at: Optional expiry for key rotation enforcement + """ + + __tablename__ = "oidc_jwks_keys" + + # Override the default UUID id with integer primary key for JWKS key sets + id = db.Column(db.Integer, primary_key=True) + + expires_at = db.Column(db.DateTime, nullable=True) + + # Key identification and type + kid = db.Column(db.String(255), unique=True, nullable=False, index=True) + key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC" + algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256" + + # Key material (PEM-encoded) — private_key must never be returned by API + private_key = db.Column(db.Text, nullable=False) + public_key = db.Column(db.Text, nullable=False) + + # Key status + is_active = db.Column(db.Boolean, default=True, nullable=False) + is_primary = db.Column(db.Boolean, default=False, nullable=False) + + def __repr__(self): + """String representation of OidcJwksKey.""" + return ( + f"" + ) + + def to_dict(self, exclude_private_key: bool = True): + """Convert model to dictionary. + + Args: + exclude_private_key: If True (default), excludes the private key. + + Returns: + Dictionary representation of the model + """ + exclude = ["private_key"] if exclude_private_key else [] + return super().to_dict(exclude=exclude) + + @classmethod + def get_active_keys(cls) -> list: + """Get all active keys for signing operations.""" + return cls.query.filter_by(is_active=True).all() + + @classmethod + def get_primary_key(cls) -> "OidcJwksKey | None": + """Get the primary signing key.""" + return cls.query.filter_by(is_primary=True).first() + + @classmethod + def get_key_by_kid(cls, kid: str) -> "OidcJwksKey | None": + """Get an active key by its key ID.""" + return cls.query.filter_by(kid=kid, is_active=True).first() diff --git a/gatehouse_app/models/oidc_refresh_token.py b/gatehouse_app/models/oidc/oidc_refresh_token.py similarity index 67% rename from gatehouse_app/models/oidc_refresh_token.py rename to gatehouse_app/models/oidc/oidc_refresh_token.py index a6459ea..3e1228f 100644 --- a/gatehouse_app/models/oidc_refresh_token.py +++ b/gatehouse_app/models/oidc/oidc_refresh_token.py @@ -1,5 +1,5 @@ """OIDC Refresh Token model for token rotation.""" -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel @@ -8,7 +8,8 @@ class OIDCRefreshToken(BaseModel): """OIDC Refresh Token model for token refresh and rotation. Refresh tokens are long-lived credentials used to obtain new access tokens. - They support token rotation for enhanced security. + They support token rotation for enhanced security — each use invalidates + the old token and issues a new one. """ __tablename__ = "oidc_refresh_tokens" @@ -21,16 +22,14 @@ class OIDCRefreshToken(BaseModel): db.String(36), db.ForeignKey("users.id"), nullable=False, index=True ) - # Token (hashed for security) + # Token (hashed for security — never store plaintext refresh tokens) token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True) - # Associated access token ID (stores JWT JTI string — no FK to sessions) - access_token_id = db.Column( - db.String(255), nullable=True, index=True - ) + # Associated access token JTI (no FK — stored as string for lightweight lookup) + access_token_id = db.Column(db.String(255), nullable=True, index=True) # Token scope - scope = db.Column(db.JSON, nullable=True) # Granted scopes + scope = db.Column(db.JSON, nullable=True) # Timing expires_at = db.Column(db.DateTime, nullable=False, index=True) @@ -40,7 +39,7 @@ class OIDCRefreshToken(BaseModel): revoked_reason = db.Column(db.String(255), nullable=True) # Token rotation metadata - previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation + previous_token_hash = db.Column(db.String(255), nullable=True) rotation_count = db.Column(db.Integer, default=0, nullable=False) # Request metadata @@ -53,25 +52,27 @@ class OIDCRefreshToken(BaseModel): def __repr__(self): """String representation of OIDCRefreshToken.""" - return f"" + return ( + f"" + ) - def is_expired(self): + def is_expired(self) -> bool: """Check if the refresh token has expired.""" - # Handle both timezone-aware and timezone-naive expires_at values expires_at = self.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) return datetime.now(timezone.utc) > expires_at - def is_revoked(self): + def is_revoked(self) -> bool: """Check if the refresh token has been revoked.""" return self.revoked_at is not None - def is_valid(self): + def is_valid(self) -> bool: """Check if the refresh token is valid for use.""" return not self.is_revoked() and not self.is_expired() and self.deleted_at is None - def revoke(self, reason=None): + def revoke(self, reason: str = None) -> None: """Revoke the refresh token. Args: @@ -81,8 +82,8 @@ class OIDCRefreshToken(BaseModel): self.revoked_reason = reason db.session.commit() - def rotate(self, new_token_hash): - """Rotate the refresh token (invalidate old, create new). + def rotate(self, new_token_hash: str) -> "OIDCRefreshToken": + """Rotate the refresh token — invalidate the old hash, store the new one. Args: new_token_hash: Hash of the new refresh token @@ -90,20 +91,25 @@ class OIDCRefreshToken(BaseModel): Returns: self for chaining """ - # Store reference to old token self.previous_token_hash = self.token_hash self.token_hash = new_token_hash self.rotation_count += 1 - # Extend expiration on rotation - from datetime import timedelta self.expires_at = datetime.now(timezone.utc) + timedelta(days=30) db.session.commit() return self @classmethod - def create_token(cls, client_id, user_id, token_hash, scope=None, - access_token_id=None, ip_address=None, user_agent=None, - lifetime_seconds=2592000): + def create_token( + cls, + client_id: str, + user_id: str, + token_hash: str, + scope=None, + access_token_id: str = None, + ip_address: str = None, + user_agent: str = None, + lifetime_seconds: int = 2592000, + ) -> "OIDCRefreshToken": """Create a new refresh token. Args: @@ -111,7 +117,7 @@ class OIDCRefreshToken(BaseModel): user_id: The user ID token_hash: Hashed refresh token scope: Granted scopes - access_token_id: Associated access token ID + access_token_id: Associated access token JTI ip_address: Client IP address user_agent: Client user agent lifetime_seconds: Token lifetime in seconds (default 30 days) @@ -119,7 +125,6 @@ class OIDCRefreshToken(BaseModel): Returns: OIDCRefreshToken instance """ - from datetime import timedelta token = cls( client_id=client_id, user_id=user_id, @@ -137,20 +142,7 @@ class OIDCRefreshToken(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude token hashes - exclude.append("token_hash") - exclude.append("previous_token_hash") + for field in ("token_hash", "previous_token_hash"): + if field not in exclude: + exclude.append(field) return super().to_dict(exclude=exclude) - - -# Add relationship back to User model -from gatehouse_app.models.user import User -User.oidc_refresh_tokens = db.relationship( - "OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan" -) - -# Add relationship back to OIDCClient model -from gatehouse_app.models.oidc_client import OIDCClient -OIDCClient.refresh_tokens = db.relationship( - "OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan" -) diff --git a/gatehouse_app/models/oidc_session.py b/gatehouse_app/models/oidc/oidc_session.py similarity index 70% rename from gatehouse_app/models/oidc_session.py rename to gatehouse_app/models/oidc/oidc_session.py index 8d6a88b..5768bd6 100644 --- a/gatehouse_app/models/oidc_session.py +++ b/gatehouse_app/models/oidc/oidc_session.py @@ -1,5 +1,7 @@ """OIDC Session model for OIDC session tracking.""" -from datetime import datetime, timezone +import hashlib +import base64 +from datetime import datetime, timedelta, timezone from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel @@ -7,8 +9,8 @@ from gatehouse_app.models.base import BaseModel class OIDCSession(BaseModel): """OIDC Session model for tracking OIDC authentication sessions. - This model tracks the state during the OIDC authentication flow, - including PKCE parameters and nonce validation. + Tracks the state during the OIDC authorization flow, including PKCE + parameters and nonce validation. """ __tablename__ = "oidc_sessions" @@ -25,11 +27,11 @@ class OIDCSession(BaseModel): # State management state = db.Column(db.String(255), nullable=False, index=True) - nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation + nonce = db.Column(db.String(255), nullable=True) # Authorization request parameters redirect_uri = db.Column(db.String(512), nullable=False) - scope = db.Column(db.JSON, nullable=True) # Requested scopes + scope = db.Column(db.JSON, nullable=True) # PKCE parameters code_challenge = db.Column(db.String(255), nullable=True) @@ -45,50 +47,52 @@ class OIDCSession(BaseModel): def __repr__(self): """String representation of OIDCSession.""" - return f"" + return ( + f"" + ) - def is_expired(self): + def is_expired(self) -> bool: """Check if the OIDC session has expired.""" - return datetime.now(timezone.utc) > self.expires_at + expires_at = self.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + return datetime.now(timezone.utc) > expires_at - def is_authenticated(self): + def is_authenticated(self) -> bool: """Check if the user has been authenticated in this session.""" return self.authenticated_at is not None - def mark_authenticated(self): + def mark_authenticated(self) -> None: """Mark the session as authenticated.""" self.authenticated_at = datetime.now(timezone.utc) db.session.commit() - def validate_nonce(self, expected_nonce): + def validate_nonce(self, expected_nonce: str) -> bool: """Validate the nonce matches the expected value. Args: expected_nonce: The expected nonce value Returns: - bool: True if nonce matches + True if nonce matches """ return self.nonce == expected_nonce - def validate_code_challenge(self, code_verifier): + def validate_code_challenge(self, code_verifier: str) -> bool: """Validate the code verifier against the stored code challenge. Args: code_verifier: The PKCE code verifier Returns: - bool: True if code challenge is valid + True if the challenge is satisfied """ if not self.code_challenge: return False if self.code_challenge_method == "S256": - import hashlib - import base64 - # SHA256 hash of code_verifier digest = hashlib.sha256(code_verifier.encode()).digest() - # Base64 URL encode without padding expected = base64.urlsafe_b64encode(digest).decode().rstrip("=") return self.code_challenge == expected elif self.code_challenge_method == "plain": @@ -97,9 +101,18 @@ class OIDCSession(BaseModel): return False @classmethod - def create_session(cls, user_id, client_id, state, redirect_uri, scope=None, - nonce=None, code_challenge=None, code_challenge_method=None, - lifetime_seconds=600): + def create_session( + cls, + user_id: str, + client_id: str, + state: str, + redirect_uri: str, + scope=None, + nonce: str = None, + code_challenge: str = None, + code_challenge_method: str = None, + lifetime_seconds: int = 600, + ) -> "OIDCSession": """Create a new OIDC session. Args: @@ -116,7 +129,6 @@ class OIDCSession(BaseModel): Returns: OIDCSession instance """ - from datetime import timedelta session = cls( user_id=user_id, client_id=client_id, @@ -133,7 +145,7 @@ class OIDCSession(BaseModel): return session @classmethod - def get_by_state(cls, state): + def get_by_state(cls, state: str) -> "OIDCSession | None": """Get a session by state parameter. Args: @@ -147,16 +159,3 @@ class OIDCSession(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary.""" return super().to_dict(exclude=exclude) - - -# Add relationship back to User model -from gatehouse_app.models.user import User -User.oidc_sessions = db.relationship( - "OIDCSession", back_populates="user", cascade="all, delete-orphan" -) - -# Add relationship back to OIDCClient model -from gatehouse_app.models.oidc_client import OIDCClient -OIDCClient.oidc_sessions = db.relationship( - "OIDCSession", back_populates="client", cascade="all, delete-orphan" -) diff --git a/gatehouse_app/models/oidc_token_metadata.py b/gatehouse_app/models/oidc/oidc_token_metadata.py similarity index 67% rename from gatehouse_app/models/oidc_token_metadata.py rename to gatehouse_app/models/oidc/oidc_token_metadata.py index 2c6c7a8..be8c862 100644 --- a/gatehouse_app/models/oidc_token_metadata.py +++ b/gatehouse_app/models/oidc/oidc_token_metadata.py @@ -8,13 +8,14 @@ from gatehouse_app.models.base import BaseModel class OIDCTokenMetadata(BaseModel): """OIDC Token Metadata model for tracking issued tokens. - This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens) - for the purpose of token revocation. The id field matches the JTI (JWT ID) claim. + Stores metadata about issued tokens (access, refresh, ID) for revocation. + The ``id`` field on this model intentionally overrides the BaseModel UUID + to store the JWT JTI directly as the primary key for O(1) revocation checks. """ __tablename__ = "oidc_token_metadata" - # Token identifier (matches JTI in JWT) + # Primary key = JTI so revocation lookups are always a PK scan id = db.Column( db.String(36), primary_key=True, default=lambda: str(uuid.uuid4()) ) @@ -27,11 +28,11 @@ class OIDCTokenMetadata(BaseModel): db.String(36), db.ForeignKey("users.id"), nullable=False, index=True ) - # Token type - token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token" + # Token type: "access_token", "refresh_token", or "id_token" + token_type = db.Column(db.String(50), nullable=False) - # Token identifier for revocation lookup - token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim + # JWT ID claim (indexed for fast lookup when id != jti) + token_jti = db.Column(db.String(255), nullable=False, index=True) # Timing expires_at = db.Column(db.DateTime, nullable=False, index=True) @@ -46,25 +47,27 @@ class OIDCTokenMetadata(BaseModel): def __repr__(self): """String representation of OIDCTokenMetadata.""" - return f"" + return ( + f"" + ) - def is_expired(self): + def is_expired(self) -> bool: """Check if the token has expired.""" - # Handle both timezone-aware and timezone-naive expires_at values expires_at = self.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) return datetime.now(timezone.utc) > expires_at - def is_revoked(self): + def is_revoked(self) -> bool: """Check if the token has been revoked.""" return self.revoked_at is not None - def is_valid(self): + def is_valid(self) -> bool: """Check if the token is valid (not expired and not revoked).""" return not self.is_revoked() and not self.is_expired() and self.deleted_at is None - def revoke(self, reason=None): + def revoke(self, reason: str = None) -> None: """Revoke the token. Args: @@ -75,8 +78,16 @@ class OIDCTokenMetadata(BaseModel): db.session.commit() @classmethod - def create_metadata(cls, client_id, user_id, token_type, token_jti, - expires_at, ip_address=None, user_agent=None): + def create_metadata( + cls, + client_id: str, + user_id: str, + token_type: str, + token_jti: str, + expires_at, + ip_address: str = None, + user_agent: str = None, + ) -> "OIDCTokenMetadata": """Create token metadata for tracking. Args: @@ -85,8 +96,8 @@ class OIDCTokenMetadata(BaseModel): token_type: Type of token ("access_token", "refresh_token", "id_token") token_jti: JWT ID claim expires_at: Token expiration datetime - ip_address: Client IP address - user_agent: Client user agent + ip_address: Client IP address (unused column, kept for API compat) + user_agent: Client user agent (unused column, kept for API compat) Returns: OIDCTokenMetadata instance @@ -104,7 +115,7 @@ class OIDCTokenMetadata(BaseModel): return metadata @classmethod - def get_by_jti(cls, token_jti): + def get_by_jti(cls, token_jti: str) -> "OIDCTokenMetadata | None": """Get token metadata by JWT ID. Args: @@ -116,7 +127,7 @@ class OIDCTokenMetadata(BaseModel): return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first() @classmethod - def revoke_by_jti(cls, token_jti, reason=None): + def revoke_by_jti(cls, token_jti: str, reason: str = None) -> bool: """Revoke a token by its JWT ID. Args: @@ -124,7 +135,7 @@ class OIDCTokenMetadata(BaseModel): reason: Optional revocation reason Returns: - bool: True if token was found and revoked + True if token was found and revoked, False otherwise """ metadata = cls.get_by_jti(token_jti) if metadata: @@ -133,47 +144,53 @@ class OIDCTokenMetadata(BaseModel): return False @classmethod - def revoke_all_for_user(cls, user_id, client_id=None, reason=None): + def revoke_all_for_user( + cls, user_id: str, client_id: str = None, reason: str = None + ) -> int: """Revoke all tokens for a user. Args: user_id: The user ID - client_id: Optional client ID to filter by + client_id: Optional client ID filter reason: Optional revocation reason Returns: - int: Number of tokens revoked + Number of tokens revoked """ - query = cls.query.filter_by(user_id=user_id, deleted_at=None) + query = cls.query.filter_by(user_id=user_id, deleted_at=None).filter( + cls.revoked_at.is_(None) + ) if client_id: query = query.filter_by(client_id=client_id) - tokens = query.filter(cls.revoked_at == None).all() count = 0 - for token in tokens: + for token in query.all(): token.revoke(reason) count += 1 return count @classmethod - def revoke_all_for_client(cls, client_id, user_id=None, reason=None): + def revoke_all_for_client( + cls, client_id: str, user_id: str = None, reason: str = None + ) -> int: """Revoke all tokens for a client. Args: client_id: The client ID - user_id: Optional user ID to filter by + user_id: Optional user ID filter reason: Optional revocation reason Returns: - int: Number of tokens revoked + Number of tokens revoked """ - query = cls.query.filter_by(client_id=client_id, deleted_at=None) + query = cls.query.filter_by(client_id=client_id, deleted_at=None).filter( + cls.revoked_at.is_(None) + ) if user_id: query = query.filter_by(user_id=user_id) - tokens = query.filter(cls.revoked_at == None).all() count = 0 - for token in tokens: + for token in query.all(): token.revoke(reason) count += 1 return count @@ -181,16 +198,3 @@ class OIDCTokenMetadata(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary.""" return super().to_dict(exclude=exclude) - - -# Add relationship back to User model -from gatehouse_app.models.user import User -User.oidc_token_metadata = db.relationship( - "OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan" -) - -# Add relationship back to OIDCClient model -from gatehouse_app.models.oidc_client import OIDCClient -OIDCClient.token_metadata = db.relationship( - "OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan" -) diff --git a/gatehouse_app/models/oidc_jwks_key.py b/gatehouse_app/models/oidc_jwks_key.py deleted file mode 100644 index 07dcb80..0000000 --- a/gatehouse_app/models/oidc_jwks_key.py +++ /dev/null @@ -1,77 +0,0 @@ -"""OIDC JWKS Key model for persisting signing keys.""" -from datetime import datetime, timezone -from gatehouse_app.extensions import db -from gatehouse_app.models.base import BaseModel - - -class OidcJwksKey(BaseModel): - """ - OIDC JWKS Key model for persisting JSON Web Key Set signing keys. - - This model stores RSA/ECDSA key pairs used for signing OIDC tokens. - Multiple keys can be stored to support key rotation scenarios. - - Attributes: - id: Integer primary key - kid: Unique key ID used in JWT "kid" header - key_type: Type of key (e.g., "RSA", "EC") - private_key: PEM-encoded private key - public_key: PEM-encoded public key - algorithm: Signing algorithm (e.g., "RS256", "ES256") - created_at: When the key was created - is_active: Whether this key is currently active for signing - is_primary: Whether this is the primary signing key - expires_at: ... - """ - - __tablename__ = "oidc_jwks_keys" - - # Override the default UUID id with integer primary key - id = db.Column(db.Integer, primary_key=True) - - expires_at = db.Column(db.DateTime, nullable=True) - - # Key identification and type - kid = db.Column(db.String(255), unique=True, nullable=False, index=True) - key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC" - algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256" - - # Key material (PEM-encoded) - private_key = db.Column(db.Text, nullable=False) - public_key = db.Column(db.Text, nullable=False) - - # Key status - is_active = db.Column(db.Boolean, default=True, nullable=False) - is_primary = db.Column(db.Boolean, default=False, nullable=False) - - def __repr__(self): - """String representation of OidcJwksKey.""" - return f"" - - def to_dict(self, exclude_private_key=True): - """ - Convert model to dictionary. - - Args: - exclude_private_key: If True, excludes the private key from output - - Returns: - Dictionary representation of the model - """ - exclude = ["private_key"] if exclude_private_key else [] - return super().to_dict(exclude=exclude) - - @classmethod - def get_active_keys(cls): - """Get all active keys for signing operations.""" - return cls.query.filter(cls.is_active == True).all() - - @classmethod - def get_primary_key(cls): - """Get the primary signing key.""" - return cls.query.filter(cls.is_primary == True).first() - - @classmethod - def get_key_by_kid(cls, kid): - """Get a key by its key ID.""" - return cls.query.filter(cls.kid == kid, cls.is_active == True).first() \ No newline at end of file diff --git a/gatehouse_app/models/organization/__init__.py b/gatehouse_app/models/organization/__init__.py new file mode 100644 index 0000000..aa33f8e --- /dev/null +++ b/gatehouse_app/models/organization/__init__.py @@ -0,0 +1,27 @@ +"""Organization subpackage.""" +from gatehouse_app.models.organization.organization import Organization +from gatehouse_app.models.organization.organization_member import OrganizationMember +from gatehouse_app.models.organization.department import ( + Department, + DepartmentMembership, + DepartmentPrincipal, +) +from gatehouse_app.models.organization.department_cert_policy import ( + DepartmentCertPolicy, + STANDARD_EXTENSIONS, +) +from gatehouse_app.models.organization.principal import Principal, PrincipalMembership +from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + +__all__ = [ + "Organization", + "OrganizationMember", + "Department", + "DepartmentMembership", + "DepartmentPrincipal", + "DepartmentCertPolicy", + "STANDARD_EXTENSIONS", + "Principal", + "PrincipalMembership", + "OrgInviteToken", +] diff --git a/gatehouse_app/models/organization/department.py b/gatehouse_app/models/organization/department.py new file mode 100644 index 0000000..800780b --- /dev/null +++ b/gatehouse_app/models/organization/department.py @@ -0,0 +1,196 @@ +"""Department, DepartmentMembership, and DepartmentPrincipal models.""" +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class Department(BaseModel): + """Department model representing an organizational unit for SSH access control. + + Departments are used to group users and assign SSH principals (access levels) + to them. A user can be a member of multiple departments, and each department + can have multiple principals assigned. + + Example: + - Department: "Engineering" + - Members: user1@example.com, user2@example.com + - Principals: "eng-prod", "eng-staging" + - Users get access based on their principal assignments + """ + + __tablename__ = "departments" + + organization_id = db.Column( + db.String(36), + db.ForeignKey("organizations.id"), + nullable=False, + index=True, + ) + name = db.Column(db.String(255), nullable=False, index=True) + description = db.Column(db.Text, nullable=True) + + # Relationships + organization = db.relationship("Organization", back_populates="departments") + memberships = db.relationship( + "DepartmentMembership", + back_populates="department", + cascade="all, delete-orphan", + ) + principal_links = db.relationship( + "DepartmentPrincipal", + back_populates="department", + cascade="all, delete-orphan", + ) + cert_policy = db.relationship( + "DepartmentCertPolicy", + back_populates="department", + uselist=False, + cascade="all, delete-orphan", + ) + + __table_args__ = ( + db.UniqueConstraint("organization_id", "name", name="uix_org_dept_name"), + ) + + def __repr__(self): + """String representation of Department.""" + return f"" + + def to_dict(self, exclude=None): + """Convert department to dictionary.""" + exclude = exclude or [] + data = super().to_dict(exclude=exclude) + data["member_count"] = len([m for m in self.memberships if m.deleted_at is None]) + data["principal_count"] = len([p for p in self.principal_links if p.deleted_at is None]) + return data + + def get_members(self, active_only: bool = True): + """Get all members of this department. + + Args: + active_only: If True, exclude soft-deleted members + + Returns: + List of DepartmentMembership objects + """ + if active_only: + return [m for m in self.memberships if m.deleted_at is None] + return list(self.memberships) + + def get_principals(self, active_only: bool = True): + """Get all principals assigned to this department. + + Args: + active_only: If True, exclude soft-deleted principals + + Returns: + List of Principal objects via DepartmentPrincipal + """ + if active_only: + return [ + p.principal + for p in self.principal_links + if p.deleted_at is None and p.principal.deleted_at is None + ] + return [p.principal for p in self.principal_links] + + def is_member(self, user_id: str) -> bool: + """Check if a user is a member of this department. + + Args: + user_id: ID of the user to check + + Returns: + True if user is an active member, False otherwise + """ + return ( + DepartmentMembership.query.filter_by( + user_id=user_id, + department_id=self.id, + deleted_at=None, + ).first() + is not None + ) + + def get_member_count(self) -> int: + """Get the count of active members in this department.""" + return len(self.get_members(active_only=True)) + + +class DepartmentMembership(BaseModel): + """Department membership model representing user membership in a department. + + When a user is added to a department, they become eligible for SSH principals + assigned to that department. + """ + + __tablename__ = "department_memberships" + + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id"), + nullable=False, + index=True, + ) + department_id = db.Column( + db.String(36), + db.ForeignKey("departments.id"), + nullable=False, + index=True, + ) + + # Relationships + user = db.relationship("User", back_populates="department_memberships") + department = db.relationship("Department", back_populates="memberships") + + __table_args__ = ( + db.UniqueConstraint("user_id", "department_id", name="uix_user_dept"), + ) + + def __repr__(self): + """String representation of DepartmentMembership.""" + return ( + f"" + ) + + +class DepartmentPrincipal(BaseModel): + """Department principal assignment model. + + Represents the assignment of principals to departments. All members of a + department get access to its assigned principals (transitively). + + Example: + - Department: "Engineering" + - Principal: "eng-prod-servers" + - All engineering department members can SSH as "eng-prod-servers" + """ + + __tablename__ = "department_principals" + + department_id = db.Column( + db.String(36), + db.ForeignKey("departments.id"), + nullable=False, + index=True, + ) + principal_id = db.Column( + db.String(36), + db.ForeignKey("principals.id"), + nullable=False, + index=True, + ) + + # Relationships + department = db.relationship("Department", back_populates="principal_links") + principal = db.relationship("Principal", back_populates="department_links") + + __table_args__ = ( + db.UniqueConstraint("department_id", "principal_id", name="uix_dept_principal"), + ) + + def __repr__(self): + """String representation of DepartmentPrincipal.""" + return ( + f"" + ) diff --git a/gatehouse_app/models/organization/department_cert_policy.py b/gatehouse_app/models/organization/department_cert_policy.py new file mode 100644 index 0000000..357329f --- /dev/null +++ b/gatehouse_app/models/organization/department_cert_policy.py @@ -0,0 +1,76 @@ +"""DepartmentCertPolicy — per-department SSH certificate issuance rules.""" +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +# Standard SSH certificate extensions +STANDARD_EXTENSIONS = [ + "permit-X11-forwarding", + "permit-agent-forwarding", + "permit-pty", + "permit-port-forwarding", + "permit-user-rc", +] + + +class DepartmentCertPolicy(BaseModel): + """SSH certificate policy for a department. + + Controls: + - Whether members may choose their own expiry date (up to ``max_expiry_hours``) + - Default expiry hours when the user doesn't (or can't) pick + - Maximum expiry hours (hard ceiling, even for admins signing on behalf) + - Which SSH certificate extensions are granted to members of this department + - Any custom extensions the admin wants to add beyond the standard five + + Inherits ``id``, ``created_at``, ``updated_at``, and ``deleted_at`` from + :class:`BaseModel` so soft-delete and the standard timestamp behaviour are + consistent with every other model in the application. + """ + + __tablename__ = "department_cert_policies" + + department_id = db.Column( + db.String(36), + db.ForeignKey("departments.id"), + nullable=False, + unique=True, + index=True, + ) + + # Expiry control + allow_user_expiry = db.Column(db.Boolean, nullable=False, default=False) + default_expiry_hours = db.Column(db.Integer, nullable=False, default=1) + max_expiry_hours = db.Column(db.Integer, nullable=False, default=24) + + # Extensions — list of extension name strings + allowed_extensions = db.Column( + db.JSON, + nullable=False, + default=lambda: list(STANDARD_EXTENSIONS), + ) + # Admin-defined extras beyond the standard five + custom_extensions = db.Column(db.JSON, nullable=False, default=list) + + # Relationship back to department + department = db.relationship("Department", back_populates="cert_policy", uselist=False) + + def __repr__(self): + return ( + f"" + ) + + def all_extensions(self) -> list: + """Return the full list of enabled extensions (allowed + custom).""" + return list((self.allowed_extensions or []) + (self.custom_extensions or [])) + + def to_dict(self, exclude=None): + """Convert to dictionary.""" + exclude = exclude or [] + data = super().to_dict(exclude=exclude) + # Augment with computed / convenience fields not in the base columns + data["all_extensions"] = self.all_extensions() + data["standard_extensions"] = STANDARD_EXTENSIONS + return data diff --git a/gatehouse_app/models/organization/org_invite_token.py b/gatehouse_app/models/organization/org_invite_token.py new file mode 100644 index 0000000..2830b25 --- /dev/null +++ b/gatehouse_app/models/organization/org_invite_token.py @@ -0,0 +1,77 @@ +"""Organization invite token model.""" +import secrets +from datetime import datetime, timezone, timedelta + +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class OrgInviteToken(BaseModel): + """Token-based invitation to join an organization.""" + + __tablename__ = "org_invite_tokens" + + organization_id = db.Column( + db.String(36), + db.ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + invited_by_id = db.Column( + db.String(36), + db.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ) + email = db.Column(db.String(255), nullable=False, index=True) + role = db.Column(db.String(64), nullable=False, default="member") + token = db.Column(db.String(128), unique=True, nullable=False, index=True) + expires_at = db.Column(db.DateTime, nullable=False) + accepted_at = db.Column(db.DateTime, nullable=True) + + organization = db.relationship( + "Organization", + backref=db.backref("invite_tokens", cascade="all, delete-orphan"), + ) + invited_by = db.relationship("User", foreign_keys=[invited_by_id]) + + @classmethod + def generate( + cls, + organization_id: str, + email: str, + role: str = "member", + invited_by_id: str = None, + ttl_days: int = 7, + ) -> "OrgInviteToken": + """Create a new invite token for an organization.""" + token_value = secrets.token_urlsafe(48) + instance = cls( + organization_id=organization_id, + email=email.lower(), + role=role, + invited_by_id=invited_by_id, + token=token_value, + expires_at=datetime.now(timezone.utc) + timedelta(days=ttl_days), + ) + db.session.add(instance) + db.session.commit() + return instance + + @property + def is_valid(self) -> bool: + """Return True if the token is unused and not expired.""" + if self.accepted_at is not None: + return False + now = datetime.now(timezone.utc) + expires = self.expires_at + if expires.tzinfo is None: + expires = expires.replace(tzinfo=timezone.utc) + return now < expires + + def accept(self) -> None: + """Mark the invite as accepted.""" + self.accepted_at = datetime.now(timezone.utc) + db.session.commit() + + def __repr__(self) -> str: + return f"" diff --git a/gatehouse_app/models/organization.py b/gatehouse_app/models/organization/organization.py similarity index 81% rename from gatehouse_app/models/organization.py rename to gatehouse_app/models/organization/organization.py index 1c4170f..9be5c65 100644 --- a/gatehouse_app/models/organization.py +++ b/gatehouse_app/models/organization/organization.py @@ -34,6 +34,15 @@ class Organization(BaseModel): cascade="all, delete-orphan", foreign_keys="OrganizationSecurityPolicy.organization_id", ) + departments = db.relationship( + "Department", back_populates="organization", cascade="all, delete-orphan" + ) + principals = db.relationship( + "Principal", back_populates="organization", cascade="all, delete-orphan" + ) + cas = db.relationship( + "CA", back_populates="organization", cascade="all, delete-orphan" + ) def __repr__(self): """String representation of Organization.""" @@ -52,9 +61,9 @@ class Organization(BaseModel): return member.user return None - def is_member(self, user_id): + def is_member(self, user_id: str) -> bool: """Check if a user is a member of the organization.""" - from gatehouse_app.models.organization_member import OrganizationMember + from gatehouse_app.models.organization.organization_member import OrganizationMember return ( OrganizationMember.query.filter_by( diff --git a/gatehouse_app/models/organization_member.py b/gatehouse_app/models/organization/organization_member.py similarity index 78% rename from gatehouse_app/models/organization_member.py rename to gatehouse_app/models/organization/organization_member.py index 3247082..3b02242 100644 --- a/gatehouse_app/models/organization_member.py +++ b/gatehouse_app/models/organization/organization_member.py @@ -1,4 +1,4 @@ -"""Organization member model.""" +"""Organization member model.""" from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel from gatehouse_app.utils.constants import OrganizationRole @@ -21,31 +21,35 @@ class OrganizationMember(BaseModel): joined_at = db.Column(db.DateTime, nullable=True) # Relationships - user = db.relationship("User", foreign_keys=[user_id], back_populates="organization_memberships") + user = db.relationship( + "User", foreign_keys=[user_id], back_populates="organization_memberships" + ) organization = db.relationship("Organization", back_populates="members") invited_by = db.relationship("User", foreign_keys=[invited_by_id]) - # Unique constraint to prevent duplicate memberships __table_args__ = ( db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"), ) def __repr__(self): """String representation of OrganizationMember.""" - return f"" + return ( + f"" + ) - def is_owner(self): + def is_owner(self) -> bool: """Check if member is an owner.""" return self.role == OrganizationRole.OWNER - def is_admin(self): + def is_admin(self) -> bool: """Check if member is an admin or owner.""" return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] - def can_manage_members(self): + def can_manage_members(self) -> bool: """Check if member can manage other members.""" return self.is_admin() - def can_delete_organization(self): + def can_delete_organization(self) -> bool: """Check if member can delete the organization.""" return self.is_owner() diff --git a/gatehouse_app/models/organization/principal.py b/gatehouse_app/models/organization/principal.py new file mode 100644 index 0000000..8f87fe3 --- /dev/null +++ b/gatehouse_app/models/organization/principal.py @@ -0,0 +1,215 @@ +"""Principal and PrincipalMembership models.""" +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class Principal(BaseModel): + """Principal model representing an SSH principal (access level/role). + + In SSH CA terminology, a principal is a string like "eng-prod-servers" or + "devops-admins" that represents a set of machines or access level. Users + can be granted access to principals, either directly or via department + membership. + + Example: + - Principal: "eng-prod-servers" + - Users with this principal can SSH to prod servers + - Can be assigned to departments or directly to users + """ + + __tablename__ = "principals" + + organization_id = db.Column( + db.String(36), + db.ForeignKey("organizations.id"), + nullable=False, + index=True, + ) + name = db.Column(db.String(255), nullable=False, index=True) + description = db.Column(db.Text, nullable=True) + + # Relationships + organization = db.relationship("Organization", back_populates="principals") + memberships = db.relationship( + "PrincipalMembership", + back_populates="principal", + cascade="all, delete-orphan", + ) + department_links = db.relationship( + "DepartmentPrincipal", + back_populates="principal", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + db.UniqueConstraint("organization_id", "name", name="uix_org_principal_name"), + ) + + def __repr__(self): + """String representation of Principal.""" + return f"" + + def to_dict(self, exclude=None): + """Convert principal to dictionary.""" + exclude = exclude or [] + data = super().to_dict(exclude=exclude) + data["direct_member_count"] = len( + [m for m in self.memberships if m.deleted_at is None] + ) + data["department_count"] = len( + [d for d in self.department_links if d.deleted_at is None] + ) + return data + + def get_members(self, active_only: bool = True): + """Get all users who are directly assigned to this principal. + + Does NOT include users who get access via department membership. + + Args: + active_only: If True, exclude soft-deleted members + + Returns: + List of PrincipalMembership objects + """ + if active_only: + return [m for m in self.memberships if m.deleted_at is None] + return list(self.memberships) + + def get_all_members(self, active_only: bool = True): + """Get all users who have access to this principal. + + Includes both direct members and users via department membership. + + Args: + active_only: If True, exclude soft-deleted members + + Returns: + Set of User objects with access to this principal + """ + all_users: set = set() + + # Direct members + for membership in self.get_members(active_only=active_only): + if not active_only or membership.user.deleted_at is None: + all_users.add(membership.user) + + # Members via department assignment + for dept_link in self.department_links: + if dept_link.deleted_at is None or not active_only: + for dept_member in dept_link.department.get_members(active_only=active_only): + if not active_only or dept_member.user.deleted_at is None: + all_users.add(dept_member.user) + + return all_users + + def get_departments(self, active_only: bool = True): + """Get all departments this principal is assigned to. + + Args: + active_only: If True, exclude soft-deleted departments + + Returns: + List of Department objects + """ + if active_only: + return [ + d.department + for d in self.department_links + if d.deleted_at is None and d.department.deleted_at is None + ] + return [d.department for d in self.department_links] + + def is_member(self, user_id: str, include_via_department: bool = True) -> bool: + """Check if a user has access to this principal. + + Args: + user_id: ID of the user to check + include_via_department: If True, check department memberships too + + Returns: + True if user has access to this principal + """ + # Check direct membership + has_direct = ( + PrincipalMembership.query.filter_by( + user_id=user_id, + principal_id=self.id, + deleted_at=None, + ).first() + is not None + ) + + if has_direct: + return True + + if not include_via_department: + return False + + # Check department membership + dept_ids = [d.id for d in self.get_departments(active_only=True)] + if not dept_ids: + return False + + from gatehouse_app.models.organization.department import DepartmentMembership + + return ( + DepartmentMembership.query.filter( + DepartmentMembership.user_id == user_id, + DepartmentMembership.department_id.in_(dept_ids), + DepartmentMembership.deleted_at.is_(None), + ).first() + is not None + ) + + def get_member_count(self, include_via_department: bool = True) -> int: + """Get the count of active members with access to this principal. + + Args: + include_via_department: If True, include members via department + + Returns: + Count of members + """ + if not include_via_department: + return len(self.get_members(active_only=True)) + return len(self.get_all_members(active_only=True)) + + +class PrincipalMembership(BaseModel): + """Principal membership model representing direct user assignment to a principal. + + When a user is assigned directly to a principal, they get access to that + principal for SSH authentication. This is in addition to any principals + they get via department membership. + """ + + __tablename__ = "principal_memberships" + + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id"), + nullable=False, + index=True, + ) + principal_id = db.Column( + db.String(36), + db.ForeignKey("principals.id"), + nullable=False, + index=True, + ) + + # Relationships + user = db.relationship("User", back_populates="principal_memberships") + principal = db.relationship("Principal", back_populates="memberships") + + __table_args__ = ( + db.UniqueConstraint("user_id", "principal_id", name="uix_user_principal"), + ) + + def __repr__(self): + """String representation of PrincipalMembership.""" + return ( + f"" + ) diff --git a/gatehouse_app/models/security/__init__.py b/gatehouse_app/models/security/__init__.py new file mode 100644 index 0000000..d24aef5 --- /dev/null +++ b/gatehouse_app/models/security/__init__.py @@ -0,0 +1,12 @@ +"""Security subpackage — organization and user security policies, MFA compliance.""" +from gatehouse_app.models.security.organization_security_policy import ( + OrganizationSecurityPolicy, +) +from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy +from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance + +__all__ = [ + "OrganizationSecurityPolicy", + "UserSecurityPolicy", + "MfaPolicyCompliance", +] diff --git a/gatehouse_app/models/mfa_policy_compliance.py b/gatehouse_app/models/security/mfa_policy_compliance.py similarity index 76% rename from gatehouse_app/models/mfa_policy_compliance.py rename to gatehouse_app/models/security/mfa_policy_compliance.py index 6ecd217..5c2de13 100644 --- a/gatehouse_app/models/mfa_policy_compliance.py +++ b/gatehouse_app/models/security/mfa_policy_compliance.py @@ -1,4 +1,4 @@ -"""MfaPolicyCompliance model.""" +"""MfaPolicyCompliance model — per-user per-organization MFA compliance tracking.""" from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel from gatehouse_app.utils.constants import MfaComplianceStatus @@ -7,7 +7,8 @@ from gatehouse_app.utils.constants import MfaComplianceStatus class MfaPolicyCompliance(BaseModel): """MFA policy compliance tracking per user per organization. - Tracks each user's MFA compliance state separately for each organization membership. + Tracks each user's MFA compliance state separately for each organization + membership. One row per (user, org) pair. """ __tablename__ = "mfa_policy_compliance" @@ -25,13 +26,13 @@ class MfaPolicyCompliance(BaseModel): default=MfaComplianceStatus.NOT_APPLICABLE, ) - # Snapshot of org policy at the time this record became active + # Snapshot of org policy version when this record became active policy_version = db.Column(db.Integer, nullable=False) # When policy started applying to this user applied_at = db.Column(db.DateTime, nullable=True) - # Final deadline for this user to comply (per user, not global) + # Final deadline for this user to comply deadline_at = db.Column(db.DateTime, nullable=True) # When they became compliant under this policy_version @@ -45,9 +46,7 @@ class MfaPolicyCompliance(BaseModel): notification_count = db.Column(db.Integer, nullable=False, default=0) __table_args__ = ( - db.UniqueConstraint( - "user_id", "organization_id", name="uix_user_org_compliance" - ), + db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_compliance"), ) # Relationships @@ -58,9 +57,11 @@ class MfaPolicyCompliance(BaseModel): def __repr__(self): """String representation of MfaPolicyCompliance.""" - return f"" + return ( + f"" + ) def to_dict(self, exclude=None): """Convert to dictionary.""" - exclude = exclude or [] - return super().to_dict(exclude=exclude) \ No newline at end of file + return super().to_dict(exclude=exclude or []) diff --git a/gatehouse_app/models/organization_security_policy.py b/gatehouse_app/models/security/organization_security_policy.py similarity index 83% rename from gatehouse_app/models/organization_security_policy.py rename to gatehouse_app/models/security/organization_security_policy.py index 991b72d..593781a 100644 --- a/gatehouse_app/models/organization_security_policy.py +++ b/gatehouse_app/models/security/organization_security_policy.py @@ -39,15 +39,19 @@ class OrganizationSecurityPolicy(BaseModel): # Relationships organization = db.relationship( - "Organization", back_populates="security_policy", foreign_keys=[organization_id] + "Organization", + back_populates="security_policy", + foreign_keys=[organization_id], ) updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id]) def __repr__(self): """String representation of OrganizationSecurityPolicy.""" - return f"" + return ( + f"" + ) def to_dict(self, exclude=None): """Convert to dictionary.""" - exclude = exclude or [] - return super().to_dict(exclude=exclude) \ No newline at end of file + return super().to_dict(exclude=exclude or []) diff --git a/gatehouse_app/models/user_security_policy.py b/gatehouse_app/models/security/user_security_policy.py similarity index 67% rename from gatehouse_app/models/user_security_policy.py rename to gatehouse_app/models/security/user_security_policy.py index d765575..a96ef84 100644 --- a/gatehouse_app/models/user_security_policy.py +++ b/gatehouse_app/models/security/user_security_policy.py @@ -1,4 +1,4 @@ -"""UserSecurityPolicy model.""" +"""UserSecurityPolicy model — per-user MFA overrides.""" from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel from gatehouse_app.utils.constants import MfaRequirementOverride @@ -7,7 +7,7 @@ from gatehouse_app.utils.constants import MfaRequirementOverride class UserSecurityPolicy(BaseModel): """User security policy model for per-user MFA overrides. - Stores per user overrides of organization level MFA requirements. + Stores per-user overrides of organization-level MFA requirements. """ __tablename__ = "user_security_policies" @@ -25,29 +25,27 @@ class UserSecurityPolicy(BaseModel): default=MfaRequirementOverride.INHERIT, ) - # If override is REQUIRED and you want to force a specific factor set + # If override is REQUIRED, optionally force a specific factor set force_totp = db.Column(db.Boolean, nullable=False, default=False) force_webauthn = db.Column(db.Boolean, nullable=False, default=False) __table_args__ = ( - db.UniqueConstraint( - "user_id", "organization_id", name="uix_user_org_policy" - ), + db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_policy"), ) # Relationships user = db.relationship( "User", back_populates="security_policies", foreign_keys=[user_id] ) - organization = db.relationship( - "Organization", foreign_keys=[organization_id] - ) + organization = db.relationship("Organization", foreign_keys=[organization_id]) def __repr__(self): """String representation of UserSecurityPolicy.""" - return f"" + return ( + f"" + ) def to_dict(self, exclude=None): """Convert to dictionary.""" - exclude = exclude or [] - return super().to_dict(exclude=exclude) \ No newline at end of file + return super().to_dict(exclude=exclude or []) diff --git a/gatehouse_app/models/ssh_ca/__init__.py b/gatehouse_app/models/ssh_ca/__init__.py new file mode 100644 index 0000000..d6932b3 --- /dev/null +++ b/gatehouse_app/models/ssh_ca/__init__.py @@ -0,0 +1,17 @@ +"""SSH/CA subpackage — certificate authorities, SSH keys, and certificates.""" +from gatehouse_app.models.ssh_ca.ca import CA, KeyType, CertType, CaType, CAPermission +from gatehouse_app.models.ssh_ca.ssh_key import SSHKey +from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus +from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog + +__all__ = [ + "CA", + "KeyType", + "CertType", + "CaType", + "CAPermission", + "SSHKey", + "SSHCertificate", + "CertificateStatus", + "CertificateAuditLog", +] diff --git a/gatehouse_app/models/ssh_ca/ca.py b/gatehouse_app/models/ssh_ca/ca.py new file mode 100644 index 0000000..182d842 --- /dev/null +++ b/gatehouse_app/models/ssh_ca/ca.py @@ -0,0 +1,238 @@ +"""Certificate Authority (CA) model.""" +from enum import Enum +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class KeyType(str, Enum): + """SSH CA key types.""" + + ED25519 = "ed25519" + RSA = "rsa" + ECDSA = "ecdsa" + + +class CertType(str, Enum): + """SSH certificate types.""" + + USER = "user" + HOST = "host" + + +class CaType(str, Enum): + """CA signing type — whether this CA signs user or host certificates.""" + + USER = "user" + HOST = "host" + + +class CA(BaseModel): + """Certificate Authority (CA) model for SSH certificate signing. + + Each organization can have multiple CAs for different purposes + (e.g., production vs. staging). Private keys are encrypted at rest + and should be protected with KMS. + """ + + __tablename__ = "cas" + + organization_id = db.Column( + db.String(36), + db.ForeignKey("organizations.id"), + nullable=True, # NULL for the global system-config CA + index=True, + ) + + # CA identity + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=True) + + # CA signing type: 'user' signs user certificates, 'host' signs host certs + ca_type = db.Column( + db.Enum(CaType, values_callable=lambda x: [e.value for e in x]), + default=CaType.USER, + nullable=False, + ) + + # Key type (ED25519, RSA, ECDSA) + key_type = db.Column( + db.Enum(KeyType, values_callable=lambda x: [e.value for e in x]), + default=KeyType.ED25519, + nullable=False, + ) + + # Private key — PEM-encoded, encrypted at rest by database/KMS + private_key = db.Column(db.Text, nullable=False) + + # Public key — PEM format + public_key = db.Column(db.Text, nullable=False) + + # SHA256 fingerprint of the public key + fingerprint = db.Column(db.String(255), nullable=False, unique=True) + + # CRL (Certificate Revocation List) configuration + crl_enabled = db.Column(db.Boolean, default=True, nullable=False) + crl_endpoint = db.Column(db.String(512), nullable=True) + + # Default certificate validity in hours (overridable per request) + default_cert_validity_hours = db.Column(db.Integer, default=1, nullable=False) + + # Maximum validity duration allowed + max_cert_validity_hours = db.Column(db.Integer, default=24, nullable=False) + + # CA status + is_active = db.Column(db.Boolean, default=True, nullable=False, index=True) + + # Key rotation tracking + rotated_at = db.Column(db.DateTime, nullable=True) + rotation_reason = db.Column(db.String(255), nullable=True) + + # Monotonically-increasing serial counter. Every cert this CA issues + # gets the next value so serials are unique, ordered, and auditable. + # Protected by a row-level SELECT … FOR UPDATE in get_next_serial(). + next_serial_number = db.Column(db.BigInteger, default=1, nullable=False) + + # Relationships + organization = db.relationship("Organization", back_populates="cas") + certificates = db.relationship( + "SSHCertificate", + back_populates="ca", + cascade="all, delete-orphan", + ) + permissions = db.relationship( + "CAPermission", + back_populates="ca", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + db.Index("idx_ca_org_active", "organization_id", "is_active"), + ) + + def __repr__(self): + """String representation of CA.""" + return ( + f"" + ) + + def to_dict(self, exclude=None): + """Convert CA to dictionary, never exposing the private key.""" + exclude = exclude or [] + if "private_key" not in exclude: + exclude.append("private_key") + data = super().to_dict(exclude=exclude) + + # Add computed fields + data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None]) + data["active_certs"] = len( + [c for c in self.certificates if c.deleted_at is None and not c.revoked] + ) + data["revoked_certs"] = len( + [c for c in self.certificates if c.deleted_at is None and c.revoked] + ) + return data + + def get_active_certificates(self) -> list: + """Get all active (non-revoked) certificates issued by this CA.""" + return [ + c for c in self.certificates if c.deleted_at is None and not c.revoked + ] + + def rotate_key( + self, + new_private_key: str, + new_public_key: str, + new_fingerprint: str, + reason: str = None, + ) -> None: + """Rotate the CA's key pair. + + This should only be done in carefully controlled circumstances. + All existing certificates remain valid but no new certificates can be + signed with the old key after rotation. + + Args: + new_private_key: New PEM-encoded private key + new_public_key: New PEM-encoded public key + new_fingerprint: SHA256 fingerprint of new public key + reason: Optional reason for rotation + """ + self.private_key = new_private_key + self.public_key = new_public_key + self.fingerprint = new_fingerprint + self.rotated_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow() + self.rotation_reason = reason + self.save() + + def get_next_serial(self) -> int: + """Atomically increment and return the next certificate serial number. + + Uses a SELECT … FOR UPDATE row lock so concurrent requests never + receive the same serial. Must be called inside an active DB + transaction (i.e. before the final session.commit()). + + Returns: + int: The serial number to embed in the next certificate. + """ + # Re-fetch this CA row with an exclusive row lock + locked = ( + db.session.query(CA) + .with_for_update() + .filter_by(id=self.id) + .one() + ) + serial = locked.next_serial_number + locked.next_serial_number = serial + 1 + db.session.flush() # write increment; commit happens in the caller + return serial + + +class CAPermission(BaseModel): + """Per-user CA permission model. + + Controls which users are allowed to sign certificates against a specific CA. + When a CA has any permission rows, the signing endpoint enforces the list; + CAs with no rows are open to all org members (backwards-compatible default). + + Permission values: + sign – user may request certificate signing + admin – user may sign AND manage the CA (rotate keys, delete, etc.) + """ + + __tablename__ = "ca_permissions" + + ca_id = db.Column( + db.String(36), + db.ForeignKey("cas.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + permission = db.Column(db.String(50), nullable=False, default="sign") + + # Relationships + ca = db.relationship("CA", back_populates="permissions") + user = db.relationship("User", back_populates="ca_permissions") + + __table_args__ = ( + db.UniqueConstraint("ca_id", "user_id", name="uix_ca_permission"), + ) + + def __repr__(self): + return ( + f"" + ) + + def to_dict(self, exclude=None): + data = super().to_dict(exclude=exclude or []) + data["permission"] = self.permission + return data diff --git a/gatehouse_app/models/ssh_ca/certificate_audit_log.py b/gatehouse_app/models/ssh_ca/certificate_audit_log.py new file mode 100644 index 0000000..02f24d3 --- /dev/null +++ b/gatehouse_app/models/ssh_ca/certificate_audit_log.py @@ -0,0 +1,91 @@ +"""Certificate audit log model — tracks SSH certificate lifecycle events.""" +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class CertificateAuditLog(BaseModel): + """Audit log for SSH certificate lifecycle events. + + Tracks all operations on SSH certificates: signing, revocation, validation, + etc. Kept separate from the general AuditLog to provide detailed certificate + operation tracking without polluting the main audit stream. + """ + + __tablename__ = "certificate_audit_logs" + + # Reference to the certificate + certificate_id = db.Column( + db.String(36), + db.ForeignKey("ssh_certificates.id"), + nullable=False, + index=True, + ) + + # The user who performed the action (null for system actions) + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id"), + nullable=True, + index=True, + ) + + # Action type (e.g., "signed", "revoked", "validated", "requested") + action = db.Column(db.String(50), nullable=False, index=True) + + # Request details + ip_address = db.Column(db.String(45), nullable=True) + user_agent = db.Column(db.String(512), nullable=True) + request_id = db.Column(db.String(36), nullable=True) + + # Detailed message + message = db.Column(db.Text, nullable=True) + + # Additional context + extra_data = db.Column(db.JSON, nullable=True) + + # Outcome + success = db.Column(db.Boolean, default=True, nullable=False) + error_message = db.Column(db.Text, nullable=True) + + # Relationships + certificate = db.relationship("SSHCertificate", back_populates="audit_logs") + user = db.relationship("User") + + __table_args__ = ( + db.Index("idx_cert_audit_cert_action", "certificate_id", "action"), + db.Index("idx_cert_audit_user", "user_id", "created_at"), + ) + + def __repr__(self): + """String representation of CertificateAuditLog.""" + return ( + f"" + ) + + @classmethod + def log( + cls, + certificate_id: str, + action: str, + user_id: str = None, + **kwargs, + ) -> "CertificateAuditLog": + """Create a certificate audit log entry. + + Args: + certificate_id: ID of the certificate + action: Action type (e.g., "signed", "revoked") + user_id: ID of the user performing the action (optional) + **kwargs: Additional fields (ip_address, user_agent, message, etc.) + + Returns: + CertificateAuditLog instance + """ + log_entry = cls( + certificate_id=certificate_id, + action=action, + user_id=user_id, + **kwargs, + ) + log_entry.save() + return log_entry diff --git a/gatehouse_app/models/ssh_ca/ssh_certificate.py b/gatehouse_app/models/ssh_ca/ssh_certificate.py new file mode 100644 index 0000000..f226a69 --- /dev/null +++ b/gatehouse_app/models/ssh_ca/ssh_certificate.py @@ -0,0 +1,176 @@ +"""SSH Certificate model — signed SSH user/host certificates.""" +from enum import Enum +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel +from gatehouse_app.models.ssh_ca.ca import CertType + + +class CertificateStatus(str, Enum): + """SSH certificate lifecycle status.""" + + REQUESTED = "requested" # Waiting for signing + ISSUED = "issued" # Signed and valid + REVOKED = "revoked" # Manually revoked + EXPIRED = "expired" # Validity period ended + SUPERSEDED = "superseded" # Replaced by newer certificate + + +class SSHCertificate(BaseModel): + """SSH Certificate model representing a signed SSH user/host certificate. + + Certificates are issued by a CA and associated with an SSH public key. + They include principals (access levels), validity periods, and standard + OpenSSH certificate metadata. + """ + + __tablename__ = "ssh_certificates" + + # Certificate relationships + ca_id = db.Column( + db.String(36), + db.ForeignKey("cas.id"), + nullable=False, + index=True, + ) + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id"), + nullable=False, + index=True, + ) + ssh_key_id = db.Column( + db.String(36), + db.ForeignKey("ssh_keys.id"), + nullable=False, + index=True, + ) + + # Certificate content — full signed certificate in OpenSSH format + certificate = db.Column(db.Text, nullable=False) + + # Certificate metadata + serial = db.Column(db.String(255), nullable=False, unique=True, index=True) + key_id = db.Column(db.String(255), nullable=False) # Usually user email + cert_type = db.Column( + db.Enum(CertType, values_callable=lambda x: [e.value for e in x]), + default=CertType.USER, + nullable=False, + ) + + # Principals — JSON list, e.g., ["prod-servers", "dev-servers"] + principals = db.Column(db.JSON, nullable=False, default=list) + + # Validity period + valid_after = db.Column(db.DateTime, nullable=False) + valid_before = db.Column(db.DateTime, nullable=False) + + # Revocation status + revoked = db.Column(db.Boolean, default=False, nullable=False, index=True) + revoked_at = db.Column(db.DateTime, nullable=True) + revoke_reason = db.Column(db.String(255), nullable=True) + + # Status tracking + status = db.Column( + db.Enum(CertificateStatus, values_callable=lambda x: [e.value for e in x]), + default=CertificateStatus.ISSUED, + nullable=False, + index=True, + ) + + # Request metadata + request_ip = db.Column(db.String(45), nullable=True) + request_user_agent = db.Column(db.String(512), nullable=True) + + # Critical options — OpenSSH critical options (JSON) + critical_options = db.Column(db.JSON, nullable=True, default=dict) + + # Extensions — OpenSSH extensions (JSON) + extensions = db.Column(db.JSON, nullable=True, default=dict) + + # Relationships + ca = db.relationship("CA", back_populates="certificates") + user = db.relationship("User", back_populates="ssh_certificates") + ssh_key = db.relationship( + "SSHKey", + back_populates="certificates", + foreign_keys="SSHCertificate.ssh_key_id", + ) + audit_logs = db.relationship( + "CertificateAuditLog", + back_populates="certificate", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + db.Index("idx_cert_user_status", "user_id", "status"), + db.Index("idx_cert_validity", "valid_after", "valid_before"), + db.Index("idx_cert_revoked", "revoked", "revoked_at"), + ) + + def __repr__(self): + """String representation of SSHCertificate.""" + return f"" + + def to_dict(self, exclude=None): + """Convert certificate to dictionary. + + The raw ``certificate`` blob is excluded by default (it is large and + callers can request it explicitly by removing it from the exclude list). + """ + exclude = exclude or [] + if "certificate" not in exclude: + exclude.append("certificate") + data = super().to_dict(exclude=exclude) + data["is_valid"] = self.is_valid() + data["days_until_expiry"] = self.days_until_expiry() + return data + + def _aware(self, dt: datetime) -> datetime: + """Return a timezone-aware UTC datetime.""" + return dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt + + def is_valid(self) -> bool: + """Check if certificate is currently valid. + + Returns: + True if certificate is issued, not revoked, and within validity period + """ + if self.revoked or self.status == CertificateStatus.REVOKED: + return False + now = datetime.now(timezone.utc) + return self._aware(self.valid_after) <= now <= self._aware(self.valid_before) + + def is_expired(self) -> bool: + """Check if certificate has expired. + + Returns: + True if current time is past valid_before + """ + return datetime.now(timezone.utc) > self._aware(self.valid_before) + + def days_until_expiry(self) -> int: + """Get number of days until certificate expires. + + Returns: + Number of days remaining (negative if already expired) + """ + delta = self._aware(self.valid_before) - datetime.now(timezone.utc) + return delta.days + (1 if delta.seconds > 0 else 0) + + def revoke(self, reason: str = None) -> None: + """Revoke this certificate. + + Args: + reason: Optional reason for revocation + """ + self.revoked = True + self.revoked_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow() + self.revoke_reason = reason + self.status = CertificateStatus.REVOKED + self.save() + + def mark_expired(self) -> None: + """Mark certificate as expired when validity period ends.""" + self.status = CertificateStatus.EXPIRED + self.save() diff --git a/gatehouse_app/models/ssh_ca/ssh_key.py b/gatehouse_app/models/ssh_ca/ssh_key.py new file mode 100644 index 0000000..218fd99 --- /dev/null +++ b/gatehouse_app/models/ssh_ca/ssh_key.py @@ -0,0 +1,98 @@ +"""SSH Key model — user SSH public keys registered for certificate signing.""" +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class SSHKey(BaseModel): + """SSH Key model representing a user's SSH public key. + + Users register SSH public keys for certificate signing. Keys must be + verified (owner proved possession) before they can be used. + """ + + __tablename__ = "ssh_keys" + + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id"), + nullable=False, + index=True, + ) + + # SSH key payload in OpenSSH format (e.g., "ssh-ed25519 AAAAB3Nz...") + payload = db.Column(db.Text, nullable=False, unique=True) + + # SHA256 fingerprint for quick comparison and deduplication + fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True) + + # Optional human-readable description (e.g., "My laptop key") + description = db.Column(db.String(255), nullable=True) + + # Verification status + verified = db.Column(db.Boolean, default=False, nullable=False, index=True) + verified_at = db.Column(db.DateTime, nullable=True) + + # Verification challenge — shown to user once, cleared after verification + verify_text = db.Column(db.String(255), nullable=True) + verify_text_created_at = db.Column(db.DateTime, nullable=True) + + # Key metadata extracted from the key + key_type = db.Column(db.String(50), nullable=True) # ssh-rsa, ssh-ed25519, etc. + key_bits = db.Column(db.Integer, nullable=True) # key length + key_comment = db.Column(db.String(255), nullable=True) + + # Relationships + user = db.relationship("User", back_populates="ssh_keys") + certificates = db.relationship( + "SSHCertificate", + back_populates="ssh_key", + cascade="all, delete-orphan", + foreign_keys="SSHCertificate.ssh_key_id", + ) + + __table_args__ = ( + db.Index("idx_ssh_key_user_verified", "user_id", "verified"), + ) + + def __repr__(self): + """String representation of SSHKey.""" + return f"" + + def to_dict(self, exclude=None): + """Convert SSH key to dictionary. + + ``payload`` and ``verify_text`` are never exposed through the API. + """ + exclude = exclude or [] + for field in ("payload", "verify_text"): + if field not in exclude: + exclude.append(field) + data = super().to_dict(exclude=exclude) + data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None]) + return data + + def mark_verified(self) -> None: + """Mark this SSH key as verified and clear the challenge.""" + self.verified = True + self.verified_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow() + self.verify_text = None + self.save() + + def needs_verification_refresh(self, max_age_hours: int = 24) -> bool: + """Check if verification challenge needs to be refreshed. + + Args: + max_age_hours: Maximum age of verification challenge in hours + + Returns: + True if verification challenge is stale or missing + """ + if not self.verify_text_created_at: + return True + age = datetime.now(timezone.utc) - self.verify_text_created_at.replace( + tzinfo=timezone.utc + ) if self.verify_text_created_at.tzinfo is None else ( + datetime.now(timezone.utc) - self.verify_text_created_at + ) + return age.total_seconds() > (max_age_hours * 3600) diff --git a/gatehouse_app/models/user.py b/gatehouse_app/models/user.py index 474269f..eecf942 100644 --- a/gatehouse_app/models/user.py +++ b/gatehouse_app/models/user.py @@ -1,153 +1,4 @@ -"""User model.""" -from gatehouse_app.extensions import db -from gatehouse_app.models.base import BaseModel -from gatehouse_app.utils.constants import UserStatus +"""Backward-compatibility shim — import from gatehouse_app.models.user.user instead.""" +from gatehouse_app.models.user.user import User # noqa: F401 - -class User(BaseModel): - """User model representing a user account.""" - - __tablename__ = "users" - - email = db.Column(db.String(255), unique=True, nullable=False, index=True) - email_verified = db.Column(db.Boolean, default=False, nullable=False) - full_name = db.Column(db.String(255), nullable=True) - avatar_url = db.Column(db.String(512), nullable=True) - status = db.Column( - db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True - ) - last_login_at = db.Column(db.DateTime, nullable=True) - last_login_ip = db.Column(db.String(45), nullable=True) - - # Relationships - authentication_methods = db.relationship( - "AuthenticationMethod", back_populates="user", cascade="all, delete-orphan" - ) - sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan") - organization_memberships = db.relationship( - "OrganizationMember", - back_populates="user", - cascade="all, delete-orphan", - foreign_keys="OrganizationMember.user_id", - ) - audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan") - security_policies = db.relationship( - "UserSecurityPolicy", - back_populates="user", - cascade="all, delete-orphan", - foreign_keys="UserSecurityPolicy.user_id", - ) - mfa_compliance = db.relationship( - "MfaPolicyCompliance", - back_populates="user", - cascade="all, delete-orphan", - foreign_keys="MfaPolicyCompliance.user_id", - ) - - def __repr__(self): - """String representation of User.""" - return f"" - - def to_dict(self, exclude=None): - """Convert user to dictionary, excluding sensitive fields by default.""" - exclude = exclude or [] - # Always exclude password-related fields - default_exclude = [] - all_exclude = list(set(default_exclude + exclude)) - return super().to_dict(exclude=all_exclude) - - def has_password_auth(self): - """Check if user has password authentication enabled.""" - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return ( - AuthenticationMethod.query.filter_by( - user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None - ).first() - is not None - ) - - def get_organizations(self): - """Get all organizations the user is a member of.""" - return [membership.organization for membership in self.organization_memberships] - - def has_totp_enabled(self) -> bool: - """Check if user has TOTP enabled and verified. - - Returns: - True if user has a verified TOTP authentication method, False otherwise. - """ - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return ( - AuthenticationMethod.query.filter_by( - user_id=self.id, - method_type=AuthMethodType.TOTP, - verified=True, - deleted_at=None, - ).first() - is not None - ) - - def get_totp_method(self): - """Get user's TOTP authentication method. - - Returns: - The AuthenticationMethod instance for TOTP or None if not found. - - Note: - Returns the most recently created TOTP method to handle cases where - multiple enrollment attempts may exist. - """ - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return AuthenticationMethod.query.filter_by( - user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None - ).order_by(AuthenticationMethod.created_at.desc()).first() - - def has_webauthn_enabled(self) -> bool: - """Check if user has any WebAuthn passkey credentials. - - Returns: - True if user has at least one WebAuthn credential, False otherwise. - """ - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return ( - AuthenticationMethod.query.filter_by( - user_id=self.id, - method_type=AuthMethodType.WEBAUTHN, - deleted_at=None, - ).first() - is not None - ) - - def get_webauthn_credentials(self): - """Get all WebAuthn credentials for the user. - - Returns: - List of AuthenticationMethod instances for WebAuthn, ordered by creation date. - """ - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return AuthenticationMethod.query.filter_by( - user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None - ).order_by(AuthenticationMethod.created_at.desc()).all() - - def get_webauthn_credential_count(self) -> int: - """Get the count of WebAuthn credentials for the user. - - Returns: - Number of WebAuthn credentials. - """ - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return AuthenticationMethod.query.filter_by( - user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None - ).count() +__all__ = ["User"] diff --git a/gatehouse_app/models/user/__init__.py b/gatehouse_app/models/user/__init__.py new file mode 100644 index 0000000..1d05e9a --- /dev/null +++ b/gatehouse_app/models/user/__init__.py @@ -0,0 +1,5 @@ +"""User subpackage.""" +from gatehouse_app.models.user.user import User +from gatehouse_app.models.user.session import Session + +__all__ = ["User", "Session"] diff --git a/gatehouse_app/models/session.py b/gatehouse_app/models/user/session.py similarity index 86% rename from gatehouse_app/models/session.py rename to gatehouse_app/models/user/session.py index 0290a20..9a78830 100644 --- a/gatehouse_app/models/session.py +++ b/gatehouse_app/models/user/session.py @@ -21,7 +21,9 @@ class Session(BaseModel): # Timing expires_at = db.Column(db.DateTime, nullable=False) - last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)) + last_activity_at = db.Column( + db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc) + ) revoked_at = db.Column(db.DateTime, nullable=True) revoked_reason = db.Column(db.String(255), nullable=True) @@ -38,7 +40,6 @@ class Session(BaseModel): def is_active(self): """Check if session is currently active.""" now = datetime.now(timezone.utc) - # Make expires_at timezone-aware if it's naive expires_at = self.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) @@ -51,15 +52,13 @@ class Session(BaseModel): def is_expired(self): """Check if session has expired.""" now = datetime.now(timezone.utc) - # Make expires_at timezone-aware if it's naive expires_at = self.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) return now > expires_at - def refresh(self, duration_seconds=86400): - """ - Refresh session expiration. + def refresh(self, duration_seconds: int = 86400): + """Refresh session expiration. Args: duration_seconds: New session duration in seconds @@ -68,9 +67,8 @@ class Session(BaseModel): self.last_activity_at = datetime.now(timezone.utc) db.session.commit() - def revoke(self, reason=None): - """ - Revoke the session. + def revoke(self, reason: str = None): + """Revoke the session. Args: reason: Optional reason for revocation @@ -84,6 +82,5 @@ class Session(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Exclude token from dict exclude.append("token") return super().to_dict(exclude=exclude) diff --git a/gatehouse_app/models/user/user.py b/gatehouse_app/models/user/user.py new file mode 100644 index 0000000..c2fb1c8 --- /dev/null +++ b/gatehouse_app/models/user/user.py @@ -0,0 +1,209 @@ +"""User model.""" +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel +from gatehouse_app.utils.constants import UserStatus + + +class User(BaseModel): + """User model representing a user account.""" + + __tablename__ = "users" + + email = db.Column(db.String(255), unique=True, nullable=False, index=True) + email_verified = db.Column(db.Boolean, default=False, nullable=False) + full_name = db.Column(db.String(255), nullable=True) + avatar_url = db.Column(db.String(512), nullable=True) + status = db.Column( + db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True + ) + last_login_at = db.Column(db.DateTime, nullable=True) + last_login_ip = db.Column(db.String(45), nullable=True) + + # Account activation (email-link flow) + activated = db.Column(db.Boolean, default=True, nullable=False) + activation_key = db.Column(db.String(128), unique=True, nullable=True, index=True) + + # Relationships – defined here only for models that don't circular-import + authentication_methods = db.relationship( + "AuthenticationMethod", back_populates="user", cascade="all, delete-orphan" + ) + sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan") + organization_memberships = db.relationship( + "OrganizationMember", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="OrganizationMember.user_id", + ) + audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan") + security_policies = db.relationship( + "UserSecurityPolicy", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="UserSecurityPolicy.user_id", + ) + mfa_compliance = db.relationship( + "MfaPolicyCompliance", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="MfaPolicyCompliance.user_id", + ) + department_memberships = db.relationship( + "DepartmentMembership", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="DepartmentMembership.user_id", + ) + principal_memberships = db.relationship( + "PrincipalMembership", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="PrincipalMembership.user_id", + ) + ssh_keys = db.relationship( + "SSHKey", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="SSHKey.user_id", + ) + ssh_certificates = db.relationship( + "SSHCertificate", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="SSHCertificate.user_id", + ) + ca_permissions = db.relationship( + "CAPermission", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="CAPermission.user_id", + ) + + # OIDC relationships – registered here (no monkey-patching needed) + oidc_auth_codes = db.relationship( + "OIDCAuthCode", back_populates="user", cascade="all, delete-orphan" + ) + oidc_refresh_tokens = db.relationship( + "OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan" + ) + oidc_sessions = db.relationship( + "OIDCSession", back_populates="user", cascade="all, delete-orphan" + ) + oidc_token_metadata = db.relationship( + "OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan" + ) + oidc_audit_logs = db.relationship( + "OIDCAuditLog", back_populates="user", cascade="all, delete-orphan" + ) + + def __repr__(self): + """String representation of User.""" + return f"" + + def to_dict(self, exclude=None): + """Convert user to dictionary, excluding sensitive fields by default.""" + exclude = exclude or [] + return super().to_dict(exclude=exclude) + + def has_password_auth(self): + """Check if user has password authentication enabled.""" + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return ( + AuthenticationMethod.query.filter_by( + user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None + ).first() + is not None + ) + + def get_organizations(self): + """Get all organizations the user is a member of.""" + return [membership.organization for membership in self.organization_memberships] + + def has_totp_enabled(self) -> bool: + """Check if user has TOTP enabled and verified. + + Returns: + True if user has a verified TOTP authentication method, False otherwise. + """ + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return ( + AuthenticationMethod.query.filter_by( + user_id=self.id, + method_type=AuthMethodType.TOTP, + verified=True, + deleted_at=None, + ).first() + is not None + ) + + def get_totp_method(self): + """Get user's TOTP authentication method. + + Returns: + The AuthenticationMethod instance for TOTP or None if not found. + + Note: + Returns the most recently created TOTP method to handle cases where + multiple enrollment attempts may exist. + """ + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return ( + AuthenticationMethod.query.filter_by( + user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None + ) + .order_by(AuthenticationMethod.created_at.desc()) + .first() + ) + + def has_webauthn_enabled(self) -> bool: + """Check if user has any WebAuthn passkey credentials. + + Returns: + True if user has at least one WebAuthn credential, False otherwise. + """ + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return ( + AuthenticationMethod.query.filter_by( + user_id=self.id, + method_type=AuthMethodType.WEBAUTHN, + deleted_at=None, + ).first() + is not None + ) + + def get_webauthn_credentials(self): + """Get all WebAuthn credentials for the user. + + Returns: + List of AuthenticationMethod instances for WebAuthn, ordered by creation date. + """ + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return ( + AuthenticationMethod.query.filter_by( + user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None + ) + .order_by(AuthenticationMethod.created_at.desc()) + .all() + ) + + def get_webauthn_credential_count(self) -> int: + """Get the count of WebAuthn credentials for the user. + + Returns: + Number of WebAuthn credentials. + """ + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return AuthenticationMethod.query.filter_by( + user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None + ).count() diff --git a/gatehouse_app/schemas/auth_schema.py b/gatehouse_app/schemas/auth_schema.py index b2042f7..dff1758 100644 --- a/gatehouse_app/schemas/auth_schema.py +++ b/gatehouse_app/schemas/auth_schema.py @@ -25,7 +25,7 @@ class LoginSchema(Schema): email = fields.Email(required=True) password = fields.Str(required=True, validate=validate.Length(min=1)) - remember_me = fields.Bool(missing=False) + remember_me = fields.Bool(load_default=False) class RefreshTokenSchema(Schema): @@ -77,14 +77,38 @@ class TOTPVerifyEnrollmentSchema(Schema): class TOTPVerifySchema(Schema): """Schema for TOTP code verification during login.""" - code = fields.Str(required=True) - is_backup_code = fields.Bool(missing=False) + code = fields.Str( + required=True, + validate=validate.Length(min=1), + ) + is_backup_code = fields.Bool(load_default=False) client_timestamp = fields.Int( required=False, allow_none=True, metadata={"description": "Client UTC timestamp in seconds since epoch for TOTP verification"}, ) + @validates_schema + def validate_code_format(self, data, **kwargs): + """Validate code format depending on whether it's a backup code.""" + code = data.get("code", "") + is_backup_code = data.get("is_backup_code", False) + if is_backup_code: + # Backup codes are 16 uppercase hex characters + if not code or len(code) != 16 or not all(c in "0123456789ABCDEFabcdef" for c in code): + raise ValidationError( + "Backup code must be a 16-character hexadecimal string.", + field_name="code", + ) + else: + # Regular TOTP codes are exactly 6 digits + import re + if not re.match(r"^\d{6}$", code): + raise ValidationError( + "Code must be a 6-digit number.", + field_name="code", + ) + class TOTPDisableSchema(Schema): """Schema for disabling TOTP.""" diff --git a/gatehouse_app/services/audit_service.py b/gatehouse_app/services/audit_service.py index 3978aa5..ed70f14 100644 --- a/gatehouse_app/services/audit_service.py +++ b/gatehouse_app/services/audit_service.py @@ -1,6 +1,6 @@ """Audit service.""" from flask import request, g -from gatehouse_app.models.audit_log import AuditLog +from gatehouse_app.models.auth.audit_log import AuditLog from gatehouse_app.utils.constants import AuditAction @@ -59,7 +59,7 @@ class AuditService: ip_address=ip_address, user_agent=user_agent, request_id=request_id, - metadata=metadata, + extra_data=metadata, description=description, success=success, error_message=error_message, diff --git a/gatehouse_app/services/auth_service.py b/gatehouse_app/services/auth_service.py index 3df5bf3..9061f33 100644 --- a/gatehouse_app/services/auth_service.py +++ b/gatehouse_app/services/auth_service.py @@ -5,9 +5,9 @@ from datetime import datetime, timedelta, timezone from typing import Optional from flask import request, g, current_app from gatehouse_app.extensions import db, bcrypt -from gatehouse_app.models.user import User -from gatehouse_app.models.authentication_method import AuthenticationMethod -from gatehouse_app.models.session import Session +from gatehouse_app.models.user.user import User +from gatehouse_app.models.auth.authentication_method import AuthenticationMethod +from gatehouse_app.models.user.session import Session from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, UserStatus, AuditAction from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError @@ -102,7 +102,7 @@ class AuthService: if current_app.config.get('ENV') == 'development': logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}") - if user.status == UserStatus.SUSPENDED: + if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): raise AccountSuspendedError() if user.status == UserStatus.INACTIVE: raise AccountInactiveError() @@ -210,6 +210,22 @@ class AuthService: auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8") db.session.commit() + # Invalidate all other sessions so that if an attacker had a valid + # session token, changing the password actually locks them out. + # The current request's session (if any) is preserved so the user + # doesn't have to log in again immediately. + from flask import g as flask_g + current_session_id = getattr(flask_g, "current_session", None) + current_session_id = current_session_id.id if current_session_id else None + sessions_to_revoke = Session.query.filter( + Session.user_id == user.id, + Session.revoked_at == None, # noqa: E711 + ).all() + for sess in sessions_to_revoke: + if sess.id != current_session_id: + sess.revoke(reason="Password changed") + db.session.commit() + # Log password change AuditService.log_action( action=AuditAction.PASSWORD_CHANGE, @@ -482,9 +498,24 @@ class AuthService: if not secret: raise InvalidCredentialsError("TOTP secret not found") + # Replay-attack prevention: reject codes that have already been + # accepted within the current validity window. + if TOTPService.is_code_already_used(str(user.id), code): + AuditService.log_action( + action=AuditAction.TOTP_VERIFY_FAILED, + user_id=user.id, + resource_type="authentication_method", + resource_id=auth_method.id, + description="TOTP code replay attempt detected", + ) + raise InvalidCredentialsError("Invalid TOTP code") + is_valid = TOTPService.verify_code(secret, code, client_utc_timestamp=client_utc_timestamp) if is_valid: + # Mark this code as used to prevent replay within the validity window + TOTPService.mark_code_used(str(user.id), code) + auth_method.last_used_at = datetime.now(timezone.utc) db.session.commit() diff --git a/gatehouse_app/services/external_auth_service.py b/gatehouse_app/services/external_auth_service.py index 89bfbf1..57cad76 100644 --- a/gatehouse_app/services/external_auth_service.py +++ b/gatehouse_app/services/external_auth_service.py @@ -8,7 +8,7 @@ from flask import current_app from gatehouse_app.extensions import db from gatehouse_app.models import User, AuthenticationMethod -from gatehouse_app.models.authentication_method import ( +from gatehouse_app.models.auth.authentication_method import ( OAuthState, ApplicationProviderConfig, OrganizationProviderOverride @@ -736,9 +736,14 @@ class ExternalAuthService: 400, ) - # Generate PKCE - code_verifier = secrets.token_urlsafe(32) - code_challenge = cls._compute_s256_challenge(code_verifier) + # Generate PKCE — skip for confidential clients (Google, Microsoft) that use a + # client_secret. Sending code_challenge to Microsoft causes it to enforce PKCE on + # the token exchange, which then fails. Matches the behaviour of initiate_login_flow. + code_verifier = None + code_challenge = None + if provider_type_str not in ('google', 'microsoft'): + code_verifier = secrets.token_urlsafe(32) + code_challenge = cls._compute_s256_challenge(code_verifier) # Create OAuth state state = OAuthState.create_state( @@ -1210,12 +1215,35 @@ class ExternalAuthService: else: email_verified = data.get("email_verified", False) + sub = data.get("sub") + + # Derive email from sub when the provider omits the email claim. + # This happens with some OIDC servers (including the nav-security mock) + # that only return the minimal {sub, iss, iat, exp} set. + # Rule: if sub looks like an email address, use it directly. + # Otherwise, construct a deterministic fallback so we never get NULL. + raw_email = data.get("email") + if not raw_email and sub: + import re as _re + if _re.match(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", sub): + raw_email = sub + email_verified = True # if sub IS the email it's already verified + else: + # e.g. "12345" → "12345@google.local" so we can store it + raw_email = f"{sub}@{provider or 'oauth'}.local" + email_verified = False + + # Derive display name when omitted + raw_name = data.get("name") or data.get("display_name") + if not raw_name and raw_email: + raw_name = raw_email.split("@")[0] + # Standardize user info return { - "provider_user_id": data.get("sub"), - "email": data.get("email"), + "provider_user_id": sub, + "email": raw_email, "email_verified": email_verified, - "name": data.get("name"), + "name": raw_name, "first_name": data.get("given_name"), "last_name": data.get("family_name"), "picture": data.get("picture"), diff --git a/gatehouse_app/services/mfa_policy_service.py b/gatehouse_app/services/mfa_policy_service.py index d553fc1..d7b1316 100644 --- a/gatehouse_app/services/mfa_policy_service.py +++ b/gatehouse_app/services/mfa_policy_service.py @@ -4,11 +4,11 @@ from datetime import datetime, timezone from typing import Optional, List, Dict, Any from gatehouse_app.extensions import db -from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy -from gatehouse_app.models.user_security_policy import UserSecurityPolicy -from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance -from gatehouse_app.models.user import User -from gatehouse_app.models.organization import Organization +from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy +from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy +from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance +from gatehouse_app.models.user.user import User +from gatehouse_app.models.organization.organization import Organization from gatehouse_app.services.audit_service import AuditService from gatehouse_app.utils.constants import ( MfaPolicyMode, @@ -702,7 +702,7 @@ class MfaPolicyService: if now is None: now = datetime.now(timezone.utc) - from gatehouse_app.models.organization_member import OrganizationMember + from gatehouse_app.models.organization.organization_member import OrganizationMember updated_count = 0 diff --git a/gatehouse_app/services/notification_service.py b/gatehouse_app/services/notification_service.py index d3afdca..fc9bbe0 100644 --- a/gatehouse_app/services/notification_service.py +++ b/gatehouse_app/services/notification_service.py @@ -19,9 +19,9 @@ import logging import json from gatehouse_app.extensions import db -from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance -from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy -from gatehouse_app.models.user import User +from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance +from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy +from gatehouse_app.models.user.user import User from gatehouse_app.services.audit_service import AuditService from gatehouse_app.utils.constants import AuditAction @@ -37,6 +37,7 @@ class NotificationService: SMTP_PORT_KEY = "SMTP_PORT" SMTP_USERNAME_KEY = "SMTP_USERNAME" SMTP_PASSWORD_KEY = "SMTP_PASSWORD" + SMTP_USE_TLS_KEY = "SMTP_USE_TLS" FROM_ADDRESS_KEY = "FROM_ADDRESS" @staticmethod @@ -86,10 +87,9 @@ class NotificationService: if success: logger.info( f"Sent MFA deadline reminder to {user.email} " - f"({days_until_deadline} days remaining # Audit log -)" + f"({days_until_deadline} days remaining)" ) - AuditService.log_action( + AuditService.log_action( action=AuditAction.MFA_POLICY_USER_COMPLIANT, user_id=user.id, organization_id=compliance.organization_id, @@ -291,101 +291,62 @@ Gatehouse Security Team body: str, html_body: Optional[str] = None, ) -> bool: - """Send an email notification. + """Send an email via SMTP. - This method attempts to send an email using configured SMTP settings. - If email is not configured, it logs the notification instead. - - Args: - to_address: Recipient email address - subject: Email subject - body: Plain text email body - html_body: Optional HTML email body - - Returns: - True if email was sent (or logged), False on error + Returns True if the email was sent successfully, False otherwise. + If EMAIL_ENABLED is False, logs the email body instead (simulation mode). """ + import smtplib + from email.mime.multipart import MIMEMultipart + from email.mime.text import MIMEText + from flask import current_app + + email_enabled = current_app.config.get(NotificationService.EMAIL_ENABLED_KEY, False) + + if not email_enabled: + logger.info( + f"[EMAIL DISABLED] Would have sent to: {to_address} | Subject: {subject}\n" + f"Body: {body[:500]}" + ) + return False + + smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY, "localhost") + smtp_port = int(current_app.config.get(NotificationService.SMTP_PORT_KEY, 587)) + smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY) + smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY) + smtp_use_tls = current_app.config.get( + NotificationService.SMTP_USE_TLS_KEY, + smtp_port not in (25, 1025), + ) + from_address = current_app.config.get( + NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local" + ) + try: - from flask import current_app + msg = MIMEMultipart("alternative") + msg["Subject"] = subject + msg["From"] = from_address + msg["To"] = to_address + msg.attach(MIMEText(body, "plain")) + if html_body: + msg.attach(MIMEText(html_body, "html")) - # Check if email is configured - email_enabled = current_app.config.get( - NotificationService.EMAIL_ENABLED_KEY, False - ) - - if not email_enabled: - # Log the notification instead of sending - logger.info( - f"[EMAIL SIMULATION] To: {to_address}\n" - f"Subject: {subject}\n" - f"Body: {body[:200]}..." if len(body) > 200 else f"Body: {body}" - ) - return True - - # Get email configuration - smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY) - smtp_port = current_app.config.get(NotificationService.SMTP_PORT_KEY, 587) - smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY) - smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY) - from_address = current_app.config.get( - NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local" - ) - - # Import send_email based on available mail library - try: - from flask_mail import Message - - from gatehouse_app import mail - - msg = Message( - subject=subject, - recipients=[to_address], - body=body, - html=html_body, - sender=from_address, - ) - mail.send(msg) - logger.info(f"Email sent successfully to {to_address}") - return True - - except ImportError: - # Flask-Mail not available, use SMTP directly - import smtplib - from email.mime.text import MIMEText - from email.mime.multipart import MIMEMultipart - - msg = MIMEMultipart("alternative") - msg["Subject"] = subject - msg["From"] = from_address - msg["To"] = to_address - - # Attach plain text and HTML versions - part1 = MIMEText(body, "plain") - msg.attach(part1) - - if html_body: - part2 = MIMEText(html_body, "html") - msg.attach(part2) - - # Send via SMTP - with smtplib.SMTP(smtp_host, smtp_port) as server: + with smtplib.SMTP(smtp_host, smtp_port) as server: + server.ehlo() + if smtp_use_tls: server.starttls() - if smtp_username and smtp_password: - server.login(smtp_username, smtp_password) - server.send_message(msg) + server.ehlo() + if smtp_username and smtp_password: + server.login(smtp_username, smtp_password) + server.send_message(msg) - logger.info(f"Email sent successfully to {to_address}") - return True + logger.info(f"[EMAIL] Sent to {to_address} | Subject: {subject}") + return True except Exception as e: - logger.exception(f"Failed to send email to {to_address}: {e}") - # Log the notification as fallback - logger.info( - f"[EMAIL FALLBACK] To: {to_address}\n" - f"Subject: {subject}\n" - f"Body: {body[:500]}..." if len(body) > 500 else f"Body: {body}" - ) - return True # Return True to continue processing + logger.error(f"[EMAIL] Failed to send to {to_address}: {e}") + return False + @staticmethod def get_notification_stats(user_id: str) -> Dict[str, Any]: @@ -397,7 +358,7 @@ Gatehouse Security Team Returns: Dictionary with notification statistics """ - from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance + from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance stats = { "total_notifications": 0, diff --git a/gatehouse_app/services/oauth_flow_service.py b/gatehouse_app/services/oauth_flow_service.py index 89b1745..2889a7c 100644 --- a/gatehouse_app/services/oauth_flow_service.py +++ b/gatehouse_app/services/oauth_flow_service.py @@ -9,10 +9,10 @@ from flask import current_app, request, g, redirect from gatehouse_app.extensions import db from gatehouse_app.models import User, AuthenticationMethod -from gatehouse_app.models.authentication_method import OAuthState +from gatehouse_app.models.auth.authentication_method import OAuthState from gatehouse_app.models.base import BaseModel -from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode -from gatehouse_app.utils.constants import AuthMethodType +from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode +from gatehouse_app.utils.constants import AuthMethodType, AuditAction from gatehouse_app.services.audit_service import AuditService from gatehouse_app.services.external_auth_service import ( ExternalAuthService, @@ -139,7 +139,7 @@ class OAuthFlowService: except ExternalAuthError as e: # Log failed initiation AuditService.log_action( - action="external_auth.login.initiated", + action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED, organization_id=organization_id, metadata={ "provider_type": provider_type_str, @@ -236,7 +236,7 @@ class OAuthFlowService: except ExternalAuthError as e: AuditService.log_action( - action="external_auth.register.initiated", + action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED, organization_id=organization_id, metadata={ "provider_type": provider_type_str, @@ -399,6 +399,27 @@ class OAuthFlowService: access_token=tokens["access_token"], ) + if not user_info.get("provider_user_id"): + raise OAuthFlowError( + "Provider did not return a user identifier (sub claim). " + "Cannot complete authentication.", + "MISSING_PROVIDER_USER_ID", + 400, + ) + + if not user_info.get("email"): + raise OAuthFlowError( + "Provider did not return an email address. " + "Cannot complete authentication.", + "MISSING_EMAIL", + 400, + ) + + logger.debug( + f"Got user_info from provider: sub={user_info['provider_user_id']}, " + f"email={user_info['email']}, email_verified={user_info.get('email_verified')}" + ) + # Look up user by provider_user_id auth_method = AuthenticationMethod.query.filter_by( method_type=provider_type, @@ -494,25 +515,52 @@ class OAuthFlowService: if not target_org and len(user_orgs) == 1: target_org = user_orgs[0] - # Priority 3: No orgs at all — auto-create a personal org and log in + # Priority 3: No orgs at all — send to org-setup instead of auto-creating if not target_org and len(user_orgs) == 0: - import re - import uuid - from gatehouse_app.services.organization_service import OrganizationService - org_name = f"{user_info.get('name') or user.email.split('@')[0]}'s Workspace" - # Build a URL-safe slug and ensure uniqueness with a short suffix - base_slug = re.sub(r"[^a-z0-9]+", "-", org_name.lower()).strip("-")[:40] - slug = f"{base_slug}-{uuid.uuid4().hex[:6]}" - org = OrganizationService.create_organization( - name=org_name, - slug=slug, - owner_user_id=user.id, - ) - target_org = org + from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + from gatehouse_app.services.auth_service import AuthService as _AS + _now = datetime.now(timezone.utc) + _session = _AS.create_session(user=user, is_compliance_only=False) + _session_dict = _session.to_dict() + _session_dict["token"] = _session.token + _expires_at = _session.expires_at + if _expires_at.tzinfo is None: + _expires_at = _expires_at.replace(tzinfo=timezone.utc) + _session_dict["expires_in"] = int((_expires_at - _now).total_seconds()) + + _pending = OrgInviteToken.query.filter( + OrgInviteToken.email == user.email, + OrgInviteToken.accepted_at.is_(None), + OrgInviteToken.expires_at > _now, + OrgInviteToken.deleted_at.is_(None), + ).all() + _pending_list = [ + { + "token": inv.token, + "organization": { + "id": str(inv.organization_id), + "name": inv.organization.name, + }, + "role": inv.role, + "expires_at": inv.expires_at.isoformat(), + } + for inv in _pending + ] + + state_record.mark_used() logger.info( - f"OAuth login: auto-created org '{org.name}' (id={org.id}) " - f"for new user {user.id}" + f"OAuth login: user {user.id} has no org, redirecting to org-setup " + f"(pending_invites={len(_pending_list)})" ) + return { + "success": True, + "flow_type": "login", + "requires_org_creation": True, + "user": {"id": user.id, "email": user.email, "full_name": user.full_name}, + "session": _session_dict, + "pending_invites": _pending_list, + "state": state_record.state, + } # Priority 4: Multiple orgs — need user to pick one if not target_org: @@ -755,7 +803,7 @@ class OAuthFlowService: # If organization_id hint was provided and valid, create session for that org if state_record.organization_id: - from gatehouse_app.models.organization import Organization + from gatehouse_app.models.organization.organization import Organization org = Organization.query.get(state_record.organization_id) if org: from gatehouse_app.services.auth_service import AuthService @@ -784,7 +832,40 @@ class OAuthFlowService: "session": session_dict, } - # No organization hint or invalid - need to create/select org + # No organization hint or invalid - need to create/select org. + # Still create a session so the frontend can call /organizations + # and /invites after redirecting to /org-setup. + from gatehouse_app.services.auth_service import AuthService as _AS + from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + _session = _AS.create_session(user=user, is_compliance_only=False) + _session_dict = _session.to_dict() + _session_dict["token"] = _session.token + _expires_at = _session.expires_at + if _expires_at.tzinfo is None: + _expires_at = _expires_at.replace(tzinfo=timezone.utc) + _now = datetime.now(timezone.utc) + _session_dict["expires_in"] = int((_expires_at - _now).total_seconds()) + + # Surface pending invitations so the UI can offer "join vs create" + _pending = OrgInviteToken.query.filter( + OrgInviteToken.email == user.email, + OrgInviteToken.accepted_at.is_(None), + OrgInviteToken.expires_at > _now, + OrgInviteToken.deleted_at.is_(None), + ).all() + _pending_list = [ + { + "token": inv.token, + "organization": { + "id": str(inv.organization_id), + "name": inv.organization.name, + }, + "role": inv.role, + "expires_at": inv.expires_at.isoformat(), + } + for inv in _pending + ] + return { "success": True, "flow_type": "register", @@ -794,6 +875,8 @@ class OAuthFlowService: "email": user.email, "full_name": user.full_name, }, + "session": _session_dict, + "pending_invites": _pending_list, "state": state_record.state, } @@ -966,8 +1049,8 @@ class OAuthFlowService: ) # Determine organization - from gatehouse_app.models.organization import Organization - from gatehouse_app.models.organization_member import OrganizationMember + from gatehouse_app.models.organization.organization import Organization + from gatehouse_app.models.organization.organization_member import OrganizationMember # Get user's organizations user_orgs = user.get_organizations() diff --git a/gatehouse_app/services/oidc_jwks_service.py b/gatehouse_app/services/oidc_jwks_service.py index c8423ef..269dc08 100644 --- a/gatehouse_app/services/oidc_jwks_service.py +++ b/gatehouse_app/services/oidc_jwks_service.py @@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Tuple from flask import current_app from gatehouse_app.extensions import db -from gatehouse_app.models.oidc_jwks_key import OidcJwksKey +from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey class JWKSKey: diff --git a/gatehouse_app/services/oidc_service.py b/gatehouse_app/services/oidc_service.py index 6157c21..fc6e317 100644 --- a/gatehouse_app/services/oidc_service.py +++ b/gatehouse_app/services/oidc_service.py @@ -14,7 +14,7 @@ from gatehouse_app.models import ( User, OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCSession, OIDCTokenMetadata ) -from gatehouse_app.models.organization_member import OrganizationMember +from gatehouse_app.models.organization.organization_member import OrganizationMember from gatehouse_app.exceptions.validation_exceptions import ( ValidationError, NotFoundError, BadRequestError ) diff --git a/gatehouse_app/services/oidc_token_service.py b/gatehouse_app/services/oidc_token_service.py index 3c1c929..5605d5f 100644 --- a/gatehouse_app/services/oidc_token_service.py +++ b/gatehouse_app/services/oidc_token_service.py @@ -11,7 +11,7 @@ import jwt from flask import current_app, g from gatehouse_app.models import User, OIDCClient -from gatehouse_app.models.organization_member import OrganizationMember +from gatehouse_app.models.organization.organization_member import OrganizationMember from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService logger = logging.getLogger(__name__) diff --git a/gatehouse_app/services/organization_service.py b/gatehouse_app/services/organization_service.py index e802b84..27ec7f4 100644 --- a/gatehouse_app/services/organization_service.py +++ b/gatehouse_app/services/organization_service.py @@ -3,8 +3,8 @@ import logging from datetime import datetime, timezone from flask import current_app from gatehouse_app.extensions import db -from gatehouse_app.models.organization import Organization -from gatehouse_app.models.organization_member import OrganizationMember +from gatehouse_app.models.organization.organization import Organization +from gatehouse_app.models.organization.organization_member import OrganizationMember from gatehouse_app.exceptions.validation_exceptions import OrganizationNotFoundError, ConflictError from gatehouse_app.utils.constants import OrganizationRole, AuditAction from gatehouse_app.services.audit_service import AuditService @@ -188,11 +188,10 @@ class OrganizationService: Raises: ConflictError: If user is already a member """ - # Check if already a member + # Check if already a member (active or soft-deleted — both blocked by DB unique constraint) existing = OrganizationMember.query.filter_by( user_id=user_id, organization_id=org.id, - deleted_at=None, ).first() # Development-only debug logging for membership validation @@ -200,6 +199,25 @@ class OrganizationService: logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}") if existing: + if existing.deleted_at is not None: + # Reactivate the soft-deleted membership with the new role + existing.deleted_at = None + existing.role = role + existing.invited_by_id = inviter_id + existing.invited_at = datetime.now(timezone.utc) + existing.joined_at = datetime.now(timezone.utc) + existing.save() + + AuditService.log_action( + action=AuditAction.ORG_MEMBER_ADD, + user_id=inviter_id, + organization_id=org.id, + resource_type="organization_member", + resource_id=existing.id, + metadata={"added_user_id": user_id, "role": role.value}, + description=f"Member re-added to organization with role: {role.value}", + ) + return existing raise ConflictError("User is already a member of this organization") # Create membership diff --git a/gatehouse_app/services/session_service.py b/gatehouse_app/services/session_service.py index 7186222..7103285 100644 --- a/gatehouse_app/services/session_service.py +++ b/gatehouse_app/services/session_service.py @@ -1,6 +1,6 @@ """Session service.""" from datetime import datetime, timezone -from gatehouse_app.models.session import Session +from gatehouse_app.models.user.session import Session from gatehouse_app.utils.constants import SessionStatus @@ -17,7 +17,7 @@ class SessionService: Returns: Session object if found and active, None otherwise """ - from gatehouse_app.models.session import Session + from gatehouse_app.models.user.session import Session from gatehouse_app.utils.constants import SessionStatus return Session.query.filter_by( token=token, diff --git a/gatehouse_app/services/ssh_ca_signing_service.py b/gatehouse_app/services/ssh_ca_signing_service.py new file mode 100644 index 0000000..1502ab8 --- /dev/null +++ b/gatehouse_app/services/ssh_ca_signing_service.py @@ -0,0 +1,359 @@ +"""SSH Certificate Authority signing service. + +Handles SSH certificate signing operations, leveraging sshkey-tools library. +This service is a Gatehouse-integrated version of the secuird/ssh_ca.py logic. +""" +import logging +import os +from datetime import datetime, timedelta, timezone +from typing import List, Dict, Any, Optional + +from sshkey_tools.cert import SSHCertificate, CertificateFields +from sshkey_tools.keys import PublicKey, PrivateKey + +from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config +from gatehouse_app.exceptions import SSHCAError, ValidationError +from gatehouse_app.utils.crypto import compute_ssh_fingerprint + +logger = logging.getLogger(__name__) + + +class SSHCASigningError(Exception): + """SSH CA signing operation error.""" + pass + + +class SSHCertificateSigningRequest: + """Represents an SSH certificate signing request.""" + + def __init__( + self, + ssh_public_key: str, + principals: List[str], + key_id: str, + cert_type: str = "user", + expiry_hours: Optional[int] = None, + critical_options: Optional[Dict[str, str]] = None, + extensions: Optional[List[str]] = None, + ): + """Initialize signing request. + + Args: + ssh_public_key: Public key in OpenSSH format (e.g., "ssh-ed25519 AAAA...") + principals: List of principals (e.g., ["prod-servers", "staging"]) + key_id: Key identifier (usually user email) + cert_type: Certificate type - "user" or "host" (default: user) + expiry_hours: Certificate validity in hours + critical_options: Critical options dict + extensions: List of extensions (e.g., ["permit-pty", "permit-agent-forwarding"]) + """ + self.ssh_public_key = ssh_public_key + self.principals = principals or [] + self.key_id = key_id + self.cert_type = cert_type + self.expiry_hours = expiry_hours + self.critical_options = critical_options or {} + self.extensions = extensions or [] + + def validate(self) -> List[str]: + """Validate the signing request. + + Returns: + List of validation errors (empty if valid) + """ + errors = [] + config = get_ssh_ca_config() + + # Validate cert type + if self.cert_type not in ("user", "host"): + errors.append(f"Invalid cert_type: {self.cert_type}. Must be 'user' or 'host'") + + # Validate SSH public key + if not self.ssh_public_key or len(self.ssh_public_key) < 16: + errors.append("SSH public key is missing or invalid") + else: + try: + PublicKey.from_string(self.ssh_public_key) + except Exception as e: + errors.append(f"SSH public key is not valid: {str(e)}") + + # Validate principals + if not self.principals or len(self.principals) == 0: + errors.append("At least one principal is required") + else: + max_principals = config.get_int('max_principals_per_cert') + if len(self.principals) > max_principals: + errors.append( + f"Too many principals ({len(self.principals)}). " + f"Maximum is {max_principals}" + ) + + # Validate key_id + if not self.key_id or len(self.key_id) < 5: + errors.append("key_id is missing or too short (minimum 5 characters)") + else: + max_id_len = config.get_int('max_key_id_length') + if len(self.key_id) > max_id_len: + errors.append(f"key_id exceeds maximum length of {max_id_len}") + + # Validate expiry_hours + if self.expiry_hours is not None: + if not isinstance(self.expiry_hours, int) or self.expiry_hours <= 0: + errors.append("expiry_hours must be a positive integer") + else: + max_validity = config.get_int('max_cert_validity_hours') + if self.expiry_hours > max_validity: + errors.append( + f"Requested expiry ({self.expiry_hours}h) exceeds " + f"maximum allowed ({max_validity}h)" + ) + + return errors + + +class SSHCertificateSigningResponse: + """Represents a signed SSH certificate response.""" + + def __init__( + self, + certificate: str, + serial: str, + valid_after: datetime, + valid_before: datetime, + principals: Optional[List[str]] = None, + ): + """Initialize signing response. + + Args: + certificate: Full certificate in OpenSSH format + serial: Certificate serial number + valid_after: Validity start datetime + valid_before: Validity end datetime + principals: List of principals the cert was issued for + """ + self.certificate = certificate + self.serial = serial + self.valid_after = valid_after + self.valid_before = valid_before + self.principals = principals or [] + + def to_dict(self) -> Dict[str, Any]: + """Convert response to dictionary.""" + return { + 'certificate': self.certificate, + 'serial': self.serial, + 'valid_after': self.valid_after.isoformat(), + 'valid_before': self.valid_before.isoformat(), + } + + +class SSHCASigningService: + """Service for signing SSH certificates. + + This service handles all SSH certificate signing operations. + It uses configuration from ssh_ca_config to apply rules and limits. + """ + + def __init__(self): + """Initialize the SSH CA signing service.""" + self.config = get_ssh_ca_config() + self.logger = logger + + def _load_ca_key_from_config(self) -> str: + """Load CA private key from config (local file or env var). + + Returns: + CA private key in PEM/OpenSSH format as string + + Raises: + SSHCASigningError: If key cannot be loaded + """ + # Check env var first + key_content = os.environ.get('SSH_CA_PRIVATE_KEY') + if key_content: + return key_content + + # Load from file path + key_path = self.config.get_str('ca_key_path', '').strip() + if not key_path: + raise SSHCASigningError( + "CA private key not configured. Set SSH_CA_PRIVATE_KEY env var " + "or ca_key_path in etc/ssh_ca.conf" + ) + + key_path = os.path.expandvars(os.path.expanduser(key_path)) + if not os.path.exists(key_path): + raise SSHCASigningError(f"CA private key file not found: {key_path}") + + with open(key_path, 'r') as f: + return f.read() + + def sign_certificate( + self, + signing_request: SSHCertificateSigningRequest, + ca_private_key: Optional[str] = None, + ca_obj=None, + ) -> "SSHCertificateSigningResponse": + """Sign an SSH certificate. + + Args: + signing_request: SSHCertificateSigningRequest instance + ca_private_key: CA private key in PEM format. If not provided, + loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var) + ca_obj: Optional CA model instance. When supplied its monotonic + serial counter is incremented atomically (SELECT FOR UPDATE) + and the resulting integer is embedded in the certificate's + serial field. This ensures every issued cert has a unique, + ordered, auditable serial number. + + Returns: + SSHCertificateSigningResponse with signed certificate + + Raises: + SSHCASigningError: If signing fails + ValidationError: If request is invalid + """ + # Validate request + errors = signing_request.validate() + if errors: + error_msg = "; ".join(errors) + self.logger.error(f"Certificate signing validation failed: {error_msg}") + raise ValidationError(f"Certificate signing validation failed: {error_msg}") + + # Load CA key if not provided + if ca_private_key is None: + ca_private_key = self._load_ca_key_from_config() + + try: + # Parse CA private key + try: + ca_key = PrivateKey.from_string(ca_private_key) + except Exception as e: + self.logger.error(f"Failed to load CA private key: {str(e)}") + raise SSHCASigningError(f"Invalid CA private key: {str(e)}") + + # Parse user's public key + try: + user_pub_key = PublicKey.from_string(signing_request.ssh_public_key) + except Exception as e: + self.logger.error(f"Failed to parse user public key: {str(e)}") + raise SSHCASigningError(f"Invalid user public key: {str(e)}") + + # Create certificate + certificate = SSHCertificate.create( + subject_pubkey=user_pub_key, + ca_privkey=ca_key, + ) + + # Set validity period + now = datetime.now(timezone.utc) + expiry_hours = signing_request.expiry_hours or self.config.get_int('cert_validity_hours') + valid_before = now + timedelta(hours=expiry_hours) + + # Set certificate fields + # sshkey-tools: user=1, host=2 (not 0) + cert_type = 1 if signing_request.cert_type == "user" else 2 + + certificate.fields.cert_type = cert_type + certificate.fields.key_id = signing_request.key_id + certificate.fields.principals = signing_request.principals + certificate.fields.valid_after = now + certificate.fields.valid_before = valid_before + + # ── Serial number ──────────────────────────────────────────────── + # If a CA object is provided, use its monotonic counter so every + # certificate gets a unique, ordered, auditable serial. The + # counter increment is flushed inside get_next_serial(); the + # caller's commit() persists it atomically with the cert record. + if ca_obj is not None: + assigned_serial = ca_obj.get_next_serial() + certificate.fields.serial = assigned_serial + self.logger.debug( + f"Assigned serial {assigned_serial} from CA {ca_obj.id}" + ) + # ───────────────────────────────────────────────────────────────── + + # Set extensions — prefer policy-provided list, fall back to standard set + extensions = signing_request.extensions + if not extensions: + from gatehouse_app.models.organization.department_cert_policy import STANDARD_EXTENSIONS + extensions = list(STANDARD_EXTENSIONS) + + certificate.fields.extensions = extensions + certificate.fields.critical_options = signing_request.critical_options or {} + + # Validate certificate before signing + if not certificate.can_sign(): + raise SSHCASigningError("Certificate cannot be signed") + + # Sign the certificate + certificate.sign() + + # Verify the certificate + try: + certificate.verify(ca_key.public_key, raise_on_error=True) + except Exception as e: + self.logger.error(f"Certificate verification failed: {str(e)}") + raise SSHCASigningError(f"Certificate verification failed: {str(e)}") + + # Extract serial from certificate — use the integer we assigned + # when ca_obj was provided, otherwise fall back to whatever the + # library generated. + if ca_obj is not None: + serial = str(assigned_serial) + else: + serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial) + + # Build response + cert_string = certificate.to_string() + + self.logger.info( + f"Successfully signed certificate: serial={serial}, " + f"key_id={signing_request.key_id}, principals={signing_request.principals}" + ) + + return SSHCertificateSigningResponse( + certificate=cert_string, + serial=serial, + valid_after=now, + valid_before=valid_before, + principals=signing_request.principals, + ) + + except (SSHCASigningError, ValidationError): + raise + except Exception as e: + self.logger.error(f"Unexpected error during certificate signing: {str(e)}", exc_info=True) + raise SSHCASigningError(f"Error signing certificate: {str(e)}") + + def verify_ca_key(self, ca_private_key: str) -> Dict[str, Any]: + """Verify a CA private key is valid and extract metadata. + + Args: + ca_private_key: CA private key in PEM format + + Returns: + Dictionary with key metadata (fingerprint, key_type, etc.) + + Raises: + SSHCASigningError: If key is invalid + """ + try: + ca_key = PrivateKey.from_string(ca_private_key) + pub_key = ca_key.public_key + + # Compute fingerprint + fingerprint = compute_ssh_fingerprint(pub_key.to_string()) + + # Get key type + key_type = pub_key.keytype if hasattr(pub_key, 'keytype') else 'unknown' + + return { + 'fingerprint': fingerprint, + 'key_type': key_type, + 'public_key': pub_key.to_string(), + 'valid': True, + } + except Exception as e: + self.logger.error(f"CA key verification failed: {str(e)}") + raise SSHCASigningError(f"Invalid CA key: {str(e)}") diff --git a/gatehouse_app/services/ssh_key_service.py b/gatehouse_app/services/ssh_key_service.py new file mode 100644 index 0000000..db07b0b --- /dev/null +++ b/gatehouse_app/services/ssh_key_service.py @@ -0,0 +1,375 @@ +"""SSH Key management service.""" +import base64 +import logging +import os +import secrets +import subprocess +import tempfile +from datetime import datetime, timedelta +from typing import Optional, List, Dict, Any + +from gatehouse_app.extensions import db +from gatehouse_app.models import SSHKey, User +from gatehouse_app.exceptions import ( + SSHKeyError, + SSHKeyNotFoundError, + SSHKeyAlreadyExistsError, + SSHKeyNotVerifiedError, + ValidationError, + UserNotFoundError, +) +from gatehouse_app.utils.crypto import ( + compute_ssh_fingerprint, + verify_ssh_key_format, + extract_ssh_key_type, + extract_ssh_key_comment, +) +from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config + +logger = logging.getLogger(__name__) + + +class SSHKeyService: + """Service for managing SSH keys.""" + + def __init__(self): + """Initialize SSH key service.""" + self.config = get_ssh_ca_config() + + def add_ssh_key( + self, + user_id: str, + public_key: str, + description: Optional[str] = None, + ) -> SSHKey: + """Add an SSH public key for a user. + + Args: + user_id: ID of the user + public_key: SSH public key in OpenSSH format + description: Optional description of the key + + Returns: + Created SSHKey instance + + Raises: + UserNotFoundError: If user doesn't exist + SSHKeyError: If key format is invalid + SSHKeyAlreadyExistsError: If key already exists + """ + # Verify user exists + user = User.query.get(user_id) + if not user: + raise UserNotFoundError(f"User {user_id} not found") + + # Validate key format + if not verify_ssh_key_format(public_key): + raise SSHKeyError("Invalid SSH public key format") + + # Compute fingerprint + try: + fingerprint = compute_ssh_fingerprint(public_key) + except Exception as e: + logger.error(f"Failed to compute fingerprint: {str(e)}") + raise SSHKeyError(f"Failed to compute key fingerprint: {str(e)}") + + # Check for duplicate (including soft-deleted records — fingerprint is unique in DB) + existing = SSHKey.query.filter_by(fingerprint=fingerprint).first() + if existing: + if existing.deleted_at is not None: + # Restore the soft-deleted key: clear deleted_at and update fields + existing.deleted_at = None + existing.user_id = user_id + existing.description = description or existing.description + existing.verified = False + existing.verified_at = None + existing.verify_text = None + existing.verify_text_created_at = None + db.session.commit() + logger.info( + f"Restored soft-deleted SSH key for user {user_id}: " + f"fingerprint={fingerprint}" + ) + return existing + raise SSHKeyAlreadyExistsError( + f"SSH key with fingerprint {fingerprint} already exists" + ) + + # Extract metadata + key_type = extract_ssh_key_type(public_key) + key_comment = extract_ssh_key_comment(public_key) + + # Create SSH key record + ssh_key = SSHKey( + user_id=user_id, + payload=public_key, + fingerprint=fingerprint, + description=description, + key_type=key_type, + key_comment=key_comment, + verified=False, + ) + + ssh_key.save() + + logger.info( + f"SSH key added for user {user_id}: " + f"fingerprint={fingerprint}, type={key_type}" + ) + + return ssh_key + + def get_ssh_key(self, key_id: str) -> SSHKey: + """Get an SSH key by ID. + + Args: + key_id: SSH key ID + + Returns: + SSHKey instance + + Raises: + SSHKeyNotFoundError: If key not found + """ + key = SSHKey.query.filter_by(id=key_id, deleted_at=None).first() + if not key: + raise SSHKeyNotFoundError(f"SSH key {key_id} not found") + return key + + def get_user_ssh_keys(self, user_id: str) -> List[SSHKey]: + """Get all SSH keys for a user. + + Args: + user_id: User ID + + Returns: + List of SSHKey instances + """ + return SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all() + + def get_user_verified_ssh_keys(self, user_id: str) -> List[SSHKey]: + """Get all verified SSH keys for a user. + + Args: + user_id: User ID + + Returns: + List of verified SSHKey instances + """ + return SSHKey.query.filter_by( + user_id=user_id, + verified=True, + deleted_at=None, + ).all() + + def delete_ssh_key(self, key_id: str) -> None: + """Soft-delete an SSH key. + + Args: + key_id: SSH key ID + + Raises: + SSHKeyNotFoundError: If key not found + """ + key = self.get_ssh_key(key_id) + key.delete() + + logger.info(f"SSH key deleted: {key_id}") + + def generate_verification_challenge(self, key_id: str) -> str: + """Generate a verification challenge for an SSH key. + + The user must sign this challenge text with their private key + to prove key ownership. + + Args: + key_id: SSH key ID + + Returns: + Verification challenge text + + Raises: + SSHKeyNotFoundError: If key not found + """ + key = self.get_ssh_key(key_id) + + # Generate random challenge + challenge = secrets.token_hex(32) + challenge_text = f"Please sign this to verify SSH key ownership: {challenge}" + + # Store challenge + key.verify_text = challenge_text + key.verify_text_created_at = datetime.utcnow() + key.save() + + logger.info(f"Generated verification challenge for SSH key {key_id}") + + return challenge_text + + def verify_ssh_key_ownership( + self, + key_id: str, + signature: str, + ) -> bool: + """Verify SSH key ownership via signature. + + The user must sign the verification challenge with their private key. + We verify the signature using the public key. + + Args: + key_id: SSH key ID + signature: Base64-encoded signature of the challenge + + Returns: + True if signature is valid + + Raises: + SSHKeyNotFoundError: If key not found + SSHKeyNotVerifiedError: If challenge is stale or missing + SSHKeyError: If verification fails + """ + key = self.get_ssh_key(key_id) + + # Check if challenge exists and is not stale + if not key.verify_text or not key.verify_text_created_at: + raise SSHKeyNotVerifiedError("No verification challenge generated") + + max_age = self.config.get_int('verification_challenge_max_age') + age = datetime.utcnow() - key.verify_text_created_at + if age.total_seconds() > (max_age * 3600): + raise SSHKeyNotVerifiedError("Verification challenge has expired") + + try: + # Verify the SSH signature using ssh-keygen -Y verify. + # The CLI signs the challenge with: ssh-keygen -Y sign -f -n file + # We verify with: ssh-keygen -Y verify -f -I -n file -s < + # + # allowed_signers format: " " + # We use the key fingerprint as the identity. + + sig_bytes = base64.b64decode(signature) + challenge_text = key.verify_text + "\n" + + with tempfile.TemporaryDirectory() as tmpdir: + allowed_signers_path = os.path.join(tmpdir, "allowed_signers") + sig_path = os.path.join(tmpdir, "message.sig") + message_path = os.path.join(tmpdir, "message.txt") + + identity = key.fingerprint + + # Write the allowed_signers file + with open(allowed_signers_path, "w") as f: + f.write(f"{identity} {key.payload}\n") + + # Write the signature file + with open(sig_path, "wb") as f: + f.write(sig_bytes) + + # Write the challenge message + with open(message_path, "w") as f: + f.write(challenge_text) + + result = subprocess.run( + [ + "ssh-keygen", "-Y", "verify", + "-f", allowed_signers_path, + "-I", identity, + "-n", "file", + "-s", sig_path, + ], + stdin=open(message_path, "rb"), + capture_output=True, + timeout=10, + ) + + if result.returncode != 0: + stderr = result.stderr.decode(errors="replace").strip() + logger.warning(f"SSH signature verification failed for key {key_id}: {stderr}") + raise SSHKeyError(f"Signature verification failed: {stderr}") + + key.mark_verified() + logger.info(f"SSH key verified: {key_id}") + return True + + except SSHKeyError: + raise + except Exception as e: + logger.error(f"SSH key verification failed: {str(e)}") + raise SSHKeyError(f"Signature verification failed: {str(e)}") + + def get_key_fingerprint(self, key_id: str) -> str: + """Get the fingerprint of an SSH key. + + Args: + key_id: SSH key ID + + Returns: + Fingerprint string + + Raises: + SSHKeyNotFoundError: If key not found + """ + key = self.get_ssh_key(key_id) + return key.fingerprint + + def update_ssh_key_description(self, key_id: str, description: str) -> SSHKey: + """Update the description of an SSH key. + + Args: + key_id: SSH key ID + description: New description + + Returns: + Updated SSHKey instance + + Raises: + SSHKeyNotFoundError: If key not found + """ + key = self.get_ssh_key(key_id) + key.description = description + key.save() + + return key + + def cleanup_expired_challenges(self) -> int: + """Clean up expired verification challenges. + + Returns: + Number of challenges cleaned + """ + max_age = self.config.get_int('verification_challenge_max_age') + threshold = datetime.utcnow() - timedelta(hours=max_age) + + expired = SSHKey.query.filter( + SSHKey.verify_text_created_at < threshold, + SSHKey.verify_text_created_at.isnot(None), + SSHKey.deleted_at.is_(None), + ).update({"verify_text": None, "verify_text_created_at": None}) + + db.session.commit() + + logger.info(f"Cleaned up {expired} expired verification challenges") + return expired + + def cleanup_unverified_keys(self) -> int: + """Delete unverified SSH keys older than configured days. + + Returns: + Number of keys deleted + """ + days = self.config.get_int('auto_delete_unverified_days') + threshold = datetime.utcnow() - timedelta(days=days) + + old_unverified = SSHKey.query.filter( + SSHKey.verified == False, + SSHKey.created_at < threshold, + SSHKey.deleted_at.is_(None), + ).all() + + count = 0 + for key in old_unverified: + key.delete() + count += 1 + + logger.info(f"Deleted {count} unverified SSH keys older than {days} days") + return count diff --git a/gatehouse_app/services/totp_service.py b/gatehouse_app/services/totp_service.py index b71f56a..c667e3e 100644 --- a/gatehouse_app/services/totp_service.py +++ b/gatehouse_app/services/totp_service.py @@ -11,10 +11,51 @@ from gatehouse_app.extensions import bcrypt logger = logging.getLogger(__name__) +# TOTP codes are valid for at most (2*window + 1) * 30s steps. +# With window=1 that's 3 steps = 90 seconds. We use a slightly +# generous TTL of 95 seconds to account for clock skew at boundaries. +_TOTP_USED_CODE_TTL = 95 + class TOTPService: """Service for TOTP operations.""" + # ------------------------------------------------------------------ + # Replay-attack prevention helpers + # ------------------------------------------------------------------ + + @staticmethod + def _used_key(user_id: str, code: str) -> str: + return f"totp:used:{user_id}:{code}" + + @staticmethod + def is_code_already_used(user_id: str, code: str) -> bool: + """Return True if *code* has already been accepted for *user_id* + within the current validity window (prevents replay attacks).""" + try: + from gatehouse_app.extensions import redis_client + if redis_client is None: + return False + return redis_client.exists(TOTPService._used_key(user_id, code)) == 1 + except Exception: + logger.warning("Redis unavailable for TOTP replay check; allowing code") + return False + + @staticmethod + def mark_code_used(user_id: str, code: str) -> None: + """Record *code* as consumed for *user_id* so it cannot be reused.""" + try: + from gatehouse_app.extensions import redis_client + if redis_client is None: + return + redis_client.setex( + TOTPService._used_key(user_id, code), + _TOTP_USED_CODE_TTL, + "1", + ) + except Exception: + logger.warning("Redis unavailable; TOTP used-code not recorded") + @staticmethod def generate_secret() -> str: """ diff --git a/gatehouse_app/services/user_service.py b/gatehouse_app/services/user_service.py index 94845d1..b8b5fc9 100644 --- a/gatehouse_app/services/user_service.py +++ b/gatehouse_app/services/user_service.py @@ -2,7 +2,7 @@ import logging from flask import current_app from gatehouse_app.extensions import db -from gatehouse_app.models.user import User +from gatehouse_app.models.user.user import User from gatehouse_app.exceptions.validation_exceptions import UserNotFoundError from gatehouse_app.utils.constants import AuditAction from gatehouse_app.services.audit_service import AuditService diff --git a/gatehouse_app/services/webauthn_service.py b/gatehouse_app/services/webauthn_service.py index 79ea50a..9f953d3 100644 --- a/gatehouse_app/services/webauthn_service.py +++ b/gatehouse_app/services/webauthn_service.py @@ -10,8 +10,8 @@ from flask import current_app from sqlalchemy.orm.attributes import flag_modified from gatehouse_app.extensions import db, redis_client -from gatehouse_app.models.user import User -from gatehouse_app.models.authentication_method import AuthenticationMethod +from gatehouse_app.models.user.user import User +from gatehouse_app.models.auth.authentication_method import AuthenticationMethod from gatehouse_app.utils.constants import AuthMethodType, AuditAction from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError from gatehouse_app.services.audit_service import AuditService @@ -641,6 +641,26 @@ class WebAuthnService: ) return True + + @classmethod + def credential_belongs_to_user(cls, credential_id: str, user: User) -> bool: + """Check whether *credential_id* exists and belongs to *user*. + + Args: + credential_id: The credential ID to look up + user: User instance + + Returns: + True if the credential exists and belongs to this user, False otherwise. + """ + auth_method = AuthenticationMethod.query.filter_by( + user_id=user.id, + method_type=AuthMethodType.WEBAUTHN, + deleted_at=None, + ).first() + if not auth_method or not auth_method.provider_data: + return False + return auth_method.provider_data.get("credential_id") == credential_id @classmethod def rename_credential(cls, credential_id: str, user: User, name: str) -> bool: diff --git a/gatehouse_app/utils/ca_key_encryption.py b/gatehouse_app/utils/ca_key_encryption.py new file mode 100644 index 0000000..183e3c3 --- /dev/null +++ b/gatehouse_app/utils/ca_key_encryption.py @@ -0,0 +1,206 @@ +"""Encryption helpers for CA private keys stored in the database. + +CA private keys are encrypted at rest using Fernet (AES-128-CBC + HMAC-SHA256) +from the ``cryptography`` package. The encryption key is derived from the +``CA_ENCRYPTION_KEY`` environment variable (or ``Flask.config["CA_ENCRYPTION_KEY"]``). + +Key derivation +-------------- +Fernet requires a URL-safe base64-encoded 32-byte key. We accept any string +from the env and derive the actual Fernet key using SHA-256 so that operators +can supply human-readable secrets without having to pre-encode them. + +Envelope format +--------------- +Encrypted values are stored as the string:: + + $fernet$ + +The ``$fernet$`` prefix lets the code distinguish already-encrypted values from +legacy plaintext PEM keys so that the migration path is safe and idempotent. + +Usage +----- +Encrypt before storing:: + + from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key + ca.private_key = encrypt_ca_key(private_key_pem) + +Decrypt before use:: + + from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key + plaintext_pem = decrypt_ca_key(ca.private_key) +""" +import base64 +import hashlib +import logging +import os + +from cryptography.fernet import Fernet, InvalidToken + +logger = logging.getLogger(__name__) + +# Prefix that marks a stored value as Fernet-encrypted +_FERNET_PREFIX = "$fernet$" + + +class CAKeyEncryptionError(Exception): + """Raised when CA key encryption or decryption fails.""" + + +def _get_fernet() -> Fernet: + """Build a Fernet instance from the configured encryption key. + + Looks up ``CA_ENCRYPTION_KEY`` in the environment first, then falls back to + the Flask app config (if a request context is active). + + Raises: + CAKeyEncryptionError: if no key is configured or it is the insecure + placeholder value in a production-like environment. + """ + raw_key = os.environ.get("CA_ENCRYPTION_KEY") + + if not raw_key: + # Try Flask config if we're inside an app context + try: + from flask import current_app + raw_key = current_app.config.get("CA_ENCRYPTION_KEY") + except RuntimeError: + pass # No app context + + if not raw_key: + raise CAKeyEncryptionError( + "CA_ENCRYPTION_KEY is not set. " + "Set this environment variable before starting the application." + ) + + # Warn loudly when running with the placeholder in a non-test environment + env_name = os.environ.get("FLASK_ENV", "").lower() + if raw_key.startswith("dev-") and env_name not in ("development", "testing", "test"): + logger.warning( + "CA_ENCRYPTION_KEY appears to be a development placeholder. " + "Set a strong random key for production environments." + ) + + # Derive a 32-byte key from the raw secret via SHA-256, then URL-safe base64 + key_bytes = hashlib.sha256(raw_key.encode()).digest() + fernet_key = base64.urlsafe_b64encode(key_bytes) + return Fernet(fernet_key) + + +def encrypt_ca_key(plaintext_pem: str) -> str: + """Encrypt a CA private key PEM string. + + Idempotent: already-encrypted values are returned unchanged. + + Args: + plaintext_pem: CA private key in OpenSSH/PEM format. + + Returns: + Encrypted string with ``$fernet$`` prefix, safe for database storage. + + Raises: + CAKeyEncryptionError: if the key cannot be encrypted. + """ + if not plaintext_pem: + raise CAKeyEncryptionError("Cannot encrypt an empty key") + + # Already encrypted — do not double-encrypt + if plaintext_pem.startswith(_FERNET_PREFIX): + return plaintext_pem + + try: + fernet = _get_fernet() + token = fernet.encrypt(plaintext_pem.encode()).decode() + return f"{_FERNET_PREFIX}{token}" + except CAKeyEncryptionError: + raise + except Exception as exc: + raise CAKeyEncryptionError(f"Failed to encrypt CA key: {exc}") from exc + + +def decrypt_ca_key(stored_value: str) -> str: + """Decrypt a CA private key retrieved from the database. + + Idempotent: plaintext (legacy) values are returned unchanged so that the + system continues to work while a migration encrypts existing rows. + + Args: + stored_value: Value from ``CA.private_key`` column. + + Returns: + Plaintext PEM string ready for use with ``sshkey_tools``. + + Raises: + CAKeyEncryptionError: if decryption fails (wrong key, corrupted data). + """ + if not stored_value: + raise CAKeyEncryptionError("Cannot decrypt an empty value") + + # Legacy plaintext key — return as-is + if not stored_value.startswith(_FERNET_PREFIX): + logger.warning( + "CA private key appears to be stored as plaintext. " + "Run the migration to encrypt existing keys." + ) + return stored_value + + token = stored_value[len(_FERNET_PREFIX):] + try: + fernet = _get_fernet() + return fernet.decrypt(token.encode()).decode() + except InvalidToken as exc: + raise CAKeyEncryptionError( + "CA key decryption failed — the CA_ENCRYPTION_KEY may be incorrect " + "or the stored key is corrupted." + ) from exc + except CAKeyEncryptionError: + raise + except Exception as exc: + raise CAKeyEncryptionError(f"Unexpected decryption error: {exc}") from exc + + +def is_encrypted(stored_value: str) -> bool: + """Return True if the stored value has the ``$fernet$`` envelope. + + Args: + stored_value: Value from ``CA.private_key`` column. + """ + return bool(stored_value and stored_value.startswith(_FERNET_PREFIX)) + + +def reencrypt_ca_key(stored_value: str, old_raw_key: str, new_raw_key: str) -> str: + """Re-encrypt a CA key with a new encryption key (for key rotation). + + Args: + stored_value: Current value from ``CA.private_key`` (may or may not be encrypted). + old_raw_key: The current ``CA_ENCRYPTION_KEY`` value (raw secret string). + new_raw_key: The new ``CA_ENCRYPTION_KEY`` value to encrypt with. + + Returns: + New encrypted envelope string. + + Raises: + CAKeyEncryptionError: if decryption or re-encryption fails. + """ + # Decrypt with old key + if stored_value.startswith(_FERNET_PREFIX): + token = stored_value[len(_FERNET_PREFIX):] + old_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(old_raw_key.encode()).digest()) + try: + plaintext = Fernet(old_key_bytes).decrypt(token.encode()).decode() + except InvalidToken as exc: + raise CAKeyEncryptionError( + "Re-encryption failed: could not decrypt with the old key." + ) from exc + else: + # Plaintext + plaintext = stored_value + + # Re-encrypt with new key + new_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(new_raw_key.encode()).digest()) + try: + token = Fernet(new_key_bytes).encrypt(plaintext.encode()).decode() + return f"{_FERNET_PREFIX}{token}" + except Exception as exc: + raise CAKeyEncryptionError(f"Re-encryption with new key failed: {exc}") from exc diff --git a/gatehouse_app/utils/constants.py b/gatehouse_app/utils/constants.py index cd6a58b..2a99825 100644 --- a/gatehouse_app/utils/constants.py +++ b/gatehouse_app/utils/constants.py @@ -12,6 +12,16 @@ class UserStatus(str, Enum): COMPLIANCE_SUSPENDED = "compliance_suspended" +class Role(str, Enum): + """Generic role definitions (hierarchy: Admin > Manager > Member > Viewer > Guest).""" + + ADMIN = "admin" + MANAGER = "manager" + MEMBER = "member" + VIEWER = "viewer" + GUEST = "guest" + + class OrganizationRole(str, Enum): """Organization member roles.""" @@ -51,6 +61,9 @@ class AuditAction(str, Enum): USER_REGISTER = "user.register" USER_UPDATE = "user.update" USER_DELETE = "user.delete" + USER_HARD_DELETE = "user.hard_delete" + USER_SUSPEND = "user.suspend" + USER_UNSUSPEND = "user.unsuspend" PASSWORD_CHANGE = "user.password_change" PASSWORD_RESET = "user.password_reset" @@ -61,6 +74,7 @@ class AuditAction(str, Enum): ORG_MEMBER_ADD = "org.member.add" ORG_MEMBER_REMOVE = "org.member.remove" ORG_MEMBER_ROLE_CHANGE = "org.member.role_change" + ORG_OWNERSHIP_TRANSFERRED = "org.ownership.transferred" # Session actions SESSION_CREATE = "session.create" @@ -105,6 +119,37 @@ class AuditAction(str, Enum): EXTERNAL_AUTH_CONFIG_UPDATE = "external_auth.config.update" EXTERNAL_AUTH_CONFIG_DELETE = "external_auth.config.delete" + # SSH Key and Certificate actions + SSH_KEY_ADDED = "ssh.key.added" + SSH_KEY_VERIFIED = "ssh.key.verified" + SSH_KEY_DELETED = "ssh.key.deleted" + SSH_KEY_VALIDATION_FAILED = "ssh.key.validation.failed" + SSH_CERT_REQUESTED = "ssh.cert.requested" + SSH_CERT_ISSUED = "ssh.cert.issued" + SSH_CERT_FAILED = "ssh.cert.failed" + SSH_CERT_REVOKED = "ssh.cert.revoked" + SSH_CERT_EXPIRED = "ssh.cert.expired" + + # CA actions + CA_CREATED = "ca.created" + CA_UPDATED = "ca.updated" + CA_DELETED = "ca.deleted" + CA_KEY_ROTATED = "ca.key.rotated" + + # Principal actions + PRINCIPAL_CREATED = "principal.created" + PRINCIPAL_UPDATED = "principal.updated" + PRINCIPAL_DELETED = "principal.deleted" + PRINCIPAL_MEMBER_ADDED = "principal.member.added" + PRINCIPAL_MEMBER_REMOVED = "principal.member.removed" + + # Department actions + DEPARTMENT_CREATED = "department.created" + DEPARTMENT_UPDATED = "department.updated" + DEPARTMENT_DELETED = "department.deleted" + DEPARTMENT_MEMBER_ADDED = "department.member.added" + DEPARTMENT_MEMBER_REMOVED = "department.member.removed" + class OIDCGrantType(str, Enum): """OIDC grant types.""" diff --git a/gatehouse_app/utils/crypto.py b/gatehouse_app/utils/crypto.py new file mode 100644 index 0000000..0322a8c --- /dev/null +++ b/gatehouse_app/utils/crypto.py @@ -0,0 +1,128 @@ +"""Cryptographic utilities for SSH operations.""" +import hashlib +import base64 +from typing import Optional + + +def compute_ssh_fingerprint(public_key_str: str, hash_algorithm: str = "sha256") -> str: + """Compute the fingerprint of an SSH public key. + + Args: + public_key_str: SSH public key in OpenSSH format + hash_algorithm: Hash algorithm to use (sha256, sha1, md5) + + Returns: + Fingerprint string in the format "algorithm:hex_digest" + + Example: + >>> key = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKp2..." + >>> fp = compute_ssh_fingerprint(key) + >>> print(fp) + sha256:Kb+... + """ + if not public_key_str: + raise ValueError("Public key string is empty") + + # Parse OpenSSH format: "ssh-ed25519 [comment]" + parts = public_key_str.strip().split() + if len(parts) < 2: + raise ValueError("Invalid OpenSSH public key format") + + try: + # The base64-encoded key is the second part + key_bytes = base64.b64decode(parts[1]) + except Exception as e: + raise ValueError(f"Failed to decode public key: {str(e)}") + + # Compute hash + if hash_algorithm == "sha256": + digest = hashlib.sha256(key_bytes).digest() + # SSH format uses base64 encoding without padding + fingerprint = base64.b64encode(digest).decode().rstrip('=') + elif hash_algorithm == "sha1": + digest = hashlib.sha1(key_bytes).hexdigest() + fingerprint = digest + elif hash_algorithm == "md5": + digest = hashlib.md5(key_bytes).hexdigest() + # Format as colons + fingerprint = ':'.join(digest[i:i+2] for i in range(0, len(digest), 2)) + else: + raise ValueError(f"Unsupported hash algorithm: {hash_algorithm}") + + return f"{hash_algorithm}:{fingerprint}" + + +def verify_ssh_key_format(public_key_str: str) -> bool: + """Verify that a string is in valid OpenSSH public key format. + + Args: + public_key_str: Potential SSH public key + + Returns: + True if valid OpenSSH format, False otherwise + """ + if not public_key_str or not isinstance(public_key_str, str): + return False + + parts = public_key_str.strip().split() + + # Must have at least key type and key material + if len(parts) < 2: + return False + + key_type = parts[0] + + # Valid key types + valid_types = [ + 'ssh-rsa', + 'ssh-ed25519', + 'ecdsa-sha2-nistp256', + 'ecdsa-sha2-nistp384', + 'ecdsa-sha2-nistp521', + 'ssh-dss', + ] + + if key_type not in valid_types: + return False + + # Try to decode base64 + try: + base64.b64decode(parts[1]) + return True + except Exception: + return False + + +def extract_ssh_key_type(public_key_str: str) -> Optional[str]: + """Extract the key type from an OpenSSH public key. + + Args: + public_key_str: SSH public key in OpenSSH format + + Returns: + Key type (e.g., "ssh-ed25519") or None if invalid + """ + if not verify_ssh_key_format(public_key_str): + return None + + return public_key_str.strip().split()[0] + + +def extract_ssh_key_comment(public_key_str: str) -> Optional[str]: + """Extract the comment from an OpenSSH public key. + + Args: + public_key_str: SSH public key in OpenSSH format + + Returns: + Comment string or None if not present + """ + if not verify_ssh_key_format(public_key_str): + return None + + parts = public_key_str.strip().split() + if len(parts) >= 3: + # Everything after the second part is the comment + return ' '.join(parts[2:]) + + return None diff --git a/gatehouse_app/utils/decorators.py b/gatehouse_app/utils/decorators.py index e3b0085..5cbb649 100644 --- a/gatehouse_app/utils/decorators.py +++ b/gatehouse_app/utils/decorators.py @@ -64,11 +64,47 @@ def login_required(f): session.last_activity_at = datetime.now(timezone.utc) from gatehouse_app import db db.session.commit() - + # Set context variables g.current_user = session.user g.current_session = session - + + user = session.user + token_groups: list = [] + try: + if session.device_info: + # device_info may carry OIDC claims stored at login time + claims = session.device_info + # Normalise: Gatehouse stores roles as [{"organization_id":…,"role":…}] + roles_claim = claims.get("roles", []) + if isinstance(roles_claim, list): + for entry in roles_claim: + if isinstance(entry, dict): + role_val = entry.get("role") + if role_val: + token_groups.append(str(role_val)) + elif isinstance(entry, str): + token_groups.append(entry) + # Standard OIDC groups claim + groups_claim = claims.get("groups", []) + if isinstance(groups_claim, list): + token_groups.extend(str(g_) for g_ in groups_claim if g_) + except Exception: + pass # Never block auth over token_groups enrichment failure + user.token_groups = token_groups + + # Activation check: if the user has an `activated` attribute and it is + # explicitly False, block access. New accounts without the attribute are + # treated as active to avoid breaking existing sessions. + activated = getattr(user, "activated", None) + if activated is False: + return api_response( + success=False, + message="Account not yet activated. Please check your email for an activation link.", + status=403, + error_type="ACCOUNT_NOT_ACTIVATED", + ) + return f(*args, **kwargs) return decorated_function @@ -97,11 +133,12 @@ def require_role(*allowed_roles): raise ForbiddenError("Organization context required") # Check user's role in the organization - from gatehouse_app.models.organization_member import OrganizationMember + from gatehouse_app.models.organization.organization_member import OrganizationMember membership = OrganizationMember.query.filter_by( user_id=g.current_user.id, organization_id=org_id, + deleted_at=None, ).first() if not membership: diff --git a/manage.py b/manage.py index 088447e..a89d335 100644 --- a/manage.py +++ b/manage.py @@ -111,5 +111,79 @@ def mfa_compliance_status(): print("=" * 60) +@cli.command("configure_oauth") +def configure_oauth(): + """Interactively configure an OAuth provider at the application level. + + Usage: + python manage.py configure_oauth + + Supported providers: google, github, microsoft + """ + import getpass + from gatehouse_app.models.authentication_method import ApplicationProviderConfig + from gatehouse_app.extensions import db + + SUPPORTED = ["google", "github", "microsoft"] + + print("=" * 60) + print("OAuth Provider Configuration") + print("=" * 60) + print(f"Supported providers: {', '.join(SUPPORTED)}") + + provider = input("Provider [google/github/microsoft]: ").strip().lower() + if provider not in SUPPORTED: + print(f"❌ Unknown provider: {provider}") + return + + client_id = input("Client ID: ").strip() + if not client_id: + print("❌ client_id is required") + return + + client_secret = getpass.getpass("Client Secret (leave blank to keep existing): ").strip() + + with app.app_context(): + config = ApplicationProviderConfig.query.filter_by(provider_type=provider).first() + if config: + config.client_id = client_id + if client_secret: + config.set_client_secret(client_secret) + config.is_enabled = True + db.session.commit() + print(f"✅ Updated {provider} provider config.") + else: + config = ApplicationProviderConfig( + provider_type=provider, + client_id=client_id, + is_enabled=True, + ) + if client_secret: + config.set_client_secret(client_secret) + db.session.add(config) + db.session.commit() + print(f"✅ Created {provider} provider config.") + + +@cli.command("list_oauth") +def list_oauth(): + """List all configured OAuth providers. + + Usage: + python manage.py list_oauth + """ + from gatehouse_app.models.authentication_method import ApplicationProviderConfig + + with app.app_context(): + configs = ApplicationProviderConfig.query.all() + if not configs: + print("No OAuth providers configured.") + return + print(f"{'Provider':<15} {'Client ID':<40} {'Enabled'}") + print("-" * 65) + for c in configs: + print(f"{c.provider_type:<15} {c.client_id:<40} {c.is_enabled}") + + if __name__ == "__main__": cli() diff --git a/migrations/versions/006_add_departments_principals.py b/migrations/versions/006_add_departments_principals.py new file mode 100644 index 0000000..8022143 --- /dev/null +++ b/migrations/versions/006_add_departments_principals.py @@ -0,0 +1,127 @@ +"""Add Department and Principal models for SSH CA management. + +Revision ID: 006 +Revises: 005 +Create Date: 2026-02-27 10:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '006' +down_revision = '005' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### Department table ### + op.create_table('departments', + sa.Column('organization_id', sa.String(length=36), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id'), + sa.UniqueConstraint('organization_id', 'name', name='uix_org_dept_name') + ) + op.create_index(op.f('ix_departments_organization_id'), 'departments', ['organization_id'], unique=False) + op.create_index(op.f('ix_departments_name'), 'departments', ['name'], unique=False) + + # ### DepartmentMembership table ### + op.create_table('department_memberships', + sa.Column('user_id', sa.String(length=36), nullable=False), + sa.Column('department_id', sa.String(length=36), nullable=False), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id'), + sa.UniqueConstraint('user_id', 'department_id', name='uix_user_dept') + ) + op.create_index(op.f('ix_department_memberships_user_id'), 'department_memberships', ['user_id'], unique=False) + op.create_index(op.f('ix_department_memberships_department_id'), 'department_memberships', ['department_id'], unique=False) + + # ### Principal table ### + op.create_table('principals', + sa.Column('organization_id', sa.String(length=36), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id'), + sa.UniqueConstraint('organization_id', 'name', name='uix_org_principal_name') + ) + op.create_index(op.f('ix_principals_organization_id'), 'principals', ['organization_id'], unique=False) + op.create_index(op.f('ix_principals_name'), 'principals', ['name'], unique=False) + + # ### PrincipalMembership table ### + op.create_table('principal_memberships', + sa.Column('user_id', sa.String(length=36), nullable=False), + sa.Column('principal_id', sa.String(length=36), nullable=False), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id'), + sa.UniqueConstraint('user_id', 'principal_id', name='uix_user_principal') + ) + op.create_index(op.f('ix_principal_memberships_user_id'), 'principal_memberships', ['user_id'], unique=False) + op.create_index(op.f('ix_principal_memberships_principal_id'), 'principal_memberships', ['principal_id'], unique=False) + + # ### DepartmentPrincipal table ### + op.create_table('department_principals', + sa.Column('department_id', sa.String(length=36), nullable=False), + sa.Column('principal_id', sa.String(length=36), nullable=False), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ), + sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id'), + sa.UniqueConstraint('department_id', 'principal_id', name='uix_dept_principal') + ) + op.create_index(op.f('ix_department_principals_department_id'), 'department_principals', ['department_id'], unique=False) + op.create_index(op.f('ix_department_principals_principal_id'), 'department_principals', ['principal_id'], unique=False) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_department_principals_principal_id'), table_name='department_principals') + op.drop_index(op.f('ix_department_principals_department_id'), table_name='department_principals') + op.drop_table('department_principals') + + op.drop_index(op.f('ix_principal_memberships_principal_id'), table_name='principal_memberships') + op.drop_index(op.f('ix_principal_memberships_user_id'), table_name='principal_memberships') + op.drop_table('principal_memberships') + + op.drop_index(op.f('ix_principals_name'), table_name='principals') + op.drop_index(op.f('ix_principals_organization_id'), table_name='principals') + op.drop_table('principals') + + op.drop_index(op.f('ix_department_memberships_department_id'), table_name='department_memberships') + op.drop_index(op.f('ix_department_memberships_user_id'), table_name='department_memberships') + op.drop_table('department_memberships') + + op.drop_index(op.f('ix_departments_name'), table_name='departments') + op.drop_index(op.f('ix_departments_organization_id'), table_name='departments') + op.drop_table('departments') + # ### end Alembic commands ### diff --git a/migrations/versions/007_add_ssh_ca_models.py b/migrations/versions/007_add_ssh_ca_models.py new file mode 100644 index 0000000..9749930 --- /dev/null +++ b/migrations/versions/007_add_ssh_ca_models.py @@ -0,0 +1,173 @@ +"""Add SSH CA models: SSHKey, SSHCertificate, CA, CertificateAuditLog. + +Revision ID: 007 +Revises: 006 +Create Date: 2026-02-27 11:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '007' +down_revision = '006' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### CA table ### + op.create_table('cas', + sa.Column('organization_id', sa.String(length=36), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('key_type', sa.Enum('ed25519', 'rsa', 'ecdsa', name='ca_key_type_enum'), nullable=False), + sa.Column('private_key', sa.Text(), nullable=False), + sa.Column('public_key', sa.Text(), nullable=False), + sa.Column('fingerprint', sa.String(length=255), nullable=False), + sa.Column('crl_enabled', sa.Boolean(), nullable=False), + sa.Column('crl_endpoint', sa.String(length=512), nullable=True), + sa.Column('default_cert_validity_hours', sa.Integer(), nullable=False), + sa.Column('max_cert_validity_hours', sa.Integer(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('rotated_at', sa.DateTime(), nullable=True), + sa.Column('rotation_reason', sa.String(length=255), nullable=True), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id'), + sa.UniqueConstraint('fingerprint'), + sa.UniqueConstraint('organization_id', 'name', name='uix_org_ca_name') + ) + op.create_index(op.f('ix_cas_organization_id'), 'cas', ['organization_id'], unique=False) + op.create_index('idx_ca_org_active', 'cas', ['organization_id', 'is_active'], unique=False) + + # ### SSHKey table ### + op.create_table('ssh_keys', + sa.Column('user_id', sa.String(length=36), nullable=False), + sa.Column('payload', sa.Text(), nullable=False), + sa.Column('fingerprint', sa.String(length=255), nullable=False), + sa.Column('description', sa.String(length=255), nullable=True), + sa.Column('verified', sa.Boolean(), nullable=False), + sa.Column('verified_at', sa.DateTime(), nullable=True), + sa.Column('verify_text', sa.String(length=255), nullable=True), + sa.Column('verify_text_created_at', sa.DateTime(), nullable=True), + sa.Column('key_type', sa.String(length=50), nullable=True), + sa.Column('key_bits', sa.Integer(), nullable=True), + sa.Column('key_comment', sa.String(length=255), nullable=True), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id'), + sa.UniqueConstraint('payload'), + sa.UniqueConstraint('fingerprint') + ) + op.create_index(op.f('ix_ssh_keys_user_id'), 'ssh_keys', ['user_id'], unique=False) + op.create_index(op.f('ix_ssh_keys_fingerprint'), 'ssh_keys', ['fingerprint'], unique=False) + op.create_index(op.f('ix_ssh_keys_verified'), 'ssh_keys', ['verified'], unique=False) + op.create_index('idx_ssh_key_user_verified', 'ssh_keys', ['user_id', 'verified'], unique=False) + + # ### SSHCertificate table ### + op.create_table('ssh_certificates', + sa.Column('ca_id', sa.String(length=36), nullable=False), + sa.Column('user_id', sa.String(length=36), nullable=False), + sa.Column('ssh_key_id', sa.String(length=36), nullable=False), + sa.Column('certificate', sa.Text(), nullable=False), + sa.Column('serial', sa.String(length=255), nullable=False), + sa.Column('key_id', sa.String(length=255), nullable=False), + sa.Column('cert_type', sa.Enum('user', 'host', name='ssh_cert_type_enum'), nullable=False), + sa.Column('principals', sa.JSON(), nullable=False), + sa.Column('valid_after', sa.DateTime(), nullable=False), + sa.Column('valid_before', sa.DateTime(), nullable=False), + sa.Column('revoked', sa.Boolean(), nullable=False), + sa.Column('revoked_at', sa.DateTime(), nullable=True), + sa.Column('revoke_reason', sa.String(length=255), nullable=True), + sa.Column('status', sa.Enum('requested', 'issued', 'revoked', 'expired', 'superseded', name='ssh_cert_status_enum'), nullable=False), + sa.Column('request_ip', sa.String(length=45), nullable=True), + sa.Column('request_user_agent', sa.String(length=512), nullable=True), + sa.Column('critical_options', sa.JSON(), nullable=True), + sa.Column('extensions', sa.JSON(), nullable=True), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ), + sa.ForeignKeyConstraint(['ssh_key_id'], ['ssh_keys.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id'), + sa.UniqueConstraint('serial') + ) + op.create_index(op.f('ix_ssh_certificates_ca_id'), 'ssh_certificates', ['ca_id'], unique=False) + op.create_index(op.f('ix_ssh_certificates_user_id'), 'ssh_certificates', ['user_id'], unique=False) + op.create_index(op.f('ix_ssh_certificates_ssh_key_id'), 'ssh_certificates', ['ssh_key_id'], unique=False) + op.create_index(op.f('ix_ssh_certificates_serial'), 'ssh_certificates', ['serial'], unique=False) + op.create_index(op.f('ix_ssh_certificates_revoked'), 'ssh_certificates', ['revoked'], unique=False) + op.create_index(op.f('ix_ssh_certificates_status'), 'ssh_certificates', ['status'], unique=False) + op.create_index('idx_cert_user_status', 'ssh_certificates', ['user_id', 'status'], unique=False) + op.create_index('idx_cert_validity', 'ssh_certificates', ['valid_after', 'valid_before'], unique=False) + op.create_index('idx_cert_revoked', 'ssh_certificates', ['revoked', 'revoked_at'], unique=False) + + # ### CertificateAuditLog table ### + op.create_table('certificate_audit_logs', + sa.Column('certificate_id', sa.String(length=36), nullable=False), + sa.Column('user_id', sa.String(length=36), nullable=True), + sa.Column('action', sa.String(length=50), nullable=False), + sa.Column('ip_address', sa.String(length=45), nullable=True), + sa.Column('user_agent', sa.String(length=512), nullable=True), + sa.Column('request_id', sa.String(length=36), nullable=True), + sa.Column('message', sa.Text(), nullable=True), + sa.Column('extra_data', sa.JSON(), nullable=True), + sa.Column('success', sa.Boolean(), nullable=False), + sa.Column('error_message', sa.Text(), nullable=True), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['certificate_id'], ['ssh_certificates.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id') + ) + op.create_index(op.f('ix_certificate_audit_logs_certificate_id'), 'certificate_audit_logs', ['certificate_id'], unique=False) + op.create_index(op.f('ix_certificate_audit_logs_user_id'), 'certificate_audit_logs', ['user_id'], unique=False) + op.create_index(op.f('ix_certificate_audit_logs_action'), 'certificate_audit_logs', ['action'], unique=False) + op.create_index('idx_cert_audit_cert_action', 'certificate_audit_logs', ['certificate_id', 'action'], unique=False) + op.create_index('idx_cert_audit_user', 'certificate_audit_logs', ['user_id', 'created_at'], unique=False) + + +def downgrade(): + op.drop_index('idx_cert_audit_user', table_name='certificate_audit_logs') + op.drop_index('idx_cert_audit_cert_action', table_name='certificate_audit_logs') + op.drop_index(op.f('ix_certificate_audit_logs_action'), table_name='certificate_audit_logs') + op.drop_index(op.f('ix_certificate_audit_logs_user_id'), table_name='certificate_audit_logs') + op.drop_index(op.f('ix_certificate_audit_logs_certificate_id'), table_name='certificate_audit_logs') + op.drop_table('certificate_audit_logs') + + op.drop_index('idx_cert_revoked', table_name='ssh_certificates') + op.drop_index('idx_cert_validity', table_name='ssh_certificates') + op.drop_index('idx_cert_user_status', table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_status'), table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_revoked'), table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_serial'), table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_ssh_key_id'), table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_user_id'), table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_ca_id'), table_name='ssh_certificates') + op.drop_table('ssh_certificates') + + op.drop_index('idx_ssh_key_user_verified', table_name='ssh_keys') + op.drop_index(op.f('ix_ssh_keys_verified'), table_name='ssh_keys') + op.drop_index(op.f('ix_ssh_keys_fingerprint'), table_name='ssh_keys') + op.drop_index(op.f('ix_ssh_keys_user_id'), table_name='ssh_keys') + op.drop_table('ssh_keys') + + op.drop_index('idx_ca_org_active', table_name='cas') + op.drop_index(op.f('ix_cas_organization_id'), table_name='cas') + op.drop_table('cas') diff --git a/migrations/versions/008_fix_authmethodtype_enum.py b/migrations/versions/008_fix_authmethodtype_enum.py new file mode 100644 index 0000000..ccf56e9 --- /dev/null +++ b/migrations/versions/008_fix_authmethodtype_enum.py @@ -0,0 +1,53 @@ +"""Add TOTP and WEBAUTHN to authmethodtype enum. + +Revision ID: 008 +Revises: 007 +Create Date: 2026-02-27 15:00:00.000000 + +The original migration (001_base) created authmethodtype with only: + PASSWORD, GOOGLE, GITHUB, MICROSOFT, SAML, OIDC + +This migration adds the missing TOTP and WEBAUTHN values so +has_totp_enabled() and has_webauthn_enabled() queries work correctly. +""" +from alembic import op +import sqlalchemy as sa + + +revision = '008' +down_revision = '007' +branch_labels = None +depends_on = None + + +def upgrade(): + # Add TOTP to the enum (idempotent approach using DO block) + op.execute(""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_enum + WHERE enumlabel = 'TOTP' + AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype') + ) THEN + ALTER TYPE authmethodtype ADD VALUE 'TOTP'; + END IF; + END$$; + """) + op.execute(""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_enum + WHERE enumlabel = 'WEBAUTHN' + AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype') + ) THEN + ALTER TYPE authmethodtype ADD VALUE 'WEBAUTHN'; + END IF; + END$$; + """) + + +def downgrade(): + # PostgreSQL does not support removing enum values; downgrade is a no-op. + pass diff --git a/migrations/versions/009_sync_auditaction_enum.py b/migrations/versions/009_sync_auditaction_enum.py new file mode 100644 index 0000000..59bc210 --- /dev/null +++ b/migrations/versions/009_sync_auditaction_enum.py @@ -0,0 +1,61 @@ +"""Sync auditaction enum with all AuditAction Python enum values. + +Revision ID: 009 +Revises: 008 +Create Date: 2026-02-27 15:20:00.000000 + +The auditaction DB enum was only created with the initial 17 values from 001_base.py. +All TOTP, WebAuthn, OAuth, SSH, CA, Principal, and Department audit actions were added +to the Python enum but never synced to the DB type. +""" +from alembic import op + + +revision = '009' +down_revision = '008' +branch_labels = None +depends_on = None + +MISSING_VALUES = [ + 'TOTP_ENROLL_INITIATED', 'TOTP_ENROLL_COMPLETED', 'TOTP_VERIFY_SUCCESS', + 'TOTP_VERIFY_FAILED', 'TOTP_DISABLED', 'TOTP_BACKUP_CODE_USED', + 'TOTP_BACKUP_CODES_REGENERATED', 'WEBAUTHN_REGISTER_INITIATED', + 'WEBAUTHN_REGISTER_COMPLETED', 'WEBAUTHN_REGISTER_FAILED', + 'WEBAUTHN_LOGIN_INITIATED', 'WEBAUTHN_LOGIN_SUCCESS', 'WEBAUTHN_LOGIN_FAILED', + 'WEBAUTHN_CREDENTIAL_DELETED', 'WEBAUTHN_CREDENTIAL_RENAMED', + 'ORG_SECURITY_POLICY_UPDATE', 'USER_SECURITY_POLICY_OVERRIDE_UPDATE', + 'MFA_POLICY_USER_SUSPENDED', 'MFA_POLICY_USER_COMPLIANT', + 'EXTERNAL_AUTH_LINK_INITIATED', 'EXTERNAL_AUTH_LINK_COMPLETED', + 'EXTERNAL_AUTH_LINK_FAILED', 'EXTERNAL_AUTH_UNLINK', 'EXTERNAL_AUTH_LOGIN', + 'EXTERNAL_AUTH_LOGIN_FAILED', 'EXTERNAL_AUTH_TOKEN_REFRESH', + 'EXTERNAL_AUTH_CONFIG_CREATE', 'EXTERNAL_AUTH_CONFIG_UPDATE', + 'EXTERNAL_AUTH_CONFIG_DELETE', 'SSH_KEY_ADDED', 'SSH_KEY_VERIFIED', + 'SSH_KEY_DELETED', 'SSH_KEY_VALIDATION_FAILED', 'SSH_CERT_REQUESTED', + 'SSH_CERT_ISSUED', 'SSH_CERT_FAILED', 'SSH_CERT_REVOKED', 'SSH_CERT_EXPIRED', + 'CA_CREATED', 'CA_UPDATED', 'CA_DELETED', 'CA_KEY_ROTATED', + 'PRINCIPAL_CREATED', 'PRINCIPAL_UPDATED', 'PRINCIPAL_DELETED', + 'PRINCIPAL_MEMBER_ADDED', 'PRINCIPAL_MEMBER_REMOVED', + 'DEPARTMENT_CREATED', 'DEPARTMENT_UPDATED', 'DEPARTMENT_DELETED', + 'DEPARTMENT_MEMBER_ADDED', 'DEPARTMENT_MEMBER_REMOVED', +] + + +def upgrade(): + for val in MISSING_VALUES: + op.execute(f""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_enum + WHERE enumlabel = '{val}' + AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction') + ) THEN + ALTER TYPE auditaction ADD VALUE '{val}'; + END IF; + END$$; + """) + + +def downgrade(): + # PostgreSQL does not support removing enum values; downgrade is a no-op. + pass diff --git a/migrations/versions/010_password_reset_email_verify.py b/migrations/versions/010_password_reset_email_verify.py new file mode 100644 index 0000000..efb836e --- /dev/null +++ b/migrations/versions/010_password_reset_email_verify.py @@ -0,0 +1,50 @@ +"""add password reset and email verification token tables + +Revision ID: 010_password_reset_email_verify +Revises: 009_sync_auditaction_enum +Create Date: 2025-01-01 00:00:00.000000 +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '010_password_reset_email_verify' +down_revision = '009' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + 'password_reset_tokens', + sa.Column('id', sa.String(36), primary_key=True, nullable=False), + sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False), + sa.Column('token', sa.String(128), nullable=False, unique=True), + sa.Column('expires_at', sa.DateTime, nullable=False), + sa.Column('used_at', sa.DateTime, nullable=True), + sa.Column('created_at', sa.DateTime, nullable=False), + sa.Column('updated_at', sa.DateTime, nullable=False), + sa.Column('deleted_at', sa.DateTime, nullable=True), + ) + op.create_index('ix_password_reset_tokens_user_id', 'password_reset_tokens', ['user_id']) + op.create_index('ix_password_reset_tokens_token', 'password_reset_tokens', ['token']) + + op.create_table( + 'email_verification_tokens', + sa.Column('id', sa.String(36), primary_key=True, nullable=False), + sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False), + sa.Column('token', sa.String(128), nullable=False, unique=True), + sa.Column('expires_at', sa.DateTime, nullable=False), + sa.Column('used_at', sa.DateTime, nullable=True), + sa.Column('created_at', sa.DateTime, nullable=False), + sa.Column('updated_at', sa.DateTime, nullable=False), + sa.Column('deleted_at', sa.DateTime, nullable=True), + ) + op.create_index('ix_email_verification_tokens_user_id', 'email_verification_tokens', ['user_id']) + op.create_index('ix_email_verification_tokens_token', 'email_verification_tokens', ['token']) + + +def downgrade(): + op.drop_table('email_verification_tokens') + op.drop_table('password_reset_tokens') diff --git a/migrations/versions/011_org_invite_tokens.py b/migrations/versions/011_org_invite_tokens.py new file mode 100644 index 0000000..003da63 --- /dev/null +++ b/migrations/versions/011_org_invite_tokens.py @@ -0,0 +1,38 @@ +"""add org_invite_tokens table + +Revision ID: 011_org_invite_tokens +Revises: 010_password_reset_email_verify +Create Date: 2025-01-01 00:00:00.000000 +""" +from alembic import op +import sqlalchemy as sa + + +revision = '011_org_invite_tokens' +down_revision = '010_password_reset_email_verify' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + 'org_invite_tokens', + sa.Column('id', sa.String(36), primary_key=True, nullable=False), + sa.Column('organization_id', sa.String(36), sa.ForeignKey('organizations.id', ondelete='CASCADE'), nullable=False), + sa.Column('invited_by_id', sa.String(36), sa.ForeignKey('users.id', ondelete='SET NULL'), nullable=True), + sa.Column('email', sa.String(255), nullable=False), + sa.Column('role', sa.String(64), nullable=False, server_default='member'), + sa.Column('token', sa.String(128), nullable=False, unique=True), + sa.Column('expires_at', sa.DateTime, nullable=False), + sa.Column('accepted_at', sa.DateTime, nullable=True), + sa.Column('created_at', sa.DateTime, nullable=False), + sa.Column('updated_at', sa.DateTime, nullable=False), + sa.Column('deleted_at', sa.DateTime, nullable=True), + ) + op.create_index('ix_org_invite_tokens_organization_id', 'org_invite_tokens', ['organization_id']) + op.create_index('ix_org_invite_tokens_email', 'org_invite_tokens', ['email']) + op.create_index('ix_org_invite_tokens_token', 'org_invite_tokens', ['token']) + + +def downgrade(): + op.drop_table('org_invite_tokens') diff --git a/migrations/versions/012_ca_nullable_org_and_cert_serial.py b/migrations/versions/012_ca_nullable_org_and_cert_serial.py new file mode 100644 index 0000000..8b7dd0e --- /dev/null +++ b/migrations/versions/012_ca_nullable_org_and_cert_serial.py @@ -0,0 +1,33 @@ +"""Make CA.organization_id nullable (system CA) and add cert_id to sign response + +Revision ID: 012_ca_nullable_org_and_cert_serial +Revises: 011_org_invite_tokens +Create Date: 2025-01-01 00:00:00.000000 +""" +from alembic import op +import sqlalchemy as sa + + +revision = '012_ca_nullable_org' +down_revision = '011_org_invite_tokens' +branch_labels = None +depends_on = None + + +def upgrade(): + # Allow CA records without an org (e.g. the global system-config CA) + with op.batch_alter_table('cas', schema=None) as batch_op: + batch_op.alter_column( + 'organization_id', + existing_type=sa.String(36), + nullable=True, + ) + + +def downgrade(): + with op.batch_alter_table('cas', schema=None) as batch_op: + batch_op.alter_column( + 'organization_id', + existing_type=sa.String(36), + nullable=False, + ) diff --git a/migrations/versions/013_add_ca_type.py b/migrations/versions/013_add_ca_type.py new file mode 100644 index 0000000..1df4bc2 --- /dev/null +++ b/migrations/versions/013_add_ca_type.py @@ -0,0 +1,42 @@ +"""Add ca_type column to cas table (user/host). + +Revision ID: 013 +Revises: d34bfb72844e +Create Date: 2026-02-28 23:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '013' +down_revision = 'd34bfb72844e' +branch_labels = None +depends_on = None + + +def upgrade(): + # Create the enum type first (PostgreSQL requires this) + ca_type_enum = sa.Enum('user', 'host', name='ca_type_enum') + ca_type_enum.create(op.get_bind(), checkfirst=True) + + # Add ca_type column with a default of 'user' so existing CAs stay valid + op.add_column( + 'cas', + sa.Column( + 'ca_type', + ca_type_enum, + nullable=False, + server_default='user', + ), + ) + + +def downgrade(): + op.drop_column('cas', 'ca_type') + # Drop the enum type (PostgreSQL only; SQLite ignores) + try: + op.execute("DROP TYPE IF EXISTS ca_type_enum") + except Exception: + pass diff --git a/migrations/versions/014_add_dept_cert_policy.py b/migrations/versions/014_add_dept_cert_policy.py new file mode 100644 index 0000000..58b6cb3 --- /dev/null +++ b/migrations/versions/014_add_dept_cert_policy.py @@ -0,0 +1,44 @@ +"""add_department_cert_policies + +Adds the department_cert_policies table which stores per-department +SSH certificate issuance rules: + - whether users may choose their own expiry + - default and maximum expiry durations + - allowed SSH certificate extensions +""" + +from alembic import op +import sqlalchemy as sa + +revision = "014_add_dept_cert_policy" +down_revision = "013" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "department_cert_policies", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("department_id", sa.String(36), sa.ForeignKey("departments.id"), nullable=False, unique=True), + # Whether users are allowed to specify their own expiry (up to max) + sa.Column("allow_user_expiry", sa.Boolean(), nullable=False, server_default="0"), + # Default validity in hours (used when user doesn't specify, or not allowed to) + sa.Column("default_expiry_hours", sa.Integer(), nullable=False, server_default="1"), + # Hard cap on validity; admin cannot be exceeded + sa.Column("max_expiry_hours", sa.Integer(), nullable=False, server_default="24"), + # JSON list of extension names that are enabled for this department + # e.g. ["permit-pty", "permit-agent-forwarding"] + sa.Column("allowed_extensions", sa.JSON(), nullable=False, server_default='["permit-pty","permit-agent-forwarding","permit-X11-forwarding","permit-port-forwarding","permit-user-rc"]'), + # Admin-defined custom extension names beyond the standard five + sa.Column("custom_extensions", sa.JSON(), nullable=False, server_default="[]"), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + ) + op.create_index("idx_dept_cert_policy_dept", "department_cert_policies", ["department_id"]) + + +def downgrade(): + op.drop_index("idx_dept_cert_policy_dept", "department_cert_policies") + op.drop_table("department_cert_policies") diff --git a/migrations/versions/015_add_user_suspend_audit_actions.py b/migrations/versions/015_add_user_suspend_audit_actions.py new file mode 100644 index 0000000..06834da --- /dev/null +++ b/migrations/versions/015_add_user_suspend_audit_actions.py @@ -0,0 +1,37 @@ +"""Add USER_SUSPEND and USER_UNSUSPEND to auditaction enum. + +Revision ID: 015_add_user_suspend_audit_actions +Revises: 014_add_dept_cert_policy +Create Date: 2026-03-02 + +USER_SUSPEND and USER_UNSUSPEND were added to the Python AuditAction enum +but were never synced to the PostgreSQL auditaction type, causing a +DataError (invalid enum value) whenever an admin suspends or unsuspends a user. +""" +from alembic import op + +revision = "015_user_suspend_audit" +down_revision = "014_add_dept_cert_policy" +branch_labels = None +depends_on = None + + +def upgrade(): + for val in ("USER_SUSPEND", "USER_UNSUSPEND"): + op.execute(f""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_enum + WHERE enumlabel = '{val}' + AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction') + ) THEN + ALTER TYPE auditaction ADD VALUE '{val}'; + END IF; + END$$; + """) + + +def downgrade(): + # PostgreSQL does not support removing enum values; downgrade is a no-op. + pass diff --git a/migrations/versions/016_encrypt_existing_ca_keys.py b/migrations/versions/016_encrypt_existing_ca_keys.py new file mode 100644 index 0000000..91acdda --- /dev/null +++ b/migrations/versions/016_encrypt_existing_ca_keys.py @@ -0,0 +1,168 @@ +"""Encrypt existing plaintext CA private keys at rest. + +Revision ID: 016_encrypt_existing_ca_keys +Revises: 015_add_user_suspend_audit_actions +Create Date: 2026-03-02 + +All CA private keys created before this migration were stored as plaintext PEM +strings in the ``cas.private_key`` column. This migration detects those rows +(by checking for the absence of the ``$fernet$`` prefix that encrypted values +carry) and re-encrypts them with the key derived from ``CA_ENCRYPTION_KEY``. + +The migration is safe to re-run: already-encrypted rows are left untouched. + +Prerequisites +------------- +``CA_ENCRYPTION_KEY`` must be set in the environment before running this +migration. The same value must be configured for the running application. + +To roll back to plaintext (downgrade): +The ``downgrade()`` function decrypts all rows back to plaintext PEM. This is +provided only for emergency rollback and should not be used in production once +the system has been running with encrypted keys. +""" +import os +import base64 +import hashlib +import logging + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + +# Alembic revision identifiers +revision = "016_encrypt_ca_keys" +down_revision = "015_user_suspend_audit" +branch_labels = None +depends_on = None + +_FERNET_PREFIX = "$fernet$" + + +def _get_fernet(): + """Build a Fernet instance from CA_ENCRYPTION_KEY env var.""" + from cryptography.fernet import Fernet + + raw_key = os.environ.get("CA_ENCRYPTION_KEY") + if not raw_key: + raise RuntimeError( + "CA_ENCRYPTION_KEY environment variable is not set. " + "Set it before running this migration." + ) + key_bytes = base64.urlsafe_b64encode(hashlib.sha256(raw_key.encode()).digest()) + return Fernet(key_bytes) + + +def upgrade(): + """Encrypt plaintext CA private keys.""" + bind = op.get_bind() + session = Session(bind=bind) + + try: + fernet = _get_fernet() + except RuntimeError as exc: + raise RuntimeError(str(exc)) from exc + + # Fetch all non-deleted CA rows + rows = session.execute( + sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL") + ).fetchall() + + encrypted_count = 0 + skipped_count = 0 + + for row in rows: + ca_id, private_key = row[0], row[1] + + if not private_key: + logger.warning(f"CA {ca_id} has empty private_key — skipping") + skipped_count += 1 + continue + + if private_key.startswith(_FERNET_PREFIX): + # Already encrypted + skipped_count += 1 + continue + + # Encrypt + try: + token = fernet.encrypt(private_key.encode()).decode() + encrypted_value = f"{_FERNET_PREFIX}{token}" + session.execute( + sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"), + {"pk": encrypted_value, "id": ca_id}, + ) + encrypted_count += 1 + logger.info(f"Encrypted private key for CA {ca_id}") + except Exception as exc: + session.rollback() + raise RuntimeError( + f"Failed to encrypt private key for CA {ca_id}: {exc}" + ) from exc + + session.commit() + logger.info( + f"CA key encryption migration complete: " + f"{encrypted_count} encrypted, {skipped_count} skipped" + ) + print( + f" [016_encrypt_ca_keys] {encrypted_count} CA private key(s) encrypted, " + f"{skipped_count} already encrypted or empty." + ) + + +def downgrade(): + """Decrypt CA private keys back to plaintext (emergency rollback only).""" + bind = op.get_bind() + session = Session(bind=bind) + + try: + fernet = _get_fernet() + except RuntimeError as exc: + raise RuntimeError(str(exc)) from exc + + rows = session.execute( + sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL") + ).fetchall() + + decrypted_count = 0 + skipped_count = 0 + + for row in rows: + ca_id, private_key = row[0], row[1] + + if not private_key or not private_key.startswith(_FERNET_PREFIX): + skipped_count += 1 + continue + + token = private_key[len(_FERNET_PREFIX):] + try: + from cryptography.fernet import InvalidToken + try: + plaintext = fernet.decrypt(token.encode()).decode() + except InvalidToken as exc: + raise RuntimeError( + f"Downgrade failed: cannot decrypt CA {ca_id} — wrong key or corrupted data." + ) from exc + + session.execute( + sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"), + {"pk": plaintext, "id": ca_id}, + ) + decrypted_count += 1 + logger.warning(f"Decrypted (plaintext restore) private key for CA {ca_id}") + except RuntimeError: + session.rollback() + raise + + session.commit() + logger.warning( + f"CA key decryption (downgrade) complete: " + f"{decrypted_count} decrypted, {skipped_count} skipped" + ) + print( + f" [016_encrypt_ca_keys] DOWNGRADE: {decrypted_count} CA private key(s) " + f"decrypted to plaintext. WARNING: keys are now unencrypted at rest." + ) diff --git a/migrations/versions/017_add_ca_serial_counter.py b/migrations/versions/017_add_ca_serial_counter.py new file mode 100644 index 0000000..94b85e9 --- /dev/null +++ b/migrations/versions/017_add_ca_serial_counter.py @@ -0,0 +1,37 @@ +"""Add monotonic serial counter to CAs table. + +Each CA now owns a `next_serial_number` (BigInteger) that is atomically +incremented every time a certificate is signed. This guarantees: + - Serials are unique per CA + - Serials are monotonically increasing (auditable, no gaps by accident) + - The value embedded in the OpenSSH certificate matches what is stored + in the `ssh_certificates.serial` column + +Revision ID: 017_add_ca_serial_counter +Revises: 016_encrypt_ca_keys +Create Date: 2026-03-02 +""" +from alembic import op +import sqlalchemy as sa + +revision = "017_add_ca_serial_counter" +down_revision = "016_encrypt_ca_keys" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("cas", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "next_serial_number", + sa.BigInteger(), + nullable=False, + server_default="1", + ) + ) + + +def downgrade(): + with op.batch_alter_table("cas", schema=None) as batch_op: + batch_op.drop_column("next_serial_number") diff --git a/migrations/versions/018_add_ownership_and_hard_delete_audit_actions.py b/migrations/versions/018_add_ownership_and_hard_delete_audit_actions.py new file mode 100644 index 0000000..7cbbd32 --- /dev/null +++ b/migrations/versions/018_add_ownership_and_hard_delete_audit_actions.py @@ -0,0 +1,52 @@ +"""Add ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE to auditaction enum. + +Revision ID: 018_audit_enum_values +Revises: 017_add_ca_serial_counter +Create Date: 2026-03-02 + +ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE were added to the Python +AuditAction enum but were never synced to the PostgreSQL auditaction type, +causing a DataError (invalid enum value) when transferring org ownership +or hard-deleting a user. +""" +from alembic import op + +revision = "018_audit_enum_values" +down_revision = "017_add_ca_serial_counter" +branch_labels = None +depends_on = None + + +def upgrade(): + # ALTER TYPE ... ADD VALUE cannot run inside a transaction block in PostgreSQL. + # Alembic has already opened a transaction on the connection by the time our + # upgrade() runs, so we must: + # 1. Roll back that open transaction on the raw psycopg2 connection. + # 2. Switch to autocommit so the ALTER TYPE runs outside any transaction. + # 3. Restore the previous state afterwards. + conn = op.get_bind() + # SQLAlchemy 2.x: conn.connection is a _ConnectionFairy; .driver_connection is psycopg2 + fairy = conn.connection + raw = getattr(fairy, "driver_connection", None) or getattr(fairy, "dbapi_connection", fairy) + # Roll back the open transaction so psycopg2 allows us to change autocommit. + raw.rollback() + old_autocommit = raw.autocommit + raw.autocommit = True + try: + with raw.cursor() as cur: + for val in ("ORG_OWNERSHIP_TRANSFERRED", "USER_HARD_DELETE"): + cur.execute( + "SELECT 1 FROM pg_enum " + "WHERE enumlabel = %s " + "AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')", + (val,), + ) + if not cur.fetchone(): + cur.execute(f"ALTER TYPE auditaction ADD VALUE '{val}'") + finally: + raw.autocommit = old_autocommit + + +def downgrade(): + # PostgreSQL does not support removing enum values; downgrade is a no-op. + pass diff --git a/migrations/versions/d34bfb72844e_add_activation_fields_and_ca_permissions.py b/migrations/versions/d34bfb72844e_add_activation_fields_and_ca_permissions.py new file mode 100644 index 0000000..83d9f72 --- /dev/null +++ b/migrations/versions/d34bfb72844e_add_activation_fields_and_ca_permissions.py @@ -0,0 +1,50 @@ +"""add_activation_fields_and_ca_permissions + +Revision ID: d34bfb72844e +Revises: 012_ca_nullable_org +Create Date: 2026-02-28 18:06:47.328552 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'd34bfb72844e' +down_revision = '012_ca_nullable_org' +branch_labels = None +depends_on = None + + +def upgrade(): + # Create ca_permissions table + op.create_table( + 'ca_permissions', + sa.Column('ca_id', sa.String(length=36), nullable=False), + sa.Column('user_id', sa.String(length=36), nullable=False), + sa.Column('permission', sa.String(length=50), nullable=False), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('ca_id', 'user_id', name='uix_ca_permission'), + ) + op.create_index('ix_ca_permissions_ca_id', 'ca_permissions', ['ca_id'], unique=False) + op.create_index('ix_ca_permissions_user_id', 'ca_permissions', ['user_id'], unique=False) + + # Add activation columns to users + op.add_column('users', sa.Column('activated', sa.Boolean(), nullable=False, + server_default=sa.text('true'))) + op.add_column('users', sa.Column('activation_key', sa.String(length=128), nullable=True)) + op.create_index('ix_users_activation_key', 'users', ['activation_key'], unique=True) + + +def downgrade(): + op.drop_index('ix_users_activation_key', table_name='users') + op.drop_column('users', 'activation_key') + op.drop_column('users', 'activated') + op.drop_index('ix_ca_permissions_user_id', table_name='ca_permissions') + op.drop_index('ix_ca_permissions_ca_id', table_name='ca_permissions') + op.drop_table('ca_permissions') diff --git a/requirements/base.txt b/requirements/base.txt index 2f9c20d..1f6591d 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -14,7 +14,7 @@ Flask-Marshmallow==0.15.0 marshmallow-sqlalchemy==0.29.0 # Security -bcrypt==4.1.2 +bcrypt==4.2.0 Flask-Bcrypt==1.0.1 pyotp==2.9.0 @@ -24,7 +24,7 @@ cbor2==5.6.0 # JWT / OIDC PyJWT==2.8.0 -cryptography==41.0.7 +cryptography==42.0.7 # CORS Flask-CORS==4.0.0 @@ -47,4 +47,7 @@ Flask-Limiter==3.5.0 # Logging python-json-logger==2.0.7 -qrcode[pil] \ No newline at end of file +qrcode[pil] + +# SSH CA Certificate signing +sshkey-tools==0.11.3 diff --git a/requirements/development.txt b/requirements/development.txt index bccb4ec..8c626fd 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -20,3 +20,25 @@ watchdog==3.0.0 # Documentation sphinx==7.2.6 + +# Web framework & Database +Flask==3.0.0 +Flask-SQLAlchemy==3.1.1 +Flask-Migrate==4.0.5 +sqlalchemy-cockroachdb==2.0.3 + +# Utilities +colorlog==6.8.0 +coloredlogs==15.0.1 +prettytable==3.10.2 +tabulate==0.9.0 +requests==2.31.0 +pytz==2023.3 +python-dotenv==1.0.0 +pydantic==2.5.0 +PyJWT==2.8.0 +cryptography==42.0.7 +pycryptodome==3.20.0 +psycopg2-binary==2.9.9 +sshkey-tools==0.11.3 +sendgrid==6.11.0 diff --git a/scripts/seed_data.py b/scripts/seed_data.py index 97450fa..e8c1d04 100644 --- a/scripts/seed_data.py +++ b/scripts/seed_data.py @@ -15,11 +15,11 @@ load_dotenv() from gatehouse_app import create_app from gatehouse_app.extensions import db -from gatehouse_app.models.user import User -from gatehouse_app.models.organization import Organization -from gatehouse_app.models.organization_member import OrganizationMember -from gatehouse_app.models.authentication_method import AuthenticationMethod -from gatehouse_app.models.oidc_client import OIDCClient +from gatehouse_app.models.user.user import User +from gatehouse_app.models.organization.organization import Organization +from gatehouse_app.models.organization.organization_member import OrganizationMember +from gatehouse_app.models.auth.authentication_method import AuthenticationMethod +from gatehouse_app.models.oidc.oidc_client import OIDCClient from gatehouse_app.services.auth_service import AuthService from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.utils.constants import OrganizationRole, UserStatus, AuthMethodType