diff --git a/.env.example b/.env.example
index d87fa74..9e97462 100644
--- a/.env.example
+++ b/.env.example
@@ -40,5 +40,9 @@ LOG_TO_STDOUT=True
RATELIMIT_ENABLED=True
RATELIMIT_STORAGE_URL=redis://localhost:6379/1
-# Testing
-TESTING=False
+# SSH CA
+# Path to CA private key file (alternative to SSH_CA_PRIVATE_KEY env var)
+SSH_CA_KEY_PATH=/path/to/ca-users
+# Or set the key content directly (takes priority over SSH_CA_KEY_PATH):
+# SSH_CA_PRIVATE_KEY=
+
diff --git a/client/gatehouse-cli.py b/client/gatehouse-cli.py
new file mode 100755
index 0000000..2060c5a
--- /dev/null
+++ b/client/gatehouse-cli.py
@@ -0,0 +1,528 @@
+#!/usr/bin/python3
+import base64
+import os
+import sys
+import webbrowser
+import requests
+import argparse
+import jwt
+import json
+import datetime
+import pytz
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from urllib.parse import urlparse, parse_qsl
+from dotenv import load_dotenv
+from sshkey_tools.cert import SSHCertificate
+import logging
+import coloredlogs
+import subprocess
+
+# Load environment variables from the .env file
+load_dotenv()
+
+# Get the API_URL from the environment variables
+SIGN_URL = os.getenv("SIGN_URL", "http://localhost:5000")
+LISTENER_HOST_NAME = "127.0.0.1"
+LISTENER_SERVER_PORT = 8250
+CACHE_FILE = os.path.expanduser('~/.gatehouse/token_cache.json')
+os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True)
+CERT_FILE_PATH = "/tmp/ssh-cert"
+CHALLENGE_FILE_PATH = "/tmp/challenge.txt"
+CHALLENGE_SIG_FILE_PATH = "/tmp/challenge.txt.sig"
+
+# Configure logger
+logger = logging.getLogger(__name__)
+coloredlogs.install(level='DEBUG', logger=logger, fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+
+token = ""
+
+def auth_headers(content_type="application/json"):
+ """Return auth headers using the current cached token."""
+ return {"Authorization": f"Bearer {token}", "Content-Type": content_type}
+
+
+class MyServer(BaseHTTPRequestHandler):
+ def do_GET(self):
+ """Handle GET requests and process token reception."""
+ global server_done, token
+
+ self.send_response(200)
+ self.send_header("Content-type", "text/html")
+ self.end_headers()
+ self.wfile.write(bytes("
OIDC Workflow Tool", "utf-8"))
+ self.wfile.write(bytes("The token has been received
", "utf-8"))
+ self.wfile.write(bytes("You may now close this window.
", "utf-8"))
+ self.wfile.write(bytes("", "utf-8"))
+
+ parsed_url = urlparse(self.path)
+ query_data = dict(parse_qsl(parsed_url.query))
+ received_token = query_data.get('token')
+
+ if received_token:
+ token = received_token
+ server_done = True
+ logger.info("Token received")
+ save_token_to_cache(token)
+
+ def log_message(self, format, *args):
+ """Log messages using the logger instead of stdout."""
+ logger.info("%s - %s" % (self.client_address[0], format % args))
+
+
+def load_token_from_cache():
+ """Load the token from the cache file."""
+ if os.path.exists(CACHE_FILE):
+ with open(CACHE_FILE, 'r') as f:
+ data = json.load(f)
+ if 'token' in data:
+ return data['token']
+ return None
+
+def save_token_to_cache(token):
+ """Save the token to the cache file."""
+ with open(CACHE_FILE, 'w') as f:
+ json.dump({'token': token}, f)
+
+def clear_token_cache():
+ """Remove the cached token file."""
+ if os.path.exists(CACHE_FILE):
+ os.remove(CACHE_FILE)
+ logger.info("Cached token removed.")
+ else:
+ logger.info("No cached token found.")
+
+def decode_and_validate_token(token):
+ """Decode the JWT and validate its claims.
+
+ Returns True if the token is a valid, non-expired JWT.
+ Returns False if the token is not a JWT (e.g. opaque session token)
+ or if it has expired — callers should then fall back to /auth/me.
+ """
+ try:
+ decoded_token = jwt.decode(token, options={"verify_signature": False})
+ except jwt.exceptions.DecodeError:
+ # Not a JWT — likely an opaque session token; let /auth/me handle it.
+ return False
+ except Exception as e:
+ logger.debug(f"Unexpected JWT decode error: {e}")
+ return False
+
+ iat = decoded_token.get('iat')
+ exp = decoded_token.get('exp')
+
+ if iat is None or exp is None:
+ logger.debug("JWT is missing 'iat' or 'exp' claims — treating as invalid.")
+ return False
+
+ now = datetime.datetime.now(pytz.UTC)
+ exp_dt = datetime.datetime.fromtimestamp(exp, pytz.UTC)
+ iat_dt = datetime.datetime.fromtimestamp(iat, pytz.UTC)
+
+ logger.debug(f"JWT iat={iat_dt.isoformat()} exp={exp_dt.isoformat()}")
+
+ if exp_dt < now:
+ logger.debug("JWT has expired.")
+ return False
+
+ if iat_dt > now:
+ logger.debug("JWT 'iat' is in the future — clock skew?")
+
+ return True
+
+def request_token():
+ global server_done, token
+ server_done = False
+ logger.info("Starting request_token process.")
+
+ # Attempt to load the token from the cache
+ token = load_token_from_cache()
+ logger.debug("Token loaded from cache: %s", token)
+
+ # Validate the cached token, if it exists
+ if token:
+ try:
+ if decode_and_validate_token(token):
+ logger.info("Cached token is valid. Using cached token.")
+ return token
+ except Exception:
+ pass
+ # Try opaque token via /auth/me
+ try:
+ r = requests.get(
+ f"{SIGN_URL}/api/v1/auth/me",
+ headers={"Authorization": f"Bearer {token}"},
+ timeout=5,
+ )
+ if r.status_code == 200:
+ logger.info("Cached session token is valid. Using cached token.")
+ return token
+ except Exception:
+ pass
+ logger.info("Cached token is expired or invalid, requesting a new token.")
+ token = ""
+
+ # Prepare the redirect URL for the token request
+ redirect_url = f"http://{LISTENER_HOST_NAME}:{LISTENER_SERVER_PORT}/?token="
+ logger.info("Redirect URL: %s", redirect_url)
+
+ # Construct the token request URL
+ token_url = f"{SIGN_URL}/api/v1/token_please?redirect_url={redirect_url}"
+ logger.info("Token request URL: %s", token_url)
+
+ # Start the web server to handle the token response
+ logger.debug("Starting the HTTP server on %s:%d", LISTENER_HOST_NAME, LISTENER_SERVER_PORT)
+ webServer = HTTPServer((LISTENER_HOST_NAME, LISTENER_SERVER_PORT), MyServer)
+
+ # Open the web browser to initiate the token request
+ logger.info("Opening web browser to request token.")
+ webbrowser.open(token_url, new=2)
+
+ # Wait for the server to handle the request and receive the token
+ logger.debug("Waiting for the token response...")
+ while not server_done:
+ webServer.handle_request()
+ logger.debug("Server handled a request, server_done status: %s", server_done)
+
+ logger.info("Token received: %s", token)
+ return token
+
+def get_activated_ssh_key():
+ """Retrieve the list of SSH keys and return the ID of a verified key."""
+ try:
+ response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers())
+ if response.status_code != 200:
+ logger.error(f"Failed to retrieve SSH keys: {response.status_code} - {response.text}")
+ exit(1)
+
+ keys = response.json().get('data', {}).get('keys', [])
+ verified_keys = [k for k in keys if k['verified']]
+
+ if not verified_keys:
+ logger.error("No verified SSH keys found for the user.")
+ exit(1)
+
+ if len(verified_keys) > 1 and sys.stdout.isatty():
+ print("\nMultiple verified SSH keys found. Please choose one:")
+ for i, k in enumerate(verified_keys):
+ print(f" [{i+1}] {k['id'][:8]}... fingerprint={k.get('fingerprint','?')} name={k.get('key_comment','?')}")
+ try:
+ choice = int(input("Enter number: ").strip()) - 1
+ if 0 <= choice < len(verified_keys):
+ return verified_keys[choice]['id']
+ except (ValueError, EOFError):
+ pass
+ logger.info("Invalid choice; using the most recently added key.")
+
+ verified_keys.sort(key=lambda k: k.get('created_at', ''), reverse=True)
+ return verified_keys[0]['id']
+
+ except SystemExit:
+ raise
+ except Exception as e:
+ logger.error(f"Error while retrieving SSH keys: {e}")
+ exit(1)
+
+
+def fetch_my_principals():
+ """Fetch all principal names the current user is entitled to from the API.
+ For regular members: returns their assigned principals.
+ For org admins/owners: returns all principals in the org (they can sign for any).
+ """
+ global token
+ response = requests.get(
+ f"{SIGN_URL}/api/v1/users/me/principals",
+ headers={"Authorization": f"Bearer {token}"},
+ timeout=10,
+ )
+ if response.status_code != 200:
+ logger.error(f"Failed to fetch principals from server: {response.status_code} - {response.text}")
+ exit(1)
+
+ orgs = response.json().get("data", {}).get("orgs", [])
+ principal_names = []
+ for org in orgs:
+ # Admins/owners get all principals; regular members get only their assigned ones
+ if org.get("is_admin"):
+ source = org.get("all_principals", [])
+ else:
+ source = org.get("my_principals", [])
+ for p in source:
+ if p["name"] not in principal_names:
+ principal_names.append(p["name"])
+
+ return principal_names
+
+
+def request_certificate():
+ CERT_ID = os.getenv("CERT_ID") or get_activated_ssh_key()
+
+ principals = fetch_my_principals()
+ if not principals:
+ logger.error("You have no principals assigned. Contact your org admin.")
+ exit(1)
+ logger.info(f"Requesting certificate for principals: {', '.join(principals)}")
+
+ headers = {
+ 'content-type': 'application/json',
+ "Authorization": "bearer " + token
+ }
+
+ payload = {
+ 'cert_id': CERT_ID,
+ 'principals': principals,
+ }
+
+ try:
+ response = requests.post(f"{SIGN_URL}/api/v1/ssh/sign", json=payload, headers=headers)
+
+ if response.status_code == 201:
+ json_result = response.json().get('data', response.json())
+ with open(CERT_FILE_PATH, 'w') as f:
+ f.write(json_result['certificate'])
+ logger.info(f"Certificate signed successfully, located at {CERT_FILE_PATH}")
+ logger.info(f"Valid for principals: {', '.join(json_result.get('principals', principals))}")
+ logger.info("You can login to your destination server with the following command")
+ logger.info(f"\tssh user@server -o CertificateFile={CERT_FILE_PATH}")
+ else:
+ logger.error("Error in response from server")
+ logger.error(f"Status code: {response.status_code}")
+ logger.error(f"Response text: {response.text}")
+ except Exception as e:
+ logger.error(f"Error during certificate signing: {e}")
+
+def generate_and_sign_challenge(ssh_key_file, key_id):
+ """Fetch a challenge from the server, sign it with the SSH key, and submit the signature."""
+ logger.debug(f"generate_and_sign_challenge - {ssh_key_file} {key_id}")
+
+ # Fetch challenge text
+ try:
+ response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify", headers=auth_headers())
+ if response.status_code != 200:
+ logger.error(f"Server returned unexpected code {response.status_code}")
+ return False
+ resp_json = response.json()
+ data = resp_json.get('data', resp_json)
+ challenge_text = data.get('challenge_text', data.get('validationText', '')) + "\n"
+ except Exception as e:
+ logger.error(f"Unable to fetch SSH Key validation data: {e}")
+ return False
+
+ # Sign the challenge
+ try:
+ for path in (CHALLENGE_FILE_PATH, CHALLENGE_SIG_FILE_PATH):
+ if os.path.exists(path):
+ os.remove(path)
+
+ with open(CHALLENGE_FILE_PATH, 'w') as f:
+ f.write(challenge_text)
+
+ subprocess.run(
+ ["ssh-keygen", "-Y", "sign", "-f", ssh_key_file, "-n", "file", CHALLENGE_FILE_PATH],
+ check=True,
+ )
+
+ with open(CHALLENGE_SIG_FILE_PATH, 'rb') as f:
+ signature = base64.b64encode(f.read()).decode('utf-8')
+ except Exception as e:
+ logger.error(f"Unable to sign the challenge response: {e}")
+ return False
+
+ # Submit signature
+ try:
+ response = requests.post(
+ f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify",
+ headers=auth_headers(),
+ json={"signature": signature},
+ )
+ if response.status_code == 200:
+ logger.info("SSH key verified successfully.")
+ else:
+ logger.error(f"Verification failed: {response.status_code} - {response.text}")
+ except Exception as e:
+ logger.error(f"Unable to submit the challenge response: {e}")
+
+ return signature
+
+def remove_ssh_key(key_id=None):
+ """
+ Remove an SSH key from the server. If key_id is None, list keys and prompt user to pick one.
+ """
+ response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers())
+ if response.status_code != 200:
+ logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}")
+ exit(1)
+
+ keys = response.json().get('data', {}).get('keys', [])
+ if not keys:
+ logger.info("No SSH keys found for your user.")
+ return
+
+ if key_id:
+ target = next((k for k in keys if k['id'] == key_id), None)
+ if not target:
+ logger.error(f"Key ID {key_id} not found in your profile.")
+ exit(1)
+ keys_to_delete = [target]
+ else:
+ print("\nYour SSH keys:")
+ for i, k in enumerate(keys):
+ verified = "✓ verified" if k['verified'] else "✗ unverified"
+ print(f" [{i+1}] {k['id']} {verified} {k.get('description', '')} (added {k['created_at'][:10]})")
+ print(" [a] Delete ALL keys")
+ print(" [q] Quit")
+ choice = input("\nEnter number to delete (or 'a' for all, 'q' to quit): ").strip().lower()
+
+ if choice == 'q':
+ return
+ elif choice == 'a':
+ keys_to_delete = keys
+ else:
+ try:
+ idx = int(choice) - 1
+ if idx < 0 or idx >= len(keys):
+ raise ValueError()
+ keys_to_delete = [keys[idx]]
+ except ValueError:
+ logger.error("Invalid selection.")
+ exit(1)
+
+ for k in keys_to_delete:
+ del_response = requests.delete(f"{SIGN_URL}/api/v1/ssh/keys/{k['id']}", headers=auth_headers())
+ if del_response.status_code == 200:
+ logger.info(f"Key {k['id']} removed successfully.")
+ else:
+ logger.error(f"Failed to remove key {k['id']}: {del_response.status_code} - {del_response.text}")
+
+
+def add_ssh_key(ssh_key_file):
+ """Add an SSH key to the server and auto-verify it."""
+ if hasattr(ssh_key_file, 'read'):
+ key_bytes = ssh_key_file.read()
+ key_path = ssh_key_file.name
+ elif isinstance(ssh_key_file, bytes):
+ key_bytes = ssh_key_file
+ key_path = None
+ else:
+ key_path = str(ssh_key_file)
+ with open(key_path, 'rb') as f:
+ key_bytes = f.read()
+
+ ssh_key = key_bytes.decode('utf-8').strip()
+ payload = {
+ 'description': 'Added via gatehouse CLI tool',
+ 'key': ssh_key,
+ }
+
+ response = requests.post(f"{SIGN_URL}/api/v1/ssh/keys", json=payload, headers=auth_headers())
+ if response.status_code == 201:
+ ssh_key_id = response.json().get('data', {}).get('id')
+ logger.info(f"SSH key {ssh_key_id} added successfully")
+ if key_path:
+ private_key_path = key_path[:-4] if key_path.endswith('.pub') else key_path
+ generate_and_sign_challenge(private_key_path, ssh_key_id)
+ else:
+ logger.warning("No key file path available — skipping auto-verification. "
+ "Run with -k to enable automatic key verification.")
+ else:
+ logger.error(f"Failed to add SSH key: {response.status_code} - {response.text}")
+
+def checkCert():
+ logger.info("Running cert check")
+ if not os.path.isfile(CERT_FILE_PATH):
+ logger.warning("Certificate does not exist, new certificate required")
+ return 1
+
+ try:
+ certificate = SSHCertificate.from_file(CERT_FILE_PATH)
+ except Exception:
+ logger.warning("Certificate file is invalid or corrupt, renewal required")
+ return 1
+
+ # Get the current datetime
+ now = datetime.datetime.now()
+ logger.debug(certificate
+ )
+
+ # Check if the date is in the past or future
+ if certificate.get("valid_before") > now:
+ # Expiry is in the future
+ if args.force:
+ return 0
+ else:
+ logger.info("You have a valid SSH Certificate with the principals {} expiring at {}, not renewing. Use -f to force renewal".format(certificate.get("principals"), certificate.get("valid_before")))
+ return 0
+ else:
+ logger.warning("Certificate is not valid, renewal required")
+ return 1
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Sign an SSH key via a web service')
+ parser.add_argument("-k", "--ssh-key", type=argparse.FileType('rb'), dest="sshkeyfile", help="Add an SSH Public Key to your user profile in gatehouse")
+ parser.add_argument("-f", "--force", action='store_true', default=False, help="Force the certificate renewal")
+ parser.add_argument("-a", "--add-key", action='store_true', default=False, help="Add SSH key to the server")
+ parser.add_argument("-c", "--check-cert", action='store_true', default=False, help="Check the certificate, if it's valid exit 0, if it's invalid exit 1")
+ parser.add_argument("-r", "--request-cert", action='store_true', default=False, help="Request that gatehouse sign a new certificate for you based on an SSH public key on file in your profile")
+ parser.add_argument("--clear-cache", action='store_true', default=False, help="Remove the cached authentication token")
+ parser.add_argument("--remove-key", nargs='?', const='', metavar='KEY_ID', help="Remove an SSH key from your profile. Omit KEY_ID to pick interactively.")
+ parser.add_argument("--list-keys", action='store_true', default=False, help="List SSH keys in your profile")
+
+ args = parser.parse_args()
+ if not (args.check_cert or args.request_cert or args.add_key or args.clear_cache
+ or args.remove_key is not None or args.list_keys):
+ parser.error("At least one of --check-cert, --request-cert, --add-key, --list-keys, --remove-key, or --clear-cache must be provided.")
+
+
+ # Retrieve SSH key from environment variables if not provided via CLI
+ ssh_key_file = args.sshkeyfile if args.sshkeyfile else os.getenv('SSH_KEY_FILE')
+
+ if args.check_cert:
+ logger.info("Only checking certificate")
+ exit(checkCert())
+
+ if args.clear_cache:
+ clear_token_cache()
+ exit(0)
+
+ if args.remove_key is not None:
+ request_token()
+ remove_ssh_key(args.remove_key if args.remove_key else None)
+ exit(0)
+
+ if args.list_keys:
+ request_token()
+ response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers())
+ if response.status_code == 200:
+ keys = response.json().get('data', {}).get('keys', [])
+ if not keys:
+ print("No SSH keys found in your profile.")
+ else:
+ for k in keys:
+ verified = "✓ verified" if k.get('verified') else "✗ unverified"
+ print(f" {k['id']} {verified} {k.get('description', '')} (added {k['created_at'][:10]})")
+ else:
+ logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}")
+ exit(0)
+
+ if args.add_key:
+ request_token()
+
+ if not ssh_key_file:
+ logger.error("SSH key file is required to add SSH key")
+ exit(1)
+
+ # If ssh_key_file is retrieved from the environment, it will be a string (file path), so open it
+ if isinstance(ssh_key_file, str):
+ with open(ssh_key_file, 'rb') as f:
+ ssh_key_file = f.read()
+
+ add_ssh_key(ssh_key_file)
+ exit(0)
+
+
+ if args.request_cert:
+ request_token()
+ if args.force:
+ logger.info("Forcing renewal of certificate")
+ if args.force or checkCert() == 1:
+ request_certificate()
+ exit(0)
diff --git a/config/base.py b/config/base.py
index 4a2dfe6..73a4d7f 100644
--- a/config/base.py
+++ b/config/base.py
@@ -28,6 +28,11 @@ class BaseConfig:
# Encryption key for sensitive data (client secrets, tokens, etc.)
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "dev-encryption-key-change-in-production")
+
+ # Encryption key for CA private keys stored in the database.
+ # Must be set to a strong random secret in production.
+ # Any string is accepted — it is SHA-256 derived to a 32-byte Fernet key internally.
+ CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "dev-ca-encryption-key-change-in-production")
# Session configuration for WebAuthn cross-origin support
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "True").lower() == "true"
@@ -72,6 +77,13 @@ class BaseConfig:
RATELIMIT_STORAGE_URL = os.getenv("RATELIMIT_STORAGE_URL", "redis://localhost:6379/1")
RATELIMIT_DEFAULT = "100/hour"
+ # Per-endpoint auth rate limits (override via env vars for each environment)
+ RATELIMIT_AUTH_REGISTER = os.getenv("RATELIMIT_AUTH_REGISTER", "10 per minute; 50 per hour")
+ RATELIMIT_AUTH_LOGIN = os.getenv("RATELIMIT_AUTH_LOGIN", "20 per minute; 100 per hour")
+ RATELIMIT_AUTH_TOTP_VERIFY = os.getenv("RATELIMIT_AUTH_TOTP_VERIFY", "20 per minute; 100 per hour")
+ RATELIMIT_AUTH_FORGOT_PASSWORD = os.getenv("RATELIMIT_AUTH_FORGOT_PASSWORD", "5 per minute; 20 per hour")
+ RATELIMIT_AUTH_RESET_PASSWORD = os.getenv("RATELIMIT_AUTH_RESET_PASSWORD", "10 per minute; 30 per hour")
+
# Logging
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true"
@@ -116,3 +128,11 @@ class BaseConfig:
# Frontend URL (for OAuth callback redirects)
FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:8080")
+
+ # Email / SMTP
+ EMAIL_ENABLED = os.getenv("EMAIL_ENABLED", "False").lower() == "true"
+ SMTP_HOST = os.getenv("SMTP_HOST", "smtp.gmail.com")
+ SMTP_PORT = int(os.getenv("SMTP_PORT", "587"))
+ SMTP_USERNAME = os.getenv("SMTP_USERNAME", "")
+ SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
+ FROM_ADDRESS = os.getenv("FROM_ADDRESS", "noreply@gatehouse.local")
diff --git a/config/development.py b/config/development.py
index 840a84b..5436adf 100644
--- a/config/development.py
+++ b/config/development.py
@@ -25,3 +25,15 @@ class DevelopmentConfig(BaseConfig):
"CORS_ORIGINS",
"http://localhost:8080,http://localhost:3000,http://localhost:5173,https://ui.webauthn.local"
).split(",")
+
+ # ── Email / SMTP ──────────────────────────────────────────────────────────
+ # Read from .env so real SMTP credentials work in dev.
+ # Set EMAIL_ENABLED=false in .env to disable; defaults to True if SMTP_HOST is set.
+ EMAIL_ENABLED = os.getenv("EMAIL_ENABLED", "True").lower() == "true"
+ SMTP_HOST = os.getenv("SMTP_HOST", "localhost")
+ SMTP_PORT = int(os.getenv("SMTP_PORT", "1025"))
+ SMTP_USERNAME = os.getenv("SMTP_USERNAME") or None
+ SMTP_PASSWORD = os.getenv("SMTP_PASSWORD") or None
+ SMTP_USE_TLS = os.getenv("SMTP_USE_TLS", "").lower() == "true" if os.getenv("SMTP_USE_TLS") else int(os.getenv("SMTP_PORT", "1025")) not in (25, 1025)
+ FROM_ADDRESS = os.getenv("FROM_ADDRESS", "noreply@gatehouse.local")
+ EMAIL_FROM = FROM_ADDRESS # alias
diff --git a/config/testing.py b/config/testing.py
index 4ecfff0..aa988a7 100644
--- a/config/testing.py
+++ b/config/testing.py
@@ -12,6 +12,9 @@ class TestingConfig(BaseConfig):
# Explicitly set SECRET_KEY for testing
SECRET_KEY = os.getenv("SECRET_KEY", "test-secret-key-for-testing")
+ # CA key encryption — use a fixed test key so tests are deterministic
+ CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "test-ca-encryption-key-fixed-for-tests")
+
# Use in-memory SQLite for testing
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
SQLALCHEMY_ECHO = False
diff --git a/etc/ssh_ca.conf b/etc/ssh_ca.conf
new file mode 100644
index 0000000..212de10
--- /dev/null
+++ b/etc/ssh_ca.conf
@@ -0,0 +1,30 @@
+
+[default]
+# Certificate validity period (in hours)
+cert_validity_hours=8
+
+# Maximum certificate validity allowed (in hours)
+max_cert_validity_hours=720
+
+# CA private key path (required for local encryption mode)
+ca_key_path=
+
+# Certificate Field Limits
+max_principals_per_cert=256
+max_key_id_length=255
+
+# Verification challenge max age (in hours)
+verification_challenge_max_age=24
+
+# Cleanup: delete unverified SSH keys after this many days
+auto_delete_unverified_days=30
+
+[development]
+ca_key_path=${SSH_CA_KEY_PATH}
+cert_validity_hours=24
+
+[production]
+cert_validity_hours=8
+
+[testing]
+cert_validity_hours=8
diff --git a/gatehouse_app/__init__.py b/gatehouse_app/__init__.py
index a4e1043..f17c784 100644
--- a/gatehouse_app/__init__.py
+++ b/gatehouse_app/__init__.py
@@ -2,6 +2,9 @@
import os
import logging
+from dotenv import load_dotenv
+load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env'))
+
# Test debug logging - this should appear when running `flask run --debug`
_root_logger = logging.getLogger(__name__)
_root_logger.debug("[TEST] Debug logging is working!")
@@ -239,3 +242,6 @@ def initialize_oidc_jwks(app):
app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}")
except Exception as e:
app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}")
+
+# Create default app instance for gunicorn/wsgi
+app = create_app()
diff --git a/gatehouse_app/api/oidc.py b/gatehouse_app/api/oidc.py
index 5ec7700..43bf5b4 100644
--- a/gatehouse_app/api/oidc.py
+++ b/gatehouse_app/api/oidc.py
@@ -21,8 +21,12 @@ from gatehouse_app.extensions import db
from gatehouse_app.extensions import bcrypt as flask_bcrypt
from gatehouse_app.extensions import redis_client as _redis_client_ref # may be None until app init
from gatehouse_app.models import User, OIDCClient
-from gatehouse_app.models.organization import Organization
-from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
+from gatehouse_app.models.organization.organization import Organization
+from gatehouse_app.exceptions.auth_exceptions import (
+ InvalidCredentialsError,
+ AccountSuspendedError,
+ AccountInactiveError,
+)
# ---------------------------------------------------------------------------
# Helpers for Redis-backed OIDC pending state
@@ -326,7 +330,7 @@ def oidc_complete():
400: invalid request
401: invalid token
"""
- from gatehouse_app.models.session import Session as GHSession
+ from gatehouse_app.models.user.session import Session as GHSession
from gatehouse_app.utils.constants import SessionStatus
data = request.get_json(silent=True) or {}
@@ -343,6 +347,20 @@ def oidc_complete():
user_id = str(gh_session.user_id)
+ # Check the user is still active (not suspended after session was issued)
+ from gatehouse_app.models.user.user import User as _User
+ from gatehouse_app.utils.constants import UserStatus
+ _complete_user = _User.query.filter_by(id=user_id, deleted_at=None).first()
+ if not _complete_user or _complete_user.status in (
+ UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED, UserStatus.INACTIVE
+ ):
+ return api_response(
+ success=False,
+ message="Your account is not active or has been suspended.",
+ status=403,
+ error_type="ACCOUNT_SUSPENDED",
+ )
+
# Retrieve stashed OIDC params (consume = True removes from Redis atomically)
params = _fetch_oidc_params(oidc_session_id, consume=True)
if not params:
@@ -565,6 +583,28 @@ def oidc_authorize():
session["oidc_user_id"] = user_id
logger.debug("[OIDC] User authentication successful: user_id=%s, email=%s", user_id, email)
+ except AccountSuspendedError:
+ logger.debug("[OIDC] User authentication failed: account suspended for email=%s", email)
+ return _show_login_page(
+ client_id=client_id,
+ redirect_uri=redirect_uri,
+ scope=scope,
+ state=state,
+ nonce=nonce,
+ response_type=response_type,
+ error="Your account has been suspended. Please contact an administrator.",
+ )
+ except AccountInactiveError:
+ logger.debug("[OIDC] User authentication failed: account inactive for email=%s", email)
+ return _show_login_page(
+ client_id=client_id,
+ redirect_uri=redirect_uri,
+ scope=scope,
+ state=state,
+ nonce=nonce,
+ response_type=response_type,
+ error="Your account is not active. Please verify your email.",
+ )
except InvalidCredentialsError:
logger.debug("[OIDC] User authentication failed: invalid credentials for email=%s", email)
return _show_login_page(
@@ -600,7 +640,34 @@ def oidc_authorize():
if not user:
logger.debug("[OIDC] Redirecting with error: server_error (user not found)")
return _redirect_with_error(redirect_uri, "server_error", "User not found", state)
-
+
+ # Check account is still active (user could have been suspended after session start)
+ from gatehouse_app.utils.constants import UserStatus as _UserStatus
+ if user.status in (_UserStatus.SUSPENDED, _UserStatus.COMPLIANCE_SUSPENDED):
+ session.pop("oidc_user_id", None) # clear stale session
+ logger.debug("[OIDC] User is suspended, clearing session and showing login error: user_id=%s", user_id)
+ return _show_login_page(
+ client_id=client_id,
+ redirect_uri=redirect_uri,
+ scope=scope,
+ state=state,
+ nonce=nonce,
+ response_type=response_type,
+ error="Your account has been suspended. Please contact an administrator.",
+ )
+ if user.status == _UserStatus.INACTIVE:
+ session.pop("oidc_user_id", None)
+ logger.debug("[OIDC] User is inactive, clearing session and showing login error: user_id=%s", user_id)
+ return _show_login_page(
+ client_id=client_id,
+ redirect_uri=redirect_uri,
+ scope=scope,
+ state=state,
+ nonce=nonce,
+ response_type=response_type,
+ error="Your account is not active. Please verify your email.",
+ )
+
logger.debug("[OIDC] Generating authorization code...")
logger.debug("[OIDC] Authorization code params: client_id=%s, user_id=%s, redirect_uri=%s", client_id, user_id, redirect_uri)
logger.debug("[OIDC] Authorization code params: scopes=%s, state=%s, nonce=%s", valid_scopes, state, nonce)
diff --git a/gatehouse_app/api/v1/__init__.py b/gatehouse_app/api/v1/__init__.py
index b97cd94..14534ed 100644
--- a/gatehouse_app/api/v1/__init__.py
+++ b/gatehouse_app/api/v1/__init__.py
@@ -5,4 +5,6 @@ from flask import Blueprint
api_v1_bp = Blueprint("api_v1", __name__)
# Import route modules to register them
-from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth
+from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh
+
+api_v1_bp.register_blueprint(ssh.ssh_bp)
diff --git a/gatehouse_app/api/v1/auth.py b/gatehouse_app/api/v1/auth.py
index ebffe8d..993f213 100644
--- a/gatehouse_app/api/v1/auth.py
+++ b/gatehouse_app/api/v1/auth.py
@@ -1,8 +1,10 @@
"""Authentication endpoints."""
import json
-from flask import request, session, g, jsonify
+import logging
+from flask import request, session, g, jsonify, current_app
from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp
+from gatehouse_app.extensions import limiter
from gatehouse_app.utils.response import api_response
from gatehouse_app.schemas.auth_schema import (
RegisterSchema,
@@ -23,6 +25,7 @@ from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.webauthn_service import WebAuthnService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
+from gatehouse_app.services.notification_service import NotificationService
from gatehouse_app.utils.decorators import login_required
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
@@ -30,6 +33,7 @@ from gatehouse_app.exceptions.validation_exceptions import ConflictError, NotFou
@api_v1_bp.route("/auth/register", methods=["POST"])
+@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_REGISTER"])
def register():
"""
Register a new user.
@@ -57,14 +61,66 @@ def register():
full_name=data.get("full_name"),
)
+ # Send verification email
+ try:
+ from gatehouse_app.models import EmailVerificationToken
+ verify_token = EmailVerificationToken.generate(user_id=user.id)
+ app_url = current_app.config.get("APP_URL", "http://localhost:8080")
+ verify_link = f"{app_url}/verify-email?token={verify_token.token}"
+ subject = "Verify your Gatehouse email address"
+ body = (
+ f"Hi {user.full_name or user.email},\n\n"
+ f"Welcome to Gatehouse! Please verify your email address by clicking the link below (valid for 24 hours):\n"
+ f"{verify_link}\n\n"
+ f"Gatehouse Security Team"
+ )
+ NotificationService._send_email(to_address=user.email, subject=subject, body=body)
+ except Exception as exc:
+ logging.getLogger(__name__).warning(f"Failed to send verification email on register: {exc}")
+
# Create session
user_session = AuthService.create_session(user)
+ # ── Post-registration hints ─────────────────────────────────────────
+ from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
+ from gatehouse_app.models.user.user import User as _User
+ from datetime import datetime, timezone as _tz
+
+ now = datetime.now(_tz.utc)
+ pending_invites = OrgInviteToken.query.filter(
+ OrgInviteToken.email == user.email,
+ OrgInviteToken.accepted_at.is_(None),
+ OrgInviteToken.expires_at > now,
+ OrgInviteToken.deleted_at.is_(None),
+ ).all()
+
+ # Determine if this is the very first user ever registered on this
+ # instance (exactly 1 active user means it must be this one).
+ total_users = _User.query.filter(_User.deleted_at.is_(None)).count()
+ is_first_user = total_users == 1
+
+ expires_str = user_session.expires_at.isoformat()
+ if expires_str[-1] != "Z":
+ expires_str += "Z"
+
return api_response(
data={
"user": user.to_dict(),
"token": user_session.token,
- "expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(),
+ "expires_at": expires_str,
+ "is_first_user": is_first_user,
+ "pending_invites": [
+ {
+ "token": inv.token,
+ "organization": {
+ "id": str(inv.organization_id),
+ "name": inv.organization.name,
+ },
+ "role": inv.role,
+ "expires_at": inv.expires_at.isoformat(),
+ }
+ for inv in pending_invites
+ ],
},
message="Registration successful",
status=201,
@@ -81,6 +137,7 @@ def register():
@api_v1_bp.route("/auth/login", methods=["POST"])
+@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_LOGIN"])
def login():
"""
Login user.
@@ -179,7 +236,9 @@ def login():
"organization_id": org.organization_id,
"organization_name": org.organization_name,
"status": org.status,
+ "effective_mode": org.effective_mode,
"deadline_at": org.deadline_at,
+ "applied_at": org.applied_at,
}
for org in policy_result.compliance_summary.orgs
],
@@ -189,6 +248,36 @@ def login():
if is_compliance_only:
response_data["requires_mfa_enrollment"] = True
+ # ── Org-setup hint for org-less users ────────────────────────────────
+ # If the user has no organisation memberships, surface any pending
+ # invitations so the UI can redirect straight to /org-setup instead of
+ # showing an empty dashboard.
+ user_orgs = user.get_organizations()
+ if not user_orgs:
+ from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
+ from datetime import datetime, timezone as _tz
+ _now = datetime.now(_tz.utc)
+ pending_invites = OrgInviteToken.query.filter(
+ OrgInviteToken.email == user.email,
+ OrgInviteToken.accepted_at.is_(None),
+ OrgInviteToken.expires_at > _now,
+ OrgInviteToken.deleted_at.is_(None),
+ ).all()
+ response_data["pending_invites"] = [
+ {
+ "token": inv.token,
+ "organization": {
+ "id": str(inv.organization_id),
+ "name": inv.organization.name,
+ },
+ "role": inv.role,
+ "expires_at": inv.expires_at.isoformat(),
+ }
+ for inv in pending_invites
+ ]
+ # Flag so the UI knows to send this user through org-setup
+ response_data["requires_org_setup"] = True
+
return api_response(
data=response_data,
message="Login successful",
@@ -239,8 +328,13 @@ def get_current_user():
data={
"user": user.to_dict(),
"organizations": [
- {"id": org.id, "name": org.name, "slug": org.slug}
- for org in user.get_organizations()
+ {
+ "id": membership.organization.id,
+ "name": membership.organization.name,
+ "slug": membership.organization.slug,
+ "role": membership.role,
+ }
+ for membership in user.organization_memberships
],
},
message="User retrieved successfully",
@@ -284,7 +378,7 @@ def revoke_session(session_id):
401: Not authenticated
404: Session not found
"""
- from gatehouse_app.models.session import Session
+ from gatehouse_app.models.user.session import Session
# Ensure session belongs to current user
user_session = Session.query.filter_by(
@@ -392,6 +486,7 @@ def verify_totp_enrollment():
@api_v1_bp.route("/auth/totp/verify", methods=["POST"])
+@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_TOTP_VERIFY"])
def verify_totp():
"""
Verify TOTP code during login.
@@ -424,7 +519,7 @@ def verify_totp():
)
# Get user from database
- from gatehouse_app.models.user import User
+ from gatehouse_app.models.user.user import User
user = User.query.get(user_id)
if not user:
return api_response(
@@ -434,6 +529,18 @@ def verify_totp():
error_type="AUTHENTICATION_ERROR",
)
+ # Check account suspension before completing TOTP verification
+ from gatehouse_app.utils.constants import UserStatus
+ if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
+ session.pop("totp_pending_user_id", None)
+ session.pop("webauthn_pending_user_id", None)
+ return api_response(
+ success=False,
+ message="Account is suspended. Contact an administrator.",
+ status=403,
+ error_type="ACCOUNT_SUSPENDED",
+ )
+
# Verify TOTP code
AuthService.authenticate_with_totp(
user,
@@ -475,7 +582,9 @@ def verify_totp():
"organization_id": org.organization_id,
"organization_name": org.organization_name,
"status": org.status,
+ "effective_mode": org.effective_mode,
"deadline_at": org.deadline_at,
+ "applied_at": org.applied_at,
}
for org in policy_result.compliance_summary.orgs
],
@@ -806,7 +915,7 @@ def begin_webauthn_login():
data = schema.load(request.json)
# Find user by email
- from gatehouse_app.models.user import User
+ from gatehouse_app.models.user.user import User
user = User.query.filter_by(
email=data["email"].lower(),
deleted_at=None
@@ -820,7 +929,18 @@ def begin_webauthn_login():
status=404,
error_type="NOT_FOUND",
)
-
+
+ # Check account suspension before proceeding
+ from gatehouse_app.utils.constants import UserStatus
+ if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
+ logger.warning(f"WebAuthn login begin - suspended account attempt: {user.email}")
+ return api_response(
+ success=False,
+ message="Account is suspended. Contact an administrator.",
+ status=403,
+ error_type="ACCOUNT_SUSPENDED",
+ )
+
# Check if user has any WebAuthn credentials
if not user.has_webauthn_enabled():
logger.warning(f"WebAuthn login begin - no credentials for user: {user.email}")
@@ -893,7 +1013,7 @@ def complete_webauthn_login():
data = schema.load(request.json)
# Get user from database
- from gatehouse_app.models.user import User
+ from gatehouse_app.models.user.user import User
user = User.query.get(user_id)
if not user:
logger.error(f"WebAuthn login complete - user not found: {user_id}")
@@ -903,7 +1023,19 @@ def complete_webauthn_login():
status=401,
error_type="AUTHENTICATION_ERROR",
)
-
+
+ # Check account suspension before completing login
+ from gatehouse_app.utils.constants import UserStatus
+ if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
+ session.pop("webauthn_pending_user_id", None)
+ logger.warning(f"WebAuthn login complete - suspended account attempt: {user.email}")
+ return api_response(
+ success=False,
+ message="Account is suspended. Contact an administrator.",
+ status=403,
+ error_type="ACCOUNT_SUSPENDED",
+ )
+
# Extract challenge from client data
client_data = data.get("response", {}).get("clientDataJSON", "")
@@ -962,7 +1094,9 @@ def complete_webauthn_login():
"organization_id": org.organization_id,
"organization_name": org.organization_name,
"status": org.status,
+ "effective_mode": org.effective_mode,
"deadline_at": org.deadline_at,
+ "applied_at": org.applied_at,
}
for org in policy_result.compliance_summary.orgs
],
@@ -1039,6 +1173,19 @@ def delete_webauthn_credential(credential_id):
"""
user = g.current_user
+ # First check that the specific credential actually belongs to this user.
+ # Only then check whether it is the last one — otherwise a user with zero
+ # credentials gets a misleading "Cannot delete the last passkey" error
+ # instead of a 404.
+ credential_exists = WebAuthnService.credential_belongs_to_user(credential_id, user)
+ if not credential_exists:
+ return api_response(
+ success=False,
+ message="Credential not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
# Check if this is the last credential
credential_count = user.get_webauthn_credential_count()
if credential_count <= 1:
@@ -1142,3 +1289,403 @@ def get_webauthn_status():
},
message="WebAuthn status retrieved successfully",
)
+
+
+_pw_logger = logging.getLogger(__name__)
+
+
+@api_v1_bp.route("/auth/forgot-password", methods=["POST"])
+@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_FORGOT_PASSWORD"])
+def forgot_password():
+ """Request a password reset email.
+
+ Always returns 200 to avoid leaking account existence.
+
+ Request body:
+ email: User email address
+
+ Returns:
+ 200: Password reset email sent (or silently no-op if email not found)
+ """
+ from gatehouse_app.models import User, PasswordResetToken
+
+ data = request.get_json() or {}
+ email = (data.get("email") or "").strip().lower()
+
+ if not email:
+ return api_response(
+ success=False,
+ message="Email is required",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ # Always return 200 — don't leak whether the email exists
+ user = User.query.filter_by(email=email, deleted_at=None).first()
+ if user:
+ try:
+ reset_token = PasswordResetToken.generate(user_id=user.id)
+ app_url = current_app.config.get("APP_URL", "http://localhost:8080")
+ reset_link = f"{app_url}/reset-password?token={reset_token.token}"
+ subject = "Reset your Gatehouse password"
+ body = (
+ f"Hi {user.full_name or user.email},\n\n"
+ f"You requested a password reset for your Gatehouse account.\n\n"
+ f"Click the link below to reset your password (valid for 2 hours):\n"
+ f"{reset_link}\n\n"
+ f"If you did not request this, you can safely ignore this email.\n\n"
+ f"Gatehouse Security Team"
+ )
+ NotificationService._send_email(
+ to_address=user.email,
+ subject=subject,
+ body=body,
+ )
+ _pw_logger.info(f"Password reset token generated for user {user.id}")
+ except Exception as exc:
+ _pw_logger.exception(f"Error generating password reset token: {exc}")
+
+ return api_response(
+ data={},
+ message="If an account exists for this email, you will receive a password reset link shortly.",
+ )
+
+
+@api_v1_bp.route("/auth/reset-password", methods=["POST"])
+@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESET_PASSWORD"])
+def reset_password():
+ """Reset a user's password using a reset token.
+
+ Request body:
+ token: Password reset token from email
+ password: New password
+ password_confirm: Password confirmation
+
+ Returns:
+ 200: Password reset successfully
+ 400: Invalid or expired token / validation error
+ """
+ import bcrypt as _bcrypt
+ from gatehouse_app.extensions import bcrypt
+ from gatehouse_app.models import PasswordResetToken, AuthenticationMethod
+ from gatehouse_app.utils.constants import AuthMethodType
+
+ data = request.get_json() or {}
+ token_value = (data.get("token") or "").strip()
+ new_password = data.get("password") or ""
+ password_confirm = data.get("password_confirm") or ""
+
+ if not token_value or not new_password:
+ return api_response(
+ success=False,
+ message="Token and new password are required",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ if new_password != password_confirm:
+ return api_response(
+ success=False,
+ message="Passwords do not match",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ if len(new_password) < 8:
+ return api_response(
+ success=False,
+ message="Password must be at least 8 characters",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ reset_token = PasswordResetToken.query.filter_by(token=token_value).first()
+ if not reset_token or not reset_token.is_valid:
+ return api_response(
+ success=False,
+ message="This password reset link is invalid or has expired.",
+ status=400,
+ error_type="INVALID_TOKEN",
+ )
+
+ try:
+ user = reset_token.user
+ # Update the password hash on the authentication method
+ auth_method = AuthenticationMethod.query.filter_by(
+ user_id=user.id,
+ method_type=AuthMethodType.PASSWORD,
+ deleted_at=None,
+ ).first()
+ if auth_method:
+ auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
+ from gatehouse_app.extensions import db
+ db.session.add(auth_method)
+
+ reset_token.consume()
+ _pw_logger.info(f"Password reset for user {user.id}")
+
+ return api_response(
+ data={},
+ message="Your password has been reset. You can now sign in with your new password.",
+ )
+ except Exception as exc:
+ _pw_logger.exception(f"Error resetting password: {exc}")
+ return api_response(
+ success=False,
+ message="An error occurred while resetting your password.",
+ status=500,
+ error_type="INTERNAL_ERROR",
+ )
+
+
+@api_v1_bp.route("/auth/verify-email", methods=["POST"])
+def verify_email():
+ """Verify a user's email address using a verification token.
+
+ Request body:
+ token: Email verification token
+
+ Returns:
+ 200: Email verified successfully
+ 400: Invalid or expired token
+ """
+ from gatehouse_app.models import EmailVerificationToken
+
+ data = request.get_json() or {}
+ token_value = (data.get("token") or "").strip()
+
+ if not token_value:
+ return api_response(
+ success=False,
+ message="Verification token is required",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ verify_token = EmailVerificationToken.query.filter_by(token=token_value).first()
+ if not verify_token or not verify_token.is_valid:
+ return api_response(
+ success=False,
+ message="This verification link is invalid or has expired.",
+ status=400,
+ error_type="INVALID_TOKEN",
+ )
+
+ try:
+ user = verify_token.user
+ user.email_verified = True
+ from gatehouse_app.extensions import db
+ db.session.add(user)
+ verify_token.consume()
+ _pw_logger.info(f"Email verified for user {user.id}")
+
+ return api_response(
+ data={},
+ message="Your email has been verified. You can now sign in.",
+ )
+ except Exception as exc:
+ _pw_logger.exception(f"Error verifying email: {exc}")
+ return api_response(
+ success=False,
+ message="An error occurred while verifying your email.",
+ status=500,
+ error_type="INTERNAL_ERROR",
+ )
+
+
+@api_v1_bp.route("/auth/resend-verification", methods=["POST"])
+def resend_verification():
+ """Resend email verification link.
+
+ Always returns 200 to avoid leaking account existence.
+
+ Request body:
+ email: User email address
+
+ Returns:
+ 200: Verification email sent (or silently no-op)
+ """
+ from gatehouse_app.models import User, EmailVerificationToken
+
+ data = request.get_json() or {}
+ email = (data.get("email") or "").strip().lower()
+
+ if not email:
+ return api_response(
+ success=False,
+ message="Email is required",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ user = User.query.filter_by(email=email, deleted_at=None).first()
+ if user and not user.email_verified:
+ try:
+ verify_token = EmailVerificationToken.generate(user_id=user.id)
+ app_url = current_app.config.get("APP_URL", "http://localhost:8080")
+ verify_link = f"{app_url}/verify-email?token={verify_token.token}"
+ subject = "Verify your Gatehouse email address"
+ body = (
+ f"Hi {user.full_name or user.email},\n\n"
+ f"Please verify your email address by clicking the link below (valid for 24 hours):\n"
+ f"{verify_link}\n\n"
+ f"Gatehouse Security Team"
+ )
+ NotificationService._send_email(
+ to_address=user.email,
+ subject=subject,
+ body=body,
+ )
+ _pw_logger.info(f"Verification email sent for user {user.id}")
+ except Exception as exc:
+ _pw_logger.exception(f"Error sending verification email: {exc}")
+
+ return api_response(
+ data={},
+ message="If an account exists for this email and is not yet verified, you will receive a verification link shortly.",
+ )
+
+
+# =============================================================================
+# Account Activation (separate from email-verification)
+# =============================================================================
+
+@api_v1_bp.route("/auth/activate", methods=["POST"])
+def activate_account():
+ """Activate a user account via a one-time activation code.
+
+ Request body:
+ code – the activation_key from the welcome email
+
+ Returns:
+ 200: Account activated, session token returned
+ 400: Missing code
+ 404: Invalid or already-used code
+ """
+ import secrets
+ from gatehouse_app.models.user.user import User
+ from gatehouse_app.extensions import db
+
+ data = request.get_json() or {}
+ code = (data.get("code") or "").strip()
+ if not code:
+ return api_response(success=False, message="Activation code is required", status=400, error_type="VALIDATION_ERROR")
+
+ user = User.query.filter_by(activation_key=code, deleted_at=None).first()
+ if not user:
+ return api_response(success=False, message="Invalid or expired activation code", status=404, error_type="NOT_FOUND")
+
+ user.activated = True
+ user.activation_key = None # one-time use
+ db.session.add(user)
+ db.session.commit()
+
+ user_session = AuthService.create_session(user)
+ _pw_logger.info(f"Account activated for user {user.id}")
+
+ return api_response(
+ data={
+ "user": user.to_dict(),
+ "token": user_session.token,
+ "expires_at": user_session.expires_at.isoformat() + "Z"
+ if user_session.expires_at.isoformat()[-1] != "Z"
+ else user_session.expires_at.isoformat(),
+ },
+ message="Account activated successfully",
+ )
+
+
+@api_v1_bp.route("/auth/resend-activation", methods=["POST"])
+def resend_activation():
+ """Re-send an account activation email.
+
+ Always returns 200 to avoid leaking whether an account exists.
+
+ Request body:
+ email – user email address
+ """
+ import secrets
+ from gatehouse_app.models.user.user import User
+ from gatehouse_app.extensions import db
+
+ data = request.get_json() or {}
+ email = (data.get("email") or "").strip().lower()
+ if not email:
+ return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR")
+
+ user = User.query.filter_by(email=email, deleted_at=None).first()
+ if user and not user.activated:
+ try:
+ code = secrets.token_urlsafe(32)
+ user.activation_key = code
+ db.session.add(user)
+ db.session.commit()
+
+ app_url = current_app.config.get("APP_URL", current_app.config.get("FRONTEND_URL", "http://localhost:8080"))
+ activate_link = f"{app_url}/activate?code={code}"
+ subject = "Activate your Gatehouse account"
+ body = (
+ f"Hi {user.full_name or user.email},\n\n"
+ f"Please activate your Gatehouse account by clicking the link below:\n"
+ f"{activate_link}\n\n"
+ f"If you did not create an account, you can safely ignore this email.\n\n"
+ f"Gatehouse Security Team"
+ )
+ NotificationService._send_email(to_address=user.email, subject=subject, body=body)
+ _pw_logger.info(f"Activation email re-sent to {user.id}")
+ except Exception as exc:
+ _pw_logger.exception(f"Error re-sending activation email: {exc}")
+
+ return api_response(
+ data={},
+ message="If an unactivated account exists for this email, you will receive a new activation link shortly.",
+ )
+
+
+# =============================================================================
+# Token retrieval / redirect (for CLI / external tools)
+# =============================================================================
+
+@api_v1_bp.route("/auth/token", methods=["GET"])
+@login_required
+def get_token():
+ """Return the current session token, optionally redirecting to a URL.
+
+ Query parameters:
+ redirect – optional URL to redirect to with the token appended as
+ a query param: ``?token=``
+
+ Returns:
+ 200: JSON ``{"token": ""}`` (no redirect given)
+ 302: Redirect to ``?token=``
+ """
+ from flask import redirect as flask_redirect
+ from urllib.parse import urlparse
+
+ token = g.current_session.token
+ redirect_url = request.args.get("redirect", "").strip()
+
+ if redirect_url:
+ # Validate redirect URL against allowed origins to prevent open-redirect
+ # token exfiltration attacks (CWE-601).
+ allowed_origins = set(current_app.config.get("CORS_ORIGINS", []))
+ frontend_url = current_app.config.get("FRONTEND_URL", "")
+ if frontend_url:
+ parsed = urlparse(frontend_url)
+ allowed_origins.add(f"{parsed.scheme}://{parsed.netloc}")
+
+ parsed_redirect = urlparse(redirect_url)
+ redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
+
+ if redirect_origin not in allowed_origins:
+ return api_response(
+ success=False,
+ message="Redirect URL is not allowed.",
+ status=400,
+ error_type="INVALID_REDIRECT",
+ )
+
+ sep = "&" if "?" in redirect_url else "?"
+ return flask_redirect(f"{redirect_url}{sep}token={token}", code=302)
+
+ return api_response(data={"token": token}, message="Token retrieved")
diff --git a/gatehouse_app/api/v1/departments.py b/gatehouse_app/api/v1/departments.py
new file mode 100644
index 0000000..d305c66
--- /dev/null
+++ b/gatehouse_app/api/v1/departments.py
@@ -0,0 +1,699 @@
+"""Department endpoints."""
+from flask import g, request
+from marshmallow import Schema, fields, validate, ValidationError
+from sqlalchemy.orm.attributes import flag_modified
+
+from gatehouse_app.api.v1 import api_v1_bp
+from gatehouse_app.utils.response import api_response
+from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
+from gatehouse_app.models import Department, DepartmentMembership
+from gatehouse_app.services.organization_service import OrganizationService
+from gatehouse_app.services.user_service import UserService
+from gatehouse_app.extensions import db
+
+
+class DepartmentCreateSchema(Schema):
+ """Schema for creating a department."""
+ name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
+ description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
+
+
+class DepartmentUpdateSchema(Schema):
+ """Schema for updating a department."""
+ name = fields.Str(validate=validate.Length(min=1, max=255))
+ description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
+
+
+class AddDepartmentMemberSchema(Schema):
+ """Schema for adding a member to a department."""
+ email = fields.Email(required=True)
+
+
+@api_v1_bp.route("/organizations//departments", methods=["GET"])
+@login_required
+@full_access_required
+def list_departments(org_id):
+ """
+ List all departments in an organization.
+
+ Args:
+ org_id: Organization ID
+
+ Returns:
+ 200: List of departments
+ 401: Not authenticated
+ 403: Not a member
+ 404: Organization not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ # Check if user is a member
+ if not org.is_member(g.current_user.id):
+ return api_response(
+ success=False,
+ message="You are not a member of this organization",
+ status=403,
+ error_type="AUTHORIZATION_ERROR",
+ )
+
+ departments = Department.query.filter_by(
+ organization_id=org_id,
+ deleted_at=None
+ ).all()
+
+ return api_response(
+ data={
+ "departments": [d.to_dict() for d in departments],
+ "count": len(departments),
+ },
+ message="Departments retrieved successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//departments", methods=["POST"])
+@login_required
+@require_admin
+@full_access_required
+def create_department(org_id):
+ """
+ Create a new department.
+
+ Args:
+ org_id: Organization ID
+
+ Request body:
+ name: Department name (required)
+ description: Optional description
+
+ Returns:
+ 201: Department created successfully
+ 400: Validation error
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization not found
+ 409: Department name already exists in org
+ """
+ try:
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ schema = DepartmentCreateSchema()
+ data = schema.load(request.json or {})
+
+ # Check if department name already exists
+ existing = Department.query.filter_by(
+ organization_id=org_id,
+ name=data["name"],
+ deleted_at=None
+ ).first()
+
+ if existing:
+ return api_response(
+ success=False,
+ message=f"Department '{data['name']}' already exists in this organization",
+ status=409,
+ error_type="CONFLICT",
+ )
+
+ # Create department
+ dept = Department(
+ organization_id=org_id,
+ name=data["name"],
+ description=data.get("description"),
+ )
+ db.session.add(dept)
+ db.session.commit()
+
+ return api_response(
+ data={"department": dept.to_dict()},
+ message="Department created successfully",
+ status=201,
+ )
+
+ except ValidationError as e:
+ return api_response(
+ success=False,
+ message="Validation failed",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ error_details=e.messages,
+ )
+
+
+@api_v1_bp.route("/organizations//departments/", methods=["GET"])
+@login_required
+@full_access_required
+def get_department(org_id, dept_id):
+ """
+ Get a specific department.
+
+ Args:
+ org_id: Organization ID
+ dept_id: Department ID
+
+ Returns:
+ 200: Department data
+ 401: Not authenticated
+ 403: Not a member
+ 404: Organization or department not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ if not org.is_member(g.current_user.id):
+ return api_response(
+ success=False,
+ message="You are not a member of this organization",
+ status=403,
+ error_type="AUTHORIZATION_ERROR",
+ )
+
+ dept = Department.query.filter_by(
+ id=dept_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not dept:
+ return api_response(
+ success=False,
+ message="Department not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ return api_response(
+ data={"department": dept.to_dict()},
+ message="Department retrieved successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//departments/", methods=["PATCH"])
+@login_required
+@require_admin
+@full_access_required
+def update_department(org_id, dept_id):
+ """
+ Update a department.
+
+ Args:
+ org_id: Organization ID
+ dept_id: Department ID
+
+ Request body:
+ name: Optional new name
+ description: Optional new description
+
+ Returns:
+ 200: Department updated successfully
+ 400: Validation error
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization or department not found
+ 409: Name already exists
+ """
+ try:
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ dept = Department.query.filter_by(
+ id=dept_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not dept:
+ return api_response(
+ success=False,
+ message="Department not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ schema = DepartmentUpdateSchema()
+ data = schema.load(request.json or {})
+
+ # Check if new name already exists
+ if "name" in data and data["name"] != dept.name:
+ existing = Department.query.filter_by(
+ organization_id=org_id,
+ name=data["name"],
+ deleted_at=None
+ ).first()
+ if existing:
+ return api_response(
+ success=False,
+ message=f"Department '{data['name']}' already exists",
+ status=409,
+ error_type="CONFLICT",
+ )
+
+ # Update fields
+ for key, value in data.items():
+ setattr(dept, key, value)
+
+ db.session.commit()
+
+ return api_response(
+ data={"department": dept.to_dict()},
+ message="Department updated successfully",
+ )
+
+ except ValidationError as e:
+ return api_response(
+ success=False,
+ message="Validation failed",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ error_details=e.messages,
+ )
+
+
+@api_v1_bp.route("/organizations//departments/", methods=["DELETE"])
+@login_required
+@require_admin
+@full_access_required
+def delete_department(org_id, dept_id):
+ """
+ Delete a department (soft delete).
+
+ Args:
+ org_id: Organization ID
+ dept_id: Department ID
+
+ Returns:
+ 200: Department deleted successfully
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization or department not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ dept = Department.query.filter_by(
+ id=dept_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not dept:
+ return api_response(
+ success=False,
+ message="Department not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ # Soft delete
+ dept.deleted_at = db.func.now()
+ db.session.commit()
+
+ return api_response(
+ message="Department deleted successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//departments//members", methods=["GET"])
+@login_required
+@full_access_required
+def get_department_members(org_id, dept_id):
+ """
+ Get all members of a department.
+
+ Args:
+ org_id: Organization ID
+ dept_id: Department ID
+
+ Returns:
+ 200: List of members
+ 401: Not authenticated
+ 403: Not a member
+ 404: Organization or department not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ if not org.is_member(g.current_user.id):
+ return api_response(
+ success=False,
+ message="You are not a member of this organization",
+ status=403,
+ error_type="AUTHORIZATION_ERROR",
+ )
+
+ dept = Department.query.filter_by(
+ id=dept_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not dept:
+ return api_response(
+ success=False,
+ message="Department not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ members = DepartmentMembership.query.filter_by(
+ department_id=dept_id,
+ deleted_at=None
+ ).all()
+
+ members_data = []
+ for member in members:
+ member_dict = member.to_dict()
+ member_dict["user"] = member.user.to_dict()
+ members_data.append(member_dict)
+
+ return api_response(
+ data={
+ "members": members_data,
+ "count": len(members_data),
+ },
+ message="Members retrieved successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//departments//members", methods=["POST"])
+@login_required
+@require_admin
+@full_access_required
+def add_department_member(org_id, dept_id):
+ """
+ Add a member to a department.
+
+ Args:
+ org_id: Organization ID
+ dept_id: Department ID
+
+ Request body:
+ email: User email to add
+
+ Returns:
+ 201: Member added successfully
+ 400: Validation error
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization, department, or user not found
+ 409: User already a member
+ """
+ try:
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ dept = Department.query.filter_by(
+ id=dept_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not dept:
+ return api_response(
+ success=False,
+ message="Department not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ schema = AddDepartmentMemberSchema()
+ data = schema.load(request.json or {})
+
+ # Find user by email
+ user = UserService.get_user_by_email(data["email"])
+ if not user:
+ return api_response(
+ success=False,
+ message="User not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ # Check if already an active member
+ existing = DepartmentMembership.query.filter_by(
+ user_id=user.id,
+ department_id=dept_id,
+ deleted_at=None
+ ).first()
+
+ if existing:
+ return api_response(
+ success=False,
+ message="User is already a member of this department",
+ status=409,
+ error_type="CONFLICT",
+ )
+
+ # Check for a previously soft-deleted row and resurrect it instead of inserting
+ soft_deleted = DepartmentMembership.query.filter(
+ DepartmentMembership.user_id == user.id,
+ DepartmentMembership.department_id == dept_id,
+ DepartmentMembership.deleted_at.isnot(None)
+ ).first()
+
+ if soft_deleted:
+ soft_deleted.deleted_at = None
+ membership = soft_deleted
+ else:
+ membership = DepartmentMembership(
+ user_id=user.id,
+ department_id=dept_id,
+ )
+ db.session.add(membership)
+
+ db.session.commit()
+
+ member_dict = membership.to_dict()
+ member_dict["user"] = user.to_dict()
+
+ return api_response(
+ data={"member": member_dict},
+ message="Member added successfully",
+ status=201,
+ )
+
+ except ValidationError as e:
+ return api_response(
+ success=False,
+ message="Validation failed",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ error_details=e.messages,
+ )
+
+
+@api_v1_bp.route("/organizations//departments//members/", methods=["DELETE"])
+@login_required
+@require_admin
+@full_access_required
+def remove_department_member(org_id, dept_id, user_id):
+ """
+ Remove a member from a department.
+
+ Args:
+ org_id: Organization ID
+ dept_id: Department ID
+ user_id: User ID to remove
+
+ Returns:
+ 200: Member removed successfully
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization, department, or member not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ dept = Department.query.filter_by(
+ id=dept_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not dept:
+ return api_response(
+ success=False,
+ message="Department not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ membership = DepartmentMembership.query.filter_by(
+ user_id=user_id,
+ department_id=dept_id,
+ deleted_at=None
+ ).first()
+
+ if not membership:
+ return api_response(
+ success=False,
+ message="User is not a member of this department",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ # Soft delete
+ membership.deleted_at = db.func.now()
+ db.session.commit()
+
+ return api_response(
+ message="Member removed successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//departments//principals", methods=["GET"])
+@login_required
+@full_access_required
+def get_department_principals(org_id, dept_id):
+ """Get all principals linked to a department."""
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ if not org.is_member(g.current_user.id):
+ return api_response(
+ success=False,
+ message="You are not a member of this organization",
+ status=403,
+ error_type="AUTHORIZATION_ERROR",
+ )
+
+ dept = Department.query.filter_by(
+ id=dept_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not dept:
+ return api_response(
+ success=False,
+ message="Department not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ principals = dept.get_principals(active_only=True)
+
+ return api_response(
+ data={
+ "principals": [p.to_dict() for p in principals],
+ "count": len(principals),
+ },
+ message="Principals retrieved successfully",
+ )
+
+
+# ---------------------------------------------------------------------------
+# Department Certificate Policy
+# ---------------------------------------------------------------------------
+
+@api_v1_bp.route("/organizations//departments//cert-policy", methods=["GET"])
+@login_required
+@require_admin
+@full_access_required
+def get_dept_cert_policy(org_id, dept_id):
+ """Get the certificate issuance policy for a department (admin only)."""
+ from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy, STANDARD_EXTENSIONS
+
+ dept = Department.query.filter_by(
+ id=dept_id, organization_id=org_id, deleted_at=None
+ ).first()
+ if not dept:
+ return api_response(success=False, message="Department not found", status=404, error_type="NOT_FOUND")
+
+ policy = DepartmentCertPolicy.query.filter(
+ DepartmentCertPolicy.department_id == dept_id,
+ DepartmentCertPolicy.deleted_at.is_(None),
+ ).first()
+
+ if policy:
+ data = policy.to_dict()
+ else:
+ # Return default (all standard extensions, no user expiry choice)
+ data = {
+ "department_id": str(dept_id),
+ "allow_user_expiry": False,
+ "default_expiry_hours": 1,
+ "max_expiry_hours": 24,
+ "allowed_extensions": list(STANDARD_EXTENSIONS),
+ "custom_extensions": [],
+ "all_extensions": list(STANDARD_EXTENSIONS),
+ "standard_extensions": list(STANDARD_EXTENSIONS),
+ }
+
+ return api_response(data={"cert_policy": data}, message="Certificate policy retrieved")
+
+
+@api_v1_bp.route("/organizations//departments//cert-policy", methods=["PUT"])
+@login_required
+@require_admin
+@full_access_required
+def set_dept_cert_policy(org_id, dept_id):
+ """Create or update the certificate issuance policy for a department (admin only)."""
+ from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy, STANDARD_EXTENSIONS
+
+ dept = Department.query.filter_by(
+ id=dept_id, organization_id=org_id, deleted_at=None
+ ).first()
+ if not dept:
+ return api_response(success=False, message="Department not found", status=404, error_type="NOT_FOUND")
+
+ body = request.get_json() or {}
+
+ # Validate expiry values
+ default_expiry = body.get("default_expiry_hours")
+ max_expiry = body.get("max_expiry_hours")
+ if default_expiry is not None:
+ try:
+ default_expiry = int(default_expiry)
+ if default_expiry < 1:
+ raise ValueError
+ except (ValueError, TypeError):
+ return api_response(success=False, message="default_expiry_hours must be a positive integer", status=400, error_type="VALIDATION_ERROR")
+ if max_expiry is not None:
+ try:
+ max_expiry = int(max_expiry)
+ if max_expiry < 1:
+ raise ValueError
+ except (ValueError, TypeError):
+ return api_response(success=False, message="max_expiry_hours must be a positive integer", status=400, error_type="VALIDATION_ERROR")
+ if default_expiry and max_expiry and default_expiry > max_expiry:
+ return api_response(success=False, message="default_expiry_hours cannot exceed max_expiry_hours", status=400, error_type="VALIDATION_ERROR")
+
+ # Validate allowed_extensions — must be subset of STANDARD_EXTENSIONS
+ allowed_extensions = body.get("allowed_extensions")
+ if allowed_extensions is not None:
+ if not isinstance(allowed_extensions, list):
+ return api_response(success=False, message="allowed_extensions must be a list", status=400, error_type="VALIDATION_ERROR")
+ invalid_ext = [e for e in allowed_extensions if e not in STANDARD_EXTENSIONS]
+ if invalid_ext:
+ return api_response(
+ success=False,
+ message=f"Invalid standard extensions: {', '.join(invalid_ext)}. Valid: {', '.join(STANDARD_EXTENSIONS)}",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ # Validate custom_extensions — plain strings
+ custom_extensions = body.get("custom_extensions")
+ if custom_extensions is not None:
+ if not isinstance(custom_extensions, list) or not all(isinstance(e, str) for e in custom_extensions):
+ return api_response(success=False, message="custom_extensions must be a list of strings", status=400, error_type="VALIDATION_ERROR")
+
+ policy = DepartmentCertPolicy.query.filter(
+ DepartmentCertPolicy.department_id == dept_id,
+ DepartmentCertPolicy.deleted_at.is_(None),
+ ).first()
+
+ if policy is None:
+ policy = DepartmentCertPolicy(department_id=dept_id)
+ db.session.add(policy)
+
+ if "allow_user_expiry" in body:
+ policy.allow_user_expiry = bool(body["allow_user_expiry"])
+ if default_expiry is not None:
+ policy.default_expiry_hours = default_expiry
+ if max_expiry is not None:
+ policy.max_expiry_hours = max_expiry
+ if allowed_extensions is not None:
+ policy.allowed_extensions = list(allowed_extensions)
+ flag_modified(policy, "allowed_extensions")
+ if custom_extensions is not None:
+ policy.custom_extensions = list(custom_extensions)
+ flag_modified(policy, "custom_extensions")
+
+ db.session.commit()
+
+ return api_response(data={"cert_policy": policy.to_dict()}, message="Certificate policy saved")
+
diff --git a/gatehouse_app/api/v1/external_auth.py b/gatehouse_app/api/v1/external_auth.py
index bb2969d..a23101c 100644
--- a/gatehouse_app/api/v1/external_auth.py
+++ b/gatehouse_app/api/v1/external_auth.py
@@ -46,6 +46,33 @@ def _pop_oidc_bridge(oauth_state: str) -> str | None:
pass
return None
+
+def _store_cli_redirect(oauth_state: str, redirect_url: str) -> None:
+ """Store CLI redirect_url keyed by OAuth state (for /token_please flow)."""
+ try:
+ import gatehouse_app.extensions as _ext
+ rc = _ext.redis_client
+ if rc is not None:
+ rc.setex(f"oauth_cli_redirect:{oauth_state}", _OAUTH_BRIDGE_TTL, redirect_url)
+ except Exception:
+ pass
+
+
+def _pop_cli_redirect(oauth_state: str) -> str | None:
+ """Retrieve and delete CLI redirect_url for the given OAuth state."""
+ try:
+ import gatehouse_app.extensions as _ext
+ rc = _ext.redis_client
+ if rc is not None:
+ key = f"oauth_cli_redirect:{oauth_state}"
+ val = rc.get(key)
+ if val:
+ rc.delete(key)
+ return val.decode() if isinstance(val, bytes) else val
+ except Exception:
+ pass
+ return None
+
logger = logging.getLogger(__name__)
@@ -69,6 +96,137 @@ def get_provider_type(provider: str) -> AuthMethodType:
return PROVIDER_TYPE_MAP[provider_lower]
+@api_v1_bp.route("/token_please", methods=["GET"])
+def token_please():
+ """
+ CLI token acquisition endpoint.
+
+ Redirects the user's browser to the Gatehouse login page so they can
+ authenticate using any method (password, OAuth, passkey, TOTP, etc.).
+ On successful login the frontend delivers the session token directly to
+ the CLI's local callback server.
+
+ This endpoint is designed for CLI clients that:
+ 1. Start a local HTTP server on LISTENER_SERVER_PORT (e.g. 8250)
+ 2. Open a browser to /api/v1/token_please?redirect_url=http://127.0.0.1:8250/?token=
+ 3. Wait for the browser to deliver the token to their local server
+
+ Query parameters:
+ redirect_url: Local callback URL where the token will be appended
+ """
+ import secrets
+ from urllib.parse import urlencode, quote
+ from flask import current_app, redirect as flask_redirect
+
+ redirect_url = request.args.get("redirect_url", "").strip()
+
+ if not redirect_url:
+ return api_response(
+ success=False,
+ message="redirect_url query parameter is required",
+ status=400,
+ error_type="MISSING_REDIRECT_URL",
+ )
+
+ # Validate redirect_url is localhost/127.0.0.1 (security: prevent open redirect)
+ from urllib.parse import urlparse as _urlparse
+ parsed = _urlparse(redirect_url)
+ if parsed.hostname not in ("localhost", "127.0.0.1"):
+ return api_response(
+ success=False,
+ message="redirect_url must point to localhost",
+ status=400,
+ error_type="INVALID_REDIRECT_URL",
+ )
+
+ # Store the CLI redirect URL in Redis keyed by a short-lived token so the
+ # frontend can retrieve it after login without it being visible in the URL.
+ cli_token = secrets.token_urlsafe(32)
+ try:
+ import gatehouse_app.extensions as _ext
+ rc = _ext.redis_client
+ if rc is not None:
+ rc.setex(f"cli_redirect:{cli_token}", _OAUTH_BRIDGE_TTL, redirect_url)
+ else:
+ logger.warning("Redis not available; passing cli_redirect directly in URL")
+ cli_token = None
+ except Exception:
+ cli_token = None
+
+ frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080")
+
+ if cli_token:
+ # Pass an opaque token; the frontend exchanges it for the real URL via
+ # GET /api/v1/cli/redirect-url?token=
+ login_url = f"{frontend_url}/login?cli_token={cli_token}"
+ else:
+ # Fallback: put the redirect URL directly (still localhost-only, validated above)
+ login_url = f"{frontend_url}/login?cli_redirect={quote(redirect_url, safe='')}"
+
+ logger.info(f"CLI token_please: redirecting browser to Gatehouse login page")
+ return flask_redirect(login_url, code=302)
+
+
+@api_v1_bp.route("/cli/redirect-url", methods=["GET"])
+def cli_redirect_url_lookup():
+ """
+ Exchange a short-lived cli_token for the CLI's local redirect URL.
+
+ Called by the frontend LoginPage after it detects the cli_token query
+ param so it can obtain the actual CLI callback URL from Redis without
+ exposing it in the browser URL bar.
+
+ Query parameters:
+ token: The cli_token issued by /token_please
+
+ Returns:
+ 200: { "redirect_url": "http://127.0.0.1:8250/?token=" }
+ 400: Missing token
+ 404: Token not found or expired
+ """
+ cli_token = request.args.get("token", "").strip()
+ if not cli_token:
+ return api_response(
+ success=False,
+ message="token query parameter is required",
+ status=400,
+ error_type="MISSING_TOKEN",
+ )
+
+ try:
+ import gatehouse_app.extensions as _ext
+ rc = _ext.redis_client
+ if rc is not None:
+ key = f"cli_redirect:{cli_token}"
+ val = rc.get(key)
+ if val is None:
+ return api_response(
+ success=False,
+ message="CLI token not found or expired",
+ status=404,
+ error_type="TOKEN_NOT_FOUND",
+ )
+ # Keep the key alive until the login actually completes (consume on use
+ # would break multi-step auth like TOTP), so we leave it as-is.
+ redirect_url = val.decode() if isinstance(val, bytes) else val
+ return api_response(data={"redirect_url": redirect_url})
+ except Exception as e:
+ logger.error(f"cli_redirect_url_lookup error: {e}")
+ return api_response(
+ success=False,
+ message="Internal error looking up CLI token",
+ status=500,
+ error_type="INTERNAL_ERROR",
+ )
+
+ return api_response(
+ success=False,
+ message="Redis not available",
+ status=503,
+ error_type="SERVICE_UNAVAILABLE",
+ )
+
+
# =============================================================================
# Provider Configuration Endpoints (Admin)
# =============================================================================
@@ -83,7 +241,7 @@ def list_providers():
200: List of providers with their configuration status
401: Not authenticated
"""
- from gatehouse_app.models.authentication_method import ApplicationProviderConfig
+ from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
from gatehouse_app.services.external_auth_service import ExternalProviderConfig
# Check app-level provider configs (ApplicationProviderConfig)
@@ -575,8 +733,6 @@ def initiate_oauth_authorize(provider: str):
"state": "state_token"
}
"""
- provider_type = get_provider_type(provider)
-
# Get query parameters - organization_id is now optional
flow = request.args.get("flow", "login")
redirect_uri = request.args.get("redirect_uri")
@@ -592,7 +748,7 @@ def initiate_oauth_authorize(provider: str):
)
try:
- # Initiate flow - organization_id is now optional
+ provider_type = get_provider_type(provider)
if flow == "login":
auth_url, state = OAuthFlowService.initiate_login_flow(
provider_type=provider_type,
@@ -626,6 +782,13 @@ def initiate_oauth_authorize(provider: str):
status=e.status_code,
error_type=e.error_type,
)
+ except ExternalAuthError as e:
+ return api_response(
+ success=False,
+ message=e.message,
+ status=e.status_code,
+ error_type=e.error_type,
+ )
@api_v1_bp.route("/auth/external//callback", methods=["GET"])
@@ -666,8 +829,19 @@ def handle_oauth_callback(provider: str):
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080")
frontend_callback = f"{frontend_url}/oauth/callback"
+ # Check if this is a CLI /token_please flow — retrieve stored redirect_url
+ cli_redirect_url = _pop_cli_redirect(state) if state else None
+
def redirect_error(message: str, error_type: str = "OAUTH_ERROR"):
- """Redirect to frontend with error params."""
+ """Redirect to frontend (or CLI) with error params."""
+ if cli_redirect_url:
+ # CLI flow: return a plain error page instead of redirecting back
+ from flask import make_response
+ return make_response(
+ f"Authentication Error
{message}
"
+ f"You may close this window.
",
+ 400,
+ )
params = {"error": message, "error_type": error_type}
if state:
params["state"] = state
@@ -706,8 +880,11 @@ def handle_oauth_callback(provider: str):
# Recover oidc_session_id if this was triggered from an OIDC bridge flow
oidc_session_id = _pop_oidc_bridge(state)
+ # Organization selection / creation flows are not supported in CLI mode
+ # (fall through to token redirect with whatever session we have)
+
# Organization selection needed (user belongs to multiple orgs)
- if result.get("requires_org_selection"):
+ if result.get("requires_org_selection") and not cli_redirect_url:
import json
orgs = json.dumps(result.get("available_organizations", []))
params = {
@@ -722,12 +899,20 @@ def handle_oauth_callback(provider: str):
return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302)
# Organization creation needed (new user via OAuth with no org)
- if result.get("requires_org_creation"):
+ if result.get("requires_org_creation") and not cli_redirect_url:
+ import json as _json
+ session_data = result.get("session", {})
+ token = session_data.get("token", "")
+ expires_in = session_data.get("expires_in", 86400)
+ pending_invites = result.get("pending_invites", [])
params = {
"requires_org_creation": "1",
"state": result["state"],
"provider": provider,
"flow": flow_type,
+ "token": token,
+ "expires_in": str(expires_in),
+ "pending_invites": _json.dumps(pending_invites),
}
if oidc_session_id:
params["oidc_session_id"] = oidc_session_id
@@ -751,6 +936,19 @@ def handle_oauth_callback(provider: str):
user_info = result.get("user", {})
if user_info.get("email"):
params["email"] = user_info["email"]
+
+ # ── CLI /token_please flow: redirect to the CLI's local callback ─────
+ if cli_redirect_url:
+ # The CLI expects: http://127.0.0.1:8250/?token=
+ # cli_redirect_url already ends with "token=" so just append the value
+ cli_final_url = cli_redirect_url + token
+ logger.info(
+ f"CLI token_please success: provider={provider}, user={user_info.get('email')}, "
+ f"redirecting to CLI callback"
+ )
+ return flask_redirect(cli_final_url, code=302)
+
+ # ── Frontend flow ─────────────────────────────────────────────────────
# Pass oidc_session_id through so the frontend can complete the OIDC flow
if oidc_session_id:
params["oidc_session_id"] = oidc_session_id
@@ -1049,3 +1247,203 @@ def _get_provider_endpoints(provider_type: AuthMethodType):
"UNSUPPORTED_PROVIDER",
400,
)
+
+
+# =============================================================================
+# Admin: Application-level OAuth Provider Management
+# =============================================================================
+
+@api_v1_bp.route("/admin/oauth/providers", methods=["GET"])
+@login_required
+def admin_list_app_providers():
+ """List all application-level OAuth provider configurations (admin only).
+
+ Returns:
+ 200: List of providers with client_id and enabled status
+ 401: Not authenticated
+ 403: Not an admin
+ """
+ from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
+ from gatehouse_app.models import OrganizationMember
+ from gatehouse_app.utils.constants import OrganizationRole
+
+ # Verify caller is admin in any org
+ admin_memberships = OrganizationMember.query.filter(
+ OrganizationMember.user_id == g.current_user.id,
+ OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]),
+ ).all()
+
+ if not admin_memberships:
+ return api_response(
+ success=False,
+ message="Admin access required",
+ status=403,
+ error_type="FORBIDDEN",
+ )
+
+ PROVIDERS = [
+ {"id": "google", "name": "Google"},
+ {"id": "github", "name": "GitHub"},
+ {"id": "microsoft", "name": "Microsoft"},
+ ]
+
+ db_configs = {
+ c.provider_type: c
+ for c in ApplicationProviderConfig.query.all()
+ }
+
+ result = []
+ for p in PROVIDERS:
+ cfg = db_configs.get(p["id"])
+ result.append({
+ "id": p["id"],
+ "name": p["name"],
+ "is_configured": cfg is not None,
+ "is_enabled": cfg.is_enabled if cfg else False,
+ "client_id": cfg.client_id if cfg else None,
+ })
+
+ return api_response(
+ data={"providers": result},
+ message="OAuth providers retrieved successfully",
+ )
+
+
+@api_v1_bp.route("/admin/oauth/providers/", methods=["PUT"])
+@login_required
+def admin_configure_app_provider(provider: str):
+ """Create or update an application-level OAuth provider config (admin only).
+
+ Args:
+ provider: Provider type (google, github, microsoft)
+
+ Request body:
+ client_id: OAuth client ID
+ client_secret: OAuth client secret (optional — omit to keep existing)
+ is_enabled: Whether the provider is enabled (default: true)
+
+ Returns:
+ 200: Provider configuration updated
+ 400: Validation error
+ 401: Not authenticated
+ 403: Not an admin
+ """
+ from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
+ from gatehouse_app.models import OrganizationMember
+ from gatehouse_app.utils.constants import OrganizationRole
+ from gatehouse_app.extensions import db
+
+ SUPPORTED = ["google", "github", "microsoft"]
+ if provider not in SUPPORTED:
+ return api_response(
+ success=False,
+ message=f"Unsupported provider. Must be one of: {', '.join(SUPPORTED)}",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ # Verify caller is admin in any org
+ admin_memberships = OrganizationMember.query.filter(
+ OrganizationMember.user_id == g.current_user.id,
+ OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]),
+ ).all()
+
+ if not admin_memberships:
+ return api_response(
+ success=False,
+ message="Admin access required",
+ status=403,
+ error_type="FORBIDDEN",
+ )
+
+ data = request.json or {}
+ client_id = (data.get("client_id") or "").strip()
+ client_secret = (data.get("client_secret") or "").strip()
+ is_enabled = data.get("is_enabled", True)
+
+ if not client_id:
+ return api_response(
+ success=False,
+ message="client_id is required",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
+ if cfg:
+ cfg.client_id = client_id
+ if client_secret:
+ cfg.set_client_secret(client_secret)
+ cfg.is_enabled = bool(is_enabled)
+ db.session.commit()
+ else:
+ cfg = ApplicationProviderConfig(
+ provider_type=provider,
+ client_id=client_id,
+ is_enabled=bool(is_enabled),
+ )
+ if client_secret:
+ cfg.set_client_secret(client_secret)
+ db.session.add(cfg)
+ db.session.commit()
+
+ return api_response(
+ data={
+ "provider": {
+ "id": provider,
+ "client_id": cfg.client_id,
+ "is_enabled": cfg.is_enabled,
+ }
+ },
+ message=f"{provider.capitalize()} OAuth provider configured successfully",
+ )
+
+
+@api_v1_bp.route("/admin/oauth/providers/", methods=["DELETE"])
+@login_required
+def admin_delete_app_provider(provider: str):
+ """Delete an application-level OAuth provider config (admin only).
+
+ Args:
+ provider: Provider type (google, github, microsoft)
+
+ Returns:
+ 200: Provider configuration deleted
+ 404: Provider not found
+ 401: Not authenticated
+ 403: Not an admin
+ """
+ from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
+ from gatehouse_app.models import OrganizationMember
+ from gatehouse_app.utils.constants import OrganizationRole
+ from gatehouse_app.extensions import db
+
+ # Verify caller is admin in any org
+ admin_memberships = OrganizationMember.query.filter(
+ OrganizationMember.user_id == g.current_user.id,
+ OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]),
+ ).all()
+
+ if not admin_memberships:
+ return api_response(
+ success=False,
+ message="Admin access required",
+ status=403,
+ error_type="FORBIDDEN",
+ )
+
+ cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
+ if not cfg:
+ return api_response(
+ success=False,
+ message=f"Provider '{provider}' is not configured",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ db.session.delete(cfg)
+ db.session.commit()
+
+ return api_response(
+ message=f"{provider.capitalize()} OAuth provider configuration removed",
+ )
diff --git a/gatehouse_app/api/v1/organizations.py b/gatehouse_app/api/v1/organizations.py
index 98c39e1..7d8f4d3 100644
--- a/gatehouse_app/api/v1/organizations.py
+++ b/gatehouse_app/api/v1/organizations.py
@@ -1,5 +1,5 @@
"""Organization endpoints."""
-from flask import g, request
+from flask import g, request, current_app
from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
@@ -13,6 +13,72 @@ from gatehouse_app.schemas.organization_schema import (
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.utils.constants import OrganizationRole
+from gatehouse_app.extensions import db
+
+
+
+def _get_system_ca_dict():
+ """Return a synthetic read-only CA dict for the config-file CA, or None.
+
+ This is injected into the org CA list when no DB CA exists for a given
+ ca_type so that the admin UI correctly shows "configured" rather than
+ "Not configured" when a system-level CA key is present.
+
+ The returned dict has ``is_system=True`` so the frontend can render it
+ as read-only (no delete / edit / generate buttons).
+ """
+ import os
+ try:
+ from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
+ from gatehouse_app.utils.crypto import compute_ssh_fingerprint
+
+ # Check env var first (takes priority over file path)
+ priv_key = os.environ.get("SSH_CA_PRIVATE_KEY", "").strip()
+ pub_key = ""
+
+ if not priv_key:
+ cfg = get_ssh_ca_config()
+ key_path = cfg.get_str("ca_key_path", "").strip()
+ if not key_path:
+ return None
+ pub_path = key_path + ".pub"
+ if not os.path.exists(pub_path):
+ return None
+ with open(pub_path) as f:
+ pub_key = f.read().strip()
+ else:
+ # Derive the public key from the private key
+ from sshkey_tools.keys import PrivateKey
+ pk = PrivateKey.from_string(priv_key)
+ pub_key = pk.public_key.to_string()
+
+ fingerprint = compute_ssh_fingerprint(pub_key)
+ return {
+ "id": f"system-ca-{fingerprint[:16]}",
+ "organization_id": None,
+ "name": "System CA (config file)",
+ "description": (
+ "Read-only — this CA is loaded from the server's SSH_CA_PRIVATE_KEY "
+ "environment variable or etc/ssh_ca.conf. Manage it on the server."
+ ),
+ # ca_type is set by the caller
+ "ca_type": "user",
+ "key_type": "unknown",
+ "public_key": pub_key,
+ "fingerprint": fingerprint,
+ "is_active": True,
+ "is_system": True,
+ "default_cert_validity_hours": 0,
+ "max_cert_validity_hours": 0,
+ "total_certs": 0,
+ "active_certs": 0,
+ "revoked_certs": 0,
+ "created_at": None,
+ "updated_at": None,
+ }
+ except Exception:
+ return None
+
@api_v1_bp.route("/organizations", methods=["POST"])
@@ -160,6 +226,10 @@ def delete_organization(org_id):
"""
Delete organization (soft delete).
+ The owner may only delete the organization if they are the *sole* remaining
+ member. If other active members exist they must first transfer ownership
+ (or remove all other members) before deleting the organization.
+
Args:
org_id: Organization ID
@@ -168,9 +238,26 @@ def delete_organization(org_id):
401: Not authenticated
403: Not the owner
404: Organization not found
+ 409: Organization still has other members — transfer ownership first
"""
org = OrganizationService.get_organization_by_id(org_id)
+ # Guard: block deletion while non-owner members still exist so ownership
+ # can be transferred rather than silently orphaning them.
+ active_member_count = org.get_member_count()
+ if active_member_count > 1:
+ return api_response(
+ success=False,
+ message=(
+ "This organization still has other members. "
+ "Please transfer ownership to another member or remove all "
+ "other members before deleting the organization."
+ ),
+ status=409,
+ error_type="ORG_HAS_MEMBERS",
+ error_details={"member_count": active_member_count},
+ )
+
OrganizationService.delete_organization(
org=org,
user_id=g.current_user.id,
@@ -378,3 +465,1399 @@ def update_member_role(org_id, user_id):
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
+
+
+@api_v1_bp.route("/organizations//transfer-ownership", methods=["POST"])
+@login_required
+@full_access_required
+def transfer_organization_ownership(org_id):
+ """Transfer organization ownership from the current user to another member.
+
+ Only the current OWNER of the organization may call this endpoint.
+ The caller will be demoted to ADMIN and the target user will be promoted to OWNER.
+
+ Request body:
+ new_owner_user_id (str): UUID of the member to promote to OWNER.
+
+ Returns:
+ 200: Ownership transferred successfully
+ 400: Validation error / missing fields
+ 403: Caller is not the OWNER of this org
+ 404: Organization or target member not found
+ 409: Target is already the OWNER
+ """
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.utils.constants import OrganizationRole, AuditAction
+ from gatehouse_app.services.audit_service import AuditService
+
+ caller = g.current_user
+
+ data = request.get_json() or {}
+ new_owner_user_id = data.get("new_owner_user_id")
+ if not new_owner_user_id:
+ return api_response(
+ success=False,
+ message="new_owner_user_id is required",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ if str(new_owner_user_id) == str(caller.id):
+ return api_response(
+ success=False,
+ message="You are already the owner of this organization.",
+ status=409,
+ error_type="CONFLICT",
+ )
+
+ # Fetch org (raises NotFound internally)
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ # Confirm caller is the current OWNER
+ caller_membership = OrganizationMember.query.filter_by(
+ organization_id=org.id,
+ user_id=caller.id,
+ deleted_at=None,
+ ).first()
+ if not caller_membership or caller_membership.role != OrganizationRole.OWNER:
+ return api_response(
+ success=False,
+ message="Only the organization owner can transfer ownership.",
+ status=403,
+ error_type="AUTHORIZATION_ERROR",
+ )
+
+ # Verify the target is an active member
+ target_membership = OrganizationMember.query.filter_by(
+ organization_id=org.id,
+ user_id=new_owner_user_id,
+ deleted_at=None,
+ ).first()
+ if not target_membership:
+ return api_response(
+ success=False,
+ message="Target user is not a member of this organization.",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ if target_membership.role == OrganizationRole.OWNER:
+ return api_response(
+ success=False,
+ message="Target user is already the owner.",
+ status=409,
+ error_type="CONFLICT",
+ )
+
+ # ── Atomic role swap ─────────────────────────────────────────────────────
+ # Demote caller → ADMIN, promote target → OWNER.
+ # Both updates go through OrganizationService so all hooks/auditing fire.
+ try:
+ demoted = OrganizationService.update_member_role(
+ org=org,
+ user_id=str(caller.id),
+ new_role=OrganizationRole.ADMIN,
+ updater_id=str(caller.id),
+ )
+ promoted = OrganizationService.update_member_role(
+ org=org,
+ user_id=str(new_owner_user_id),
+ new_role=OrganizationRole.OWNER,
+ updater_id=str(caller.id),
+ )
+ except Exception as exc:
+ from gatehouse_app.extensions import db as _db
+ _db.session.rollback()
+ return api_response(
+ success=False,
+ message=f"Failed to transfer ownership: {exc}",
+ status=500,
+ error_type="SERVER_ERROR",
+ )
+
+ AuditService.log_action(
+ action=AuditAction.ORG_OWNERSHIP_TRANSFERRED,
+ user_id=caller.id,
+ organization_id=org.id,
+ resource_type="organization",
+ resource_id=str(org.id),
+ description=(
+ f"Ownership of '{org.name}' transferred from {caller.email} "
+ f"to {target_membership.user.email if target_membership.user else new_owner_user_id}"
+ ),
+ metadata={
+ "previous_owner_id": str(caller.id),
+ "previous_owner_email": caller.email,
+ "new_owner_id": str(new_owner_user_id),
+ "new_owner_email": (
+ target_membership.user.email if target_membership.user else None
+ ),
+ },
+ )
+
+ def _member_dict(m):
+ d = m.to_dict()
+ if m.user:
+ d["user"] = m.user.to_dict()
+ return d
+
+ return api_response(
+ data={
+ "previous_owner": _member_dict(demoted),
+ "new_owner": _member_dict(promoted),
+ },
+ message=(
+ f"Ownership of '{org.name}' successfully transferred to "
+ f"{target_membership.user.email if target_membership.user else new_owner_user_id}."
+ ),
+ )
+
+
+@api_v1_bp.route("/organizations//audit-logs", methods=["GET"])
+@login_required
+@require_admin
+@full_access_required
+def get_organization_audit_logs(org_id):
+ """
+ Get audit logs for an organization.
+
+ Query params:
+ page: Page number (default 1)
+ per_page: Results per page (default 50, max 200)
+ action: Filter by action type
+
+ Returns:
+ 200: List of audit log entries
+ 401: Not authenticated
+ 403: Not a member / insufficient permissions
+ 404: Organization not found
+ """
+ from gatehouse_app.models.auth.audit_log import AuditLog
+
+ # Ensure org exists and user is a member (full_access_required handles this)
+ OrganizationService.get_organization_by_id(org_id)
+
+ page = int(request.args.get("page", 1))
+ per_page = min(int(request.args.get("per_page", 50)), 200)
+ action_filter = request.args.get("action")
+
+ query = AuditLog.query.filter_by(organization_id=org_id)
+ if action_filter:
+ query = query.filter_by(action=action_filter)
+
+ query = query.order_by(AuditLog.created_at.desc())
+ total = query.count()
+ logs = query.offset((page - 1) * per_page).limit(per_page).all()
+
+ def log_to_dict(log):
+ return {
+ "id": log.id,
+ "action": log.action.value if log.action else None,
+ "user_id": log.user_id,
+ "user_email": log.user.email if log.user else None,
+ "user": {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} if log.user else None,
+ "organization_id": log.organization_id,
+ "resource_type": log.resource_type,
+ "resource_id": log.resource_id,
+ "ip_address": log.ip_address,
+ "user_agent": log.user_agent,
+ "request_id": log.request_id,
+ "description": log.description,
+ "success": log.success,
+ "error_message": log.error_message,
+ "metadata": log.extra_data,
+ "created_at": log.created_at.isoformat() if log.created_at else None,
+ "updated_at": log.updated_at.isoformat() if log.updated_at else None,
+ }
+
+ return api_response(
+ data={
+ "audit_logs": [log_to_dict(log) for log in logs],
+ "count": total,
+ "page": page,
+ "per_page": per_page,
+ "pages": (total + per_page - 1) // per_page,
+ },
+ message="Audit logs retrieved successfully",
+ )
+
+
+# ============================================================================
+# Organization Invite Tokens
+# ============================================================================
+
+@api_v1_bp.route("/organizations//invites", methods=["POST"])
+@login_required
+@require_admin
+def create_org_invite(org_id):
+ """Create an invite token for an organization.
+
+ Request body:
+ email: Email address to invite
+ role: Role to assign (default: member)
+
+ Returns:
+ 201: Invite created
+ 400: Validation error
+ 403: Not an admin
+ 404: Organization not found
+ """
+ from gatehouse_app.models import OrgInviteToken, Organization
+ from gatehouse_app.services.notification_service import NotificationService
+ from flask import current_app
+
+ org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
+ if not org:
+ return api_response(success=False, message="Organization not found", status=404)
+
+ data = request.get_json() or {}
+ email = (data.get("email") or "").strip().lower()
+ role = (data.get("role") or "member").strip()
+
+ if not email:
+ return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR")
+
+ invite = OrgInviteToken.generate(
+ organization_id=org_id,
+ email=email,
+ role=role,
+ invited_by_id=g.current_user.id,
+ )
+
+ app_url = current_app.config.get("APP_URL", "http://localhost:8080")
+ invite_link = f"{app_url}/invite?token={invite.token}"
+
+ email_sent = NotificationService._send_email(
+ to_address=email,
+ subject=f"You're invited to join {org.name} on Gatehouse",
+ body=(
+ f"You've been invited to join {org.name} on Gatehouse.\n\n"
+ f"Click the link below to accept the invitation (valid for 7 days):\n"
+ f"{invite_link}\n\n"
+ f"Gatehouse Security Team"
+ ),
+ )
+
+ # In dev mode email may not be configured — always log the link so it's findable
+ import logging
+ if not email_sent:
+ logging.getLogger(__name__).warning(
+ f"[INVITE LINK] Email not sent (EMAIL_ENABLED=False or SMTP down). "
+ f"Invite for {email} → {invite_link}"
+ )
+ else:
+ logging.getLogger(__name__).info(
+ f"[INVITE] Email sent successfully to {email}"
+ )
+
+ response_data = {
+ "invite": {
+ "id": invite.id,
+ "email": invite.email,
+ "role": invite.role,
+ "expires_at": invite.expires_at.isoformat() + "Z",
+ # Only include invite_link when email delivery failed — signals frontend to show copy dialog
+ **({"invite_link": invite_link} if not email_sent else {}),
+ }
+ }
+
+ return api_response(
+ data=response_data,
+ message="Invite sent successfully",
+ status=201,
+ )
+
+
+@api_v1_bp.route("/organizations//invites", methods=["GET"])
+@login_required
+@require_admin
+def list_org_invites(org_id):
+ """List pending invite tokens for an organization.
+
+ Returns:
+ 200: List of invites
+ 403: Not an admin
+ 404: Organization not found
+ """
+ from gatehouse_app.models import OrgInviteToken, Organization
+
+ org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
+ if not org:
+ return api_response(success=False, message="Organization not found", status=404)
+
+ invites = (
+ OrgInviteToken.query.filter_by(organization_id=org_id)
+ .filter(OrgInviteToken.accepted_at == None)
+ .filter(OrgInviteToken.deleted_at == None)
+ .all()
+ )
+
+ def invite_to_dict(inv):
+ return {
+ "id": inv.id,
+ "email": inv.email,
+ "role": inv.role,
+ "invited_by_id": inv.invited_by_id,
+ "created_at": inv.created_at.isoformat() + "Z",
+ "expires_at": inv.expires_at.isoformat() + "Z",
+ }
+
+ return api_response(
+ data={"invites": [invite_to_dict(i) for i in invites]},
+ message="Invites retrieved",
+ )
+
+
+@api_v1_bp.route("/organizations//invites/", methods=["DELETE"])
+@login_required
+@require_admin
+def cancel_org_invite(org_id, invite_id):
+ """Cancel (soft-delete) an organization invite.
+
+ Returns:
+ 200: Invite cancelled
+ 403: Not an admin
+ 404: Invite not found
+ """
+ from gatehouse_app.models import OrgInviteToken, Organization
+
+ org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
+ if not org:
+ return api_response(success=False, message="Organization not found", status=404)
+
+ invite = OrgInviteToken.query.filter_by(id=invite_id, organization_id=org_id, deleted_at=None).first()
+ if not invite:
+ return api_response(success=False, message="Invite not found", status=404)
+
+ # Soft delete the invite so it's no longer usable
+ invite.delete(soft=True)
+
+ return api_response(data={}, message="Invite cancelled")
+
+
+@api_v1_bp.route("/invites/", methods=["GET"])
+def get_invite(token):
+ """Get invite details by token.
+
+ Returns:
+ 200: Invite details (org name, email)
+ 400: Invalid or expired token
+ """
+ from gatehouse_app.models import OrgInviteToken, User
+
+ invite = OrgInviteToken.query.filter_by(token=token).first()
+ if not invite or not invite.is_valid:
+ return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN")
+
+ user_exists = User.query.filter_by(email=invite.email, deleted_at=None).first() is not None
+
+ return api_response(
+ data={
+ "email": invite.email,
+ "organization": {"id": invite.organization_id, "name": invite.organization.name},
+ "role": invite.role,
+ "user_exists": user_exists,
+ },
+ message="Invite found",
+ )
+
+
+@api_v1_bp.route("/invites//accept", methods=["POST"])
+def accept_invite(token):
+ """Accept an organization invite.
+
+ Creates the user account (if not already registered) and adds them
+ to the organization.
+
+ Request body:
+ full_name: User's display name
+ password: Password for new account (if not already registered)
+ password_confirm: Password confirmation
+
+ Returns:
+ 200: Invite accepted, returns user token
+ 400: Invalid/expired token or validation error
+ 409: Already a member
+ """
+ from gatehouse_app.models import OrgInviteToken, User
+ from gatehouse_app.services.auth_service import AuthService
+ from gatehouse_app.services.organization_service import OrganizationService
+ from gatehouse_app.utils.constants import OrganizationRole
+
+ invite = OrgInviteToken.query.filter_by(token=token).first()
+ if not invite or not invite.is_valid:
+ return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN")
+
+ data = request.get_json() or {}
+ full_name = data.get("full_name") or ""
+ password = data.get("password") or ""
+ password_confirm = data.get("password_confirm") or ""
+
+ user = User.query.filter_by(email=invite.email, deleted_at=None).first()
+
+ if not user:
+ # Register a new user
+ if not password:
+ return api_response(success=False, message="Password is required for new accounts.", status=400, error_type="VALIDATION_ERROR")
+ if password != password_confirm:
+ return api_response(success=False, message="Passwords do not match.", status=400, error_type="VALIDATION_ERROR")
+ if len(password) < 8:
+ return api_response(success=False, message="Password must be at least 8 characters.", status=400, error_type="VALIDATION_ERROR")
+ try:
+ user = AuthService.register_user(email=invite.email, password=password, full_name=full_name or None)
+ except Exception as exc:
+ return api_response(success=False, message=str(exc), status=400, error_type="REGISTRATION_ERROR")
+
+ # Add to org
+ role_value = invite.role
+ try:
+ org_role = OrganizationRole(role_value)
+ except ValueError:
+ org_role = OrganizationRole.MEMBER
+
+ try:
+ OrganizationService.add_member(
+ org=invite.organization,
+ user_id=user.id,
+ role=org_role,
+ inviter_id=invite.invited_by_id,
+ )
+ except Exception:
+ from gatehouse_app.extensions import db
+ db.session.rollback() # Clear broken transaction so invite.accept() can commit
+
+ invite.accept()
+
+ has_webauthn = user.has_webauthn_enabled()
+ has_totp = user.has_totp_enabled()
+
+ if has_webauthn:
+ from flask import session as flask_session
+ flask_session["webauthn_pending_user_id"] = user.id
+ return api_response(
+ data={"requires_webauthn": True},
+ message="Passkey verification required. Please use your passkey to complete sign-in.",
+ )
+
+ if has_totp:
+ from flask import session as flask_session
+ flask_session["totp_pending_user_id"] = user.id
+ return api_response(
+ data={"requires_totp": True},
+ message="TOTP code required. Please enter your 6-digit code from your authenticator app.",
+ )
+
+ user_session = AuthService.create_session(user)
+
+ return api_response(
+ data={
+ "user": user.to_dict(),
+ "token": user_session.token,
+ "expires_at": user_session.expires_at.isoformat() + "Z",
+ },
+ message="Invitation accepted. Welcome!",
+ )
+
+
+# ============================================================================
+# Organization OIDC Clients
+# ============================================================================
+
+@api_v1_bp.route("/organizations//clients", methods=["GET"])
+@login_required
+@require_admin
+@full_access_required
+def list_org_clients(org_id):
+ """List OIDC clients for an organization.
+
+ Returns:
+ 200: List of OIDC clients
+ 403: Not an admin
+ 404: Organization not found
+ """
+ from gatehouse_app.models import OIDCClient, Organization
+
+ org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
+ if not org:
+ return api_response(success=False, message="Organization not found", status=404)
+
+ clients = OIDCClient.query.filter_by(organization_id=org_id, is_active=True).all()
+
+ def client_to_dict(c):
+ return {
+ "id": c.id,
+ "name": c.name,
+ "client_id": c.client_id,
+ "redirect_uris": c.redirect_uris,
+ "scopes": c.scopes,
+ "grant_types": c.grant_types,
+ "is_active": c.is_active,
+ "created_at": c.created_at.isoformat() + "Z",
+ }
+
+ return api_response(
+ data={"clients": [client_to_dict(c) for c in clients], "count": len(clients)},
+ message="Clients retrieved successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//clients", methods=["POST"])
+@login_required
+@require_admin
+def create_org_client(org_id):
+ """Create a new OIDC client for an organization.
+
+ Request body:
+ name: Client name
+ redirect_uris: List of allowed redirect URIs (newline or comma separated string)
+
+ Returns:
+ 201: Client created with client_id and client_secret
+ 403: Not an admin
+ 404: Organization not found
+ """
+ import secrets as _secrets
+ from gatehouse_app.extensions import bcrypt
+ from gatehouse_app.models import OIDCClient, Organization
+
+ org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
+ if not org:
+ return api_response(success=False, message="Organization not found", status=404)
+
+ data = request.get_json() or {}
+ name = (data.get("name") or "").strip()
+ redirect_uris_raw = data.get("redirect_uris") or []
+
+ if not name:
+ return api_response(success=False, message="Client name is required", status=400, error_type="VALIDATION_ERROR")
+
+ if isinstance(redirect_uris_raw, str):
+ redirect_uris = [u.strip() for u in redirect_uris_raw.replace(",", "\n").splitlines() if u.strip()]
+ else:
+ redirect_uris = [u.strip() for u in redirect_uris_raw if isinstance(u, str) and u.strip()]
+
+ if not redirect_uris:
+ return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR")
+
+ client_id = _secrets.token_hex(16)
+ client_secret = _secrets.token_urlsafe(32)
+
+ client = OIDCClient(
+ organization_id=org_id,
+ name=name,
+ client_id=client_id,
+ client_secret_hash=bcrypt.generate_password_hash(client_secret).decode("utf-8"),
+ redirect_uris=redirect_uris,
+ grant_types=["authorization_code", "refresh_token"],
+ response_types=["code"],
+ scopes=["openid", "profile", "email"],
+ is_active=True,
+ is_confidential=True,
+ )
+ from gatehouse_app.extensions import db
+ db.session.add(client)
+ db.session.commit()
+
+ return api_response(
+ data={
+ "client": {
+ "id": client.id,
+ "name": client.name,
+ "client_id": client.client_id,
+ "client_secret": client_secret, # Only returned once
+ "redirect_uris": client.redirect_uris,
+ "scopes": client.scopes,
+ "created_at": client.created_at.isoformat() + "Z",
+ }
+ },
+ message="OIDC client created successfully",
+ status=201,
+ )
+
+
+@api_v1_bp.route("/organizations//clients/", methods=["DELETE"])
+@login_required
+@require_admin
+def delete_org_client(org_id, client_id):
+ """Deactivate an OIDC client.
+
+ Returns:
+ 200: Client deactivated
+ 403: Not an admin
+ 404: Client not found
+ """
+ from gatehouse_app.models import OIDCClient
+ from gatehouse_app.extensions import db
+
+ client = OIDCClient.query.filter_by(id=client_id, organization_id=org_id).first()
+ if not client:
+ return api_response(success=False, message="Client not found", status=404)
+
+ client.is_active = False
+ db.session.commit()
+
+ return api_response(data={}, message="Client deactivated successfully")
+
+
+@api_v1_bp.route("/organizations//members//send-mfa-reminder", methods=["POST"])
+@login_required
+@require_admin
+def send_mfa_reminder(org_id, user_id):
+ """Send an MFA reminder email to a specific member.
+
+ Returns:
+ 200: Reminder sent (or silently skipped if no deadline record)
+ 403: Not an admin
+ 404: Member not found
+ """
+ from gatehouse_app.models import User, MfaPolicyCompliance, OrganizationSecurityPolicy
+ from gatehouse_app.services.notification_service import NotificationService
+
+ user = User.query.filter_by(id=user_id, deleted_at=None).first()
+ if not user:
+ return api_response(success=False, message="User not found", status=404)
+
+ compliance = MfaPolicyCompliance.query.filter_by(
+ user_id=user_id, organization_id=org_id
+ ).first()
+ policy = OrganizationSecurityPolicy.query.filter_by(organization_id=org_id).first()
+
+ if compliance and policy and compliance.deadline_at:
+ NotificationService.send_mfa_deadline_reminder(user, compliance, policy)
+ else:
+ # No compliance deadline — send a generic nudge
+ NotificationService._send_email(
+ to_address=user.email,
+ subject="Reminder: Set up multi-factor authentication",
+ body=(
+ f"Hi {user.full_name or user.email},\n\n"
+ "Your organization administrator has asked you to set up "
+ "multi-factor authentication (MFA) on your Gatehouse account.\n\n"
+ "Please log in and configure MFA as soon as possible.\n\n"
+ "Gatehouse Security Team"
+ ),
+ )
+
+ return api_response(data={}, message="Reminder sent successfully")
+
+
+# =============================================================================
+# System-wide Audit Log (admin view) + User self audit
+# =============================================================================
+
+def _audit_log_to_dict(log):
+ """Serialize an AuditLog record to a dict."""
+ return {
+ "id": log.id,
+ "action": log.action.value if log.action else None,
+ "user_id": log.user_id,
+ "user": (
+ {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name}
+ if log.user else None
+ ),
+ "organization_id": log.organization_id,
+ "resource_type": log.resource_type,
+ "resource_id": log.resource_id,
+ "ip_address": log.ip_address,
+ "user_agent": log.user_agent,
+ "request_id": log.request_id,
+ "description": log.description,
+ "success": log.success,
+ "error_message": log.error_message,
+ "metadata": log.extra_data,
+ "created_at": log.created_at.isoformat() if log.created_at else None,
+ "updated_at": log.updated_at.isoformat() if log.updated_at else None,
+ }
+
+
+@api_v1_bp.route("/audit-logs", methods=["GET"])
+@login_required
+def get_system_audit_logs():
+ """
+ Get all audit logs (system-wide). Any authenticated user can query
+ their own logs; org owners/admins also see org-scoped logs; this
+ endpoint returns ALL logs for users who own at least one org
+ (acting as an admin view).
+
+ Query params:
+ page – page number (default 1)
+ per_page – results per page (default 50, max 200)
+ action – filter by AuditAction value
+ user_id – filter by user id
+ resource_type – filter by resource type
+ success – "true"/"false"
+ q – free-text search on description
+ """
+ from gatehouse_app.models.auth.audit_log import AuditLog
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+
+ current_user = g.current_user
+ page = max(1, int(request.args.get("page", 1)))
+ per_page = min(int(request.args.get("per_page", 50)), 200)
+
+ # Check if the user is an admin or owner of any org to grant admin-level access
+ is_admin = OrganizationMember.query.filter(
+ OrganizationMember.user_id == current_user.id,
+ OrganizationMember.role.in_(["OWNER", "ADMIN"]),
+ OrganizationMember.deleted_at == None,
+ ).first() is not None
+
+ query = AuditLog.query
+
+ if not is_admin:
+ # Non-admins can only see their own logs
+ query = query.filter(AuditLog.user_id == current_user.id)
+
+ # Optional filters
+ action_filter = request.args.get("action")
+ if action_filter:
+ query = query.filter(AuditLog.action == action_filter)
+
+ user_id_filter = request.args.get("user_id")
+ if user_id_filter:
+ query = query.filter(AuditLog.user_id == user_id_filter)
+
+ resource_type_filter = request.args.get("resource_type")
+ if resource_type_filter:
+ query = query.filter(AuditLog.resource_type == resource_type_filter)
+
+ success_filter = request.args.get("success")
+ if success_filter is not None:
+ query = query.filter(AuditLog.success == (success_filter.lower() == "true"))
+
+ q = request.args.get("q", "").strip()
+ if q:
+ query = query.filter(AuditLog.description.ilike(f"%{q}%"))
+
+ query = query.order_by(AuditLog.created_at.desc())
+ total = query.count()
+ logs = query.offset((page - 1) * per_page).limit(per_page).all()
+
+ return api_response(
+ data={
+ "audit_logs": [_audit_log_to_dict(log) for log in logs],
+ "count": total,
+ "page": page,
+ "per_page": per_page,
+ "pages": (total + per_page - 1) // per_page,
+ "is_admin_view": is_admin,
+ },
+ message="Audit logs retrieved",
+ )
+
+
+@api_v1_bp.route("/auth/audit-logs", methods=["GET"])
+@login_required
+def get_my_audit_logs():
+ """
+ Get audit logs for the currently authenticated user only.
+
+ Query params:
+ page – page number (default 1)
+ per_page – results per page (default 50, max 200)
+ action – filter by AuditAction value
+ """
+ from gatehouse_app.models.auth.audit_log import AuditLog
+
+ current_user = g.current_user
+ page = max(1, int(request.args.get("page", 1)))
+ per_page = min(int(request.args.get("per_page", 50)), 200)
+
+ query = AuditLog.query.filter(AuditLog.user_id == current_user.id)
+
+ action_filter = request.args.get("action")
+ if action_filter:
+ query = query.filter(AuditLog.action == action_filter)
+
+ query = query.order_by(AuditLog.created_at.desc())
+ total = query.count()
+ logs = query.offset((page - 1) * per_page).limit(per_page).all()
+
+ return api_response(
+ data={
+ "audit_logs": [_audit_log_to_dict(log) for log in logs],
+ "count": total,
+ "page": page,
+ "per_page": per_page,
+ "pages": (total + per_page - 1) // per_page,
+ },
+ message="Activity retrieved",
+ )
+
+
+
+@api_v1_bp.route("/organizations//roles", methods=["GET"])
+@login_required
+def list_organization_roles(org_id):
+ """List the available roles for an organization.
+
+ Returns the canonical set of OrganizationRole values together with every
+ current member assigned to each role.
+
+ Returns:
+ 200: roles list with member counts
+ 401: Not authenticated
+ 404: Organization not found
+ """
+ from gatehouse_app.models.organization.organization import Organization
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+
+ org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
+ if not org:
+ return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND")
+
+ # Load all active members grouped by role
+ members = OrganizationMember.query.filter_by(organization_id=org_id, deleted_at=None).all()
+ by_role: dict = {r.value: [] for r in OrganizationRole}
+ for m in members:
+ role_key = m.role.value if hasattr(m.role, "value") else str(m.role)
+ if role_key in by_role:
+ by_role[role_key].append({
+ "user_id": m.user_id,
+ "email": m.user.email if m.user else None,
+ "full_name": m.user.full_name if m.user else None,
+ "joined_at": m.created_at.isoformat() if m.created_at else None,
+ })
+
+ roles = [
+ {
+ "role": r.value,
+ "member_count": len(by_role[r.value]),
+ "members": by_role[r.value],
+ }
+ for r in OrganizationRole
+ ]
+ return api_response(data={"roles": roles, "organization_id": org_id}, message="Roles retrieved")
+
+
+@api_v1_bp.route("/organizations//roles//members", methods=["POST"])
+@login_required
+@require_admin
+def assign_role_to_member(org_id, role_name):
+ """Assign a role to a user in the organization (admin/owner only).
+
+ This is a convenience endpoint equivalent to PATCH
+ /organizations//members//role but driven by role name.
+
+ Request body:
+ user_id – UUID of the member to assign
+
+ Returns:
+ 200: Role assigned
+ 400: Invalid role / missing user_id
+ 403: Not an admin/owner
+ 404: Org or member not found
+ """
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.extensions import db
+
+ try:
+ new_role = OrganizationRole(role_name.lower())
+ except ValueError:
+ valid = [r.value for r in OrganizationRole]
+ return api_response(success=False, message=f"Invalid role. Must be one of: {valid}", status=400, error_type="VALIDATION_ERROR")
+
+ data = request.get_json() or {}
+ target_user_id = data.get("user_id")
+ if not target_user_id:
+ return api_response(success=False, message="user_id is required", status=400, error_type="VALIDATION_ERROR")
+
+ membership = OrganizationMember.query.filter_by(
+ organization_id=org_id, user_id=target_user_id, deleted_at=None
+ ).first()
+ if not membership:
+ return api_response(success=False, message="Member not found in this organization", status=404, error_type="NOT_FOUND")
+
+ membership.role = new_role
+ db.session.commit()
+ return api_response(
+ data={"user_id": target_user_id, "role": new_role.value},
+ message=f"Role updated to {new_role.value}",
+ )
+
+
+@api_v1_bp.route("/organizations//roles//members/", methods=["DELETE"])
+@login_required
+@require_admin
+def remove_role_from_member(org_id, role_name, user_id):
+ """Demote a member to GUEST (effectively removing a named role).
+
+ Removing a role downgrades the member to GUEST rather than removing them
+ from the organization entirely. Use the existing DELETE
+ /organizations//members/ endpoint to fully remove.
+
+ Returns:
+ 200: Role removed (member demoted to GUEST)
+ 400: Invalid role name
+ 403: Not an admin/owner
+ 404: Org or member not found
+ """
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.extensions import db
+
+ try:
+ OrganizationRole(role_name.lower()) # validate the name
+ except ValueError:
+ valid = [r.value for r in OrganizationRole]
+ return api_response(success=False, message=f"Invalid role. Must be one of: {valid}", status=400, error_type="VALIDATION_ERROR")
+
+ membership = OrganizationMember.query.filter_by(
+ organization_id=org_id, user_id=user_id, deleted_at=None
+ ).first()
+ if not membership:
+ return api_response(success=False, message="Member not found in this organization", status=404, error_type="NOT_FOUND")
+
+ membership.role = OrganizationRole.GUEST
+ db.session.commit()
+ return api_response(
+ data={"user_id": user_id, "role": OrganizationRole.GUEST.value},
+ message="Role removed; member demoted to GUEST",
+ )
+
+
+@api_v1_bp.route("/organizations//cas", methods=["GET"])
+@login_required
+@require_admin
+def list_org_cas(org_id):
+ """List all Certificate Authorities for an organization.
+
+ If the system config-file CA is configured (via SSH_CA_PRIVATE_KEY env var
+ or ca_key_path in etc/ssh_ca.conf) and no DB CA exists for a given ca_type,
+ a synthetic read-only entry is injected so the UI correctly shows the
+ system CA as configured rather than "Not configured".
+
+ Returns:
+ 200: List of CAs (private_key excluded)
+ 403: Not admin/owner
+ 404: Org not found
+ """
+ from gatehouse_app.models.ssh_ca.ca import CA, CaType
+ from gatehouse_app.models.organization.organization import Organization
+
+ org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
+ if not org:
+ return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND")
+
+ cas = CA.query.filter_by(organization_id=org_id, deleted_at=None).all()
+ ca_list = [ca.to_dict() for ca in cas]
+
+ # Determine which ca_types are already covered by a DB CA
+ covered_types = {ca.ca_type for ca in cas}
+
+ # Check whether a system config-file CA is available
+ system_ca_dict = _get_system_ca_dict()
+ if system_ca_dict:
+ # Inject a synthetic entry for each ca_type NOT covered by a real DB CA.
+ # The system CA only signs user certs (cert_type="user"), so we only
+ # inject it for the user slot. Host signing always needs a DB CA.
+ if CaType.USER not in covered_types:
+ ca_list.append({**system_ca_dict, "ca_type": "user"})
+
+ return api_response(
+ data={"cas": ca_list, "count": len(ca_list)},
+ message="CAs retrieved",
+ )
+
+
+@api_v1_bp.route("/organizations//cas/", methods=["PATCH"])
+@login_required
+@require_admin
+def update_org_ca(org_id, ca_id):
+ """Update CA configuration (validity hours).
+
+ Request body:
+ default_cert_validity_hours: Default validity in hours (optional)
+ max_cert_validity_hours: Maximum validity in hours (optional)
+
+ Returns:
+ 200: CA updated successfully
+ 400: Validation error
+ 403: Not admin/owner
+ 404: Org or CA not found
+ """
+ from gatehouse_app.models.ssh_ca.ca import CA
+ from gatehouse_app.models.organization.organization import Organization
+ from marshmallow import Schema, fields, validate, ValidationError
+
+ org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
+ if not org:
+ return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND")
+
+ ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first()
+ if not ca:
+ return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND")
+
+ try:
+ class CAUpdateSchema(Schema):
+ default_cert_validity_hours = fields.Int(
+ validate=validate.Range(min=1),
+ required=False
+ )
+ max_cert_validity_hours = fields.Int(
+ validate=validate.Range(min=1),
+ required=False
+ )
+
+ schema = CAUpdateSchema()
+ data = schema.load(request.json or {})
+
+ # Validate that max >= default if both are provided
+ default_hours = data.get('default_cert_validity_hours', ca.default_cert_validity_hours)
+ max_hours = data.get('max_cert_validity_hours', ca.max_cert_validity_hours)
+
+ if default_hours > max_hours:
+ return api_response(
+ success=False,
+ message="Default validity must be less than or equal to maximum validity",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ # Update fields
+ if 'default_cert_validity_hours' in data:
+ ca.default_cert_validity_hours = data['default_cert_validity_hours']
+ if 'max_cert_validity_hours' in data:
+ ca.max_cert_validity_hours = data['max_cert_validity_hours']
+
+ db.session.commit()
+
+ return api_response(
+ data={"ca": ca.to_dict()},
+ message="CA updated successfully",
+ )
+
+ except ValidationError as e:
+ return api_response(
+ success=False,
+ message="Validation failed",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ error_details=e.messages,
+ )
+ except Exception as e:
+ db.session.rollback()
+ return api_response(
+ success=False,
+ message="Failed to update CA",
+ status=500,
+ error_type="SERVER_ERROR",
+ )
+
+
+@api_v1_bp.route("/organizations//cas", methods=["POST"])
+@login_required
+@require_admin
+def create_org_ca(org_id):
+ """Create a new Certificate Authority for an organization.
+
+ Request body:
+ name: CA display name (required)
+ description: Optional description
+ key_type: "ed25519" (default), "rsa", or "ecdsa"
+ default_cert_validity_hours: Default cert validity in hours (optional)
+ max_cert_validity_hours: Max cert validity in hours (optional)
+
+ Returns:
+ 201: CA created successfully
+ 400: Validation error or name already taken
+ 403: Not admin/owner
+ 404: Org not found
+ """
+ from gatehouse_app.models.ssh_ca.ca import CA, KeyType
+ from gatehouse_app.models.organization.organization import Organization
+ from gatehouse_app.utils.crypto import compute_ssh_fingerprint
+ from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
+ from marshmallow import Schema, fields as ma_fields, validate, ValidationError as MaValidationError
+ from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
+
+ org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
+ if not org:
+ return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND")
+
+ class CreateCASchema(Schema):
+ name = ma_fields.Str(required=True, validate=validate.Length(min=1, max=255))
+ description = ma_fields.Str(load_default=None, allow_none=True)
+ ca_type = ma_fields.Str(load_default="user", validate=validate.OneOf(["user", "host"]))
+ key_type = ma_fields.Str(load_default="ed25519", validate=validate.OneOf(["ed25519", "rsa", "ecdsa"]))
+ default_cert_validity_hours = ma_fields.Int(load_default=8, validate=validate.Range(min=1))
+ max_cert_validity_hours = ma_fields.Int(load_default=720, validate=validate.Range(min=1))
+
+ try:
+ schema = CreateCASchema()
+ data = schema.load(request.get_json() or {})
+
+ # Check name uniqueness within org
+ existing = CA.query.filter_by(
+ organization_id=org_id, name=data["name"], deleted_at=None
+ ).first()
+ if existing:
+ return api_response(
+ success=False,
+ message="A CA with that name already exists in this organization",
+ status=400,
+ error_type="DUPLICATE_NAME",
+ )
+
+ # Enforce one CA per type per org
+ from gatehouse_app.models.ssh_ca.ca import CaType
+ ca_type_val = data["ca_type"]
+ existing_type = CA.query.filter_by(
+ organization_id=org_id, deleted_at=None
+ ).filter(CA.ca_type == CaType(ca_type_val)).first()
+ if existing_type:
+ type_label = "User" if ca_type_val == "user" else "Host"
+ return api_response(
+ success=False,
+ message=f"A {type_label} CA already exists for this organization. "
+ f"You can only have one {type_label} CA per organization.",
+ status=400,
+ error_type="DUPLICATE_CA_TYPE",
+ )
+
+ # Validate cross-field
+ if data["default_cert_validity_hours"] > data["max_cert_validity_hours"]:
+ return api_response(
+ success=False,
+ message="Default validity must be less than or equal to maximum validity",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ # Generate key pair
+ key_type = data["key_type"]
+ if key_type == "ed25519":
+ private_key_obj = Ed25519PrivateKey.generate()
+ elif key_type == "rsa":
+ private_key_obj = RsaPrivateKey.generate(4096)
+ else: # ecdsa
+ private_key_obj = EcdsaPrivateKey.generate()
+
+ private_key_pem = private_key_obj.to_string()
+ public_key_str = private_key_obj.public_key.to_string()
+ fingerprint = compute_ssh_fingerprint(public_key_str)
+
+ # Encrypt the private key before storing in the database
+ encrypted_private_key = encrypt_ca_key(private_key_pem)
+
+ ca = CA(
+ organization_id=org_id,
+ name=data["name"],
+ description=data["description"],
+ ca_type=CaType(ca_type_val),
+ key_type=KeyType(key_type),
+ private_key=encrypted_private_key,
+ public_key=public_key_str,
+ fingerprint=fingerprint,
+ default_cert_validity_hours=data["default_cert_validity_hours"],
+ max_cert_validity_hours=data["max_cert_validity_hours"],
+ is_active=True,
+ )
+ db.session.add(ca)
+ try:
+ db.session.commit()
+ except Exception as commit_exc:
+ db.session.rollback()
+ # Surface unique-constraint violations (soft-deleted record with same name) as a
+ # user-friendly 400 instead of a 500.
+ exc_str = str(commit_exc).lower()
+ if "uix_org_ca_name" in exc_str or "unique" in exc_str:
+ return api_response(
+ success=False,
+ message=(
+ "A CA with that name already exists in this organization "
+ "(it may have been recently deleted — choose a different name)."
+ ),
+ status=400,
+ error_type="DUPLICATE_NAME",
+ )
+ raise
+
+ return api_response(
+ data={"ca": ca.to_dict()},
+ message="CA created successfully",
+ status=201,
+ )
+
+ except MaValidationError as e:
+ return api_response(
+ success=False,
+ message="Validation failed",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ error_details=e.messages,
+ )
+ except Exception as e:
+ db.session.rollback()
+ current_app.logger.exception("Failed to create CA")
+ return api_response(
+ success=False,
+ message="Failed to create CA",
+ status=500,
+ error_type="SERVER_ERROR",
+ )
+
+
+@api_v1_bp.route("/organizations//cas/", methods=["DELETE"])
+@login_required
+@require_admin
+def delete_org_ca(org_id, ca_id):
+ """Soft-delete a Certificate Authority.
+
+ Deactivates the CA so no new certificates can be signed with it.
+ Existing certificates remain valid until they expire.
+
+ Returns:
+ 200: CA deleted successfully
+ 403: Not admin/owner
+ 404: Org or CA not found
+ """
+ from gatehouse_app.models.ssh_ca.ca import CA
+ from gatehouse_app.models.organization.organization import Organization
+ from gatehouse_app.utils.constants import AuditAction
+ from gatehouse_app.models import AuditLog
+
+ org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
+ if not org:
+ return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND")
+
+ ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first()
+ if not ca:
+ return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND")
+
+ try:
+ ca_name = ca.name
+ ca_type = ca.ca_type.value if hasattr(ca.ca_type, "value") else str(ca.ca_type)
+ ca.is_active = False
+ ca.delete(soft=True)
+
+ AuditLog.log(
+ action=AuditAction.CA_DELETED,
+ user_id=g.current_user.id,
+ resource_type="CA",
+ resource_id=ca_id,
+ organization_id=org_id,
+ ip_address=request.remote_addr,
+ description=f"CA '{ca_name}' ({ca_type}) deleted",
+ )
+
+ return api_response(
+ data={"ca_id": ca_id},
+ message="CA deleted successfully",
+ )
+ except Exception as e:
+ db.session.rollback()
+ current_app.logger.exception("Failed to delete CA")
+ return api_response(
+ success=False,
+ message="Failed to delete CA",
+ status=500,
+ error_type="SERVER_ERROR",
+ )
+
+
+@api_v1_bp.route("/organizations//cas//rotate", methods=["POST"])
+@login_required
+@require_admin
+def rotate_org_ca(org_id, ca_id):
+ """Rotate (replace) a CA's key pair.
+
+ Generates a new key pair of the same or different type. The old public key
+ fingerprint is returned so admins can update TrustedUserCAKeys / known_hosts
+ on their servers. All previously-issued certificates remain valid until they
+ expire but no new certificates will be signed with the old key.
+
+ Request body (all optional):
+ key_type: "ed25519" (default keeps current), "rsa", or "ecdsa"
+ reason: Human-readable reason for the rotation
+
+ Returns:
+ 200: CA rotated — { ca, old_fingerprint }
+ 403: Not admin/owner
+ 404: Org or CA not found
+ """
+ from gatehouse_app.models.ssh_ca.ca import CA, KeyType
+ from gatehouse_app.models.organization.organization import Organization
+ from gatehouse_app.utils.crypto import compute_ssh_fingerprint
+ from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
+ from gatehouse_app.utils.constants import AuditAction
+ from gatehouse_app.models import AuditLog
+ from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
+
+ org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
+ if not org:
+ return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND")
+
+ ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first()
+ if not ca:
+ return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND")
+
+ data = request.get_json() or {}
+ new_key_type = data.get("key_type") or (ca.key_type.value if hasattr(ca.key_type, "value") else str(ca.key_type))
+ reason = data.get("reason", "Admin-initiated key rotation")
+
+ if new_key_type not in ("ed25519", "rsa", "ecdsa"):
+ return api_response(
+ success=False,
+ message="Invalid key_type. Must be one of: ed25519, rsa, ecdsa",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ try:
+ old_fingerprint = ca.fingerprint
+
+ # Generate new key pair
+ if new_key_type == "ed25519":
+ private_key_obj = Ed25519PrivateKey.generate()
+ elif new_key_type == "rsa":
+ private_key_obj = RsaPrivateKey.generate(4096)
+ else: # ecdsa
+ private_key_obj = EcdsaPrivateKey.generate()
+
+ new_private_key = private_key_obj.to_string()
+ new_public_key = private_key_obj.public_key.to_string()
+ new_fingerprint = compute_ssh_fingerprint(new_public_key)
+
+ # Encrypt the new private key before storing
+ encrypted_new_private_key = encrypt_ca_key(new_private_key)
+
+ ca.rotate_key(
+ new_private_key=encrypted_new_private_key,
+ new_public_key=new_public_key,
+ new_fingerprint=new_fingerprint,
+ reason=reason,
+ )
+ ca.key_type = KeyType(new_key_type)
+ db.session.commit()
+
+ AuditLog.log(
+ action=AuditAction.CA_KEY_ROTATED,
+ user_id=g.current_user.id,
+ resource_type="CA",
+ resource_id=ca_id,
+ organization_id=org_id,
+ ip_address=request.remote_addr,
+ description=(
+ f"CA '{ca.name}' key rotated. "
+ f"Old fingerprint: {old_fingerprint}, New fingerprint: {new_fingerprint}. "
+ f"Reason: {reason}"
+ ),
+ )
+
+ return api_response(
+ data={
+ "ca": ca.to_dict(),
+ "old_fingerprint": old_fingerprint,
+ },
+ message="CA key rotated successfully. Update TrustedUserCAKeys / known_hosts on your servers.",
+ )
+ except Exception as e:
+ db.session.rollback()
+ current_app.logger.exception("Failed to rotate CA key")
+ return api_response(
+ success=False,
+ message="Failed to rotate CA key",
+ status=500,
+ error_type="SERVER_ERROR",
+ )
+
diff --git a/gatehouse_app/api/v1/policies.py b/gatehouse_app/api/v1/policies.py
index ddf16f2..c97465f 100644
--- a/gatehouse_app/api/v1/policies.py
+++ b/gatehouse_app/api/v1/policies.py
@@ -195,20 +195,46 @@ def get_org_mfa_compliance(org_id):
limit = min(int(request.args.get("limit", 100)), 100)
offset = int(request.args.get("offset", 0))
+ page = int(request.args.get("page", 1))
+ page_size = min(int(request.args.get("page_size", limit)), 100)
+
+ effective_offset = offset if request.args.get("offset") else (page - 1) * page_size
compliance_list = MfaPolicyService.get_org_compliance_list(
organization_id=org_id,
status=status,
- limit=limit,
- offset=offset,
+ limit=page_size,
+ offset=effective_offset,
)
+ def format_member(c):
+ """Normalize compliance record to UI-expected shape."""
+ if isinstance(c, dict):
+ return {
+ "user_id": c.get("user_id"),
+ "user_email": c.get("email"),
+ "user_name": c.get("full_name"),
+ "status": c.get("status"),
+ "deadline_at": c.get("deadline_at"),
+ "compliant_at": c.get("compliant_at"),
+ "last_notified_at": c.get("notified_at"),
+ }
+ return {
+ "user_id": getattr(c, "user_id", None),
+ "user_email": getattr(c, "email", None),
+ "user_name": getattr(c, "full_name", None),
+ "status": getattr(c, "status", None),
+ "deadline_at": getattr(c, "deadline_at", None),
+ "compliant_at": getattr(c, "compliant_at", None),
+ "last_notified_at": getattr(c, "notified_at", None),
+ }
+
return api_response(
data={
- "compliance": compliance_list,
+ "members": [format_member(c) for c in compliance_list],
"count": len(compliance_list),
- "limit": limit,
- "offset": offset,
+ "page": page,
+ "page_size": page_size,
},
message="Compliance records retrieved successfully",
)
@@ -325,12 +351,10 @@ def get_my_mfa_compliance():
return api_response(
data={
- "mfa_compliance": {
- "overall_status": compliance_summary.overall_status,
- "missing_methods": compliance_summary.missing_methods,
- "deadline_at": compliance_summary.deadline_at,
- "orgs": orgs,
- }
+ "overall_status": compliance_summary.overall_status,
+ "missing_methods": compliance_summary.missing_methods,
+ "deadline_at": compliance_summary.deadline_at,
+ "orgs": orgs,
},
message="MFA compliance retrieved successfully",
)
\ No newline at end of file
diff --git a/gatehouse_app/api/v1/principals.py b/gatehouse_app/api/v1/principals.py
new file mode 100644
index 0000000..0da8315
--- /dev/null
+++ b/gatehouse_app/api/v1/principals.py
@@ -0,0 +1,779 @@
+"""Principal endpoints."""
+from flask import g, request
+from marshmallow import Schema, fields, validate, ValidationError
+
+from gatehouse_app.api.v1 import api_v1_bp
+from gatehouse_app.utils.response import api_response
+from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
+from gatehouse_app.models import Principal, PrincipalMembership, Department, DepartmentPrincipal
+from gatehouse_app.services.organization_service import OrganizationService
+from gatehouse_app.services.user_service import UserService
+from gatehouse_app.exceptions import OrganizationNotFoundError
+from gatehouse_app.extensions import db
+
+
+class PrincipalCreateSchema(Schema):
+ """Schema for creating a principal."""
+ name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
+ description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
+
+
+class PrincipalUpdateSchema(Schema):
+ """Schema for updating a principal."""
+ name = fields.Str(validate=validate.Length(min=1, max=255))
+ description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
+
+
+class AddPrincipalMemberSchema(Schema):
+ """Schema for adding a member to a principal."""
+ email = fields.Email(required=True)
+
+
+class LinkPrincipalSchema(Schema):
+ """Schema for linking principal to department."""
+ department_id = fields.Str(required=True)
+
+
+@api_v1_bp.route("/organizations//principals", methods=["GET"])
+@login_required
+@full_access_required
+def list_principals(org_id):
+ """
+ List all principals in an organization.
+
+ Args:
+ org_id: Organization ID
+
+ Returns:
+ 200: List of principals
+ 401: Not authenticated
+ 403: Not a member
+ 404: Organization not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ if not org.is_member(g.current_user.id):
+ return api_response(
+ success=False,
+ message="You are not a member of this organization",
+ status=403,
+ error_type="AUTHORIZATION_ERROR",
+ )
+
+ principals = Principal.query.filter_by(
+ organization_id=org_id,
+ deleted_at=None
+ ).all()
+
+ return api_response(
+ data={
+ "principals": [p.to_dict() for p in principals],
+ "count": len(principals),
+ },
+ message="Principals retrieved successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//principals", methods=["POST"])
+@login_required
+@require_admin
+@full_access_required
+def create_principal(org_id):
+ """
+ Create a new principal.
+
+ Args:
+ org_id: Organization ID
+
+ Request body:
+ name: Principal name (required)
+ description: Optional description
+
+ Returns:
+ 201: Principal created successfully
+ 400: Validation error
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization not found
+ 409: Principal name already exists
+ """
+ try:
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ schema = PrincipalCreateSchema()
+ data = schema.load(request.json or {})
+
+ # Check if principal name already exists
+ existing = Principal.query.filter_by(
+ organization_id=org_id,
+ name=data["name"],
+ deleted_at=None
+ ).first()
+
+ if existing:
+ return api_response(
+ success=False,
+ message=f"Principal '{data['name']}' already exists",
+ status=409,
+ error_type="CONFLICT",
+ )
+
+ # Create principal
+ principal = Principal(
+ organization_id=org_id,
+ name=data["name"],
+ description=data.get("description"),
+ )
+ db.session.add(principal)
+ db.session.commit()
+
+ return api_response(
+ data={"principal": principal.to_dict()},
+ message="Principal created successfully",
+ status=201,
+ )
+
+ except ValidationError as e:
+ return api_response(
+ success=False,
+ message="Validation failed",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ error_details=e.messages,
+ )
+
+
+@api_v1_bp.route("/organizations//principals/", methods=["GET"])
+@login_required
+@full_access_required
+def get_principal(org_id, principal_id):
+ """
+ Get a specific principal.
+
+ Args:
+ org_id: Organization ID
+ principal_id: Principal ID
+
+ Returns:
+ 200: Principal data
+ 401: Not authenticated
+ 403: Not a member
+ 404: Organization or principal not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ if not org.is_member(g.current_user.id):
+ return api_response(
+ success=False,
+ message="You are not a member of this organization",
+ status=403,
+ error_type="AUTHORIZATION_ERROR",
+ )
+
+ principal = Principal.query.filter_by(
+ id=principal_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not principal:
+ return api_response(
+ success=False,
+ message="Principal not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ return api_response(
+ data={"principal": principal.to_dict()},
+ message="Principal retrieved successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//principals/", methods=["PATCH"])
+@login_required
+@require_admin
+@full_access_required
+def update_principal(org_id, principal_id):
+ """
+ Update a principal.
+
+ Args:
+ org_id: Organization ID
+ principal_id: Principal ID
+
+ Request body:
+ name: Optional new name
+ description: Optional new description
+
+ Returns:
+ 200: Principal updated successfully
+ 400: Validation error
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization or principal not found
+ 409: Name already exists
+ """
+ try:
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ principal = Principal.query.filter_by(
+ id=principal_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not principal:
+ return api_response(
+ success=False,
+ message="Principal not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ schema = PrincipalUpdateSchema()
+ data = schema.load(request.json or {})
+
+ # Check if new name already exists
+ if "name" in data and data["name"] != principal.name:
+ existing = Principal.query.filter_by(
+ organization_id=org_id,
+ name=data["name"],
+ deleted_at=None
+ ).first()
+ if existing:
+ return api_response(
+ success=False,
+ message=f"Principal '{data['name']}' already exists",
+ status=409,
+ error_type="CONFLICT",
+ )
+
+ # Update fields
+ for key, value in data.items():
+ setattr(principal, key, value)
+
+ db.session.commit()
+
+ return api_response(
+ data={"principal": principal.to_dict()},
+ message="Principal updated successfully",
+ )
+
+ except ValidationError as e:
+ return api_response(
+ success=False,
+ message="Validation failed",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ error_details=e.messages,
+ )
+
+
+@api_v1_bp.route("/organizations//principals/", methods=["DELETE"])
+@login_required
+@require_admin
+@full_access_required
+def delete_principal(org_id, principal_id):
+ """
+ Delete a principal (soft delete).
+
+ Args:
+ org_id: Organization ID
+ principal_id: Principal ID
+
+ Returns:
+ 200: Principal deleted successfully
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization or principal not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ principal = Principal.query.filter_by(
+ id=principal_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not principal:
+ return api_response(
+ success=False,
+ message="Principal not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ # Soft delete
+ principal.deleted_at = db.func.now()
+ db.session.commit()
+
+ return api_response(
+ message="Principal deleted successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//principals//members", methods=["GET"])
+@login_required
+@full_access_required
+def get_principal_members(org_id, principal_id):
+ """
+ Get all members (direct + via department) with access to a principal.
+
+ Args:
+ org_id: Organization ID
+ principal_id: Principal ID
+
+ Returns:
+ 200: List of members
+ 401: Not authenticated
+ 403: Not a member
+ 404: Organization or principal not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ if not org.is_member(g.current_user.id):
+ return api_response(
+ success=False,
+ message="You are not a member of this organization",
+ status=403,
+ error_type="AUTHORIZATION_ERROR",
+ )
+
+ principal = Principal.query.filter_by(
+ id=principal_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not principal:
+ return api_response(
+ success=False,
+ message="Principal not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ # Get direct members
+ direct_members = PrincipalMembership.query.filter_by(
+ principal_id=principal_id,
+ deleted_at=None
+ ).all()
+
+ all_users = set()
+ for membership in direct_members:
+ if membership.user.deleted_at is None:
+ all_users.add(membership.user)
+
+ # Get members via departments
+ dept_links = DepartmentPrincipal.query.filter_by(
+ principal_id=principal_id,
+ deleted_at=None
+ ).all()
+
+ for link in dept_links:
+ dept = link.department
+ if dept.deleted_at is None:
+ dept_members = dept.get_members(active_only=True)
+ for dept_member in dept_members:
+ if dept_member.user.deleted_at is None:
+ all_users.add(dept_member.user)
+
+ users_data = [u.to_dict() for u in all_users]
+
+ return api_response(
+ data={
+ "members": users_data,
+ "count": len(users_data),
+ },
+ message="Members retrieved successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//principals//members", methods=["POST"])
+@login_required
+@require_admin
+@full_access_required
+def add_principal_member(org_id, principal_id):
+ """
+ Add a direct member to a principal.
+
+ Args:
+ org_id: Organization ID
+ principal_id: Principal ID
+
+ Request body:
+ email: User email to add
+
+ Returns:
+ 201: Member added successfully
+ 400: Validation error
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization, principal, or user not found
+ 409: User already a member
+ """
+ try:
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ principal = Principal.query.filter_by(
+ id=principal_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not principal:
+ return api_response(
+ success=False,
+ message="Principal not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ schema = AddPrincipalMemberSchema()
+ data = schema.load(request.json or {})
+
+ # Find user by email
+ user = UserService.get_user_by_email(data["email"])
+ if not user:
+ return api_response(
+ success=False,
+ message="User not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ # Check if already a member
+ existing = PrincipalMembership.query.filter_by(
+ user_id=user.id,
+ principal_id=principal_id,
+ deleted_at=None
+ ).first()
+
+ if existing:
+ return api_response(
+ success=False,
+ message="User is already a member of this principal",
+ status=409,
+ error_type="CONFLICT",
+ )
+
+ soft_deleted = PrincipalMembership.query.filter(
+ PrincipalMembership.user_id == user.id,
+ PrincipalMembership.principal_id == principal_id,
+ PrincipalMembership.deleted_at.isnot(None)
+ ).first()
+
+ if soft_deleted:
+ soft_deleted.deleted_at = None
+ membership = soft_deleted
+ else:
+ membership = PrincipalMembership(
+ user_id=user.id,
+ principal_id=principal_id,
+ )
+ db.session.add(membership)
+
+ db.session.commit()
+
+ member_dict = membership.to_dict()
+ member_dict["user"] = user.to_dict()
+
+ return api_response(
+ data={"member": member_dict},
+ message="Member added successfully",
+ status=201,
+ )
+
+ except ValidationError as e:
+ return api_response(
+ success=False,
+ message="Validation failed",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ error_details=e.messages,
+ )
+
+
+@api_v1_bp.route("/organizations//principals//members/", methods=["DELETE"])
+@login_required
+@require_admin
+@full_access_required
+def remove_principal_member(org_id, principal_id, user_id):
+ """
+ Remove a direct member from a principal.
+
+ Args:
+ org_id: Organization ID
+ principal_id: Principal ID
+ user_id: User ID to remove
+
+ Returns:
+ 200: Member removed successfully
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization, principal, or member not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ principal = Principal.query.filter_by(
+ id=principal_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not principal:
+ return api_response(
+ success=False,
+ message="Principal not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ membership = PrincipalMembership.query.filter_by(
+ user_id=user_id,
+ principal_id=principal_id,
+ deleted_at=None
+ ).first()
+
+ if not membership:
+ return api_response(
+ success=False,
+ message="User is not a member of this principal",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ # Soft delete
+ membership.deleted_at = db.func.now()
+ db.session.commit()
+
+ return api_response(
+ message="Member removed successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//principals//departments", methods=["GET"])
+@login_required
+@full_access_required
+def get_principal_departments(org_id, principal_id):
+ """
+ Get all departments this principal is assigned to.
+
+ Args:
+ org_id: Organization ID
+ principal_id: Principal ID
+
+ Returns:
+ 200: List of departments
+ 401: Not authenticated
+ 403: Not a member
+ 404: Organization or principal not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ if not org.is_member(g.current_user.id):
+ return api_response(
+ success=False,
+ message="You are not a member of this organization",
+ status=403,
+ error_type="AUTHORIZATION_ERROR",
+ )
+
+ principal = Principal.query.filter_by(
+ id=principal_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not principal:
+ return api_response(
+ success=False,
+ message="Principal not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ depts = principal.get_departments(active_only=True)
+
+ return api_response(
+ data={
+ "departments": [d.to_dict() for d in depts],
+ "count": len(depts),
+ },
+ message="Departments retrieved successfully",
+ )
+
+
+@api_v1_bp.route("/organizations//principals//departments/", methods=["POST"])
+@login_required
+@require_admin
+@full_access_required
+def link_principal_to_department(org_id, principal_id, dept_id):
+ """
+ Link a principal to a department.
+
+ Args:
+ org_id: Organization ID
+ principal_id: Principal ID
+ dept_id: Department ID
+
+ Returns:
+ 201: Principal linked successfully
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization, principal, or department not found
+ 409: Already linked
+ """
+ try:
+ org = OrganizationService.get_organization_by_id(org_id)
+ except OrganizationNotFoundError:
+ return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND")
+
+ principal = Principal.query.filter_by(
+ id=principal_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not principal:
+ return api_response(
+ success=False,
+ message="Principal not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ dept = Department.query.filter_by(
+ id=dept_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not dept:
+ return api_response(
+ success=False,
+ message="Department not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ existing = DepartmentPrincipal.query.filter_by(
+ department_id=dept_id,
+ principal_id=principal_id,
+ deleted_at=None
+ ).first()
+
+ if existing:
+ return api_response(
+ success=False,
+ message="Principal is already linked to this department",
+ status=409,
+ error_type="CONFLICT",
+ )
+
+ soft_deleted = DepartmentPrincipal.query.filter(
+ DepartmentPrincipal.department_id == dept_id,
+ DepartmentPrincipal.principal_id == principal_id,
+ DepartmentPrincipal.deleted_at.isnot(None),
+ ).first()
+
+ try:
+ if soft_deleted:
+ soft_deleted.deleted_at = None
+ else:
+ link = DepartmentPrincipal(
+ department_id=dept_id,
+ principal_id=principal_id,
+ )
+ db.session.add(link)
+ db.session.commit()
+ except Exception:
+ db.session.rollback()
+ return api_response(
+ success=False,
+ message="Failed to link principal to department",
+ status=500,
+ error_type="SERVER_ERROR",
+ )
+
+ return api_response(
+ data={
+ "principal": principal.to_dict(),
+ "department": dept.to_dict(),
+ },
+ message="Principal linked to department successfully",
+ status=201,
+ )
+
+
+@api_v1_bp.route("/organizations//principals//departments/", methods=["DELETE"])
+@login_required
+@require_admin
+@full_access_required
+def unlink_principal_from_department(org_id, principal_id, dept_id):
+ """
+ Unlink a principal from a department.
+
+ Args:
+ org_id: Organization ID
+ principal_id: Principal ID
+ dept_id: Department ID
+
+ Returns:
+ 200: Principal unlinked successfully
+ 401: Not authenticated
+ 403: Not an admin
+ 404: Organization, principal, department, or link not found
+ """
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ principal = Principal.query.filter_by(
+ id=principal_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not principal:
+ return api_response(
+ success=False,
+ message="Principal not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ dept = Department.query.filter_by(
+ id=dept_id,
+ organization_id=org_id,
+ deleted_at=None
+ ).first()
+
+ if not dept:
+ return api_response(
+ success=False,
+ message="Department not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ link = DepartmentPrincipal.query.filter_by(
+ department_id=dept_id,
+ principal_id=principal_id,
+ deleted_at=None
+ ).first()
+
+ if not link:
+ return api_response(
+ success=False,
+ message="Principal is not linked to this department",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ # Soft delete
+ link.deleted_at = db.func.now()
+ db.session.commit()
+
+ return api_response(
+ message="Principal unlinked from department successfully",
+ )
diff --git a/gatehouse_app/api/v1/ssh.py b/gatehouse_app/api/v1/ssh.py
new file mode 100644
index 0000000..8a54ae5
--- /dev/null
+++ b/gatehouse_app/api/v1/ssh.py
@@ -0,0 +1,1108 @@
+"""SSH Key and Certificate API routes."""
+from flask import Blueprint, request, jsonify, g
+from sqlalchemy.exc import IntegrityError
+from gatehouse_app.services.ssh_key_service import SSHKeyService
+from gatehouse_app.services.ssh_ca_signing_service import (
+ SSHCASigningService,
+ SSHCertificateSigningRequest,
+)
+from gatehouse_app.exceptions import (
+ SSHKeyError,
+ SSHKeyNotFoundError,
+ SSHCertificateError,
+ ValidationError,
+ SSHKeyAlreadyExistsError,
+)
+from gatehouse_app.utils.constants import AuditAction
+from gatehouse_app.models import AuditLog
+from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
+from gatehouse_app.utils.decorators import login_required
+from gatehouse_app.utils.response import api_response
+
+ssh_bp = Blueprint('ssh', __name__, url_prefix='/ssh')
+ssh_key_service = SSHKeyService()
+ssh_ca_service = SSHCASigningService()
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _get_org_ca_for_user(user, ca_type: str = "user"):
+ """Return the active DB CA of the given type for the user's first org, or None.
+
+ Args:
+ user: The current user object.
+ ca_type: ``"user"`` (default) or ``"host"`` — selects the CA that signs
+ the corresponding certificate type.
+ """
+ try:
+ from gatehouse_app.models.ssh_ca.ca import CA, CaType
+ org_ids = [m.organization_id for m in user.organization_memberships]
+ if not org_ids:
+ return None
+ return CA.query.filter(
+ CA.organization_id.in_(org_ids),
+ CA.ca_type == CaType(ca_type),
+ CA.is_active == True, # noqa: E712
+ ).first()
+ except Exception:
+ return None
+
+
+def _get_or_create_system_ca():
+ """
+ Return a CA DB record representing the config-file CA.
+
+ This is used as the ``ca_id`` FK when persisting certificates that were
+ signed by the globally-configured CA key (not an org-specific DB CA).
+ The record is created on first use and has no ``organization_id``.
+ """
+ from gatehouse_app.extensions import db
+ from gatehouse_app.models.ssh_ca.ca import CA, KeyType
+ from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
+ from gatehouse_app.utils.crypto import compute_ssh_fingerprint
+ import os
+
+ try:
+ existing = CA.query.filter_by(name="system-config-ca").first()
+ if existing:
+ return existing
+
+ cfg = get_ssh_ca_config()
+ key_path = cfg.get_str("ca_key_path", "").strip()
+ pub_key_path = key_path + ".pub"
+
+ if not os.path.exists(pub_key_path):
+ return None
+
+ with open(pub_key_path) as f:
+ pub_key = f.read().strip()
+
+ # Load private key for the record (encrypt before storing in DB)
+ priv_key = ""
+ if os.path.exists(key_path):
+ with open(key_path) as f:
+ raw_priv_key = f.read()
+ try:
+ from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
+ priv_key = encrypt_ca_key(raw_priv_key)
+ except Exception:
+ priv_key = raw_priv_key # fallback: store as-is if encryption unavailable
+
+ fingerprint = compute_ssh_fingerprint(pub_key)
+
+ # Check by fingerprint in case it was created under a different name
+ existing_by_fp = CA.query.filter_by(fingerprint=fingerprint).first()
+ if existing_by_fp:
+ return existing_by_fp
+
+ system_ca = CA(
+ name="system-config-ca",
+ description="Global CA loaded from etc/ssh_ca.conf (ca_key_path)",
+ key_type=KeyType.ED25519,
+ private_key=priv_key,
+ public_key=pub_key,
+ fingerprint=fingerprint,
+ is_active=True,
+ default_cert_validity_hours=24,
+ max_cert_validity_hours=720,
+ )
+ # organization_id is nullable=False in schema — we need a dummy org or
+ # need to allow NULL. Use None; the DB constraint will tell us quickly.
+ # If the migration enforces NOT NULL we'll catch the error gracefully.
+ db.session.add(system_ca)
+ db.session.commit()
+ return system_ca
+ except Exception as exc:
+ import logging
+ logging.getLogger(__name__).warning(
+ f"Could not upsert system-config-ca: {exc}"
+ )
+ try:
+ db.session.rollback()
+ except Exception:
+ pass
+ return None
+
+
+def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user', cert_identity=None):
+ """Save a signed certificate to the ssh_certificates table.
+
+ Args:
+ user_id: UUID of the user
+ ssh_key_id: UUID of the SSH key that was signed
+ ca: CA model instance (may be None — cert still returned but not persisted)
+ signing_response: SSHCertificateSigningResponse
+ request_ip: Client IP address
+ cert_type_str: 'user' or 'host' (from the sign request)
+ cert_identity: Rich OpenSSH key_id string (e.g. "user@host (Name) [org:slug]").
+ Falls back to str(ssh_key_id) when not provided.
+
+ Returns:
+ SSHCertificate instance or None if persistence failed
+ """
+ if ca is None:
+ return None
+
+ try:
+ from gatehouse_app.extensions import db
+ from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
+ from gatehouse_app.models.ssh_ca.ca import CertType
+
+ try:
+ resolved_cert_type = CertType(cert_type_str)
+ except ValueError:
+ resolved_cert_type = CertType.USER
+
+ cert_record = SSHCertificate(
+ ca_id=ca.id,
+ user_id=user_id,
+ ssh_key_id=ssh_key_id,
+ certificate=signing_response.certificate,
+ serial=signing_response.serial,
+ key_id=cert_identity or str(ssh_key_id),
+ cert_type=resolved_cert_type,
+ principals=signing_response.principals,
+ valid_after=signing_response.valid_after,
+ valid_before=signing_response.valid_before,
+ revoked=False,
+ status=CertificateStatus.ISSUED,
+ request_ip=request_ip,
+ )
+ db.session.add(cert_record)
+ db.session.commit()
+ return cert_record
+ except Exception as exc:
+ import logging
+ logging.getLogger(__name__).warning(
+ f"Failed to persist certificate to DB: {exc}"
+ )
+ try:
+ from gatehouse_app.extensions import db as _db
+ _db.session.rollback()
+ except Exception:
+ pass
+ return None
+
+
+
+def _get_merged_dept_cert_policy(user_id):
+ """Return a merged cert policy view for the given user across all their departments.
+
+ Rules for merging when a user belongs to multiple departments:
+ - ``allow_user_expiry``: True only if ALL departments allow it.
+ - ``default_expiry_hours``: minimum across departments (most restrictive).
+ - ``max_expiry_hours``: minimum across departments (most restrictive).
+ - ``extensions``: intersection — only extensions allowed by ALL departments.
+
+ Returns a plain dict with keys:
+ allow_user_expiry, default_expiry_hours, max_expiry_hours, extensions
+ Or None if the user has no department memberships or no policies are configured.
+ """
+ from gatehouse_app.models.organization.department import DepartmentMembership
+ from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy, STANDARD_EXTENSIONS
+
+ memberships = DepartmentMembership.query.filter_by(user_id=user_id, deleted_at=None).all()
+ dept_ids = [m.department_id for m in memberships if m.department and m.department.deleted_at is None]
+ if not dept_ids:
+ return None
+
+ policies = DepartmentCertPolicy.query.filter(
+ DepartmentCertPolicy.department_id.in_(dept_ids),
+ DepartmentCertPolicy.deleted_at.is_(None),
+ ).all()
+ if not policies:
+ return None
+
+ allow_user_expiry = all(p.allow_user_expiry for p in policies)
+ default_expiry_hours = min(p.default_expiry_hours for p in policies)
+ max_expiry_hours = min(p.max_expiry_hours for p in policies)
+
+ # Intersection of all_extensions() across policies
+ ext_sets = [set(p.all_extensions()) for p in policies]
+ extensions = list(ext_sets[0].intersection(*ext_sets[1:]))
+
+ return {
+ "allow_user_expiry": allow_user_expiry,
+ "default_expiry_hours": default_expiry_hours,
+ "max_expiry_hours": max_expiry_hours,
+ "extensions": extensions,
+ }
+
+
+@ssh_bp.route('/dept-cert-policy', methods=['GET'])
+@login_required
+def get_my_dept_cert_policy():
+ """Return the merged department certificate policy for the current user.
+
+ Admins always get allow_user_expiry=True so the frontend shows the expiry
+ picker for them regardless of the member-facing toggle setting.
+ """
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.models.organization.department_cert_policy import STANDARD_EXTENSIONS
+ from gatehouse_app.utils.constants import OrganizationRole
+
+ user = g.current_user
+ user_id = user.id
+
+ # Check if caller is an org admin/owner
+ is_org_admin = OrganizationMember.query.filter(
+ OrganizationMember.user_id == user_id,
+ OrganizationMember.role.in_(["OWNER", "ADMIN"]),
+ OrganizationMember.deleted_at == None,
+ ).first() is not None
+
+ policy = _get_merged_dept_cert_policy(user_id)
+ if policy is None:
+ policy = {
+ "allow_user_expiry": is_org_admin, # admins default to True even without a dept policy
+ "default_expiry_hours": 1,
+ "max_expiry_hours": 24,
+ "extensions": list(STANDARD_EXTENSIONS),
+ }
+ elif is_org_admin:
+ # Override allow_user_expiry for admins — they can always pick
+ policy = {**policy, "allow_user_expiry": True}
+
+ return api_response(data={"policy": policy}, message="Certificate policy retrieved")
+
+
+@ssh_bp.route('/keys', methods=['GET'])
+@login_required
+def list_ssh_keys():
+ """Get all SSH keys for current user."""
+ user_id = g.current_user.id
+
+ keys = ssh_key_service.get_user_ssh_keys(user_id)
+ return api_response(
+ data={
+ 'keys': [k.to_dict() for k in keys],
+ 'count': len(keys),
+ },
+ message="SSH keys retrieved successfully"
+ )
+
+
+@ssh_bp.route('/keys', methods=['POST'])
+@login_required
+def add_ssh_key():
+ """Add a new SSH public key for current user."""
+ user_id = g.current_user.id
+
+ data = request.get_json()
+ if not data:
+ return api_response(success=False, message='No JSON data provided', status=400, error_type='BAD_REQUEST')
+
+ public_key = data.get('public_key') or data.get('key')
+ description = data.get('description')
+
+ if not public_key:
+ return api_response(success=False, message='public_key is required', status=400, error_type='BAD_REQUEST')
+
+ try:
+ ssh_key = ssh_key_service.add_ssh_key(
+ user_id=user_id,
+ public_key=public_key,
+ description=description,
+ )
+
+ AuditLog.log(
+ action=AuditAction.SSH_KEY_ADDED,
+ user_id=user_id,
+ resource_type='SSHKey',
+ resource_id=ssh_key.id,
+ ip_address=request.remote_addr,
+ )
+
+ return api_response(success=True, message='SSH key added', data=ssh_key.to_dict(), status=201)
+
+ except SSHKeyAlreadyExistsError as e:
+ return api_response(success=False, message=e.message, status=409, error_type='SSH_KEY_ALREADY_EXISTS')
+ except IntegrityError:
+ return api_response(success=False, message='SSH key already exists', status=409, error_type='SSH_KEY_ALREADY_EXISTS')
+ except SSHKeyError as e:
+ return api_response(success=False, message=str(e), status=400, error_type='SSH_KEY_ERROR')
+ except ValidationError as e:
+ return api_response(success=False, message=str(e), status=400, error_type='VALIDATION_ERROR')
+
+
+@ssh_bp.route('/keys/', methods=['GET'])
+@login_required
+def get_ssh_key(key_id):
+ """Get a specific SSH key."""
+ user_id = g.current_user.id
+
+ try:
+ ssh_key = ssh_key_service.get_ssh_key(key_id)
+
+ if ssh_key.user_id != user_id:
+ return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN')
+
+ return api_response(success=True, message='SSH key retrieved', data=ssh_key.to_dict(), status=200)
+
+ except SSHKeyNotFoundError:
+ return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND')
+
+
+@ssh_bp.route('/keys/', methods=['DELETE'])
+@login_required
+def delete_ssh_key(key_id):
+ """Delete an SSH key."""
+ user_id = g.current_user.id
+
+ try:
+ ssh_key = ssh_key_service.get_ssh_key(key_id)
+
+ if ssh_key.user_id != user_id:
+ return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN')
+
+ ssh_key_service.delete_ssh_key(key_id)
+
+ AuditLog.log(
+ action=AuditAction.SSH_KEY_DELETED,
+ user_id=user_id,
+ resource_type='SSHKey',
+ resource_id=key_id,
+ ip_address=request.remote_addr,
+ )
+
+ return api_response(success=True, message='SSH key deleted', data={'status': 'deleted'}, status=200)
+
+ except SSHKeyNotFoundError:
+ return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND')
+
+
+@ssh_bp.route('/keys//verify', methods=['GET', 'POST'])
+@login_required
+def verify_ssh_key(key_id):
+ """Generate or verify SSH key ownership challenge."""
+ user_id = g.current_user.id
+
+ try:
+ ssh_key = ssh_key_service.get_ssh_key(key_id)
+
+ if ssh_key.user_id != user_id:
+ return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN')
+
+ # GET — return a fresh challenge
+ if request.method == 'GET':
+ challenge = ssh_key_service.generate_verification_challenge(key_id)
+ return api_response(success=True, message='Challenge generated', data={
+ 'challenge_text': challenge,
+ 'validationText': challenge,
+ 'key_id': key_id,
+ }, status=200)
+
+ # POST — verify signature or generate challenge
+ data = request.get_json() or {}
+ action = data.get('action', 'verify_signature')
+
+ if action == 'verify_signature':
+ signature = data.get('signature')
+ if not signature:
+ return api_response(success=False, message='signature is required', status=400, error_type='BAD_REQUEST')
+
+ try:
+ verified = ssh_key_service.verify_ssh_key_ownership(key_id, signature)
+
+ AuditLog.log(
+ action=AuditAction.SSH_KEY_VERIFIED,
+ user_id=user_id,
+ resource_type='SSHKey',
+ resource_id=key_id,
+ ip_address=request.remote_addr,
+ success=verified,
+ )
+
+ return api_response(success=True, message='Verification complete', data={'verified': verified}, status=200)
+
+ except Exception as e:
+ AuditLog.log(
+ action=AuditAction.SSH_KEY_VALIDATION_FAILED,
+ user_id=user_id,
+ resource_type='SSHKey',
+ resource_id=key_id,
+ ip_address=request.remote_addr,
+ success=False,
+ error_message=str(e),
+ )
+ return api_response(success=False, message=str(e), status=400, error_type='VERIFICATION_FAILED')
+
+ else: # generate_challenge
+ challenge = ssh_key_service.generate_verification_challenge(key_id)
+ return api_response(success=True, message='Challenge generated', data={
+ 'challenge_text': challenge,
+ 'challenge': challenge,
+ }, status=200)
+
+ except SSHKeyNotFoundError:
+ return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND')
+
+
+@ssh_bp.route('/keys//update-description', methods=['PATCH'])
+@login_required
+def update_ssh_key_description(key_id):
+ """Update SSH key description."""
+ user_id = g.current_user.id
+
+ data = request.get_json()
+ if not data or 'description' not in data:
+ return api_response(success=False, message='description is required', status=400, error_type='BAD_REQUEST')
+
+ try:
+ ssh_key = ssh_key_service.get_ssh_key(key_id)
+
+ if ssh_key.user_id != user_id:
+ return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN')
+
+ updated_key = ssh_key_service.update_ssh_key_description(key_id, data['description'])
+
+ return api_response(success=True, message='Description updated', data=updated_key.to_dict(), status=200)
+
+ except SSHKeyNotFoundError:
+ return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND')
+
+
+@ssh_bp.route('/sign', methods=['POST'])
+@login_required
+def sign_certificate():
+ """Sign an SSH certificate for the current user."""
+ user = g.current_user
+ user_id = user.id
+
+ # ── Check account suspension ──────────────────────────────────────────────
+ from gatehouse_app.utils.constants import UserStatus
+ if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
+ return api_response(
+ success=False,
+ message="Your account is suspended. Contact an administrator.",
+ status=403,
+ error_type="ACCOUNT_SUSPENDED",
+ )
+
+ data = request.get_json()
+ if not data:
+ return api_response(success=False, message="No JSON data provided", status=400, error_type="BAD_REQUEST")
+
+ requested_principals = data.get('principals') or []
+ cert_type = data.get('cert_type', 'user')
+ key_id = data.get('key_id') or data.get('cert_id')
+ expiry_hours = data.get('expiry_hours')
+
+ # ── Log the request ───────────────────────────────────────────────────────
+ AuditLog.log(
+ action=AuditAction.SSH_CERT_REQUESTED,
+ user_id=user_id,
+ resource_type='SSHCertificate',
+ ip_address=request.remote_addr,
+ description=(
+ f'{user.email} requested a certificate'
+ + (f' for principals: {", ".join(requested_principals)}' if requested_principals else '')
+ ),
+ )
+
+ # ── Resolve which principals the user is allowed to use ──────────────────
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
+ from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal
+ from gatehouse_app.utils.constants import OrganizationRole
+
+ allowed_principal_names = set()
+
+ memberships = OrganizationMember.query.filter_by(user_id=user_id).all()
+ for om in memberships:
+ org = om.organization
+ if not org or org.deleted_at is not None:
+ continue
+ role = om.role
+ if role in (OrganizationRole.ADMIN, OrganizationRole.OWNER):
+ # Admin/owner can use any principal in the org
+ for p in Principal.query.filter_by(organization_id=org.id, deleted_at=None).all():
+ allowed_principal_names.add(p.name)
+ else:
+ # Direct memberships
+ for pm in PrincipalMembership.query.filter_by(user_id=user_id, deleted_at=None).all():
+ if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None:
+ allowed_principal_names.add(pm.principal.name)
+ # Via department
+ for dm in DepartmentMembership.query.filter_by(user_id=user_id, deleted_at=None).all():
+ if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None:
+ for dp in DepartmentPrincipal.query.filter_by(department_id=dm.department_id, deleted_at=None).all():
+ if dp.principal and dp.principal.deleted_at is None:
+ allowed_principal_names.add(dp.principal.name)
+
+ # ── Determine final principals list ─────────────────────────────────────
+ if not requested_principals:
+ # Auto-resolve: use all principals the user is assigned to
+ principals = list(allowed_principal_names)
+ if not principals:
+ return api_response(
+ success=False,
+ message="You have no principals assigned. Ask an admin to add you to a principal.",
+ status=400,
+ error_type="NO_PRINCIPALS",
+ )
+ else:
+ # Validate each requested principal is within the user's allowed set
+ invalid = [p for p in requested_principals if p not in allowed_principal_names]
+ if invalid:
+ return api_response(
+ success=False,
+ message=f"You are not authorised to request principals: {', '.join(invalid)}",
+ status=403,
+ error_type="UNAUTHORIZED_PRINCIPALS",
+ )
+ principals = requested_principals
+
+ # ── Key resolution ────────────────────────────────────────────────────────
+ if not key_id:
+ verified_keys = ssh_key_service.get_user_verified_ssh_keys(user_id)
+ if not verified_keys:
+ return api_response(
+ success=False,
+ message="No verified SSH keys found. Verify a key before requesting a certificate.",
+ status=400,
+ error_type="NO_VERIFIED_KEYS",
+ )
+ key_id = verified_keys[0].id
+
+ try:
+ ssh_key = ssh_key_service.get_ssh_key(key_id)
+ except SSHKeyNotFoundError:
+ return api_response(success=False, message="SSH key not found", status=404, error_type="NOT_FOUND")
+
+ if ssh_key.user_id != user_id:
+ return api_response(success=False, message="Forbidden", status=403, error_type="FORBIDDEN")
+
+ if not ssh_key.verified:
+ return api_response(
+ success=False,
+ message="SSH key is not verified. Verify it before requesting a certificate.",
+ status=400,
+ error_type="KEY_NOT_VERIFIED",
+ )
+
+ db_ca = _get_org_ca_for_user(user, ca_type=cert_type)
+ if db_ca is None:
+ return api_response(
+ success=False,
+ message=(
+ "No active Certificate Authority is configured for your organization. "
+ "An admin must generate a CA on the Certificate Authorities page before "
+ "certificates can be issued."
+ ),
+ status=503,
+ error_type="CA_NOT_CONFIGURED",
+ )
+
+ # Determine if the caller is an org admin/owner (admins can always choose expiry)
+ is_org_admin = any(
+ om.role in (OrganizationRole.ADMIN, OrganizationRole.OWNER)
+ for om in memberships
+ if om.organization and om.organization.deleted_at is None
+ )
+
+ # ── Apply department certificate policy ───────────────────────────────────
+ dept_policy = _get_merged_dept_cert_policy(user_id)
+ if dept_policy:
+ if is_org_admin:
+ # Admins can always choose their own expiry, but still capped at dept max
+ if expiry_hours is not None:
+ expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"])
+ elif not dept_policy["allow_user_expiry"]:
+ # Regular members: ignore user-requested expiry; use dept default
+ expiry_hours = dept_policy["default_expiry_hours"]
+ else:
+ # Regular members allowed to pick, cap at dept maximum
+ if expiry_hours is not None:
+ expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"])
+ policy_extensions = dept_policy["extensions"]
+ else:
+ policy_extensions = None # let signing service use its own defaults
+
+ # ── Build rich key_id identity for the OpenSSH cert ─────────────────────
+ # This appears in `ssh-keygen -L -f cert.pub` as the Key ID field and
+ # is stored in the DB cert record so it's auditable.
+ org_slugs = sorted({
+ om.organization.slug
+ for om in memberships
+ if om.organization and om.organization.deleted_at is None
+ and getattr(om.organization, 'slug', None)
+ })
+ org_slug = org_slugs[0] if org_slugs else "unknown"
+ full_name = getattr(user, 'full_name', None) or getattr(user, 'name', None) or "unknown"
+ cert_identity = f"{user.email} ({full_name}) [org:{org_slug}]"
+
+ signing_request = SSHCertificateSigningRequest(
+ ssh_public_key=ssh_key.payload,
+ principals=principals,
+ cert_type=cert_type,
+ key_id=cert_identity,
+ expiry_hours=int(expiry_hours) if expiry_hours else None,
+ extensions=policy_extensions,
+ )
+ validation_errors = signing_request.validate()
+ if validation_errors:
+ return api_response(
+ success=False,
+ message="Invalid signing request",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ error_details={"errors": validation_errors},
+ )
+
+ try:
+ from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
+ ca_private_key_pem = decrypt_ca_key(db_ca.private_key)
+ response = ssh_ca_service.sign_certificate(
+ signing_request, ca_private_key=ca_private_key_pem, ca_obj=db_ca
+ )
+ except SSHCertificateError as e:
+ AuditLog.log(
+ action=AuditAction.SSH_CERT_FAILED,
+ user_id=user_id,
+ resource_type='SSHCertificate',
+ ip_address=request.remote_addr,
+ success=False,
+ error_message=str(e),
+ )
+ return api_response(success=False, message=str(e), status=400, error_type="SIGNING_FAILED")
+ except Exception as e:
+ AuditLog.log(
+ action=AuditAction.SSH_CERT_FAILED,
+ user_id=user_id,
+ resource_type='SSHCertificate',
+ ip_address=request.remote_addr,
+ success=False,
+ error_message=str(e),
+ )
+ return api_response(success=False, message="Certificate signing failed", status=500, error_type="SERVER_ERROR")
+
+ cert_record = _persist_certificate(
+ user_id=user_id,
+ ssh_key_id=key_id,
+ ca=db_ca,
+ signing_response=response,
+ request_ip=request.remote_addr,
+ cert_type_str=cert_type,
+ cert_identity=cert_identity,
+ )
+
+ AuditLog.log(
+ action=AuditAction.SSH_CERT_ISSUED,
+ user_id=user_id,
+ resource_type='SSHCertificate',
+ resource_id=cert_record.id if cert_record else key_id,
+ ip_address=request.remote_addr,
+ description=(
+ f'Certificate serial={response.serial} issued for {user.email}; '
+ f'principals: {", ".join(principals)}'
+ ),
+ extra_data={
+ 'serial': response.serial,
+ 'key_id': cert_identity,
+ 'principals': principals,
+ 'ca_id': str(db_ca.id),
+ 'ssh_key_id': str(key_id),
+ },
+ )
+
+ if cert_record:
+ CertificateAuditLog.log(
+ certificate_id=cert_record.id,
+ action='issued',
+ user_id=user_id,
+ ip_address=request.remote_addr,
+ user_agent=request.headers.get('User-Agent'),
+ message=(
+ f'Certificate serial={response.serial} issued for {user.email}; '
+ f'principals: {", ".join(principals)}'
+ ),
+ extra_data={
+ 'serial': response.serial,
+ 'key_id': cert_identity,
+ 'principals': principals,
+ 'ca_id': str(db_ca.id),
+ 'ssh_key_id': str(key_id),
+ 'valid_after': response.valid_after.isoformat() if response.valid_after else None,
+ 'valid_before': response.valid_before.isoformat() if response.valid_before else None,
+ },
+ success=True,
+ )
+
+ result = {
+ 'certificate': response.certificate,
+ 'serial': response.serial,
+ 'principals': response.principals,
+ 'valid_after': response.valid_after.isoformat() if response.valid_after else None,
+ 'valid_before': response.valid_before.isoformat() if response.valid_before else None,
+ }
+ if cert_record:
+ result['cert_id'] = str(cert_record.id)
+
+ return api_response(data=result, message="Certificate signed successfully", status=201)
+
+
+@ssh_bp.route('/certificates', methods=['GET'])
+@login_required
+def list_certificates():
+ """List all SSH certificates issued for the current user."""
+ user_id = g.current_user.id
+
+ try:
+ from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate
+ certs = (
+ SSHCertificate.query
+ .filter_by(user_id=user_id, deleted_at=None)
+ .order_by(SSHCertificate.created_at.desc())
+ .all()
+ )
+ return api_response(
+ data={
+ 'certificates': [c.to_dict() for c in certs],
+ 'count': len(certs),
+ },
+ message="Certificates retrieved successfully"
+ )
+ except Exception as e:
+ return api_response(
+ success=False,
+ message=str(e),
+ status=500,
+ error_type='INTERNAL_ERROR'
+ )
+
+
+@ssh_bp.route('/certificates/', methods=['GET'])
+@login_required
+def get_certificate(cert_id):
+ """Get a specific issued certificate (metadata only)."""
+ user_id = g.current_user.id
+
+ try:
+ from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate
+ cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first()
+ if not cert:
+ return api_response(success=False, message='Certificate not found', status=404, error_type='NOT_FOUND')
+ if cert.user_id != user_id:
+ return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN')
+ data = cert.to_dict()
+ data['certificate'] = cert.certificate
+ return api_response(success=True, message='Certificate retrieved', data=data, status=200)
+ except Exception as e:
+ return api_response(success=False, message=str(e), status=500, error_type='INTERNAL_ERROR')
+
+
+@ssh_bp.route('/certificates//revoke', methods=['POST'])
+@login_required
+def revoke_certificate(cert_id):
+ """Revoke an issued certificate."""
+ user_id = g.current_user.id
+
+ data = request.get_json() or {}
+ reason = data.get('reason', 'User requested revocation')
+
+ try:
+ from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate
+ cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first()
+ if not cert:
+ return api_response(success=False, message='Certificate not found', status=404, error_type='NOT_FOUND')
+ if cert.user_id != user_id:
+ return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN')
+ if cert.revoked:
+ return api_response(success=False, message='Certificate is already revoked', status=409, error_type='ALREADY_REVOKED')
+
+ cert.revoke(reason=reason)
+
+ AuditLog.log(
+ action=AuditAction.SSH_CERT_REVOKED,
+ user_id=user_id,
+ resource_type='SSHCertificate',
+ resource_id=cert_id,
+ ip_address=request.remote_addr,
+ description=f'Revoked: {reason}',
+ )
+
+ CertificateAuditLog.log(
+ certificate_id=cert_id,
+ action='revoked',
+ user_id=user_id,
+ ip_address=request.remote_addr,
+ user_agent=request.headers.get('User-Agent'),
+ message=f'Certificate revoked: {reason}',
+ success=True,
+ )
+
+ return api_response(
+ success=True,
+ message='Certificate revoked successfully',
+ data={'status': 'revoked', 'cert_id': cert_id, 'reason': reason},
+ status=200,
+ )
+ except Exception as e:
+ return api_response(success=False, message=str(e), status=500, error_type='INTERNAL_ERROR')
+
+
+@ssh_bp.route('/ca/public-key', methods=['GET'])
+@login_required
+def get_ca_public_key():
+ """
+ Return the CA public key for this user's organization.
+
+ Server admins should add this key to their host's ``TrustedUserCAKeys``
+ directive so that certificates issued by gatehouse are trusted.
+
+ Query parameters:
+ ca_type: 'user' (default) or 'host' — which CA's public key to return
+ format: 'openssh' (default) or 'text' — affects Content-Type only
+
+ Returns:
+ { "public_key": "ssh-ed25519 AAAA...",
+ "fingerprint": "SHA256:...",
+ "ca_name": "..." }
+ """
+ user = g.current_user
+ ca_type = request.args.get("ca_type", "user")
+ if ca_type not in ("user", "host"):
+ return api_response(
+ success=False,
+ message="ca_type must be 'user' or 'host'",
+ status=400,
+ error_type="BAD_REQUEST",
+ )
+
+ db_ca = _get_org_ca_for_user(user, ca_type=ca_type)
+ if db_ca:
+ return api_response(
+ data={
+ 'public_key': db_ca.public_key,
+ 'fingerprint': db_ca.fingerprint,
+ 'ca_name': db_ca.name,
+ 'ca_type': ca_type,
+ 'source': 'db',
+ },
+ message="CA public key retrieved successfully"
+ )
+
+ return api_response(
+ success=False,
+ message=(
+ f"No {ca_type} CA is configured for your organization. "
+ "An admin must generate one on the Certificate Authorities page."
+ ),
+ status=404,
+ error_type="CA_NOT_CONFIGURED",
+ )
+
+
+# ---------------------------------------------------------------------------
+# CA Permissions
+# ---------------------------------------------------------------------------
+
+@ssh_bp.route('/ca//permissions', methods=['GET'])
+@login_required
+def list_ca_permissions(ca_id):
+ """List permissions for a Certificate Authority.
+
+ Returns:
+ 200: { ca_id, permissions: [...], open_to_all: bool }
+ 403: Not admin/owner
+ 404: CA not found
+ """
+ from gatehouse_app.models.ssh_ca.ca import CA, CAPermission
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.utils.constants import OrganizationRole
+
+ user = g.current_user
+
+ ca = CA.query.filter_by(id=ca_id, deleted_at=None).first()
+ if not ca:
+ return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND")
+
+ # Verify user is admin/owner of the CA's org
+ if ca.organization_id:
+ membership = OrganizationMember.query.filter_by(
+ organization_id=ca.organization_id,
+ user_id=user.id,
+ deleted_at=None,
+ ).first()
+ if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER):
+ return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
+
+ perms = CAPermission.query.filter_by(ca_id=ca_id, deleted_at=None).all()
+ perm_list = []
+ for p in perms:
+ d = p.to_dict()
+ d["user_email"] = p.user.email if p.user else None
+ perm_list.append(d)
+
+ return api_response(
+ data={
+ "ca_id": ca_id,
+ "permissions": perm_list,
+ "open_to_all": len(perms) == 0,
+ },
+ message="CA permissions retrieved",
+ )
+
+
+@ssh_bp.route('/ca//permissions', methods=['POST'])
+@login_required
+def add_ca_permission(ca_id):
+ """Grant a user permission on a Certificate Authority.
+
+ Request body:
+ user_id: UUID of the user to grant access
+ permission: "sign" or "admin" (default: "sign")
+
+ Returns:
+ 201: Permission granted
+ 400: Validation error
+ 403: Not admin/owner
+ 404: CA or user not found
+ 409: Permission already exists
+ """
+ from gatehouse_app.models.ssh_ca.ca import CA, CAPermission
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.models.user import User
+ from gatehouse_app.utils.constants import OrganizationRole, AuditAction
+ from gatehouse_app.models import AuditLog
+ from gatehouse_app.extensions import db
+
+ user = g.current_user
+
+ ca = CA.query.filter_by(id=ca_id, deleted_at=None).first()
+ if not ca:
+ return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND")
+
+ # Verify user is admin/owner of the CA's org
+ if ca.organization_id:
+ membership = OrganizationMember.query.filter_by(
+ organization_id=ca.organization_id,
+ user_id=user.id,
+ deleted_at=None,
+ ).first()
+ if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER):
+ return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
+
+ data = request.get_json() or {}
+ target_user_id = (data.get("user_id") or "").strip()
+ permission = data.get("permission", "sign")
+
+ if not target_user_id:
+ return api_response(success=False, message="user_id is required", status=400, error_type="VALIDATION_ERROR")
+ if permission not in ("sign", "admin"):
+ return api_response(
+ success=False,
+ message="permission must be 'sign' or 'admin'",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ target_user = User.query.filter_by(id=target_user_id, deleted_at=None).first()
+ if not target_user:
+ return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
+
+ # Check for duplicate
+ existing = CAPermission.query.filter_by(
+ ca_id=ca_id, user_id=target_user_id, deleted_at=None
+ ).first()
+ if existing:
+ # Update permission level if different
+ if existing.permission != permission:
+ existing.permission = permission
+ db.session.commit()
+ d = existing.to_dict()
+ d["user_email"] = target_user.email
+ return api_response(
+ data={"message": "Permission updated", "permission": d},
+ message="Permission updated",
+ )
+ return api_response(
+ success=False,
+ message="User already has this permission on the CA",
+ status=409,
+ error_type="DUPLICATE",
+ )
+
+ perm = CAPermission(
+ ca_id=ca_id,
+ user_id=target_user_id,
+ permission=permission,
+ )
+ db.session.add(perm)
+ db.session.commit()
+
+ AuditLog.log(
+ action=AuditAction.CA_UPDATED,
+ user_id=user.id,
+ resource_type="CAPermission",
+ resource_id=perm.id,
+ ip_address=request.remote_addr,
+ description=f"Granted '{permission}' on CA '{ca.name}' to user {target_user.email}",
+ )
+
+ d = perm.to_dict()
+ d["user_email"] = target_user.email
+ return api_response(
+ data={"message": "Permission granted", "permission": d},
+ message="Permission granted",
+ status=201,
+ )
+
+
+@ssh_bp.route('/ca//permissions/', methods=['DELETE'])
+@login_required
+def remove_ca_permission(ca_id, target_user_id):
+ """Revoke a user's permission on a Certificate Authority.
+
+ Returns:
+ 200: Permission revoked
+ 403: Not admin/owner
+ 404: CA or permission not found
+ """
+ from gatehouse_app.models.ssh_ca.ca import CA, CAPermission
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.utils.constants import OrganizationRole, AuditAction
+ from gatehouse_app.models import AuditLog
+ from gatehouse_app.extensions import db
+
+ user = g.current_user
+
+ ca = CA.query.filter_by(id=ca_id, deleted_at=None).first()
+ if not ca:
+ return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND")
+
+ # Verify user is admin/owner of the CA's org
+ if ca.organization_id:
+ membership = OrganizationMember.query.filter_by(
+ organization_id=ca.organization_id,
+ user_id=user.id,
+ deleted_at=None,
+ ).first()
+ if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER):
+ return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
+
+ perm = CAPermission.query.filter_by(
+ ca_id=ca_id, user_id=target_user_id, deleted_at=None
+ ).first()
+ if not perm:
+ return api_response(success=False, message="Permission not found", status=404, error_type="NOT_FOUND")
+
+ perm.delete(soft=True)
+
+ AuditLog.log(
+ action=AuditAction.CA_UPDATED,
+ user_id=user.id,
+ resource_type="CAPermission",
+ resource_id=perm.id,
+ ip_address=request.remote_addr,
+ description=f"Revoked permission on CA '{ca.name}' from user {target_user_id}",
+ )
+
+ return api_response(
+ data={},
+ message="Permission revoked",
+ )
+
diff --git a/gatehouse_app/api/v1/users.py b/gatehouse_app/api/v1/users.py
index 407e373..65c42c3 100644
--- a/gatehouse_app/api/v1/users.py
+++ b/gatehouse_app/api/v1/users.py
@@ -73,11 +73,51 @@ def delete_me():
"""
Delete current user account (soft delete).
+ Blocked if the user is the sole owner of any organization that has other
+ active members — they must transfer ownership or dissolve those organizations
+ first.
+
Returns:
200: Account deleted successfully
401: Not authenticated
+ 409: User is sole owner of one or more organizations with other members
"""
- UserService.delete_user(g.current_user, soft=True)
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.utils.constants import OrganizationRole
+
+ user = g.current_user
+
+ # Find orgs where this user is the sole owner AND other members exist.
+ owned_memberships = OrganizationMember.query.filter_by(
+ user_id=user.id,
+ role=OrganizationRole.OWNER,
+ deleted_at=None,
+ ).all()
+
+ blocked_orgs = []
+ for membership in owned_memberships:
+ org = membership.organization
+ if org.deleted_at is not None:
+ continue
+ member_count = org.get_member_count()
+ if member_count > 1:
+ blocked_orgs.append(org.name)
+
+ if blocked_orgs:
+ names = ", ".join(f'"{n}"' for n in blocked_orgs)
+ return api_response(
+ success=False,
+ message=(
+ f"You are the sole owner of {len(blocked_orgs)} organization"
+ f"{'s' if len(blocked_orgs) > 1 else ''}: {names}. "
+ "Transfer ownership or delete those organizations before deleting your account."
+ ),
+ status=409,
+ error_type="USER_IS_SOLE_OWNER",
+ error_details={"organizations": blocked_orgs},
+ )
+
+ UserService.delete_user(user, soft=True)
return api_response(
message="Account deleted successfully",
@@ -142,18 +182,686 @@ def change_password():
@full_access_required
def get_my_organizations():
"""
- Get all organizations current user is a member of.
+ Get all organizations current user is a member of, including the user's role.
Returns:
- 200: List of organizations
+ 200: List of organizations with role
401: Not authenticated
"""
- organizations = UserService.get_user_organizations(g.current_user)
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+
+ user = g.current_user
+ memberships = OrganizationMember.query.filter_by(
+ user_id=user.id,
+ deleted_at=None,
+ ).all()
+
+ orgs = []
+ for membership in memberships:
+ org = membership.organization
+ if not org or org.deleted_at is not None:
+ continue
+ org_dict = org.to_dict()
+ org_dict["role"] = membership.role.value if hasattr(membership.role, "value") else str(membership.role)
+ orgs.append(org_dict)
return api_response(
data={
- "organizations": [org.to_dict() for org in organizations],
- "count": len(organizations),
+ "organizations": orgs,
+ "count": len(orgs),
},
message="Organizations retrieved successfully",
)
+
+
+@api_v1_bp.route("/users/me/principals", methods=["GET"])
+@login_required
+@full_access_required
+def get_my_principals():
+ """Return all principals the current user can sign certificates for.
+
+ For each organization the user belongs to, returns:
+ - Their effective principals (direct membership + via department)
+ - Their role in that org (so the frontend can offer admin-mode selection)
+ - All principals in the org (admin/owner only — so they can pick any)
+
+ Returns:
+ 200: {
+ orgs: [{
+ org_id, org_name, role,
+ my_principals: [{id, name, description}],
+ all_principals: [{id, name, description}] # populated for admin/owner only
+ }]
+ }
+ """
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
+ from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal
+ from gatehouse_app.utils.constants import OrganizationRole
+
+ user = g.current_user
+ user_id = user.id
+
+ # Get all org memberships
+ memberships = OrganizationMember.query.filter_by(
+ user_id=user_id,
+ ).all()
+
+ orgs_result = []
+ for membership in memberships:
+ org = membership.organization
+ if not org or org.deleted_at is not None:
+ continue
+
+ role = membership.role
+ is_admin = role in (OrganizationRole.ADMIN, OrganizationRole.OWNER)
+
+ # Collect the user's effective principals for this org
+ # Track direct vs via-department separately
+ direct_principal_ids = set()
+ via_dept_principal_ids = set()
+
+ # Direct memberships
+ direct = PrincipalMembership.query.filter_by(
+ user_id=user_id,
+ deleted_at=None,
+ ).all()
+ for pm in direct:
+ if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None:
+ direct_principal_ids.add(pm.principal_id)
+
+ # Via department
+ dept_memberships = DepartmentMembership.query.filter_by(
+ user_id=user_id,
+ deleted_at=None,
+ ).all()
+ for dm in dept_memberships:
+ if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None:
+ dept_principals = DepartmentPrincipal.query.filter_by(
+ department_id=dm.department_id,
+ deleted_at=None,
+ ).all()
+ for dp in dept_principals:
+ if dp.principal and dp.principal.deleted_at is None:
+ via_dept_principal_ids.add(dp.principal_id)
+
+ effective_principal_ids = direct_principal_ids | via_dept_principal_ids
+
+ # Fetch principal objects
+ my_principals = []
+ if effective_principal_ids:
+ my_p = Principal.query.filter(
+ Principal.id.in_(list(effective_principal_ids)),
+ Principal.deleted_at == None,
+ ).all()
+ my_principals = [
+ {
+ "id": p.id,
+ "name": p.name,
+ "description": p.description,
+ # direct=True means removable via API; False=inherited via department
+ "direct": p.id in direct_principal_ids,
+ }
+ for p in my_p
+ ]
+
+ # For admins/owners: also return all principals in the org
+ all_principals = []
+ if is_admin:
+ all_p = Principal.query.filter_by(
+ organization_id=org.id,
+ deleted_at=None,
+ ).all()
+ all_principals = [{"id": p.id, "name": p.name, "description": p.description} for p in all_p]
+
+ orgs_result.append({
+ "org_id": org.id,
+ "org_name": org.name,
+ "role": role.value if hasattr(role, "value") else role,
+ "is_admin": is_admin,
+ "my_principals": my_principals,
+ "all_principals": all_principals,
+ })
+
+ return api_response(
+ data={"orgs": orgs_result},
+ message="Principals retrieved successfully",
+ )
+
+
+@api_v1_bp.route("/admin/users", methods=["GET"])
+@login_required
+@full_access_required
+def admin_list_users():
+ """List all users the caller has admin rights to see.
+
+ The caller must be an OWNER or ADMIN of at least one organization.
+ Returns users that share an organization with the caller and where the
+ caller holds admin/owner role in that organization.
+
+ Query params:
+ q – optional search string (matched against name/email)
+ page – page number (default 1)
+ per_page – page size (default 50, max 200)
+ """
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.models.user.user import User as _User
+ from gatehouse_app.extensions import db as _db
+ from sqlalchemy import or_
+
+ caller = g.current_user
+
+ # Find orgs where caller is admin/owner
+ admin_memberships = OrganizationMember.query.filter(
+ OrganizationMember.user_id == caller.id,
+ OrganizationMember.role.in_(["OWNER", "ADMIN"]),
+ OrganizationMember.deleted_at == None,
+ ).all()
+
+ if not admin_memberships:
+ return api_response(
+ success=False,
+ message="Admin or owner role required",
+ status=403,
+ error_type="AUTHORIZATION_ERROR",
+ )
+
+ admin_org_ids = [m.organization_id for m in admin_memberships]
+
+ # Collect user IDs in those orgs
+ member_rows = OrganizationMember.query.filter(
+ OrganizationMember.organization_id.in_(admin_org_ids),
+ OrganizationMember.deleted_at == None,
+ ).all()
+ visible_user_ids = list({row.user_id for row in member_rows})
+
+ # Optional search
+ q = request.args.get("q", "").strip()
+ try:
+ page = max(1, int(request.args.get("page", 1)))
+ per_page = min(200, max(1, int(request.args.get("per_page", 50))))
+ except ValueError:
+ page, per_page = 1, 50
+
+ query = _User.query.filter(
+ _User.id.in_(visible_user_ids),
+ _User.deleted_at == None,
+ )
+ if q:
+ like = f"%{q}%"
+ query = query.filter(or_(_User.email.ilike(like), _User.full_name.ilike(like)))
+
+ total = query.count()
+ users = query.order_by(_User.email).offset((page - 1) * per_page).limit(per_page).all()
+
+ member_lookup: dict = {}
+ for row in member_rows:
+ if row.user_id not in member_lookup:
+ member_lookup[row.user_id] = {
+ "organization_id": row.organization_id,
+ "role": row.role.value if hasattr(row.role, "value") else row.role,
+ }
+
+ users_data = []
+ for u in users:
+ d = u.to_dict()
+ m = member_lookup.get(u.id, {})
+ d["org_role"] = m.get("role", "member")
+ d["org_id"] = m.get("organization_id")
+ users_data.append(d)
+
+ return api_response(
+ data={
+ "users": users_data,
+ "count": total,
+ "page": page,
+ "per_page": per_page,
+ "pages": (total + per_page - 1) // per_page,
+ },
+ message="Users retrieved successfully",
+ )
+
+
+@api_v1_bp.route("/admin/users/", methods=["GET"])
+@login_required
+@full_access_required
+def admin_get_user(user_id):
+ """Get a single user's profile (admin view with SSH keys)."""
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.models.user.user import User as _User
+ from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
+
+ caller = g.current_user
+
+ target = _User.query.filter_by(id=user_id, deleted_at=None).first()
+ if not target:
+ return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
+
+ # Verify caller has admin access to a shared org
+ target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None}
+ has_access = OrganizationMember.query.filter(
+ OrganizationMember.user_id == caller.id,
+ OrganizationMember.organization_id.in_(target_org_ids),
+ OrganizationMember.role.in_(["OWNER", "ADMIN"]),
+ OrganizationMember.deleted_at == None,
+ ).first() is not None
+
+ if not has_access:
+ return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
+
+ ssh_keys = SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all()
+
+ return api_response(
+ data={
+ "user": target.to_dict(),
+ "ssh_keys": [k.to_dict() for k in ssh_keys],
+ },
+ message="User retrieved",
+ )
+
+
+@api_v1_bp.route("/admin/users//suspend", methods=["POST"])
+@login_required
+@full_access_required
+def admin_suspend_user(user_id):
+ """Suspend a user account (blocks CA issuance and login).
+
+ The caller must be an OWNER or ADMIN of an organization the target user belongs to.
+ """
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.models.user.user import User as _User
+ from gatehouse_app.extensions import db as _db
+ from gatehouse_app.utils.constants import UserStatus, AuditAction
+ from gatehouse_app.services.audit_service import AuditService
+
+ caller = g.current_user
+ target = _User.query.filter_by(id=user_id, deleted_at=None).first()
+ if not target:
+ return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
+
+ if target.id == caller.id:
+ return api_response(success=False, message="Cannot suspend yourself", status=400, error_type="BAD_REQUEST")
+
+ # Verify caller has admin access to a shared org
+ target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None}
+ admin_in_shared_org = OrganizationMember.query.filter(
+ OrganizationMember.user_id == caller.id,
+ OrganizationMember.organization_id.in_(target_org_ids),
+ OrganizationMember.role.in_(["OWNER", "ADMIN"]),
+ OrganizationMember.deleted_at == None,
+ ).first()
+
+ if not admin_in_shared_org:
+ return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
+
+ # ── Owner protection ──────────────────────────────────────────────────────
+ # An org owner cannot be suspended until they transfer ownership.
+ from gatehouse_app.utils.constants import OrganizationRole
+ owner_memberships = OrganizationMember.query.filter(
+ OrganizationMember.user_id == target.id,
+ OrganizationMember.role == OrganizationRole.OWNER,
+ OrganizationMember.deleted_at == None,
+ ).all()
+ if owner_memberships:
+ org_names = [
+ m.organization.name
+ for m in owner_memberships
+ if m.organization and not m.organization.deleted_at
+ ]
+ return api_response(
+ success=False,
+ message=(
+ f"Cannot suspend an organization owner. "
+ f"{target.email} is the owner of: {', '.join(org_names)}. "
+ "Transfer ownership to another member first."
+ ),
+ status=403,
+ error_type="OWNER_PROTECTION",
+ )
+
+ if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
+ return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT")
+
+ target.status = UserStatus.SUSPENDED
+ _db.session.commit()
+
+ AuditService.log_action(
+ action=AuditAction.USER_SUSPEND,
+ user_id=caller.id,
+ organization_id=admin_in_shared_org.organization_id,
+ resource_type="user",
+ resource_id=str(target.id),
+ description=f"Admin suspended user {target.email}",
+ metadata={"target_user_id": str(target.id), "target_email": target.email},
+ )
+
+ return api_response(data={"user": target.to_dict()}, message="User suspended successfully")
+
+
+@api_v1_bp.route("/admin/users//unsuspend", methods=["POST"])
+@login_required
+@full_access_required
+def admin_unsuspend_user(user_id):
+ """Restore a suspended user account to active status."""
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.models.user.user import User as _User
+ from gatehouse_app.extensions import db as _db
+ from gatehouse_app.utils.constants import UserStatus, AuditAction
+ from gatehouse_app.services.audit_service import AuditService
+
+ caller = g.current_user
+ target = _User.query.filter_by(id=user_id, deleted_at=None).first()
+ if not target:
+ return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
+
+ # Verify caller has admin access to a shared org
+ target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None}
+ admin_in_shared_org = OrganizationMember.query.filter(
+ OrganizationMember.user_id == caller.id,
+ OrganizationMember.organization_id.in_(target_org_ids),
+ OrganizationMember.role.in_(["OWNER", "ADMIN"]),
+ OrganizationMember.deleted_at == None,
+ ).first()
+
+ if not admin_in_shared_org:
+ return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
+
+ if target.status not in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
+ return api_response(success=False, message="User is not suspended", status=409, error_type="CONFLICT")
+
+ target.status = UserStatus.ACTIVE
+ _db.session.commit()
+
+ AuditService.log_action(
+ action=AuditAction.USER_UNSUSPEND,
+ user_id=caller.id,
+ organization_id=admin_in_shared_org.organization_id,
+ resource_type="user",
+ resource_id=str(target.id),
+ description=f"Admin unsuspended user {target.email}",
+ metadata={"target_user_id": str(target.id), "target_email": target.email},
+ )
+
+ return api_response(data={"user": target.to_dict()}, message="User unsuspended successfully")
+
+
+@api_v1_bp.route("/users/me/invites", methods=["GET"])
+@login_required
+def get_my_pending_invites():
+ """Return pending (unaccepted, non-expired) invitations for the current user's email."""
+ from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
+ from datetime import datetime, timezone
+
+ user = g.current_user
+ now = datetime.now(timezone.utc)
+
+ invites = OrgInviteToken.query.filter(
+ OrgInviteToken.email == user.email,
+ OrgInviteToken.accepted_at.is_(None),
+ OrgInviteToken.expires_at > now,
+ OrgInviteToken.deleted_at.is_(None),
+ ).all()
+
+ return api_response(
+ data={
+ "invites": [
+ {
+ "token": i.token,
+ "organization": {"id": str(i.organization_id), "name": i.organization.name},
+ "role": i.role,
+ "expires_at": i.expires_at.isoformat(),
+ }
+ for i in invites
+ ]
+ },
+ message="Pending invitations retrieved",
+ )
+
+
+@api_v1_bp.route("/users/me/memberships", methods=["GET"])
+@login_required
+def get_my_memberships():
+ """Return the current user's department and principal memberships across all orgs.
+
+ Returns:
+ 200: {
+ orgs: [{
+ org_id, org_name, role,
+ departments: [{id, name, description}],
+ principals: [{id, name, description, via_department: bool}]
+ }]
+ }
+ """
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal, Department
+ from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
+
+ user = g.current_user
+
+ memberships = OrganizationMember.query.filter_by(
+ user_id=user.id,
+ deleted_at=None,
+ ).all()
+
+ orgs_result = []
+ for membership in memberships:
+ org = membership.organization
+ if not org or org.deleted_at is not None:
+ continue
+
+ # Departments the user belongs to
+ dept_memberships = DepartmentMembership.query.filter_by(
+ user_id=user.id,
+ deleted_at=None,
+ ).all()
+ user_depts = [
+ dm.department for dm in dept_memberships
+ if dm.department
+ and dm.department.organization_id == org.id
+ and dm.department.deleted_at is None
+ ]
+
+ # Principals: direct
+ direct_pm = PrincipalMembership.query.filter_by(
+ user_id=user.id,
+ deleted_at=None,
+ ).all()
+ direct_principal_ids = {
+ pm.principal_id for pm in direct_pm
+ if pm.principal
+ and pm.principal.organization_id == org.id
+ and pm.principal.deleted_at is None
+ }
+
+ # Principals: via department
+ via_dept_principal_ids = set()
+ for dept in user_depts:
+ for dp in DepartmentPrincipal.query.filter_by(department_id=dept.id, deleted_at=None).all():
+ if dp.principal and dp.principal.deleted_at is None:
+ via_dept_principal_ids.add(dp.principal_id)
+
+ all_principal_ids = direct_principal_ids | via_dept_principal_ids
+ principals_list = []
+ if all_principal_ids:
+ for p in Principal.query.filter(
+ Principal.id.in_(list(all_principal_ids)),
+ Principal.deleted_at == None,
+ ).all():
+ principals_list.append({
+ "id": str(p.id),
+ "name": p.name,
+ "description": p.description,
+ "via_department": p.id not in direct_principal_ids,
+ })
+
+ role = membership.role
+ orgs_result.append({
+ "org_id": str(org.id),
+ "org_name": org.name,
+ "role": role.value if hasattr(role, "value") else role,
+ "departments": [
+ {"id": str(d.id), "name": d.name, "description": d.description}
+ for d in user_depts
+ ],
+ "principals": principals_list,
+ })
+
+ return api_response(
+ data={"orgs": orgs_result},
+ message="Memberships retrieved",
+ )
+
+
+@api_v1_bp.route("/admin/users//delete", methods=["POST"])
+@login_required
+@full_access_required
+def admin_hard_delete_user(user_id):
+ """Permanently delete a user and ALL associated data (hard delete, irreversible).
+
+ Required body: {"confirm": true}
+
+ Pre-conditions:
+ - Caller is OWNER or ADMIN of a shared org with the target.
+ - Cannot delete yourself.
+ - Target must not be the OWNER of any active organization (transfer first).
+
+ Side-effects:
+ - All active SSH certificates are revoked before deletion.
+ - The user row and all cascaded rows are hard-deleted from the database.
+ - An audit log entry is written by the *caller* (so it is not lost with the user).
+ """
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.models.user.user import User as _User
+ from gatehouse_app.extensions import db as _db
+ from gatehouse_app.utils.constants import UserStatus, AuditAction, OrganizationRole
+ from gatehouse_app.services.audit_service import AuditService
+
+ caller = g.current_user
+ data = request.get_json() or {}
+
+ if not data.get("confirm"):
+ return api_response(
+ success=False,
+ message="Deletion requires explicit confirmation. Send {\"confirm\": true} to proceed.",
+ status=400,
+ error_type="CONFIRMATION_REQUIRED",
+ )
+
+ target = _User.query.filter_by(id=user_id).first()
+ if not target:
+ return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
+
+ if target.id == caller.id:
+ return api_response(
+ success=False,
+ message="Cannot delete your own account via this endpoint.",
+ status=400,
+ error_type="BAD_REQUEST",
+ )
+
+ # Caller must be OWNER/ADMIN of a shared org.
+ # Include soft-deleted memberships so that already-soft-deleted users can
+ # still be hard-deleted by an admin who shared an org with them.
+ target_org_ids = {m.organization_id for m in target.organization_memberships}
+ admin_in_shared_org = OrganizationMember.query.filter(
+ OrganizationMember.user_id == caller.id,
+ OrganizationMember.organization_id.in_(target_org_ids),
+ OrganizationMember.role.in_(["OWNER", "ADMIN"]),
+ OrganizationMember.deleted_at == None,
+ ).first()
+ if not admin_in_shared_org:
+ return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
+
+ # Block deletion if target is an org owner — they must transfer first
+ owner_memberships = OrganizationMember.query.filter(
+ OrganizationMember.user_id == target.id,
+ OrganizationMember.role == OrganizationRole.OWNER,
+ OrganizationMember.deleted_at == None,
+ ).all()
+ if owner_memberships:
+ org_names = [
+ m.organization.name
+ for m in owner_memberships
+ if m.organization and not m.organization.deleted_at
+ ]
+ return api_response(
+ success=False,
+ message=(
+ f"Cannot delete an organization owner. "
+ f"{target.email} is the owner of: {', '.join(org_names)}. "
+ "Transfer ownership to another member first."
+ ),
+ status=403,
+ error_type="OWNER_PROTECTION",
+ )
+
+ # ── Collect counts for audit metadata ────────────────────────────────────
+ from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
+ from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
+
+ ssh_key_count = SSHKey.query.filter_by(user_id=target.id, deleted_at=None).count()
+ active_cert_count = SSHCertificate.query.filter_by(
+ user_id=target.id, revoked=False
+ ).filter(SSHCertificate.deleted_at == None).count()
+
+ # ── Revoke all active SSH certificates before deletion ───────────────────
+ active_certs = SSHCertificate.query.filter_by(
+ user_id=target.id, revoked=False
+ ).filter(SSHCertificate.deleted_at == None).all()
+ for cert in active_certs:
+ try:
+ cert.revoke("account_deleted")
+ except Exception:
+ pass
+
+ if active_certs:
+ try:
+ _db.session.flush()
+ except Exception:
+ pass
+
+ # ── Hard delete ───────────────────────────────────────────────────────────
+ target_email = target.email # capture before deletion
+ target_id_str = str(target.id)
+
+ try:
+ _db.session.delete(target) # cascades to all child tables
+ _db.session.flush()
+ except Exception as exc:
+ _db.session.rollback()
+ import logging
+ logging.getLogger(__name__).error(f"Hard delete failed for {target_id_str}: {exc}")
+ return api_response(
+ success=False,
+ message="Failed to delete user account. Please try again.",
+ status=500,
+ error_type="SERVER_ERROR",
+ )
+
+ # ── Audit log (written as the caller so it survives the deletion) ─────────
+ AuditService.log_action(
+ action=AuditAction.USER_HARD_DELETE,
+ user_id=caller.id,
+ organization_id=admin_in_shared_org.organization_id,
+ resource_type="user",
+ resource_id=target_id_str,
+ description=f"Admin permanently deleted user account: {target_email}",
+ metadata={
+ "deleted_user_id": target_id_str,
+ "deleted_user_email": target_email,
+ "ssh_keys_deleted": ssh_key_count,
+ "certs_revoked": active_cert_count,
+ },
+ )
+
+ _db.session.commit()
+
+ return api_response(
+ message=f"User account {target_email} has been permanently deleted.",
+ data={
+ "deleted_user_id": target_id_str,
+ "deleted_user_email": target_email,
+ "ssh_keys_deleted": ssh_key_count,
+ "certs_revoked": active_cert_count,
+ },
+ )
diff --git a/gatehouse_app/config/ssh_ca_config.py b/gatehouse_app/config/ssh_ca_config.py
new file mode 100644
index 0000000..7b62a96
--- /dev/null
+++ b/gatehouse_app/config/ssh_ca_config.py
@@ -0,0 +1,234 @@
+"""SSH CA Configuration Manager.
+
+Handles loading and managing SSH CA configuration from etc/ssh_ca.conf
+and environment variables.
+"""
+import os
+import configparser
+from pathlib import Path
+from typing import Optional, Union
+
+
+class SSHCAConfig:
+ """Configuration manager for SSH CA settings.
+
+ Loads configuration from:
+ 1. etc/ssh_ca.conf file
+ 2. Environment variables (override config file)
+ 3. Application environment-specific defaults
+
+ Example:
+ config = SSHCAConfig()
+ cert_hours = config.get_int('cert_validity_hours')
+ key_path = config.get_str('ca_key_path')
+ """
+
+ # Configuration file location (relative to project root)
+ DEFAULT_CONFIG_FILE = "etc/ssh_ca.conf"
+
+ # Default values if config file is missing
+ DEFAULTS = {
+ 'cert_validity_hours': '8',
+ 'max_cert_validity_hours': '720',
+ 'ca_key_path': '',
+ 'max_principals_per_cert': '256',
+ 'max_key_id_length': '255',
+ 'verification_challenge_max_age': '24',
+ 'auto_delete_unverified_days': '30',
+ }
+
+ def __init__(self, config_file: Optional[str] = None, environment: Optional[str] = None):
+ """Initialize SSH CA configuration.
+
+ Args:
+ config_file: Path to config file (default: etc/ssh_ca.conf)
+ environment: Environment name (development, production, testing)
+ Default: value of FLASK_ENV or 'development'
+ """
+ self.config = configparser.ConfigParser()
+
+ # Determine environment
+ if environment is None:
+ environment = os.environ.get('FLASK_ENV', 'development')
+ self.environment = environment
+
+ # Load config file
+ if config_file is None:
+ # Try to find config file relative to this module
+ module_dir = Path(__file__).parent.parent.parent
+ config_file = module_dir / self.DEFAULT_CONFIG_FILE
+
+ self.config_file = config_file
+ self._load_config()
+
+ def _load_config(self):
+ """Load configuration from file and apply environment-specific overrides."""
+ # Set defaults
+ self.config['default'] = self.DEFAULTS.copy()
+
+ # Load config file if it exists
+ if Path(self.config_file).exists():
+ self.config.read(self.config_file)
+
+ # Apply environment-specific configuration
+ if self.environment in self.config:
+ for key, value in self.config[self.environment].items():
+ self.config['default'][key] = value
+
+ def get_str(self, key: str, default: Optional[str] = None) -> str:
+ """Get a string configuration value.
+
+ First checks environment variables (SSH_CA_), then config file.
+
+ Args:
+ key: Configuration key
+ default: Default value if not found
+
+ Returns:
+ Configuration value as string
+ """
+ env_key = f"SSH_CA_{key.upper()}"
+
+ # Check environment variable first
+ if env_key in os.environ:
+ return os.environ[env_key]
+
+ # Check config file
+ if key in self.config['default']:
+ value = self.config['default'][key]
+ # Handle environment variable substitution
+ return os.path.expandvars(value)
+
+ # Return default
+ if default is not None:
+ return default
+
+ return self.DEFAULTS.get(key, '')
+
+ def get_int(self, key: str, default: Optional[int] = None) -> int:
+ """Get an integer configuration value.
+
+ Args:
+ key: Configuration key
+ default: Default value if not found
+
+ Returns:
+ Configuration value as integer
+
+ Raises:
+ ValueError: If value cannot be converted to integer
+ """
+ str_value = self.get_str(key)
+ if not str_value:
+ if default is not None:
+ return default
+ raise ValueError(f"No value found for {key}")
+
+ try:
+ return int(str_value)
+ except ValueError:
+ if default is not None:
+ return default
+ raise ValueError(f"Configuration {key}={str_value} is not a valid integer")
+
+ def get_bool(self, key: str, default: Optional[bool] = None) -> bool:
+ """Get a boolean configuration value.
+
+ Args:
+ key: Configuration key
+ default: Default value if not found
+
+ Returns:
+ Configuration value as boolean
+ """
+ str_value = self.get_str(key)
+ if not str_value:
+ if default is not None:
+ return default
+ return False
+
+ return str_value.lower() in ('true', '1', 'yes', 'on')
+
+ def get_list(self, key: str, delimiter: str = ',', default: Optional[list] = None) -> list:
+ """Get a comma-separated list configuration value.
+
+ Args:
+ key: Configuration key
+ delimiter: Delimiter between items (default: comma)
+ default: Default value if not found
+
+ Returns:
+ Configuration value as list of strings
+ """
+ str_value = self.get_str(key)
+ if not str_value:
+ if default is not None:
+ return default
+ return []
+
+ return [item.strip() for item in str_value.split(delimiter) if item.strip()]
+
+ def validate_config(self) -> list:
+ """Validate SSH CA configuration.
+
+ Returns:
+ List of validation error messages (empty if valid)
+ """
+ errors = []
+
+ # Check cert validity hours
+ try:
+ validity = self.get_int('cert_validity_hours')
+ max_validity = self.get_int('max_cert_validity_hours')
+ if validity > max_validity:
+ errors.append(
+ f"cert_validity_hours ({validity}) > max_cert_validity_hours ({max_validity})"
+ )
+ except ValueError as e:
+ errors.append(f"Invalid cert validity hours: {e}")
+
+ # Check principals limit
+ max_principals = self.get_int('max_principals_per_cert')
+ if max_principals > 256:
+ errors.append(f"max_principals_per_cert ({max_principals}) exceeds SSH limit of 256")
+
+ # Check ca_key_path is set
+ if not self.get_str('ca_key_path', '').strip():
+ errors.append("ca_key_path is not set")
+
+ return errors
+
+ def to_dict(self) -> dict:
+ """Export current configuration as dictionary.
+ """
+ return dict(self.config['default'])
+
+ def __repr__(self):
+ """String representation of configuration."""
+ return f""
+
+
+# Global configuration instance
+_config_instance = None
+
+
+def get_ssh_ca_config() -> SSHCAConfig:
+ """Get the global SSH CA configuration instance.
+
+ This function uses a singleton pattern to ensure only one
+ configuration instance is created and reused.
+
+ Returns:
+ SSHCAConfig instance
+ """
+ global _config_instance
+ if _config_instance is None:
+ _config_instance = SSHCAConfig()
+ return _config_instance
+
+
+def reset_config_instance():
+ """Reset the global configuration instance.
+ """
+ global _config_instance
+ _config_instance = None
diff --git a/gatehouse_app/exceptions/__init__.py b/gatehouse_app/exceptions/__init__.py
index bd8e6f8..06c3f34 100644
--- a/gatehouse_app/exceptions/__init__.py
+++ b/gatehouse_app/exceptions/__init__.py
@@ -19,6 +19,21 @@ from gatehouse_app.exceptions.validation_exceptions import (
OrganizationNotFoundError,
UserNotFoundError,
)
+from gatehouse_app.exceptions.ssh_exceptions import (
+ SSHCAError,
+ SSHKeyError,
+ SSHKeyNotFoundError,
+ SSHKeyAlreadyExistsError,
+ SSHKeyNotVerifiedError,
+ SSHCertificateError,
+ SSHCertificateNotFoundError,
+ CAError,
+ CANotFoundError,
+ PrincipalError,
+ PrincipalNotFoundError,
+ DepartmentError,
+ DepartmentNotFoundError,
+)
__all__ = [
"BaseAPIException",
@@ -37,4 +52,18 @@ __all__ = [
"EmailAlreadyExistsError",
"OrganizationNotFoundError",
"UserNotFoundError",
+ "SSHCAError",
+ "SSHKeyError",
+ "SSHKeyNotFoundError",
+ "SSHKeyAlreadyExistsError",
+ "SSHKeyNotVerifiedError",
+ "SSHCertificateError",
+ "SSHCertificateNotFoundError",
+ "CAError",
+ "CANotFoundError",
+ "PrincipalError",
+ "PrincipalNotFoundError",
+ "DepartmentError",
+ "DepartmentNotFoundError",
]
+
diff --git a/gatehouse_app/exceptions/base.py b/gatehouse_app/exceptions/base.py
index 3f1f42b..f7cfb0e 100644
--- a/gatehouse_app/exceptions/base.py
+++ b/gatehouse_app/exceptions/base.py
@@ -16,9 +16,10 @@ class BaseAPIException(Exception):
message: Custom error message
error_details: Additional error details dictionary
"""
- super().__init__()
+ super().__init__(self.message)
if message:
self.message = message
+ super().__init__(message) # update args so str(e) works
self.error_details = error_details or {}
def to_dict(self):
diff --git a/gatehouse_app/exceptions/ssh_exceptions.py b/gatehouse_app/exceptions/ssh_exceptions.py
new file mode 100644
index 0000000..141c1e9
--- /dev/null
+++ b/gatehouse_app/exceptions/ssh_exceptions.py
@@ -0,0 +1,93 @@
+"""SSH-specific exceptions."""
+from gatehouse_app.exceptions.base import BaseAPIException
+
+
+class SSHCAError(BaseAPIException):
+ """Base exception for SSH CA operations."""
+
+ status_code = 500
+ error_type = "SSH_CA_ERROR"
+
+
+class SSHKeyError(BaseAPIException):
+ """Exception for SSH key operations."""
+
+ status_code = 400
+ error_type = "SSH_KEY_ERROR"
+
+
+class SSHKeyNotFoundError(BaseAPIException):
+ """SSH key not found."""
+
+ status_code = 404
+ error_type = "SSH_KEY_NOT_FOUND"
+
+
+class SSHKeyAlreadyExistsError(BaseAPIException):
+ """SSH key already exists (duplicate fingerprint)."""
+
+ status_code = 409
+ error_type = "SSH_KEY_ALREADY_EXISTS"
+
+
+class SSHKeyNotVerifiedError(BaseAPIException):
+ """SSH key has not been verified."""
+
+ status_code = 400
+ error_type = "SSH_KEY_NOT_VERIFIED"
+
+
+class SSHCertificateError(BaseAPIException):
+ """Exception for SSH certificate operations."""
+
+ status_code = 400
+ error_type = "SSH_CERT_ERROR"
+
+
+class SSHCertificateNotFoundError(BaseAPIException):
+ """SSH certificate not found."""
+
+ status_code = 404
+ error_type = "SSH_CERT_NOT_FOUND"
+
+
+class CAError(BaseAPIException):
+ """Exception for Certificate Authority operations."""
+
+ status_code = 400
+ error_type = "CA_ERROR"
+
+
+class CANotFoundError(BaseAPIException):
+ """Certificate Authority not found."""
+
+ status_code = 404
+ error_type = "CA_NOT_FOUND"
+
+
+class PrincipalError(BaseAPIException):
+ """Exception for principal operations."""
+
+ status_code = 400
+ error_type = "PRINCIPAL_ERROR"
+
+
+class PrincipalNotFoundError(BaseAPIException):
+ """Principal not found."""
+
+ status_code = 404
+ error_type = "PRINCIPAL_NOT_FOUND"
+
+
+class DepartmentError(BaseAPIException):
+ """Exception for department operations."""
+
+ status_code = 400
+ error_type = "DEPARTMENT_ERROR"
+
+
+class DepartmentNotFoundError(BaseAPIException):
+ """Department not found."""
+
+ status_code = 404
+ error_type = "DEPARTMENT_NOT_FOUND"
diff --git a/gatehouse_app/middleware/security_headers.py b/gatehouse_app/middleware/security_headers.py
index eb1b6a4..ea03dc4 100644
--- a/gatehouse_app/middleware/security_headers.py
+++ b/gatehouse_app/middleware/security_headers.py
@@ -1,5 +1,6 @@
"""Security headers middleware."""
-from flask import request
+import os
+from flask import current_app, request
class SecurityHeadersMiddleware:
@@ -34,13 +35,22 @@ class SecurityHeadersMiddleware:
)
# Content Security Policy
+ try:
+ flask_env = current_app.config.get("ENV") or os.environ.get("FLASK_ENV", "production")
+ if flask_env == "development":
+ connect_src = "connect-src 'self' http://localhost:5000 http://127.0.0.1:5000"
+ else:
+ connect_src = "connect-src 'self'"
+ except RuntimeError:
+ connect_src = "connect-src 'self'"
+
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"font-src 'self' data:; "
- "connect-src 'self'"
+ + connect_src
)
# Referrer Policy
diff --git a/gatehouse_app/models/__init__.py b/gatehouse_app/models/__init__.py
index 99ef6fb..f15764c 100644
--- a/gatehouse_app/models/__init__.py
+++ b/gatehouse_app/models/__init__.py
@@ -1,43 +1,149 @@
-"""Models package."""
-from gatehouse_app.models.base import BaseModel
-from gatehouse_app.models.user import User
-from gatehouse_app.models.organization import Organization
-from gatehouse_app.models.organization_member import OrganizationMember
-from gatehouse_app.models.authentication_method import (
+"""Models package.
+
+Sub-packages
+------------
+models.user — User, Session
+models.organization — Organization, OrganizationMember, Department,
+ DepartmentMembership, DepartmentPrincipal,
+ DepartmentCertPolicy, Principal, PrincipalMembership,
+ OrgInviteToken
+models.auth — AuthenticationMethod, ApplicationProviderConfig,
+ OrganizationProviderOverride, OAuthState,
+ AuditLog, PasswordResetToken, EmailVerificationToken
+models.oidc — OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCSession,
+ OIDCTokenMetadata, OIDCAuditLog, OidcJwksKey
+models.ssh_ca — CA, KeyType, CertType, CaType, CAPermission,
+ SSHKey, SSHCertificate, CertificateStatus,
+ CertificateAuditLog
+models.security — OrganizationSecurityPolicy, UserSecurityPolicy,
+ MfaPolicyCompliance
+
+All names are re-exported here so that existing code using the flat import
+style (``from gatehouse_app.models import X``) or the old per-file style
+(``from gatehouse_app.models.user import User``) continue to work unchanged.
+"""
+
+# ── Base ──────────────────────────────────────────────────────────────────────
+from gatehouse_app.models.base import BaseModel # noqa: F401
+
+# ── User ──────────────────────────────────────────────────────────────────────
+from gatehouse_app.models.user.user import User # noqa: F401
+from gatehouse_app.models.user.session import Session # noqa: F401
+
+# ── Organization ──────────────────────────────────────────────────────────────
+from gatehouse_app.models.organization.organization import Organization # noqa: F401
+from gatehouse_app.models.organization.organization_member import ( # noqa: F401
+ OrganizationMember,
+)
+from gatehouse_app.models.organization.department import ( # noqa: F401
+ Department,
+ DepartmentMembership,
+ DepartmentPrincipal,
+)
+from gatehouse_app.models.organization.department_cert_policy import ( # noqa: F401
+ DepartmentCertPolicy,
+ STANDARD_EXTENSIONS,
+)
+from gatehouse_app.models.organization.principal import ( # noqa: F401
+ Principal,
+ PrincipalMembership,
+)
+from gatehouse_app.models.organization.org_invite_token import OrgInviteToken # noqa: F401
+
+# ── Auth ──────────────────────────────────────────────────────────────────────
+from gatehouse_app.models.auth.authentication_method import ( # noqa: F401
AuthenticationMethod,
ApplicationProviderConfig,
OrganizationProviderOverride,
OAuthState,
)
-from gatehouse_app.models.session import Session
-from gatehouse_app.models.audit_log import AuditLog
-from gatehouse_app.models.oidc_client import OIDCClient
-from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
-from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken
-from gatehouse_app.models.oidc_session import OIDCSession
-from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
-from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
-from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
-from gatehouse_app.models.user_security_policy import UserSecurityPolicy
-from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
+from gatehouse_app.models.auth.audit_log import AuditLog # noqa: F401
+from gatehouse_app.models.auth.password_reset_token import PasswordResetToken # noqa: F401
+from gatehouse_app.models.auth.email_verification_token import ( # noqa: F401
+ EmailVerificationToken,
+)
+
+# ── OIDC ──────────────────────────────────────────────────────────────────────
+from gatehouse_app.models.oidc.oidc_client import OIDCClient # noqa: F401
+from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode # noqa: F401
+from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken # noqa: F401
+from gatehouse_app.models.oidc.oidc_session import OIDCSession # noqa: F401
+from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata # noqa: F401
+from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog # noqa: F401
+from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey # noqa: F401
+
+# ── SSH / CA ──────────────────────────────────────────────────────────────────
+from gatehouse_app.models.ssh_ca.ca import ( # noqa: F401
+ CA,
+ KeyType,
+ CertType,
+ CaType,
+ CAPermission,
+)
+from gatehouse_app.models.ssh_ca.ssh_key import SSHKey # noqa: F401
+from gatehouse_app.models.ssh_ca.ssh_certificate import ( # noqa: F401
+ SSHCertificate,
+ CertificateStatus,
+)
+from gatehouse_app.models.ssh_ca.certificate_audit_log import ( # noqa: F401
+ CertificateAuditLog,
+)
+
+# ── Security ──────────────────────────────────────────────────────────────────
+from gatehouse_app.models.security.organization_security_policy import ( # noqa: F401
+ OrganizationSecurityPolicy,
+)
+from gatehouse_app.models.security.user_security_policy import ( # noqa: F401
+ UserSecurityPolicy,
+)
+from gatehouse_app.models.security.mfa_policy_compliance import ( # noqa: F401
+ MfaPolicyCompliance,
+)
__all__ = [
+ # Base
"BaseModel",
+ # User
"User",
+ "Session",
+ # Organization
"Organization",
"OrganizationMember",
+ "Department",
+ "DepartmentMembership",
+ "DepartmentPrincipal",
+ "DepartmentCertPolicy",
+ "STANDARD_EXTENSIONS",
+ "Principal",
+ "PrincipalMembership",
+ "OrgInviteToken",
+ # Auth
"AuthenticationMethod",
"ApplicationProviderConfig",
"OrganizationProviderOverride",
"OAuthState",
- "Session",
"AuditLog",
+ "PasswordResetToken",
+ "EmailVerificationToken",
+ # OIDC
"OIDCClient",
"OIDCAuthCode",
"OIDCRefreshToken",
"OIDCSession",
"OIDCTokenMetadata",
"OIDCAuditLog",
+ "OidcJwksKey",
+ # SSH / CA
+ "CA",
+ "KeyType",
+ "CertType",
+ "CaType",
+ "CAPermission",
+ "SSHKey",
+ "SSHCertificate",
+ "CertificateStatus",
+ "CertificateAuditLog",
+ # Security
"OrganizationSecurityPolicy",
"UserSecurityPolicy",
"MfaPolicyCompliance",
diff --git a/gatehouse_app/models/auth/__init__.py b/gatehouse_app/models/auth/__init__.py
new file mode 100644
index 0000000..e28b467
--- /dev/null
+++ b/gatehouse_app/models/auth/__init__.py
@@ -0,0 +1,20 @@
+"""Auth subpackage — authentication methods, tokens, and audit logs."""
+from gatehouse_app.models.auth.authentication_method import (
+ AuthenticationMethod,
+ ApplicationProviderConfig,
+ OrganizationProviderOverride,
+ OAuthState,
+)
+from gatehouse_app.models.auth.audit_log import AuditLog
+from gatehouse_app.models.auth.password_reset_token import PasswordResetToken
+from gatehouse_app.models.auth.email_verification_token import EmailVerificationToken
+
+__all__ = [
+ "AuthenticationMethod",
+ "ApplicationProviderConfig",
+ "OrganizationProviderOverride",
+ "OAuthState",
+ "AuditLog",
+ "PasswordResetToken",
+ "EmailVerificationToken",
+]
diff --git a/gatehouse_app/models/audit_log.py b/gatehouse_app/models/auth/audit_log.py
similarity index 92%
rename from gatehouse_app/models/audit_log.py
rename to gatehouse_app/models/auth/audit_log.py
index 3e3cea1..849f915 100644
--- a/gatehouse_app/models/audit_log.py
+++ b/gatehouse_app/models/auth/audit_log.py
@@ -26,14 +26,13 @@ class AuditLog(BaseModel):
extra_data = db.Column(db.JSON, nullable=True)
description = db.Column(db.Text, nullable=True)
- # Success/failure
+ # Outcome
success = db.Column(db.Boolean, default=True, nullable=False)
error_message = db.Column(db.Text, nullable=True)
# Relationships
user = db.relationship("User", back_populates="audit_logs")
- # Indexes for common queries
__table_args__ = (
db.Index("idx_audit_user_action", "user_id", "action"),
db.Index("idx_audit_resource", "resource_type", "resource_id"),
@@ -45,9 +44,8 @@ class AuditLog(BaseModel):
return f""
@classmethod
- def log(cls, action, user_id=None, **kwargs):
- """
- Create an audit log entry.
+ def log(cls, action, user_id=None, **kwargs) -> "AuditLog":
+ """Create an audit log entry.
Args:
action: AuditAction enum value
diff --git a/gatehouse_app/models/authentication_method.py b/gatehouse_app/models/auth/authentication_method.py
similarity index 74%
rename from gatehouse_app/models/authentication_method.py
rename to gatehouse_app/models/auth/authentication_method.py
index 3766d52..3fbd7c0 100644
--- a/gatehouse_app/models/authentication_method.py
+++ b/gatehouse_app/models/auth/authentication_method.py
@@ -1,4 +1,4 @@
-"""Authentication method model."""
+"""Authentication method model — user credentials and OAuth provider config."""
from datetime import datetime, timedelta, timezone
import secrets
from gatehouse_app.extensions import db
@@ -35,7 +35,6 @@ class AuthenticationMethod(BaseModel):
# Relationships
user = db.relationship("User", back_populates="authentication_methods")
- # Ensure unique provider combinations
__table_args__ = (
db.Index("idx_user_method", "user_id", "method_type"),
db.UniqueConstraint(
@@ -45,13 +44,15 @@ class AuthenticationMethod(BaseModel):
def __repr__(self):
"""String representation of AuthenticationMethod."""
- return f""
+ return (
+ f""
+ )
- def is_password(self):
+ def is_password(self) -> bool:
"""Check if this is a password authentication method."""
return self.method_type == AuthMethodType.PASSWORD
- def is_oauth(self):
+ def is_oauth(self) -> bool:
"""Check if this is an OAuth authentication method."""
return self.method_type in [
AuthMethodType.GOOGLE,
@@ -59,32 +60,32 @@ class AuthenticationMethod(BaseModel):
AuthMethodType.MICROSOFT,
]
- def is_totp(self):
+ def is_totp(self) -> bool:
"""Check if this is a TOTP authentication method."""
return self.method_type == AuthMethodType.TOTP
- def is_webauthn(self):
+ def is_webauthn(self) -> bool:
"""Check if this is a WebAuthn authentication method."""
return self.method_type == AuthMethodType.WEBAUTHN
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
- # Always exclude password hash and TOTP secrets
- exclude.append("password_hash")
- exclude.append("totp_secret")
- exclude.append("totp_backup_codes")
+ # Always exclude credential material
+ for field in ("password_hash", "totp_secret", "totp_backup_codes"):
+ if field not in exclude:
+ exclude.append(field)
return super().to_dict(exclude=exclude)
def to_webauthn_dict(self):
"""Convert WebAuthn credential to public dictionary.
-
+
Returns:
- Dictionary with safe-to-expose credential information.
+ Dictionary with safe-to-expose credential information, or None.
"""
if not self.is_webauthn() or not self.provider_data:
return None
-
+
data = self.provider_data
return {
"id": data.get("credential_id"),
@@ -98,26 +99,26 @@ class AuthenticationMethod(BaseModel):
class ApplicationProviderConfig(BaseModel):
"""Application-wide OAuth provider configuration.
-
- This model stores OAuth provider credentials at the application level,
- allowing users to authenticate without needing to specify an organization first.
+
+ Stores OAuth provider credentials at the application level, allowing users
+ to authenticate without needing to specify an organization first.
"""
__tablename__ = "application_provider_configs"
# Provider identification
provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True)
-
- # OAuth credentials (encrypted)
+
+ # OAuth credentials (client_secret encrypted at rest)
client_id = db.Column(db.String(255), nullable=False)
client_secret_encrypted = db.Column(db.String(512), nullable=True)
-
+
# Provider status
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
-
+
# Default redirect URL
default_redirect_url = db.Column(db.String(2048), nullable=True)
-
+
# Provider-specific settings (JSON)
additional_config = db.Column(db.JSON, nullable=True)
@@ -126,28 +127,34 @@ class ApplicationProviderConfig(BaseModel):
"OrganizationProviderOverride",
back_populates="application_config",
foreign_keys="OrganizationProviderOverride.provider_type",
- primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
- cascade="all, delete-orphan"
+ primaryjoin=(
+ "ApplicationProviderConfig.provider_type"
+ "==OrganizationProviderOverride.provider_type"
+ ),
+ cascade="all, delete-orphan",
)
def __repr__(self):
"""String representation of ApplicationProviderConfig."""
- return f""
+ return (
+ f""
+ )
- def set_client_secret(self, plaintext_secret: str):
+ def set_client_secret(self, plaintext_secret: str) -> None:
"""Encrypt and store client secret.
-
+
Args:
plaintext_secret: The plaintext OAuth client secret
"""
if plaintext_secret:
self.client_secret_encrypted = encrypt(plaintext_secret)
- def get_client_secret(self) -> str:
+ def get_client_secret(self) -> str | None:
"""Decrypt and return client secret.
-
+
Returns:
- The plaintext OAuth client secret
+ The plaintext OAuth client secret, or None if not set.
"""
if self.client_secret_encrypted:
return decrypt(self.client_secret_encrypted)
@@ -156,37 +163,38 @@ class ApplicationProviderConfig(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
- # Always exclude encrypted client secret
- exclude.append("client_secret_encrypted")
+ if "client_secret_encrypted" not in exclude:
+ exclude.append("client_secret_encrypted")
return super().to_dict(exclude=exclude)
class OrganizationProviderOverride(BaseModel):
"""Organization-specific OAuth configuration overrides.
-
- This model allows organizations to override application-level OAuth settings
- for enterprise SSO scenarios or custom provider configurations.
+
+ Allows organizations to override application-level OAuth settings for
+ enterprise SSO scenarios or custom provider configurations.
"""
__tablename__ = "organization_provider_overrides"
- # References
organization_id = db.Column(
- db.String(36), db.ForeignKey("organizations.id"),
- nullable=False, index=True
+ db.String(36),
+ db.ForeignKey("organizations.id"),
+ nullable=False,
+ index=True,
)
provider_type = db.Column(db.String(50), nullable=False, index=True)
-
- # Override OAuth credentials (encrypted, nullable - only if overriding)
+
+ # Override OAuth credentials (encrypted, nullable — only set when overriding)
client_id = db.Column(db.String(255), nullable=True)
client_secret_encrypted = db.Column(db.String(512), nullable=True)
-
+
# Provider status
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
-
+
# Redirect URL override
redirect_url_override = db.Column(db.String(2048), nullable=True)
-
+
# Provider-specific settings override (JSON)
additional_config = db.Column(db.JSON, nullable=True)
@@ -196,37 +204,33 @@ class OrganizationProviderOverride(BaseModel):
"ApplicationProviderConfig",
back_populates="organization_overrides",
foreign_keys=[provider_type],
- primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
- viewonly=True
+ primaryjoin=(
+ "ApplicationProviderConfig.provider_type"
+ "==OrganizationProviderOverride.provider_type"
+ ),
+ viewonly=True,
)
- # Unique constraint on (organization_id, provider_type)
__table_args__ = (
db.UniqueConstraint(
- "organization_id", "provider_type",
- name="uix_org_provider_type"
+ "organization_id", "provider_type", name="uix_org_provider_type"
),
)
def __repr__(self):
"""String representation of OrganizationProviderOverride."""
- return f""
+ return (
+ f""
+ )
- def set_client_secret(self, plaintext_secret: str):
- """Encrypt and store client secret override.
-
- Args:
- plaintext_secret: The plaintext OAuth client secret
- """
+ def set_client_secret(self, plaintext_secret: str) -> None:
+ """Encrypt and store client secret override."""
if plaintext_secret:
self.client_secret_encrypted = encrypt(plaintext_secret)
- def get_client_secret(self) -> str:
- """Decrypt and return client secret override.
-
- Returns:
- The plaintext OAuth client secret
- """
+ def get_client_secret(self) -> str | None:
+ """Decrypt and return client secret override."""
if self.client_secret_encrypted:
return decrypt(self.client_secret_encrypted)
return None
@@ -234,53 +238,52 @@ class OrganizationProviderOverride(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
- # Always exclude encrypted client secret
- exclude.append("client_secret_encrypted")
+ if "client_secret_encrypted" not in exclude:
+ exclude.append("client_secret_encrypted")
return super().to_dict(exclude=exclude)
class OAuthState(BaseModel):
"""OAuth flow state tracking.
-
- This model tracks OAuth authentication flow state, including PKCE parameters
- and organization context (which is now optional to support login flows where
- the organization isn't known until after authentication).
+
+ Tracks OAuth authentication flow state, including PKCE parameters and
+ organization context (which is optional to support login flows where the
+ organization isn't known until after authentication).
"""
__tablename__ = "oauth_states"
# OAuth state parameter (unique, used for CSRF protection)
state = db.Column(db.String(64), unique=True, nullable=False, index=True)
-
+
# Flow type: "login", "register", "link"
flow_type = db.Column(db.String(50), nullable=False)
-
+
# Provider type
provider_type = db.Column(db.String(50), nullable=False)
-
- # User context (optional - not set for login/register flows)
+
+ # User context (optional — not set for login/register flows)
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
-
- # Organization context (NOW OPTIONAL - for SSO discovery or post-auth)
+
+ # Organization context (optional — for SSO discovery or post-auth)
organization_id = db.Column(
- db.String(36), db.ForeignKey("organizations.id"),
- nullable=True, index=True
+ db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True
)
-
+
# PKCE parameters
nonce = db.Column(db.String(128), nullable=True)
code_verifier = db.Column(db.String(128), nullable=True)
code_challenge = db.Column(db.String(128), nullable=True)
-
+
# OAuth parameters
redirect_uri = db.Column(db.String(2048), nullable=True)
-
+
# Post-auth redirect (for frontend routing)
return_url = db.Column(db.String(2048), nullable=True)
-
+
# Additional state data
extra_data = db.Column(db.JSON, nullable=True)
-
+
# Expiration and usage tracking
expires_at = db.Column(db.DateTime, nullable=False, index=True)
used = db.Column(db.Boolean, default=False, nullable=False)
@@ -291,7 +294,10 @@ class OAuthState(BaseModel):
def __repr__(self):
"""String representation of OAuthState."""
- return f""
+ return (
+ f""
+ )
@classmethod
def create_state(
@@ -306,10 +312,10 @@ class OAuthState(BaseModel):
code_challenge: str = None,
nonce: str = None,
extra_data: dict = None,
- lifetime_seconds: int = 600
- ):
- """Create a new OAuth state with auto-generated state parameter.
-
+ lifetime_seconds: int = 600,
+ ) -> "OAuthState":
+ """Create a new OAuth state with an auto-generated state parameter.
+
Args:
flow_type: Type of flow ("login", "register", "link")
provider_type: OAuth provider type
@@ -322,13 +328,13 @@ class OAuthState(BaseModel):
nonce: OpenID Connect nonce
extra_data: Additional state data
lifetime_seconds: How long the state is valid (default 10 minutes)
-
+
Returns:
New OAuthState instance
"""
state = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds)
-
+
oauth_state = cls(
state=state,
flow_type=flow_type,
@@ -342,31 +348,30 @@ class OAuthState(BaseModel):
nonce=nonce,
extra_data=extra_data,
expires_at=expires_at,
- used=False
+ used=False,
)
oauth_state.save()
return oauth_state
def is_valid(self) -> bool:
"""Check if the OAuth state is still valid.
-
+
Returns:
- True if state hasn't expired and hasn't been used
+ True if state hasn't expired and hasn't been used.
"""
now = datetime.now(timezone.utc)
- # Make expires_at timezone-aware if it's naive (database returns naive datetimes)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return not self.used and expires_at > now
- def mark_used(self):
+ def mark_used(self) -> None:
"""Mark the state as used to prevent replay attacks."""
self.used = True
self.save()
@classmethod
- def cleanup_expired(cls):
+ def cleanup_expired(cls) -> None:
"""Remove expired OAuth states."""
now = datetime.now(timezone.utc)
cls.query.filter(cls.expires_at < now).delete()
@@ -375,6 +380,7 @@ class OAuthState(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
- # Exclude code_verifier as it's sensitive
- exclude.append("code_verifier")
+ # code_verifier must never be exposed
+ if "code_verifier" not in exclude:
+ exclude.append("code_verifier")
return super().to_dict(exclude=exclude)
diff --git a/gatehouse_app/models/auth/email_verification_token.py b/gatehouse_app/models/auth/email_verification_token.py
new file mode 100644
index 0000000..9f40682
--- /dev/null
+++ b/gatehouse_app/models/auth/email_verification_token.py
@@ -0,0 +1,68 @@
+"""Email verification token model."""
+import secrets
+from datetime import datetime, timezone, timedelta
+
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+
+
+class EmailVerificationToken(BaseModel):
+ """Single-use token for verifying a user's email address."""
+
+ __tablename__ = "email_verification_tokens"
+
+ user_id = db.Column(
+ db.String(36),
+ db.ForeignKey("users.id", ondelete="CASCADE"),
+ nullable=False,
+ index=True,
+ )
+ token = db.Column(db.String(128), unique=True, nullable=False, index=True)
+ expires_at = db.Column(db.DateTime, nullable=False)
+ used_at = db.Column(db.DateTime, nullable=True)
+
+ user = db.relationship(
+ "User",
+ backref=db.backref("email_verification_tokens", cascade="all, delete-orphan"),
+ )
+
+ @classmethod
+ def generate(cls, user_id: str, ttl_hours: int = 24) -> "EmailVerificationToken":
+ """Create a new verification token for a user.
+
+ Any existing unused tokens for this user are invalidated first.
+ """
+ cls.query.filter_by(user_id=user_id, used_at=None).delete()
+ db.session.flush()
+
+ token_value = secrets.token_urlsafe(48)
+ instance = cls(
+ user_id=user_id,
+ token=token_value,
+ expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
+ )
+ db.session.add(instance)
+ db.session.commit()
+ return instance
+
+ @property
+ def is_valid(self) -> bool:
+ """Return True if the token has not been used and has not expired."""
+ if self.used_at is not None:
+ return False
+ now = datetime.now(timezone.utc)
+ expires = self.expires_at
+ if expires.tzinfo is None:
+ expires = expires.replace(tzinfo=timezone.utc)
+ return now < expires
+
+ def consume(self) -> None:
+ """Mark the token as used."""
+ self.used_at = datetime.now(timezone.utc)
+ db.session.commit()
+
+ def __repr__(self) -> str:
+ return (
+ f""
+ )
diff --git a/gatehouse_app/models/auth/password_reset_token.py b/gatehouse_app/models/auth/password_reset_token.py
new file mode 100644
index 0000000..53072ef
--- /dev/null
+++ b/gatehouse_app/models/auth/password_reset_token.py
@@ -0,0 +1,69 @@
+"""Password reset token model."""
+import secrets
+from datetime import datetime, timezone, timedelta
+
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+
+
+class PasswordResetToken(BaseModel):
+ """Single-use token for resetting a user's password."""
+
+ __tablename__ = "password_reset_tokens"
+
+ user_id = db.Column(
+ db.String(36),
+ db.ForeignKey("users.id", ondelete="CASCADE"),
+ nullable=False,
+ index=True,
+ )
+ token = db.Column(db.String(128), unique=True, nullable=False, index=True)
+ expires_at = db.Column(db.DateTime, nullable=False)
+ used_at = db.Column(db.DateTime, nullable=True)
+
+ user = db.relationship(
+ "User",
+ backref=db.backref("password_reset_tokens", cascade="all, delete-orphan"),
+ )
+
+ @classmethod
+ def generate(cls, user_id: str, ttl_hours: int = 2) -> "PasswordResetToken":
+ """Create a new password reset token for a user.
+
+ Any existing unused tokens for this user are invalidated first.
+ """
+ # Invalidate any existing unused tokens for this user
+ cls.query.filter_by(user_id=user_id, used_at=None).delete()
+ db.session.flush()
+
+ token_value = secrets.token_urlsafe(48)
+ instance = cls(
+ user_id=user_id,
+ token=token_value,
+ expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
+ )
+ db.session.add(instance)
+ db.session.commit()
+ return instance
+
+ @property
+ def is_valid(self) -> bool:
+ """Return True if the token has not been used and has not expired."""
+ if self.used_at is not None:
+ return False
+ now = datetime.now(timezone.utc)
+ expires = self.expires_at
+ if expires.tzinfo is None:
+ expires = expires.replace(tzinfo=timezone.utc)
+ return now < expires
+
+ def consume(self) -> None:
+ """Mark the token as used."""
+ self.used_at = datetime.now(timezone.utc)
+ db.session.commit()
+
+ def __repr__(self) -> str:
+ return (
+ f""
+ )
diff --git a/gatehouse_app/models/base.py b/gatehouse_app/models/base.py
index 0a63735..d1ce1c4 100644
--- a/gatehouse_app/models/base.py
+++ b/gatehouse_app/models/base.py
@@ -82,7 +82,10 @@ class BaseModel(db.Model):
if column.name not in exclude:
value = getattr(self, column.name)
if isinstance(value, datetime):
- result[column.name] = value.isoformat()
+ if value.tzinfo is None:
+ result[column.name] = value.isoformat() + "Z"
+ else:
+ result[column.name] = value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
else:
result[column.name] = value
return result
diff --git a/gatehouse_app/models/oidc/__init__.py b/gatehouse_app/models/oidc/__init__.py
new file mode 100644
index 0000000..2cb08da
--- /dev/null
+++ b/gatehouse_app/models/oidc/__init__.py
@@ -0,0 +1,18 @@
+"""OIDC subpackage — clients, tokens, sessions, and audit logs."""
+from gatehouse_app.models.oidc.oidc_client import OIDCClient
+from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
+from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken
+from gatehouse_app.models.oidc.oidc_session import OIDCSession
+from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata
+from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog
+from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey
+
+__all__ = [
+ "OIDCClient",
+ "OIDCAuthCode",
+ "OIDCRefreshToken",
+ "OIDCSession",
+ "OIDCTokenMetadata",
+ "OIDCAuditLog",
+ "OidcJwksKey",
+]
diff --git a/gatehouse_app/models/oidc_audit_log.py b/gatehouse_app/models/oidc/oidc_audit_log.py
similarity index 68%
rename from gatehouse_app/models/oidc_audit_log.py
rename to gatehouse_app/models/oidc/oidc_audit_log.py
index 39b21a5..c0ae557 100644
--- a/gatehouse_app/models/oidc_audit_log.py
+++ b/gatehouse_app/models/oidc/oidc_audit_log.py
@@ -1,5 +1,4 @@
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
-from datetime import datetime
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
@@ -7,8 +6,7 @@ from gatehouse_app.models.base import BaseModel
class OIDCAuditLog(BaseModel):
"""OIDC Audit Log model for comprehensive OIDC event tracking.
- This model logs all OIDC-related events for security, compliance,
- and debugging purposes.
+ Logs all OIDC-related events for security, compliance, and debugging.
"""
__tablename__ = "oidc_audit_logs"
@@ -46,16 +44,29 @@ class OIDCAuditLog(BaseModel):
def __repr__(self):
"""String representation of OIDCAuditLog."""
status = "success" if self.success else "failed"
- return f""
+ return (
+ f""
+ )
@classmethod
- def log_event(cls, event_type, client_id=None, user_id=None, success=True,
- error_code=None, error_description=None, ip_address=None,
- user_agent=None, request_id=None, event_metadata=None):
+ def log_event(
+ cls,
+ event_type: str,
+ client_id: str = None,
+ user_id: str = None,
+ success: bool = True,
+ error_code: str = None,
+ error_description: str = None,
+ ip_address: str = None,
+ user_agent: str = None,
+ request_id: str = None,
+ event_metadata: dict = None,
+ ) -> "OIDCAuditLog":
"""Log an OIDC event.
Args:
- event_type: Type of event (e.g., "authorization_request", "token_issue")
+ event_type: Type of event (e.g., "authorization_request")
client_id: The OIDC client ID
user_id: The user ID
success: Whether the event was successful
@@ -86,9 +97,19 @@ class OIDCAuditLog(BaseModel):
return log
@classmethod
- def log_authorization_request(cls, client_id, user_id, redirect_uri, scope,
- ip_address=None, user_agent=None, request_id=None,
- success=True, error_code=None, error_description=None):
+ def log_authorization_request(
+ cls,
+ client_id: str,
+ user_id: str,
+ redirect_uri: str,
+ scope,
+ ip_address: str = None,
+ user_agent: str = None,
+ request_id: str = None,
+ success: bool = True,
+ error_code: str = None,
+ error_description: str = None,
+ ) -> "OIDCAuditLog":
"""Log an authorization request event."""
return cls.log_event(
event_type="authorization_request",
@@ -100,15 +121,19 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
- event_metadata={
- "redirect_uri": redirect_uri,
- "scope": scope,
- }
+ event_metadata={"redirect_uri": redirect_uri, "scope": scope},
)
@classmethod
- def log_token_issue(cls, client_id, user_id, token_type,
- ip_address=None, user_agent=None, request_id=None):
+ def log_token_issue(
+ cls,
+ client_id: str,
+ user_id: str,
+ token_type: str,
+ ip_address: str = None,
+ user_agent: str = None,
+ request_id: str = None,
+ ) -> "OIDCAuditLog":
"""Log a token issuance event."""
return cls.log_event(
event_type="token_issue",
@@ -118,12 +143,20 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
- event_metadata={"token_type": token_type}
+ event_metadata={"token_type": token_type},
)
@classmethod
- def log_token_revocation(cls, client_id, user_id, token_type, reason=None,
- ip_address=None, user_agent=None, request_id=None):
+ def log_token_revocation(
+ cls,
+ client_id: str,
+ user_id: str,
+ token_type: str,
+ reason: str = None,
+ ip_address: str = None,
+ user_agent: str = None,
+ request_id: str = None,
+ ) -> "OIDCAuditLog":
"""Log a token revocation event."""
return cls.log_event(
event_type="token_revocation",
@@ -133,15 +166,19 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
- event_metadata={
- "token_type": token_type,
- "reason": reason,
- }
+ event_metadata={"token_type": token_type, "reason": reason},
)
@classmethod
- def log_authentication_failure(cls, client_id, error_code, error_description,
- ip_address=None, user_agent=None, request_id=None):
+ def log_authentication_failure(
+ cls,
+ client_id: str,
+ error_code: str,
+ error_description: str,
+ ip_address: str = None,
+ user_agent: str = None,
+ request_id: str = None,
+ ) -> "OIDCAuditLog":
"""Log an authentication failure event."""
return cls.log_event(
event_type="authentication_failure",
@@ -155,7 +192,7 @@ class OIDCAuditLog(BaseModel):
)
@classmethod
- def get_events_for_user(cls, user_id, limit=100):
+ def get_events_for_user(cls, user_id: str, limit: int = 100) -> list:
"""Get audit events for a user.
Args:
@@ -165,13 +202,15 @@ class OIDCAuditLog(BaseModel):
Returns:
List of OIDCAuditLog instances
"""
- return cls.query.filter_by(user_id=user_id, deleted_at=None)\
- .order_by(cls.created_at.desc())\
- .limit(limit)\
+ return (
+ cls.query.filter_by(user_id=user_id, deleted_at=None)
+ .order_by(cls.created_at.desc())
+ .limit(limit)
.all()
+ )
@classmethod
- def get_events_for_client(cls, client_id, limit=100):
+ def get_events_for_client(cls, client_id: str, limit: int = 100) -> list:
"""Get audit events for a client.
Args:
@@ -181,14 +220,22 @@ class OIDCAuditLog(BaseModel):
Returns:
List of OIDCAuditLog instances
"""
- return cls.query.filter_by(client_id=client_id, deleted_at=None)\
- .order_by(cls.created_at.desc())\
- .limit(limit)\
+ return (
+ cls.query.filter_by(client_id=client_id, deleted_at=None)
+ .order_by(cls.created_at.desc())
+ .limit(limit)
.all()
+ )
@classmethod
- def get_failed_events(cls, client_id=None, user_id=None, start_date=None,
- end_date=None, limit=100):
+ def get_failed_events(
+ cls,
+ client_id: str = None,
+ user_id: str = None,
+ start_date=None,
+ end_date=None,
+ limit: int = 100,
+ ) -> list:
"""Get failed audit events.
Args:
@@ -210,22 +257,8 @@ class OIDCAuditLog(BaseModel):
query = query.filter(cls.created_at >= start_date)
if end_date:
query = query.filter(cls.created_at <= end_date)
-
return query.order_by(cls.created_at.desc()).limit(limit).all()
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
-
-
-# Add relationship back to User model
-from gatehouse_app.models.user import User
-User.oidc_audit_logs = db.relationship(
- "OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
-)
-
-# Add relationship back to OIDCClient model
-from gatehouse_app.models.oidc_client import OIDCClient
-OIDCClient.audit_logs = db.relationship(
- "OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
-)
diff --git a/gatehouse_app/models/oidc_authorization_code.py b/gatehouse_app/models/oidc/oidc_authorization_code.py
similarity index 65%
rename from gatehouse_app/models/oidc_authorization_code.py
rename to gatehouse_app/models/oidc/oidc_authorization_code.py
index 640078e..3884592 100644
--- a/gatehouse_app/models/oidc_authorization_code.py
+++ b/gatehouse_app/models/oidc/oidc_authorization_code.py
@@ -1,14 +1,14 @@
-"""OIDC Authorization Code model for auth code flow."""
+"""OIDC Authorization Code model for the authorization code grant flow."""
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCAuthCode(BaseModel):
- """OIDC Authorization Code model for authorization code flow.
+ """OIDC Authorization Code model for the authorization code grant flow.
- Authorization codes are single-use, short-lived codes used in the
- authorization code grant flow. The code is hashed for security.
+ Authorization codes are single-use, short-lived codes. The code itself is
+ hashed before storage so that a database breach cannot replay codes.
"""
__tablename__ = "oidc_authorization_codes"
@@ -26,9 +26,9 @@ class OIDCAuthCode(BaseModel):
# Request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
- scope = db.Column(db.JSON, nullable=True) # Requested scopes
- nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
- code_verifier = db.Column(db.String(255), nullable=True) # For PKCE
+ scope = db.Column(db.JSON, nullable=True)
+ nonce = db.Column(db.String(255), nullable=True)
+ code_verifier = db.Column(db.String(255), nullable=True)
# Status tracking
expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -39,37 +39,48 @@ class OIDCAuthCode(BaseModel):
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
- # Relationships
+ # Relationships — back_populates declared on User and OIDCClient
client = db.relationship("OIDCClient", back_populates="authorization_codes")
user = db.relationship("User", back_populates="oidc_auth_codes")
def __repr__(self):
"""String representation of OIDCAuthCode."""
- return f""
+ return (
+ f""
+ )
- def is_expired(self):
+ def is_expired(self) -> bool:
"""Check if the authorization code has expired."""
- # Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
- # Make naive datetime timezone-aware (UTC)
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
- def is_valid(self):
+ def is_valid(self) -> bool:
"""Check if the authorization code is valid for use."""
return not self.is_used and not self.is_expired() and self.deleted_at is None
- def mark_as_used(self):
+ def mark_as_used(self) -> None:
"""Mark the authorization code as used."""
self.is_used = True
self.used_at = datetime.now(timezone.utc)
db.session.commit()
@classmethod
- def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None,
- nonce=None, code_verifier=None, ip_address=None, user_agent=None,
- lifetime_seconds=600):
+ def create_code(
+ cls,
+ client_id: str,
+ user_id: str,
+ code_hash: str,
+ redirect_uri: str,
+ scope=None,
+ nonce: str = None,
+ code_verifier: str = None,
+ ip_address: str = None,
+ user_agent: str = None,
+ lifetime_seconds: int = 600,
+ ) -> "OIDCAuthCode":
"""Create a new authorization code.
Args:
@@ -79,7 +90,7 @@ class OIDCAuthCode(BaseModel):
redirect_uri: The redirect URI
scope: Requested scopes
nonce: OIDC nonce
- code_verifier: PKCE code verifier
+ code_verifier: PKCE code verifier (stored hashed server-side)
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
@@ -106,20 +117,7 @@ class OIDCAuthCode(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
- # Always exclude code hash
- exclude.append("code_hash")
- exclude.append("code_verifier")
+ for field in ("code_hash", "code_verifier"):
+ if field not in exclude:
+ exclude.append(field)
return super().to_dict(exclude=exclude)
-
-
-# Add relationship back to User model
-from gatehouse_app.models.user import User
-User.oidc_auth_codes = db.relationship(
- "OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
-)
-
-# Add relationship back to OIDCClient model
-from gatehouse_app.models.oidc_client import OIDCClient
-OIDCClient.authorization_codes = db.relationship(
- "OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
-)
diff --git a/gatehouse_app/models/oidc_client.py b/gatehouse_app/models/oidc/oidc_client.py
similarity index 63%
rename from gatehouse_app/models/oidc_client.py
rename to gatehouse_app/models/oidc/oidc_client.py
index a446983..03c0b18 100644
--- a/gatehouse_app/models/oidc_client.py
+++ b/gatehouse_app/models/oidc/oidc_client.py
@@ -17,10 +17,10 @@ class OIDCClient(BaseModel):
client_secret_hash = db.Column(db.String(255), nullable=False)
# OAuth/OIDC configuration
- redirect_uris = db.Column(db.JSON, nullable=False) # List of allowed redirect URIs
- grant_types = db.Column(db.JSON, nullable=False) # List of allowed grant types
- response_types = db.Column(db.JSON, nullable=False) # List of allowed response types
- scopes = db.Column(db.JSON, nullable=False) # List of allowed scopes
+ redirect_uris = db.Column(db.JSON, nullable=False) # Allowed redirect URIs
+ grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
+ response_types = db.Column(db.JSON, nullable=False) # Allowed response types
+ scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
# Client metadata
logo_uri = db.Column(db.String(512), nullable=True)
@@ -41,6 +41,23 @@ class OIDCClient(BaseModel):
# Relationships
organization = db.relationship("Organization", back_populates="oidc_clients")
+ # OIDC sub-resource relationships (declared here, not monkey-patched elsewhere)
+ authorization_codes = db.relationship(
+ "OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
+ )
+ refresh_tokens = db.relationship(
+ "OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
+ )
+ oidc_sessions = db.relationship(
+ "OIDCSession", back_populates="client", cascade="all, delete-orphan"
+ )
+ token_metadata = db.relationship(
+ "OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
+ )
+ audit_logs = db.relationship(
+ "OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
+ )
+
def __repr__(self):
"""String representation of OIDCClient."""
return f""
@@ -48,22 +65,22 @@ class OIDCClient(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
- # Always exclude client secret
- exclude.append("client_secret_hash")
+ if "client_secret_hash" not in exclude:
+ exclude.append("client_secret_hash")
return super().to_dict(exclude=exclude)
- def has_grant_type(self, grant_type):
+ def has_grant_type(self, grant_type) -> bool:
"""Check if client supports a specific grant type."""
return grant_type in self.grant_types
- def has_response_type(self, response_type):
+ def has_response_type(self, response_type) -> bool:
"""Check if client supports a specific response type."""
return response_type in self.response_types
- def is_redirect_uri_allowed(self, redirect_uri):
+ def is_redirect_uri_allowed(self, redirect_uri: str) -> bool:
"""Check if a redirect URI is allowed for this client."""
return redirect_uri in self.redirect_uris
- def has_scope(self, scope):
+ def has_scope(self, scope: str) -> bool:
"""Check if client is allowed to request a specific scope."""
return scope in self.scopes
diff --git a/gatehouse_app/models/oidc/oidc_jwks_key.py b/gatehouse_app/models/oidc/oidc_jwks_key.py
new file mode 100644
index 0000000..f8fa982
--- /dev/null
+++ b/gatehouse_app/models/oidc/oidc_jwks_key.py
@@ -0,0 +1,76 @@
+"""OIDC JWKS Key model for persisting signing keys."""
+from datetime import datetime, timezone
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+
+
+class OidcJwksKey(BaseModel):
+ """OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
+
+ Stores RSA/ECDSA key pairs used for signing OIDC tokens. Multiple keys can
+ be stored to support key rotation scenarios.
+
+ Attributes:
+ kid: Unique key ID used in JWT ``kid`` header
+ key_type: Type of key (e.g., "RSA", "EC")
+ private_key: PEM-encoded private key (never exposed in API responses)
+ public_key: PEM-encoded public key
+ algorithm: Signing algorithm (e.g., "RS256", "ES256")
+ is_active: Whether this key is currently used for signing/verification
+ is_primary: Whether this is the primary signing key
+ expires_at: Optional expiry for key rotation enforcement
+ """
+
+ __tablename__ = "oidc_jwks_keys"
+
+ # Override the default UUID id with integer primary key for JWKS key sets
+ id = db.Column(db.Integer, primary_key=True)
+
+ expires_at = db.Column(db.DateTime, nullable=True)
+
+ # Key identification and type
+ kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
+ key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
+ algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
+
+ # Key material (PEM-encoded) — private_key must never be returned by API
+ private_key = db.Column(db.Text, nullable=False)
+ public_key = db.Column(db.Text, nullable=False)
+
+ # Key status
+ is_active = db.Column(db.Boolean, default=True, nullable=False)
+ is_primary = db.Column(db.Boolean, default=False, nullable=False)
+
+ def __repr__(self):
+ """String representation of OidcJwksKey."""
+ return (
+ f""
+ )
+
+ def to_dict(self, exclude_private_key: bool = True):
+ """Convert model to dictionary.
+
+ Args:
+ exclude_private_key: If True (default), excludes the private key.
+
+ Returns:
+ Dictionary representation of the model
+ """
+ exclude = ["private_key"] if exclude_private_key else []
+ return super().to_dict(exclude=exclude)
+
+ @classmethod
+ def get_active_keys(cls) -> list:
+ """Get all active keys for signing operations."""
+ return cls.query.filter_by(is_active=True).all()
+
+ @classmethod
+ def get_primary_key(cls) -> "OidcJwksKey | None":
+ """Get the primary signing key."""
+ return cls.query.filter_by(is_primary=True).first()
+
+ @classmethod
+ def get_key_by_kid(cls, kid: str) -> "OidcJwksKey | None":
+ """Get an active key by its key ID."""
+ return cls.query.filter_by(kid=kid, is_active=True).first()
diff --git a/gatehouse_app/models/oidc_refresh_token.py b/gatehouse_app/models/oidc/oidc_refresh_token.py
similarity index 67%
rename from gatehouse_app/models/oidc_refresh_token.py
rename to gatehouse_app/models/oidc/oidc_refresh_token.py
index a6459ea..3e1228f 100644
--- a/gatehouse_app/models/oidc_refresh_token.py
+++ b/gatehouse_app/models/oidc/oidc_refresh_token.py
@@ -1,5 +1,5 @@
"""OIDC Refresh Token model for token rotation."""
-from datetime import datetime, timezone
+from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
@@ -8,7 +8,8 @@ class OIDCRefreshToken(BaseModel):
"""OIDC Refresh Token model for token refresh and rotation.
Refresh tokens are long-lived credentials used to obtain new access tokens.
- They support token rotation for enhanced security.
+ They support token rotation for enhanced security — each use invalidates
+ the old token and issues a new one.
"""
__tablename__ = "oidc_refresh_tokens"
@@ -21,16 +22,14 @@ class OIDCRefreshToken(BaseModel):
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
- # Token (hashed for security)
+ # Token (hashed for security — never store plaintext refresh tokens)
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
- # Associated access token ID (stores JWT JTI string — no FK to sessions)
- access_token_id = db.Column(
- db.String(255), nullable=True, index=True
- )
+ # Associated access token JTI (no FK — stored as string for lightweight lookup)
+ access_token_id = db.Column(db.String(255), nullable=True, index=True)
# Token scope
- scope = db.Column(db.JSON, nullable=True) # Granted scopes
+ scope = db.Column(db.JSON, nullable=True)
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -40,7 +39,7 @@ class OIDCRefreshToken(BaseModel):
revoked_reason = db.Column(db.String(255), nullable=True)
# Token rotation metadata
- previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation
+ previous_token_hash = db.Column(db.String(255), nullable=True)
rotation_count = db.Column(db.Integer, default=0, nullable=False)
# Request metadata
@@ -53,25 +52,27 @@ class OIDCRefreshToken(BaseModel):
def __repr__(self):
"""String representation of OIDCRefreshToken."""
- return f""
+ return (
+ f""
+ )
- def is_expired(self):
+ def is_expired(self) -> bool:
"""Check if the refresh token has expired."""
- # Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
- def is_revoked(self):
+ def is_revoked(self) -> bool:
"""Check if the refresh token has been revoked."""
return self.revoked_at is not None
- def is_valid(self):
+ def is_valid(self) -> bool:
"""Check if the refresh token is valid for use."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
- def revoke(self, reason=None):
+ def revoke(self, reason: str = None) -> None:
"""Revoke the refresh token.
Args:
@@ -81,8 +82,8 @@ class OIDCRefreshToken(BaseModel):
self.revoked_reason = reason
db.session.commit()
- def rotate(self, new_token_hash):
- """Rotate the refresh token (invalidate old, create new).
+ def rotate(self, new_token_hash: str) -> "OIDCRefreshToken":
+ """Rotate the refresh token — invalidate the old hash, store the new one.
Args:
new_token_hash: Hash of the new refresh token
@@ -90,20 +91,25 @@ class OIDCRefreshToken(BaseModel):
Returns:
self for chaining
"""
- # Store reference to old token
self.previous_token_hash = self.token_hash
self.token_hash = new_token_hash
self.rotation_count += 1
- # Extend expiration on rotation
- from datetime import timedelta
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
db.session.commit()
return self
@classmethod
- def create_token(cls, client_id, user_id, token_hash, scope=None,
- access_token_id=None, ip_address=None, user_agent=None,
- lifetime_seconds=2592000):
+ def create_token(
+ cls,
+ client_id: str,
+ user_id: str,
+ token_hash: str,
+ scope=None,
+ access_token_id: str = None,
+ ip_address: str = None,
+ user_agent: str = None,
+ lifetime_seconds: int = 2592000,
+ ) -> "OIDCRefreshToken":
"""Create a new refresh token.
Args:
@@ -111,7 +117,7 @@ class OIDCRefreshToken(BaseModel):
user_id: The user ID
token_hash: Hashed refresh token
scope: Granted scopes
- access_token_id: Associated access token ID
+ access_token_id: Associated access token JTI
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Token lifetime in seconds (default 30 days)
@@ -119,7 +125,6 @@ class OIDCRefreshToken(BaseModel):
Returns:
OIDCRefreshToken instance
"""
- from datetime import timedelta
token = cls(
client_id=client_id,
user_id=user_id,
@@ -137,20 +142,7 @@ class OIDCRefreshToken(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
- # Always exclude token hashes
- exclude.append("token_hash")
- exclude.append("previous_token_hash")
+ for field in ("token_hash", "previous_token_hash"):
+ if field not in exclude:
+ exclude.append(field)
return super().to_dict(exclude=exclude)
-
-
-# Add relationship back to User model
-from gatehouse_app.models.user import User
-User.oidc_refresh_tokens = db.relationship(
- "OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
-)
-
-# Add relationship back to OIDCClient model
-from gatehouse_app.models.oidc_client import OIDCClient
-OIDCClient.refresh_tokens = db.relationship(
- "OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
-)
diff --git a/gatehouse_app/models/oidc_session.py b/gatehouse_app/models/oidc/oidc_session.py
similarity index 70%
rename from gatehouse_app/models/oidc_session.py
rename to gatehouse_app/models/oidc/oidc_session.py
index 8d6a88b..5768bd6 100644
--- a/gatehouse_app/models/oidc_session.py
+++ b/gatehouse_app/models/oidc/oidc_session.py
@@ -1,5 +1,7 @@
"""OIDC Session model for OIDC session tracking."""
-from datetime import datetime, timezone
+import hashlib
+import base64
+from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
@@ -7,8 +9,8 @@ from gatehouse_app.models.base import BaseModel
class OIDCSession(BaseModel):
"""OIDC Session model for tracking OIDC authentication sessions.
- This model tracks the state during the OIDC authentication flow,
- including PKCE parameters and nonce validation.
+ Tracks the state during the OIDC authorization flow, including PKCE
+ parameters and nonce validation.
"""
__tablename__ = "oidc_sessions"
@@ -25,11 +27,11 @@ class OIDCSession(BaseModel):
# State management
state = db.Column(db.String(255), nullable=False, index=True)
- nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
+ nonce = db.Column(db.String(255), nullable=True)
# Authorization request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
- scope = db.Column(db.JSON, nullable=True) # Requested scopes
+ scope = db.Column(db.JSON, nullable=True)
# PKCE parameters
code_challenge = db.Column(db.String(255), nullable=True)
@@ -45,50 +47,52 @@ class OIDCSession(BaseModel):
def __repr__(self):
"""String representation of OIDCSession."""
- return f""
+ return (
+ f""
+ )
- def is_expired(self):
+ def is_expired(self) -> bool:
"""Check if the OIDC session has expired."""
- return datetime.now(timezone.utc) > self.expires_at
+ expires_at = self.expires_at
+ if expires_at.tzinfo is None:
+ expires_at = expires_at.replace(tzinfo=timezone.utc)
+ return datetime.now(timezone.utc) > expires_at
- def is_authenticated(self):
+ def is_authenticated(self) -> bool:
"""Check if the user has been authenticated in this session."""
return self.authenticated_at is not None
- def mark_authenticated(self):
+ def mark_authenticated(self) -> None:
"""Mark the session as authenticated."""
self.authenticated_at = datetime.now(timezone.utc)
db.session.commit()
- def validate_nonce(self, expected_nonce):
+ def validate_nonce(self, expected_nonce: str) -> bool:
"""Validate the nonce matches the expected value.
Args:
expected_nonce: The expected nonce value
Returns:
- bool: True if nonce matches
+ True if nonce matches
"""
return self.nonce == expected_nonce
- def validate_code_challenge(self, code_verifier):
+ def validate_code_challenge(self, code_verifier: str) -> bool:
"""Validate the code verifier against the stored code challenge.
Args:
code_verifier: The PKCE code verifier
Returns:
- bool: True if code challenge is valid
+ True if the challenge is satisfied
"""
if not self.code_challenge:
return False
if self.code_challenge_method == "S256":
- import hashlib
- import base64
- # SHA256 hash of code_verifier
digest = hashlib.sha256(code_verifier.encode()).digest()
- # Base64 URL encode without padding
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
return self.code_challenge == expected
elif self.code_challenge_method == "plain":
@@ -97,9 +101,18 @@ class OIDCSession(BaseModel):
return False
@classmethod
- def create_session(cls, user_id, client_id, state, redirect_uri, scope=None,
- nonce=None, code_challenge=None, code_challenge_method=None,
- lifetime_seconds=600):
+ def create_session(
+ cls,
+ user_id: str,
+ client_id: str,
+ state: str,
+ redirect_uri: str,
+ scope=None,
+ nonce: str = None,
+ code_challenge: str = None,
+ code_challenge_method: str = None,
+ lifetime_seconds: int = 600,
+ ) -> "OIDCSession":
"""Create a new OIDC session.
Args:
@@ -116,7 +129,6 @@ class OIDCSession(BaseModel):
Returns:
OIDCSession instance
"""
- from datetime import timedelta
session = cls(
user_id=user_id,
client_id=client_id,
@@ -133,7 +145,7 @@ class OIDCSession(BaseModel):
return session
@classmethod
- def get_by_state(cls, state):
+ def get_by_state(cls, state: str) -> "OIDCSession | None":
"""Get a session by state parameter.
Args:
@@ -147,16 +159,3 @@ class OIDCSession(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
-
-
-# Add relationship back to User model
-from gatehouse_app.models.user import User
-User.oidc_sessions = db.relationship(
- "OIDCSession", back_populates="user", cascade="all, delete-orphan"
-)
-
-# Add relationship back to OIDCClient model
-from gatehouse_app.models.oidc_client import OIDCClient
-OIDCClient.oidc_sessions = db.relationship(
- "OIDCSession", back_populates="client", cascade="all, delete-orphan"
-)
diff --git a/gatehouse_app/models/oidc_token_metadata.py b/gatehouse_app/models/oidc/oidc_token_metadata.py
similarity index 67%
rename from gatehouse_app/models/oidc_token_metadata.py
rename to gatehouse_app/models/oidc/oidc_token_metadata.py
index 2c6c7a8..be8c862 100644
--- a/gatehouse_app/models/oidc_token_metadata.py
+++ b/gatehouse_app/models/oidc/oidc_token_metadata.py
@@ -8,13 +8,14 @@ from gatehouse_app.models.base import BaseModel
class OIDCTokenMetadata(BaseModel):
"""OIDC Token Metadata model for tracking issued tokens.
- This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens)
- for the purpose of token revocation. The id field matches the JTI (JWT ID) claim.
+ Stores metadata about issued tokens (access, refresh, ID) for revocation.
+ The ``id`` field on this model intentionally overrides the BaseModel UUID
+ to store the JWT JTI directly as the primary key for O(1) revocation checks.
"""
__tablename__ = "oidc_token_metadata"
- # Token identifier (matches JTI in JWT)
+ # Primary key = JTI so revocation lookups are always a PK scan
id = db.Column(
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
@@ -27,11 +28,11 @@ class OIDCTokenMetadata(BaseModel):
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
- # Token type
- token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token"
+ # Token type: "access_token", "refresh_token", or "id_token"
+ token_type = db.Column(db.String(50), nullable=False)
- # Token identifier for revocation lookup
- token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim
+ # JWT ID claim (indexed for fast lookup when id != jti)
+ token_jti = db.Column(db.String(255), nullable=False, index=True)
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -46,25 +47,27 @@ class OIDCTokenMetadata(BaseModel):
def __repr__(self):
"""String representation of OIDCTokenMetadata."""
- return f""
+ return (
+ f""
+ )
- def is_expired(self):
+ def is_expired(self) -> bool:
"""Check if the token has expired."""
- # Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
- def is_revoked(self):
+ def is_revoked(self) -> bool:
"""Check if the token has been revoked."""
return self.revoked_at is not None
- def is_valid(self):
+ def is_valid(self) -> bool:
"""Check if the token is valid (not expired and not revoked)."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
- def revoke(self, reason=None):
+ def revoke(self, reason: str = None) -> None:
"""Revoke the token.
Args:
@@ -75,8 +78,16 @@ class OIDCTokenMetadata(BaseModel):
db.session.commit()
@classmethod
- def create_metadata(cls, client_id, user_id, token_type, token_jti,
- expires_at, ip_address=None, user_agent=None):
+ def create_metadata(
+ cls,
+ client_id: str,
+ user_id: str,
+ token_type: str,
+ token_jti: str,
+ expires_at,
+ ip_address: str = None,
+ user_agent: str = None,
+ ) -> "OIDCTokenMetadata":
"""Create token metadata for tracking.
Args:
@@ -85,8 +96,8 @@ class OIDCTokenMetadata(BaseModel):
token_type: Type of token ("access_token", "refresh_token", "id_token")
token_jti: JWT ID claim
expires_at: Token expiration datetime
- ip_address: Client IP address
- user_agent: Client user agent
+ ip_address: Client IP address (unused column, kept for API compat)
+ user_agent: Client user agent (unused column, kept for API compat)
Returns:
OIDCTokenMetadata instance
@@ -104,7 +115,7 @@ class OIDCTokenMetadata(BaseModel):
return metadata
@classmethod
- def get_by_jti(cls, token_jti):
+ def get_by_jti(cls, token_jti: str) -> "OIDCTokenMetadata | None":
"""Get token metadata by JWT ID.
Args:
@@ -116,7 +127,7 @@ class OIDCTokenMetadata(BaseModel):
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
@classmethod
- def revoke_by_jti(cls, token_jti, reason=None):
+ def revoke_by_jti(cls, token_jti: str, reason: str = None) -> bool:
"""Revoke a token by its JWT ID.
Args:
@@ -124,7 +135,7 @@ class OIDCTokenMetadata(BaseModel):
reason: Optional revocation reason
Returns:
- bool: True if token was found and revoked
+ True if token was found and revoked, False otherwise
"""
metadata = cls.get_by_jti(token_jti)
if metadata:
@@ -133,47 +144,53 @@ class OIDCTokenMetadata(BaseModel):
return False
@classmethod
- def revoke_all_for_user(cls, user_id, client_id=None, reason=None):
+ def revoke_all_for_user(
+ cls, user_id: str, client_id: str = None, reason: str = None
+ ) -> int:
"""Revoke all tokens for a user.
Args:
user_id: The user ID
- client_id: Optional client ID to filter by
+ client_id: Optional client ID filter
reason: Optional revocation reason
Returns:
- int: Number of tokens revoked
+ Number of tokens revoked
"""
- query = cls.query.filter_by(user_id=user_id, deleted_at=None)
+ query = cls.query.filter_by(user_id=user_id, deleted_at=None).filter(
+ cls.revoked_at.is_(None)
+ )
if client_id:
query = query.filter_by(client_id=client_id)
- tokens = query.filter(cls.revoked_at == None).all()
count = 0
- for token in tokens:
+ for token in query.all():
token.revoke(reason)
count += 1
return count
@classmethod
- def revoke_all_for_client(cls, client_id, user_id=None, reason=None):
+ def revoke_all_for_client(
+ cls, client_id: str, user_id: str = None, reason: str = None
+ ) -> int:
"""Revoke all tokens for a client.
Args:
client_id: The client ID
- user_id: Optional user ID to filter by
+ user_id: Optional user ID filter
reason: Optional revocation reason
Returns:
- int: Number of tokens revoked
+ Number of tokens revoked
"""
- query = cls.query.filter_by(client_id=client_id, deleted_at=None)
+ query = cls.query.filter_by(client_id=client_id, deleted_at=None).filter(
+ cls.revoked_at.is_(None)
+ )
if user_id:
query = query.filter_by(user_id=user_id)
- tokens = query.filter(cls.revoked_at == None).all()
count = 0
- for token in tokens:
+ for token in query.all():
token.revoke(reason)
count += 1
return count
@@ -181,16 +198,3 @@ class OIDCTokenMetadata(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
-
-
-# Add relationship back to User model
-from gatehouse_app.models.user import User
-User.oidc_token_metadata = db.relationship(
- "OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
-)
-
-# Add relationship back to OIDCClient model
-from gatehouse_app.models.oidc_client import OIDCClient
-OIDCClient.token_metadata = db.relationship(
- "OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
-)
diff --git a/gatehouse_app/models/oidc_jwks_key.py b/gatehouse_app/models/oidc_jwks_key.py
deleted file mode 100644
index 07dcb80..0000000
--- a/gatehouse_app/models/oidc_jwks_key.py
+++ /dev/null
@@ -1,77 +0,0 @@
-"""OIDC JWKS Key model for persisting signing keys."""
-from datetime import datetime, timezone
-from gatehouse_app.extensions import db
-from gatehouse_app.models.base import BaseModel
-
-
-class OidcJwksKey(BaseModel):
- """
- OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
-
- This model stores RSA/ECDSA key pairs used for signing OIDC tokens.
- Multiple keys can be stored to support key rotation scenarios.
-
- Attributes:
- id: Integer primary key
- kid: Unique key ID used in JWT "kid" header
- key_type: Type of key (e.g., "RSA", "EC")
- private_key: PEM-encoded private key
- public_key: PEM-encoded public key
- algorithm: Signing algorithm (e.g., "RS256", "ES256")
- created_at: When the key was created
- is_active: Whether this key is currently active for signing
- is_primary: Whether this is the primary signing key
- expires_at: ...
- """
-
- __tablename__ = "oidc_jwks_keys"
-
- # Override the default UUID id with integer primary key
- id = db.Column(db.Integer, primary_key=True)
-
- expires_at = db.Column(db.DateTime, nullable=True)
-
- # Key identification and type
- kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
- key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
- algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
-
- # Key material (PEM-encoded)
- private_key = db.Column(db.Text, nullable=False)
- public_key = db.Column(db.Text, nullable=False)
-
- # Key status
- is_active = db.Column(db.Boolean, default=True, nullable=False)
- is_primary = db.Column(db.Boolean, default=False, nullable=False)
-
- def __repr__(self):
- """String representation of OidcJwksKey."""
- return f""
-
- def to_dict(self, exclude_private_key=True):
- """
- Convert model to dictionary.
-
- Args:
- exclude_private_key: If True, excludes the private key from output
-
- Returns:
- Dictionary representation of the model
- """
- exclude = ["private_key"] if exclude_private_key else []
- return super().to_dict(exclude=exclude)
-
- @classmethod
- def get_active_keys(cls):
- """Get all active keys for signing operations."""
- return cls.query.filter(cls.is_active == True).all()
-
- @classmethod
- def get_primary_key(cls):
- """Get the primary signing key."""
- return cls.query.filter(cls.is_primary == True).first()
-
- @classmethod
- def get_key_by_kid(cls, kid):
- """Get a key by its key ID."""
- return cls.query.filter(cls.kid == kid, cls.is_active == True).first()
\ No newline at end of file
diff --git a/gatehouse_app/models/organization/__init__.py b/gatehouse_app/models/organization/__init__.py
new file mode 100644
index 0000000..aa33f8e
--- /dev/null
+++ b/gatehouse_app/models/organization/__init__.py
@@ -0,0 +1,27 @@
+"""Organization subpackage."""
+from gatehouse_app.models.organization.organization import Organization
+from gatehouse_app.models.organization.organization_member import OrganizationMember
+from gatehouse_app.models.organization.department import (
+ Department,
+ DepartmentMembership,
+ DepartmentPrincipal,
+)
+from gatehouse_app.models.organization.department_cert_policy import (
+ DepartmentCertPolicy,
+ STANDARD_EXTENSIONS,
+)
+from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
+from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
+
+__all__ = [
+ "Organization",
+ "OrganizationMember",
+ "Department",
+ "DepartmentMembership",
+ "DepartmentPrincipal",
+ "DepartmentCertPolicy",
+ "STANDARD_EXTENSIONS",
+ "Principal",
+ "PrincipalMembership",
+ "OrgInviteToken",
+]
diff --git a/gatehouse_app/models/organization/department.py b/gatehouse_app/models/organization/department.py
new file mode 100644
index 0000000..800780b
--- /dev/null
+++ b/gatehouse_app/models/organization/department.py
@@ -0,0 +1,196 @@
+"""Department, DepartmentMembership, and DepartmentPrincipal models."""
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+
+
+class Department(BaseModel):
+ """Department model representing an organizational unit for SSH access control.
+
+ Departments are used to group users and assign SSH principals (access levels)
+ to them. A user can be a member of multiple departments, and each department
+ can have multiple principals assigned.
+
+ Example:
+ - Department: "Engineering"
+ - Members: user1@example.com, user2@example.com
+ - Principals: "eng-prod", "eng-staging"
+ - Users get access based on their principal assignments
+ """
+
+ __tablename__ = "departments"
+
+ organization_id = db.Column(
+ db.String(36),
+ db.ForeignKey("organizations.id"),
+ nullable=False,
+ index=True,
+ )
+ name = db.Column(db.String(255), nullable=False, index=True)
+ description = db.Column(db.Text, nullable=True)
+
+ # Relationships
+ organization = db.relationship("Organization", back_populates="departments")
+ memberships = db.relationship(
+ "DepartmentMembership",
+ back_populates="department",
+ cascade="all, delete-orphan",
+ )
+ principal_links = db.relationship(
+ "DepartmentPrincipal",
+ back_populates="department",
+ cascade="all, delete-orphan",
+ )
+ cert_policy = db.relationship(
+ "DepartmentCertPolicy",
+ back_populates="department",
+ uselist=False,
+ cascade="all, delete-orphan",
+ )
+
+ __table_args__ = (
+ db.UniqueConstraint("organization_id", "name", name="uix_org_dept_name"),
+ )
+
+ def __repr__(self):
+ """String representation of Department."""
+ return f""
+
+ def to_dict(self, exclude=None):
+ """Convert department to dictionary."""
+ exclude = exclude or []
+ data = super().to_dict(exclude=exclude)
+ data["member_count"] = len([m for m in self.memberships if m.deleted_at is None])
+ data["principal_count"] = len([p for p in self.principal_links if p.deleted_at is None])
+ return data
+
+ def get_members(self, active_only: bool = True):
+ """Get all members of this department.
+
+ Args:
+ active_only: If True, exclude soft-deleted members
+
+ Returns:
+ List of DepartmentMembership objects
+ """
+ if active_only:
+ return [m for m in self.memberships if m.deleted_at is None]
+ return list(self.memberships)
+
+ def get_principals(self, active_only: bool = True):
+ """Get all principals assigned to this department.
+
+ Args:
+ active_only: If True, exclude soft-deleted principals
+
+ Returns:
+ List of Principal objects via DepartmentPrincipal
+ """
+ if active_only:
+ return [
+ p.principal
+ for p in self.principal_links
+ if p.deleted_at is None and p.principal.deleted_at is None
+ ]
+ return [p.principal for p in self.principal_links]
+
+ def is_member(self, user_id: str) -> bool:
+ """Check if a user is a member of this department.
+
+ Args:
+ user_id: ID of the user to check
+
+ Returns:
+ True if user is an active member, False otherwise
+ """
+ return (
+ DepartmentMembership.query.filter_by(
+ user_id=user_id,
+ department_id=self.id,
+ deleted_at=None,
+ ).first()
+ is not None
+ )
+
+ def get_member_count(self) -> int:
+ """Get the count of active members in this department."""
+ return len(self.get_members(active_only=True))
+
+
+class DepartmentMembership(BaseModel):
+ """Department membership model representing user membership in a department.
+
+ When a user is added to a department, they become eligible for SSH principals
+ assigned to that department.
+ """
+
+ __tablename__ = "department_memberships"
+
+ user_id = db.Column(
+ db.String(36),
+ db.ForeignKey("users.id"),
+ nullable=False,
+ index=True,
+ )
+ department_id = db.Column(
+ db.String(36),
+ db.ForeignKey("departments.id"),
+ nullable=False,
+ index=True,
+ )
+
+ # Relationships
+ user = db.relationship("User", back_populates="department_memberships")
+ department = db.relationship("Department", back_populates="memberships")
+
+ __table_args__ = (
+ db.UniqueConstraint("user_id", "department_id", name="uix_user_dept"),
+ )
+
+ def __repr__(self):
+ """String representation of DepartmentMembership."""
+ return (
+ f""
+ )
+
+
+class DepartmentPrincipal(BaseModel):
+ """Department principal assignment model.
+
+ Represents the assignment of principals to departments. All members of a
+ department get access to its assigned principals (transitively).
+
+ Example:
+ - Department: "Engineering"
+ - Principal: "eng-prod-servers"
+ - All engineering department members can SSH as "eng-prod-servers"
+ """
+
+ __tablename__ = "department_principals"
+
+ department_id = db.Column(
+ db.String(36),
+ db.ForeignKey("departments.id"),
+ nullable=False,
+ index=True,
+ )
+ principal_id = db.Column(
+ db.String(36),
+ db.ForeignKey("principals.id"),
+ nullable=False,
+ index=True,
+ )
+
+ # Relationships
+ department = db.relationship("Department", back_populates="principal_links")
+ principal = db.relationship("Principal", back_populates="department_links")
+
+ __table_args__ = (
+ db.UniqueConstraint("department_id", "principal_id", name="uix_dept_principal"),
+ )
+
+ def __repr__(self):
+ """String representation of DepartmentPrincipal."""
+ return (
+ f""
+ )
diff --git a/gatehouse_app/models/organization/department_cert_policy.py b/gatehouse_app/models/organization/department_cert_policy.py
new file mode 100644
index 0000000..357329f
--- /dev/null
+++ b/gatehouse_app/models/organization/department_cert_policy.py
@@ -0,0 +1,76 @@
+"""DepartmentCertPolicy — per-department SSH certificate issuance rules."""
+from datetime import datetime, timezone
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+
+
+# Standard SSH certificate extensions
+STANDARD_EXTENSIONS = [
+ "permit-X11-forwarding",
+ "permit-agent-forwarding",
+ "permit-pty",
+ "permit-port-forwarding",
+ "permit-user-rc",
+]
+
+
+class DepartmentCertPolicy(BaseModel):
+ """SSH certificate policy for a department.
+
+ Controls:
+ - Whether members may choose their own expiry date (up to ``max_expiry_hours``)
+ - Default expiry hours when the user doesn't (or can't) pick
+ - Maximum expiry hours (hard ceiling, even for admins signing on behalf)
+ - Which SSH certificate extensions are granted to members of this department
+ - Any custom extensions the admin wants to add beyond the standard five
+
+ Inherits ``id``, ``created_at``, ``updated_at``, and ``deleted_at`` from
+ :class:`BaseModel` so soft-delete and the standard timestamp behaviour are
+ consistent with every other model in the application.
+ """
+
+ __tablename__ = "department_cert_policies"
+
+ department_id = db.Column(
+ db.String(36),
+ db.ForeignKey("departments.id"),
+ nullable=False,
+ unique=True,
+ index=True,
+ )
+
+ # Expiry control
+ allow_user_expiry = db.Column(db.Boolean, nullable=False, default=False)
+ default_expiry_hours = db.Column(db.Integer, nullable=False, default=1)
+ max_expiry_hours = db.Column(db.Integer, nullable=False, default=24)
+
+ # Extensions — list of extension name strings
+ allowed_extensions = db.Column(
+ db.JSON,
+ nullable=False,
+ default=lambda: list(STANDARD_EXTENSIONS),
+ )
+ # Admin-defined extras beyond the standard five
+ custom_extensions = db.Column(db.JSON, nullable=False, default=list)
+
+ # Relationship back to department
+ department = db.relationship("Department", back_populates="cert_policy", uselist=False)
+
+ def __repr__(self):
+ return (
+ f""
+ )
+
+ def all_extensions(self) -> list:
+ """Return the full list of enabled extensions (allowed + custom)."""
+ return list((self.allowed_extensions or []) + (self.custom_extensions or []))
+
+ def to_dict(self, exclude=None):
+ """Convert to dictionary."""
+ exclude = exclude or []
+ data = super().to_dict(exclude=exclude)
+ # Augment with computed / convenience fields not in the base columns
+ data["all_extensions"] = self.all_extensions()
+ data["standard_extensions"] = STANDARD_EXTENSIONS
+ return data
diff --git a/gatehouse_app/models/organization/org_invite_token.py b/gatehouse_app/models/organization/org_invite_token.py
new file mode 100644
index 0000000..2830b25
--- /dev/null
+++ b/gatehouse_app/models/organization/org_invite_token.py
@@ -0,0 +1,77 @@
+"""Organization invite token model."""
+import secrets
+from datetime import datetime, timezone, timedelta
+
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+
+
+class OrgInviteToken(BaseModel):
+ """Token-based invitation to join an organization."""
+
+ __tablename__ = "org_invite_tokens"
+
+ organization_id = db.Column(
+ db.String(36),
+ db.ForeignKey("organizations.id", ondelete="CASCADE"),
+ nullable=False,
+ index=True,
+ )
+ invited_by_id = db.Column(
+ db.String(36),
+ db.ForeignKey("users.id", ondelete="SET NULL"),
+ nullable=True,
+ )
+ email = db.Column(db.String(255), nullable=False, index=True)
+ role = db.Column(db.String(64), nullable=False, default="member")
+ token = db.Column(db.String(128), unique=True, nullable=False, index=True)
+ expires_at = db.Column(db.DateTime, nullable=False)
+ accepted_at = db.Column(db.DateTime, nullable=True)
+
+ organization = db.relationship(
+ "Organization",
+ backref=db.backref("invite_tokens", cascade="all, delete-orphan"),
+ )
+ invited_by = db.relationship("User", foreign_keys=[invited_by_id])
+
+ @classmethod
+ def generate(
+ cls,
+ organization_id: str,
+ email: str,
+ role: str = "member",
+ invited_by_id: str = None,
+ ttl_days: int = 7,
+ ) -> "OrgInviteToken":
+ """Create a new invite token for an organization."""
+ token_value = secrets.token_urlsafe(48)
+ instance = cls(
+ organization_id=organization_id,
+ email=email.lower(),
+ role=role,
+ invited_by_id=invited_by_id,
+ token=token_value,
+ expires_at=datetime.now(timezone.utc) + timedelta(days=ttl_days),
+ )
+ db.session.add(instance)
+ db.session.commit()
+ return instance
+
+ @property
+ def is_valid(self) -> bool:
+ """Return True if the token is unused and not expired."""
+ if self.accepted_at is not None:
+ return False
+ now = datetime.now(timezone.utc)
+ expires = self.expires_at
+ if expires.tzinfo is None:
+ expires = expires.replace(tzinfo=timezone.utc)
+ return now < expires
+
+ def accept(self) -> None:
+ """Mark the invite as accepted."""
+ self.accepted_at = datetime.now(timezone.utc)
+ db.session.commit()
+
+ def __repr__(self) -> str:
+ return f""
diff --git a/gatehouse_app/models/organization.py b/gatehouse_app/models/organization/organization.py
similarity index 81%
rename from gatehouse_app/models/organization.py
rename to gatehouse_app/models/organization/organization.py
index 1c4170f..9be5c65 100644
--- a/gatehouse_app/models/organization.py
+++ b/gatehouse_app/models/organization/organization.py
@@ -34,6 +34,15 @@ class Organization(BaseModel):
cascade="all, delete-orphan",
foreign_keys="OrganizationSecurityPolicy.organization_id",
)
+ departments = db.relationship(
+ "Department", back_populates="organization", cascade="all, delete-orphan"
+ )
+ principals = db.relationship(
+ "Principal", back_populates="organization", cascade="all, delete-orphan"
+ )
+ cas = db.relationship(
+ "CA", back_populates="organization", cascade="all, delete-orphan"
+ )
def __repr__(self):
"""String representation of Organization."""
@@ -52,9 +61,9 @@ class Organization(BaseModel):
return member.user
return None
- def is_member(self, user_id):
+ def is_member(self, user_id: str) -> bool:
"""Check if a user is a member of the organization."""
- from gatehouse_app.models.organization_member import OrganizationMember
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
return (
OrganizationMember.query.filter_by(
diff --git a/gatehouse_app/models/organization_member.py b/gatehouse_app/models/organization/organization_member.py
similarity index 78%
rename from gatehouse_app/models/organization_member.py
rename to gatehouse_app/models/organization/organization_member.py
index 3247082..3b02242 100644
--- a/gatehouse_app/models/organization_member.py
+++ b/gatehouse_app/models/organization/organization_member.py
@@ -1,4 +1,4 @@
-"""Organization member model."""
+"""Organization member model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import OrganizationRole
@@ -21,31 +21,35 @@ class OrganizationMember(BaseModel):
joined_at = db.Column(db.DateTime, nullable=True)
# Relationships
- user = db.relationship("User", foreign_keys=[user_id], back_populates="organization_memberships")
+ user = db.relationship(
+ "User", foreign_keys=[user_id], back_populates="organization_memberships"
+ )
organization = db.relationship("Organization", back_populates="members")
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
- # Unique constraint to prevent duplicate memberships
__table_args__ = (
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"),
)
def __repr__(self):
"""String representation of OrganizationMember."""
- return f""
+ return (
+ f""
+ )
- def is_owner(self):
+ def is_owner(self) -> bool:
"""Check if member is an owner."""
return self.role == OrganizationRole.OWNER
- def is_admin(self):
+ def is_admin(self) -> bool:
"""Check if member is an admin or owner."""
return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
- def can_manage_members(self):
+ def can_manage_members(self) -> bool:
"""Check if member can manage other members."""
return self.is_admin()
- def can_delete_organization(self):
+ def can_delete_organization(self) -> bool:
"""Check if member can delete the organization."""
return self.is_owner()
diff --git a/gatehouse_app/models/organization/principal.py b/gatehouse_app/models/organization/principal.py
new file mode 100644
index 0000000..8f87fe3
--- /dev/null
+++ b/gatehouse_app/models/organization/principal.py
@@ -0,0 +1,215 @@
+"""Principal and PrincipalMembership models."""
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+
+
+class Principal(BaseModel):
+ """Principal model representing an SSH principal (access level/role).
+
+ In SSH CA terminology, a principal is a string like "eng-prod-servers" or
+ "devops-admins" that represents a set of machines or access level. Users
+ can be granted access to principals, either directly or via department
+ membership.
+
+ Example:
+ - Principal: "eng-prod-servers"
+ - Users with this principal can SSH to prod servers
+ - Can be assigned to departments or directly to users
+ """
+
+ __tablename__ = "principals"
+
+ organization_id = db.Column(
+ db.String(36),
+ db.ForeignKey("organizations.id"),
+ nullable=False,
+ index=True,
+ )
+ name = db.Column(db.String(255), nullable=False, index=True)
+ description = db.Column(db.Text, nullable=True)
+
+ # Relationships
+ organization = db.relationship("Organization", back_populates="principals")
+ memberships = db.relationship(
+ "PrincipalMembership",
+ back_populates="principal",
+ cascade="all, delete-orphan",
+ )
+ department_links = db.relationship(
+ "DepartmentPrincipal",
+ back_populates="principal",
+ cascade="all, delete-orphan",
+ )
+
+ __table_args__ = (
+ db.UniqueConstraint("organization_id", "name", name="uix_org_principal_name"),
+ )
+
+ def __repr__(self):
+ """String representation of Principal."""
+ return f""
+
+ def to_dict(self, exclude=None):
+ """Convert principal to dictionary."""
+ exclude = exclude or []
+ data = super().to_dict(exclude=exclude)
+ data["direct_member_count"] = len(
+ [m for m in self.memberships if m.deleted_at is None]
+ )
+ data["department_count"] = len(
+ [d for d in self.department_links if d.deleted_at is None]
+ )
+ return data
+
+ def get_members(self, active_only: bool = True):
+ """Get all users who are directly assigned to this principal.
+
+ Does NOT include users who get access via department membership.
+
+ Args:
+ active_only: If True, exclude soft-deleted members
+
+ Returns:
+ List of PrincipalMembership objects
+ """
+ if active_only:
+ return [m for m in self.memberships if m.deleted_at is None]
+ return list(self.memberships)
+
+ def get_all_members(self, active_only: bool = True):
+ """Get all users who have access to this principal.
+
+ Includes both direct members and users via department membership.
+
+ Args:
+ active_only: If True, exclude soft-deleted members
+
+ Returns:
+ Set of User objects with access to this principal
+ """
+ all_users: set = set()
+
+ # Direct members
+ for membership in self.get_members(active_only=active_only):
+ if not active_only or membership.user.deleted_at is None:
+ all_users.add(membership.user)
+
+ # Members via department assignment
+ for dept_link in self.department_links:
+ if dept_link.deleted_at is None or not active_only:
+ for dept_member in dept_link.department.get_members(active_only=active_only):
+ if not active_only or dept_member.user.deleted_at is None:
+ all_users.add(dept_member.user)
+
+ return all_users
+
+ def get_departments(self, active_only: bool = True):
+ """Get all departments this principal is assigned to.
+
+ Args:
+ active_only: If True, exclude soft-deleted departments
+
+ Returns:
+ List of Department objects
+ """
+ if active_only:
+ return [
+ d.department
+ for d in self.department_links
+ if d.deleted_at is None and d.department.deleted_at is None
+ ]
+ return [d.department for d in self.department_links]
+
+ def is_member(self, user_id: str, include_via_department: bool = True) -> bool:
+ """Check if a user has access to this principal.
+
+ Args:
+ user_id: ID of the user to check
+ include_via_department: If True, check department memberships too
+
+ Returns:
+ True if user has access to this principal
+ """
+ # Check direct membership
+ has_direct = (
+ PrincipalMembership.query.filter_by(
+ user_id=user_id,
+ principal_id=self.id,
+ deleted_at=None,
+ ).first()
+ is not None
+ )
+
+ if has_direct:
+ return True
+
+ if not include_via_department:
+ return False
+
+ # Check department membership
+ dept_ids = [d.id for d in self.get_departments(active_only=True)]
+ if not dept_ids:
+ return False
+
+ from gatehouse_app.models.organization.department import DepartmentMembership
+
+ return (
+ DepartmentMembership.query.filter(
+ DepartmentMembership.user_id == user_id,
+ DepartmentMembership.department_id.in_(dept_ids),
+ DepartmentMembership.deleted_at.is_(None),
+ ).first()
+ is not None
+ )
+
+ def get_member_count(self, include_via_department: bool = True) -> int:
+ """Get the count of active members with access to this principal.
+
+ Args:
+ include_via_department: If True, include members via department
+
+ Returns:
+ Count of members
+ """
+ if not include_via_department:
+ return len(self.get_members(active_only=True))
+ return len(self.get_all_members(active_only=True))
+
+
+class PrincipalMembership(BaseModel):
+ """Principal membership model representing direct user assignment to a principal.
+
+ When a user is assigned directly to a principal, they get access to that
+ principal for SSH authentication. This is in addition to any principals
+ they get via department membership.
+ """
+
+ __tablename__ = "principal_memberships"
+
+ user_id = db.Column(
+ db.String(36),
+ db.ForeignKey("users.id"),
+ nullable=False,
+ index=True,
+ )
+ principal_id = db.Column(
+ db.String(36),
+ db.ForeignKey("principals.id"),
+ nullable=False,
+ index=True,
+ )
+
+ # Relationships
+ user = db.relationship("User", back_populates="principal_memberships")
+ principal = db.relationship("Principal", back_populates="memberships")
+
+ __table_args__ = (
+ db.UniqueConstraint("user_id", "principal_id", name="uix_user_principal"),
+ )
+
+ def __repr__(self):
+ """String representation of PrincipalMembership."""
+ return (
+ f""
+ )
diff --git a/gatehouse_app/models/security/__init__.py b/gatehouse_app/models/security/__init__.py
new file mode 100644
index 0000000..d24aef5
--- /dev/null
+++ b/gatehouse_app/models/security/__init__.py
@@ -0,0 +1,12 @@
+"""Security subpackage — organization and user security policies, MFA compliance."""
+from gatehouse_app.models.security.organization_security_policy import (
+ OrganizationSecurityPolicy,
+)
+from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy
+from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
+
+__all__ = [
+ "OrganizationSecurityPolicy",
+ "UserSecurityPolicy",
+ "MfaPolicyCompliance",
+]
diff --git a/gatehouse_app/models/mfa_policy_compliance.py b/gatehouse_app/models/security/mfa_policy_compliance.py
similarity index 76%
rename from gatehouse_app/models/mfa_policy_compliance.py
rename to gatehouse_app/models/security/mfa_policy_compliance.py
index 6ecd217..5c2de13 100644
--- a/gatehouse_app/models/mfa_policy_compliance.py
+++ b/gatehouse_app/models/security/mfa_policy_compliance.py
@@ -1,4 +1,4 @@
-"""MfaPolicyCompliance model."""
+"""MfaPolicyCompliance model — per-user per-organization MFA compliance tracking."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import MfaComplianceStatus
@@ -7,7 +7,8 @@ from gatehouse_app.utils.constants import MfaComplianceStatus
class MfaPolicyCompliance(BaseModel):
"""MFA policy compliance tracking per user per organization.
- Tracks each user's MFA compliance state separately for each organization membership.
+ Tracks each user's MFA compliance state separately for each organization
+ membership. One row per (user, org) pair.
"""
__tablename__ = "mfa_policy_compliance"
@@ -25,13 +26,13 @@ class MfaPolicyCompliance(BaseModel):
default=MfaComplianceStatus.NOT_APPLICABLE,
)
- # Snapshot of org policy at the time this record became active
+ # Snapshot of org policy version when this record became active
policy_version = db.Column(db.Integer, nullable=False)
# When policy started applying to this user
applied_at = db.Column(db.DateTime, nullable=True)
- # Final deadline for this user to comply (per user, not global)
+ # Final deadline for this user to comply
deadline_at = db.Column(db.DateTime, nullable=True)
# When they became compliant under this policy_version
@@ -45,9 +46,7 @@ class MfaPolicyCompliance(BaseModel):
notification_count = db.Column(db.Integer, nullable=False, default=0)
__table_args__ = (
- db.UniqueConstraint(
- "user_id", "organization_id", name="uix_user_org_compliance"
- ),
+ db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_compliance"),
)
# Relationships
@@ -58,9 +57,11 @@ class MfaPolicyCompliance(BaseModel):
def __repr__(self):
"""String representation of MfaPolicyCompliance."""
- return f""
+ return (
+ f""
+ )
def to_dict(self, exclude=None):
"""Convert to dictionary."""
- exclude = exclude or []
- return super().to_dict(exclude=exclude)
\ No newline at end of file
+ return super().to_dict(exclude=exclude or [])
diff --git a/gatehouse_app/models/organization_security_policy.py b/gatehouse_app/models/security/organization_security_policy.py
similarity index 83%
rename from gatehouse_app/models/organization_security_policy.py
rename to gatehouse_app/models/security/organization_security_policy.py
index 991b72d..593781a 100644
--- a/gatehouse_app/models/organization_security_policy.py
+++ b/gatehouse_app/models/security/organization_security_policy.py
@@ -39,15 +39,19 @@ class OrganizationSecurityPolicy(BaseModel):
# Relationships
organization = db.relationship(
- "Organization", back_populates="security_policy", foreign_keys=[organization_id]
+ "Organization",
+ back_populates="security_policy",
+ foreign_keys=[organization_id],
)
updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id])
def __repr__(self):
"""String representation of OrganizationSecurityPolicy."""
- return f""
+ return (
+ f""
+ )
def to_dict(self, exclude=None):
"""Convert to dictionary."""
- exclude = exclude or []
- return super().to_dict(exclude=exclude)
\ No newline at end of file
+ return super().to_dict(exclude=exclude or [])
diff --git a/gatehouse_app/models/user_security_policy.py b/gatehouse_app/models/security/user_security_policy.py
similarity index 67%
rename from gatehouse_app/models/user_security_policy.py
rename to gatehouse_app/models/security/user_security_policy.py
index d765575..a96ef84 100644
--- a/gatehouse_app/models/user_security_policy.py
+++ b/gatehouse_app/models/security/user_security_policy.py
@@ -1,4 +1,4 @@
-"""UserSecurityPolicy model."""
+"""UserSecurityPolicy model — per-user MFA overrides."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import MfaRequirementOverride
@@ -7,7 +7,7 @@ from gatehouse_app.utils.constants import MfaRequirementOverride
class UserSecurityPolicy(BaseModel):
"""User security policy model for per-user MFA overrides.
- Stores per user overrides of organization level MFA requirements.
+ Stores per-user overrides of organization-level MFA requirements.
"""
__tablename__ = "user_security_policies"
@@ -25,29 +25,27 @@ class UserSecurityPolicy(BaseModel):
default=MfaRequirementOverride.INHERIT,
)
- # If override is REQUIRED and you want to force a specific factor set
+ # If override is REQUIRED, optionally force a specific factor set
force_totp = db.Column(db.Boolean, nullable=False, default=False)
force_webauthn = db.Column(db.Boolean, nullable=False, default=False)
__table_args__ = (
- db.UniqueConstraint(
- "user_id", "organization_id", name="uix_user_org_policy"
- ),
+ db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_policy"),
)
# Relationships
user = db.relationship(
"User", back_populates="security_policies", foreign_keys=[user_id]
)
- organization = db.relationship(
- "Organization", foreign_keys=[organization_id]
- )
+ organization = db.relationship("Organization", foreign_keys=[organization_id])
def __repr__(self):
"""String representation of UserSecurityPolicy."""
- return f""
+ return (
+ f""
+ )
def to_dict(self, exclude=None):
"""Convert to dictionary."""
- exclude = exclude or []
- return super().to_dict(exclude=exclude)
\ No newline at end of file
+ return super().to_dict(exclude=exclude or [])
diff --git a/gatehouse_app/models/ssh_ca/__init__.py b/gatehouse_app/models/ssh_ca/__init__.py
new file mode 100644
index 0000000..d6932b3
--- /dev/null
+++ b/gatehouse_app/models/ssh_ca/__init__.py
@@ -0,0 +1,17 @@
+"""SSH/CA subpackage — certificate authorities, SSH keys, and certificates."""
+from gatehouse_app.models.ssh_ca.ca import CA, KeyType, CertType, CaType, CAPermission
+from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
+from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
+from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
+
+__all__ = [
+ "CA",
+ "KeyType",
+ "CertType",
+ "CaType",
+ "CAPermission",
+ "SSHKey",
+ "SSHCertificate",
+ "CertificateStatus",
+ "CertificateAuditLog",
+]
diff --git a/gatehouse_app/models/ssh_ca/ca.py b/gatehouse_app/models/ssh_ca/ca.py
new file mode 100644
index 0000000..182d842
--- /dev/null
+++ b/gatehouse_app/models/ssh_ca/ca.py
@@ -0,0 +1,238 @@
+"""Certificate Authority (CA) model."""
+from enum import Enum
+from datetime import datetime, timezone
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+
+
+class KeyType(str, Enum):
+ """SSH CA key types."""
+
+ ED25519 = "ed25519"
+ RSA = "rsa"
+ ECDSA = "ecdsa"
+
+
+class CertType(str, Enum):
+ """SSH certificate types."""
+
+ USER = "user"
+ HOST = "host"
+
+
+class CaType(str, Enum):
+ """CA signing type — whether this CA signs user or host certificates."""
+
+ USER = "user"
+ HOST = "host"
+
+
+class CA(BaseModel):
+ """Certificate Authority (CA) model for SSH certificate signing.
+
+ Each organization can have multiple CAs for different purposes
+ (e.g., production vs. staging). Private keys are encrypted at rest
+ and should be protected with KMS.
+ """
+
+ __tablename__ = "cas"
+
+ organization_id = db.Column(
+ db.String(36),
+ db.ForeignKey("organizations.id"),
+ nullable=True, # NULL for the global system-config CA
+ index=True,
+ )
+
+ # CA identity
+ name = db.Column(db.String(255), nullable=False)
+ description = db.Column(db.Text, nullable=True)
+
+ # CA signing type: 'user' signs user certificates, 'host' signs host certs
+ ca_type = db.Column(
+ db.Enum(CaType, values_callable=lambda x: [e.value for e in x]),
+ default=CaType.USER,
+ nullable=False,
+ )
+
+ # Key type (ED25519, RSA, ECDSA)
+ key_type = db.Column(
+ db.Enum(KeyType, values_callable=lambda x: [e.value for e in x]),
+ default=KeyType.ED25519,
+ nullable=False,
+ )
+
+ # Private key — PEM-encoded, encrypted at rest by database/KMS
+ private_key = db.Column(db.Text, nullable=False)
+
+ # Public key — PEM format
+ public_key = db.Column(db.Text, nullable=False)
+
+ # SHA256 fingerprint of the public key
+ fingerprint = db.Column(db.String(255), nullable=False, unique=True)
+
+ # CRL (Certificate Revocation List) configuration
+ crl_enabled = db.Column(db.Boolean, default=True, nullable=False)
+ crl_endpoint = db.Column(db.String(512), nullable=True)
+
+ # Default certificate validity in hours (overridable per request)
+ default_cert_validity_hours = db.Column(db.Integer, default=1, nullable=False)
+
+ # Maximum validity duration allowed
+ max_cert_validity_hours = db.Column(db.Integer, default=24, nullable=False)
+
+ # CA status
+ is_active = db.Column(db.Boolean, default=True, nullable=False, index=True)
+
+ # Key rotation tracking
+ rotated_at = db.Column(db.DateTime, nullable=True)
+ rotation_reason = db.Column(db.String(255), nullable=True)
+
+ # Monotonically-increasing serial counter. Every cert this CA issues
+ # gets the next value so serials are unique, ordered, and auditable.
+ # Protected by a row-level SELECT … FOR UPDATE in get_next_serial().
+ next_serial_number = db.Column(db.BigInteger, default=1, nullable=False)
+
+ # Relationships
+ organization = db.relationship("Organization", back_populates="cas")
+ certificates = db.relationship(
+ "SSHCertificate",
+ back_populates="ca",
+ cascade="all, delete-orphan",
+ )
+ permissions = db.relationship(
+ "CAPermission",
+ back_populates="ca",
+ cascade="all, delete-orphan",
+ )
+
+ __table_args__ = (
+ db.Index("idx_ca_org_active", "organization_id", "is_active"),
+ )
+
+ def __repr__(self):
+ """String representation of CA."""
+ return (
+ f""
+ )
+
+ def to_dict(self, exclude=None):
+ """Convert CA to dictionary, never exposing the private key."""
+ exclude = exclude or []
+ if "private_key" not in exclude:
+ exclude.append("private_key")
+ data = super().to_dict(exclude=exclude)
+
+ # Add computed fields
+ data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None])
+ data["active_certs"] = len(
+ [c for c in self.certificates if c.deleted_at is None and not c.revoked]
+ )
+ data["revoked_certs"] = len(
+ [c for c in self.certificates if c.deleted_at is None and c.revoked]
+ )
+ return data
+
+ def get_active_certificates(self) -> list:
+ """Get all active (non-revoked) certificates issued by this CA."""
+ return [
+ c for c in self.certificates if c.deleted_at is None and not c.revoked
+ ]
+
+ def rotate_key(
+ self,
+ new_private_key: str,
+ new_public_key: str,
+ new_fingerprint: str,
+ reason: str = None,
+ ) -> None:
+ """Rotate the CA's key pair.
+
+ This should only be done in carefully controlled circumstances.
+ All existing certificates remain valid but no new certificates can be
+ signed with the old key after rotation.
+
+ Args:
+ new_private_key: New PEM-encoded private key
+ new_public_key: New PEM-encoded public key
+ new_fingerprint: SHA256 fingerprint of new public key
+ reason: Optional reason for rotation
+ """
+ self.private_key = new_private_key
+ self.public_key = new_public_key
+ self.fingerprint = new_fingerprint
+ self.rotated_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
+ self.rotation_reason = reason
+ self.save()
+
+ def get_next_serial(self) -> int:
+ """Atomically increment and return the next certificate serial number.
+
+ Uses a SELECT … FOR UPDATE row lock so concurrent requests never
+ receive the same serial. Must be called inside an active DB
+ transaction (i.e. before the final session.commit()).
+
+ Returns:
+ int: The serial number to embed in the next certificate.
+ """
+ # Re-fetch this CA row with an exclusive row lock
+ locked = (
+ db.session.query(CA)
+ .with_for_update()
+ .filter_by(id=self.id)
+ .one()
+ )
+ serial = locked.next_serial_number
+ locked.next_serial_number = serial + 1
+ db.session.flush() # write increment; commit happens in the caller
+ return serial
+
+
+class CAPermission(BaseModel):
+ """Per-user CA permission model.
+
+ Controls which users are allowed to sign certificates against a specific CA.
+ When a CA has any permission rows, the signing endpoint enforces the list;
+ CAs with no rows are open to all org members (backwards-compatible default).
+
+ Permission values:
+ sign – user may request certificate signing
+ admin – user may sign AND manage the CA (rotate keys, delete, etc.)
+ """
+
+ __tablename__ = "ca_permissions"
+
+ ca_id = db.Column(
+ db.String(36),
+ db.ForeignKey("cas.id", ondelete="CASCADE"),
+ nullable=False,
+ index=True,
+ )
+ user_id = db.Column(
+ db.String(36),
+ db.ForeignKey("users.id", ondelete="CASCADE"),
+ nullable=False,
+ index=True,
+ )
+
+ permission = db.Column(db.String(50), nullable=False, default="sign")
+
+ # Relationships
+ ca = db.relationship("CA", back_populates="permissions")
+ user = db.relationship("User", back_populates="ca_permissions")
+
+ __table_args__ = (
+ db.UniqueConstraint("ca_id", "user_id", name="uix_ca_permission"),
+ )
+
+ def __repr__(self):
+ return (
+ f""
+ )
+
+ def to_dict(self, exclude=None):
+ data = super().to_dict(exclude=exclude or [])
+ data["permission"] = self.permission
+ return data
diff --git a/gatehouse_app/models/ssh_ca/certificate_audit_log.py b/gatehouse_app/models/ssh_ca/certificate_audit_log.py
new file mode 100644
index 0000000..02f24d3
--- /dev/null
+++ b/gatehouse_app/models/ssh_ca/certificate_audit_log.py
@@ -0,0 +1,91 @@
+"""Certificate audit log model — tracks SSH certificate lifecycle events."""
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+
+
+class CertificateAuditLog(BaseModel):
+ """Audit log for SSH certificate lifecycle events.
+
+ Tracks all operations on SSH certificates: signing, revocation, validation,
+ etc. Kept separate from the general AuditLog to provide detailed certificate
+ operation tracking without polluting the main audit stream.
+ """
+
+ __tablename__ = "certificate_audit_logs"
+
+ # Reference to the certificate
+ certificate_id = db.Column(
+ db.String(36),
+ db.ForeignKey("ssh_certificates.id"),
+ nullable=False,
+ index=True,
+ )
+
+ # The user who performed the action (null for system actions)
+ user_id = db.Column(
+ db.String(36),
+ db.ForeignKey("users.id"),
+ nullable=True,
+ index=True,
+ )
+
+ # Action type (e.g., "signed", "revoked", "validated", "requested")
+ action = db.Column(db.String(50), nullable=False, index=True)
+
+ # Request details
+ ip_address = db.Column(db.String(45), nullable=True)
+ user_agent = db.Column(db.String(512), nullable=True)
+ request_id = db.Column(db.String(36), nullable=True)
+
+ # Detailed message
+ message = db.Column(db.Text, nullable=True)
+
+ # Additional context
+ extra_data = db.Column(db.JSON, nullable=True)
+
+ # Outcome
+ success = db.Column(db.Boolean, default=True, nullable=False)
+ error_message = db.Column(db.Text, nullable=True)
+
+ # Relationships
+ certificate = db.relationship("SSHCertificate", back_populates="audit_logs")
+ user = db.relationship("User")
+
+ __table_args__ = (
+ db.Index("idx_cert_audit_cert_action", "certificate_id", "action"),
+ db.Index("idx_cert_audit_user", "user_id", "created_at"),
+ )
+
+ def __repr__(self):
+ """String representation of CertificateAuditLog."""
+ return (
+ f""
+ )
+
+ @classmethod
+ def log(
+ cls,
+ certificate_id: str,
+ action: str,
+ user_id: str = None,
+ **kwargs,
+ ) -> "CertificateAuditLog":
+ """Create a certificate audit log entry.
+
+ Args:
+ certificate_id: ID of the certificate
+ action: Action type (e.g., "signed", "revoked")
+ user_id: ID of the user performing the action (optional)
+ **kwargs: Additional fields (ip_address, user_agent, message, etc.)
+
+ Returns:
+ CertificateAuditLog instance
+ """
+ log_entry = cls(
+ certificate_id=certificate_id,
+ action=action,
+ user_id=user_id,
+ **kwargs,
+ )
+ log_entry.save()
+ return log_entry
diff --git a/gatehouse_app/models/ssh_ca/ssh_certificate.py b/gatehouse_app/models/ssh_ca/ssh_certificate.py
new file mode 100644
index 0000000..f226a69
--- /dev/null
+++ b/gatehouse_app/models/ssh_ca/ssh_certificate.py
@@ -0,0 +1,176 @@
+"""SSH Certificate model — signed SSH user/host certificates."""
+from enum import Enum
+from datetime import datetime, timezone
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+from gatehouse_app.models.ssh_ca.ca import CertType
+
+
+class CertificateStatus(str, Enum):
+ """SSH certificate lifecycle status."""
+
+ REQUESTED = "requested" # Waiting for signing
+ ISSUED = "issued" # Signed and valid
+ REVOKED = "revoked" # Manually revoked
+ EXPIRED = "expired" # Validity period ended
+ SUPERSEDED = "superseded" # Replaced by newer certificate
+
+
+class SSHCertificate(BaseModel):
+ """SSH Certificate model representing a signed SSH user/host certificate.
+
+ Certificates are issued by a CA and associated with an SSH public key.
+ They include principals (access levels), validity periods, and standard
+ OpenSSH certificate metadata.
+ """
+
+ __tablename__ = "ssh_certificates"
+
+ # Certificate relationships
+ ca_id = db.Column(
+ db.String(36),
+ db.ForeignKey("cas.id"),
+ nullable=False,
+ index=True,
+ )
+ user_id = db.Column(
+ db.String(36),
+ db.ForeignKey("users.id"),
+ nullable=False,
+ index=True,
+ )
+ ssh_key_id = db.Column(
+ db.String(36),
+ db.ForeignKey("ssh_keys.id"),
+ nullable=False,
+ index=True,
+ )
+
+ # Certificate content — full signed certificate in OpenSSH format
+ certificate = db.Column(db.Text, nullable=False)
+
+ # Certificate metadata
+ serial = db.Column(db.String(255), nullable=False, unique=True, index=True)
+ key_id = db.Column(db.String(255), nullable=False) # Usually user email
+ cert_type = db.Column(
+ db.Enum(CertType, values_callable=lambda x: [e.value for e in x]),
+ default=CertType.USER,
+ nullable=False,
+ )
+
+ # Principals — JSON list, e.g., ["prod-servers", "dev-servers"]
+ principals = db.Column(db.JSON, nullable=False, default=list)
+
+ # Validity period
+ valid_after = db.Column(db.DateTime, nullable=False)
+ valid_before = db.Column(db.DateTime, nullable=False)
+
+ # Revocation status
+ revoked = db.Column(db.Boolean, default=False, nullable=False, index=True)
+ revoked_at = db.Column(db.DateTime, nullable=True)
+ revoke_reason = db.Column(db.String(255), nullable=True)
+
+ # Status tracking
+ status = db.Column(
+ db.Enum(CertificateStatus, values_callable=lambda x: [e.value for e in x]),
+ default=CertificateStatus.ISSUED,
+ nullable=False,
+ index=True,
+ )
+
+ # Request metadata
+ request_ip = db.Column(db.String(45), nullable=True)
+ request_user_agent = db.Column(db.String(512), nullable=True)
+
+ # Critical options — OpenSSH critical options (JSON)
+ critical_options = db.Column(db.JSON, nullable=True, default=dict)
+
+ # Extensions — OpenSSH extensions (JSON)
+ extensions = db.Column(db.JSON, nullable=True, default=dict)
+
+ # Relationships
+ ca = db.relationship("CA", back_populates="certificates")
+ user = db.relationship("User", back_populates="ssh_certificates")
+ ssh_key = db.relationship(
+ "SSHKey",
+ back_populates="certificates",
+ foreign_keys="SSHCertificate.ssh_key_id",
+ )
+ audit_logs = db.relationship(
+ "CertificateAuditLog",
+ back_populates="certificate",
+ cascade="all, delete-orphan",
+ )
+
+ __table_args__ = (
+ db.Index("idx_cert_user_status", "user_id", "status"),
+ db.Index("idx_cert_validity", "valid_after", "valid_before"),
+ db.Index("idx_cert_revoked", "revoked", "revoked_at"),
+ )
+
+ def __repr__(self):
+ """String representation of SSHCertificate."""
+ return f""
+
+ def to_dict(self, exclude=None):
+ """Convert certificate to dictionary.
+
+ The raw ``certificate`` blob is excluded by default (it is large and
+ callers can request it explicitly by removing it from the exclude list).
+ """
+ exclude = exclude or []
+ if "certificate" not in exclude:
+ exclude.append("certificate")
+ data = super().to_dict(exclude=exclude)
+ data["is_valid"] = self.is_valid()
+ data["days_until_expiry"] = self.days_until_expiry()
+ return data
+
+ def _aware(self, dt: datetime) -> datetime:
+ """Return a timezone-aware UTC datetime."""
+ return dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt
+
+ def is_valid(self) -> bool:
+ """Check if certificate is currently valid.
+
+ Returns:
+ True if certificate is issued, not revoked, and within validity period
+ """
+ if self.revoked or self.status == CertificateStatus.REVOKED:
+ return False
+ now = datetime.now(timezone.utc)
+ return self._aware(self.valid_after) <= now <= self._aware(self.valid_before)
+
+ def is_expired(self) -> bool:
+ """Check if certificate has expired.
+
+ Returns:
+ True if current time is past valid_before
+ """
+ return datetime.now(timezone.utc) > self._aware(self.valid_before)
+
+ def days_until_expiry(self) -> int:
+ """Get number of days until certificate expires.
+
+ Returns:
+ Number of days remaining (negative if already expired)
+ """
+ delta = self._aware(self.valid_before) - datetime.now(timezone.utc)
+ return delta.days + (1 if delta.seconds > 0 else 0)
+
+ def revoke(self, reason: str = None) -> None:
+ """Revoke this certificate.
+
+ Args:
+ reason: Optional reason for revocation
+ """
+ self.revoked = True
+ self.revoked_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
+ self.revoke_reason = reason
+ self.status = CertificateStatus.REVOKED
+ self.save()
+
+ def mark_expired(self) -> None:
+ """Mark certificate as expired when validity period ends."""
+ self.status = CertificateStatus.EXPIRED
+ self.save()
diff --git a/gatehouse_app/models/ssh_ca/ssh_key.py b/gatehouse_app/models/ssh_ca/ssh_key.py
new file mode 100644
index 0000000..218fd99
--- /dev/null
+++ b/gatehouse_app/models/ssh_ca/ssh_key.py
@@ -0,0 +1,98 @@
+"""SSH Key model — user SSH public keys registered for certificate signing."""
+from datetime import datetime, timezone
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+
+
+class SSHKey(BaseModel):
+ """SSH Key model representing a user's SSH public key.
+
+ Users register SSH public keys for certificate signing. Keys must be
+ verified (owner proved possession) before they can be used.
+ """
+
+ __tablename__ = "ssh_keys"
+
+ user_id = db.Column(
+ db.String(36),
+ db.ForeignKey("users.id"),
+ nullable=False,
+ index=True,
+ )
+
+ # SSH key payload in OpenSSH format (e.g., "ssh-ed25519 AAAAB3Nz...")
+ payload = db.Column(db.Text, nullable=False, unique=True)
+
+ # SHA256 fingerprint for quick comparison and deduplication
+ fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True)
+
+ # Optional human-readable description (e.g., "My laptop key")
+ description = db.Column(db.String(255), nullable=True)
+
+ # Verification status
+ verified = db.Column(db.Boolean, default=False, nullable=False, index=True)
+ verified_at = db.Column(db.DateTime, nullable=True)
+
+ # Verification challenge — shown to user once, cleared after verification
+ verify_text = db.Column(db.String(255), nullable=True)
+ verify_text_created_at = db.Column(db.DateTime, nullable=True)
+
+ # Key metadata extracted from the key
+ key_type = db.Column(db.String(50), nullable=True) # ssh-rsa, ssh-ed25519, etc.
+ key_bits = db.Column(db.Integer, nullable=True) # key length
+ key_comment = db.Column(db.String(255), nullable=True)
+
+ # Relationships
+ user = db.relationship("User", back_populates="ssh_keys")
+ certificates = db.relationship(
+ "SSHCertificate",
+ back_populates="ssh_key",
+ cascade="all, delete-orphan",
+ foreign_keys="SSHCertificate.ssh_key_id",
+ )
+
+ __table_args__ = (
+ db.Index("idx_ssh_key_user_verified", "user_id", "verified"),
+ )
+
+ def __repr__(self):
+ """String representation of SSHKey."""
+ return f""
+
+ def to_dict(self, exclude=None):
+ """Convert SSH key to dictionary.
+
+ ``payload`` and ``verify_text`` are never exposed through the API.
+ """
+ exclude = exclude or []
+ for field in ("payload", "verify_text"):
+ if field not in exclude:
+ exclude.append(field)
+ data = super().to_dict(exclude=exclude)
+ data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None])
+ return data
+
+ def mark_verified(self) -> None:
+ """Mark this SSH key as verified and clear the challenge."""
+ self.verified = True
+ self.verified_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
+ self.verify_text = None
+ self.save()
+
+ def needs_verification_refresh(self, max_age_hours: int = 24) -> bool:
+ """Check if verification challenge needs to be refreshed.
+
+ Args:
+ max_age_hours: Maximum age of verification challenge in hours
+
+ Returns:
+ True if verification challenge is stale or missing
+ """
+ if not self.verify_text_created_at:
+ return True
+ age = datetime.now(timezone.utc) - self.verify_text_created_at.replace(
+ tzinfo=timezone.utc
+ ) if self.verify_text_created_at.tzinfo is None else (
+ datetime.now(timezone.utc) - self.verify_text_created_at
+ )
+ return age.total_seconds() > (max_age_hours * 3600)
diff --git a/gatehouse_app/models/user.py b/gatehouse_app/models/user.py
index 474269f..eecf942 100644
--- a/gatehouse_app/models/user.py
+++ b/gatehouse_app/models/user.py
@@ -1,153 +1,4 @@
-"""User model."""
-from gatehouse_app.extensions import db
-from gatehouse_app.models.base import BaseModel
-from gatehouse_app.utils.constants import UserStatus
+"""Backward-compatibility shim — import from gatehouse_app.models.user.user instead."""
+from gatehouse_app.models.user.user import User # noqa: F401
-
-class User(BaseModel):
- """User model representing a user account."""
-
- __tablename__ = "users"
-
- email = db.Column(db.String(255), unique=True, nullable=False, index=True)
- email_verified = db.Column(db.Boolean, default=False, nullable=False)
- full_name = db.Column(db.String(255), nullable=True)
- avatar_url = db.Column(db.String(512), nullable=True)
- status = db.Column(
- db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
- )
- last_login_at = db.Column(db.DateTime, nullable=True)
- last_login_ip = db.Column(db.String(45), nullable=True)
-
- # Relationships
- authentication_methods = db.relationship(
- "AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
- )
- sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
- organization_memberships = db.relationship(
- "OrganizationMember",
- back_populates="user",
- cascade="all, delete-orphan",
- foreign_keys="OrganizationMember.user_id",
- )
- audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
- security_policies = db.relationship(
- "UserSecurityPolicy",
- back_populates="user",
- cascade="all, delete-orphan",
- foreign_keys="UserSecurityPolicy.user_id",
- )
- mfa_compliance = db.relationship(
- "MfaPolicyCompliance",
- back_populates="user",
- cascade="all, delete-orphan",
- foreign_keys="MfaPolicyCompliance.user_id",
- )
-
- def __repr__(self):
- """String representation of User."""
- return f""
-
- def to_dict(self, exclude=None):
- """Convert user to dictionary, excluding sensitive fields by default."""
- exclude = exclude or []
- # Always exclude password-related fields
- default_exclude = []
- all_exclude = list(set(default_exclude + exclude))
- return super().to_dict(exclude=all_exclude)
-
- def has_password_auth(self):
- """Check if user has password authentication enabled."""
- from gatehouse_app.models.authentication_method import AuthenticationMethod
- from gatehouse_app.utils.constants import AuthMethodType
-
- return (
- AuthenticationMethod.query.filter_by(
- user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
- ).first()
- is not None
- )
-
- def get_organizations(self):
- """Get all organizations the user is a member of."""
- return [membership.organization for membership in self.organization_memberships]
-
- def has_totp_enabled(self) -> bool:
- """Check if user has TOTP enabled and verified.
-
- Returns:
- True if user has a verified TOTP authentication method, False otherwise.
- """
- from gatehouse_app.models.authentication_method import AuthenticationMethod
- from gatehouse_app.utils.constants import AuthMethodType
-
- return (
- AuthenticationMethod.query.filter_by(
- user_id=self.id,
- method_type=AuthMethodType.TOTP,
- verified=True,
- deleted_at=None,
- ).first()
- is not None
- )
-
- def get_totp_method(self):
- """Get user's TOTP authentication method.
-
- Returns:
- The AuthenticationMethod instance for TOTP or None if not found.
-
- Note:
- Returns the most recently created TOTP method to handle cases where
- multiple enrollment attempts may exist.
- """
- from gatehouse_app.models.authentication_method import AuthenticationMethod
- from gatehouse_app.utils.constants import AuthMethodType
-
- return AuthenticationMethod.query.filter_by(
- user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
- ).order_by(AuthenticationMethod.created_at.desc()).first()
-
- def has_webauthn_enabled(self) -> bool:
- """Check if user has any WebAuthn passkey credentials.
-
- Returns:
- True if user has at least one WebAuthn credential, False otherwise.
- """
- from gatehouse_app.models.authentication_method import AuthenticationMethod
- from gatehouse_app.utils.constants import AuthMethodType
-
- return (
- AuthenticationMethod.query.filter_by(
- user_id=self.id,
- method_type=AuthMethodType.WEBAUTHN,
- deleted_at=None,
- ).first()
- is not None
- )
-
- def get_webauthn_credentials(self):
- """Get all WebAuthn credentials for the user.
-
- Returns:
- List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
- """
- from gatehouse_app.models.authentication_method import AuthenticationMethod
- from gatehouse_app.utils.constants import AuthMethodType
-
- return AuthenticationMethod.query.filter_by(
- user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
- ).order_by(AuthenticationMethod.created_at.desc()).all()
-
- def get_webauthn_credential_count(self) -> int:
- """Get the count of WebAuthn credentials for the user.
-
- Returns:
- Number of WebAuthn credentials.
- """
- from gatehouse_app.models.authentication_method import AuthenticationMethod
- from gatehouse_app.utils.constants import AuthMethodType
-
- return AuthenticationMethod.query.filter_by(
- user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
- ).count()
+__all__ = ["User"]
diff --git a/gatehouse_app/models/user/__init__.py b/gatehouse_app/models/user/__init__.py
new file mode 100644
index 0000000..1d05e9a
--- /dev/null
+++ b/gatehouse_app/models/user/__init__.py
@@ -0,0 +1,5 @@
+"""User subpackage."""
+from gatehouse_app.models.user.user import User
+from gatehouse_app.models.user.session import Session
+
+__all__ = ["User", "Session"]
diff --git a/gatehouse_app/models/session.py b/gatehouse_app/models/user/session.py
similarity index 86%
rename from gatehouse_app/models/session.py
rename to gatehouse_app/models/user/session.py
index 0290a20..9a78830 100644
--- a/gatehouse_app/models/session.py
+++ b/gatehouse_app/models/user/session.py
@@ -21,7 +21,9 @@ class Session(BaseModel):
# Timing
expires_at = db.Column(db.DateTime, nullable=False)
- last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
+ last_activity_at = db.Column(
+ db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)
+ )
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
@@ -38,7 +40,6 @@ class Session(BaseModel):
def is_active(self):
"""Check if session is currently active."""
now = datetime.now(timezone.utc)
- # Make expires_at timezone-aware if it's naive
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
@@ -51,15 +52,13 @@ class Session(BaseModel):
def is_expired(self):
"""Check if session has expired."""
now = datetime.now(timezone.utc)
- # Make expires_at timezone-aware if it's naive
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
- def refresh(self, duration_seconds=86400):
- """
- Refresh session expiration.
+ def refresh(self, duration_seconds: int = 86400):
+ """Refresh session expiration.
Args:
duration_seconds: New session duration in seconds
@@ -68,9 +67,8 @@ class Session(BaseModel):
self.last_activity_at = datetime.now(timezone.utc)
db.session.commit()
- def revoke(self, reason=None):
- """
- Revoke the session.
+ def revoke(self, reason: str = None):
+ """Revoke the session.
Args:
reason: Optional reason for revocation
@@ -84,6 +82,5 @@ class Session(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
- # Exclude token from dict
exclude.append("token")
return super().to_dict(exclude=exclude)
diff --git a/gatehouse_app/models/user/user.py b/gatehouse_app/models/user/user.py
new file mode 100644
index 0000000..c2fb1c8
--- /dev/null
+++ b/gatehouse_app/models/user/user.py
@@ -0,0 +1,209 @@
+"""User model."""
+from gatehouse_app.extensions import db
+from gatehouse_app.models.base import BaseModel
+from gatehouse_app.utils.constants import UserStatus
+
+
+class User(BaseModel):
+ """User model representing a user account."""
+
+ __tablename__ = "users"
+
+ email = db.Column(db.String(255), unique=True, nullable=False, index=True)
+ email_verified = db.Column(db.Boolean, default=False, nullable=False)
+ full_name = db.Column(db.String(255), nullable=True)
+ avatar_url = db.Column(db.String(512), nullable=True)
+ status = db.Column(
+ db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
+ )
+ last_login_at = db.Column(db.DateTime, nullable=True)
+ last_login_ip = db.Column(db.String(45), nullable=True)
+
+ # Account activation (email-link flow)
+ activated = db.Column(db.Boolean, default=True, nullable=False)
+ activation_key = db.Column(db.String(128), unique=True, nullable=True, index=True)
+
+ # Relationships – defined here only for models that don't circular-import
+ authentication_methods = db.relationship(
+ "AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
+ )
+ sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
+ organization_memberships = db.relationship(
+ "OrganizationMember",
+ back_populates="user",
+ cascade="all, delete-orphan",
+ foreign_keys="OrganizationMember.user_id",
+ )
+ audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
+ security_policies = db.relationship(
+ "UserSecurityPolicy",
+ back_populates="user",
+ cascade="all, delete-orphan",
+ foreign_keys="UserSecurityPolicy.user_id",
+ )
+ mfa_compliance = db.relationship(
+ "MfaPolicyCompliance",
+ back_populates="user",
+ cascade="all, delete-orphan",
+ foreign_keys="MfaPolicyCompliance.user_id",
+ )
+ department_memberships = db.relationship(
+ "DepartmentMembership",
+ back_populates="user",
+ cascade="all, delete-orphan",
+ foreign_keys="DepartmentMembership.user_id",
+ )
+ principal_memberships = db.relationship(
+ "PrincipalMembership",
+ back_populates="user",
+ cascade="all, delete-orphan",
+ foreign_keys="PrincipalMembership.user_id",
+ )
+ ssh_keys = db.relationship(
+ "SSHKey",
+ back_populates="user",
+ cascade="all, delete-orphan",
+ foreign_keys="SSHKey.user_id",
+ )
+ ssh_certificates = db.relationship(
+ "SSHCertificate",
+ back_populates="user",
+ cascade="all, delete-orphan",
+ foreign_keys="SSHCertificate.user_id",
+ )
+ ca_permissions = db.relationship(
+ "CAPermission",
+ back_populates="user",
+ cascade="all, delete-orphan",
+ foreign_keys="CAPermission.user_id",
+ )
+
+ # OIDC relationships – registered here (no monkey-patching needed)
+ oidc_auth_codes = db.relationship(
+ "OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
+ )
+ oidc_refresh_tokens = db.relationship(
+ "OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
+ )
+ oidc_sessions = db.relationship(
+ "OIDCSession", back_populates="user", cascade="all, delete-orphan"
+ )
+ oidc_token_metadata = db.relationship(
+ "OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
+ )
+ oidc_audit_logs = db.relationship(
+ "OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
+ )
+
+ def __repr__(self):
+ """String representation of User."""
+ return f""
+
+ def to_dict(self, exclude=None):
+ """Convert user to dictionary, excluding sensitive fields by default."""
+ exclude = exclude or []
+ return super().to_dict(exclude=exclude)
+
+ def has_password_auth(self):
+ """Check if user has password authentication enabled."""
+ from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
+ from gatehouse_app.utils.constants import AuthMethodType
+
+ return (
+ AuthenticationMethod.query.filter_by(
+ user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
+ ).first()
+ is not None
+ )
+
+ def get_organizations(self):
+ """Get all organizations the user is a member of."""
+ return [membership.organization for membership in self.organization_memberships]
+
+ def has_totp_enabled(self) -> bool:
+ """Check if user has TOTP enabled and verified.
+
+ Returns:
+ True if user has a verified TOTP authentication method, False otherwise.
+ """
+ from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
+ from gatehouse_app.utils.constants import AuthMethodType
+
+ return (
+ AuthenticationMethod.query.filter_by(
+ user_id=self.id,
+ method_type=AuthMethodType.TOTP,
+ verified=True,
+ deleted_at=None,
+ ).first()
+ is not None
+ )
+
+ def get_totp_method(self):
+ """Get user's TOTP authentication method.
+
+ Returns:
+ The AuthenticationMethod instance for TOTP or None if not found.
+
+ Note:
+ Returns the most recently created TOTP method to handle cases where
+ multiple enrollment attempts may exist.
+ """
+ from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
+ from gatehouse_app.utils.constants import AuthMethodType
+
+ return (
+ AuthenticationMethod.query.filter_by(
+ user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
+ )
+ .order_by(AuthenticationMethod.created_at.desc())
+ .first()
+ )
+
+ def has_webauthn_enabled(self) -> bool:
+ """Check if user has any WebAuthn passkey credentials.
+
+ Returns:
+ True if user has at least one WebAuthn credential, False otherwise.
+ """
+ from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
+ from gatehouse_app.utils.constants import AuthMethodType
+
+ return (
+ AuthenticationMethod.query.filter_by(
+ user_id=self.id,
+ method_type=AuthMethodType.WEBAUTHN,
+ deleted_at=None,
+ ).first()
+ is not None
+ )
+
+ def get_webauthn_credentials(self):
+ """Get all WebAuthn credentials for the user.
+
+ Returns:
+ List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
+ """
+ from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
+ from gatehouse_app.utils.constants import AuthMethodType
+
+ return (
+ AuthenticationMethod.query.filter_by(
+ user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
+ )
+ .order_by(AuthenticationMethod.created_at.desc())
+ .all()
+ )
+
+ def get_webauthn_credential_count(self) -> int:
+ """Get the count of WebAuthn credentials for the user.
+
+ Returns:
+ Number of WebAuthn credentials.
+ """
+ from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
+ from gatehouse_app.utils.constants import AuthMethodType
+
+ return AuthenticationMethod.query.filter_by(
+ user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
+ ).count()
diff --git a/gatehouse_app/schemas/auth_schema.py b/gatehouse_app/schemas/auth_schema.py
index b2042f7..dff1758 100644
--- a/gatehouse_app/schemas/auth_schema.py
+++ b/gatehouse_app/schemas/auth_schema.py
@@ -25,7 +25,7 @@ class LoginSchema(Schema):
email = fields.Email(required=True)
password = fields.Str(required=True, validate=validate.Length(min=1))
- remember_me = fields.Bool(missing=False)
+ remember_me = fields.Bool(load_default=False)
class RefreshTokenSchema(Schema):
@@ -77,14 +77,38 @@ class TOTPVerifyEnrollmentSchema(Schema):
class TOTPVerifySchema(Schema):
"""Schema for TOTP code verification during login."""
- code = fields.Str(required=True)
- is_backup_code = fields.Bool(missing=False)
+ code = fields.Str(
+ required=True,
+ validate=validate.Length(min=1),
+ )
+ is_backup_code = fields.Bool(load_default=False)
client_timestamp = fields.Int(
required=False,
allow_none=True,
metadata={"description": "Client UTC timestamp in seconds since epoch for TOTP verification"},
)
+ @validates_schema
+ def validate_code_format(self, data, **kwargs):
+ """Validate code format depending on whether it's a backup code."""
+ code = data.get("code", "")
+ is_backup_code = data.get("is_backup_code", False)
+ if is_backup_code:
+ # Backup codes are 16 uppercase hex characters
+ if not code or len(code) != 16 or not all(c in "0123456789ABCDEFabcdef" for c in code):
+ raise ValidationError(
+ "Backup code must be a 16-character hexadecimal string.",
+ field_name="code",
+ )
+ else:
+ # Regular TOTP codes are exactly 6 digits
+ import re
+ if not re.match(r"^\d{6}$", code):
+ raise ValidationError(
+ "Code must be a 6-digit number.",
+ field_name="code",
+ )
+
class TOTPDisableSchema(Schema):
"""Schema for disabling TOTP."""
diff --git a/gatehouse_app/services/audit_service.py b/gatehouse_app/services/audit_service.py
index 3978aa5..ed70f14 100644
--- a/gatehouse_app/services/audit_service.py
+++ b/gatehouse_app/services/audit_service.py
@@ -1,6 +1,6 @@
"""Audit service."""
from flask import request, g
-from gatehouse_app.models.audit_log import AuditLog
+from gatehouse_app.models.auth.audit_log import AuditLog
from gatehouse_app.utils.constants import AuditAction
@@ -59,7 +59,7 @@ class AuditService:
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
- metadata=metadata,
+ extra_data=metadata,
description=description,
success=success,
error_message=error_message,
diff --git a/gatehouse_app/services/auth_service.py b/gatehouse_app/services/auth_service.py
index 3df5bf3..9061f33 100644
--- a/gatehouse_app/services/auth_service.py
+++ b/gatehouse_app/services/auth_service.py
@@ -5,9 +5,9 @@ from datetime import datetime, timedelta, timezone
from typing import Optional
from flask import request, g, current_app
from gatehouse_app.extensions import db, bcrypt
-from gatehouse_app.models.user import User
-from gatehouse_app.models.authentication_method import AuthenticationMethod
-from gatehouse_app.models.session import Session
+from gatehouse_app.models.user.user import User
+from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
+from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, UserStatus, AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError
@@ -102,7 +102,7 @@ class AuthService:
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}")
- if user.status == UserStatus.SUSPENDED:
+ if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
raise AccountSuspendedError()
if user.status == UserStatus.INACTIVE:
raise AccountInactiveError()
@@ -210,6 +210,22 @@ class AuthService:
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
db.session.commit()
+ # Invalidate all other sessions so that if an attacker had a valid
+ # session token, changing the password actually locks them out.
+ # The current request's session (if any) is preserved so the user
+ # doesn't have to log in again immediately.
+ from flask import g as flask_g
+ current_session_id = getattr(flask_g, "current_session", None)
+ current_session_id = current_session_id.id if current_session_id else None
+ sessions_to_revoke = Session.query.filter(
+ Session.user_id == user.id,
+ Session.revoked_at == None, # noqa: E711
+ ).all()
+ for sess in sessions_to_revoke:
+ if sess.id != current_session_id:
+ sess.revoke(reason="Password changed")
+ db.session.commit()
+
# Log password change
AuditService.log_action(
action=AuditAction.PASSWORD_CHANGE,
@@ -482,9 +498,24 @@ class AuthService:
if not secret:
raise InvalidCredentialsError("TOTP secret not found")
+ # Replay-attack prevention: reject codes that have already been
+ # accepted within the current validity window.
+ if TOTPService.is_code_already_used(str(user.id), code):
+ AuditService.log_action(
+ action=AuditAction.TOTP_VERIFY_FAILED,
+ user_id=user.id,
+ resource_type="authentication_method",
+ resource_id=auth_method.id,
+ description="TOTP code replay attempt detected",
+ )
+ raise InvalidCredentialsError("Invalid TOTP code")
+
is_valid = TOTPService.verify_code(secret, code, client_utc_timestamp=client_utc_timestamp)
if is_valid:
+ # Mark this code as used to prevent replay within the validity window
+ TOTPService.mark_code_used(str(user.id), code)
+
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
diff --git a/gatehouse_app/services/external_auth_service.py b/gatehouse_app/services/external_auth_service.py
index 89bfbf1..57cad76 100644
--- a/gatehouse_app/services/external_auth_service.py
+++ b/gatehouse_app/services/external_auth_service.py
@@ -8,7 +8,7 @@ from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models import User, AuthenticationMethod
-from gatehouse_app.models.authentication_method import (
+from gatehouse_app.models.auth.authentication_method import (
OAuthState,
ApplicationProviderConfig,
OrganizationProviderOverride
@@ -736,9 +736,14 @@ class ExternalAuthService:
400,
)
- # Generate PKCE
- code_verifier = secrets.token_urlsafe(32)
- code_challenge = cls._compute_s256_challenge(code_verifier)
+ # Generate PKCE — skip for confidential clients (Google, Microsoft) that use a
+ # client_secret. Sending code_challenge to Microsoft causes it to enforce PKCE on
+ # the token exchange, which then fails. Matches the behaviour of initiate_login_flow.
+ code_verifier = None
+ code_challenge = None
+ if provider_type_str not in ('google', 'microsoft'):
+ code_verifier = secrets.token_urlsafe(32)
+ code_challenge = cls._compute_s256_challenge(code_verifier)
# Create OAuth state
state = OAuthState.create_state(
@@ -1210,12 +1215,35 @@ class ExternalAuthService:
else:
email_verified = data.get("email_verified", False)
+ sub = data.get("sub")
+
+ # Derive email from sub when the provider omits the email claim.
+ # This happens with some OIDC servers (including the nav-security mock)
+ # that only return the minimal {sub, iss, iat, exp} set.
+ # Rule: if sub looks like an email address, use it directly.
+ # Otherwise, construct a deterministic fallback so we never get NULL.
+ raw_email = data.get("email")
+ if not raw_email and sub:
+ import re as _re
+ if _re.match(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", sub):
+ raw_email = sub
+ email_verified = True # if sub IS the email it's already verified
+ else:
+ # e.g. "12345" → "12345@google.local" so we can store it
+ raw_email = f"{sub}@{provider or 'oauth'}.local"
+ email_verified = False
+
+ # Derive display name when omitted
+ raw_name = data.get("name") or data.get("display_name")
+ if not raw_name and raw_email:
+ raw_name = raw_email.split("@")[0]
+
# Standardize user info
return {
- "provider_user_id": data.get("sub"),
- "email": data.get("email"),
+ "provider_user_id": sub,
+ "email": raw_email,
"email_verified": email_verified,
- "name": data.get("name"),
+ "name": raw_name,
"first_name": data.get("given_name"),
"last_name": data.get("family_name"),
"picture": data.get("picture"),
diff --git a/gatehouse_app/services/mfa_policy_service.py b/gatehouse_app/services/mfa_policy_service.py
index d553fc1..d7b1316 100644
--- a/gatehouse_app/services/mfa_policy_service.py
+++ b/gatehouse_app/services/mfa_policy_service.py
@@ -4,11 +4,11 @@ from datetime import datetime, timezone
from typing import Optional, List, Dict, Any
from gatehouse_app.extensions import db
-from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
-from gatehouse_app.models.user_security_policy import UserSecurityPolicy
-from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
-from gatehouse_app.models.user import User
-from gatehouse_app.models.organization import Organization
+from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy
+from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy
+from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
+from gatehouse_app.models.user.user import User
+from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.utils.constants import (
MfaPolicyMode,
@@ -702,7 +702,7 @@ class MfaPolicyService:
if now is None:
now = datetime.now(timezone.utc)
- from gatehouse_app.models.organization_member import OrganizationMember
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
updated_count = 0
diff --git a/gatehouse_app/services/notification_service.py b/gatehouse_app/services/notification_service.py
index d3afdca..fc9bbe0 100644
--- a/gatehouse_app/services/notification_service.py
+++ b/gatehouse_app/services/notification_service.py
@@ -19,9 +19,9 @@ import logging
import json
from gatehouse_app.extensions import db
-from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
-from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
-from gatehouse_app.models.user import User
+from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
+from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy
+from gatehouse_app.models.user.user import User
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.utils.constants import AuditAction
@@ -37,6 +37,7 @@ class NotificationService:
SMTP_PORT_KEY = "SMTP_PORT"
SMTP_USERNAME_KEY = "SMTP_USERNAME"
SMTP_PASSWORD_KEY = "SMTP_PASSWORD"
+ SMTP_USE_TLS_KEY = "SMTP_USE_TLS"
FROM_ADDRESS_KEY = "FROM_ADDRESS"
@staticmethod
@@ -86,10 +87,9 @@ class NotificationService:
if success:
logger.info(
f"Sent MFA deadline reminder to {user.email} "
- f"({days_until_deadline} days remaining # Audit log
-)"
+ f"({days_until_deadline} days remaining)"
)
- AuditService.log_action(
+ AuditService.log_action(
action=AuditAction.MFA_POLICY_USER_COMPLIANT,
user_id=user.id,
organization_id=compliance.organization_id,
@@ -291,101 +291,62 @@ Gatehouse Security Team
body: str,
html_body: Optional[str] = None,
) -> bool:
- """Send an email notification.
+ """Send an email via SMTP.
- This method attempts to send an email using configured SMTP settings.
- If email is not configured, it logs the notification instead.
-
- Args:
- to_address: Recipient email address
- subject: Email subject
- body: Plain text email body
- html_body: Optional HTML email body
-
- Returns:
- True if email was sent (or logged), False on error
+ Returns True if the email was sent successfully, False otherwise.
+ If EMAIL_ENABLED is False, logs the email body instead (simulation mode).
"""
+ import smtplib
+ from email.mime.multipart import MIMEMultipart
+ from email.mime.text import MIMEText
+ from flask import current_app
+
+ email_enabled = current_app.config.get(NotificationService.EMAIL_ENABLED_KEY, False)
+
+ if not email_enabled:
+ logger.info(
+ f"[EMAIL DISABLED] Would have sent to: {to_address} | Subject: {subject}\n"
+ f"Body: {body[:500]}"
+ )
+ return False
+
+ smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY, "localhost")
+ smtp_port = int(current_app.config.get(NotificationService.SMTP_PORT_KEY, 587))
+ smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY)
+ smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY)
+ smtp_use_tls = current_app.config.get(
+ NotificationService.SMTP_USE_TLS_KEY,
+ smtp_port not in (25, 1025),
+ )
+ from_address = current_app.config.get(
+ NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local"
+ )
+
try:
- from flask import current_app
+ msg = MIMEMultipart("alternative")
+ msg["Subject"] = subject
+ msg["From"] = from_address
+ msg["To"] = to_address
+ msg.attach(MIMEText(body, "plain"))
+ if html_body:
+ msg.attach(MIMEText(html_body, "html"))
- # Check if email is configured
- email_enabled = current_app.config.get(
- NotificationService.EMAIL_ENABLED_KEY, False
- )
-
- if not email_enabled:
- # Log the notification instead of sending
- logger.info(
- f"[EMAIL SIMULATION] To: {to_address}\n"
- f"Subject: {subject}\n"
- f"Body: {body[:200]}..." if len(body) > 200 else f"Body: {body}"
- )
- return True
-
- # Get email configuration
- smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY)
- smtp_port = current_app.config.get(NotificationService.SMTP_PORT_KEY, 587)
- smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY)
- smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY)
- from_address = current_app.config.get(
- NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local"
- )
-
- # Import send_email based on available mail library
- try:
- from flask_mail import Message
-
- from gatehouse_app import mail
-
- msg = Message(
- subject=subject,
- recipients=[to_address],
- body=body,
- html=html_body,
- sender=from_address,
- )
- mail.send(msg)
- logger.info(f"Email sent successfully to {to_address}")
- return True
-
- except ImportError:
- # Flask-Mail not available, use SMTP directly
- import smtplib
- from email.mime.text import MIMEText
- from email.mime.multipart import MIMEMultipart
-
- msg = MIMEMultipart("alternative")
- msg["Subject"] = subject
- msg["From"] = from_address
- msg["To"] = to_address
-
- # Attach plain text and HTML versions
- part1 = MIMEText(body, "plain")
- msg.attach(part1)
-
- if html_body:
- part2 = MIMEText(html_body, "html")
- msg.attach(part2)
-
- # Send via SMTP
- with smtplib.SMTP(smtp_host, smtp_port) as server:
+ with smtplib.SMTP(smtp_host, smtp_port) as server:
+ server.ehlo()
+ if smtp_use_tls:
server.starttls()
- if smtp_username and smtp_password:
- server.login(smtp_username, smtp_password)
- server.send_message(msg)
+ server.ehlo()
+ if smtp_username and smtp_password:
+ server.login(smtp_username, smtp_password)
+ server.send_message(msg)
- logger.info(f"Email sent successfully to {to_address}")
- return True
+ logger.info(f"[EMAIL] Sent to {to_address} | Subject: {subject}")
+ return True
except Exception as e:
- logger.exception(f"Failed to send email to {to_address}: {e}")
- # Log the notification as fallback
- logger.info(
- f"[EMAIL FALLBACK] To: {to_address}\n"
- f"Subject: {subject}\n"
- f"Body: {body[:500]}..." if len(body) > 500 else f"Body: {body}"
- )
- return True # Return True to continue processing
+ logger.error(f"[EMAIL] Failed to send to {to_address}: {e}")
+ return False
+
@staticmethod
def get_notification_stats(user_id: str) -> Dict[str, Any]:
@@ -397,7 +358,7 @@ Gatehouse Security Team
Returns:
Dictionary with notification statistics
"""
- from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
+ from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
stats = {
"total_notifications": 0,
diff --git a/gatehouse_app/services/oauth_flow_service.py b/gatehouse_app/services/oauth_flow_service.py
index 89b1745..2889a7c 100644
--- a/gatehouse_app/services/oauth_flow_service.py
+++ b/gatehouse_app/services/oauth_flow_service.py
@@ -9,10 +9,10 @@ from flask import current_app, request, g, redirect
from gatehouse_app.extensions import db
from gatehouse_app.models import User, AuthenticationMethod
-from gatehouse_app.models.authentication_method import OAuthState
+from gatehouse_app.models.auth.authentication_method import OAuthState
from gatehouse_app.models.base import BaseModel
-from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
-from gatehouse_app.utils.constants import AuthMethodType
+from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
+from gatehouse_app.utils.constants import AuthMethodType, AuditAction
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.services.external_auth_service import (
ExternalAuthService,
@@ -139,7 +139,7 @@ class OAuthFlowService:
except ExternalAuthError as e:
# Log failed initiation
AuditService.log_action(
- action="external_auth.login.initiated",
+ action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
organization_id=organization_id,
metadata={
"provider_type": provider_type_str,
@@ -236,7 +236,7 @@ class OAuthFlowService:
except ExternalAuthError as e:
AuditService.log_action(
- action="external_auth.register.initiated",
+ action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
organization_id=organization_id,
metadata={
"provider_type": provider_type_str,
@@ -399,6 +399,27 @@ class OAuthFlowService:
access_token=tokens["access_token"],
)
+ if not user_info.get("provider_user_id"):
+ raise OAuthFlowError(
+ "Provider did not return a user identifier (sub claim). "
+ "Cannot complete authentication.",
+ "MISSING_PROVIDER_USER_ID",
+ 400,
+ )
+
+ if not user_info.get("email"):
+ raise OAuthFlowError(
+ "Provider did not return an email address. "
+ "Cannot complete authentication.",
+ "MISSING_EMAIL",
+ 400,
+ )
+
+ logger.debug(
+ f"Got user_info from provider: sub={user_info['provider_user_id']}, "
+ f"email={user_info['email']}, email_verified={user_info.get('email_verified')}"
+ )
+
# Look up user by provider_user_id
auth_method = AuthenticationMethod.query.filter_by(
method_type=provider_type,
@@ -494,25 +515,52 @@ class OAuthFlowService:
if not target_org and len(user_orgs) == 1:
target_org = user_orgs[0]
- # Priority 3: No orgs at all — auto-create a personal org and log in
+ # Priority 3: No orgs at all — send to org-setup instead of auto-creating
if not target_org and len(user_orgs) == 0:
- import re
- import uuid
- from gatehouse_app.services.organization_service import OrganizationService
- org_name = f"{user_info.get('name') or user.email.split('@')[0]}'s Workspace"
- # Build a URL-safe slug and ensure uniqueness with a short suffix
- base_slug = re.sub(r"[^a-z0-9]+", "-", org_name.lower()).strip("-")[:40]
- slug = f"{base_slug}-{uuid.uuid4().hex[:6]}"
- org = OrganizationService.create_organization(
- name=org_name,
- slug=slug,
- owner_user_id=user.id,
- )
- target_org = org
+ from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
+ from gatehouse_app.services.auth_service import AuthService as _AS
+ _now = datetime.now(timezone.utc)
+ _session = _AS.create_session(user=user, is_compliance_only=False)
+ _session_dict = _session.to_dict()
+ _session_dict["token"] = _session.token
+ _expires_at = _session.expires_at
+ if _expires_at.tzinfo is None:
+ _expires_at = _expires_at.replace(tzinfo=timezone.utc)
+ _session_dict["expires_in"] = int((_expires_at - _now).total_seconds())
+
+ _pending = OrgInviteToken.query.filter(
+ OrgInviteToken.email == user.email,
+ OrgInviteToken.accepted_at.is_(None),
+ OrgInviteToken.expires_at > _now,
+ OrgInviteToken.deleted_at.is_(None),
+ ).all()
+ _pending_list = [
+ {
+ "token": inv.token,
+ "organization": {
+ "id": str(inv.organization_id),
+ "name": inv.organization.name,
+ },
+ "role": inv.role,
+ "expires_at": inv.expires_at.isoformat(),
+ }
+ for inv in _pending
+ ]
+
+ state_record.mark_used()
logger.info(
- f"OAuth login: auto-created org '{org.name}' (id={org.id}) "
- f"for new user {user.id}"
+ f"OAuth login: user {user.id} has no org, redirecting to org-setup "
+ f"(pending_invites={len(_pending_list)})"
)
+ return {
+ "success": True,
+ "flow_type": "login",
+ "requires_org_creation": True,
+ "user": {"id": user.id, "email": user.email, "full_name": user.full_name},
+ "session": _session_dict,
+ "pending_invites": _pending_list,
+ "state": state_record.state,
+ }
# Priority 4: Multiple orgs — need user to pick one
if not target_org:
@@ -755,7 +803,7 @@ class OAuthFlowService:
# If organization_id hint was provided and valid, create session for that org
if state_record.organization_id:
- from gatehouse_app.models.organization import Organization
+ from gatehouse_app.models.organization.organization import Organization
org = Organization.query.get(state_record.organization_id)
if org:
from gatehouse_app.services.auth_service import AuthService
@@ -784,7 +832,40 @@ class OAuthFlowService:
"session": session_dict,
}
- # No organization hint or invalid - need to create/select org
+ # No organization hint or invalid - need to create/select org.
+ # Still create a session so the frontend can call /organizations
+ # and /invites after redirecting to /org-setup.
+ from gatehouse_app.services.auth_service import AuthService as _AS
+ from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
+ _session = _AS.create_session(user=user, is_compliance_only=False)
+ _session_dict = _session.to_dict()
+ _session_dict["token"] = _session.token
+ _expires_at = _session.expires_at
+ if _expires_at.tzinfo is None:
+ _expires_at = _expires_at.replace(tzinfo=timezone.utc)
+ _now = datetime.now(timezone.utc)
+ _session_dict["expires_in"] = int((_expires_at - _now).total_seconds())
+
+ # Surface pending invitations so the UI can offer "join vs create"
+ _pending = OrgInviteToken.query.filter(
+ OrgInviteToken.email == user.email,
+ OrgInviteToken.accepted_at.is_(None),
+ OrgInviteToken.expires_at > _now,
+ OrgInviteToken.deleted_at.is_(None),
+ ).all()
+ _pending_list = [
+ {
+ "token": inv.token,
+ "organization": {
+ "id": str(inv.organization_id),
+ "name": inv.organization.name,
+ },
+ "role": inv.role,
+ "expires_at": inv.expires_at.isoformat(),
+ }
+ for inv in _pending
+ ]
+
return {
"success": True,
"flow_type": "register",
@@ -794,6 +875,8 @@ class OAuthFlowService:
"email": user.email,
"full_name": user.full_name,
},
+ "session": _session_dict,
+ "pending_invites": _pending_list,
"state": state_record.state,
}
@@ -966,8 +1049,8 @@ class OAuthFlowService:
)
# Determine organization
- from gatehouse_app.models.organization import Organization
- from gatehouse_app.models.organization_member import OrganizationMember
+ from gatehouse_app.models.organization.organization import Organization
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
# Get user's organizations
user_orgs = user.get_organizations()
diff --git a/gatehouse_app/services/oidc_jwks_service.py b/gatehouse_app/services/oidc_jwks_service.py
index c8423ef..269dc08 100644
--- a/gatehouse_app/services/oidc_jwks_service.py
+++ b/gatehouse_app/services/oidc_jwks_service.py
@@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Tuple
from flask import current_app
from gatehouse_app.extensions import db
-from gatehouse_app.models.oidc_jwks_key import OidcJwksKey
+from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey
class JWKSKey:
diff --git a/gatehouse_app/services/oidc_service.py b/gatehouse_app/services/oidc_service.py
index 6157c21..fc6e317 100644
--- a/gatehouse_app/services/oidc_service.py
+++ b/gatehouse_app/services/oidc_service.py
@@ -14,7 +14,7 @@ from gatehouse_app.models import (
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
OIDCSession, OIDCTokenMetadata
)
-from gatehouse_app.models.organization_member import OrganizationMember
+from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.exceptions.validation_exceptions import (
ValidationError, NotFoundError, BadRequestError
)
diff --git a/gatehouse_app/services/oidc_token_service.py b/gatehouse_app/services/oidc_token_service.py
index 3c1c929..5605d5f 100644
--- a/gatehouse_app/services/oidc_token_service.py
+++ b/gatehouse_app/services/oidc_token_service.py
@@ -11,7 +11,7 @@ import jwt
from flask import current_app, g
from gatehouse_app.models import User, OIDCClient
-from gatehouse_app.models.organization_member import OrganizationMember
+from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
logger = logging.getLogger(__name__)
diff --git a/gatehouse_app/services/organization_service.py b/gatehouse_app/services/organization_service.py
index e802b84..27ec7f4 100644
--- a/gatehouse_app/services/organization_service.py
+++ b/gatehouse_app/services/organization_service.py
@@ -3,8 +3,8 @@ import logging
from datetime import datetime, timezone
from flask import current_app
from gatehouse_app.extensions import db
-from gatehouse_app.models.organization import Organization
-from gatehouse_app.models.organization_member import OrganizationMember
+from gatehouse_app.models.organization.organization import Organization
+from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.exceptions.validation_exceptions import OrganizationNotFoundError, ConflictError
from gatehouse_app.utils.constants import OrganizationRole, AuditAction
from gatehouse_app.services.audit_service import AuditService
@@ -188,11 +188,10 @@ class OrganizationService:
Raises:
ConflictError: If user is already a member
"""
- # Check if already a member
+ # Check if already a member (active or soft-deleted — both blocked by DB unique constraint)
existing = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org.id,
- deleted_at=None,
).first()
# Development-only debug logging for membership validation
@@ -200,6 +199,25 @@ class OrganizationService:
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
if existing:
+ if existing.deleted_at is not None:
+ # Reactivate the soft-deleted membership with the new role
+ existing.deleted_at = None
+ existing.role = role
+ existing.invited_by_id = inviter_id
+ existing.invited_at = datetime.now(timezone.utc)
+ existing.joined_at = datetime.now(timezone.utc)
+ existing.save()
+
+ AuditService.log_action(
+ action=AuditAction.ORG_MEMBER_ADD,
+ user_id=inviter_id,
+ organization_id=org.id,
+ resource_type="organization_member",
+ resource_id=existing.id,
+ metadata={"added_user_id": user_id, "role": role.value},
+ description=f"Member re-added to organization with role: {role.value}",
+ )
+ return existing
raise ConflictError("User is already a member of this organization")
# Create membership
diff --git a/gatehouse_app/services/session_service.py b/gatehouse_app/services/session_service.py
index 7186222..7103285 100644
--- a/gatehouse_app/services/session_service.py
+++ b/gatehouse_app/services/session_service.py
@@ -1,6 +1,6 @@
"""Session service."""
from datetime import datetime, timezone
-from gatehouse_app.models.session import Session
+from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import SessionStatus
@@ -17,7 +17,7 @@ class SessionService:
Returns:
Session object if found and active, None otherwise
"""
- from gatehouse_app.models.session import Session
+ from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import SessionStatus
return Session.query.filter_by(
token=token,
diff --git a/gatehouse_app/services/ssh_ca_signing_service.py b/gatehouse_app/services/ssh_ca_signing_service.py
new file mode 100644
index 0000000..1502ab8
--- /dev/null
+++ b/gatehouse_app/services/ssh_ca_signing_service.py
@@ -0,0 +1,359 @@
+"""SSH Certificate Authority signing service.
+
+Handles SSH certificate signing operations, leveraging sshkey-tools library.
+This service is a Gatehouse-integrated version of the secuird/ssh_ca.py logic.
+"""
+import logging
+import os
+from datetime import datetime, timedelta, timezone
+from typing import List, Dict, Any, Optional
+
+from sshkey_tools.cert import SSHCertificate, CertificateFields
+from sshkey_tools.keys import PublicKey, PrivateKey
+
+from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
+from gatehouse_app.exceptions import SSHCAError, ValidationError
+from gatehouse_app.utils.crypto import compute_ssh_fingerprint
+
+logger = logging.getLogger(__name__)
+
+
+class SSHCASigningError(Exception):
+ """SSH CA signing operation error."""
+ pass
+
+
+class SSHCertificateSigningRequest:
+ """Represents an SSH certificate signing request."""
+
+ def __init__(
+ self,
+ ssh_public_key: str,
+ principals: List[str],
+ key_id: str,
+ cert_type: str = "user",
+ expiry_hours: Optional[int] = None,
+ critical_options: Optional[Dict[str, str]] = None,
+ extensions: Optional[List[str]] = None,
+ ):
+ """Initialize signing request.
+
+ Args:
+ ssh_public_key: Public key in OpenSSH format (e.g., "ssh-ed25519 AAAA...")
+ principals: List of principals (e.g., ["prod-servers", "staging"])
+ key_id: Key identifier (usually user email)
+ cert_type: Certificate type - "user" or "host" (default: user)
+ expiry_hours: Certificate validity in hours
+ critical_options: Critical options dict
+ extensions: List of extensions (e.g., ["permit-pty", "permit-agent-forwarding"])
+ """
+ self.ssh_public_key = ssh_public_key
+ self.principals = principals or []
+ self.key_id = key_id
+ self.cert_type = cert_type
+ self.expiry_hours = expiry_hours
+ self.critical_options = critical_options or {}
+ self.extensions = extensions or []
+
+ def validate(self) -> List[str]:
+ """Validate the signing request.
+
+ Returns:
+ List of validation errors (empty if valid)
+ """
+ errors = []
+ config = get_ssh_ca_config()
+
+ # Validate cert type
+ if self.cert_type not in ("user", "host"):
+ errors.append(f"Invalid cert_type: {self.cert_type}. Must be 'user' or 'host'")
+
+ # Validate SSH public key
+ if not self.ssh_public_key or len(self.ssh_public_key) < 16:
+ errors.append("SSH public key is missing or invalid")
+ else:
+ try:
+ PublicKey.from_string(self.ssh_public_key)
+ except Exception as e:
+ errors.append(f"SSH public key is not valid: {str(e)}")
+
+ # Validate principals
+ if not self.principals or len(self.principals) == 0:
+ errors.append("At least one principal is required")
+ else:
+ max_principals = config.get_int('max_principals_per_cert')
+ if len(self.principals) > max_principals:
+ errors.append(
+ f"Too many principals ({len(self.principals)}). "
+ f"Maximum is {max_principals}"
+ )
+
+ # Validate key_id
+ if not self.key_id or len(self.key_id) < 5:
+ errors.append("key_id is missing or too short (minimum 5 characters)")
+ else:
+ max_id_len = config.get_int('max_key_id_length')
+ if len(self.key_id) > max_id_len:
+ errors.append(f"key_id exceeds maximum length of {max_id_len}")
+
+ # Validate expiry_hours
+ if self.expiry_hours is not None:
+ if not isinstance(self.expiry_hours, int) or self.expiry_hours <= 0:
+ errors.append("expiry_hours must be a positive integer")
+ else:
+ max_validity = config.get_int('max_cert_validity_hours')
+ if self.expiry_hours > max_validity:
+ errors.append(
+ f"Requested expiry ({self.expiry_hours}h) exceeds "
+ f"maximum allowed ({max_validity}h)"
+ )
+
+ return errors
+
+
+class SSHCertificateSigningResponse:
+ """Represents a signed SSH certificate response."""
+
+ def __init__(
+ self,
+ certificate: str,
+ serial: str,
+ valid_after: datetime,
+ valid_before: datetime,
+ principals: Optional[List[str]] = None,
+ ):
+ """Initialize signing response.
+
+ Args:
+ certificate: Full certificate in OpenSSH format
+ serial: Certificate serial number
+ valid_after: Validity start datetime
+ valid_before: Validity end datetime
+ principals: List of principals the cert was issued for
+ """
+ self.certificate = certificate
+ self.serial = serial
+ self.valid_after = valid_after
+ self.valid_before = valid_before
+ self.principals = principals or []
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert response to dictionary."""
+ return {
+ 'certificate': self.certificate,
+ 'serial': self.serial,
+ 'valid_after': self.valid_after.isoformat(),
+ 'valid_before': self.valid_before.isoformat(),
+ }
+
+
+class SSHCASigningService:
+ """Service for signing SSH certificates.
+
+ This service handles all SSH certificate signing operations.
+ It uses configuration from ssh_ca_config to apply rules and limits.
+ """
+
+ def __init__(self):
+ """Initialize the SSH CA signing service."""
+ self.config = get_ssh_ca_config()
+ self.logger = logger
+
+ def _load_ca_key_from_config(self) -> str:
+ """Load CA private key from config (local file or env var).
+
+ Returns:
+ CA private key in PEM/OpenSSH format as string
+
+ Raises:
+ SSHCASigningError: If key cannot be loaded
+ """
+ # Check env var first
+ key_content = os.environ.get('SSH_CA_PRIVATE_KEY')
+ if key_content:
+ return key_content
+
+ # Load from file path
+ key_path = self.config.get_str('ca_key_path', '').strip()
+ if not key_path:
+ raise SSHCASigningError(
+ "CA private key not configured. Set SSH_CA_PRIVATE_KEY env var "
+ "or ca_key_path in etc/ssh_ca.conf"
+ )
+
+ key_path = os.path.expandvars(os.path.expanduser(key_path))
+ if not os.path.exists(key_path):
+ raise SSHCASigningError(f"CA private key file not found: {key_path}")
+
+ with open(key_path, 'r') as f:
+ return f.read()
+
+ def sign_certificate(
+ self,
+ signing_request: SSHCertificateSigningRequest,
+ ca_private_key: Optional[str] = None,
+ ca_obj=None,
+ ) -> "SSHCertificateSigningResponse":
+ """Sign an SSH certificate.
+
+ Args:
+ signing_request: SSHCertificateSigningRequest instance
+ ca_private_key: CA private key in PEM format. If not provided,
+ loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var)
+ ca_obj: Optional CA model instance. When supplied its monotonic
+ serial counter is incremented atomically (SELECT FOR UPDATE)
+ and the resulting integer is embedded in the certificate's
+ serial field. This ensures every issued cert has a unique,
+ ordered, auditable serial number.
+
+ Returns:
+ SSHCertificateSigningResponse with signed certificate
+
+ Raises:
+ SSHCASigningError: If signing fails
+ ValidationError: If request is invalid
+ """
+ # Validate request
+ errors = signing_request.validate()
+ if errors:
+ error_msg = "; ".join(errors)
+ self.logger.error(f"Certificate signing validation failed: {error_msg}")
+ raise ValidationError(f"Certificate signing validation failed: {error_msg}")
+
+ # Load CA key if not provided
+ if ca_private_key is None:
+ ca_private_key = self._load_ca_key_from_config()
+
+ try:
+ # Parse CA private key
+ try:
+ ca_key = PrivateKey.from_string(ca_private_key)
+ except Exception as e:
+ self.logger.error(f"Failed to load CA private key: {str(e)}")
+ raise SSHCASigningError(f"Invalid CA private key: {str(e)}")
+
+ # Parse user's public key
+ try:
+ user_pub_key = PublicKey.from_string(signing_request.ssh_public_key)
+ except Exception as e:
+ self.logger.error(f"Failed to parse user public key: {str(e)}")
+ raise SSHCASigningError(f"Invalid user public key: {str(e)}")
+
+ # Create certificate
+ certificate = SSHCertificate.create(
+ subject_pubkey=user_pub_key,
+ ca_privkey=ca_key,
+ )
+
+ # Set validity period
+ now = datetime.now(timezone.utc)
+ expiry_hours = signing_request.expiry_hours or self.config.get_int('cert_validity_hours')
+ valid_before = now + timedelta(hours=expiry_hours)
+
+ # Set certificate fields
+ # sshkey-tools: user=1, host=2 (not 0)
+ cert_type = 1 if signing_request.cert_type == "user" else 2
+
+ certificate.fields.cert_type = cert_type
+ certificate.fields.key_id = signing_request.key_id
+ certificate.fields.principals = signing_request.principals
+ certificate.fields.valid_after = now
+ certificate.fields.valid_before = valid_before
+
+ # ── Serial number ────────────────────────────────────────────────
+ # If a CA object is provided, use its monotonic counter so every
+ # certificate gets a unique, ordered, auditable serial. The
+ # counter increment is flushed inside get_next_serial(); the
+ # caller's commit() persists it atomically with the cert record.
+ if ca_obj is not None:
+ assigned_serial = ca_obj.get_next_serial()
+ certificate.fields.serial = assigned_serial
+ self.logger.debug(
+ f"Assigned serial {assigned_serial} from CA {ca_obj.id}"
+ )
+ # ─────────────────────────────────────────────────────────────────
+
+ # Set extensions — prefer policy-provided list, fall back to standard set
+ extensions = signing_request.extensions
+ if not extensions:
+ from gatehouse_app.models.organization.department_cert_policy import STANDARD_EXTENSIONS
+ extensions = list(STANDARD_EXTENSIONS)
+
+ certificate.fields.extensions = extensions
+ certificate.fields.critical_options = signing_request.critical_options or {}
+
+ # Validate certificate before signing
+ if not certificate.can_sign():
+ raise SSHCASigningError("Certificate cannot be signed")
+
+ # Sign the certificate
+ certificate.sign()
+
+ # Verify the certificate
+ try:
+ certificate.verify(ca_key.public_key, raise_on_error=True)
+ except Exception as e:
+ self.logger.error(f"Certificate verification failed: {str(e)}")
+ raise SSHCASigningError(f"Certificate verification failed: {str(e)}")
+
+ # Extract serial from certificate — use the integer we assigned
+ # when ca_obj was provided, otherwise fall back to whatever the
+ # library generated.
+ if ca_obj is not None:
+ serial = str(assigned_serial)
+ else:
+ serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
+
+ # Build response
+ cert_string = certificate.to_string()
+
+ self.logger.info(
+ f"Successfully signed certificate: serial={serial}, "
+ f"key_id={signing_request.key_id}, principals={signing_request.principals}"
+ )
+
+ return SSHCertificateSigningResponse(
+ certificate=cert_string,
+ serial=serial,
+ valid_after=now,
+ valid_before=valid_before,
+ principals=signing_request.principals,
+ )
+
+ except (SSHCASigningError, ValidationError):
+ raise
+ except Exception as e:
+ self.logger.error(f"Unexpected error during certificate signing: {str(e)}", exc_info=True)
+ raise SSHCASigningError(f"Error signing certificate: {str(e)}")
+
+ def verify_ca_key(self, ca_private_key: str) -> Dict[str, Any]:
+ """Verify a CA private key is valid and extract metadata.
+
+ Args:
+ ca_private_key: CA private key in PEM format
+
+ Returns:
+ Dictionary with key metadata (fingerprint, key_type, etc.)
+
+ Raises:
+ SSHCASigningError: If key is invalid
+ """
+ try:
+ ca_key = PrivateKey.from_string(ca_private_key)
+ pub_key = ca_key.public_key
+
+ # Compute fingerprint
+ fingerprint = compute_ssh_fingerprint(pub_key.to_string())
+
+ # Get key type
+ key_type = pub_key.keytype if hasattr(pub_key, 'keytype') else 'unknown'
+
+ return {
+ 'fingerprint': fingerprint,
+ 'key_type': key_type,
+ 'public_key': pub_key.to_string(),
+ 'valid': True,
+ }
+ except Exception as e:
+ self.logger.error(f"CA key verification failed: {str(e)}")
+ raise SSHCASigningError(f"Invalid CA key: {str(e)}")
diff --git a/gatehouse_app/services/ssh_key_service.py b/gatehouse_app/services/ssh_key_service.py
new file mode 100644
index 0000000..db07b0b
--- /dev/null
+++ b/gatehouse_app/services/ssh_key_service.py
@@ -0,0 +1,375 @@
+"""SSH Key management service."""
+import base64
+import logging
+import os
+import secrets
+import subprocess
+import tempfile
+from datetime import datetime, timedelta
+from typing import Optional, List, Dict, Any
+
+from gatehouse_app.extensions import db
+from gatehouse_app.models import SSHKey, User
+from gatehouse_app.exceptions import (
+ SSHKeyError,
+ SSHKeyNotFoundError,
+ SSHKeyAlreadyExistsError,
+ SSHKeyNotVerifiedError,
+ ValidationError,
+ UserNotFoundError,
+)
+from gatehouse_app.utils.crypto import (
+ compute_ssh_fingerprint,
+ verify_ssh_key_format,
+ extract_ssh_key_type,
+ extract_ssh_key_comment,
+)
+from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
+
+logger = logging.getLogger(__name__)
+
+
+class SSHKeyService:
+ """Service for managing SSH keys."""
+
+ def __init__(self):
+ """Initialize SSH key service."""
+ self.config = get_ssh_ca_config()
+
+ def add_ssh_key(
+ self,
+ user_id: str,
+ public_key: str,
+ description: Optional[str] = None,
+ ) -> SSHKey:
+ """Add an SSH public key for a user.
+
+ Args:
+ user_id: ID of the user
+ public_key: SSH public key in OpenSSH format
+ description: Optional description of the key
+
+ Returns:
+ Created SSHKey instance
+
+ Raises:
+ UserNotFoundError: If user doesn't exist
+ SSHKeyError: If key format is invalid
+ SSHKeyAlreadyExistsError: If key already exists
+ """
+ # Verify user exists
+ user = User.query.get(user_id)
+ if not user:
+ raise UserNotFoundError(f"User {user_id} not found")
+
+ # Validate key format
+ if not verify_ssh_key_format(public_key):
+ raise SSHKeyError("Invalid SSH public key format")
+
+ # Compute fingerprint
+ try:
+ fingerprint = compute_ssh_fingerprint(public_key)
+ except Exception as e:
+ logger.error(f"Failed to compute fingerprint: {str(e)}")
+ raise SSHKeyError(f"Failed to compute key fingerprint: {str(e)}")
+
+ # Check for duplicate (including soft-deleted records — fingerprint is unique in DB)
+ existing = SSHKey.query.filter_by(fingerprint=fingerprint).first()
+ if existing:
+ if existing.deleted_at is not None:
+ # Restore the soft-deleted key: clear deleted_at and update fields
+ existing.deleted_at = None
+ existing.user_id = user_id
+ existing.description = description or existing.description
+ existing.verified = False
+ existing.verified_at = None
+ existing.verify_text = None
+ existing.verify_text_created_at = None
+ db.session.commit()
+ logger.info(
+ f"Restored soft-deleted SSH key for user {user_id}: "
+ f"fingerprint={fingerprint}"
+ )
+ return existing
+ raise SSHKeyAlreadyExistsError(
+ f"SSH key with fingerprint {fingerprint} already exists"
+ )
+
+ # Extract metadata
+ key_type = extract_ssh_key_type(public_key)
+ key_comment = extract_ssh_key_comment(public_key)
+
+ # Create SSH key record
+ ssh_key = SSHKey(
+ user_id=user_id,
+ payload=public_key,
+ fingerprint=fingerprint,
+ description=description,
+ key_type=key_type,
+ key_comment=key_comment,
+ verified=False,
+ )
+
+ ssh_key.save()
+
+ logger.info(
+ f"SSH key added for user {user_id}: "
+ f"fingerprint={fingerprint}, type={key_type}"
+ )
+
+ return ssh_key
+
+ def get_ssh_key(self, key_id: str) -> SSHKey:
+ """Get an SSH key by ID.
+
+ Args:
+ key_id: SSH key ID
+
+ Returns:
+ SSHKey instance
+
+ Raises:
+ SSHKeyNotFoundError: If key not found
+ """
+ key = SSHKey.query.filter_by(id=key_id, deleted_at=None).first()
+ if not key:
+ raise SSHKeyNotFoundError(f"SSH key {key_id} not found")
+ return key
+
+ def get_user_ssh_keys(self, user_id: str) -> List[SSHKey]:
+ """Get all SSH keys for a user.
+
+ Args:
+ user_id: User ID
+
+ Returns:
+ List of SSHKey instances
+ """
+ return SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all()
+
+ def get_user_verified_ssh_keys(self, user_id: str) -> List[SSHKey]:
+ """Get all verified SSH keys for a user.
+
+ Args:
+ user_id: User ID
+
+ Returns:
+ List of verified SSHKey instances
+ """
+ return SSHKey.query.filter_by(
+ user_id=user_id,
+ verified=True,
+ deleted_at=None,
+ ).all()
+
+ def delete_ssh_key(self, key_id: str) -> None:
+ """Soft-delete an SSH key.
+
+ Args:
+ key_id: SSH key ID
+
+ Raises:
+ SSHKeyNotFoundError: If key not found
+ """
+ key = self.get_ssh_key(key_id)
+ key.delete()
+
+ logger.info(f"SSH key deleted: {key_id}")
+
+ def generate_verification_challenge(self, key_id: str) -> str:
+ """Generate a verification challenge for an SSH key.
+
+ The user must sign this challenge text with their private key
+ to prove key ownership.
+
+ Args:
+ key_id: SSH key ID
+
+ Returns:
+ Verification challenge text
+
+ Raises:
+ SSHKeyNotFoundError: If key not found
+ """
+ key = self.get_ssh_key(key_id)
+
+ # Generate random challenge
+ challenge = secrets.token_hex(32)
+ challenge_text = f"Please sign this to verify SSH key ownership: {challenge}"
+
+ # Store challenge
+ key.verify_text = challenge_text
+ key.verify_text_created_at = datetime.utcnow()
+ key.save()
+
+ logger.info(f"Generated verification challenge for SSH key {key_id}")
+
+ return challenge_text
+
+ def verify_ssh_key_ownership(
+ self,
+ key_id: str,
+ signature: str,
+ ) -> bool:
+ """Verify SSH key ownership via signature.
+
+ The user must sign the verification challenge with their private key.
+ We verify the signature using the public key.
+
+ Args:
+ key_id: SSH key ID
+ signature: Base64-encoded signature of the challenge
+
+ Returns:
+ True if signature is valid
+
+ Raises:
+ SSHKeyNotFoundError: If key not found
+ SSHKeyNotVerifiedError: If challenge is stale or missing
+ SSHKeyError: If verification fails
+ """
+ key = self.get_ssh_key(key_id)
+
+ # Check if challenge exists and is not stale
+ if not key.verify_text or not key.verify_text_created_at:
+ raise SSHKeyNotVerifiedError("No verification challenge generated")
+
+ max_age = self.config.get_int('verification_challenge_max_age')
+ age = datetime.utcnow() - key.verify_text_created_at
+ if age.total_seconds() > (max_age * 3600):
+ raise SSHKeyNotVerifiedError("Verification challenge has expired")
+
+ try:
+ # Verify the SSH signature using ssh-keygen -Y verify.
+ # The CLI signs the challenge with: ssh-keygen -Y sign -f -n file
+ # We verify with: ssh-keygen -Y verify -f -I -n file -s <
+ #
+ # allowed_signers format: " "
+ # We use the key fingerprint as the identity.
+
+ sig_bytes = base64.b64decode(signature)
+ challenge_text = key.verify_text + "\n"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ allowed_signers_path = os.path.join(tmpdir, "allowed_signers")
+ sig_path = os.path.join(tmpdir, "message.sig")
+ message_path = os.path.join(tmpdir, "message.txt")
+
+ identity = key.fingerprint
+
+ # Write the allowed_signers file
+ with open(allowed_signers_path, "w") as f:
+ f.write(f"{identity} {key.payload}\n")
+
+ # Write the signature file
+ with open(sig_path, "wb") as f:
+ f.write(sig_bytes)
+
+ # Write the challenge message
+ with open(message_path, "w") as f:
+ f.write(challenge_text)
+
+ result = subprocess.run(
+ [
+ "ssh-keygen", "-Y", "verify",
+ "-f", allowed_signers_path,
+ "-I", identity,
+ "-n", "file",
+ "-s", sig_path,
+ ],
+ stdin=open(message_path, "rb"),
+ capture_output=True,
+ timeout=10,
+ )
+
+ if result.returncode != 0:
+ stderr = result.stderr.decode(errors="replace").strip()
+ logger.warning(f"SSH signature verification failed for key {key_id}: {stderr}")
+ raise SSHKeyError(f"Signature verification failed: {stderr}")
+
+ key.mark_verified()
+ logger.info(f"SSH key verified: {key_id}")
+ return True
+
+ except SSHKeyError:
+ raise
+ except Exception as e:
+ logger.error(f"SSH key verification failed: {str(e)}")
+ raise SSHKeyError(f"Signature verification failed: {str(e)}")
+
+ def get_key_fingerprint(self, key_id: str) -> str:
+ """Get the fingerprint of an SSH key.
+
+ Args:
+ key_id: SSH key ID
+
+ Returns:
+ Fingerprint string
+
+ Raises:
+ SSHKeyNotFoundError: If key not found
+ """
+ key = self.get_ssh_key(key_id)
+ return key.fingerprint
+
+ def update_ssh_key_description(self, key_id: str, description: str) -> SSHKey:
+ """Update the description of an SSH key.
+
+ Args:
+ key_id: SSH key ID
+ description: New description
+
+ Returns:
+ Updated SSHKey instance
+
+ Raises:
+ SSHKeyNotFoundError: If key not found
+ """
+ key = self.get_ssh_key(key_id)
+ key.description = description
+ key.save()
+
+ return key
+
+ def cleanup_expired_challenges(self) -> int:
+ """Clean up expired verification challenges.
+
+ Returns:
+ Number of challenges cleaned
+ """
+ max_age = self.config.get_int('verification_challenge_max_age')
+ threshold = datetime.utcnow() - timedelta(hours=max_age)
+
+ expired = SSHKey.query.filter(
+ SSHKey.verify_text_created_at < threshold,
+ SSHKey.verify_text_created_at.isnot(None),
+ SSHKey.deleted_at.is_(None),
+ ).update({"verify_text": None, "verify_text_created_at": None})
+
+ db.session.commit()
+
+ logger.info(f"Cleaned up {expired} expired verification challenges")
+ return expired
+
+ def cleanup_unverified_keys(self) -> int:
+ """Delete unverified SSH keys older than configured days.
+
+ Returns:
+ Number of keys deleted
+ """
+ days = self.config.get_int('auto_delete_unverified_days')
+ threshold = datetime.utcnow() - timedelta(days=days)
+
+ old_unverified = SSHKey.query.filter(
+ SSHKey.verified == False,
+ SSHKey.created_at < threshold,
+ SSHKey.deleted_at.is_(None),
+ ).all()
+
+ count = 0
+ for key in old_unverified:
+ key.delete()
+ count += 1
+
+ logger.info(f"Deleted {count} unverified SSH keys older than {days} days")
+ return count
diff --git a/gatehouse_app/services/totp_service.py b/gatehouse_app/services/totp_service.py
index b71f56a..c667e3e 100644
--- a/gatehouse_app/services/totp_service.py
+++ b/gatehouse_app/services/totp_service.py
@@ -11,10 +11,51 @@ from gatehouse_app.extensions import bcrypt
logger = logging.getLogger(__name__)
+# TOTP codes are valid for at most (2*window + 1) * 30s steps.
+# With window=1 that's 3 steps = 90 seconds. We use a slightly
+# generous TTL of 95 seconds to account for clock skew at boundaries.
+_TOTP_USED_CODE_TTL = 95
+
class TOTPService:
"""Service for TOTP operations."""
+ # ------------------------------------------------------------------
+ # Replay-attack prevention helpers
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def _used_key(user_id: str, code: str) -> str:
+ return f"totp:used:{user_id}:{code}"
+
+ @staticmethod
+ def is_code_already_used(user_id: str, code: str) -> bool:
+ """Return True if *code* has already been accepted for *user_id*
+ within the current validity window (prevents replay attacks)."""
+ try:
+ from gatehouse_app.extensions import redis_client
+ if redis_client is None:
+ return False
+ return redis_client.exists(TOTPService._used_key(user_id, code)) == 1
+ except Exception:
+ logger.warning("Redis unavailable for TOTP replay check; allowing code")
+ return False
+
+ @staticmethod
+ def mark_code_used(user_id: str, code: str) -> None:
+ """Record *code* as consumed for *user_id* so it cannot be reused."""
+ try:
+ from gatehouse_app.extensions import redis_client
+ if redis_client is None:
+ return
+ redis_client.setex(
+ TOTPService._used_key(user_id, code),
+ _TOTP_USED_CODE_TTL,
+ "1",
+ )
+ except Exception:
+ logger.warning("Redis unavailable; TOTP used-code not recorded")
+
@staticmethod
def generate_secret() -> str:
"""
diff --git a/gatehouse_app/services/user_service.py b/gatehouse_app/services/user_service.py
index 94845d1..b8b5fc9 100644
--- a/gatehouse_app/services/user_service.py
+++ b/gatehouse_app/services/user_service.py
@@ -2,7 +2,7 @@
import logging
from flask import current_app
from gatehouse_app.extensions import db
-from gatehouse_app.models.user import User
+from gatehouse_app.models.user.user import User
from gatehouse_app.exceptions.validation_exceptions import UserNotFoundError
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
diff --git a/gatehouse_app/services/webauthn_service.py b/gatehouse_app/services/webauthn_service.py
index 79ea50a..9f953d3 100644
--- a/gatehouse_app/services/webauthn_service.py
+++ b/gatehouse_app/services/webauthn_service.py
@@ -10,8 +10,8 @@ from flask import current_app
from sqlalchemy.orm.attributes import flag_modified
from gatehouse_app.extensions import db, redis_client
-from gatehouse_app.models.user import User
-from gatehouse_app.models.authentication_method import AuthenticationMethod
+from gatehouse_app.models.user.user import User
+from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
from gatehouse_app.services.audit_service import AuditService
@@ -641,6 +641,26 @@ class WebAuthnService:
)
return True
+
+ @classmethod
+ def credential_belongs_to_user(cls, credential_id: str, user: User) -> bool:
+ """Check whether *credential_id* exists and belongs to *user*.
+
+ Args:
+ credential_id: The credential ID to look up
+ user: User instance
+
+ Returns:
+ True if the credential exists and belongs to this user, False otherwise.
+ """
+ auth_method = AuthenticationMethod.query.filter_by(
+ user_id=user.id,
+ method_type=AuthMethodType.WEBAUTHN,
+ deleted_at=None,
+ ).first()
+ if not auth_method or not auth_method.provider_data:
+ return False
+ return auth_method.provider_data.get("credential_id") == credential_id
@classmethod
def rename_credential(cls, credential_id: str, user: User, name: str) -> bool:
diff --git a/gatehouse_app/utils/ca_key_encryption.py b/gatehouse_app/utils/ca_key_encryption.py
new file mode 100644
index 0000000..183e3c3
--- /dev/null
+++ b/gatehouse_app/utils/ca_key_encryption.py
@@ -0,0 +1,206 @@
+"""Encryption helpers for CA private keys stored in the database.
+
+CA private keys are encrypted at rest using Fernet (AES-128-CBC + HMAC-SHA256)
+from the ``cryptography`` package. The encryption key is derived from the
+``CA_ENCRYPTION_KEY`` environment variable (or ``Flask.config["CA_ENCRYPTION_KEY"]``).
+
+Key derivation
+--------------
+Fernet requires a URL-safe base64-encoded 32-byte key. We accept any string
+from the env and derive the actual Fernet key using SHA-256 so that operators
+can supply human-readable secrets without having to pre-encode them.
+
+Envelope format
+---------------
+Encrypted values are stored as the string::
+
+ $fernet$
+
+The ``$fernet$`` prefix lets the code distinguish already-encrypted values from
+legacy plaintext PEM keys so that the migration path is safe and idempotent.
+
+Usage
+-----
+Encrypt before storing::
+
+ from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
+ ca.private_key = encrypt_ca_key(private_key_pem)
+
+Decrypt before use::
+
+ from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
+ plaintext_pem = decrypt_ca_key(ca.private_key)
+"""
+import base64
+import hashlib
+import logging
+import os
+
+from cryptography.fernet import Fernet, InvalidToken
+
+logger = logging.getLogger(__name__)
+
+# Prefix that marks a stored value as Fernet-encrypted
+_FERNET_PREFIX = "$fernet$"
+
+
+class CAKeyEncryptionError(Exception):
+ """Raised when CA key encryption or decryption fails."""
+
+
+def _get_fernet() -> Fernet:
+ """Build a Fernet instance from the configured encryption key.
+
+ Looks up ``CA_ENCRYPTION_KEY`` in the environment first, then falls back to
+ the Flask app config (if a request context is active).
+
+ Raises:
+ CAKeyEncryptionError: if no key is configured or it is the insecure
+ placeholder value in a production-like environment.
+ """
+ raw_key = os.environ.get("CA_ENCRYPTION_KEY")
+
+ if not raw_key:
+ # Try Flask config if we're inside an app context
+ try:
+ from flask import current_app
+ raw_key = current_app.config.get("CA_ENCRYPTION_KEY")
+ except RuntimeError:
+ pass # No app context
+
+ if not raw_key:
+ raise CAKeyEncryptionError(
+ "CA_ENCRYPTION_KEY is not set. "
+ "Set this environment variable before starting the application."
+ )
+
+ # Warn loudly when running with the placeholder in a non-test environment
+ env_name = os.environ.get("FLASK_ENV", "").lower()
+ if raw_key.startswith("dev-") and env_name not in ("development", "testing", "test"):
+ logger.warning(
+ "CA_ENCRYPTION_KEY appears to be a development placeholder. "
+ "Set a strong random key for production environments."
+ )
+
+ # Derive a 32-byte key from the raw secret via SHA-256, then URL-safe base64
+ key_bytes = hashlib.sha256(raw_key.encode()).digest()
+ fernet_key = base64.urlsafe_b64encode(key_bytes)
+ return Fernet(fernet_key)
+
+
+def encrypt_ca_key(plaintext_pem: str) -> str:
+ """Encrypt a CA private key PEM string.
+
+ Idempotent: already-encrypted values are returned unchanged.
+
+ Args:
+ plaintext_pem: CA private key in OpenSSH/PEM format.
+
+ Returns:
+ Encrypted string with ``$fernet$`` prefix, safe for database storage.
+
+ Raises:
+ CAKeyEncryptionError: if the key cannot be encrypted.
+ """
+ if not plaintext_pem:
+ raise CAKeyEncryptionError("Cannot encrypt an empty key")
+
+ # Already encrypted — do not double-encrypt
+ if plaintext_pem.startswith(_FERNET_PREFIX):
+ return plaintext_pem
+
+ try:
+ fernet = _get_fernet()
+ token = fernet.encrypt(plaintext_pem.encode()).decode()
+ return f"{_FERNET_PREFIX}{token}"
+ except CAKeyEncryptionError:
+ raise
+ except Exception as exc:
+ raise CAKeyEncryptionError(f"Failed to encrypt CA key: {exc}") from exc
+
+
+def decrypt_ca_key(stored_value: str) -> str:
+ """Decrypt a CA private key retrieved from the database.
+
+ Idempotent: plaintext (legacy) values are returned unchanged so that the
+ system continues to work while a migration encrypts existing rows.
+
+ Args:
+ stored_value: Value from ``CA.private_key`` column.
+
+ Returns:
+ Plaintext PEM string ready for use with ``sshkey_tools``.
+
+ Raises:
+ CAKeyEncryptionError: if decryption fails (wrong key, corrupted data).
+ """
+ if not stored_value:
+ raise CAKeyEncryptionError("Cannot decrypt an empty value")
+
+ # Legacy plaintext key — return as-is
+ if not stored_value.startswith(_FERNET_PREFIX):
+ logger.warning(
+ "CA private key appears to be stored as plaintext. "
+ "Run the migration to encrypt existing keys."
+ )
+ return stored_value
+
+ token = stored_value[len(_FERNET_PREFIX):]
+ try:
+ fernet = _get_fernet()
+ return fernet.decrypt(token.encode()).decode()
+ except InvalidToken as exc:
+ raise CAKeyEncryptionError(
+ "CA key decryption failed — the CA_ENCRYPTION_KEY may be incorrect "
+ "or the stored key is corrupted."
+ ) from exc
+ except CAKeyEncryptionError:
+ raise
+ except Exception as exc:
+ raise CAKeyEncryptionError(f"Unexpected decryption error: {exc}") from exc
+
+
+def is_encrypted(stored_value: str) -> bool:
+ """Return True if the stored value has the ``$fernet$`` envelope.
+
+ Args:
+ stored_value: Value from ``CA.private_key`` column.
+ """
+ return bool(stored_value and stored_value.startswith(_FERNET_PREFIX))
+
+
+def reencrypt_ca_key(stored_value: str, old_raw_key: str, new_raw_key: str) -> str:
+ """Re-encrypt a CA key with a new encryption key (for key rotation).
+
+ Args:
+ stored_value: Current value from ``CA.private_key`` (may or may not be encrypted).
+ old_raw_key: The current ``CA_ENCRYPTION_KEY`` value (raw secret string).
+ new_raw_key: The new ``CA_ENCRYPTION_KEY`` value to encrypt with.
+
+ Returns:
+ New encrypted envelope string.
+
+ Raises:
+ CAKeyEncryptionError: if decryption or re-encryption fails.
+ """
+ # Decrypt with old key
+ if stored_value.startswith(_FERNET_PREFIX):
+ token = stored_value[len(_FERNET_PREFIX):]
+ old_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(old_raw_key.encode()).digest())
+ try:
+ plaintext = Fernet(old_key_bytes).decrypt(token.encode()).decode()
+ except InvalidToken as exc:
+ raise CAKeyEncryptionError(
+ "Re-encryption failed: could not decrypt with the old key."
+ ) from exc
+ else:
+ # Plaintext
+ plaintext = stored_value
+
+ # Re-encrypt with new key
+ new_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(new_raw_key.encode()).digest())
+ try:
+ token = Fernet(new_key_bytes).encrypt(plaintext.encode()).decode()
+ return f"{_FERNET_PREFIX}{token}"
+ except Exception as exc:
+ raise CAKeyEncryptionError(f"Re-encryption with new key failed: {exc}") from exc
diff --git a/gatehouse_app/utils/constants.py b/gatehouse_app/utils/constants.py
index cd6a58b..2a99825 100644
--- a/gatehouse_app/utils/constants.py
+++ b/gatehouse_app/utils/constants.py
@@ -12,6 +12,16 @@ class UserStatus(str, Enum):
COMPLIANCE_SUSPENDED = "compliance_suspended"
+class Role(str, Enum):
+ """Generic role definitions (hierarchy: Admin > Manager > Member > Viewer > Guest)."""
+
+ ADMIN = "admin"
+ MANAGER = "manager"
+ MEMBER = "member"
+ VIEWER = "viewer"
+ GUEST = "guest"
+
+
class OrganizationRole(str, Enum):
"""Organization member roles."""
@@ -51,6 +61,9 @@ class AuditAction(str, Enum):
USER_REGISTER = "user.register"
USER_UPDATE = "user.update"
USER_DELETE = "user.delete"
+ USER_HARD_DELETE = "user.hard_delete"
+ USER_SUSPEND = "user.suspend"
+ USER_UNSUSPEND = "user.unsuspend"
PASSWORD_CHANGE = "user.password_change"
PASSWORD_RESET = "user.password_reset"
@@ -61,6 +74,7 @@ class AuditAction(str, Enum):
ORG_MEMBER_ADD = "org.member.add"
ORG_MEMBER_REMOVE = "org.member.remove"
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
+ ORG_OWNERSHIP_TRANSFERRED = "org.ownership.transferred"
# Session actions
SESSION_CREATE = "session.create"
@@ -105,6 +119,37 @@ class AuditAction(str, Enum):
EXTERNAL_AUTH_CONFIG_UPDATE = "external_auth.config.update"
EXTERNAL_AUTH_CONFIG_DELETE = "external_auth.config.delete"
+ # SSH Key and Certificate actions
+ SSH_KEY_ADDED = "ssh.key.added"
+ SSH_KEY_VERIFIED = "ssh.key.verified"
+ SSH_KEY_DELETED = "ssh.key.deleted"
+ SSH_KEY_VALIDATION_FAILED = "ssh.key.validation.failed"
+ SSH_CERT_REQUESTED = "ssh.cert.requested"
+ SSH_CERT_ISSUED = "ssh.cert.issued"
+ SSH_CERT_FAILED = "ssh.cert.failed"
+ SSH_CERT_REVOKED = "ssh.cert.revoked"
+ SSH_CERT_EXPIRED = "ssh.cert.expired"
+
+ # CA actions
+ CA_CREATED = "ca.created"
+ CA_UPDATED = "ca.updated"
+ CA_DELETED = "ca.deleted"
+ CA_KEY_ROTATED = "ca.key.rotated"
+
+ # Principal actions
+ PRINCIPAL_CREATED = "principal.created"
+ PRINCIPAL_UPDATED = "principal.updated"
+ PRINCIPAL_DELETED = "principal.deleted"
+ PRINCIPAL_MEMBER_ADDED = "principal.member.added"
+ PRINCIPAL_MEMBER_REMOVED = "principal.member.removed"
+
+ # Department actions
+ DEPARTMENT_CREATED = "department.created"
+ DEPARTMENT_UPDATED = "department.updated"
+ DEPARTMENT_DELETED = "department.deleted"
+ DEPARTMENT_MEMBER_ADDED = "department.member.added"
+ DEPARTMENT_MEMBER_REMOVED = "department.member.removed"
+
class OIDCGrantType(str, Enum):
"""OIDC grant types."""
diff --git a/gatehouse_app/utils/crypto.py b/gatehouse_app/utils/crypto.py
new file mode 100644
index 0000000..0322a8c
--- /dev/null
+++ b/gatehouse_app/utils/crypto.py
@@ -0,0 +1,128 @@
+"""Cryptographic utilities for SSH operations."""
+import hashlib
+import base64
+from typing import Optional
+
+
+def compute_ssh_fingerprint(public_key_str: str, hash_algorithm: str = "sha256") -> str:
+ """Compute the fingerprint of an SSH public key.
+
+ Args:
+ public_key_str: SSH public key in OpenSSH format
+ hash_algorithm: Hash algorithm to use (sha256, sha1, md5)
+
+ Returns:
+ Fingerprint string in the format "algorithm:hex_digest"
+
+ Example:
+ >>> key = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKp2..."
+ >>> fp = compute_ssh_fingerprint(key)
+ >>> print(fp)
+ sha256:Kb+...
+ """
+ if not public_key_str:
+ raise ValueError("Public key string is empty")
+
+ # Parse OpenSSH format: "ssh-ed25519 [comment]"
+ parts = public_key_str.strip().split()
+ if len(parts) < 2:
+ raise ValueError("Invalid OpenSSH public key format")
+
+ try:
+ # The base64-encoded key is the second part
+ key_bytes = base64.b64decode(parts[1])
+ except Exception as e:
+ raise ValueError(f"Failed to decode public key: {str(e)}")
+
+ # Compute hash
+ if hash_algorithm == "sha256":
+ digest = hashlib.sha256(key_bytes).digest()
+ # SSH format uses base64 encoding without padding
+ fingerprint = base64.b64encode(digest).decode().rstrip('=')
+ elif hash_algorithm == "sha1":
+ digest = hashlib.sha1(key_bytes).hexdigest()
+ fingerprint = digest
+ elif hash_algorithm == "md5":
+ digest = hashlib.md5(key_bytes).hexdigest()
+ # Format as colons
+ fingerprint = ':'.join(digest[i:i+2] for i in range(0, len(digest), 2))
+ else:
+ raise ValueError(f"Unsupported hash algorithm: {hash_algorithm}")
+
+ return f"{hash_algorithm}:{fingerprint}"
+
+
+def verify_ssh_key_format(public_key_str: str) -> bool:
+ """Verify that a string is in valid OpenSSH public key format.
+
+ Args:
+ public_key_str: Potential SSH public key
+
+ Returns:
+ True if valid OpenSSH format, False otherwise
+ """
+ if not public_key_str or not isinstance(public_key_str, str):
+ return False
+
+ parts = public_key_str.strip().split()
+
+ # Must have at least key type and key material
+ if len(parts) < 2:
+ return False
+
+ key_type = parts[0]
+
+ # Valid key types
+ valid_types = [
+ 'ssh-rsa',
+ 'ssh-ed25519',
+ 'ecdsa-sha2-nistp256',
+ 'ecdsa-sha2-nistp384',
+ 'ecdsa-sha2-nistp521',
+ 'ssh-dss',
+ ]
+
+ if key_type not in valid_types:
+ return False
+
+ # Try to decode base64
+ try:
+ base64.b64decode(parts[1])
+ return True
+ except Exception:
+ return False
+
+
+def extract_ssh_key_type(public_key_str: str) -> Optional[str]:
+ """Extract the key type from an OpenSSH public key.
+
+ Args:
+ public_key_str: SSH public key in OpenSSH format
+
+ Returns:
+ Key type (e.g., "ssh-ed25519") or None if invalid
+ """
+ if not verify_ssh_key_format(public_key_str):
+ return None
+
+ return public_key_str.strip().split()[0]
+
+
+def extract_ssh_key_comment(public_key_str: str) -> Optional[str]:
+ """Extract the comment from an OpenSSH public key.
+
+ Args:
+ public_key_str: SSH public key in OpenSSH format
+
+ Returns:
+ Comment string or None if not present
+ """
+ if not verify_ssh_key_format(public_key_str):
+ return None
+
+ parts = public_key_str.strip().split()
+ if len(parts) >= 3:
+ # Everything after the second part is the comment
+ return ' '.join(parts[2:])
+
+ return None
diff --git a/gatehouse_app/utils/decorators.py b/gatehouse_app/utils/decorators.py
index e3b0085..5cbb649 100644
--- a/gatehouse_app/utils/decorators.py
+++ b/gatehouse_app/utils/decorators.py
@@ -64,11 +64,47 @@ def login_required(f):
session.last_activity_at = datetime.now(timezone.utc)
from gatehouse_app import db
db.session.commit()
-
+
# Set context variables
g.current_user = session.user
g.current_session = session
-
+
+ user = session.user
+ token_groups: list = []
+ try:
+ if session.device_info:
+ # device_info may carry OIDC claims stored at login time
+ claims = session.device_info
+ # Normalise: Gatehouse stores roles as [{"organization_id":…,"role":…}]
+ roles_claim = claims.get("roles", [])
+ if isinstance(roles_claim, list):
+ for entry in roles_claim:
+ if isinstance(entry, dict):
+ role_val = entry.get("role")
+ if role_val:
+ token_groups.append(str(role_val))
+ elif isinstance(entry, str):
+ token_groups.append(entry)
+ # Standard OIDC groups claim
+ groups_claim = claims.get("groups", [])
+ if isinstance(groups_claim, list):
+ token_groups.extend(str(g_) for g_ in groups_claim if g_)
+ except Exception:
+ pass # Never block auth over token_groups enrichment failure
+ user.token_groups = token_groups
+
+ # Activation check: if the user has an `activated` attribute and it is
+ # explicitly False, block access. New accounts without the attribute are
+ # treated as active to avoid breaking existing sessions.
+ activated = getattr(user, "activated", None)
+ if activated is False:
+ return api_response(
+ success=False,
+ message="Account not yet activated. Please check your email for an activation link.",
+ status=403,
+ error_type="ACCOUNT_NOT_ACTIVATED",
+ )
+
return f(*args, **kwargs)
return decorated_function
@@ -97,11 +133,12 @@ def require_role(*allowed_roles):
raise ForbiddenError("Organization context required")
# Check user's role in the organization
- from gatehouse_app.models.organization_member import OrganizationMember
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
membership = OrganizationMember.query.filter_by(
user_id=g.current_user.id,
organization_id=org_id,
+ deleted_at=None,
).first()
if not membership:
diff --git a/manage.py b/manage.py
index 088447e..a89d335 100644
--- a/manage.py
+++ b/manage.py
@@ -111,5 +111,79 @@ def mfa_compliance_status():
print("=" * 60)
+@cli.command("configure_oauth")
+def configure_oauth():
+ """Interactively configure an OAuth provider at the application level.
+
+ Usage:
+ python manage.py configure_oauth
+
+ Supported providers: google, github, microsoft
+ """
+ import getpass
+ from gatehouse_app.models.authentication_method import ApplicationProviderConfig
+ from gatehouse_app.extensions import db
+
+ SUPPORTED = ["google", "github", "microsoft"]
+
+ print("=" * 60)
+ print("OAuth Provider Configuration")
+ print("=" * 60)
+ print(f"Supported providers: {', '.join(SUPPORTED)}")
+
+ provider = input("Provider [google/github/microsoft]: ").strip().lower()
+ if provider not in SUPPORTED:
+ print(f"❌ Unknown provider: {provider}")
+ return
+
+ client_id = input("Client ID: ").strip()
+ if not client_id:
+ print("❌ client_id is required")
+ return
+
+ client_secret = getpass.getpass("Client Secret (leave blank to keep existing): ").strip()
+
+ with app.app_context():
+ config = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
+ if config:
+ config.client_id = client_id
+ if client_secret:
+ config.set_client_secret(client_secret)
+ config.is_enabled = True
+ db.session.commit()
+ print(f"✅ Updated {provider} provider config.")
+ else:
+ config = ApplicationProviderConfig(
+ provider_type=provider,
+ client_id=client_id,
+ is_enabled=True,
+ )
+ if client_secret:
+ config.set_client_secret(client_secret)
+ db.session.add(config)
+ db.session.commit()
+ print(f"✅ Created {provider} provider config.")
+
+
+@cli.command("list_oauth")
+def list_oauth():
+ """List all configured OAuth providers.
+
+ Usage:
+ python manage.py list_oauth
+ """
+ from gatehouse_app.models.authentication_method import ApplicationProviderConfig
+
+ with app.app_context():
+ configs = ApplicationProviderConfig.query.all()
+ if not configs:
+ print("No OAuth providers configured.")
+ return
+ print(f"{'Provider':<15} {'Client ID':<40} {'Enabled'}")
+ print("-" * 65)
+ for c in configs:
+ print(f"{c.provider_type:<15} {c.client_id:<40} {c.is_enabled}")
+
+
if __name__ == "__main__":
cli()
diff --git a/migrations/versions/006_add_departments_principals.py b/migrations/versions/006_add_departments_principals.py
new file mode 100644
index 0000000..8022143
--- /dev/null
+++ b/migrations/versions/006_add_departments_principals.py
@@ -0,0 +1,127 @@
+"""Add Department and Principal models for SSH CA management.
+
+Revision ID: 006
+Revises: 005
+Create Date: 2026-02-27 10:00:00.000000
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '006'
+down_revision = '005'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### Department table ###
+ op.create_table('departments',
+ sa.Column('organization_id', sa.String(length=36), nullable=False),
+ sa.Column('name', sa.String(length=255), nullable=False),
+ sa.Column('description', sa.Text(), nullable=True),
+ sa.Column('id', sa.String(length=36), nullable=False),
+ sa.Column('created_at', sa.DateTime(), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
+ sa.Column('deleted_at', sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
+ sa.PrimaryKeyConstraint('id'),
+ sa.UniqueConstraint('id'),
+ sa.UniqueConstraint('organization_id', 'name', name='uix_org_dept_name')
+ )
+ op.create_index(op.f('ix_departments_organization_id'), 'departments', ['organization_id'], unique=False)
+ op.create_index(op.f('ix_departments_name'), 'departments', ['name'], unique=False)
+
+ # ### DepartmentMembership table ###
+ op.create_table('department_memberships',
+ sa.Column('user_id', sa.String(length=36), nullable=False),
+ sa.Column('department_id', sa.String(length=36), nullable=False),
+ sa.Column('id', sa.String(length=36), nullable=False),
+ sa.Column('created_at', sa.DateTime(), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
+ sa.Column('deleted_at', sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
+ sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
+ sa.PrimaryKeyConstraint('id'),
+ sa.UniqueConstraint('id'),
+ sa.UniqueConstraint('user_id', 'department_id', name='uix_user_dept')
+ )
+ op.create_index(op.f('ix_department_memberships_user_id'), 'department_memberships', ['user_id'], unique=False)
+ op.create_index(op.f('ix_department_memberships_department_id'), 'department_memberships', ['department_id'], unique=False)
+
+ # ### Principal table ###
+ op.create_table('principals',
+ sa.Column('organization_id', sa.String(length=36), nullable=False),
+ sa.Column('name', sa.String(length=255), nullable=False),
+ sa.Column('description', sa.Text(), nullable=True),
+ sa.Column('id', sa.String(length=36), nullable=False),
+ sa.Column('created_at', sa.DateTime(), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
+ sa.Column('deleted_at', sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
+ sa.PrimaryKeyConstraint('id'),
+ sa.UniqueConstraint('id'),
+ sa.UniqueConstraint('organization_id', 'name', name='uix_org_principal_name')
+ )
+ op.create_index(op.f('ix_principals_organization_id'), 'principals', ['organization_id'], unique=False)
+ op.create_index(op.f('ix_principals_name'), 'principals', ['name'], unique=False)
+
+ # ### PrincipalMembership table ###
+ op.create_table('principal_memberships',
+ sa.Column('user_id', sa.String(length=36), nullable=False),
+ sa.Column('principal_id', sa.String(length=36), nullable=False),
+ sa.Column('id', sa.String(length=36), nullable=False),
+ sa.Column('created_at', sa.DateTime(), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
+ sa.Column('deleted_at', sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ),
+ sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
+ sa.PrimaryKeyConstraint('id'),
+ sa.UniqueConstraint('id'),
+ sa.UniqueConstraint('user_id', 'principal_id', name='uix_user_principal')
+ )
+ op.create_index(op.f('ix_principal_memberships_user_id'), 'principal_memberships', ['user_id'], unique=False)
+ op.create_index(op.f('ix_principal_memberships_principal_id'), 'principal_memberships', ['principal_id'], unique=False)
+
+ # ### DepartmentPrincipal table ###
+ op.create_table('department_principals',
+ sa.Column('department_id', sa.String(length=36), nullable=False),
+ sa.Column('principal_id', sa.String(length=36), nullable=False),
+ sa.Column('id', sa.String(length=36), nullable=False),
+ sa.Column('created_at', sa.DateTime(), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
+ sa.Column('deleted_at', sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
+ sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ),
+ sa.PrimaryKeyConstraint('id'),
+ sa.UniqueConstraint('id'),
+ sa.UniqueConstraint('department_id', 'principal_id', name='uix_dept_principal')
+ )
+ op.create_index(op.f('ix_department_principals_department_id'), 'department_principals', ['department_id'], unique=False)
+ op.create_index(op.f('ix_department_principals_principal_id'), 'department_principals', ['principal_id'], unique=False)
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_index(op.f('ix_department_principals_principal_id'), table_name='department_principals')
+ op.drop_index(op.f('ix_department_principals_department_id'), table_name='department_principals')
+ op.drop_table('department_principals')
+
+ op.drop_index(op.f('ix_principal_memberships_principal_id'), table_name='principal_memberships')
+ op.drop_index(op.f('ix_principal_memberships_user_id'), table_name='principal_memberships')
+ op.drop_table('principal_memberships')
+
+ op.drop_index(op.f('ix_principals_name'), table_name='principals')
+ op.drop_index(op.f('ix_principals_organization_id'), table_name='principals')
+ op.drop_table('principals')
+
+ op.drop_index(op.f('ix_department_memberships_department_id'), table_name='department_memberships')
+ op.drop_index(op.f('ix_department_memberships_user_id'), table_name='department_memberships')
+ op.drop_table('department_memberships')
+
+ op.drop_index(op.f('ix_departments_name'), table_name='departments')
+ op.drop_index(op.f('ix_departments_organization_id'), table_name='departments')
+ op.drop_table('departments')
+ # ### end Alembic commands ###
diff --git a/migrations/versions/007_add_ssh_ca_models.py b/migrations/versions/007_add_ssh_ca_models.py
new file mode 100644
index 0000000..9749930
--- /dev/null
+++ b/migrations/versions/007_add_ssh_ca_models.py
@@ -0,0 +1,173 @@
+"""Add SSH CA models: SSHKey, SSHCertificate, CA, CertificateAuditLog.
+
+Revision ID: 007
+Revises: 006
+Create Date: 2026-02-27 11:00:00.000000
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '007'
+down_revision = '006'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### CA table ###
+ op.create_table('cas',
+ sa.Column('organization_id', sa.String(length=36), nullable=False),
+ sa.Column('name', sa.String(length=255), nullable=False),
+ sa.Column('description', sa.Text(), nullable=True),
+ sa.Column('key_type', sa.Enum('ed25519', 'rsa', 'ecdsa', name='ca_key_type_enum'), nullable=False),
+ sa.Column('private_key', sa.Text(), nullable=False),
+ sa.Column('public_key', sa.Text(), nullable=False),
+ sa.Column('fingerprint', sa.String(length=255), nullable=False),
+ sa.Column('crl_enabled', sa.Boolean(), nullable=False),
+ sa.Column('crl_endpoint', sa.String(length=512), nullable=True),
+ sa.Column('default_cert_validity_hours', sa.Integer(), nullable=False),
+ sa.Column('max_cert_validity_hours', sa.Integer(), nullable=False),
+ sa.Column('is_active', sa.Boolean(), nullable=False),
+ sa.Column('rotated_at', sa.DateTime(), nullable=True),
+ sa.Column('rotation_reason', sa.String(length=255), nullable=True),
+ sa.Column('id', sa.String(length=36), nullable=False),
+ sa.Column('created_at', sa.DateTime(), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
+ sa.Column('deleted_at', sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
+ sa.PrimaryKeyConstraint('id'),
+ sa.UniqueConstraint('id'),
+ sa.UniqueConstraint('fingerprint'),
+ sa.UniqueConstraint('organization_id', 'name', name='uix_org_ca_name')
+ )
+ op.create_index(op.f('ix_cas_organization_id'), 'cas', ['organization_id'], unique=False)
+ op.create_index('idx_ca_org_active', 'cas', ['organization_id', 'is_active'], unique=False)
+
+ # ### SSHKey table ###
+ op.create_table('ssh_keys',
+ sa.Column('user_id', sa.String(length=36), nullable=False),
+ sa.Column('payload', sa.Text(), nullable=False),
+ sa.Column('fingerprint', sa.String(length=255), nullable=False),
+ sa.Column('description', sa.String(length=255), nullable=True),
+ sa.Column('verified', sa.Boolean(), nullable=False),
+ sa.Column('verified_at', sa.DateTime(), nullable=True),
+ sa.Column('verify_text', sa.String(length=255), nullable=True),
+ sa.Column('verify_text_created_at', sa.DateTime(), nullable=True),
+ sa.Column('key_type', sa.String(length=50), nullable=True),
+ sa.Column('key_bits', sa.Integer(), nullable=True),
+ sa.Column('key_comment', sa.String(length=255), nullable=True),
+ sa.Column('id', sa.String(length=36), nullable=False),
+ sa.Column('created_at', sa.DateTime(), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
+ sa.Column('deleted_at', sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
+ sa.PrimaryKeyConstraint('id'),
+ sa.UniqueConstraint('id'),
+ sa.UniqueConstraint('payload'),
+ sa.UniqueConstraint('fingerprint')
+ )
+ op.create_index(op.f('ix_ssh_keys_user_id'), 'ssh_keys', ['user_id'], unique=False)
+ op.create_index(op.f('ix_ssh_keys_fingerprint'), 'ssh_keys', ['fingerprint'], unique=False)
+ op.create_index(op.f('ix_ssh_keys_verified'), 'ssh_keys', ['verified'], unique=False)
+ op.create_index('idx_ssh_key_user_verified', 'ssh_keys', ['user_id', 'verified'], unique=False)
+
+ # ### SSHCertificate table ###
+ op.create_table('ssh_certificates',
+ sa.Column('ca_id', sa.String(length=36), nullable=False),
+ sa.Column('user_id', sa.String(length=36), nullable=False),
+ sa.Column('ssh_key_id', sa.String(length=36), nullable=False),
+ sa.Column('certificate', sa.Text(), nullable=False),
+ sa.Column('serial', sa.String(length=255), nullable=False),
+ sa.Column('key_id', sa.String(length=255), nullable=False),
+ sa.Column('cert_type', sa.Enum('user', 'host', name='ssh_cert_type_enum'), nullable=False),
+ sa.Column('principals', sa.JSON(), nullable=False),
+ sa.Column('valid_after', sa.DateTime(), nullable=False),
+ sa.Column('valid_before', sa.DateTime(), nullable=False),
+ sa.Column('revoked', sa.Boolean(), nullable=False),
+ sa.Column('revoked_at', sa.DateTime(), nullable=True),
+ sa.Column('revoke_reason', sa.String(length=255), nullable=True),
+ sa.Column('status', sa.Enum('requested', 'issued', 'revoked', 'expired', 'superseded', name='ssh_cert_status_enum'), nullable=False),
+ sa.Column('request_ip', sa.String(length=45), nullable=True),
+ sa.Column('request_user_agent', sa.String(length=512), nullable=True),
+ sa.Column('critical_options', sa.JSON(), nullable=True),
+ sa.Column('extensions', sa.JSON(), nullable=True),
+ sa.Column('id', sa.String(length=36), nullable=False),
+ sa.Column('created_at', sa.DateTime(), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
+ sa.Column('deleted_at', sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ),
+ sa.ForeignKeyConstraint(['ssh_key_id'], ['ssh_keys.id'], ),
+ sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
+ sa.PrimaryKeyConstraint('id'),
+ sa.UniqueConstraint('id'),
+ sa.UniqueConstraint('serial')
+ )
+ op.create_index(op.f('ix_ssh_certificates_ca_id'), 'ssh_certificates', ['ca_id'], unique=False)
+ op.create_index(op.f('ix_ssh_certificates_user_id'), 'ssh_certificates', ['user_id'], unique=False)
+ op.create_index(op.f('ix_ssh_certificates_ssh_key_id'), 'ssh_certificates', ['ssh_key_id'], unique=False)
+ op.create_index(op.f('ix_ssh_certificates_serial'), 'ssh_certificates', ['serial'], unique=False)
+ op.create_index(op.f('ix_ssh_certificates_revoked'), 'ssh_certificates', ['revoked'], unique=False)
+ op.create_index(op.f('ix_ssh_certificates_status'), 'ssh_certificates', ['status'], unique=False)
+ op.create_index('idx_cert_user_status', 'ssh_certificates', ['user_id', 'status'], unique=False)
+ op.create_index('idx_cert_validity', 'ssh_certificates', ['valid_after', 'valid_before'], unique=False)
+ op.create_index('idx_cert_revoked', 'ssh_certificates', ['revoked', 'revoked_at'], unique=False)
+
+ # ### CertificateAuditLog table ###
+ op.create_table('certificate_audit_logs',
+ sa.Column('certificate_id', sa.String(length=36), nullable=False),
+ sa.Column('user_id', sa.String(length=36), nullable=True),
+ sa.Column('action', sa.String(length=50), nullable=False),
+ sa.Column('ip_address', sa.String(length=45), nullable=True),
+ sa.Column('user_agent', sa.String(length=512), nullable=True),
+ sa.Column('request_id', sa.String(length=36), nullable=True),
+ sa.Column('message', sa.Text(), nullable=True),
+ sa.Column('extra_data', sa.JSON(), nullable=True),
+ sa.Column('success', sa.Boolean(), nullable=False),
+ sa.Column('error_message', sa.Text(), nullable=True),
+ sa.Column('id', sa.String(length=36), nullable=False),
+ sa.Column('created_at', sa.DateTime(), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
+ sa.Column('deleted_at', sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(['certificate_id'], ['ssh_certificates.id'], ),
+ sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
+ sa.PrimaryKeyConstraint('id'),
+ sa.UniqueConstraint('id')
+ )
+ op.create_index(op.f('ix_certificate_audit_logs_certificate_id'), 'certificate_audit_logs', ['certificate_id'], unique=False)
+ op.create_index(op.f('ix_certificate_audit_logs_user_id'), 'certificate_audit_logs', ['user_id'], unique=False)
+ op.create_index(op.f('ix_certificate_audit_logs_action'), 'certificate_audit_logs', ['action'], unique=False)
+ op.create_index('idx_cert_audit_cert_action', 'certificate_audit_logs', ['certificate_id', 'action'], unique=False)
+ op.create_index('idx_cert_audit_user', 'certificate_audit_logs', ['user_id', 'created_at'], unique=False)
+
+
+def downgrade():
+ op.drop_index('idx_cert_audit_user', table_name='certificate_audit_logs')
+ op.drop_index('idx_cert_audit_cert_action', table_name='certificate_audit_logs')
+ op.drop_index(op.f('ix_certificate_audit_logs_action'), table_name='certificate_audit_logs')
+ op.drop_index(op.f('ix_certificate_audit_logs_user_id'), table_name='certificate_audit_logs')
+ op.drop_index(op.f('ix_certificate_audit_logs_certificate_id'), table_name='certificate_audit_logs')
+ op.drop_table('certificate_audit_logs')
+
+ op.drop_index('idx_cert_revoked', table_name='ssh_certificates')
+ op.drop_index('idx_cert_validity', table_name='ssh_certificates')
+ op.drop_index('idx_cert_user_status', table_name='ssh_certificates')
+ op.drop_index(op.f('ix_ssh_certificates_status'), table_name='ssh_certificates')
+ op.drop_index(op.f('ix_ssh_certificates_revoked'), table_name='ssh_certificates')
+ op.drop_index(op.f('ix_ssh_certificates_serial'), table_name='ssh_certificates')
+ op.drop_index(op.f('ix_ssh_certificates_ssh_key_id'), table_name='ssh_certificates')
+ op.drop_index(op.f('ix_ssh_certificates_user_id'), table_name='ssh_certificates')
+ op.drop_index(op.f('ix_ssh_certificates_ca_id'), table_name='ssh_certificates')
+ op.drop_table('ssh_certificates')
+
+ op.drop_index('idx_ssh_key_user_verified', table_name='ssh_keys')
+ op.drop_index(op.f('ix_ssh_keys_verified'), table_name='ssh_keys')
+ op.drop_index(op.f('ix_ssh_keys_fingerprint'), table_name='ssh_keys')
+ op.drop_index(op.f('ix_ssh_keys_user_id'), table_name='ssh_keys')
+ op.drop_table('ssh_keys')
+
+ op.drop_index('idx_ca_org_active', table_name='cas')
+ op.drop_index(op.f('ix_cas_organization_id'), table_name='cas')
+ op.drop_table('cas')
diff --git a/migrations/versions/008_fix_authmethodtype_enum.py b/migrations/versions/008_fix_authmethodtype_enum.py
new file mode 100644
index 0000000..ccf56e9
--- /dev/null
+++ b/migrations/versions/008_fix_authmethodtype_enum.py
@@ -0,0 +1,53 @@
+"""Add TOTP and WEBAUTHN to authmethodtype enum.
+
+Revision ID: 008
+Revises: 007
+Create Date: 2026-02-27 15:00:00.000000
+
+The original migration (001_base) created authmethodtype with only:
+ PASSWORD, GOOGLE, GITHUB, MICROSOFT, SAML, OIDC
+
+This migration adds the missing TOTP and WEBAUTHN values so
+has_totp_enabled() and has_webauthn_enabled() queries work correctly.
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+revision = '008'
+down_revision = '007'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # Add TOTP to the enum (idempotent approach using DO block)
+ op.execute("""
+ DO $$
+ BEGIN
+ IF NOT EXISTS (
+ SELECT 1 FROM pg_enum
+ WHERE enumlabel = 'TOTP'
+ AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype')
+ ) THEN
+ ALTER TYPE authmethodtype ADD VALUE 'TOTP';
+ END IF;
+ END$$;
+ """)
+ op.execute("""
+ DO $$
+ BEGIN
+ IF NOT EXISTS (
+ SELECT 1 FROM pg_enum
+ WHERE enumlabel = 'WEBAUTHN'
+ AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype')
+ ) THEN
+ ALTER TYPE authmethodtype ADD VALUE 'WEBAUTHN';
+ END IF;
+ END$$;
+ """)
+
+
+def downgrade():
+ # PostgreSQL does not support removing enum values; downgrade is a no-op.
+ pass
diff --git a/migrations/versions/009_sync_auditaction_enum.py b/migrations/versions/009_sync_auditaction_enum.py
new file mode 100644
index 0000000..59bc210
--- /dev/null
+++ b/migrations/versions/009_sync_auditaction_enum.py
@@ -0,0 +1,61 @@
+"""Sync auditaction enum with all AuditAction Python enum values.
+
+Revision ID: 009
+Revises: 008
+Create Date: 2026-02-27 15:20:00.000000
+
+The auditaction DB enum was only created with the initial 17 values from 001_base.py.
+All TOTP, WebAuthn, OAuth, SSH, CA, Principal, and Department audit actions were added
+to the Python enum but never synced to the DB type.
+"""
+from alembic import op
+
+
+revision = '009'
+down_revision = '008'
+branch_labels = None
+depends_on = None
+
+MISSING_VALUES = [
+ 'TOTP_ENROLL_INITIATED', 'TOTP_ENROLL_COMPLETED', 'TOTP_VERIFY_SUCCESS',
+ 'TOTP_VERIFY_FAILED', 'TOTP_DISABLED', 'TOTP_BACKUP_CODE_USED',
+ 'TOTP_BACKUP_CODES_REGENERATED', 'WEBAUTHN_REGISTER_INITIATED',
+ 'WEBAUTHN_REGISTER_COMPLETED', 'WEBAUTHN_REGISTER_FAILED',
+ 'WEBAUTHN_LOGIN_INITIATED', 'WEBAUTHN_LOGIN_SUCCESS', 'WEBAUTHN_LOGIN_FAILED',
+ 'WEBAUTHN_CREDENTIAL_DELETED', 'WEBAUTHN_CREDENTIAL_RENAMED',
+ 'ORG_SECURITY_POLICY_UPDATE', 'USER_SECURITY_POLICY_OVERRIDE_UPDATE',
+ 'MFA_POLICY_USER_SUSPENDED', 'MFA_POLICY_USER_COMPLIANT',
+ 'EXTERNAL_AUTH_LINK_INITIATED', 'EXTERNAL_AUTH_LINK_COMPLETED',
+ 'EXTERNAL_AUTH_LINK_FAILED', 'EXTERNAL_AUTH_UNLINK', 'EXTERNAL_AUTH_LOGIN',
+ 'EXTERNAL_AUTH_LOGIN_FAILED', 'EXTERNAL_AUTH_TOKEN_REFRESH',
+ 'EXTERNAL_AUTH_CONFIG_CREATE', 'EXTERNAL_AUTH_CONFIG_UPDATE',
+ 'EXTERNAL_AUTH_CONFIG_DELETE', 'SSH_KEY_ADDED', 'SSH_KEY_VERIFIED',
+ 'SSH_KEY_DELETED', 'SSH_KEY_VALIDATION_FAILED', 'SSH_CERT_REQUESTED',
+ 'SSH_CERT_ISSUED', 'SSH_CERT_FAILED', 'SSH_CERT_REVOKED', 'SSH_CERT_EXPIRED',
+ 'CA_CREATED', 'CA_UPDATED', 'CA_DELETED', 'CA_KEY_ROTATED',
+ 'PRINCIPAL_CREATED', 'PRINCIPAL_UPDATED', 'PRINCIPAL_DELETED',
+ 'PRINCIPAL_MEMBER_ADDED', 'PRINCIPAL_MEMBER_REMOVED',
+ 'DEPARTMENT_CREATED', 'DEPARTMENT_UPDATED', 'DEPARTMENT_DELETED',
+ 'DEPARTMENT_MEMBER_ADDED', 'DEPARTMENT_MEMBER_REMOVED',
+]
+
+
+def upgrade():
+ for val in MISSING_VALUES:
+ op.execute(f"""
+ DO $$
+ BEGIN
+ IF NOT EXISTS (
+ SELECT 1 FROM pg_enum
+ WHERE enumlabel = '{val}'
+ AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
+ ) THEN
+ ALTER TYPE auditaction ADD VALUE '{val}';
+ END IF;
+ END$$;
+ """)
+
+
+def downgrade():
+ # PostgreSQL does not support removing enum values; downgrade is a no-op.
+ pass
diff --git a/migrations/versions/010_password_reset_email_verify.py b/migrations/versions/010_password_reset_email_verify.py
new file mode 100644
index 0000000..efb836e
--- /dev/null
+++ b/migrations/versions/010_password_reset_email_verify.py
@@ -0,0 +1,50 @@
+"""add password reset and email verification token tables
+
+Revision ID: 010_password_reset_email_verify
+Revises: 009_sync_auditaction_enum
+Create Date: 2025-01-01 00:00:00.000000
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '010_password_reset_email_verify'
+down_revision = '009'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ op.create_table(
+ 'password_reset_tokens',
+ sa.Column('id', sa.String(36), primary_key=True, nullable=False),
+ sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
+ sa.Column('token', sa.String(128), nullable=False, unique=True),
+ sa.Column('expires_at', sa.DateTime, nullable=False),
+ sa.Column('used_at', sa.DateTime, nullable=True),
+ sa.Column('created_at', sa.DateTime, nullable=False),
+ sa.Column('updated_at', sa.DateTime, nullable=False),
+ sa.Column('deleted_at', sa.DateTime, nullable=True),
+ )
+ op.create_index('ix_password_reset_tokens_user_id', 'password_reset_tokens', ['user_id'])
+ op.create_index('ix_password_reset_tokens_token', 'password_reset_tokens', ['token'])
+
+ op.create_table(
+ 'email_verification_tokens',
+ sa.Column('id', sa.String(36), primary_key=True, nullable=False),
+ sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
+ sa.Column('token', sa.String(128), nullable=False, unique=True),
+ sa.Column('expires_at', sa.DateTime, nullable=False),
+ sa.Column('used_at', sa.DateTime, nullable=True),
+ sa.Column('created_at', sa.DateTime, nullable=False),
+ sa.Column('updated_at', sa.DateTime, nullable=False),
+ sa.Column('deleted_at', sa.DateTime, nullable=True),
+ )
+ op.create_index('ix_email_verification_tokens_user_id', 'email_verification_tokens', ['user_id'])
+ op.create_index('ix_email_verification_tokens_token', 'email_verification_tokens', ['token'])
+
+
+def downgrade():
+ op.drop_table('email_verification_tokens')
+ op.drop_table('password_reset_tokens')
diff --git a/migrations/versions/011_org_invite_tokens.py b/migrations/versions/011_org_invite_tokens.py
new file mode 100644
index 0000000..003da63
--- /dev/null
+++ b/migrations/versions/011_org_invite_tokens.py
@@ -0,0 +1,38 @@
+"""add org_invite_tokens table
+
+Revision ID: 011_org_invite_tokens
+Revises: 010_password_reset_email_verify
+Create Date: 2025-01-01 00:00:00.000000
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+revision = '011_org_invite_tokens'
+down_revision = '010_password_reset_email_verify'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ op.create_table(
+ 'org_invite_tokens',
+ sa.Column('id', sa.String(36), primary_key=True, nullable=False),
+ sa.Column('organization_id', sa.String(36), sa.ForeignKey('organizations.id', ondelete='CASCADE'), nullable=False),
+ sa.Column('invited_by_id', sa.String(36), sa.ForeignKey('users.id', ondelete='SET NULL'), nullable=True),
+ sa.Column('email', sa.String(255), nullable=False),
+ sa.Column('role', sa.String(64), nullable=False, server_default='member'),
+ sa.Column('token', sa.String(128), nullable=False, unique=True),
+ sa.Column('expires_at', sa.DateTime, nullable=False),
+ sa.Column('accepted_at', sa.DateTime, nullable=True),
+ sa.Column('created_at', sa.DateTime, nullable=False),
+ sa.Column('updated_at', sa.DateTime, nullable=False),
+ sa.Column('deleted_at', sa.DateTime, nullable=True),
+ )
+ op.create_index('ix_org_invite_tokens_organization_id', 'org_invite_tokens', ['organization_id'])
+ op.create_index('ix_org_invite_tokens_email', 'org_invite_tokens', ['email'])
+ op.create_index('ix_org_invite_tokens_token', 'org_invite_tokens', ['token'])
+
+
+def downgrade():
+ op.drop_table('org_invite_tokens')
diff --git a/migrations/versions/012_ca_nullable_org_and_cert_serial.py b/migrations/versions/012_ca_nullable_org_and_cert_serial.py
new file mode 100644
index 0000000..8b7dd0e
--- /dev/null
+++ b/migrations/versions/012_ca_nullable_org_and_cert_serial.py
@@ -0,0 +1,33 @@
+"""Make CA.organization_id nullable (system CA) and add cert_id to sign response
+
+Revision ID: 012_ca_nullable_org_and_cert_serial
+Revises: 011_org_invite_tokens
+Create Date: 2025-01-01 00:00:00.000000
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+revision = '012_ca_nullable_org'
+down_revision = '011_org_invite_tokens'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # Allow CA records without an org (e.g. the global system-config CA)
+ with op.batch_alter_table('cas', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'organization_id',
+ existing_type=sa.String(36),
+ nullable=True,
+ )
+
+
+def downgrade():
+ with op.batch_alter_table('cas', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'organization_id',
+ existing_type=sa.String(36),
+ nullable=False,
+ )
diff --git a/migrations/versions/013_add_ca_type.py b/migrations/versions/013_add_ca_type.py
new file mode 100644
index 0000000..1df4bc2
--- /dev/null
+++ b/migrations/versions/013_add_ca_type.py
@@ -0,0 +1,42 @@
+"""Add ca_type column to cas table (user/host).
+
+Revision ID: 013
+Revises: d34bfb72844e
+Create Date: 2026-02-28 23:00:00.000000
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '013'
+down_revision = 'd34bfb72844e'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # Create the enum type first (PostgreSQL requires this)
+ ca_type_enum = sa.Enum('user', 'host', name='ca_type_enum')
+ ca_type_enum.create(op.get_bind(), checkfirst=True)
+
+ # Add ca_type column with a default of 'user' so existing CAs stay valid
+ op.add_column(
+ 'cas',
+ sa.Column(
+ 'ca_type',
+ ca_type_enum,
+ nullable=False,
+ server_default='user',
+ ),
+ )
+
+
+def downgrade():
+ op.drop_column('cas', 'ca_type')
+ # Drop the enum type (PostgreSQL only; SQLite ignores)
+ try:
+ op.execute("DROP TYPE IF EXISTS ca_type_enum")
+ except Exception:
+ pass
diff --git a/migrations/versions/014_add_dept_cert_policy.py b/migrations/versions/014_add_dept_cert_policy.py
new file mode 100644
index 0000000..58b6cb3
--- /dev/null
+++ b/migrations/versions/014_add_dept_cert_policy.py
@@ -0,0 +1,44 @@
+"""add_department_cert_policies
+
+Adds the department_cert_policies table which stores per-department
+SSH certificate issuance rules:
+ - whether users may choose their own expiry
+ - default and maximum expiry durations
+ - allowed SSH certificate extensions
+"""
+
+from alembic import op
+import sqlalchemy as sa
+
+revision = "014_add_dept_cert_policy"
+down_revision = "013"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ op.create_table(
+ "department_cert_policies",
+ sa.Column("id", sa.String(36), primary_key=True),
+ sa.Column("department_id", sa.String(36), sa.ForeignKey("departments.id"), nullable=False, unique=True),
+ # Whether users are allowed to specify their own expiry (up to max)
+ sa.Column("allow_user_expiry", sa.Boolean(), nullable=False, server_default="0"),
+ # Default validity in hours (used when user doesn't specify, or not allowed to)
+ sa.Column("default_expiry_hours", sa.Integer(), nullable=False, server_default="1"),
+ # Hard cap on validity; admin cannot be exceeded
+ sa.Column("max_expiry_hours", sa.Integer(), nullable=False, server_default="24"),
+ # JSON list of extension names that are enabled for this department
+ # e.g. ["permit-pty", "permit-agent-forwarding"]
+ sa.Column("allowed_extensions", sa.JSON(), nullable=False, server_default='["permit-pty","permit-agent-forwarding","permit-X11-forwarding","permit-port-forwarding","permit-user-rc"]'),
+ # Admin-defined custom extension names beyond the standard five
+ sa.Column("custom_extensions", sa.JSON(), nullable=False, server_default="[]"),
+ sa.Column("created_at", sa.DateTime(), nullable=True),
+ sa.Column("updated_at", sa.DateTime(), nullable=True),
+ sa.Column("deleted_at", sa.DateTime(), nullable=True),
+ )
+ op.create_index("idx_dept_cert_policy_dept", "department_cert_policies", ["department_id"])
+
+
+def downgrade():
+ op.drop_index("idx_dept_cert_policy_dept", "department_cert_policies")
+ op.drop_table("department_cert_policies")
diff --git a/migrations/versions/015_add_user_suspend_audit_actions.py b/migrations/versions/015_add_user_suspend_audit_actions.py
new file mode 100644
index 0000000..06834da
--- /dev/null
+++ b/migrations/versions/015_add_user_suspend_audit_actions.py
@@ -0,0 +1,37 @@
+"""Add USER_SUSPEND and USER_UNSUSPEND to auditaction enum.
+
+Revision ID: 015_add_user_suspend_audit_actions
+Revises: 014_add_dept_cert_policy
+Create Date: 2026-03-02
+
+USER_SUSPEND and USER_UNSUSPEND were added to the Python AuditAction enum
+but were never synced to the PostgreSQL auditaction type, causing a
+DataError (invalid enum value) whenever an admin suspends or unsuspends a user.
+"""
+from alembic import op
+
+revision = "015_user_suspend_audit"
+down_revision = "014_add_dept_cert_policy"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ for val in ("USER_SUSPEND", "USER_UNSUSPEND"):
+ op.execute(f"""
+ DO $$
+ BEGIN
+ IF NOT EXISTS (
+ SELECT 1 FROM pg_enum
+ WHERE enumlabel = '{val}'
+ AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
+ ) THEN
+ ALTER TYPE auditaction ADD VALUE '{val}';
+ END IF;
+ END$$;
+ """)
+
+
+def downgrade():
+ # PostgreSQL does not support removing enum values; downgrade is a no-op.
+ pass
diff --git a/migrations/versions/016_encrypt_existing_ca_keys.py b/migrations/versions/016_encrypt_existing_ca_keys.py
new file mode 100644
index 0000000..91acdda
--- /dev/null
+++ b/migrations/versions/016_encrypt_existing_ca_keys.py
@@ -0,0 +1,168 @@
+"""Encrypt existing plaintext CA private keys at rest.
+
+Revision ID: 016_encrypt_existing_ca_keys
+Revises: 015_add_user_suspend_audit_actions
+Create Date: 2026-03-02
+
+All CA private keys created before this migration were stored as plaintext PEM
+strings in the ``cas.private_key`` column. This migration detects those rows
+(by checking for the absence of the ``$fernet$`` prefix that encrypted values
+carry) and re-encrypts them with the key derived from ``CA_ENCRYPTION_KEY``.
+
+The migration is safe to re-run: already-encrypted rows are left untouched.
+
+Prerequisites
+-------------
+``CA_ENCRYPTION_KEY`` must be set in the environment before running this
+migration. The same value must be configured for the running application.
+
+To roll back to plaintext (downgrade):
+The ``downgrade()`` function decrypts all rows back to plaintext PEM. This is
+provided only for emergency rollback and should not be used in production once
+the system has been running with encrypted keys.
+"""
+import os
+import base64
+import hashlib
+import logging
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.orm import Session
+
+logger = logging.getLogger(__name__)
+
+# Alembic revision identifiers
+revision = "016_encrypt_ca_keys"
+down_revision = "015_user_suspend_audit"
+branch_labels = None
+depends_on = None
+
+_FERNET_PREFIX = "$fernet$"
+
+
+def _get_fernet():
+ """Build a Fernet instance from CA_ENCRYPTION_KEY env var."""
+ from cryptography.fernet import Fernet
+
+ raw_key = os.environ.get("CA_ENCRYPTION_KEY")
+ if not raw_key:
+ raise RuntimeError(
+ "CA_ENCRYPTION_KEY environment variable is not set. "
+ "Set it before running this migration."
+ )
+ key_bytes = base64.urlsafe_b64encode(hashlib.sha256(raw_key.encode()).digest())
+ return Fernet(key_bytes)
+
+
+def upgrade():
+ """Encrypt plaintext CA private keys."""
+ bind = op.get_bind()
+ session = Session(bind=bind)
+
+ try:
+ fernet = _get_fernet()
+ except RuntimeError as exc:
+ raise RuntimeError(str(exc)) from exc
+
+ # Fetch all non-deleted CA rows
+ rows = session.execute(
+ sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
+ ).fetchall()
+
+ encrypted_count = 0
+ skipped_count = 0
+
+ for row in rows:
+ ca_id, private_key = row[0], row[1]
+
+ if not private_key:
+ logger.warning(f"CA {ca_id} has empty private_key — skipping")
+ skipped_count += 1
+ continue
+
+ if private_key.startswith(_FERNET_PREFIX):
+ # Already encrypted
+ skipped_count += 1
+ continue
+
+ # Encrypt
+ try:
+ token = fernet.encrypt(private_key.encode()).decode()
+ encrypted_value = f"{_FERNET_PREFIX}{token}"
+ session.execute(
+ sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
+ {"pk": encrypted_value, "id": ca_id},
+ )
+ encrypted_count += 1
+ logger.info(f"Encrypted private key for CA {ca_id}")
+ except Exception as exc:
+ session.rollback()
+ raise RuntimeError(
+ f"Failed to encrypt private key for CA {ca_id}: {exc}"
+ ) from exc
+
+ session.commit()
+ logger.info(
+ f"CA key encryption migration complete: "
+ f"{encrypted_count} encrypted, {skipped_count} skipped"
+ )
+ print(
+ f" [016_encrypt_ca_keys] {encrypted_count} CA private key(s) encrypted, "
+ f"{skipped_count} already encrypted or empty."
+ )
+
+
+def downgrade():
+ """Decrypt CA private keys back to plaintext (emergency rollback only)."""
+ bind = op.get_bind()
+ session = Session(bind=bind)
+
+ try:
+ fernet = _get_fernet()
+ except RuntimeError as exc:
+ raise RuntimeError(str(exc)) from exc
+
+ rows = session.execute(
+ sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
+ ).fetchall()
+
+ decrypted_count = 0
+ skipped_count = 0
+
+ for row in rows:
+ ca_id, private_key = row[0], row[1]
+
+ if not private_key or not private_key.startswith(_FERNET_PREFIX):
+ skipped_count += 1
+ continue
+
+ token = private_key[len(_FERNET_PREFIX):]
+ try:
+ from cryptography.fernet import InvalidToken
+ try:
+ plaintext = fernet.decrypt(token.encode()).decode()
+ except InvalidToken as exc:
+ raise RuntimeError(
+ f"Downgrade failed: cannot decrypt CA {ca_id} — wrong key or corrupted data."
+ ) from exc
+
+ session.execute(
+ sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
+ {"pk": plaintext, "id": ca_id},
+ )
+ decrypted_count += 1
+ logger.warning(f"Decrypted (plaintext restore) private key for CA {ca_id}")
+ except RuntimeError:
+ session.rollback()
+ raise
+
+ session.commit()
+ logger.warning(
+ f"CA key decryption (downgrade) complete: "
+ f"{decrypted_count} decrypted, {skipped_count} skipped"
+ )
+ print(
+ f" [016_encrypt_ca_keys] DOWNGRADE: {decrypted_count} CA private key(s) "
+ f"decrypted to plaintext. WARNING: keys are now unencrypted at rest."
+ )
diff --git a/migrations/versions/017_add_ca_serial_counter.py b/migrations/versions/017_add_ca_serial_counter.py
new file mode 100644
index 0000000..94b85e9
--- /dev/null
+++ b/migrations/versions/017_add_ca_serial_counter.py
@@ -0,0 +1,37 @@
+"""Add monotonic serial counter to CAs table.
+
+Each CA now owns a `next_serial_number` (BigInteger) that is atomically
+incremented every time a certificate is signed. This guarantees:
+ - Serials are unique per CA
+ - Serials are monotonically increasing (auditable, no gaps by accident)
+ - The value embedded in the OpenSSH certificate matches what is stored
+ in the `ssh_certificates.serial` column
+
+Revision ID: 017_add_ca_serial_counter
+Revises: 016_encrypt_ca_keys
+Create Date: 2026-03-02
+"""
+from alembic import op
+import sqlalchemy as sa
+
+revision = "017_add_ca_serial_counter"
+down_revision = "016_encrypt_ca_keys"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ with op.batch_alter_table("cas", schema=None) as batch_op:
+ batch_op.add_column(
+ sa.Column(
+ "next_serial_number",
+ sa.BigInteger(),
+ nullable=False,
+ server_default="1",
+ )
+ )
+
+
+def downgrade():
+ with op.batch_alter_table("cas", schema=None) as batch_op:
+ batch_op.drop_column("next_serial_number")
diff --git a/migrations/versions/018_add_ownership_and_hard_delete_audit_actions.py b/migrations/versions/018_add_ownership_and_hard_delete_audit_actions.py
new file mode 100644
index 0000000..7cbbd32
--- /dev/null
+++ b/migrations/versions/018_add_ownership_and_hard_delete_audit_actions.py
@@ -0,0 +1,52 @@
+"""Add ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE to auditaction enum.
+
+Revision ID: 018_audit_enum_values
+Revises: 017_add_ca_serial_counter
+Create Date: 2026-03-02
+
+ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE were added to the Python
+AuditAction enum but were never synced to the PostgreSQL auditaction type,
+causing a DataError (invalid enum value) when transferring org ownership
+or hard-deleting a user.
+"""
+from alembic import op
+
+revision = "018_audit_enum_values"
+down_revision = "017_add_ca_serial_counter"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ALTER TYPE ... ADD VALUE cannot run inside a transaction block in PostgreSQL.
+ # Alembic has already opened a transaction on the connection by the time our
+ # upgrade() runs, so we must:
+ # 1. Roll back that open transaction on the raw psycopg2 connection.
+ # 2. Switch to autocommit so the ALTER TYPE runs outside any transaction.
+ # 3. Restore the previous state afterwards.
+ conn = op.get_bind()
+ # SQLAlchemy 2.x: conn.connection is a _ConnectionFairy; .driver_connection is psycopg2
+ fairy = conn.connection
+ raw = getattr(fairy, "driver_connection", None) or getattr(fairy, "dbapi_connection", fairy)
+ # Roll back the open transaction so psycopg2 allows us to change autocommit.
+ raw.rollback()
+ old_autocommit = raw.autocommit
+ raw.autocommit = True
+ try:
+ with raw.cursor() as cur:
+ for val in ("ORG_OWNERSHIP_TRANSFERRED", "USER_HARD_DELETE"):
+ cur.execute(
+ "SELECT 1 FROM pg_enum "
+ "WHERE enumlabel = %s "
+ "AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')",
+ (val,),
+ )
+ if not cur.fetchone():
+ cur.execute(f"ALTER TYPE auditaction ADD VALUE '{val}'")
+ finally:
+ raw.autocommit = old_autocommit
+
+
+def downgrade():
+ # PostgreSQL does not support removing enum values; downgrade is a no-op.
+ pass
diff --git a/migrations/versions/d34bfb72844e_add_activation_fields_and_ca_permissions.py b/migrations/versions/d34bfb72844e_add_activation_fields_and_ca_permissions.py
new file mode 100644
index 0000000..83d9f72
--- /dev/null
+++ b/migrations/versions/d34bfb72844e_add_activation_fields_and_ca_permissions.py
@@ -0,0 +1,50 @@
+"""add_activation_fields_and_ca_permissions
+
+Revision ID: d34bfb72844e
+Revises: 012_ca_nullable_org
+Create Date: 2026-02-28 18:06:47.328552
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision = 'd34bfb72844e'
+down_revision = '012_ca_nullable_org'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # Create ca_permissions table
+ op.create_table(
+ 'ca_permissions',
+ sa.Column('ca_id', sa.String(length=36), nullable=False),
+ sa.Column('user_id', sa.String(length=36), nullable=False),
+ sa.Column('permission', sa.String(length=50), nullable=False),
+ sa.Column('id', sa.String(length=36), nullable=False),
+ sa.Column('created_at', sa.DateTime(), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
+ sa.Column('deleted_at', sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ondelete='CASCADE'),
+ sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
+ sa.PrimaryKeyConstraint('id'),
+ sa.UniqueConstraint('ca_id', 'user_id', name='uix_ca_permission'),
+ )
+ op.create_index('ix_ca_permissions_ca_id', 'ca_permissions', ['ca_id'], unique=False)
+ op.create_index('ix_ca_permissions_user_id', 'ca_permissions', ['user_id'], unique=False)
+
+ # Add activation columns to users
+ op.add_column('users', sa.Column('activated', sa.Boolean(), nullable=False,
+ server_default=sa.text('true')))
+ op.add_column('users', sa.Column('activation_key', sa.String(length=128), nullable=True))
+ op.create_index('ix_users_activation_key', 'users', ['activation_key'], unique=True)
+
+
+def downgrade():
+ op.drop_index('ix_users_activation_key', table_name='users')
+ op.drop_column('users', 'activation_key')
+ op.drop_column('users', 'activated')
+ op.drop_index('ix_ca_permissions_user_id', table_name='ca_permissions')
+ op.drop_index('ix_ca_permissions_ca_id', table_name='ca_permissions')
+ op.drop_table('ca_permissions')
diff --git a/requirements/base.txt b/requirements/base.txt
index 2f9c20d..1f6591d 100644
--- a/requirements/base.txt
+++ b/requirements/base.txt
@@ -14,7 +14,7 @@ Flask-Marshmallow==0.15.0
marshmallow-sqlalchemy==0.29.0
# Security
-bcrypt==4.1.2
+bcrypt==4.2.0
Flask-Bcrypt==1.0.1
pyotp==2.9.0
@@ -24,7 +24,7 @@ cbor2==5.6.0
# JWT / OIDC
PyJWT==2.8.0
-cryptography==41.0.7
+cryptography==42.0.7
# CORS
Flask-CORS==4.0.0
@@ -47,4 +47,7 @@ Flask-Limiter==3.5.0
# Logging
python-json-logger==2.0.7
-qrcode[pil]
\ No newline at end of file
+qrcode[pil]
+
+# SSH CA Certificate signing
+sshkey-tools==0.11.3
diff --git a/requirements/development.txt b/requirements/development.txt
index bccb4ec..8c626fd 100644
--- a/requirements/development.txt
+++ b/requirements/development.txt
@@ -20,3 +20,25 @@ watchdog==3.0.0
# Documentation
sphinx==7.2.6
+
+# Web framework & Database
+Flask==3.0.0
+Flask-SQLAlchemy==3.1.1
+Flask-Migrate==4.0.5
+sqlalchemy-cockroachdb==2.0.3
+
+# Utilities
+colorlog==6.8.0
+coloredlogs==15.0.1
+prettytable==3.10.2
+tabulate==0.9.0
+requests==2.31.0
+pytz==2023.3
+python-dotenv==1.0.0
+pydantic==2.5.0
+PyJWT==2.8.0
+cryptography==42.0.7
+pycryptodome==3.20.0
+psycopg2-binary==2.9.9
+sshkey-tools==0.11.3
+sendgrid==6.11.0
diff --git a/scripts/seed_data.py b/scripts/seed_data.py
index 97450fa..e8c1d04 100644
--- a/scripts/seed_data.py
+++ b/scripts/seed_data.py
@@ -15,11 +15,11 @@ load_dotenv()
from gatehouse_app import create_app
from gatehouse_app.extensions import db
-from gatehouse_app.models.user import User
-from gatehouse_app.models.organization import Organization
-from gatehouse_app.models.organization_member import OrganizationMember
-from gatehouse_app.models.authentication_method import AuthenticationMethod
-from gatehouse_app.models.oidc_client import OIDCClient
+from gatehouse_app.models.user.user import User
+from gatehouse_app.models.organization.organization import Organization
+from gatehouse_app.models.organization.organization_member import OrganizationMember
+from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
+from gatehouse_app.models.oidc.oidc_client import OIDCClient
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.utils.constants import OrganizationRole, UserStatus, AuthMethodType