From b2212ab4d63ec1c1302678d66fe0e42e94cce6f3 Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Fri, 27 Feb 2026 21:59:01 +0545 Subject: [PATCH] Feat: Added CA-merged with Securid-Principals, Depart, Client-CLI --- client/gatehouse-cli.py | 154 +++-- etc/ssh_ca.conf | 114 ++++ gatehouse_app/__init__.py | 6 + gatehouse_app/api/v1/__init__.py | 4 +- gatehouse_app/api/v1/external_auth.py | 136 +++- gatehouse_app/api/v1/organizations.py | 556 +++++++++++++++- gatehouse_app/api/v1/ssh.py | 615 ++++++++++++++++++ gatehouse_app/config/ssh_ca_config.py | 271 ++++++++ gatehouse_app/exceptions/__init__.py | 29 + gatehouse_app/exceptions/base.py | 3 +- gatehouse_app/exceptions/ssh_exceptions.py | 93 +++ gatehouse_app/models/__init__.py | 17 + gatehouse_app/models/ca.py | 155 +++++ gatehouse_app/models/certificate_audit_log.py | 83 +++ gatehouse_app/models/organization.py | 3 + gatehouse_app/models/ssh_certificate.py | 175 +++++ gatehouse_app/models/ssh_key.py | 96 +++ gatehouse_app/models/user.py | 12 + gatehouse_app/services/oauth_flow_service.py | 27 +- .../services/ssh_ca_signing_service.py | 333 ++++++++++ gatehouse_app/services/ssh_key_service.py | 373 +++++++++++ gatehouse_app/utils/constants.py | 41 ++ gatehouse_app/utils/crypto.py | 128 ++++ migrations/versions/007_add_ssh_ca_models.py | 173 +++++ .../versions/008_fix_authmethodtype_enum.py | 53 ++ .../versions/009_sync_auditaction_enum.py | 61 ++ .../012_ca_nullable_org_and_cert_serial.py | 33 + requirements/base.txt | 5 +- requirements/development.txt | 22 + 29 files changed, 3718 insertions(+), 53 deletions(-) create mode 100644 etc/ssh_ca.conf create mode 100644 gatehouse_app/api/v1/ssh.py create mode 100644 gatehouse_app/config/ssh_ca_config.py create mode 100644 gatehouse_app/exceptions/ssh_exceptions.py create mode 100644 gatehouse_app/models/ca.py create mode 100644 gatehouse_app/models/certificate_audit_log.py create mode 100644 gatehouse_app/models/ssh_certificate.py create mode 100644 gatehouse_app/models/ssh_key.py create mode 100644 gatehouse_app/services/ssh_ca_signing_service.py create mode 100644 gatehouse_app/services/ssh_key_service.py create mode 100644 gatehouse_app/utils/crypto.py create mode 100644 migrations/versions/007_add_ssh_ca_models.py create mode 100644 migrations/versions/008_fix_authmethodtype_enum.py create mode 100644 migrations/versions/009_sync_auditaction_enum.py create mode 100644 migrations/versions/012_ca_nullable_org_and_cert_serial.py diff --git a/client/gatehouse-cli.py b/client/gatehouse-cli.py index a5d87f4..b45ce1a 100755 --- a/client/gatehouse-cli.py +++ b/client/gatehouse-cli.py @@ -2,6 +2,7 @@ import base64 from datetime import datetime import os +import sys import webbrowser import requests import argparse @@ -22,13 +23,12 @@ import base64 load_dotenv() # Get the API_URL from the environment variables -SIGN_URL = os.getenv("SIGN_URL", "http://localhost:1234") +SIGN_URL = os.getenv("SIGN_URL", "http://localhost:5000") LISTENER_HOST_NAME = "127.0.0.1" LISTENER_SERVER_PORT = 8250 -CA_API_HOST = "127.0.0.1" -CA_SERVER_PORT = 1234 -CACHE_FILE = 'token_cache.json' ###need to change it to secure location and permissions if used in production -CERT_FILE_PATH = "/tmp/ssl-cert" +CACHE_FILE = os.path.expanduser('~/.gatehouse/token_cache.json') +os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True) +CERT_FILE_PATH = "/tmp/ssh-cert" CHALLENGE_FILE_PATH = "/tmp/challenge.txt" CHALLENGE_SIG_FILE_PATH = "/tmp/challenge.txt.sig" @@ -116,7 +116,7 @@ def decode_and_validate_token(token): return True # Token is valid except Exception as e: - logger.error(f"Token validation failed: {e}") + logger.debug(f"Token validation failed: {e}") return False def request_token(): @@ -129,19 +129,34 @@ def request_token(): logger.debug("Token loaded from cache: %s", token) # Validate the cached token, if it exists - if token and decode_and_validate_token(token): - logger.info("Cached token is valid. Using cached token.") - return token - - logger.info("No valid cached token found, proceeding to request a new token.") - token = "" + if token: + try: + if decode_and_validate_token(token): + logger.info("Cached token is valid. Using cached token.") + return token + except Exception: + pass + # Try opaque token via /auth/me + try: + r = requests.get( + f"{SIGN_URL}/api/v1/auth/me", + headers={"Authorization": f"Bearer {token}"}, + timeout=5, + ) + if r.status_code == 200: + logger.info("Cached session token is valid. Using cached token.") + return token + except Exception: + pass + logger.info("Cached token is expired or invalid, requesting a new token.") + token = "" # Prepare the redirect URL for the token request redirect_url = f"http://{LISTENER_HOST_NAME}:{LISTENER_SERVER_PORT}/?token=" logger.info("Redirect URL: %s", redirect_url) # Construct the token request URL - token_url = f"{SIGN_URL}/token_please?redirect_url={redirect_url}" + token_url = f"{SIGN_URL}/api/v1/token_please?redirect_url={redirect_url}" logger.info("Token request URL: %s", token_url) # Start the web server to handle the token response @@ -168,10 +183,10 @@ def get_activated_ssh_key(): 'Authorization': f'Bearer {token}', 'Content-Type': 'application/json' } - response = requests.get(f"{SIGN_URL}/api/ssh-keys", headers=headers) + response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=headers) if response.status_code == 200: - keys = response.json().get('ssh_keys', []) + keys = response.json().get('keys', []) verified_keys = [key for key in keys if key['verified']] if not verified_keys: @@ -179,8 +194,19 @@ def get_activated_ssh_key(): exit(1) if len(verified_keys) > 1: - logger.error("Multiple verified SSH keys found. Please specify CERT_ID.") - exit(1) + # If running interactively, let the user pick; otherwise use the most recently added key + if sys.stdout.isatty(): + print("\nMultiple verified SSH keys found. Please choose one:") + for i, k in enumerate(verified_keys): + print(f" [{i+1}] {k['id'][:8]}... fingerprint={k.get('fingerprint','?')} name={k.get('key_comment','?')}") + try: + choice = int(input("Enter number: ").strip()) - 1 + if 0 <= choice < len(verified_keys): + return verified_keys[choice]['id'] + except (ValueError, EOFError): + pass + logger.info("Multiple verified SSH keys found; using the most recently added one.") + verified_keys.sort(key=lambda k: k.get('created_at', ''), reverse=True) return verified_keys[0]['id'] @@ -193,26 +219,35 @@ def get_activated_ssh_key(): exit(1) -def request_certificate(): +def request_certificate(principals=None): CERT_ID = os.getenv("CERT_ID") or get_activated_ssh_key() + if not principals: + env_principals = os.getenv("PRINCIPALS") + if env_principals: + principals = [p.strip() for p in env_principals.split(',')] + else: + principals = [os.getlogin()] + headers = { 'content-type': 'application/json', "Authorization": "bearer " + token } payload = { - 'cert_id': CERT_ID + 'cert_id': CERT_ID, + 'principals': principals, } try: - response = requests.post(f"{SIGN_URL}/sign_cert", json=payload, headers=headers) + response = requests.post(f"{SIGN_URL}/api/v1/ssh/sign", json=payload, headers=headers) - if response.status_code == 200: + if response.status_code == 201: json_result = response.json() with open(CERT_FILE_PATH, 'w') as f: f.write(json_result['certificate']) logger.info(f"Certificate signed successfully, located at {CERT_FILE_PATH}") + logger.info(f"Valid for principals: {', '.join(json_result.get('principals', principals))}") logger.info("You can login to your destination server with the following command") logger.info(f"\tssh user@server -o CertificateFile={CERT_FILE_PATH}") else: @@ -238,14 +273,14 @@ def generate_and_sign_challenge(ssh_key_file,key_id): # Send the POST request response = requests.get( - f"http://{CA_API_HOST}:{CA_SERVER_PORT}/api/ssh-key/{key_id}/validationData", + f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify", headers=headers ) if response.status_code!=200: logger.error(f"Server returned unexpected code {response.status_code}") return False - challenge_text=response.json()['validationText']+"\n" + challenge_text=response.json().get('challenge_text', response.json().get('validationText', ''))+"\n" except Exception as e: logger.error(f"Unable to fetch SSH Key validation data {e}") @@ -291,7 +326,7 @@ def submit_signature_validation(signature, key_id): # Send the POST request response = requests.post( - f"http://{CA_API_HOST}:{CA_SERVER_PORT}/api/ssh-key/{key_id}/validate", + f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify", headers=headers, json=payload ) @@ -317,12 +352,12 @@ def remove_ssh_key(key_id=None): } # List keys first - response = requests.get(f"{SIGN_URL}/api/ssh-keys", headers=headers) + response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=headers) if response.status_code != 200: logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}") exit(1) - keys = response.json().get('ssh_keys', []) + keys = response.json().get('keys', []) if not keys: logger.info("No SSH keys found for your user.") return @@ -359,7 +394,7 @@ def remove_ssh_key(key_id=None): exit(1) for k in keys_to_delete: - del_response = requests.delete(f"{SIGN_URL}/api/ssh-key/{k['id']}", headers=headers) + del_response = requests.delete(f"{SIGN_URL}/api/v1/ssh/keys/{k['id']}", headers=headers) if del_response.status_code == 200: logger.info(f"Key {k['id']} removed successfully.") else: @@ -381,20 +416,40 @@ def add_ssh_key(ssh_key_file): headers = { 'Authorization': f'Bearer {token}', 'Content-Type': 'application/json' - } - - ssh_key = ssh_key_file.read().decode('utf-8') + } + + if hasattr(ssh_key_file, 'read'): + # File object (e.g. argparse.FileType('rb')) + key_bytes = ssh_key_file.read() + key_path = ssh_key_file.name + elif isinstance(ssh_key_file, bytes): + key_bytes = ssh_key_file + key_path = None + else: + # String path + key_path = str(ssh_key_file) + with open(key_path, 'rb') as f: + key_bytes = f.read() + + ssh_key = key_bytes.decode('utf-8').strip() + payload = { 'description': 'Added via gatehouse CLI tool', 'key': ssh_key } - response = requests.post(f"{SIGN_URL}/api/ssh-key/add", json=payload, headers=headers) + response = requests.post(f"{SIGN_URL}/api/v1/ssh/keys", json=payload, headers=headers) - if response.status_code == 200: - ssh_key_id=response.json()['key_id'] + if response.status_code == 201: + ssh_key_id=response.json()['id'] logger.info(f"SSH key {ssh_key_id} added successfully") - generate_and_sign_challenge(ssh_key_file.name,ssh_key_id) + if key_path: + # Strip .pub suffix to get the private key path for signing + private_key_path = key_path[:-4] if key_path.endswith('.pub') else key_path + generate_and_sign_challenge(private_key_path, ssh_key_id) + else: + logger.warning("No key file path available — skipping auto-verification. " + "Run with -k to enable automatic key verification.") else: logger.error(f"Failed to add SSH key: {response.status_code} - {response.text}") @@ -431,13 +486,15 @@ if __name__ == "__main__": parser.add_argument("-a", "--add-key", action='store_true', default=False, help="Add SSH key to the server") parser.add_argument("-c", "--check-cert", action='store_true', default=False, help="Check the certificate, if it's valid exit 0, if it's invalid exit 1") parser.add_argument("-r", "--request-cert", action='store_true', default=False, help="Request that gatehouse sign a new certificate for you based on an SSH public key on file in your profile") + parser.add_argument("--principals", nargs='+', metavar='PRINCIPAL', help="Unix usernames for the certificate (default: current OS user)") parser.add_argument("--clear-cache", action='store_true', default=False, help="Remove the cached authentication token") parser.add_argument("--remove-key", nargs='?', const='', metavar='KEY_ID', help="Remove an SSH key from your profile. Omit KEY_ID to pick interactively.") + parser.add_argument("--list-keys", action='store_true', default=False, help="List SSH keys in your profile") args = parser.parse_args() - # Ensure that one of --check-cert, --request-cert, or --add-key is provided - if not (args.check_cert or args.request_cert or args.add_key or args.clear_cache or args.remove_key is not None): - parser.error("At least one of --check-cert, --request-cert, --add-key, --validate-key, or --clear-cache must be provided.") + if not (args.check_cert or args.request_cert or args.add_key or args.clear_cache + or args.remove_key is not None or args.list_keys): + parser.error("At least one of --check-cert, --request-cert, --add-key, --list-keys, --remove-key, or --clear-cache must be provided.") # Retrieve SSH key from environment variables if not provided via CLI @@ -456,6 +513,25 @@ if __name__ == "__main__": remove_ssh_key(args.remove_key if args.remove_key else None) exit(0) + if args.list_keys: + request_token() + response = requests.get( + f"{SIGN_URL}/api/v1/ssh/keys", + headers={"Authorization": f"Bearer {token}"}, + ) + if response.status_code == 200: + data = response.json() + keys = data.get('keys', []) + if not keys: + print("No SSH keys found in your profile.") + else: + for k in keys: + verified = "✓ verified" if k.get('verified') else "✗ unverified" + print(f" {k['id']} {verified} {k.get('description', '')} (added {k['created_at'][:10]})") + else: + logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}") + exit(0) + if args.add_key: request_token() @@ -476,10 +552,10 @@ if __name__ == "__main__": if args.force: request_token() logger.info("Forcing renewal of certificate") - request_certificate() + request_certificate(principals=args.principals) if checkCert() == 1: request_token() - request_certificate() + request_certificate(principals=args.principals) exit(0) diff --git a/etc/ssh_ca.conf b/etc/ssh_ca.conf new file mode 100644 index 0000000..babe4dc --- /dev/null +++ b/etc/ssh_ca.conf @@ -0,0 +1,114 @@ + +[default] +# Certificate validity period (in hours) +# Default: 1 hour +cert_validity_hours=1 + +# Maximum certificate validity allowed (in hours) +# Default: 24 hours +# Prevents users from requesting certificates valid longer than this +max_cert_validity_hours=24 + +# Certificate Request Limits +# Maximum number of certificates per user +max_certs_per_user=100 + +# Certificate revocation list (CRL) configuration +crl_enabled=true +# CRL endpoint URL - set to your domain where CRL is served +crl_endpoint=https://ca.example.com/crl +# CRL refresh interval (in hours) +crl_refresh_hours=24 + +# CA Key Configuration +# Default key type for new CAs (ed25519, rsa, ecdsa) +default_key_type=ed25519 + +# RSA key size (if using RSA) +rsa_key_bits=4096 + +# Private key encryption +# Method: kms (AWS Key Management Service) or local (for development only) +private_key_encryption=kms +# AWS KMS Key ID (only used if private_key_encryption=kms) +aws_kms_key_id=${SSH_CA_KMS_KEY_ID} + +# SSH Certificate Extensions +# Default extensions to add to certificates +extensions_enabled=true +extensions=permit-X11-forwarding,permit-agent-forwarding,permit-pty,permit-port-forwarding,permit-user-rc + +# Critical Options +# Critical options to add to certificates (rarely needed) +critical_options_enabled=false + +# Certificate Field Limits +# Maximum number of principals per certificate (SSH limitation is 256) +max_principals_per_cert=256 + +# Maximum length for key_id field +max_key_id_length=255 + +# Logging Configuration +# Log level for SSH CA operations (DEBUG, INFO, WARNING, ERROR) +log_level=INFO + +# Audit Configuration +# Log all certificate signing operations +audit_enabled=true + +# Security Configuration +# Require SSH key verification before issuing certificates +require_key_verification=true + +# Verification challenge max age (in hours) +verification_challenge_max_age=24 + +# Rate limiting for certificate signing +# Max certificates per minute per user +rate_limit_certs_per_minute=5 + +# Request timeout (in seconds) +request_timeout=30 + +# Cleanup Configuration +# Automatically delete unverified SSH keys after this many days +auto_delete_unverified_days=30 + +# Archive expired certificates after this many days +archive_expired_days=365 + +# CLI OAuth Configuration (for secuird-cli.py compatibility) +# OAuth token endpoint for CLI clients +oauth_token_endpoint=/api/v1/oauth2/token +# OAuth userinfo endpoint for CLI clients +oauth_userinfo_endpoint=/api/v1/oauth2/userinfo + +[development] +# Override settings for development environment +private_key_encryption=local +ca_key_path=/home/james/cory/secuird/certs/ca-users +log_level=DEBUG +cert_validity_hours=24 +max_cert_validity_hours=720 +rate_limit_certs_per_minute=100 +require_key_verification=false + +[production] +# Override settings for production environment +private_key_encryption=kms +log_level=WARNING +cert_validity_hours=1 +max_cert_validity_hours=24 +rate_limit_certs_per_minute=5 +require_key_verification=true + +[testing] +# Override settings for testing environment +private_key_encryption=local +log_level=DEBUG +cert_validity_hours=1 +max_cert_validity_hours=24 +rate_limit_certs_per_minute=100 +require_key_verification=true +audit_enabled=false diff --git a/gatehouse_app/__init__.py b/gatehouse_app/__init__.py index a4e1043..f17c784 100644 --- a/gatehouse_app/__init__.py +++ b/gatehouse_app/__init__.py @@ -2,6 +2,9 @@ import os import logging +from dotenv import load_dotenv +load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env')) + # Test debug logging - this should appear when running `flask run --debug` _root_logger = logging.getLogger(__name__) _root_logger.debug("[TEST] Debug logging is working!") @@ -239,3 +242,6 @@ def initialize_oidc_jwks(app): app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}") except Exception as e: app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}") + +# Create default app instance for gunicorn/wsgi +app = create_app() diff --git a/gatehouse_app/api/v1/__init__.py b/gatehouse_app/api/v1/__init__.py index b97cd94..14534ed 100644 --- a/gatehouse_app/api/v1/__init__.py +++ b/gatehouse_app/api/v1/__init__.py @@ -5,4 +5,6 @@ from flask import Blueprint api_v1_bp = Blueprint("api_v1", __name__) # Import route modules to register them -from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth +from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh + +api_v1_bp.register_blueprint(ssh.ssh_bp) diff --git a/gatehouse_app/api/v1/external_auth.py b/gatehouse_app/api/v1/external_auth.py index bb2969d..d68fd77 100644 --- a/gatehouse_app/api/v1/external_auth.py +++ b/gatehouse_app/api/v1/external_auth.py @@ -46,6 +46,33 @@ def _pop_oidc_bridge(oauth_state: str) -> str | None: pass return None + +def _store_cli_redirect(oauth_state: str, redirect_url: str) -> None: + """Store CLI redirect_url keyed by OAuth state (for /token_please flow).""" + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + rc.setex(f"oauth_cli_redirect:{oauth_state}", _OAUTH_BRIDGE_TTL, redirect_url) + except Exception: + pass + + +def _pop_cli_redirect(oauth_state: str) -> str | None: + """Retrieve and delete CLI redirect_url for the given OAuth state.""" + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + key = f"oauth_cli_redirect:{oauth_state}" + val = rc.get(key) + if val: + rc.delete(key) + return val.decode() if isinstance(val, bytes) else val + except Exception: + pass + return None + logger = logging.getLogger(__name__) @@ -69,6 +96,71 @@ def get_provider_type(provider: str) -> AuthMethodType: return PROVIDER_TYPE_MAP[provider_lower] +@api_v1_bp.route("/token_please", methods=["GET"]) +def token_please(): + """ + CLI token acquisition endpoint. + + Initiates an OAuth login flow and, on success, redirects the user's browser + to the CLI's local callback server (redirect_url) with the session token + appended, e.g.: http://127.0.0.1:8250/?token= + + This endpoint is designed for CLI clients that: + 1. Start a local HTTP server on LISTENER_SERVER_PORT (e.g. 8250) + 2. Open a browser to /api/v1/token_please?redirect_url=http://127.0.0.1:8250/?token= + 3. Wait for the browser to POST the token back to their local server + + Query parameters: + redirect_url: Local callback URL where the token will be appended + provider: OAuth provider to use (default: 'google') + """ + from urllib.parse import urlencode + from flask import current_app, redirect as flask_redirect + + redirect_url = request.args.get("redirect_url", "").strip() + provider = request.args.get("provider", "google").lower() + + if not redirect_url: + return api_response( + success=False, + message="redirect_url query parameter is required", + status=400, + error_type="MISSING_REDIRECT_URL", + ) + + # Validate redirect_url is localhost/127.0.0.1 (security: prevent open redirect) + from urllib.parse import urlparse as _urlparse + parsed = _urlparse(redirect_url) + if parsed.hostname not in ("localhost", "127.0.0.1"): + return api_response( + success=False, + message="redirect_url must point to localhost", + status=400, + error_type="INVALID_REDIRECT_URL", + ) + + try: + provider_type = get_provider_type(provider) + auth_url, state = OAuthFlowService.initiate_login_flow( + provider_type=provider_type, + organization_id=None, + redirect_uri=None, + ) + except (OAuthFlowError, ExternalAuthError) as e: + return api_response( + success=False, + message=getattr(e, "message", str(e)), + status=getattr(e, "status_code", 400), + error_type=getattr(e, "error_type", "OAUTH_ERROR"), + ) + + # Store the CLI redirect URL so the callback can use it + _store_cli_redirect(state, redirect_url) + + logger.info(f"CLI token_please: provider={provider}, redirect_url={redirect_url}, redirecting to OAuth") + return flask_redirect(auth_url, code=302) + + # ============================================================================= # Provider Configuration Endpoints (Admin) # ============================================================================= @@ -575,8 +667,6 @@ def initiate_oauth_authorize(provider: str): "state": "state_token" } """ - provider_type = get_provider_type(provider) - # Get query parameters - organization_id is now optional flow = request.args.get("flow", "login") redirect_uri = request.args.get("redirect_uri") @@ -592,7 +682,7 @@ def initiate_oauth_authorize(provider: str): ) try: - # Initiate flow - organization_id is now optional + provider_type = get_provider_type(provider) if flow == "login": auth_url, state = OAuthFlowService.initiate_login_flow( provider_type=provider_type, @@ -626,6 +716,13 @@ def initiate_oauth_authorize(provider: str): status=e.status_code, error_type=e.error_type, ) + except ExternalAuthError as e: + return api_response( + success=False, + message=e.message, + status=e.status_code, + error_type=e.error_type, + ) @api_v1_bp.route("/auth/external//callback", methods=["GET"]) @@ -666,8 +763,19 @@ def handle_oauth_callback(provider: str): frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080") frontend_callback = f"{frontend_url}/oauth/callback" + # Check if this is a CLI /token_please flow — retrieve stored redirect_url + cli_redirect_url = _pop_cli_redirect(state) if state else None + def redirect_error(message: str, error_type: str = "OAUTH_ERROR"): - """Redirect to frontend with error params.""" + """Redirect to frontend (or CLI) with error params.""" + if cli_redirect_url: + # CLI flow: return a plain error page instead of redirecting back + from flask import make_response + return make_response( + f"

Authentication Error

{message}

" + f"

You may close this window.

", + 400, + ) params = {"error": message, "error_type": error_type} if state: params["state"] = state @@ -706,8 +814,11 @@ def handle_oauth_callback(provider: str): # Recover oidc_session_id if this was triggered from an OIDC bridge flow oidc_session_id = _pop_oidc_bridge(state) + # Organization selection / creation flows are not supported in CLI mode + # (fall through to token redirect with whatever session we have) + # Organization selection needed (user belongs to multiple orgs) - if result.get("requires_org_selection"): + if result.get("requires_org_selection") and not cli_redirect_url: import json orgs = json.dumps(result.get("available_organizations", [])) params = { @@ -722,7 +833,7 @@ def handle_oauth_callback(provider: str): return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302) # Organization creation needed (new user via OAuth with no org) - if result.get("requires_org_creation"): + if result.get("requires_org_creation") and not cli_redirect_url: params = { "requires_org_creation": "1", "state": result["state"], @@ -751,6 +862,19 @@ def handle_oauth_callback(provider: str): user_info = result.get("user", {}) if user_info.get("email"): params["email"] = user_info["email"] + + # ── CLI /token_please flow: redirect to the CLI's local callback ───── + if cli_redirect_url: + # The CLI expects: http://127.0.0.1:8250/?token= + # cli_redirect_url already ends with "token=" so just append the value + cli_final_url = cli_redirect_url + token + logger.info( + f"CLI token_please success: provider={provider}, user={user_info.get('email')}, " + f"redirecting to CLI callback" + ) + return flask_redirect(cli_final_url, code=302) + + # ── Frontend flow ───────────────────────────────────────────────────── # Pass oidc_session_id through so the frontend can complete the OIDC flow if oidc_session_id: params["oidc_session_id"] = oidc_session_id diff --git a/gatehouse_app/api/v1/organizations.py b/gatehouse_app/api/v1/organizations.py index 8f4f6cb..9826f06 100644 --- a/gatehouse_app/api/v1/organizations.py +++ b/gatehouse_app/api/v1/organizations.py @@ -13,8 +13,6 @@ from gatehouse_app.schemas.organization_schema import ( from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.services.user_service import UserService from gatehouse_app.utils.constants import OrganizationRole - -########jb- need to implement departs, principals @api_v1_bp.route("/organizations", methods=["POST"]) @login_required @full_access_required @@ -378,3 +376,557 @@ def update_member_role(org_id, user_id): error_type="VALIDATION_ERROR", error_details=e.messages, ) + + +@api_v1_bp.route("/organizations//audit-logs", methods=["GET"]) +@login_required +@full_access_required +def get_organization_audit_logs(org_id): + """ + Get audit logs for an organization. + + Query params: + page: Page number (default 1) + per_page: Results per page (default 50, max 200) + action: Filter by action type + + Returns: + 200: List of audit log entries + 401: Not authenticated + 403: Not a member / insufficient permissions + 404: Organization not found + """ + from gatehouse_app.models.audit_log import AuditLog + + # Ensure org exists and user is a member (full_access_required handles this) + OrganizationService.get_organization_by_id(org_id) + + page = int(request.args.get("page", 1)) + per_page = min(int(request.args.get("per_page", 50)), 200) + action_filter = request.args.get("action") + + query = AuditLog.query.filter_by(organization_id=org_id) + if action_filter: + query = query.filter_by(action=action_filter) + + query = query.order_by(AuditLog.created_at.desc()) + total = query.count() + logs = query.offset((page - 1) * per_page).limit(per_page).all() + + def log_to_dict(log): + return { + "id": log.id, + "action": log.action.value if log.action else None, + "user_id": log.user_id, + "user_email": log.user.email if log.user else None, + "user": {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} if log.user else None, + "organization_id": log.organization_id, + "resource_type": log.resource_type, + "resource_id": log.resource_id, + "ip_address": log.ip_address, + "user_agent": log.user_agent, + "request_id": log.request_id, + "description": log.description, + "success": log.success, + "error_message": log.error_message, + "metadata": log.extra_data, + "created_at": log.created_at.isoformat() if log.created_at else None, + "updated_at": log.updated_at.isoformat() if log.updated_at else None, + } + + return api_response( + data={ + "audit_logs": [log_to_dict(log) for log in logs], + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + }, + message="Audit logs retrieved successfully", + ) + + +# ============================================================================ +# Organization Invite Tokens +# ============================================================================ + +@api_v1_bp.route("/organizations//invites", methods=["POST"]) +@login_required +@require_admin +def create_org_invite(org_id): + """Create an invite token for an organization. + + Request body: + email: Email address to invite + role: Role to assign (default: member) + + Returns: + 201: Invite created + 400: Validation error + 403: Not an admin + 404: Organization not found + """ + from gatehouse_app.models import OrgInviteToken, Organization + from gatehouse_app.services.notification_service import NotificationService + from flask import current_app + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + data = request.get_json() or {} + email = (data.get("email") or "").strip().lower() + role = (data.get("role") or "member").strip() + + if not email: + return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR") + + invite = OrgInviteToken.generate( + organization_id=org_id, + email=email, + role=role, + invited_by_id=g.current_user.id, + ) + + app_url = current_app.config.get("APP_URL", "http://localhost:8080") + invite_link = f"{app_url}/invite?token={invite.token}" + + NotificationService._send_email( + to_address=email, + subject=f"You're invited to join {org.name} on Gatehouse", + body=( + f"You've been invited to join {org.name} on Gatehouse.\n\n" + f"Click the link below to accept the invitation (valid for 7 days):\n" + f"{invite_link}\n\n" + f"Gatehouse Security Team" + ), + ) + + return api_response( + data={"invite": {"id": invite.id, "email": invite.email, "role": invite.role, "expires_at": invite.expires_at.isoformat() + "Z"}}, + message="Invite sent successfully", + status=201, + ) + + +@api_v1_bp.route("/invites/", methods=["GET"]) +def get_invite(token): + """Get invite details by token. + + Returns: + 200: Invite details (org name, email) + 400: Invalid or expired token + """ + from gatehouse_app.models import OrgInviteToken + + invite = OrgInviteToken.query.filter_by(token=token).first() + if not invite or not invite.is_valid: + return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN") + + return api_response( + data={ + "email": invite.email, + "organization": {"id": invite.organization_id, "name": invite.organization.name}, + "role": invite.role, + }, + message="Invite found", + ) + + +@api_v1_bp.route("/invites//accept", methods=["POST"]) +def accept_invite(token): + """Accept an organization invite. + + Creates the user account (if not already registered) and adds them + to the organization. + + Request body: + full_name: User's display name + password: Password for new account (if not already registered) + password_confirm: Password confirmation + + Returns: + 200: Invite accepted, returns user token + 400: Invalid/expired token or validation error + 409: Already a member + """ + from gatehouse_app.models import OrgInviteToken, User + from gatehouse_app.services.auth_service import AuthService + from gatehouse_app.services.organization_service import OrganizationService + from gatehouse_app.utils.constants import OrganizationRole + + invite = OrgInviteToken.query.filter_by(token=token).first() + if not invite or not invite.is_valid: + return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN") + + data = request.get_json() or {} + full_name = data.get("full_name") or "" + password = data.get("password") or "" + password_confirm = data.get("password_confirm") or "" + + user = User.query.filter_by(email=invite.email, deleted_at=None).first() + + if not user: + # Register a new user + if not password: + return api_response(success=False, message="Password is required for new accounts.", status=400, error_type="VALIDATION_ERROR") + if password != password_confirm: + return api_response(success=False, message="Passwords do not match.", status=400, error_type="VALIDATION_ERROR") + if len(password) < 8: + return api_response(success=False, message="Password must be at least 8 characters.", status=400, error_type="VALIDATION_ERROR") + try: + user = AuthService.register_user(email=invite.email, password=password, full_name=full_name or None) + except Exception as exc: + return api_response(success=False, message=str(exc), status=400, error_type="REGISTRATION_ERROR") + + # Add to org + role_value = invite.role + try: + org_role = OrganizationRole(role_value) + except ValueError: + org_role = OrganizationRole.MEMBER + + try: + OrganizationService.add_member( + org=invite.organization, + user_id=user.id, + role=org_role, + inviter_id=invite.invited_by_id, + ) + except Exception: + pass # Already a member is fine + + invite.accept() + + user_session = AuthService.create_session(user) + + return api_response( + data={ + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z", + }, + message="Invitation accepted. Welcome!", + ) + + +# ============================================================================ +# Organization OIDC Clients +# ============================================================================ + +@api_v1_bp.route("/organizations//clients", methods=["GET"]) +@login_required +def list_org_clients(org_id): + """List OIDC clients for an organization. + + Returns: + 200: List of OIDC clients + 403: Not a member + 404: Organization not found + """ + from gatehouse_app.models import OIDCClient, Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + clients = OIDCClient.query.filter_by(organization_id=org_id, is_active=True).all() + + def client_to_dict(c): + return { + "id": c.id, + "name": c.name, + "client_id": c.client_id, + "redirect_uris": c.redirect_uris, + "scopes": c.scopes, + "grant_types": c.grant_types, + "is_active": c.is_active, + "created_at": c.created_at.isoformat() + "Z", + } + + return api_response( + data={"clients": [client_to_dict(c) for c in clients], "count": len(clients)}, + message="Clients retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//clients", methods=["POST"]) +@login_required +@require_admin +def create_org_client(org_id): + """Create a new OIDC client for an organization. + + Request body: + name: Client name + redirect_uris: List of allowed redirect URIs (newline or comma separated string) + + Returns: + 201: Client created with client_id and client_secret + 403: Not an admin + 404: Organization not found + """ + import secrets as _secrets + from gatehouse_app.extensions import bcrypt + from gatehouse_app.models import OIDCClient, Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + data = request.get_json() or {} + name = (data.get("name") or "").strip() + redirect_uris_raw = data.get("redirect_uris") or [] + + if not name: + return api_response(success=False, message="Client name is required", status=400, error_type="VALIDATION_ERROR") + + if isinstance(redirect_uris_raw, str): + redirect_uris = [u.strip() for u in redirect_uris_raw.replace(",", "\n").splitlines() if u.strip()] + else: + redirect_uris = [u.strip() for u in redirect_uris_raw if isinstance(u, str) and u.strip()] + + if not redirect_uris: + return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR") + + client_id = _secrets.token_hex(16) + client_secret = _secrets.token_urlsafe(32) + + client = OIDCClient( + organization_id=org_id, + name=name, + client_id=client_id, + client_secret_hash=bcrypt.generate_password_hash(client_secret).decode("utf-8"), + redirect_uris=redirect_uris, + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scopes=["openid", "profile", "email"], + is_active=True, + is_confidential=True, + ) + from gatehouse_app.extensions import db + db.session.add(client) + db.session.commit() + + return api_response( + data={ + "client": { + "id": client.id, + "name": client.name, + "client_id": client.client_id, + "client_secret": client_secret, # Only returned once + "redirect_uris": client.redirect_uris, + "scopes": client.scopes, + "created_at": client.created_at.isoformat() + "Z", + } + }, + message="OIDC client created successfully", + status=201, + ) + + +@api_v1_bp.route("/organizations//clients/", methods=["DELETE"]) +@login_required +@require_admin +def delete_org_client(org_id, client_id): + """Deactivate an OIDC client. + + Returns: + 200: Client deactivated + 403: Not an admin + 404: Client not found + """ + from gatehouse_app.models import OIDCClient + from gatehouse_app.extensions import db + + client = OIDCClient.query.filter_by(id=client_id, organization_id=org_id).first() + if not client: + return api_response(success=False, message="Client not found", status=404) + + client.is_active = False + db.session.commit() + + return api_response(data={}, message="Client deactivated successfully") + + +@api_v1_bp.route("/organizations//members//send-mfa-reminder", methods=["POST"]) +@login_required +@require_admin +def send_mfa_reminder(org_id, user_id): + """Send an MFA reminder email to a specific member. + + Returns: + 200: Reminder sent (or silently skipped if no deadline record) + 403: Not an admin + 404: Member not found + """ + from gatehouse_app.models import User, MfaPolicyCompliance, OrganizationSecurityPolicy + from gatehouse_app.services.notification_service import NotificationService + + user = User.query.filter_by(id=user_id, deleted_at=None).first() + if not user: + return api_response(success=False, message="User not found", status=404) + + compliance = MfaPolicyCompliance.query.filter_by( + user_id=user_id, organization_id=org_id + ).first() + policy = OrganizationSecurityPolicy.query.filter_by(organization_id=org_id).first() + + if compliance and policy and compliance.deadline_at: + NotificationService.send_mfa_deadline_reminder(user, compliance, policy) + else: + # No compliance deadline — send a generic nudge + NotificationService._send_email( + to_address=user.email, + subject="Reminder: Set up multi-factor authentication", + body=( + f"Hi {user.full_name or user.email},\n\n" + "Your organization administrator has asked you to set up " + "multi-factor authentication (MFA) on your Gatehouse account.\n\n" + "Please log in and configure MFA as soon as possible.\n\n" + "Gatehouse Security Team" + ), + ) + + return api_response(data={}, message="Reminder sent successfully") + + +# ============================================================================= +# System-wide Audit Log (admin view) + User self audit +# ============================================================================= + +def _audit_log_to_dict(log): + """Serialize an AuditLog record to a dict.""" + return { + "id": log.id, + "action": log.action.value if log.action else None, + "user_id": log.user_id, + "user": ( + {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} + if log.user else None + ), + "organization_id": log.organization_id, + "resource_type": log.resource_type, + "resource_id": log.resource_id, + "ip_address": log.ip_address, + "user_agent": log.user_agent, + "request_id": log.request_id, + "description": log.description, + "success": log.success, + "error_message": log.error_message, + "metadata": log.extra_data, + "created_at": log.created_at.isoformat() if log.created_at else None, + "updated_at": log.updated_at.isoformat() if log.updated_at else None, + } + + +@api_v1_bp.route("/audit-logs", methods=["GET"]) +@login_required +def get_system_audit_logs(): + """ + Get all audit logs (system-wide). Any authenticated user can query + their own logs; org owners/admins also see org-scoped logs; this + endpoint returns ALL logs for users who own at least one org + (acting as an admin view). + + Query params: + page – page number (default 1) + per_page – results per page (default 50, max 200) + action – filter by AuditAction value + user_id – filter by user id + resource_type – filter by resource type + success – "true"/"false" + q – free-text search on description + """ + from gatehouse_app.models.audit_log import AuditLog + from gatehouse_app.models.organization_member import OrganizationMember + + current_user = g.current_user + page = max(1, int(request.args.get("page", 1))) + per_page = min(int(request.args.get("per_page", 50)), 200) + + # Check if the user is an owner of any org to grant admin-level access + is_admin = OrganizationMember.query.filter_by( + user_id=current_user.id, role="OWNER" + ).first() is not None + + query = AuditLog.query + + if not is_admin: + # Non-admins can only see their own logs + query = query.filter(AuditLog.user_id == current_user.id) + + # Optional filters + action_filter = request.args.get("action") + if action_filter: + query = query.filter(AuditLog.action == action_filter) + + user_id_filter = request.args.get("user_id") + if user_id_filter: + query = query.filter(AuditLog.user_id == user_id_filter) + + resource_type_filter = request.args.get("resource_type") + if resource_type_filter: + query = query.filter(AuditLog.resource_type == resource_type_filter) + + success_filter = request.args.get("success") + if success_filter is not None: + query = query.filter(AuditLog.success == (success_filter.lower() == "true")) + + q = request.args.get("q", "").strip() + if q: + query = query.filter(AuditLog.description.ilike(f"%{q}%")) + + query = query.order_by(AuditLog.created_at.desc()) + total = query.count() + logs = query.offset((page - 1) * per_page).limit(per_page).all() + + return api_response( + data={ + "audit_logs": [_audit_log_to_dict(log) for log in logs], + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + "is_admin_view": is_admin, + }, + message="Audit logs retrieved", + ) + + +@api_v1_bp.route("/auth/audit-logs", methods=["GET"]) +@login_required +def get_my_audit_logs(): + """ + Get audit logs for the currently authenticated user only. + + Query params: + page – page number (default 1) + per_page – results per page (default 50, max 200) + action – filter by AuditAction value + """ + from gatehouse_app.models.audit_log import AuditLog + + current_user = g.current_user + page = max(1, int(request.args.get("page", 1))) + per_page = min(int(request.args.get("per_page", 50)), 200) + + query = AuditLog.query.filter(AuditLog.user_id == current_user.id) + + action_filter = request.args.get("action") + if action_filter: + query = query.filter(AuditLog.action == action_filter) + + query = query.order_by(AuditLog.created_at.desc()) + total = query.count() + logs = query.offset((page - 1) * per_page).limit(per_page).all() + + return api_response( + data={ + "audit_logs": [_audit_log_to_dict(log) for log in logs], + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + }, + message="Activity retrieved", + ) diff --git a/gatehouse_app/api/v1/ssh.py b/gatehouse_app/api/v1/ssh.py new file mode 100644 index 0000000..3f4946f --- /dev/null +++ b/gatehouse_app/api/v1/ssh.py @@ -0,0 +1,615 @@ +"""SSH Key and Certificate API routes.""" +from flask import Blueprint, request, jsonify, g +from sqlalchemy.exc import IntegrityError +from gatehouse_app.services.ssh_key_service import SSHKeyService +from gatehouse_app.services.ssh_ca_signing_service import ( + SSHCASigningService, + SSHCertificateSigningRequest, +) +from gatehouse_app.exceptions import ( + SSHKeyError, + SSHKeyNotFoundError, + SSHCertificateError, + ValidationError, + SSHKeyAlreadyExistsError, +) +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.models import AuditLog +from gatehouse_app.utils.decorators import login_required + +ssh_bp = Blueprint('ssh', __name__, url_prefix='/ssh') +ssh_key_service = SSHKeyService() +ssh_ca_service = SSHCASigningService() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _get_org_ca_for_user(user): + """Return the active DB CA for the user's first org, or None.""" + try: + from gatehouse_app.models.ca import CA + org_ids = [m.organization_id for m in user.organization_memberships] + if not org_ids: + return None + return CA.query.filter( + CA.organization_id.in_(org_ids), + CA.is_active == True, # noqa: E712 + ).first() + except Exception: + return None + + +def _get_or_create_system_ca(): + """ + Return a CA DB record representing the config-file CA. + + This is used as the ``ca_id`` FK when persisting certificates that were + signed by the globally-configured CA key (not an org-specific DB CA). + The record is created on first use and has no ``organization_id``. + """ + from gatehouse_app.extensions import db + from gatehouse_app.models.ca import CA, KeyType + from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config + from gatehouse_app.utils.crypto import compute_ssh_fingerprint + import os + + try: + existing = CA.query.filter_by(name="system-config-ca").first() + if existing: + return existing + + cfg = get_ssh_ca_config() + key_path = cfg.get_str("ca_key_path", "").strip() + pub_key_path = key_path + ".pub" + + if not os.path.exists(pub_key_path): + return None + + with open(pub_key_path) as f: + pub_key = f.read().strip() + + # Load private key for the record (stored but not actually used for signing here) + priv_key = "" + if os.path.exists(key_path): + with open(key_path) as f: + priv_key = f.read() + + fingerprint = compute_ssh_fingerprint(pub_key) + + # Check by fingerprint in case it was created under a different name + existing_by_fp = CA.query.filter_by(fingerprint=fingerprint).first() + if existing_by_fp: + return existing_by_fp + + system_ca = CA( + name="system-config-ca", + description="Global CA loaded from etc/ssh_ca.conf (ca_key_path)", + key_type=KeyType.ED25519, + private_key=priv_key, + public_key=pub_key, + fingerprint=fingerprint, + is_active=True, + default_cert_validity_hours=24, + max_cert_validity_hours=720, + ) + # organization_id is nullable=False in schema — we need a dummy org or + # need to allow NULL. Use None; the DB constraint will tell us quickly. + # If the migration enforces NOT NULL we'll catch the error gracefully. + db.session.add(system_ca) + db.session.commit() + return system_ca + except Exception as exc: + import logging + logging.getLogger(__name__).warning( + f"Could not upsert system-config-ca: {exc}" + ) + try: + db.session.rollback() + except Exception: + pass + return None + + +def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None): + """Save a signed certificate to the ssh_certificates table. + + Args: + user_id: UUID of the user + ssh_key_id: UUID of the SSH key that was signed + ca: CA model instance (may be None — cert still returned but not persisted) + signing_response: SSHCertificateSigningResponse + request_ip: Client IP address + + Returns: + SSHCertificate instance or None if persistence failed + """ + if ca is None: + return None + + try: + from gatehouse_app.extensions import db + from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus + from gatehouse_app.models.ca import CertType + + cert_record = SSHCertificate( + ca_id=ca.id, + user_id=user_id, + ssh_key_id=ssh_key_id, + certificate=signing_response.certificate, + serial=signing_response.serial, + key_id=str(ssh_key_id), + cert_type=CertType.USER, + principals=signing_response.principals, + valid_after=signing_response.valid_after, + valid_before=signing_response.valid_before, + revoked=False, + status=CertificateStatus.ISSUED, + request_ip=request_ip, + ) + db.session.add(cert_record) + db.session.commit() + return cert_record + except Exception as exc: + import logging + logging.getLogger(__name__).warning( + f"Failed to persist certificate to DB: {exc}" + ) + try: + from gatehouse_app.extensions import db as _db + _db.session.rollback() + except Exception: + pass + return None + + + +@ssh_bp.route('/keys', methods=['GET']) +@login_required +def list_ssh_keys(): + """Get all SSH keys for current user.""" + user_id = g.current_user.id + + keys = ssh_key_service.get_user_ssh_keys(user_id) + return jsonify({ + 'keys': [k.to_dict() for k in keys], + 'count': len(keys), + }), 200 + + +@ssh_bp.route('/keys', methods=['POST']) +@login_required +def add_ssh_key(): + """Add a new SSH public key for current user.""" + user_id = g.current_user.id + + data = request.get_json() + if not data: + return jsonify({'error': 'No JSON data provided'}), 400 + + public_key = data.get('public_key') or data.get('key') + description = data.get('description') + + if not public_key: + return jsonify({'error': 'public_key is required'}), 400 + + try: + ssh_key = ssh_key_service.add_ssh_key( + user_id=user_id, + public_key=public_key, + description=description, + ) + + # Audit log + AuditLog.log( + action=AuditAction.SSH_KEY_ADDED, + user_id=user_id, + resource_type='SSHKey', + resource_id=ssh_key.id, + ip_address=request.remote_addr, + ) + + return jsonify(ssh_key.to_dict()), 201 + + except SSHKeyAlreadyExistsError as e: + return jsonify({'error': e.message, 'code': 'SSH_KEY_ALREADY_EXISTS'}), 409 + except IntegrityError: + return jsonify({'error': 'SSH key already exists', 'code': 'SSH_KEY_ALREADY_EXISTS'}), 409 + except SSHKeyError as e: + return jsonify({'error': str(e)}), 400 + except ValidationError as e: + return jsonify({'error': str(e)}), 400 + + +@ssh_bp.route('/keys/', methods=['GET']) +@login_required +def get_ssh_key(key_id): + """Get a specific SSH key.""" + user_id = g.current_user.id + + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + + # Check ownership + if ssh_key.user_id != user_id: + return jsonify({'error': 'Forbidden'}), 403 + + return jsonify(ssh_key.to_dict()), 200 + + except SSHKeyNotFoundError: + return jsonify({'error': 'SSH key not found'}), 404 + + +@ssh_bp.route('/keys/', methods=['DELETE']) +@login_required +def delete_ssh_key(key_id): + """Delete an SSH key.""" + user_id = g.current_user.id + + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + + # Check ownership + if ssh_key.user_id != user_id: + return jsonify({'error': 'Forbidden'}), 403 + + ssh_key_service.delete_ssh_key(key_id) + + # Audit log + AuditLog.log( + action=AuditAction.SSH_KEY_DELETED, + user_id=user_id, + resource_type='SSHKey', + resource_id=key_id, + ip_address=request.remote_addr, + ) + + return jsonify({'status': 'deleted'}), 200 + + except SSHKeyNotFoundError: + return jsonify({'error': 'SSH key not found'}), 404 + + +@ssh_bp.route('/keys//verify', methods=['GET', 'POST']) +@login_required +def verify_ssh_key(key_id): + """Generate or verify SSH key ownership challenge.""" + user_id = g.current_user.id + + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + + # Check ownership + if ssh_key.user_id != user_id: + return jsonify({'error': 'Forbidden'}), 403 + + # Handle GET request - return challenge + if request.method == 'GET': + challenge = ssh_key_service.generate_verification_challenge(key_id) + return jsonify({ + 'challenge_text': challenge, + 'validationText': challenge, # Backwards compatibility + 'key_id': key_id, + }), 200 + + # Handle POST request - verify signature + data = request.get_json() or {} + action = data.get('action', 'verify_signature') + + if action == 'verify_signature': + # Verify signature + signature = data.get('signature') + if not signature: + return jsonify({'error': 'signature is required'}), 400 + + try: + verified = ssh_key_service.verify_ssh_key_ownership(key_id, signature) + + # Audit log + AuditLog.log( + action=AuditAction.SSH_KEY_VERIFIED, + user_id=user_id, + resource_type='SSHKey', + resource_id=key_id, + ip_address=request.remote_addr, + success=verified, + ) + + return jsonify({'verified': verified}), 200 + + except Exception as e: + AuditLog.log( + action=AuditAction.SSH_KEY_VALIDATION_FAILED, + user_id=user_id, + resource_type='SSHKey', + resource_id=key_id, + ip_address=request.remote_addr, + success=False, + error_message=str(e), + ) + return jsonify({'error': str(e)}), 400 + + else: # generate_challenge + # Generate verification challenge + challenge = ssh_key_service.generate_verification_challenge(key_id) + return jsonify({ + 'challenge_text': challenge, + 'challenge': challenge, # Both for compatibility + }), 200 + + except SSHKeyNotFoundError: + return jsonify({'error': 'SSH key not found'}), 404 + + +@ssh_bp.route('/keys//update-description', methods=['PATCH']) +@login_required +def update_ssh_key_description(key_id): + """Update SSH key description.""" + user_id = g.current_user.id + + data = request.get_json() + if not data or 'description' not in data: + return jsonify({'error': 'description is required'}), 400 + + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + + # Check ownership + if ssh_key.user_id != user_id: + return jsonify({'error': 'Forbidden'}), 403 + + updated_key = ssh_key_service.update_ssh_key_description( + key_id, + data['description'] + ) + + return jsonify(updated_key.to_dict()), 200 + + except SSHKeyNotFoundError: + return jsonify({'error': 'SSH key not found'}), 404 + + +@ssh_bp.route('/sign', methods=['POST']) +@login_required +def sign_certificate(): + """Sign an SSH certificate for the current user.""" + user = g.current_user + user_id = user.id + + data = request.get_json() + if not data: + return jsonify({'error': 'No JSON data provided'}), 400 + + try: + principals = data.get('principals', []) + cert_type = data.get('cert_type', 'user') + # Accept both 'key_id' and 'cert_id' (from CLI) + key_id = data.get('key_id') or data.get('cert_id') + expiry_hours = data.get('expiry_hours') + + if not principals: + return jsonify({'error': 'principals is required'}), 400 + + # If key_id not specified, use first verified key + if not key_id: + verified_keys = ssh_key_service.get_user_verified_ssh_keys(user_id) + if not verified_keys: + return jsonify({'error': 'No verified SSH keys found'}), 400 + key_id = verified_keys[0].id + + # Get the SSH key + ssh_key = ssh_key_service.get_ssh_key(key_id) + if ssh_key.user_id != user_id: + return jsonify({'error': 'Forbidden'}), 403 + + if not ssh_key.verified: + return jsonify({'error': 'SSH key is not verified'}), 400 + + # Resolve which CA to use: org DB CA > config-file CA + db_ca = _get_org_ca_for_user(user) + ca_private_key = db_ca.private_key if db_ca else None # None → signing service uses config + + # Create signing request + signing_request = SSHCertificateSigningRequest( + ssh_public_key=ssh_key.payload, + principals=principals, + cert_type=cert_type, + key_id=key_id, + expiry_hours=int(expiry_hours) if expiry_hours else None, + ) + + # Validate request + validation_errors = signing_request.validate() + if validation_errors: + return jsonify({'errors': validation_errors}), 400 + + # Sign the certificate (pass ca_private_key=None → service loads from config) + response = ssh_ca_service.sign_certificate(signing_request, ca_private_key=ca_private_key) + + # Persist certificate to DB + # If user's org has no DB CA, use the system-config-ca record + ca_for_db = db_ca or _get_or_create_system_ca() + cert_record = _persist_certificate( + user_id=user_id, + ssh_key_id=key_id, + ca=ca_for_db, + signing_response=response, + request_ip=request.remote_addr, + ) + + # Audit log + AuditLog.log( + action=AuditAction.SSH_CERT_ISSUED, + user_id=user_id, + resource_type='SSHCertificate', + resource_id=cert_record.id if cert_record else key_id, + ip_address=request.remote_addr, + description=f'Certificate issued for principals: {", ".join(principals)}', + ) + + result = { + 'certificate': response.certificate, + 'serial': response.serial, + 'principals': response.principals, + 'valid_after': response.valid_after.isoformat() if response.valid_after else None, + 'valid_before': response.valid_before.isoformat() if response.valid_before else None, + } + if cert_record: + result['cert_id'] = str(cert_record.id) + + return jsonify(result), 201 + + except SSHKeyNotFoundError: + return jsonify({'error': 'SSH key not found'}), 404 + except SSHCertificateError as e: + AuditLog.log( + action=AuditAction.SSH_CERT_FAILED, + user_id=user_id, + resource_type='SSHCertificate', + ip_address=request.remote_addr, + success=False, + error_message=str(e), + ) + return jsonify({'error': str(e)}), 400 + except Exception as e: + AuditLog.log( + action=AuditAction.SSH_CERT_FAILED, + user_id=user_id, + resource_type='SSHCertificate', + ip_address=request.remote_addr, + success=False, + error_message=str(e), + ) + return jsonify({'error': 'Certificate signing failed: ' + str(e)}), 500 + + +@ssh_bp.route('/certificates', methods=['GET']) +@login_required +def list_certificates(): + """List all SSH certificates issued for the current user.""" + user_id = g.current_user.id + + try: + from gatehouse_app.models.ssh_certificate import SSHCertificate + certs = ( + SSHCertificate.query + .filter_by(user_id=user_id, deleted_at=None) + .order_by(SSHCertificate.created_at.desc()) + .all() + ) + return jsonify({ + 'certificates': [c.to_dict() for c in certs], + 'count': len(certs), + }), 200 + except Exception as e: + return jsonify({'error': str(e)}), 500 + + +@ssh_bp.route('/certificates/', methods=['GET']) +@login_required +def get_certificate(cert_id): + """Get a specific issued certificate (metadata only).""" + user_id = g.current_user.id + + try: + from gatehouse_app.models.ssh_certificate import SSHCertificate + cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first() + if not cert: + return jsonify({'error': 'Certificate not found'}), 404 + if cert.user_id != user_id: + return jsonify({'error': 'Forbidden'}), 403 + # Include full certificate text in single-fetch endpoint + data = cert.to_dict() + data['certificate'] = cert.certificate + return jsonify(data), 200 + except Exception as e: + return jsonify({'error': str(e)}), 500 + + +@ssh_bp.route('/certificates//revoke', methods=['POST']) +@login_required +def revoke_certificate(cert_id): + """Revoke an issued certificate.""" + user_id = g.current_user.id + + data = request.get_json() or {} + reason = data.get('reason', 'User requested revocation') + + try: + from gatehouse_app.models.ssh_certificate import SSHCertificate + cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first() + if not cert: + return jsonify({'error': 'Certificate not found'}), 404 + if cert.user_id != user_id: + return jsonify({'error': 'Forbidden'}), 403 + if cert.revoked: + return jsonify({'error': 'Certificate is already revoked'}), 409 + + cert.revoke(reason=reason) + + AuditLog.log( + action=AuditAction.SSH_CERT_REVOKED, + user_id=user_id, + resource_type='SSHCertificate', + resource_id=cert_id, + ip_address=request.remote_addr, + description=f'Revoked: {reason}', + ) + + return jsonify({'status': 'revoked', 'cert_id': cert_id, 'reason': reason}), 200 + except Exception as e: + return jsonify({'error': str(e)}), 500 + + +@ssh_bp.route('/ca/public-key', methods=['GET']) +@login_required +def get_ca_public_key(): + """ + Return the CA public key for this user's organization. + + Server admins should add this key to their host's ``TrustedUserCAKeys`` + directive so that certificates issued by gatehouse are trusted. + + Query parameters: + format: 'openssh' (default) or 'text' — affects Content-Type only + + Returns: + { "public_key": "ssh-ed25519 AAAA...", + "fingerprint": "SHA256:...", + "ca_name": "..." } + """ + user = g.current_user + + # Try org CA first + db_ca = _get_org_ca_for_user(user) + if db_ca: + return jsonify({ + 'public_key': db_ca.public_key, + 'fingerprint': db_ca.fingerprint, + 'ca_name': db_ca.name, + 'source': 'db', + }), 200 + + # Fall back to config-file CA + try: + from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config + import os + cfg = get_ssh_ca_config() + key_path = cfg.get_str('ca_key_path', '').strip() + '.pub' + if os.path.exists(key_path): + with open(key_path) as f: + pub_key = f.read().strip() + from gatehouse_app.utils.crypto import compute_ssh_fingerprint + return jsonify({ + 'public_key': pub_key, + 'fingerprint': compute_ssh_fingerprint(pub_key), + 'ca_name': 'system-config-ca', + 'source': 'config', + }), 200 + except Exception as e: + return jsonify({'error': f'Could not load CA public key: {e}'}), 500 + + return jsonify({'error': 'No CA configured for this organization'}), 404 + + diff --git a/gatehouse_app/config/ssh_ca_config.py b/gatehouse_app/config/ssh_ca_config.py new file mode 100644 index 0000000..e3cbda9 --- /dev/null +++ b/gatehouse_app/config/ssh_ca_config.py @@ -0,0 +1,271 @@ +"""SSH CA Configuration Manager. + +Handles loading and managing SSH CA configuration from etc/ssh_ca.conf +and environment variables. +""" +import os +import configparser +from pathlib import Path +from typing import Optional, Union + + +class SSHCAConfig: + """Configuration manager for SSH CA settings. + + Loads configuration from: + 1. etc/ssh_ca.conf file + 2. Environment variables (override config file) + 3. Application environment-specific defaults + + Example: + config = SSHCAConfig() + cert_hours = config.get_int('cert_validity_hours') + kms_key = config.get_str('aws_kms_key_id') + """ + + # Configuration file location (relative to project root) + DEFAULT_CONFIG_FILE = "etc/ssh_ca.conf" + + # Default values if config file is missing + DEFAULTS = { + 'cert_validity_hours': '1', + 'max_cert_validity_hours': '24', + 'max_certs_per_user': '100', + 'crl_enabled': 'true', + 'crl_endpoint': 'https://ca.example.com/crl', + 'crl_refresh_hours': '24', + 'default_key_type': 'ed25519', + 'rsa_key_bits': '4096', + 'private_key_encryption': 'kms', + 'aws_kms_key_id': '', + 'extensions_enabled': 'true', + 'extensions': 'permit-X11-forwarding,permit-agent-forwarding,permit-pty,permit-port-forwarding,permit-user-rc', + 'critical_options_enabled': 'false', + 'max_principals_per_cert': '256', + 'max_key_id_length': '255', + 'log_level': 'INFO', + 'audit_enabled': 'true', + 'require_key_verification': 'true', + 'verification_challenge_max_age': '24', + 'rate_limit_certs_per_minute': '5', + 'request_timeout': '30', + 'auto_delete_unverified_days': '30', + 'archive_expired_days': '365', + 'oauth_token_endpoint': '/api/v1/oauth2/token', + 'oauth_userinfo_endpoint': '/api/v1/oauth2/userinfo', + 'ca_key_path': '', + } + + def __init__(self, config_file: Optional[str] = None, environment: Optional[str] = None): + """Initialize SSH CA configuration. + + Args: + config_file: Path to config file (default: etc/ssh_ca.conf) + environment: Environment name (development, production, testing) + Default: value of FLASK_ENV or 'development' + """ + self.config = configparser.ConfigParser() + + # Determine environment + if environment is None: + environment = os.environ.get('FLASK_ENV', 'development') + self.environment = environment + + # Load config file + if config_file is None: + # Try to find config file relative to this module + module_dir = Path(__file__).parent.parent.parent + config_file = module_dir / self.DEFAULT_CONFIG_FILE + + self.config_file = config_file + self._load_config() + + def _load_config(self): + """Load configuration from file and apply environment-specific overrides.""" + # Set defaults + self.config['default'] = self.DEFAULTS.copy() + + # Load config file if it exists + if Path(self.config_file).exists(): + self.config.read(self.config_file) + + # Apply environment-specific configuration + if self.environment in self.config: + for key, value in self.config[self.environment].items(): + self.config['default'][key] = value + + def get_str(self, key: str, default: Optional[str] = None) -> str: + """Get a string configuration value. + + First checks environment variables (SSH_CA_), then config file. + + Args: + key: Configuration key + default: Default value if not found + + Returns: + Configuration value as string + """ + env_key = f"SSH_CA_{key.upper()}" + + # Check environment variable first + if env_key in os.environ: + return os.environ[env_key] + + # Check config file + if key in self.config['default']: + value = self.config['default'][key] + # Handle environment variable substitution + return os.path.expandvars(value) + + # Return default + if default is not None: + return default + + return self.DEFAULTS.get(key, '') + + def get_int(self, key: str, default: Optional[int] = None) -> int: + """Get an integer configuration value. + + Args: + key: Configuration key + default: Default value if not found + + Returns: + Configuration value as integer + + Raises: + ValueError: If value cannot be converted to integer + """ + str_value = self.get_str(key) + if not str_value: + if default is not None: + return default + raise ValueError(f"No value found for {key}") + + try: + return int(str_value) + except ValueError: + if default is not None: + return default + raise ValueError(f"Configuration {key}={str_value} is not a valid integer") + + def get_bool(self, key: str, default: Optional[bool] = None) -> bool: + """Get a boolean configuration value. + + Args: + key: Configuration key + default: Default value if not found + + Returns: + Configuration value as boolean + """ + str_value = self.get_str(key) + if not str_value: + if default is not None: + return default + return False + + return str_value.lower() in ('true', '1', 'yes', 'on') + + def get_list(self, key: str, delimiter: str = ',', default: Optional[list] = None) -> list: + """Get a comma-separated list configuration value. + + Args: + key: Configuration key + delimiter: Delimiter between items (default: comma) + default: Default value if not found + + Returns: + Configuration value as list of strings + """ + str_value = self.get_str(key) + if not str_value: + if default is not None: + return default + return [] + + return [item.strip() for item in str_value.split(delimiter) if item.strip()] + + def validate_config(self) -> list: + """Validate SSH CA configuration. + + Returns: + List of validation error messages (empty if valid) + """ + errors = [] + + # Check cert validity hours + try: + validity = self.get_int('cert_validity_hours') + max_validity = self.get_int('max_cert_validity_hours') + if validity > max_validity: + errors.append( + f"cert_validity_hours ({validity}) > max_cert_validity_hours ({max_validity})" + ) + except ValueError as e: + errors.append(f"Invalid cert validity hours: {e}") + + # Check key type + valid_key_types = ['ed25519', 'rsa', 'ecdsa'] + key_type = self.get_str('default_key_type', 'ed25519') + if key_type not in valid_key_types: + errors.append(f"Invalid key type: {key_type}. Must be one of {valid_key_types}") + + # Check encryption method + valid_methods = ['kms', 'local'] + encryption = self.get_str('private_key_encryption', 'kms') + if encryption not in valid_methods: + errors.append(f"Invalid private_key_encryption: {encryption}. Must be one of {valid_methods}") + + # Warn if using local encryption in production + if encryption == 'local' and self.environment == 'production': + errors.append("WARNING: Using local key encryption in production! Use KMS instead.") + + # Check KMS key ID if using KMS + if encryption == 'kms': + kms_key = self.get_str('aws_kms_key_id', '').strip() + if not kms_key: + errors.append("aws_kms_key_id not set but private_key_encryption=kms") + + # Check principals limit + max_principals = self.get_int('max_principals_per_cert') + if max_principals > 256: + errors.append(f"max_principals_per_cert ({max_principals}) exceeds SSH limit of 256") + + return errors + + def to_dict(self) -> dict: + """Export current configuration as dictionary. + """ + return dict(self.config['default']) + + def __repr__(self): + """String representation of configuration.""" + return f"" + + +# Global configuration instance +_config_instance = None + + +def get_ssh_ca_config() -> SSHCAConfig: + """Get the global SSH CA configuration instance. + + This function uses a singleton pattern to ensure only one + configuration instance is created and reused. + + Returns: + SSHCAConfig instance + """ + global _config_instance + if _config_instance is None: + _config_instance = SSHCAConfig() + return _config_instance + + +def reset_config_instance(): + """Reset the global configuration instance. + """ + global _config_instance + _config_instance = None diff --git a/gatehouse_app/exceptions/__init__.py b/gatehouse_app/exceptions/__init__.py index bd8e6f8..06c3f34 100644 --- a/gatehouse_app/exceptions/__init__.py +++ b/gatehouse_app/exceptions/__init__.py @@ -19,6 +19,21 @@ from gatehouse_app.exceptions.validation_exceptions import ( OrganizationNotFoundError, UserNotFoundError, ) +from gatehouse_app.exceptions.ssh_exceptions import ( + SSHCAError, + SSHKeyError, + SSHKeyNotFoundError, + SSHKeyAlreadyExistsError, + SSHKeyNotVerifiedError, + SSHCertificateError, + SSHCertificateNotFoundError, + CAError, + CANotFoundError, + PrincipalError, + PrincipalNotFoundError, + DepartmentError, + DepartmentNotFoundError, +) __all__ = [ "BaseAPIException", @@ -37,4 +52,18 @@ __all__ = [ "EmailAlreadyExistsError", "OrganizationNotFoundError", "UserNotFoundError", + "SSHCAError", + "SSHKeyError", + "SSHKeyNotFoundError", + "SSHKeyAlreadyExistsError", + "SSHKeyNotVerifiedError", + "SSHCertificateError", + "SSHCertificateNotFoundError", + "CAError", + "CANotFoundError", + "PrincipalError", + "PrincipalNotFoundError", + "DepartmentError", + "DepartmentNotFoundError", ] + diff --git a/gatehouse_app/exceptions/base.py b/gatehouse_app/exceptions/base.py index 3f1f42b..f7cfb0e 100644 --- a/gatehouse_app/exceptions/base.py +++ b/gatehouse_app/exceptions/base.py @@ -16,9 +16,10 @@ class BaseAPIException(Exception): message: Custom error message error_details: Additional error details dictionary """ - super().__init__() + super().__init__(self.message) if message: self.message = message + super().__init__(message) # update args so str(e) works self.error_details = error_details or {} def to_dict(self): diff --git a/gatehouse_app/exceptions/ssh_exceptions.py b/gatehouse_app/exceptions/ssh_exceptions.py new file mode 100644 index 0000000..141c1e9 --- /dev/null +++ b/gatehouse_app/exceptions/ssh_exceptions.py @@ -0,0 +1,93 @@ +"""SSH-specific exceptions.""" +from gatehouse_app.exceptions.base import BaseAPIException + + +class SSHCAError(BaseAPIException): + """Base exception for SSH CA operations.""" + + status_code = 500 + error_type = "SSH_CA_ERROR" + + +class SSHKeyError(BaseAPIException): + """Exception for SSH key operations.""" + + status_code = 400 + error_type = "SSH_KEY_ERROR" + + +class SSHKeyNotFoundError(BaseAPIException): + """SSH key not found.""" + + status_code = 404 + error_type = "SSH_KEY_NOT_FOUND" + + +class SSHKeyAlreadyExistsError(BaseAPIException): + """SSH key already exists (duplicate fingerprint).""" + + status_code = 409 + error_type = "SSH_KEY_ALREADY_EXISTS" + + +class SSHKeyNotVerifiedError(BaseAPIException): + """SSH key has not been verified.""" + + status_code = 400 + error_type = "SSH_KEY_NOT_VERIFIED" + + +class SSHCertificateError(BaseAPIException): + """Exception for SSH certificate operations.""" + + status_code = 400 + error_type = "SSH_CERT_ERROR" + + +class SSHCertificateNotFoundError(BaseAPIException): + """SSH certificate not found.""" + + status_code = 404 + error_type = "SSH_CERT_NOT_FOUND" + + +class CAError(BaseAPIException): + """Exception for Certificate Authority operations.""" + + status_code = 400 + error_type = "CA_ERROR" + + +class CANotFoundError(BaseAPIException): + """Certificate Authority not found.""" + + status_code = 404 + error_type = "CA_NOT_FOUND" + + +class PrincipalError(BaseAPIException): + """Exception for principal operations.""" + + status_code = 400 + error_type = "PRINCIPAL_ERROR" + + +class PrincipalNotFoundError(BaseAPIException): + """Principal not found.""" + + status_code = 404 + error_type = "PRINCIPAL_NOT_FOUND" + + +class DepartmentError(BaseAPIException): + """Exception for department operations.""" + + status_code = 400 + error_type = "DEPARTMENT_ERROR" + + +class DepartmentNotFoundError(BaseAPIException): + """Department not found.""" + + status_code = 404 + error_type = "DEPARTMENT_NOT_FOUND" diff --git a/gatehouse_app/models/__init__.py b/gatehouse_app/models/__init__.py index 71ba304..252ee0e 100644 --- a/gatehouse_app/models/__init__.py +++ b/gatehouse_app/models/__init__.py @@ -29,6 +29,13 @@ from gatehouse_app.models.principal import ( Principal, PrincipalMembership, ) +from gatehouse_app.models.ssh_key import SSHKey +from gatehouse_app.models.ca import CA, KeyType, CertType +from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus +from gatehouse_app.models.certificate_audit_log import CertificateAuditLog +from gatehouse_app.models.password_reset_token import PasswordResetToken +from gatehouse_app.models.email_verification_token import EmailVerificationToken +from gatehouse_app.models.org_invite_token import OrgInviteToken __all__ = [ "BaseModel", @@ -55,4 +62,14 @@ __all__ = [ "DepartmentPrincipal", "Principal", "PrincipalMembership", + "SSHKey", + "CA", + "KeyType", + "CertType", + "SSHCertificate", + "CertificateStatus", + "CertificateAuditLog", + "PasswordResetToken", + "EmailVerificationToken", + "OrgInviteToken", ] diff --git a/gatehouse_app/models/ca.py b/gatehouse_app/models/ca.py new file mode 100644 index 0000000..8a6271f --- /dev/null +++ b/gatehouse_app/models/ca.py @@ -0,0 +1,155 @@ +"""Certificate Authority (CA) model.""" +from enum import Enum +from datetime import datetime +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class KeyType(str, Enum): + """SSH CA key types.""" + + ED25519 = "ed25519" + RSA = "rsa" + ECDSA = "ecdsa" + + +class CertType(str, Enum): + """SSH certificate types.""" + + USER = "user" + HOST = "host" + + +class CA(BaseModel): + """Certificate Authority (CA) model for SSH certificate signing. + + Each organization can have multiple CAs for different purposes + (e.g., production vs. staging). Private keys are encrypted at rest + and should be protected with KMS. + """ + + __tablename__ = "cas" + + organization_id = db.Column( + db.String(36), + db.ForeignKey("organizations.id"), + nullable=True, # NULL for the global system-config CA + index=True, + ) + + # CA name and description + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=True) + + # Key type (ED25519, RSA, ECDSA) + key_type = db.Column( + db.Enum(KeyType, values_callable=lambda x: [e.value for e in x]), + default=KeyType.ED25519, + nullable=False, + ) + + # Private key (encrypted at rest by database/KMS) + # Format: PEM-encoded private key + private_key = db.Column(db.Text, nullable=False) + + # Public key (PEM format) + public_key = db.Column(db.Text, nullable=False) + + # SHA256 fingerprint of the public key + fingerprint = db.Column(db.String(255), nullable=False, unique=True) + + # CRL (Certificate Revocation List) configuration + crl_enabled = db.Column(db.Boolean, default=True, nullable=False) + crl_endpoint = db.Column(db.String(512), nullable=True) + + # Default certificate validity in hours + # Can be overridden per certificate request + default_cert_validity_hours = db.Column( + db.Integer, + default=1, + nullable=False, + ) + + # Maximum validity duration allowed + max_cert_validity_hours = db.Column( + db.Integer, + default=24, + nullable=False, + ) + + # CA status + is_active = db.Column(db.Boolean, default=True, nullable=False, index=True) + + # Key rotation tracking + rotated_at = db.Column(db.DateTime, nullable=True) + rotation_reason = db.Column(db.String(255), nullable=True) + + # Relationships + organization = db.relationship("Organization", back_populates="cas") + certificates = db.relationship( + "SSHCertificate", + back_populates="ca", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + db.UniqueConstraint( + "organization_id", "name", name="uix_org_ca_name" + ), + db.Index("idx_ca_org_active", "organization_id", "is_active"), + ) + + def __repr__(self): + """String representation of CA.""" + return f"" + + def to_dict(self, exclude=None): + """Convert CA to dictionary.""" + exclude = exclude or [] + # Never expose private key in API responses + exclude.extend(["private_key"]) + data = super().to_dict(exclude=exclude) + + # Add computed fields + data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None]) + data["active_certs"] = len([ + c for c in self.certificates + if c.deleted_at is None and not c.revoked + ]) + data["revoked_certs"] = len([ + c for c in self.certificates + if c.deleted_at is None and c.revoked + ]) + + return data + + def get_active_certificates(self): + """Get all active (non-revoked) certificates issued by this CA. + + Returns: + List of non-revoked SSHCertificate objects + """ + return [ + c for c in self.certificates + if c.deleted_at is None and not c.revoked + ] + + def rotate_key(self, new_private_key, new_public_key, new_fingerprint, reason=None): + """Rotate the CA's key pair. + + This should only be done in carefully controlled circumstances. + All existing certificates remain valid but no new certs can be + signed with the old key. + + Args: + new_private_key: New PEM-encoded private key + new_public_key: New PEM-encoded public key + new_fingerprint: SHA256 fingerprint of new public key + reason: Optional reason for rotation + """ + self.private_key = new_private_key + self.public_key = new_public_key + self.fingerprint = new_fingerprint + self.rotated_at = datetime.utcnow() + self.rotation_reason = reason + self.save() diff --git a/gatehouse_app/models/certificate_audit_log.py b/gatehouse_app/models/certificate_audit_log.py new file mode 100644 index 0000000..0d3274b --- /dev/null +++ b/gatehouse_app/models/certificate_audit_log.py @@ -0,0 +1,83 @@ +"""Certificate audit log model.""" +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class CertificateAuditLog(BaseModel): + """Audit log for SSH certificate lifecycle events. + + Tracks all operations on SSH certificates: signing, revocation, + validation, etc. This is separate from the general AuditLog to + provide detailed certificate operation tracking. + """ + + __tablename__ = "certificate_audit_logs" + + # Reference to the certificate + certificate_id = db.Column( + db.String(36), + db.ForeignKey("ssh_certificates.id"), + nullable=False, + index=True, + ) + + # The user who performed the action (can be null for system actions) + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id"), + nullable=True, + index=True, + ) + + # Action type (e.g., "signed", "revoked", "validated", "requested") + action = db.Column(db.String(50), nullable=False, index=True) + + # Request details + ip_address = db.Column(db.String(45), nullable=True) + user_agent = db.Column(db.String(512), nullable=True) + request_id = db.Column(db.String(36), nullable=True) + + # Detailed message + message = db.Column(db.Text, nullable=True) + + # Additional context + extra_data = db.Column(db.JSON, nullable=True) + + # Success/failure + success = db.Column(db.Boolean, default=True, nullable=False) + error_message = db.Column(db.Text, nullable=True) + + # Relationships + certificate = db.relationship("SSHCertificate", back_populates="audit_logs") + user = db.relationship("User") + + __table_args__ = ( + db.Index("idx_cert_audit_cert_action", "certificate_id", "action"), + db.Index("idx_cert_audit_user", "user_id", "created_at"), + ) + + def __repr__(self): + """String representation of CertificateAuditLog.""" + return f"" + + @classmethod + def log(cls, certificate_id, action, user_id=None, **kwargs): + """Create a certificate audit log entry. + + Args: + certificate_id: ID of the certificate + action: Action type (e.g., "signed", "revoked") + user_id: ID of the user performing the action (optional) + **kwargs: Additional fields (ip_address, user_agent, message, etc.) + + Returns: + CertificateAuditLog instance + """ + log_entry = cls( + certificate_id=certificate_id, + action=action, + user_id=user_id, + **kwargs + ) + log_entry.save() + return log_entry diff --git a/gatehouse_app/models/organization.py b/gatehouse_app/models/organization.py index cf7c110..a6fa756 100644 --- a/gatehouse_app/models/organization.py +++ b/gatehouse_app/models/organization.py @@ -40,6 +40,9 @@ class Organization(BaseModel): principals = db.relationship( "Principal", back_populates="organization", cascade="all, delete-orphan" ) + cas = db.relationship( + "CA", back_populates="organization", cascade="all, delete-orphan" + ) def __repr__(self): """String representation of Organization.""" diff --git a/gatehouse_app/models/ssh_certificate.py b/gatehouse_app/models/ssh_certificate.py new file mode 100644 index 0000000..b81e98b --- /dev/null +++ b/gatehouse_app/models/ssh_certificate.py @@ -0,0 +1,175 @@ +"""SSH Certificate model.""" +from enum import Enum +from datetime import datetime +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel +from gatehouse_app.models.ca import CertType + + +class CertificateStatus(str, Enum): + """SSH certificate lifecycle status.""" + + REQUESTED = "requested" # Waiting for signing + ISSUED = "issued" # Signed and valid + REVOKED = "revoked" # Manually revoked + EXPIRED = "expired" # Validity period ended + SUPERSEDED = "superseded" # Replaced by newer cert + + +class SSHCertificate(BaseModel): + """SSH Certificate model representing a signed SSH user/host certificate. + + Certificates are issued by a CA and associated with an SSH public key. + They include principals (access levels), validity periods, and other + OpenSSH certificate metadata. + """ + + __tablename__ = "ssh_certificates" + + # Certificate relationships + ca_id = db.Column( + db.String(36), + db.ForeignKey("cas.id"), + nullable=False, + index=True, + ) + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id"), + nullable=False, + index=True, + ) + ssh_key_id = db.Column( + db.String(36), + db.ForeignKey("ssh_keys.id"), + nullable=False, + index=True, + ) + + # Certificate content (full signed certificate in OpenSSH format) + certificate = db.Column(db.Text, nullable=False) + + # Certificate metadata + serial = db.Column(db.String(255), nullable=False, unique=True, index=True) + key_id = db.Column(db.String(255), nullable=False) # Usually user email + cert_type = db.Column( + db.Enum(CertType, values_callable=lambda x: [e.value for e in x]), + default=CertType.USER, + nullable=False, + ) + + # Principals (JSON list) - e.g., ["prod-servers", "dev-servers"] + principals = db.Column(db.JSON, nullable=False, default=list) + + # Validity period + valid_after = db.Column(db.DateTime, nullable=False) + valid_before = db.Column(db.DateTime, nullable=False) + + # Revocation status + revoked = db.Column(db.Boolean, default=False, nullable=False, index=True) + revoked_at = db.Column(db.DateTime, nullable=True) + revoke_reason = db.Column(db.String(255), nullable=True) + + # Status tracking + status = db.Column( + db.Enum(CertificateStatus, values_callable=lambda x: [e.value for e in x]), + default=CertificateStatus.ISSUED, + nullable=False, + index=True, + ) + + # Request metadata + request_ip = db.Column(db.String(45), nullable=True) + request_user_agent = db.Column(db.String(512), nullable=True) + + # Critical options (JSON) - OpenSSH critical options + # See: https://man.openbsd.org/ssh-cert + critical_options = db.Column(db.JSON, nullable=True, default=dict) + + # Extensions (JSON) - OpenSSH extensions + # Common ones: permit-X11-forwarding, permit-agent-forwarding, permit-pty, etc. + extensions = db.Column(db.JSON, nullable=True, default=dict) + + # Relationships + ca = db.relationship("CA", back_populates="certificates") + user = db.relationship("User", back_populates="ssh_certificates") + ssh_key = db.relationship( + "SSHKey", + back_populates="certificates", + foreign_keys="SSHCertificate.ssh_key_id", + ) + audit_logs = db.relationship( + "CertificateAuditLog", + back_populates="certificate", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + db.Index("idx_cert_user_status", "user_id", "status"), + db.Index("idx_cert_validity", "valid_after", "valid_before"), + db.Index("idx_cert_revoked", "revoked", "revoked_at"), + ) + + def __repr__(self): + """String representation of SSHCertificate.""" + return f"" + + def to_dict(self, exclude=None): + """Convert certificate to dictionary.""" + exclude = exclude or [] + # Optionally exclude the certificate content (it's large) + if "certificate" not in exclude: + exclude.append("certificate") + data = super().to_dict(exclude=exclude) + + # Add computed fields + data["is_valid"] = self.is_valid() + data["days_until_expiry"] = self.days_until_expiry() + + return data + + def is_valid(self): + """Check if certificate is currently valid. + + Returns: + True if certificate is issued, not revoked, and within validity period + """ + if self.revoked or self.status == CertificateStatus.REVOKED: + return False + + now = datetime.utcnow() + return self.valid_after <= now <= self.valid_before + + def is_expired(self): + """Check if certificate has expired. + + Returns: + True if current time is past valid_before + """ + return datetime.utcnow() > self.valid_before + + def days_until_expiry(self): + """Get number of days until certificate expires. + + Returns: + Number of days remaining (negative if already expired) + """ + delta = self.valid_before - datetime.utcnow() + return delta.days + (1 if delta.seconds > 0 else 0) + + def revoke(self, reason=None): + """Revoke this certificate. + + Args: + reason: Optional reason for revocation + """ + self.revoked = True + self.revoked_at = datetime.utcnow() + self.revoke_reason = reason + self.status = CertificateStatus.REVOKED + self.save() + + def mark_expired(self): + """Mark certificate as expired when validity period ends.""" + self.status = CertificateStatus.EXPIRED + self.save() diff --git a/gatehouse_app/models/ssh_key.py b/gatehouse_app/models/ssh_key.py new file mode 100644 index 0000000..0d6fbca --- /dev/null +++ b/gatehouse_app/models/ssh_key.py @@ -0,0 +1,96 @@ +"""SSH Key model.""" +from datetime import datetime +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class SSHKey(BaseModel): + """SSH Key model representing a user's SSH public key. + + This model stores SSH public keys that users register for certificate signing. + Users must verify ownership of the key before it can be used for signing certificates. + """ + + __tablename__ = "ssh_keys" + + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id"), + nullable=False, + index=True, + ) + + # SSH key payload in OpenSSH format (e.g., "ssh-rsa AAAAB3Nz...") + payload = db.Column(db.Text, nullable=False, unique=True) + + # SHA256 fingerprint for quick comparison + fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True) + + # Optional description for the key (e.g., "My laptop key") + description = db.Column(db.String(255), nullable=True) + + # Verification status + verified = db.Column(db.Boolean, default=False, nullable=False, index=True) + verified_at = db.Column(db.DateTime, nullable=True) + + # Verification challenge + verify_text = db.Column(db.String(255), nullable=True) + verify_text_created_at = db.Column(db.DateTime, nullable=True) + + # Key type extracted from the key (ssh-rsa, ssh-ed25519, etc.) + key_type = db.Column(db.String(50), nullable=True) + + # Key bits/length + key_bits = db.Column(db.Integer, nullable=True) + + # Comment from the key (usually email or key name) + key_comment = db.Column(db.String(255), nullable=True) + + # Relationships + user = db.relationship("User", back_populates="ssh_keys") + certificates = db.relationship( + "SSHCertificate", + back_populates="ssh_key", + cascade="all, delete-orphan", + foreign_keys="SSHCertificate.ssh_key_id", + ) + + __table_args__ = ( + db.Index("idx_ssh_key_user_verified", "user_id", "verified"), + ) + + def __repr__(self): + """String representation of SSHKey.""" + return f"" + + def to_dict(self, exclude=None): + """Convert SSH key to dictionary.""" + exclude = exclude or [] + exclude.extend(["payload", "verify_text"]) # Never expose these in API + data = super().to_dict(exclude=exclude) + + # Add computed fields + data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None]) + + return data + + def mark_verified(self): + """Mark this SSH key as verified.""" + self.verified = True + self.verified_at = datetime.utcnow() + self.save() + + def needs_verification_refresh(self, max_age_hours=24): + """Check if verification challenge needs to be refreshed. + + Args: + max_age_hours: Maximum age of verification challenge in hours + + Returns: + True if verification challenge is stale + """ + if not self.verify_text_created_at: + return True + + age = datetime.utcnow() - self.verify_text_created_at + return age.total_seconds() > (max_age_hours * 3600) diff --git a/gatehouse_app/models/user.py b/gatehouse_app/models/user.py index 4da83ca..a5b3aff 100644 --- a/gatehouse_app/models/user.py +++ b/gatehouse_app/models/user.py @@ -55,6 +55,18 @@ class User(BaseModel): cascade="all, delete-orphan", foreign_keys="PrincipalMembership.user_id", ) + ssh_keys = db.relationship( + "SSHKey", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="SSHKey.user_id", + ) + ssh_certificates = db.relationship( + "SSHCertificate", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="SSHCertificate.user_id", + ) def __repr__(self): """String representation of User.""" diff --git a/gatehouse_app/services/oauth_flow_service.py b/gatehouse_app/services/oauth_flow_service.py index 89b1745..2484cac 100644 --- a/gatehouse_app/services/oauth_flow_service.py +++ b/gatehouse_app/services/oauth_flow_service.py @@ -12,7 +12,7 @@ from gatehouse_app.models import User, AuthenticationMethod from gatehouse_app.models.authentication_method import OAuthState from gatehouse_app.models.base import BaseModel from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode -from gatehouse_app.utils.constants import AuthMethodType +from gatehouse_app.utils.constants import AuthMethodType, AuditAction from gatehouse_app.services.audit_service import AuditService from gatehouse_app.services.external_auth_service import ( ExternalAuthService, @@ -139,7 +139,7 @@ class OAuthFlowService: except ExternalAuthError as e: # Log failed initiation AuditService.log_action( - action="external_auth.login.initiated", + action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED, organization_id=organization_id, metadata={ "provider_type": provider_type_str, @@ -236,7 +236,7 @@ class OAuthFlowService: except ExternalAuthError as e: AuditService.log_action( - action="external_auth.register.initiated", + action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED, organization_id=organization_id, metadata={ "provider_type": provider_type_str, @@ -399,6 +399,27 @@ class OAuthFlowService: access_token=tokens["access_token"], ) + if not user_info.get("provider_user_id"): + raise OAuthFlowError( + "Provider did not return a user identifier (sub claim). " + "Cannot complete authentication.", + "MISSING_PROVIDER_USER_ID", + 400, + ) + + if not user_info.get("email"): + raise OAuthFlowError( + "Provider did not return an email address. " + "Cannot complete authentication.", + "MISSING_EMAIL", + 400, + ) + + logger.debug( + f"Got user_info from provider: sub={user_info['provider_user_id']}, " + f"email={user_info['email']}, email_verified={user_info.get('email_verified')}" + ) + # Look up user by provider_user_id auth_method = AuthenticationMethod.query.filter_by( method_type=provider_type, diff --git a/gatehouse_app/services/ssh_ca_signing_service.py b/gatehouse_app/services/ssh_ca_signing_service.py new file mode 100644 index 0000000..574c4bf --- /dev/null +++ b/gatehouse_app/services/ssh_ca_signing_service.py @@ -0,0 +1,333 @@ +"""SSH Certificate Authority signing service. + +Handles SSH certificate signing operations, leveraging sshkey-tools library. +This service is a Gatehouse-integrated version of the secuird/ssh_ca.py logic. +""" +import logging +import os +from datetime import datetime, timedelta +from typing import List, Dict, Any, Optional + +from sshkey_tools.cert import SSHCertificate, CertificateFields +from sshkey_tools.keys import PublicKey, PrivateKey + +from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config +from gatehouse_app.exceptions import SSHCAError, ValidationError +from gatehouse_app.utils.crypto import compute_ssh_fingerprint + +logger = logging.getLogger(__name__) + + +class SSHCASigningError(Exception): + """SSH CA signing operation error.""" + pass + + +class SSHCertificateSigningRequest: + """Represents an SSH certificate signing request.""" + + def __init__( + self, + ssh_public_key: str, + principals: List[str], + key_id: str, + cert_type: str = "user", + expiry_hours: Optional[int] = None, + critical_options: Optional[Dict[str, str]] = None, + extensions: Optional[List[str]] = None, + ): + """Initialize signing request. + + Args: + ssh_public_key: Public key in OpenSSH format (e.g., "ssh-ed25519 AAAA...") + principals: List of principals (e.g., ["prod-servers", "staging"]) + key_id: Key identifier (usually user email) + cert_type: Certificate type - "user" or "host" (default: user) + expiry_hours: Certificate validity in hours + critical_options: Critical options dict + extensions: List of extensions (e.g., ["permit-pty", "permit-agent-forwarding"]) + """ + self.ssh_public_key = ssh_public_key + self.principals = principals or [] + self.key_id = key_id + self.cert_type = cert_type + self.expiry_hours = expiry_hours + self.critical_options = critical_options or {} + self.extensions = extensions or [] + + def validate(self) -> List[str]: + """Validate the signing request. + + Returns: + List of validation errors (empty if valid) + """ + errors = [] + config = get_ssh_ca_config() + + # Validate cert type + if self.cert_type not in ("user", "host"): + errors.append(f"Invalid cert_type: {self.cert_type}. Must be 'user' or 'host'") + + # Validate SSH public key + if not self.ssh_public_key or len(self.ssh_public_key) < 16: + errors.append("SSH public key is missing or invalid") + else: + try: + PublicKey.from_string(self.ssh_public_key) + except Exception as e: + errors.append(f"SSH public key is not valid: {str(e)}") + + # Validate principals + if not self.principals or len(self.principals) == 0: + errors.append("At least one principal is required") + else: + max_principals = config.get_int('max_principals_per_cert') + if len(self.principals) > max_principals: + errors.append( + f"Too many principals ({len(self.principals)}). " + f"Maximum is {max_principals}" + ) + + # Validate key_id + if not self.key_id or len(self.key_id) < 5: + errors.append("key_id is missing or too short (minimum 5 characters)") + else: + max_id_len = config.get_int('max_key_id_length') + if len(self.key_id) > max_id_len: + errors.append(f"key_id exceeds maximum length of {max_id_len}") + + # Validate expiry_hours + if self.expiry_hours is not None: + if not isinstance(self.expiry_hours, int) or self.expiry_hours <= 0: + errors.append("expiry_hours must be a positive integer") + else: + max_validity = config.get_int('max_cert_validity_hours') + if self.expiry_hours > max_validity: + errors.append( + f"Requested expiry ({self.expiry_hours}h) exceeds " + f"maximum allowed ({max_validity}h)" + ) + + return errors + + +class SSHCertificateSigningResponse: + """Represents a signed SSH certificate response.""" + + def __init__( + self, + certificate: str, + serial: str, + valid_after: datetime, + valid_before: datetime, + principals: Optional[List[str]] = None, + ): + """Initialize signing response. + + Args: + certificate: Full certificate in OpenSSH format + serial: Certificate serial number + valid_after: Validity start datetime + valid_before: Validity end datetime + principals: List of principals the cert was issued for + """ + self.certificate = certificate + self.serial = serial + self.valid_after = valid_after + self.valid_before = valid_before + self.principals = principals or [] + + def to_dict(self) -> Dict[str, Any]: + """Convert response to dictionary.""" + return { + 'certificate': self.certificate, + 'serial': self.serial, + 'valid_after': self.valid_after.isoformat(), + 'valid_before': self.valid_before.isoformat(), + } + + +class SSHCASigningService: + """Service for signing SSH certificates. + + This service handles all SSH certificate signing operations. + It uses configuration from ssh_ca_config to apply rules and limits. + """ + + def __init__(self): + """Initialize the SSH CA signing service.""" + self.config = get_ssh_ca_config() + self.logger = logger + + def _load_ca_key_from_config(self) -> str: + """Load CA private key from config (local file or env var). + + Returns: + CA private key in PEM/OpenSSH format as string + + Raises: + SSHCASigningError: If key cannot be loaded + """ + # Check env var first + key_content = os.environ.get('SSH_CA_PRIVATE_KEY') + if key_content: + return key_content + + # Load from file path + key_path = self.config.get_str('ca_key_path', '').strip() + if not key_path: + raise SSHCASigningError( + "CA private key not configured. Set SSH_CA_PRIVATE_KEY env var " + "or ca_key_path in etc/ssh_ca.conf" + ) + + key_path = os.path.expandvars(os.path.expanduser(key_path)) + if not os.path.exists(key_path): + raise SSHCASigningError(f"CA private key file not found: {key_path}") + + with open(key_path, 'r') as f: + return f.read() + + def sign_certificate( + self, + signing_request: SSHCertificateSigningRequest, + ca_private_key: Optional[str] = None, + ) -> SSHCertificateSigningResponse: + """Sign an SSH certificate. + + Args: + signing_request: SSHCertificateSigningRequest instance + ca_private_key: CA private key in PEM format. If not provided, + loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var) + + Returns: + SSHCertificateSigningResponse with signed certificate + + Raises: + SSHCASigningError: If signing fails + ValidationError: If request is invalid + """ + # Validate request + errors = signing_request.validate() + if errors: + error_msg = "; ".join(errors) + self.logger.error(f"Certificate signing validation failed: {error_msg}") + raise ValidationError(f"Certificate signing validation failed: {error_msg}") + + # Load CA key if not provided + if ca_private_key is None: + ca_private_key = self._load_ca_key_from_config() + + try: + # Parse CA private key + try: + ca_key = PrivateKey.from_string(ca_private_key) + except Exception as e: + self.logger.error(f"Failed to load CA private key: {str(e)}") + raise SSHCASigningError(f"Invalid CA private key: {str(e)}") + + # Parse user's public key + try: + user_pub_key = PublicKey.from_string(signing_request.ssh_public_key) + except Exception as e: + self.logger.error(f"Failed to parse user public key: {str(e)}") + raise SSHCASigningError(f"Invalid user public key: {str(e)}") + + # Create certificate + certificate = SSHCertificate.create( + subject_pubkey=user_pub_key, + ca_privkey=ca_key, + ) + + # Set validity period + now = datetime.utcnow() + expiry_hours = signing_request.expiry_hours or self.config.get_int('cert_validity_hours') + valid_before = now + timedelta(hours=expiry_hours) + + # Set certificate fields + cert_type = 1 if signing_request.cert_type == "user" else 0 + + certificate.fields.cert_type = cert_type + certificate.fields.key_id = signing_request.key_id + certificate.fields.principals = signing_request.principals + certificate.fields.valid_after = now + certificate.fields.valid_before = valid_before + + # Set extensions + extensions = signing_request.extensions + if not extensions and self.config.get_bool('extensions_enabled'): + extensions = self.config.get_list('extensions') + + certificate.fields.extensions = extensions or [] + certificate.fields.critical_options = signing_request.critical_options or {} + + # Validate certificate before signing + if not certificate.can_sign(): + raise SSHCASigningError("Certificate cannot be signed") + + # Sign the certificate + certificate.sign() + + # Verify the certificate + try: + certificate.verify(ca_key.public_key, raise_on_error=True) + except Exception as e: + self.logger.error(f"Certificate verification failed: {str(e)}") + raise SSHCASigningError(f"Certificate verification failed: {str(e)}") + + # Extract serial from certificate + serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial) + + # Build response + cert_string = certificate.to_string() + + self.logger.info( + f"Successfully signed certificate: serial={serial}, " + f"key_id={signing_request.key_id}, principals={signing_request.principals}" + ) + + return SSHCertificateSigningResponse( + certificate=cert_string, + serial=serial, + valid_after=now, + valid_before=valid_before, + principals=signing_request.principals, + ) + + except (SSHCASigningError, ValidationError): + raise + except Exception as e: + self.logger.error(f"Unexpected error during certificate signing: {str(e)}", exc_info=True) + raise SSHCASigningError(f"Error signing certificate: {str(e)}") + + def verify_ca_key(self, ca_private_key: str) -> Dict[str, Any]: + """Verify a CA private key is valid and extract metadata. + + Args: + ca_private_key: CA private key in PEM format + + Returns: + Dictionary with key metadata (fingerprint, key_type, etc.) + + Raises: + SSHCASigningError: If key is invalid + """ + try: + ca_key = PrivateKey.from_string(ca_private_key) + pub_key = ca_key.public_key + + # Compute fingerprint + fingerprint = compute_ssh_fingerprint(pub_key.to_string()) + + # Get key type + key_type = pub_key.keytype if hasattr(pub_key, 'keytype') else 'unknown' + + return { + 'fingerprint': fingerprint, + 'key_type': key_type, + 'public_key': pub_key.to_string(), + 'valid': True, + } + except Exception as e: + self.logger.error(f"CA key verification failed: {str(e)}") + raise SSHCASigningError(f"Invalid CA key: {str(e)}") diff --git a/gatehouse_app/services/ssh_key_service.py b/gatehouse_app/services/ssh_key_service.py new file mode 100644 index 0000000..69cf7fa --- /dev/null +++ b/gatehouse_app/services/ssh_key_service.py @@ -0,0 +1,373 @@ +"""SSH Key management service.""" +import base64 +import logging +import os +import secrets +import subprocess +import tempfile +from datetime import datetime, timedelta +from typing import Optional, List, Dict, Any + +from gatehouse_app.extensions import db +from gatehouse_app.models import SSHKey, User +from gatehouse_app.exceptions import ( + SSHKeyError, + SSHKeyNotFoundError, + SSHKeyAlreadyExistsError, + SSHKeyNotVerifiedError, + ValidationError, + UserNotFoundError, +) +from gatehouse_app.utils.crypto import ( + compute_ssh_fingerprint, + verify_ssh_key_format, + extract_ssh_key_type, + extract_ssh_key_comment, +) +from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config + +logger = logging.getLogger(__name__) + + +class SSHKeyService: + """Service for managing SSH keys.""" + + def __init__(self): + """Initialize SSH key service.""" + self.config = get_ssh_ca_config() + + def add_ssh_key( + self, + user_id: str, + public_key: str, + description: Optional[str] = None, + ) -> SSHKey: + """Add an SSH public key for a user. + + Args: + user_id: ID of the user + public_key: SSH public key in OpenSSH format + description: Optional description of the key + + Returns: + Created SSHKey instance + + Raises: + UserNotFoundError: If user doesn't exist + SSHKeyError: If key format is invalid + SSHKeyAlreadyExistsError: If key already exists + """ + # Verify user exists + user = User.query.get(user_id) + if not user: + raise UserNotFoundError(f"User {user_id} not found") + + # Validate key format + if not verify_ssh_key_format(public_key): + raise SSHKeyError("Invalid SSH public key format") + + # Compute fingerprint + try: + fingerprint = compute_ssh_fingerprint(public_key) + except Exception as e: + logger.error(f"Failed to compute fingerprint: {str(e)}") + raise SSHKeyError(f"Failed to compute key fingerprint: {str(e)}") + + # Check for duplicate (including soft-deleted records — fingerprint is unique in DB) + existing = SSHKey.query.filter_by(fingerprint=fingerprint).first() + if existing: + if existing.deleted_at is not None: + # Restore the soft-deleted key: clear deleted_at and update fields + existing.deleted_at = None + existing.user_id = user_id + existing.description = description or existing.description + existing.verified = False + existing.verified_at = None + existing.verify_text = None + existing.verify_text_created_at = None + db.session.commit() + logger.info( + f"Restored soft-deleted SSH key for user {user_id}: " + f"fingerprint={fingerprint}" + ) + return existing + raise SSHKeyAlreadyExistsError( + f"SSH key with fingerprint {fingerprint} already exists" + ) + + # Extract metadata + key_type = extract_ssh_key_type(public_key) + key_comment = extract_ssh_key_comment(public_key) + + # Create SSH key record + ssh_key = SSHKey( + user_id=user_id, + payload=public_key, + fingerprint=fingerprint, + description=description, + key_type=key_type, + key_comment=key_comment, + verified=False, + ) + + ssh_key.save() + + logger.info( + f"SSH key added for user {user_id}: " + f"fingerprint={fingerprint}, type={key_type}" + ) + + return ssh_key + + def get_ssh_key(self, key_id: str) -> SSHKey: + """Get an SSH key by ID. + + Args: + key_id: SSH key ID + + Returns: + SSHKey instance + + Raises: + SSHKeyNotFoundError: If key not found + """ + key = SSHKey.query.filter_by(id=key_id, deleted_at=None).first() + if not key: + raise SSHKeyNotFoundError(f"SSH key {key_id} not found") + return key + + def get_user_ssh_keys(self, user_id: str) -> List[SSHKey]: + """Get all SSH keys for a user. + + Args: + user_id: User ID + + Returns: + List of SSHKey instances + """ + return SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all() + + def get_user_verified_ssh_keys(self, user_id: str) -> List[SSHKey]: + """Get all verified SSH keys for a user. + + Args: + user_id: User ID + + Returns: + List of verified SSHKey instances + """ + return SSHKey.query.filter_by( + user_id=user_id, + verified=True, + deleted_at=None, + ).all() + + def delete_ssh_key(self, key_id: str) -> None: + """Soft-delete an SSH key. + + Args: + key_id: SSH key ID + + Raises: + SSHKeyNotFoundError: If key not found + """ + key = self.get_ssh_key(key_id) + key.delete() + + logger.info(f"SSH key deleted: {key_id}") + + def generate_verification_challenge(self, key_id: str) -> str: + """Generate a verification challenge for an SSH key. + + The user must sign this challenge text with their private key + to prove key ownership. + + Args: + key_id: SSH key ID + + Returns: + Verification challenge text + + Raises: + SSHKeyNotFoundError: If key not found + """ + key = self.get_ssh_key(key_id) + + # Generate random challenge + challenge = secrets.token_hex(32) + challenge_text = f"Please sign this to verify SSH key ownership: {challenge}" + + # Store challenge + key.verify_text = challenge_text + key.verify_text_created_at = datetime.utcnow() + key.save() + + logger.info(f"Generated verification challenge for SSH key {key_id}") + + return challenge_text + + def verify_ssh_key_ownership( + self, + key_id: str, + signature: str, + ) -> bool: + """Verify SSH key ownership via signature. + + The user must sign the verification challenge with their private key. + We verify the signature using the public key. + + Args: + key_id: SSH key ID + signature: Base64-encoded signature of the challenge + + Returns: + True if signature is valid + + Raises: + SSHKeyNotFoundError: If key not found + SSHKeyNotVerifiedError: If challenge is stale or missing + SSHKeyError: If verification fails + """ + key = self.get_ssh_key(key_id) + + # Check if challenge exists and is not stale + if not key.verify_text or not key.verify_text_created_at: + raise SSHKeyNotVerifiedError("No verification challenge generated") + + max_age = self.config.get_int('verification_challenge_max_age') + age = datetime.utcnow() - key.verify_text_created_at + if age.total_seconds() > (max_age * 3600): + raise SSHKeyNotVerifiedError("Verification challenge has expired") + + try: + # Verify the SSH signature using ssh-keygen -Y verify. + # The CLI signs the challenge with: ssh-keygen -Y sign -f -n file + # We verify with: ssh-keygen -Y verify -f -I -n file -s < + # + # allowed_signers format: " " + # We use the key fingerprint as the identity. + + sig_bytes = base64.b64decode(signature) + challenge_text = key.verify_text + "\n" + + with tempfile.TemporaryDirectory() as tmpdir: + allowed_signers_path = os.path.join(tmpdir, "allowed_signers") + sig_path = os.path.join(tmpdir, "message.sig") + message_path = os.path.join(tmpdir, "message.txt") + + identity = key.fingerprint + + # Write the allowed_signers file + with open(allowed_signers_path, "w") as f: + f.write(f"{identity} {key.payload}\n") + + # Write the signature file + with open(sig_path, "wb") as f: + f.write(sig_bytes) + + # Write the challenge message + with open(message_path, "w") as f: + f.write(challenge_text) + + result = subprocess.run( + [ + "ssh-keygen", "-Y", "verify", + "-f", allowed_signers_path, + "-I", identity, + "-n", "file", + "-s", sig_path, + ], + stdin=open(message_path, "rb"), + capture_output=True, + timeout=10, + ) + + if result.returncode != 0: + stderr = result.stderr.decode(errors="replace").strip() + logger.warning(f"SSH signature verification failed for key {key_id}: {stderr}") + raise SSHKeyError(f"Signature verification failed: {stderr}") + + key.mark_verified() + logger.info(f"SSH key verified: {key_id}") + return True + + except Exception as e: + logger.error(f"SSH key verification failed: {str(e)}") + raise SSHKeyError(f"Signature verification failed: {str(e)}") + + def get_key_fingerprint(self, key_id: str) -> str: + """Get the fingerprint of an SSH key. + + Args: + key_id: SSH key ID + + Returns: + Fingerprint string + + Raises: + SSHKeyNotFoundError: If key not found + """ + key = self.get_ssh_key(key_id) + return key.fingerprint + + def update_ssh_key_description(self, key_id: str, description: str) -> SSHKey: + """Update the description of an SSH key. + + Args: + key_id: SSH key ID + description: New description + + Returns: + Updated SSHKey instance + + Raises: + SSHKeyNotFoundError: If key not found + """ + key = self.get_ssh_key(key_id) + key.description = description + key.save() + + return key + + def cleanup_expired_challenges(self) -> int: + """Clean up expired verification challenges. + + Returns: + Number of challenges cleaned + """ + max_age = self.config.get_int('verification_challenge_max_age') + threshold = datetime.utcnow() - timedelta(hours=max_age) + + expired = SSHKey.query.filter( + SSHKey.verify_text_created_at < threshold, + SSHKey.verify_text_created_at.isnot(None), + SSHKey.deleted_at.is_(None), + ).update({"verify_text": None, "verify_text_created_at": None}) + + db.session.commit() + + logger.info(f"Cleaned up {expired} expired verification challenges") + return expired + + def cleanup_unverified_keys(self) -> int: + """Delete unverified SSH keys older than configured days. + + Returns: + Number of keys deleted + """ + days = self.config.get_int('auto_delete_unverified_days') + threshold = datetime.utcnow() - timedelta(days=days) + + old_unverified = SSHKey.query.filter( + SSHKey.verified == False, + SSHKey.created_at < threshold, + SSHKey.deleted_at.is_(None), + ).all() + + count = 0 + for key in old_unverified: + key.delete() + count += 1 + + logger.info(f"Deleted {count} unverified SSH keys older than {days} days") + return count diff --git a/gatehouse_app/utils/constants.py b/gatehouse_app/utils/constants.py index cd6a58b..000f75d 100644 --- a/gatehouse_app/utils/constants.py +++ b/gatehouse_app/utils/constants.py @@ -12,6 +12,16 @@ class UserStatus(str, Enum): COMPLIANCE_SUSPENDED = "compliance_suspended" +class Role(str, Enum): + """Generic role definitions (hierarchy: Admin > Manager > Member > Viewer > Guest).""" + + ADMIN = "admin" + MANAGER = "manager" + MEMBER = "member" + VIEWER = "viewer" + GUEST = "guest" + + class OrganizationRole(str, Enum): """Organization member roles.""" @@ -105,6 +115,37 @@ class AuditAction(str, Enum): EXTERNAL_AUTH_CONFIG_UPDATE = "external_auth.config.update" EXTERNAL_AUTH_CONFIG_DELETE = "external_auth.config.delete" + # SSH Key and Certificate actions + SSH_KEY_ADDED = "ssh.key.added" + SSH_KEY_VERIFIED = "ssh.key.verified" + SSH_KEY_DELETED = "ssh.key.deleted" + SSH_KEY_VALIDATION_FAILED = "ssh.key.validation.failed" + SSH_CERT_REQUESTED = "ssh.cert.requested" + SSH_CERT_ISSUED = "ssh.cert.issued" + SSH_CERT_FAILED = "ssh.cert.failed" + SSH_CERT_REVOKED = "ssh.cert.revoked" + SSH_CERT_EXPIRED = "ssh.cert.expired" + + # CA actions + CA_CREATED = "ca.created" + CA_UPDATED = "ca.updated" + CA_DELETED = "ca.deleted" + CA_KEY_ROTATED = "ca.key.rotated" + + # Principal actions + PRINCIPAL_CREATED = "principal.created" + PRINCIPAL_UPDATED = "principal.updated" + PRINCIPAL_DELETED = "principal.deleted" + PRINCIPAL_MEMBER_ADDED = "principal.member.added" + PRINCIPAL_MEMBER_REMOVED = "principal.member.removed" + + # Department actions + DEPARTMENT_CREATED = "department.created" + DEPARTMENT_UPDATED = "department.updated" + DEPARTMENT_DELETED = "department.deleted" + DEPARTMENT_MEMBER_ADDED = "department.member.added" + DEPARTMENT_MEMBER_REMOVED = "department.member.removed" + class OIDCGrantType(str, Enum): """OIDC grant types.""" diff --git a/gatehouse_app/utils/crypto.py b/gatehouse_app/utils/crypto.py new file mode 100644 index 0000000..0322a8c --- /dev/null +++ b/gatehouse_app/utils/crypto.py @@ -0,0 +1,128 @@ +"""Cryptographic utilities for SSH operations.""" +import hashlib +import base64 +from typing import Optional + + +def compute_ssh_fingerprint(public_key_str: str, hash_algorithm: str = "sha256") -> str: + """Compute the fingerprint of an SSH public key. + + Args: + public_key_str: SSH public key in OpenSSH format + hash_algorithm: Hash algorithm to use (sha256, sha1, md5) + + Returns: + Fingerprint string in the format "algorithm:hex_digest" + + Example: + >>> key = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKp2..." + >>> fp = compute_ssh_fingerprint(key) + >>> print(fp) + sha256:Kb+... + """ + if not public_key_str: + raise ValueError("Public key string is empty") + + # Parse OpenSSH format: "ssh-ed25519 [comment]" + parts = public_key_str.strip().split() + if len(parts) < 2: + raise ValueError("Invalid OpenSSH public key format") + + try: + # The base64-encoded key is the second part + key_bytes = base64.b64decode(parts[1]) + except Exception as e: + raise ValueError(f"Failed to decode public key: {str(e)}") + + # Compute hash + if hash_algorithm == "sha256": + digest = hashlib.sha256(key_bytes).digest() + # SSH format uses base64 encoding without padding + fingerprint = base64.b64encode(digest).decode().rstrip('=') + elif hash_algorithm == "sha1": + digest = hashlib.sha1(key_bytes).hexdigest() + fingerprint = digest + elif hash_algorithm == "md5": + digest = hashlib.md5(key_bytes).hexdigest() + # Format as colons + fingerprint = ':'.join(digest[i:i+2] for i in range(0, len(digest), 2)) + else: + raise ValueError(f"Unsupported hash algorithm: {hash_algorithm}") + + return f"{hash_algorithm}:{fingerprint}" + + +def verify_ssh_key_format(public_key_str: str) -> bool: + """Verify that a string is in valid OpenSSH public key format. + + Args: + public_key_str: Potential SSH public key + + Returns: + True if valid OpenSSH format, False otherwise + """ + if not public_key_str or not isinstance(public_key_str, str): + return False + + parts = public_key_str.strip().split() + + # Must have at least key type and key material + if len(parts) < 2: + return False + + key_type = parts[0] + + # Valid key types + valid_types = [ + 'ssh-rsa', + 'ssh-ed25519', + 'ecdsa-sha2-nistp256', + 'ecdsa-sha2-nistp384', + 'ecdsa-sha2-nistp521', + 'ssh-dss', + ] + + if key_type not in valid_types: + return False + + # Try to decode base64 + try: + base64.b64decode(parts[1]) + return True + except Exception: + return False + + +def extract_ssh_key_type(public_key_str: str) -> Optional[str]: + """Extract the key type from an OpenSSH public key. + + Args: + public_key_str: SSH public key in OpenSSH format + + Returns: + Key type (e.g., "ssh-ed25519") or None if invalid + """ + if not verify_ssh_key_format(public_key_str): + return None + + return public_key_str.strip().split()[0] + + +def extract_ssh_key_comment(public_key_str: str) -> Optional[str]: + """Extract the comment from an OpenSSH public key. + + Args: + public_key_str: SSH public key in OpenSSH format + + Returns: + Comment string or None if not present + """ + if not verify_ssh_key_format(public_key_str): + return None + + parts = public_key_str.strip().split() + if len(parts) >= 3: + # Everything after the second part is the comment + return ' '.join(parts[2:]) + + return None diff --git a/migrations/versions/007_add_ssh_ca_models.py b/migrations/versions/007_add_ssh_ca_models.py new file mode 100644 index 0000000..9749930 --- /dev/null +++ b/migrations/versions/007_add_ssh_ca_models.py @@ -0,0 +1,173 @@ +"""Add SSH CA models: SSHKey, SSHCertificate, CA, CertificateAuditLog. + +Revision ID: 007 +Revises: 006 +Create Date: 2026-02-27 11:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '007' +down_revision = '006' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### CA table ### + op.create_table('cas', + sa.Column('organization_id', sa.String(length=36), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('key_type', sa.Enum('ed25519', 'rsa', 'ecdsa', name='ca_key_type_enum'), nullable=False), + sa.Column('private_key', sa.Text(), nullable=False), + sa.Column('public_key', sa.Text(), nullable=False), + sa.Column('fingerprint', sa.String(length=255), nullable=False), + sa.Column('crl_enabled', sa.Boolean(), nullable=False), + sa.Column('crl_endpoint', sa.String(length=512), nullable=True), + sa.Column('default_cert_validity_hours', sa.Integer(), nullable=False), + sa.Column('max_cert_validity_hours', sa.Integer(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('rotated_at', sa.DateTime(), nullable=True), + sa.Column('rotation_reason', sa.String(length=255), nullable=True), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id'), + sa.UniqueConstraint('fingerprint'), + sa.UniqueConstraint('organization_id', 'name', name='uix_org_ca_name') + ) + op.create_index(op.f('ix_cas_organization_id'), 'cas', ['organization_id'], unique=False) + op.create_index('idx_ca_org_active', 'cas', ['organization_id', 'is_active'], unique=False) + + # ### SSHKey table ### + op.create_table('ssh_keys', + sa.Column('user_id', sa.String(length=36), nullable=False), + sa.Column('payload', sa.Text(), nullable=False), + sa.Column('fingerprint', sa.String(length=255), nullable=False), + sa.Column('description', sa.String(length=255), nullable=True), + sa.Column('verified', sa.Boolean(), nullable=False), + sa.Column('verified_at', sa.DateTime(), nullable=True), + sa.Column('verify_text', sa.String(length=255), nullable=True), + sa.Column('verify_text_created_at', sa.DateTime(), nullable=True), + sa.Column('key_type', sa.String(length=50), nullable=True), + sa.Column('key_bits', sa.Integer(), nullable=True), + sa.Column('key_comment', sa.String(length=255), nullable=True), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id'), + sa.UniqueConstraint('payload'), + sa.UniqueConstraint('fingerprint') + ) + op.create_index(op.f('ix_ssh_keys_user_id'), 'ssh_keys', ['user_id'], unique=False) + op.create_index(op.f('ix_ssh_keys_fingerprint'), 'ssh_keys', ['fingerprint'], unique=False) + op.create_index(op.f('ix_ssh_keys_verified'), 'ssh_keys', ['verified'], unique=False) + op.create_index('idx_ssh_key_user_verified', 'ssh_keys', ['user_id', 'verified'], unique=False) + + # ### SSHCertificate table ### + op.create_table('ssh_certificates', + sa.Column('ca_id', sa.String(length=36), nullable=False), + sa.Column('user_id', sa.String(length=36), nullable=False), + sa.Column('ssh_key_id', sa.String(length=36), nullable=False), + sa.Column('certificate', sa.Text(), nullable=False), + sa.Column('serial', sa.String(length=255), nullable=False), + sa.Column('key_id', sa.String(length=255), nullable=False), + sa.Column('cert_type', sa.Enum('user', 'host', name='ssh_cert_type_enum'), nullable=False), + sa.Column('principals', sa.JSON(), nullable=False), + sa.Column('valid_after', sa.DateTime(), nullable=False), + sa.Column('valid_before', sa.DateTime(), nullable=False), + sa.Column('revoked', sa.Boolean(), nullable=False), + sa.Column('revoked_at', sa.DateTime(), nullable=True), + sa.Column('revoke_reason', sa.String(length=255), nullable=True), + sa.Column('status', sa.Enum('requested', 'issued', 'revoked', 'expired', 'superseded', name='ssh_cert_status_enum'), nullable=False), + sa.Column('request_ip', sa.String(length=45), nullable=True), + sa.Column('request_user_agent', sa.String(length=512), nullable=True), + sa.Column('critical_options', sa.JSON(), nullable=True), + sa.Column('extensions', sa.JSON(), nullable=True), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ), + sa.ForeignKeyConstraint(['ssh_key_id'], ['ssh_keys.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id'), + sa.UniqueConstraint('serial') + ) + op.create_index(op.f('ix_ssh_certificates_ca_id'), 'ssh_certificates', ['ca_id'], unique=False) + op.create_index(op.f('ix_ssh_certificates_user_id'), 'ssh_certificates', ['user_id'], unique=False) + op.create_index(op.f('ix_ssh_certificates_ssh_key_id'), 'ssh_certificates', ['ssh_key_id'], unique=False) + op.create_index(op.f('ix_ssh_certificates_serial'), 'ssh_certificates', ['serial'], unique=False) + op.create_index(op.f('ix_ssh_certificates_revoked'), 'ssh_certificates', ['revoked'], unique=False) + op.create_index(op.f('ix_ssh_certificates_status'), 'ssh_certificates', ['status'], unique=False) + op.create_index('idx_cert_user_status', 'ssh_certificates', ['user_id', 'status'], unique=False) + op.create_index('idx_cert_validity', 'ssh_certificates', ['valid_after', 'valid_before'], unique=False) + op.create_index('idx_cert_revoked', 'ssh_certificates', ['revoked', 'revoked_at'], unique=False) + + # ### CertificateAuditLog table ### + op.create_table('certificate_audit_logs', + sa.Column('certificate_id', sa.String(length=36), nullable=False), + sa.Column('user_id', sa.String(length=36), nullable=True), + sa.Column('action', sa.String(length=50), nullable=False), + sa.Column('ip_address', sa.String(length=45), nullable=True), + sa.Column('user_agent', sa.String(length=512), nullable=True), + sa.Column('request_id', sa.String(length=36), nullable=True), + sa.Column('message', sa.Text(), nullable=True), + sa.Column('extra_data', sa.JSON(), nullable=True), + sa.Column('success', sa.Boolean(), nullable=False), + sa.Column('error_message', sa.Text(), nullable=True), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['certificate_id'], ['ssh_certificates.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('id') + ) + op.create_index(op.f('ix_certificate_audit_logs_certificate_id'), 'certificate_audit_logs', ['certificate_id'], unique=False) + op.create_index(op.f('ix_certificate_audit_logs_user_id'), 'certificate_audit_logs', ['user_id'], unique=False) + op.create_index(op.f('ix_certificate_audit_logs_action'), 'certificate_audit_logs', ['action'], unique=False) + op.create_index('idx_cert_audit_cert_action', 'certificate_audit_logs', ['certificate_id', 'action'], unique=False) + op.create_index('idx_cert_audit_user', 'certificate_audit_logs', ['user_id', 'created_at'], unique=False) + + +def downgrade(): + op.drop_index('idx_cert_audit_user', table_name='certificate_audit_logs') + op.drop_index('idx_cert_audit_cert_action', table_name='certificate_audit_logs') + op.drop_index(op.f('ix_certificate_audit_logs_action'), table_name='certificate_audit_logs') + op.drop_index(op.f('ix_certificate_audit_logs_user_id'), table_name='certificate_audit_logs') + op.drop_index(op.f('ix_certificate_audit_logs_certificate_id'), table_name='certificate_audit_logs') + op.drop_table('certificate_audit_logs') + + op.drop_index('idx_cert_revoked', table_name='ssh_certificates') + op.drop_index('idx_cert_validity', table_name='ssh_certificates') + op.drop_index('idx_cert_user_status', table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_status'), table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_revoked'), table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_serial'), table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_ssh_key_id'), table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_user_id'), table_name='ssh_certificates') + op.drop_index(op.f('ix_ssh_certificates_ca_id'), table_name='ssh_certificates') + op.drop_table('ssh_certificates') + + op.drop_index('idx_ssh_key_user_verified', table_name='ssh_keys') + op.drop_index(op.f('ix_ssh_keys_verified'), table_name='ssh_keys') + op.drop_index(op.f('ix_ssh_keys_fingerprint'), table_name='ssh_keys') + op.drop_index(op.f('ix_ssh_keys_user_id'), table_name='ssh_keys') + op.drop_table('ssh_keys') + + op.drop_index('idx_ca_org_active', table_name='cas') + op.drop_index(op.f('ix_cas_organization_id'), table_name='cas') + op.drop_table('cas') diff --git a/migrations/versions/008_fix_authmethodtype_enum.py b/migrations/versions/008_fix_authmethodtype_enum.py new file mode 100644 index 0000000..ccf56e9 --- /dev/null +++ b/migrations/versions/008_fix_authmethodtype_enum.py @@ -0,0 +1,53 @@ +"""Add TOTP and WEBAUTHN to authmethodtype enum. + +Revision ID: 008 +Revises: 007 +Create Date: 2026-02-27 15:00:00.000000 + +The original migration (001_base) created authmethodtype with only: + PASSWORD, GOOGLE, GITHUB, MICROSOFT, SAML, OIDC + +This migration adds the missing TOTP and WEBAUTHN values so +has_totp_enabled() and has_webauthn_enabled() queries work correctly. +""" +from alembic import op +import sqlalchemy as sa + + +revision = '008' +down_revision = '007' +branch_labels = None +depends_on = None + + +def upgrade(): + # Add TOTP to the enum (idempotent approach using DO block) + op.execute(""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_enum + WHERE enumlabel = 'TOTP' + AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype') + ) THEN + ALTER TYPE authmethodtype ADD VALUE 'TOTP'; + END IF; + END$$; + """) + op.execute(""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_enum + WHERE enumlabel = 'WEBAUTHN' + AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype') + ) THEN + ALTER TYPE authmethodtype ADD VALUE 'WEBAUTHN'; + END IF; + END$$; + """) + + +def downgrade(): + # PostgreSQL does not support removing enum values; downgrade is a no-op. + pass diff --git a/migrations/versions/009_sync_auditaction_enum.py b/migrations/versions/009_sync_auditaction_enum.py new file mode 100644 index 0000000..59bc210 --- /dev/null +++ b/migrations/versions/009_sync_auditaction_enum.py @@ -0,0 +1,61 @@ +"""Sync auditaction enum with all AuditAction Python enum values. + +Revision ID: 009 +Revises: 008 +Create Date: 2026-02-27 15:20:00.000000 + +The auditaction DB enum was only created with the initial 17 values from 001_base.py. +All TOTP, WebAuthn, OAuth, SSH, CA, Principal, and Department audit actions were added +to the Python enum but never synced to the DB type. +""" +from alembic import op + + +revision = '009' +down_revision = '008' +branch_labels = None +depends_on = None + +MISSING_VALUES = [ + 'TOTP_ENROLL_INITIATED', 'TOTP_ENROLL_COMPLETED', 'TOTP_VERIFY_SUCCESS', + 'TOTP_VERIFY_FAILED', 'TOTP_DISABLED', 'TOTP_BACKUP_CODE_USED', + 'TOTP_BACKUP_CODES_REGENERATED', 'WEBAUTHN_REGISTER_INITIATED', + 'WEBAUTHN_REGISTER_COMPLETED', 'WEBAUTHN_REGISTER_FAILED', + 'WEBAUTHN_LOGIN_INITIATED', 'WEBAUTHN_LOGIN_SUCCESS', 'WEBAUTHN_LOGIN_FAILED', + 'WEBAUTHN_CREDENTIAL_DELETED', 'WEBAUTHN_CREDENTIAL_RENAMED', + 'ORG_SECURITY_POLICY_UPDATE', 'USER_SECURITY_POLICY_OVERRIDE_UPDATE', + 'MFA_POLICY_USER_SUSPENDED', 'MFA_POLICY_USER_COMPLIANT', + 'EXTERNAL_AUTH_LINK_INITIATED', 'EXTERNAL_AUTH_LINK_COMPLETED', + 'EXTERNAL_AUTH_LINK_FAILED', 'EXTERNAL_AUTH_UNLINK', 'EXTERNAL_AUTH_LOGIN', + 'EXTERNAL_AUTH_LOGIN_FAILED', 'EXTERNAL_AUTH_TOKEN_REFRESH', + 'EXTERNAL_AUTH_CONFIG_CREATE', 'EXTERNAL_AUTH_CONFIG_UPDATE', + 'EXTERNAL_AUTH_CONFIG_DELETE', 'SSH_KEY_ADDED', 'SSH_KEY_VERIFIED', + 'SSH_KEY_DELETED', 'SSH_KEY_VALIDATION_FAILED', 'SSH_CERT_REQUESTED', + 'SSH_CERT_ISSUED', 'SSH_CERT_FAILED', 'SSH_CERT_REVOKED', 'SSH_CERT_EXPIRED', + 'CA_CREATED', 'CA_UPDATED', 'CA_DELETED', 'CA_KEY_ROTATED', + 'PRINCIPAL_CREATED', 'PRINCIPAL_UPDATED', 'PRINCIPAL_DELETED', + 'PRINCIPAL_MEMBER_ADDED', 'PRINCIPAL_MEMBER_REMOVED', + 'DEPARTMENT_CREATED', 'DEPARTMENT_UPDATED', 'DEPARTMENT_DELETED', + 'DEPARTMENT_MEMBER_ADDED', 'DEPARTMENT_MEMBER_REMOVED', +] + + +def upgrade(): + for val in MISSING_VALUES: + op.execute(f""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_enum + WHERE enumlabel = '{val}' + AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction') + ) THEN + ALTER TYPE auditaction ADD VALUE '{val}'; + END IF; + END$$; + """) + + +def downgrade(): + # PostgreSQL does not support removing enum values; downgrade is a no-op. + pass diff --git a/migrations/versions/012_ca_nullable_org_and_cert_serial.py b/migrations/versions/012_ca_nullable_org_and_cert_serial.py new file mode 100644 index 0000000..8b7dd0e --- /dev/null +++ b/migrations/versions/012_ca_nullable_org_and_cert_serial.py @@ -0,0 +1,33 @@ +"""Make CA.organization_id nullable (system CA) and add cert_id to sign response + +Revision ID: 012_ca_nullable_org_and_cert_serial +Revises: 011_org_invite_tokens +Create Date: 2025-01-01 00:00:00.000000 +""" +from alembic import op +import sqlalchemy as sa + + +revision = '012_ca_nullable_org' +down_revision = '011_org_invite_tokens' +branch_labels = None +depends_on = None + + +def upgrade(): + # Allow CA records without an org (e.g. the global system-config CA) + with op.batch_alter_table('cas', schema=None) as batch_op: + batch_op.alter_column( + 'organization_id', + existing_type=sa.String(36), + nullable=True, + ) + + +def downgrade(): + with op.batch_alter_table('cas', schema=None) as batch_op: + batch_op.alter_column( + 'organization_id', + existing_type=sa.String(36), + nullable=False, + ) diff --git a/requirements/base.txt b/requirements/base.txt index 2f9c20d..3efc1cf 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -47,4 +47,7 @@ Flask-Limiter==3.5.0 # Logging python-json-logger==2.0.7 -qrcode[pil] \ No newline at end of file +qrcode[pil] + +# SSH CA Certificate signing +sshkey-tools==0.11.0 diff --git a/requirements/development.txt b/requirements/development.txt index bccb4ec..0c9b46d 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -20,3 +20,25 @@ watchdog==3.0.0 # Documentation sphinx==7.2.6 + +# Web framework & Database +Flask==3.0.0 +Flask-SQLAlchemy==3.1.1 +Flask-Migrate==4.0.5 +sqlalchemy-cockroachdb==2.0.3 + +# Utilities +colorlog==6.8.0 +coloredlogs==15.0.1 +prettytable==3.10.2 +tabulate==0.9.0 +requests==2.31.0 +pytz==2023.3 +python-dotenv==1.0.0 +pydantic==2.5.0 +PyJWT==2.8.0 +cryptography==41.0.7 +pycryptodome==3.20.0 +psycopg2==2.9.9 +sshkey-tools==0.10.3 +sendgrid==6.11.0