Merge pull request #2 from jamesii-b/gatehouse/secuird-CA-merge-v2.01

Gatehouse with secuird CA Merge (Gatehouse Isolated)
This commit is contained in:
2026-03-03 13:52:52 +10:30
committed by GitHub
97 changed files with 11655 additions and 819 deletions
+6 -2
View File
@@ -40,5 +40,9 @@ LOG_TO_STDOUT=True
RATELIMIT_ENABLED=True
RATELIMIT_STORAGE_URL=redis://localhost:6379/1
# Testing
TESTING=False
# SSH CA
# Path to CA private key file (alternative to SSH_CA_PRIVATE_KEY env var)
SSH_CA_KEY_PATH=/path/to/ca-users
# Or set the key content directly (takes priority over SSH_CA_KEY_PATH):
# SSH_CA_PRIVATE_KEY=
+528
View File
@@ -0,0 +1,528 @@
#!/usr/bin/python3
import base64
import os
import sys
import webbrowser
import requests
import argparse
import jwt
import json
import datetime
import pytz
from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib.parse import urlparse, parse_qsl
from dotenv import load_dotenv
from sshkey_tools.cert import SSHCertificate
import logging
import coloredlogs
import subprocess
# Load environment variables from the .env file
load_dotenv()
# Get the API_URL from the environment variables
SIGN_URL = os.getenv("SIGN_URL", "http://localhost:5000")
LISTENER_HOST_NAME = "127.0.0.1"
LISTENER_SERVER_PORT = 8250
CACHE_FILE = os.path.expanduser('~/.gatehouse/token_cache.json')
os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True)
CERT_FILE_PATH = "/tmp/ssh-cert"
CHALLENGE_FILE_PATH = "/tmp/challenge.txt"
CHALLENGE_SIG_FILE_PATH = "/tmp/challenge.txt.sig"
# Configure logger
logger = logging.getLogger(__name__)
coloredlogs.install(level='DEBUG', logger=logger, fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
token = ""
def auth_headers(content_type="application/json"):
"""Return auth headers using the current cached token."""
return {"Authorization": f"Bearer {token}", "Content-Type": content_type}
class MyServer(BaseHTTPRequestHandler):
def do_GET(self):
"""Handle GET requests and process token reception."""
global server_done, token
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(bytes("<html><head><title>OIDC Workflow Tool</title></head>", "utf-8"))
self.wfile.write(bytes("<body><p>The token has been received</p>", "utf-8"))
self.wfile.write(bytes("<p>You may now close this window.</p>", "utf-8"))
self.wfile.write(bytes("</body></html>", "utf-8"))
parsed_url = urlparse(self.path)
query_data = dict(parse_qsl(parsed_url.query))
received_token = query_data.get('token')
if received_token:
token = received_token
server_done = True
logger.info("Token received")
save_token_to_cache(token)
def log_message(self, format, *args):
"""Log messages using the logger instead of stdout."""
logger.info("%s - %s" % (self.client_address[0], format % args))
def load_token_from_cache():
"""Load the token from the cache file."""
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'r') as f:
data = json.load(f)
if 'token' in data:
return data['token']
return None
def save_token_to_cache(token):
"""Save the token to the cache file."""
with open(CACHE_FILE, 'w') as f:
json.dump({'token': token}, f)
def clear_token_cache():
"""Remove the cached token file."""
if os.path.exists(CACHE_FILE):
os.remove(CACHE_FILE)
logger.info("Cached token removed.")
else:
logger.info("No cached token found.")
def decode_and_validate_token(token):
"""Decode the JWT and validate its claims.
Returns True if the token is a valid, non-expired JWT.
Returns False if the token is not a JWT (e.g. opaque session token)
or if it has expired — callers should then fall back to /auth/me.
"""
try:
decoded_token = jwt.decode(token, options={"verify_signature": False})
except jwt.exceptions.DecodeError:
# Not a JWT — likely an opaque session token; let /auth/me handle it.
return False
except Exception as e:
logger.debug(f"Unexpected JWT decode error: {e}")
return False
iat = decoded_token.get('iat')
exp = decoded_token.get('exp')
if iat is None or exp is None:
logger.debug("JWT is missing 'iat' or 'exp' claims — treating as invalid.")
return False
now = datetime.datetime.now(pytz.UTC)
exp_dt = datetime.datetime.fromtimestamp(exp, pytz.UTC)
iat_dt = datetime.datetime.fromtimestamp(iat, pytz.UTC)
logger.debug(f"JWT iat={iat_dt.isoformat()} exp={exp_dt.isoformat()}")
if exp_dt < now:
logger.debug("JWT has expired.")
return False
if iat_dt > now:
logger.debug("JWT 'iat' is in the future — clock skew?")
return True
def request_token():
global server_done, token
server_done = False
logger.info("Starting request_token process.")
# Attempt to load the token from the cache
token = load_token_from_cache()
logger.debug("Token loaded from cache: %s", token)
# Validate the cached token, if it exists
if token:
try:
if decode_and_validate_token(token):
logger.info("Cached token is valid. Using cached token.")
return token
except Exception:
pass
# Try opaque token via /auth/me
try:
r = requests.get(
f"{SIGN_URL}/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"},
timeout=5,
)
if r.status_code == 200:
logger.info("Cached session token is valid. Using cached token.")
return token
except Exception:
pass
logger.info("Cached token is expired or invalid, requesting a new token.")
token = ""
# Prepare the redirect URL for the token request
redirect_url = f"http://{LISTENER_HOST_NAME}:{LISTENER_SERVER_PORT}/?token="
logger.info("Redirect URL: %s", redirect_url)
# Construct the token request URL
token_url = f"{SIGN_URL}/api/v1/token_please?redirect_url={redirect_url}"
logger.info("Token request URL: %s", token_url)
# Start the web server to handle the token response
logger.debug("Starting the HTTP server on %s:%d", LISTENER_HOST_NAME, LISTENER_SERVER_PORT)
webServer = HTTPServer((LISTENER_HOST_NAME, LISTENER_SERVER_PORT), MyServer)
# Open the web browser to initiate the token request
logger.info("Opening web browser to request token.")
webbrowser.open(token_url, new=2)
# Wait for the server to handle the request and receive the token
logger.debug("Waiting for the token response...")
while not server_done:
webServer.handle_request()
logger.debug("Server handled a request, server_done status: %s", server_done)
logger.info("Token received: %s", token)
return token
def get_activated_ssh_key():
"""Retrieve the list of SSH keys and return the ID of a verified key."""
try:
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers())
if response.status_code != 200:
logger.error(f"Failed to retrieve SSH keys: {response.status_code} - {response.text}")
exit(1)
keys = response.json().get('data', {}).get('keys', [])
verified_keys = [k for k in keys if k['verified']]
if not verified_keys:
logger.error("No verified SSH keys found for the user.")
exit(1)
if len(verified_keys) > 1 and sys.stdout.isatty():
print("\nMultiple verified SSH keys found. Please choose one:")
for i, k in enumerate(verified_keys):
print(f" [{i+1}] {k['id'][:8]}... fingerprint={k.get('fingerprint','?')} name={k.get('key_comment','?')}")
try:
choice = int(input("Enter number: ").strip()) - 1
if 0 <= choice < len(verified_keys):
return verified_keys[choice]['id']
except (ValueError, EOFError):
pass
logger.info("Invalid choice; using the most recently added key.")
verified_keys.sort(key=lambda k: k.get('created_at', ''), reverse=True)
return verified_keys[0]['id']
except SystemExit:
raise
except Exception as e:
logger.error(f"Error while retrieving SSH keys: {e}")
exit(1)
def fetch_my_principals():
"""Fetch all principal names the current user is entitled to from the API.
For regular members: returns their assigned principals.
For org admins/owners: returns all principals in the org (they can sign for any).
"""
global token
response = requests.get(
f"{SIGN_URL}/api/v1/users/me/principals",
headers={"Authorization": f"Bearer {token}"},
timeout=10,
)
if response.status_code != 200:
logger.error(f"Failed to fetch principals from server: {response.status_code} - {response.text}")
exit(1)
orgs = response.json().get("data", {}).get("orgs", [])
principal_names = []
for org in orgs:
# Admins/owners get all principals; regular members get only their assigned ones
if org.get("is_admin"):
source = org.get("all_principals", [])
else:
source = org.get("my_principals", [])
for p in source:
if p["name"] not in principal_names:
principal_names.append(p["name"])
return principal_names
def request_certificate():
CERT_ID = os.getenv("CERT_ID") or get_activated_ssh_key()
principals = fetch_my_principals()
if not principals:
logger.error("You have no principals assigned. Contact your org admin.")
exit(1)
logger.info(f"Requesting certificate for principals: {', '.join(principals)}")
headers = {
'content-type': 'application/json',
"Authorization": "bearer " + token
}
payload = {
'cert_id': CERT_ID,
'principals': principals,
}
try:
response = requests.post(f"{SIGN_URL}/api/v1/ssh/sign", json=payload, headers=headers)
if response.status_code == 201:
json_result = response.json().get('data', response.json())
with open(CERT_FILE_PATH, 'w') as f:
f.write(json_result['certificate'])
logger.info(f"Certificate signed successfully, located at {CERT_FILE_PATH}")
logger.info(f"Valid for principals: {', '.join(json_result.get('principals', principals))}")
logger.info("You can login to your destination server with the following command")
logger.info(f"\tssh user@server -o CertificateFile={CERT_FILE_PATH}")
else:
logger.error("Error in response from server")
logger.error(f"Status code: {response.status_code}")
logger.error(f"Response text: {response.text}")
except Exception as e:
logger.error(f"Error during certificate signing: {e}")
def generate_and_sign_challenge(ssh_key_file, key_id):
"""Fetch a challenge from the server, sign it with the SSH key, and submit the signature."""
logger.debug(f"generate_and_sign_challenge - {ssh_key_file} {key_id}")
# Fetch challenge text
try:
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify", headers=auth_headers())
if response.status_code != 200:
logger.error(f"Server returned unexpected code {response.status_code}")
return False
resp_json = response.json()
data = resp_json.get('data', resp_json)
challenge_text = data.get('challenge_text', data.get('validationText', '')) + "\n"
except Exception as e:
logger.error(f"Unable to fetch SSH Key validation data: {e}")
return False
# Sign the challenge
try:
for path in (CHALLENGE_FILE_PATH, CHALLENGE_SIG_FILE_PATH):
if os.path.exists(path):
os.remove(path)
with open(CHALLENGE_FILE_PATH, 'w') as f:
f.write(challenge_text)
subprocess.run(
["ssh-keygen", "-Y", "sign", "-f", ssh_key_file, "-n", "file", CHALLENGE_FILE_PATH],
check=True,
)
with open(CHALLENGE_SIG_FILE_PATH, 'rb') as f:
signature = base64.b64encode(f.read()).decode('utf-8')
except Exception as e:
logger.error(f"Unable to sign the challenge response: {e}")
return False
# Submit signature
try:
response = requests.post(
f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify",
headers=auth_headers(),
json={"signature": signature},
)
if response.status_code == 200:
logger.info("SSH key verified successfully.")
else:
logger.error(f"Verification failed: {response.status_code} - {response.text}")
except Exception as e:
logger.error(f"Unable to submit the challenge response: {e}")
return signature
def remove_ssh_key(key_id=None):
"""
Remove an SSH key from the server. If key_id is None, list keys and prompt user to pick one.
"""
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers())
if response.status_code != 200:
logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}")
exit(1)
keys = response.json().get('data', {}).get('keys', [])
if not keys:
logger.info("No SSH keys found for your user.")
return
if key_id:
target = next((k for k in keys if k['id'] == key_id), None)
if not target:
logger.error(f"Key ID {key_id} not found in your profile.")
exit(1)
keys_to_delete = [target]
else:
print("\nYour SSH keys:")
for i, k in enumerate(keys):
verified = "✓ verified" if k['verified'] else "✗ unverified"
print(f" [{i+1}] {k['id']} {verified} {k.get('description', '')} (added {k['created_at'][:10]})")
print(" [a] Delete ALL keys")
print(" [q] Quit")
choice = input("\nEnter number to delete (or 'a' for all, 'q' to quit): ").strip().lower()
if choice == 'q':
return
elif choice == 'a':
keys_to_delete = keys
else:
try:
idx = int(choice) - 1
if idx < 0 or idx >= len(keys):
raise ValueError()
keys_to_delete = [keys[idx]]
except ValueError:
logger.error("Invalid selection.")
exit(1)
for k in keys_to_delete:
del_response = requests.delete(f"{SIGN_URL}/api/v1/ssh/keys/{k['id']}", headers=auth_headers())
if del_response.status_code == 200:
logger.info(f"Key {k['id']} removed successfully.")
else:
logger.error(f"Failed to remove key {k['id']}: {del_response.status_code} - {del_response.text}")
def add_ssh_key(ssh_key_file):
"""Add an SSH key to the server and auto-verify it."""
if hasattr(ssh_key_file, 'read'):
key_bytes = ssh_key_file.read()
key_path = ssh_key_file.name
elif isinstance(ssh_key_file, bytes):
key_bytes = ssh_key_file
key_path = None
else:
key_path = str(ssh_key_file)
with open(key_path, 'rb') as f:
key_bytes = f.read()
ssh_key = key_bytes.decode('utf-8').strip()
payload = {
'description': 'Added via gatehouse CLI tool',
'key': ssh_key,
}
response = requests.post(f"{SIGN_URL}/api/v1/ssh/keys", json=payload, headers=auth_headers())
if response.status_code == 201:
ssh_key_id = response.json().get('data', {}).get('id')
logger.info(f"SSH key {ssh_key_id} added successfully")
if key_path:
private_key_path = key_path[:-4] if key_path.endswith('.pub') else key_path
generate_and_sign_challenge(private_key_path, ssh_key_id)
else:
logger.warning("No key file path available — skipping auto-verification. "
"Run with -k <path> to enable automatic key verification.")
else:
logger.error(f"Failed to add SSH key: {response.status_code} - {response.text}")
def checkCert():
logger.info("Running cert check")
if not os.path.isfile(CERT_FILE_PATH):
logger.warning("Certificate does not exist, new certificate required")
return 1
try:
certificate = SSHCertificate.from_file(CERT_FILE_PATH)
except Exception:
logger.warning("Certificate file is invalid or corrupt, renewal required")
return 1
# Get the current datetime
now = datetime.datetime.now()
logger.debug(certificate
)
# Check if the date is in the past or future
if certificate.get("valid_before") > now:
# Expiry is in the future
if args.force:
return 0
else:
logger.info("You have a valid SSH Certificate with the principals {} expiring at {}, not renewing. Use -f to force renewal".format(certificate.get("principals"), certificate.get("valid_before")))
return 0
else:
logger.warning("Certificate is not valid, renewal required")
return 1
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Sign an SSH key via a web service')
parser.add_argument("-k", "--ssh-key", type=argparse.FileType('rb'), dest="sshkeyfile", help="Add an SSH Public Key to your user profile in gatehouse")
parser.add_argument("-f", "--force", action='store_true', default=False, help="Force the certificate renewal")
parser.add_argument("-a", "--add-key", action='store_true', default=False, help="Add SSH key to the server")
parser.add_argument("-c", "--check-cert", action='store_true', default=False, help="Check the certificate, if it's valid exit 0, if it's invalid exit 1")
parser.add_argument("-r", "--request-cert", action='store_true', default=False, help="Request that gatehouse sign a new certificate for you based on an SSH public key on file in your profile")
parser.add_argument("--clear-cache", action='store_true', default=False, help="Remove the cached authentication token")
parser.add_argument("--remove-key", nargs='?', const='', metavar='KEY_ID', help="Remove an SSH key from your profile. Omit KEY_ID to pick interactively.")
parser.add_argument("--list-keys", action='store_true', default=False, help="List SSH keys in your profile")
args = parser.parse_args()
if not (args.check_cert or args.request_cert or args.add_key or args.clear_cache
or args.remove_key is not None or args.list_keys):
parser.error("At least one of --check-cert, --request-cert, --add-key, --list-keys, --remove-key, or --clear-cache must be provided.")
# Retrieve SSH key from environment variables if not provided via CLI
ssh_key_file = args.sshkeyfile if args.sshkeyfile else os.getenv('SSH_KEY_FILE')
if args.check_cert:
logger.info("Only checking certificate")
exit(checkCert())
if args.clear_cache:
clear_token_cache()
exit(0)
if args.remove_key is not None:
request_token()
remove_ssh_key(args.remove_key if args.remove_key else None)
exit(0)
if args.list_keys:
request_token()
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers())
if response.status_code == 200:
keys = response.json().get('data', {}).get('keys', [])
if not keys:
print("No SSH keys found in your profile.")
else:
for k in keys:
verified = "✓ verified" if k.get('verified') else "✗ unverified"
print(f" {k['id']} {verified} {k.get('description', '')} (added {k['created_at'][:10]})")
else:
logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}")
exit(0)
if args.add_key:
request_token()
if not ssh_key_file:
logger.error("SSH key file is required to add SSH key")
exit(1)
# If ssh_key_file is retrieved from the environment, it will be a string (file path), so open it
if isinstance(ssh_key_file, str):
with open(ssh_key_file, 'rb') as f:
ssh_key_file = f.read()
add_ssh_key(ssh_key_file)
exit(0)
if args.request_cert:
request_token()
if args.force:
logger.info("Forcing renewal of certificate")
if args.force or checkCert() == 1:
request_certificate()
exit(0)
+20
View File
@@ -28,6 +28,11 @@ class BaseConfig:
# Encryption key for sensitive data (client secrets, tokens, etc.)
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "dev-encryption-key-change-in-production")
# Encryption key for CA private keys stored in the database.
# Must be set to a strong random secret in production.
# Any string is accepted — it is SHA-256 derived to a 32-byte Fernet key internally.
CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "dev-ca-encryption-key-change-in-production")
# Session configuration for WebAuthn cross-origin support
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "True").lower() == "true"
@@ -72,6 +77,13 @@ class BaseConfig:
RATELIMIT_STORAGE_URL = os.getenv("RATELIMIT_STORAGE_URL", "redis://localhost:6379/1")
RATELIMIT_DEFAULT = "100/hour"
# Per-endpoint auth rate limits (override via env vars for each environment)
RATELIMIT_AUTH_REGISTER = os.getenv("RATELIMIT_AUTH_REGISTER", "10 per minute; 50 per hour")
RATELIMIT_AUTH_LOGIN = os.getenv("RATELIMIT_AUTH_LOGIN", "20 per minute; 100 per hour")
RATELIMIT_AUTH_TOTP_VERIFY = os.getenv("RATELIMIT_AUTH_TOTP_VERIFY", "20 per minute; 100 per hour")
RATELIMIT_AUTH_FORGOT_PASSWORD = os.getenv("RATELIMIT_AUTH_FORGOT_PASSWORD", "5 per minute; 20 per hour")
RATELIMIT_AUTH_RESET_PASSWORD = os.getenv("RATELIMIT_AUTH_RESET_PASSWORD", "10 per minute; 30 per hour")
# Logging
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true"
@@ -116,3 +128,11 @@ class BaseConfig:
# Frontend URL (for OAuth callback redirects)
FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:8080")
# Email / SMTP
EMAIL_ENABLED = os.getenv("EMAIL_ENABLED", "False").lower() == "true"
SMTP_HOST = os.getenv("SMTP_HOST", "smtp.gmail.com")
SMTP_PORT = int(os.getenv("SMTP_PORT", "587"))
SMTP_USERNAME = os.getenv("SMTP_USERNAME", "")
SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
FROM_ADDRESS = os.getenv("FROM_ADDRESS", "noreply@gatehouse.local")
+12
View File
@@ -25,3 +25,15 @@ class DevelopmentConfig(BaseConfig):
"CORS_ORIGINS",
"http://localhost:8080,http://localhost:3000,http://localhost:5173,https://ui.webauthn.local"
).split(",")
# ── Email / SMTP ──────────────────────────────────────────────────────────
# Read from .env so real SMTP credentials work in dev.
# Set EMAIL_ENABLED=false in .env to disable; defaults to True if SMTP_HOST is set.
EMAIL_ENABLED = os.getenv("EMAIL_ENABLED", "True").lower() == "true"
SMTP_HOST = os.getenv("SMTP_HOST", "localhost")
SMTP_PORT = int(os.getenv("SMTP_PORT", "1025"))
SMTP_USERNAME = os.getenv("SMTP_USERNAME") or None
SMTP_PASSWORD = os.getenv("SMTP_PASSWORD") or None
SMTP_USE_TLS = os.getenv("SMTP_USE_TLS", "").lower() == "true" if os.getenv("SMTP_USE_TLS") else int(os.getenv("SMTP_PORT", "1025")) not in (25, 1025)
FROM_ADDRESS = os.getenv("FROM_ADDRESS", "noreply@gatehouse.local")
EMAIL_FROM = FROM_ADDRESS # alias
+3
View File
@@ -12,6 +12,9 @@ class TestingConfig(BaseConfig):
# Explicitly set SECRET_KEY for testing
SECRET_KEY = os.getenv("SECRET_KEY", "test-secret-key-for-testing")
# CA key encryption — use a fixed test key so tests are deterministic
CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "test-ca-encryption-key-fixed-for-tests")
# Use in-memory SQLite for testing
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
SQLALCHEMY_ECHO = False
+30
View File
@@ -0,0 +1,30 @@
[default]
# Certificate validity period (in hours)
cert_validity_hours=8
# Maximum certificate validity allowed (in hours)
max_cert_validity_hours=720
# CA private key path (required for local encryption mode)
ca_key_path=
# Certificate Field Limits
max_principals_per_cert=256
max_key_id_length=255
# Verification challenge max age (in hours)
verification_challenge_max_age=24
# Cleanup: delete unverified SSH keys after this many days
auto_delete_unverified_days=30
[development]
ca_key_path=${SSH_CA_KEY_PATH}
cert_validity_hours=24
[production]
cert_validity_hours=8
[testing]
cert_validity_hours=8
+6
View File
@@ -2,6 +2,9 @@
import os
import logging
from dotenv import load_dotenv
load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env'))
# Test debug logging - this should appear when running `flask run --debug`
_root_logger = logging.getLogger(__name__)
_root_logger.debug("[TEST] Debug logging is working!")
@@ -239,3 +242,6 @@ def initialize_oidc_jwks(app):
app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}")
except Exception as e:
app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}")
# Create default app instance for gunicorn/wsgi
app = create_app()
+71 -4
View File
@@ -21,8 +21,12 @@ from gatehouse_app.extensions import db
from gatehouse_app.extensions import bcrypt as flask_bcrypt
from gatehouse_app.extensions import redis_client as _redis_client_ref # may be None until app init
from gatehouse_app.models import User, OIDCClient
from gatehouse_app.models.organization import Organization
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.exceptions.auth_exceptions import (
InvalidCredentialsError,
AccountSuspendedError,
AccountInactiveError,
)
# ---------------------------------------------------------------------------
# Helpers for Redis-backed OIDC pending state
@@ -326,7 +330,7 @@ def oidc_complete():
400: invalid request
401: invalid token
"""
from gatehouse_app.models.session import Session as GHSession
from gatehouse_app.models.user.session import Session as GHSession
from gatehouse_app.utils.constants import SessionStatus
data = request.get_json(silent=True) or {}
@@ -343,6 +347,20 @@ def oidc_complete():
user_id = str(gh_session.user_id)
# Check the user is still active (not suspended after session was issued)
from gatehouse_app.models.user.user import User as _User
from gatehouse_app.utils.constants import UserStatus
_complete_user = _User.query.filter_by(id=user_id, deleted_at=None).first()
if not _complete_user or _complete_user.status in (
UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED, UserStatus.INACTIVE
):
return api_response(
success=False,
message="Your account is not active or has been suspended.",
status=403,
error_type="ACCOUNT_SUSPENDED",
)
# Retrieve stashed OIDC params (consume = True removes from Redis atomically)
params = _fetch_oidc_params(oidc_session_id, consume=True)
if not params:
@@ -565,6 +583,28 @@ def oidc_authorize():
session["oidc_user_id"] = user_id
logger.debug("[OIDC] User authentication successful: user_id=%s, email=%s", user_id, email)
except AccountSuspendedError:
logger.debug("[OIDC] User authentication failed: account suspended for email=%s", email)
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type,
error="Your account has been suspended. Please contact an administrator.",
)
except AccountInactiveError:
logger.debug("[OIDC] User authentication failed: account inactive for email=%s", email)
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type,
error="Your account is not active. Please verify your email.",
)
except InvalidCredentialsError:
logger.debug("[OIDC] User authentication failed: invalid credentials for email=%s", email)
return _show_login_page(
@@ -600,7 +640,34 @@ def oidc_authorize():
if not user:
logger.debug("[OIDC] Redirecting with error: server_error (user not found)")
return _redirect_with_error(redirect_uri, "server_error", "User not found", state)
# Check account is still active (user could have been suspended after session start)
from gatehouse_app.utils.constants import UserStatus as _UserStatus
if user.status in (_UserStatus.SUSPENDED, _UserStatus.COMPLIANCE_SUSPENDED):
session.pop("oidc_user_id", None) # clear stale session
logger.debug("[OIDC] User is suspended, clearing session and showing login error: user_id=%s", user_id)
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type,
error="Your account has been suspended. Please contact an administrator.",
)
if user.status == _UserStatus.INACTIVE:
session.pop("oidc_user_id", None)
logger.debug("[OIDC] User is inactive, clearing session and showing login error: user_id=%s", user_id)
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type,
error="Your account is not active. Please verify your email.",
)
logger.debug("[OIDC] Generating authorization code...")
logger.debug("[OIDC] Authorization code params: client_id=%s, user_id=%s, redirect_uri=%s", client_id, user_id, redirect_uri)
logger.debug("[OIDC] Authorization code params: scopes=%s, state=%s, nonce=%s", valid_scopes, state, nonce)
+3 -1
View File
@@ -5,4 +5,6 @@ from flask import Blueprint
api_v1_bp = Blueprint("api_v1", __name__)
# Import route modules to register them
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh
api_v1_bp.register_blueprint(ssh.ssh_bp)
+557 -10
View File
@@ -1,8 +1,10 @@
"""Authentication endpoints."""
import json
from flask import request, session, g, jsonify
import logging
from flask import request, session, g, jsonify, current_app
from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.extensions import limiter
from gatehouse_app.utils.response import api_response
from gatehouse_app.schemas.auth_schema import (
RegisterSchema,
@@ -23,6 +25,7 @@ from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.webauthn_service import WebAuthnService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
from gatehouse_app.services.notification_service import NotificationService
from gatehouse_app.utils.decorators import login_required
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
@@ -30,6 +33,7 @@ from gatehouse_app.exceptions.validation_exceptions import ConflictError, NotFou
@api_v1_bp.route("/auth/register", methods=["POST"])
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_REGISTER"])
def register():
"""
Register a new user.
@@ -57,14 +61,66 @@ def register():
full_name=data.get("full_name"),
)
# Send verification email
try:
from gatehouse_app.models import EmailVerificationToken
verify_token = EmailVerificationToken.generate(user_id=user.id)
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
verify_link = f"{app_url}/verify-email?token={verify_token.token}"
subject = "Verify your Gatehouse email address"
body = (
f"Hi {user.full_name or user.email},\n\n"
f"Welcome to Gatehouse! Please verify your email address by clicking the link below (valid for 24 hours):\n"
f"{verify_link}\n\n"
f"Gatehouse Security Team"
)
NotificationService._send_email(to_address=user.email, subject=subject, body=body)
except Exception as exc:
logging.getLogger(__name__).warning(f"Failed to send verification email on register: {exc}")
# Create session
user_session = AuthService.create_session(user)
# ── Post-registration hints ─────────────────────────────────────────
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
from gatehouse_app.models.user.user import User as _User
from datetime import datetime, timezone as _tz
now = datetime.now(_tz.utc)
pending_invites = OrgInviteToken.query.filter(
OrgInviteToken.email == user.email,
OrgInviteToken.accepted_at.is_(None),
OrgInviteToken.expires_at > now,
OrgInviteToken.deleted_at.is_(None),
).all()
# Determine if this is the very first user ever registered on this
# instance (exactly 1 active user means it must be this one).
total_users = _User.query.filter(_User.deleted_at.is_(None)).count()
is_first_user = total_users == 1
expires_str = user_session.expires_at.isoformat()
if expires_str[-1] != "Z":
expires_str += "Z"
return api_response(
data={
"user": user.to_dict(),
"token": user_session.token,
"expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(),
"expires_at": expires_str,
"is_first_user": is_first_user,
"pending_invites": [
{
"token": inv.token,
"organization": {
"id": str(inv.organization_id),
"name": inv.organization.name,
},
"role": inv.role,
"expires_at": inv.expires_at.isoformat(),
}
for inv in pending_invites
],
},
message="Registration successful",
status=201,
@@ -81,6 +137,7 @@ def register():
@api_v1_bp.route("/auth/login", methods=["POST"])
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_LOGIN"])
def login():
"""
Login user.
@@ -179,7 +236,9 @@ def login():
"organization_id": org.organization_id,
"organization_name": org.organization_name,
"status": org.status,
"effective_mode": org.effective_mode,
"deadline_at": org.deadline_at,
"applied_at": org.applied_at,
}
for org in policy_result.compliance_summary.orgs
],
@@ -189,6 +248,36 @@ def login():
if is_compliance_only:
response_data["requires_mfa_enrollment"] = True
# ── Org-setup hint for org-less users ────────────────────────────────
# If the user has no organisation memberships, surface any pending
# invitations so the UI can redirect straight to /org-setup instead of
# showing an empty dashboard.
user_orgs = user.get_organizations()
if not user_orgs:
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
from datetime import datetime, timezone as _tz
_now = datetime.now(_tz.utc)
pending_invites = OrgInviteToken.query.filter(
OrgInviteToken.email == user.email,
OrgInviteToken.accepted_at.is_(None),
OrgInviteToken.expires_at > _now,
OrgInviteToken.deleted_at.is_(None),
).all()
response_data["pending_invites"] = [
{
"token": inv.token,
"organization": {
"id": str(inv.organization_id),
"name": inv.organization.name,
},
"role": inv.role,
"expires_at": inv.expires_at.isoformat(),
}
for inv in pending_invites
]
# Flag so the UI knows to send this user through org-setup
response_data["requires_org_setup"] = True
return api_response(
data=response_data,
message="Login successful",
@@ -239,8 +328,13 @@ def get_current_user():
data={
"user": user.to_dict(),
"organizations": [
{"id": org.id, "name": org.name, "slug": org.slug}
for org in user.get_organizations()
{
"id": membership.organization.id,
"name": membership.organization.name,
"slug": membership.organization.slug,
"role": membership.role,
}
for membership in user.organization_memberships
],
},
message="User retrieved successfully",
@@ -284,7 +378,7 @@ def revoke_session(session_id):
401: Not authenticated
404: Session not found
"""
from gatehouse_app.models.session import Session
from gatehouse_app.models.user.session import Session
# Ensure session belongs to current user
user_session = Session.query.filter_by(
@@ -392,6 +486,7 @@ def verify_totp_enrollment():
@api_v1_bp.route("/auth/totp/verify", methods=["POST"])
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_TOTP_VERIFY"])
def verify_totp():
"""
Verify TOTP code during login.
@@ -424,7 +519,7 @@ def verify_totp():
)
# Get user from database
from gatehouse_app.models.user import User
from gatehouse_app.models.user.user import User
user = User.query.get(user_id)
if not user:
return api_response(
@@ -434,6 +529,18 @@ def verify_totp():
error_type="AUTHENTICATION_ERROR",
)
# Check account suspension before completing TOTP verification
from gatehouse_app.utils.constants import UserStatus
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
session.pop("totp_pending_user_id", None)
session.pop("webauthn_pending_user_id", None)
return api_response(
success=False,
message="Account is suspended. Contact an administrator.",
status=403,
error_type="ACCOUNT_SUSPENDED",
)
# Verify TOTP code
AuthService.authenticate_with_totp(
user,
@@ -475,7 +582,9 @@ def verify_totp():
"organization_id": org.organization_id,
"organization_name": org.organization_name,
"status": org.status,
"effective_mode": org.effective_mode,
"deadline_at": org.deadline_at,
"applied_at": org.applied_at,
}
for org in policy_result.compliance_summary.orgs
],
@@ -806,7 +915,7 @@ def begin_webauthn_login():
data = schema.load(request.json)
# Find user by email
from gatehouse_app.models.user import User
from gatehouse_app.models.user.user import User
user = User.query.filter_by(
email=data["email"].lower(),
deleted_at=None
@@ -820,7 +929,18 @@ def begin_webauthn_login():
status=404,
error_type="NOT_FOUND",
)
# Check account suspension before proceeding
from gatehouse_app.utils.constants import UserStatus
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
logger.warning(f"WebAuthn login begin - suspended account attempt: {user.email}")
return api_response(
success=False,
message="Account is suspended. Contact an administrator.",
status=403,
error_type="ACCOUNT_SUSPENDED",
)
# Check if user has any WebAuthn credentials
if not user.has_webauthn_enabled():
logger.warning(f"WebAuthn login begin - no credentials for user: {user.email}")
@@ -893,7 +1013,7 @@ def complete_webauthn_login():
data = schema.load(request.json)
# Get user from database
from gatehouse_app.models.user import User
from gatehouse_app.models.user.user import User
user = User.query.get(user_id)
if not user:
logger.error(f"WebAuthn login complete - user not found: {user_id}")
@@ -903,7 +1023,19 @@ def complete_webauthn_login():
status=401,
error_type="AUTHENTICATION_ERROR",
)
# Check account suspension before completing login
from gatehouse_app.utils.constants import UserStatus
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
session.pop("webauthn_pending_user_id", None)
logger.warning(f"WebAuthn login complete - suspended account attempt: {user.email}")
return api_response(
success=False,
message="Account is suspended. Contact an administrator.",
status=403,
error_type="ACCOUNT_SUSPENDED",
)
# Extract challenge from client data
client_data = data.get("response", {}).get("clientDataJSON", "")
@@ -962,7 +1094,9 @@ def complete_webauthn_login():
"organization_id": org.organization_id,
"organization_name": org.organization_name,
"status": org.status,
"effective_mode": org.effective_mode,
"deadline_at": org.deadline_at,
"applied_at": org.applied_at,
}
for org in policy_result.compliance_summary.orgs
],
@@ -1039,6 +1173,19 @@ def delete_webauthn_credential(credential_id):
"""
user = g.current_user
# First check that the specific credential actually belongs to this user.
# Only then check whether it is the last one — otherwise a user with zero
# credentials gets a misleading "Cannot delete the last passkey" error
# instead of a 404.
credential_exists = WebAuthnService.credential_belongs_to_user(credential_id, user)
if not credential_exists:
return api_response(
success=False,
message="Credential not found",
status=404,
error_type="NOT_FOUND",
)
# Check if this is the last credential
credential_count = user.get_webauthn_credential_count()
if credential_count <= 1:
@@ -1142,3 +1289,403 @@ def get_webauthn_status():
},
message="WebAuthn status retrieved successfully",
)
_pw_logger = logging.getLogger(__name__)
@api_v1_bp.route("/auth/forgot-password", methods=["POST"])
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_FORGOT_PASSWORD"])
def forgot_password():
"""Request a password reset email.
Always returns 200 to avoid leaking account existence.
Request body:
email: User email address
Returns:
200: Password reset email sent (or silently no-op if email not found)
"""
from gatehouse_app.models import User, PasswordResetToken
data = request.get_json() or {}
email = (data.get("email") or "").strip().lower()
if not email:
return api_response(
success=False,
message="Email is required",
status=400,
error_type="VALIDATION_ERROR",
)
# Always return 200 — don't leak whether the email exists
user = User.query.filter_by(email=email, deleted_at=None).first()
if user:
try:
reset_token = PasswordResetToken.generate(user_id=user.id)
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
reset_link = f"{app_url}/reset-password?token={reset_token.token}"
subject = "Reset your Gatehouse password"
body = (
f"Hi {user.full_name or user.email},\n\n"
f"You requested a password reset for your Gatehouse account.\n\n"
f"Click the link below to reset your password (valid for 2 hours):\n"
f"{reset_link}\n\n"
f"If you did not request this, you can safely ignore this email.\n\n"
f"Gatehouse Security Team"
)
NotificationService._send_email(
to_address=user.email,
subject=subject,
body=body,
)
_pw_logger.info(f"Password reset token generated for user {user.id}")
except Exception as exc:
_pw_logger.exception(f"Error generating password reset token: {exc}")
return api_response(
data={},
message="If an account exists for this email, you will receive a password reset link shortly.",
)
@api_v1_bp.route("/auth/reset-password", methods=["POST"])
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESET_PASSWORD"])
def reset_password():
"""Reset a user's password using a reset token.
Request body:
token: Password reset token from email
password: New password
password_confirm: Password confirmation
Returns:
200: Password reset successfully
400: Invalid or expired token / validation error
"""
import bcrypt as _bcrypt
from gatehouse_app.extensions import bcrypt
from gatehouse_app.models import PasswordResetToken, AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
data = request.get_json() or {}
token_value = (data.get("token") or "").strip()
new_password = data.get("password") or ""
password_confirm = data.get("password_confirm") or ""
if not token_value or not new_password:
return api_response(
success=False,
message="Token and new password are required",
status=400,
error_type="VALIDATION_ERROR",
)
if new_password != password_confirm:
return api_response(
success=False,
message="Passwords do not match",
status=400,
error_type="VALIDATION_ERROR",
)
if len(new_password) < 8:
return api_response(
success=False,
message="Password must be at least 8 characters",
status=400,
error_type="VALIDATION_ERROR",
)
reset_token = PasswordResetToken.query.filter_by(token=token_value).first()
if not reset_token or not reset_token.is_valid:
return api_response(
success=False,
message="This password reset link is invalid or has expired.",
status=400,
error_type="INVALID_TOKEN",
)
try:
user = reset_token.user
# Update the password hash on the authentication method
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
if auth_method:
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
from gatehouse_app.extensions import db
db.session.add(auth_method)
reset_token.consume()
_pw_logger.info(f"Password reset for user {user.id}")
return api_response(
data={},
message="Your password has been reset. You can now sign in with your new password.",
)
except Exception as exc:
_pw_logger.exception(f"Error resetting password: {exc}")
return api_response(
success=False,
message="An error occurred while resetting your password.",
status=500,
error_type="INTERNAL_ERROR",
)
@api_v1_bp.route("/auth/verify-email", methods=["POST"])
def verify_email():
"""Verify a user's email address using a verification token.
Request body:
token: Email verification token
Returns:
200: Email verified successfully
400: Invalid or expired token
"""
from gatehouse_app.models import EmailVerificationToken
data = request.get_json() or {}
token_value = (data.get("token") or "").strip()
if not token_value:
return api_response(
success=False,
message="Verification token is required",
status=400,
error_type="VALIDATION_ERROR",
)
verify_token = EmailVerificationToken.query.filter_by(token=token_value).first()
if not verify_token or not verify_token.is_valid:
return api_response(
success=False,
message="This verification link is invalid or has expired.",
status=400,
error_type="INVALID_TOKEN",
)
try:
user = verify_token.user
user.email_verified = True
from gatehouse_app.extensions import db
db.session.add(user)
verify_token.consume()
_pw_logger.info(f"Email verified for user {user.id}")
return api_response(
data={},
message="Your email has been verified. You can now sign in.",
)
except Exception as exc:
_pw_logger.exception(f"Error verifying email: {exc}")
return api_response(
success=False,
message="An error occurred while verifying your email.",
status=500,
error_type="INTERNAL_ERROR",
)
@api_v1_bp.route("/auth/resend-verification", methods=["POST"])
def resend_verification():
"""Resend email verification link.
Always returns 200 to avoid leaking account existence.
Request body:
email: User email address
Returns:
200: Verification email sent (or silently no-op)
"""
from gatehouse_app.models import User, EmailVerificationToken
data = request.get_json() or {}
email = (data.get("email") or "").strip().lower()
if not email:
return api_response(
success=False,
message="Email is required",
status=400,
error_type="VALIDATION_ERROR",
)
user = User.query.filter_by(email=email, deleted_at=None).first()
if user and not user.email_verified:
try:
verify_token = EmailVerificationToken.generate(user_id=user.id)
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
verify_link = f"{app_url}/verify-email?token={verify_token.token}"
subject = "Verify your Gatehouse email address"
body = (
f"Hi {user.full_name or user.email},\n\n"
f"Please verify your email address by clicking the link below (valid for 24 hours):\n"
f"{verify_link}\n\n"
f"Gatehouse Security Team"
)
NotificationService._send_email(
to_address=user.email,
subject=subject,
body=body,
)
_pw_logger.info(f"Verification email sent for user {user.id}")
except Exception as exc:
_pw_logger.exception(f"Error sending verification email: {exc}")
return api_response(
data={},
message="If an account exists for this email and is not yet verified, you will receive a verification link shortly.",
)
# =============================================================================
# Account Activation (separate from email-verification)
# =============================================================================
@api_v1_bp.route("/auth/activate", methods=["POST"])
def activate_account():
"""Activate a user account via a one-time activation code.
Request body:
code the activation_key from the welcome email
Returns:
200: Account activated, session token returned
400: Missing code
404: Invalid or already-used code
"""
import secrets
from gatehouse_app.models.user.user import User
from gatehouse_app.extensions import db
data = request.get_json() or {}
code = (data.get("code") or "").strip()
if not code:
return api_response(success=False, message="Activation code is required", status=400, error_type="VALIDATION_ERROR")
user = User.query.filter_by(activation_key=code, deleted_at=None).first()
if not user:
return api_response(success=False, message="Invalid or expired activation code", status=404, error_type="NOT_FOUND")
user.activated = True
user.activation_key = None # one-time use
db.session.add(user)
db.session.commit()
user_session = AuthService.create_session(user)
_pw_logger.info(f"Account activated for user {user.id}")
return api_response(
data={
"user": user.to_dict(),
"token": user_session.token,
"expires_at": user_session.expires_at.isoformat() + "Z"
if user_session.expires_at.isoformat()[-1] != "Z"
else user_session.expires_at.isoformat(),
},
message="Account activated successfully",
)
@api_v1_bp.route("/auth/resend-activation", methods=["POST"])
def resend_activation():
"""Re-send an account activation email.
Always returns 200 to avoid leaking whether an account exists.
Request body:
email user email address
"""
import secrets
from gatehouse_app.models.user.user import User
from gatehouse_app.extensions import db
data = request.get_json() or {}
email = (data.get("email") or "").strip().lower()
if not email:
return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR")
user = User.query.filter_by(email=email, deleted_at=None).first()
if user and not user.activated:
try:
code = secrets.token_urlsafe(32)
user.activation_key = code
db.session.add(user)
db.session.commit()
app_url = current_app.config.get("APP_URL", current_app.config.get("FRONTEND_URL", "http://localhost:8080"))
activate_link = f"{app_url}/activate?code={code}"
subject = "Activate your Gatehouse account"
body = (
f"Hi {user.full_name or user.email},\n\n"
f"Please activate your Gatehouse account by clicking the link below:\n"
f"{activate_link}\n\n"
f"If you did not create an account, you can safely ignore this email.\n\n"
f"Gatehouse Security Team"
)
NotificationService._send_email(to_address=user.email, subject=subject, body=body)
_pw_logger.info(f"Activation email re-sent to {user.id}")
except Exception as exc:
_pw_logger.exception(f"Error re-sending activation email: {exc}")
return api_response(
data={},
message="If an unactivated account exists for this email, you will receive a new activation link shortly.",
)
# =============================================================================
# Token retrieval / redirect (for CLI / external tools)
# =============================================================================
@api_v1_bp.route("/auth/token", methods=["GET"])
@login_required
def get_token():
"""Return the current session token, optionally redirecting to a URL.
Query parameters:
redirect optional URL to redirect to with the token appended as
a query param: ``<redirect>?token=<token>``
Returns:
200: JSON ``{"token": "<token>"}`` (no redirect given)
302: Redirect to ``<redirect>?token=<token>``
"""
from flask import redirect as flask_redirect
from urllib.parse import urlparse
token = g.current_session.token
redirect_url = request.args.get("redirect", "").strip()
if redirect_url:
# Validate redirect URL against allowed origins to prevent open-redirect
# token exfiltration attacks (CWE-601).
allowed_origins = set(current_app.config.get("CORS_ORIGINS", []))
frontend_url = current_app.config.get("FRONTEND_URL", "")
if frontend_url:
parsed = urlparse(frontend_url)
allowed_origins.add(f"{parsed.scheme}://{parsed.netloc}")
parsed_redirect = urlparse(redirect_url)
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
if redirect_origin not in allowed_origins:
return api_response(
success=False,
message="Redirect URL is not allowed.",
status=400,
error_type="INVALID_REDIRECT",
)
sep = "&" if "?" in redirect_url else "?"
return flask_redirect(f"{redirect_url}{sep}token={token}", code=302)
return api_response(data={"token": token}, message="Token retrieved")
+699
View File
@@ -0,0 +1,699 @@
"""Department endpoints."""
from flask import g, request
from marshmallow import Schema, fields, validate, ValidationError
from sqlalchemy.orm.attributes import flag_modified
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
from gatehouse_app.models import Department, DepartmentMembership
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.extensions import db
class DepartmentCreateSchema(Schema):
"""Schema for creating a department."""
name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
class DepartmentUpdateSchema(Schema):
"""Schema for updating a department."""
name = fields.Str(validate=validate.Length(min=1, max=255))
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
class AddDepartmentMemberSchema(Schema):
"""Schema for adding a member to a department."""
email = fields.Email(required=True)
@api_v1_bp.route("/organizations/<org_id>/departments", methods=["GET"])
@login_required
@full_access_required
def list_departments(org_id):
"""
List all departments in an organization.
Args:
org_id: Organization ID
Returns:
200: List of departments
401: Not authenticated
403: Not a member
404: Organization not found
"""
org = OrganizationService.get_organization_by_id(org_id)
# Check if user is a member
if not org.is_member(g.current_user.id):
return api_response(
success=False,
message="You are not a member of this organization",
status=403,
error_type="AUTHORIZATION_ERROR",
)
departments = Department.query.filter_by(
organization_id=org_id,
deleted_at=None
).all()
return api_response(
data={
"departments": [d.to_dict() for d in departments],
"count": len(departments),
},
message="Departments retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/departments", methods=["POST"])
@login_required
@require_admin
@full_access_required
def create_department(org_id):
"""
Create a new department.
Args:
org_id: Organization ID
Request body:
name: Department name (required)
description: Optional description
Returns:
201: Department created successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization not found
409: Department name already exists in org
"""
try:
org = OrganizationService.get_organization_by_id(org_id)
schema = DepartmentCreateSchema()
data = schema.load(request.json or {})
# Check if department name already exists
existing = Department.query.filter_by(
organization_id=org_id,
name=data["name"],
deleted_at=None
).first()
if existing:
return api_response(
success=False,
message=f"Department '{data['name']}' already exists in this organization",
status=409,
error_type="CONFLICT",
)
# Create department
dept = Department(
organization_id=org_id,
name=data["name"],
description=data.get("description"),
)
db.session.add(dept)
db.session.commit()
return api_response(
data={"department": dept.to_dict()},
message="Department created successfully",
status=201,
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>", methods=["GET"])
@login_required
@full_access_required
def get_department(org_id, dept_id):
"""
Get a specific department.
Args:
org_id: Organization ID
dept_id: Department ID
Returns:
200: Department data
401: Not authenticated
403: Not a member
404: Organization or department not found
"""
org = OrganizationService.get_organization_by_id(org_id)
if not org.is_member(g.current_user.id):
return api_response(
success=False,
message="You are not a member of this organization",
status=403,
error_type="AUTHORIZATION_ERROR",
)
dept = Department.query.filter_by(
id=dept_id,
organization_id=org_id,
deleted_at=None
).first()
if not dept:
return api_response(
success=False,
message="Department not found",
status=404,
error_type="NOT_FOUND",
)
return api_response(
data={"department": dept.to_dict()},
message="Department retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>", methods=["PATCH"])
@login_required
@require_admin
@full_access_required
def update_department(org_id, dept_id):
"""
Update a department.
Args:
org_id: Organization ID
dept_id: Department ID
Request body:
name: Optional new name
description: Optional new description
Returns:
200: Department updated successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization or department not found
409: Name already exists
"""
try:
org = OrganizationService.get_organization_by_id(org_id)
dept = Department.query.filter_by(
id=dept_id,
organization_id=org_id,
deleted_at=None
).first()
if not dept:
return api_response(
success=False,
message="Department not found",
status=404,
error_type="NOT_FOUND",
)
schema = DepartmentUpdateSchema()
data = schema.load(request.json or {})
# Check if new name already exists
if "name" in data and data["name"] != dept.name:
existing = Department.query.filter_by(
organization_id=org_id,
name=data["name"],
deleted_at=None
).first()
if existing:
return api_response(
success=False,
message=f"Department '{data['name']}' already exists",
status=409,
error_type="CONFLICT",
)
# Update fields
for key, value in data.items():
setattr(dept, key, value)
db.session.commit()
return api_response(
data={"department": dept.to_dict()},
message="Department updated successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>", methods=["DELETE"])
@login_required
@require_admin
@full_access_required
def delete_department(org_id, dept_id):
"""
Delete a department (soft delete).
Args:
org_id: Organization ID
dept_id: Department ID
Returns:
200: Department deleted successfully
401: Not authenticated
403: Not an admin
404: Organization or department not found
"""
org = OrganizationService.get_organization_by_id(org_id)
dept = Department.query.filter_by(
id=dept_id,
organization_id=org_id,
deleted_at=None
).first()
if not dept:
return api_response(
success=False,
message="Department not found",
status=404,
error_type="NOT_FOUND",
)
# Soft delete
dept.deleted_at = db.func.now()
db.session.commit()
return api_response(
message="Department deleted successfully",
)
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/members", methods=["GET"])
@login_required
@full_access_required
def get_department_members(org_id, dept_id):
"""
Get all members of a department.
Args:
org_id: Organization ID
dept_id: Department ID
Returns:
200: List of members
401: Not authenticated
403: Not a member
404: Organization or department not found
"""
org = OrganizationService.get_organization_by_id(org_id)
if not org.is_member(g.current_user.id):
return api_response(
success=False,
message="You are not a member of this organization",
status=403,
error_type="AUTHORIZATION_ERROR",
)
dept = Department.query.filter_by(
id=dept_id,
organization_id=org_id,
deleted_at=None
).first()
if not dept:
return api_response(
success=False,
message="Department not found",
status=404,
error_type="NOT_FOUND",
)
members = DepartmentMembership.query.filter_by(
department_id=dept_id,
deleted_at=None
).all()
members_data = []
for member in members:
member_dict = member.to_dict()
member_dict["user"] = member.user.to_dict()
members_data.append(member_dict)
return api_response(
data={
"members": members_data,
"count": len(members_data),
},
message="Members retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/members", methods=["POST"])
@login_required
@require_admin
@full_access_required
def add_department_member(org_id, dept_id):
"""
Add a member to a department.
Args:
org_id: Organization ID
dept_id: Department ID
Request body:
email: User email to add
Returns:
201: Member added successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization, department, or user not found
409: User already a member
"""
try:
org = OrganizationService.get_organization_by_id(org_id)
dept = Department.query.filter_by(
id=dept_id,
organization_id=org_id,
deleted_at=None
).first()
if not dept:
return api_response(
success=False,
message="Department not found",
status=404,
error_type="NOT_FOUND",
)
schema = AddDepartmentMemberSchema()
data = schema.load(request.json or {})
# Find user by email
user = UserService.get_user_by_email(data["email"])
if not user:
return api_response(
success=False,
message="User not found",
status=404,
error_type="NOT_FOUND",
)
# Check if already an active member
existing = DepartmentMembership.query.filter_by(
user_id=user.id,
department_id=dept_id,
deleted_at=None
).first()
if existing:
return api_response(
success=False,
message="User is already a member of this department",
status=409,
error_type="CONFLICT",
)
# Check for a previously soft-deleted row and resurrect it instead of inserting
soft_deleted = DepartmentMembership.query.filter(
DepartmentMembership.user_id == user.id,
DepartmentMembership.department_id == dept_id,
DepartmentMembership.deleted_at.isnot(None)
).first()
if soft_deleted:
soft_deleted.deleted_at = None
membership = soft_deleted
else:
membership = DepartmentMembership(
user_id=user.id,
department_id=dept_id,
)
db.session.add(membership)
db.session.commit()
member_dict = membership.to_dict()
member_dict["user"] = user.to_dict()
return api_response(
data={"member": member_dict},
message="Member added successfully",
status=201,
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/members/<user_id>", methods=["DELETE"])
@login_required
@require_admin
@full_access_required
def remove_department_member(org_id, dept_id, user_id):
"""
Remove a member from a department.
Args:
org_id: Organization ID
dept_id: Department ID
user_id: User ID to remove
Returns:
200: Member removed successfully
401: Not authenticated
403: Not an admin
404: Organization, department, or member not found
"""
org = OrganizationService.get_organization_by_id(org_id)
dept = Department.query.filter_by(
id=dept_id,
organization_id=org_id,
deleted_at=None
).first()
if not dept:
return api_response(
success=False,
message="Department not found",
status=404,
error_type="NOT_FOUND",
)
membership = DepartmentMembership.query.filter_by(
user_id=user_id,
department_id=dept_id,
deleted_at=None
).first()
if not membership:
return api_response(
success=False,
message="User is not a member of this department",
status=404,
error_type="NOT_FOUND",
)
# Soft delete
membership.deleted_at = db.func.now()
db.session.commit()
return api_response(
message="Member removed successfully",
)
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/principals", methods=["GET"])
@login_required
@full_access_required
def get_department_principals(org_id, dept_id):
"""Get all principals linked to a department."""
org = OrganizationService.get_organization_by_id(org_id)
if not org.is_member(g.current_user.id):
return api_response(
success=False,
message="You are not a member of this organization",
status=403,
error_type="AUTHORIZATION_ERROR",
)
dept = Department.query.filter_by(
id=dept_id,
organization_id=org_id,
deleted_at=None
).first()
if not dept:
return api_response(
success=False,
message="Department not found",
status=404,
error_type="NOT_FOUND",
)
principals = dept.get_principals(active_only=True)
return api_response(
data={
"principals": [p.to_dict() for p in principals],
"count": len(principals),
},
message="Principals retrieved successfully",
)
# ---------------------------------------------------------------------------
# Department Certificate Policy
# ---------------------------------------------------------------------------
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/cert-policy", methods=["GET"])
@login_required
@require_admin
@full_access_required
def get_dept_cert_policy(org_id, dept_id):
"""Get the certificate issuance policy for a department (admin only)."""
from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy, STANDARD_EXTENSIONS
dept = Department.query.filter_by(
id=dept_id, organization_id=org_id, deleted_at=None
).first()
if not dept:
return api_response(success=False, message="Department not found", status=404, error_type="NOT_FOUND")
policy = DepartmentCertPolicy.query.filter(
DepartmentCertPolicy.department_id == dept_id,
DepartmentCertPolicy.deleted_at.is_(None),
).first()
if policy:
data = policy.to_dict()
else:
# Return default (all standard extensions, no user expiry choice)
data = {
"department_id": str(dept_id),
"allow_user_expiry": False,
"default_expiry_hours": 1,
"max_expiry_hours": 24,
"allowed_extensions": list(STANDARD_EXTENSIONS),
"custom_extensions": [],
"all_extensions": list(STANDARD_EXTENSIONS),
"standard_extensions": list(STANDARD_EXTENSIONS),
}
return api_response(data={"cert_policy": data}, message="Certificate policy retrieved")
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/cert-policy", methods=["PUT"])
@login_required
@require_admin
@full_access_required
def set_dept_cert_policy(org_id, dept_id):
"""Create or update the certificate issuance policy for a department (admin only)."""
from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy, STANDARD_EXTENSIONS
dept = Department.query.filter_by(
id=dept_id, organization_id=org_id, deleted_at=None
).first()
if not dept:
return api_response(success=False, message="Department not found", status=404, error_type="NOT_FOUND")
body = request.get_json() or {}
# Validate expiry values
default_expiry = body.get("default_expiry_hours")
max_expiry = body.get("max_expiry_hours")
if default_expiry is not None:
try:
default_expiry = int(default_expiry)
if default_expiry < 1:
raise ValueError
except (ValueError, TypeError):
return api_response(success=False, message="default_expiry_hours must be a positive integer", status=400, error_type="VALIDATION_ERROR")
if max_expiry is not None:
try:
max_expiry = int(max_expiry)
if max_expiry < 1:
raise ValueError
except (ValueError, TypeError):
return api_response(success=False, message="max_expiry_hours must be a positive integer", status=400, error_type="VALIDATION_ERROR")
if default_expiry and max_expiry and default_expiry > max_expiry:
return api_response(success=False, message="default_expiry_hours cannot exceed max_expiry_hours", status=400, error_type="VALIDATION_ERROR")
# Validate allowed_extensions — must be subset of STANDARD_EXTENSIONS
allowed_extensions = body.get("allowed_extensions")
if allowed_extensions is not None:
if not isinstance(allowed_extensions, list):
return api_response(success=False, message="allowed_extensions must be a list", status=400, error_type="VALIDATION_ERROR")
invalid_ext = [e for e in allowed_extensions if e not in STANDARD_EXTENSIONS]
if invalid_ext:
return api_response(
success=False,
message=f"Invalid standard extensions: {', '.join(invalid_ext)}. Valid: {', '.join(STANDARD_EXTENSIONS)}",
status=400,
error_type="VALIDATION_ERROR",
)
# Validate custom_extensions — plain strings
custom_extensions = body.get("custom_extensions")
if custom_extensions is not None:
if not isinstance(custom_extensions, list) or not all(isinstance(e, str) for e in custom_extensions):
return api_response(success=False, message="custom_extensions must be a list of strings", status=400, error_type="VALIDATION_ERROR")
policy = DepartmentCertPolicy.query.filter(
DepartmentCertPolicy.department_id == dept_id,
DepartmentCertPolicy.deleted_at.is_(None),
).first()
if policy is None:
policy = DepartmentCertPolicy(department_id=dept_id)
db.session.add(policy)
if "allow_user_expiry" in body:
policy.allow_user_expiry = bool(body["allow_user_expiry"])
if default_expiry is not None:
policy.default_expiry_hours = default_expiry
if max_expiry is not None:
policy.max_expiry_hours = max_expiry
if allowed_extensions is not None:
policy.allowed_extensions = list(allowed_extensions)
flag_modified(policy, "allowed_extensions")
if custom_extensions is not None:
policy.custom_extensions = list(custom_extensions)
flag_modified(policy, "custom_extensions")
db.session.commit()
return api_response(data={"cert_policy": policy.to_dict()}, message="Certificate policy saved")
+405 -7
View File
@@ -46,6 +46,33 @@ def _pop_oidc_bridge(oauth_state: str) -> str | None:
pass
return None
def _store_cli_redirect(oauth_state: str, redirect_url: str) -> None:
"""Store CLI redirect_url keyed by OAuth state (for /token_please flow)."""
try:
import gatehouse_app.extensions as _ext
rc = _ext.redis_client
if rc is not None:
rc.setex(f"oauth_cli_redirect:{oauth_state}", _OAUTH_BRIDGE_TTL, redirect_url)
except Exception:
pass
def _pop_cli_redirect(oauth_state: str) -> str | None:
"""Retrieve and delete CLI redirect_url for the given OAuth state."""
try:
import gatehouse_app.extensions as _ext
rc = _ext.redis_client
if rc is not None:
key = f"oauth_cli_redirect:{oauth_state}"
val = rc.get(key)
if val:
rc.delete(key)
return val.decode() if isinstance(val, bytes) else val
except Exception:
pass
return None
logger = logging.getLogger(__name__)
@@ -69,6 +96,137 @@ def get_provider_type(provider: str) -> AuthMethodType:
return PROVIDER_TYPE_MAP[provider_lower]
@api_v1_bp.route("/token_please", methods=["GET"])
def token_please():
"""
CLI token acquisition endpoint.
Redirects the user's browser to the Gatehouse login page so they can
authenticate using any method (password, OAuth, passkey, TOTP, etc.).
On successful login the frontend delivers the session token directly to
the CLI's local callback server.
This endpoint is designed for CLI clients that:
1. Start a local HTTP server on LISTENER_SERVER_PORT (e.g. 8250)
2. Open a browser to /api/v1/token_please?redirect_url=http://127.0.0.1:8250/?token=
3. Wait for the browser to deliver the token to their local server
Query parameters:
redirect_url: Local callback URL where the token will be appended
"""
import secrets
from urllib.parse import urlencode, quote
from flask import current_app, redirect as flask_redirect
redirect_url = request.args.get("redirect_url", "").strip()
if not redirect_url:
return api_response(
success=False,
message="redirect_url query parameter is required",
status=400,
error_type="MISSING_REDIRECT_URL",
)
# Validate redirect_url is localhost/127.0.0.1 (security: prevent open redirect)
from urllib.parse import urlparse as _urlparse
parsed = _urlparse(redirect_url)
if parsed.hostname not in ("localhost", "127.0.0.1"):
return api_response(
success=False,
message="redirect_url must point to localhost",
status=400,
error_type="INVALID_REDIRECT_URL",
)
# Store the CLI redirect URL in Redis keyed by a short-lived token so the
# frontend can retrieve it after login without it being visible in the URL.
cli_token = secrets.token_urlsafe(32)
try:
import gatehouse_app.extensions as _ext
rc = _ext.redis_client
if rc is not None:
rc.setex(f"cli_redirect:{cli_token}", _OAUTH_BRIDGE_TTL, redirect_url)
else:
logger.warning("Redis not available; passing cli_redirect directly in URL")
cli_token = None
except Exception:
cli_token = None
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080")
if cli_token:
# Pass an opaque token; the frontend exchanges it for the real URL via
# GET /api/v1/cli/redirect-url?token=<cli_token>
login_url = f"{frontend_url}/login?cli_token={cli_token}"
else:
# Fallback: put the redirect URL directly (still localhost-only, validated above)
login_url = f"{frontend_url}/login?cli_redirect={quote(redirect_url, safe='')}"
logger.info(f"CLI token_please: redirecting browser to Gatehouse login page")
return flask_redirect(login_url, code=302)
@api_v1_bp.route("/cli/redirect-url", methods=["GET"])
def cli_redirect_url_lookup():
"""
Exchange a short-lived cli_token for the CLI's local redirect URL.
Called by the frontend LoginPage after it detects the cli_token query
param so it can obtain the actual CLI callback URL from Redis without
exposing it in the browser URL bar.
Query parameters:
token: The cli_token issued by /token_please
Returns:
200: { "redirect_url": "http://127.0.0.1:8250/?token=" }
400: Missing token
404: Token not found or expired
"""
cli_token = request.args.get("token", "").strip()
if not cli_token:
return api_response(
success=False,
message="token query parameter is required",
status=400,
error_type="MISSING_TOKEN",
)
try:
import gatehouse_app.extensions as _ext
rc = _ext.redis_client
if rc is not None:
key = f"cli_redirect:{cli_token}"
val = rc.get(key)
if val is None:
return api_response(
success=False,
message="CLI token not found or expired",
status=404,
error_type="TOKEN_NOT_FOUND",
)
# Keep the key alive until the login actually completes (consume on use
# would break multi-step auth like TOTP), so we leave it as-is.
redirect_url = val.decode() if isinstance(val, bytes) else val
return api_response(data={"redirect_url": redirect_url})
except Exception as e:
logger.error(f"cli_redirect_url_lookup error: {e}")
return api_response(
success=False,
message="Internal error looking up CLI token",
status=500,
error_type="INTERNAL_ERROR",
)
return api_response(
success=False,
message="Redis not available",
status=503,
error_type="SERVICE_UNAVAILABLE",
)
# =============================================================================
# Provider Configuration Endpoints (Admin)
# =============================================================================
@@ -83,7 +241,7 @@ def list_providers():
200: List of providers with their configuration status
401: Not authenticated
"""
from gatehouse_app.models.authentication_method import ApplicationProviderConfig
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
from gatehouse_app.services.external_auth_service import ExternalProviderConfig
# Check app-level provider configs (ApplicationProviderConfig)
@@ -575,8 +733,6 @@ def initiate_oauth_authorize(provider: str):
"state": "state_token"
}
"""
provider_type = get_provider_type(provider)
# Get query parameters - organization_id is now optional
flow = request.args.get("flow", "login")
redirect_uri = request.args.get("redirect_uri")
@@ -592,7 +748,7 @@ def initiate_oauth_authorize(provider: str):
)
try:
# Initiate flow - organization_id is now optional
provider_type = get_provider_type(provider)
if flow == "login":
auth_url, state = OAuthFlowService.initiate_login_flow(
provider_type=provider_type,
@@ -626,6 +782,13 @@ def initiate_oauth_authorize(provider: str):
status=e.status_code,
error_type=e.error_type,
)
except ExternalAuthError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/external/<provider>/callback", methods=["GET"])
@@ -666,8 +829,19 @@ def handle_oauth_callback(provider: str):
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080")
frontend_callback = f"{frontend_url}/oauth/callback"
# Check if this is a CLI /token_please flow — retrieve stored redirect_url
cli_redirect_url = _pop_cli_redirect(state) if state else None
def redirect_error(message: str, error_type: str = "OAUTH_ERROR"):
"""Redirect to frontend with error params."""
"""Redirect to frontend (or CLI) with error params."""
if cli_redirect_url:
# CLI flow: return a plain error page instead of redirecting back
from flask import make_response
return make_response(
f"<html><body><h2>Authentication Error</h2><p>{message}</p>"
f"<p>You may close this window.</p></body></html>",
400,
)
params = {"error": message, "error_type": error_type}
if state:
params["state"] = state
@@ -706,8 +880,11 @@ def handle_oauth_callback(provider: str):
# Recover oidc_session_id if this was triggered from an OIDC bridge flow
oidc_session_id = _pop_oidc_bridge(state)
# Organization selection / creation flows are not supported in CLI mode
# (fall through to token redirect with whatever session we have)
# Organization selection needed (user belongs to multiple orgs)
if result.get("requires_org_selection"):
if result.get("requires_org_selection") and not cli_redirect_url:
import json
orgs = json.dumps(result.get("available_organizations", []))
params = {
@@ -722,12 +899,20 @@ def handle_oauth_callback(provider: str):
return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302)
# Organization creation needed (new user via OAuth with no org)
if result.get("requires_org_creation"):
if result.get("requires_org_creation") and not cli_redirect_url:
import json as _json
session_data = result.get("session", {})
token = session_data.get("token", "")
expires_in = session_data.get("expires_in", 86400)
pending_invites = result.get("pending_invites", [])
params = {
"requires_org_creation": "1",
"state": result["state"],
"provider": provider,
"flow": flow_type,
"token": token,
"expires_in": str(expires_in),
"pending_invites": _json.dumps(pending_invites),
}
if oidc_session_id:
params["oidc_session_id"] = oidc_session_id
@@ -751,6 +936,19 @@ def handle_oauth_callback(provider: str):
user_info = result.get("user", {})
if user_info.get("email"):
params["email"] = user_info["email"]
# ── CLI /token_please flow: redirect to the CLI's local callback ─────
if cli_redirect_url:
# The CLI expects: http://127.0.0.1:8250/?token=<TOKEN>
# cli_redirect_url already ends with "token=" so just append the value
cli_final_url = cli_redirect_url + token
logger.info(
f"CLI token_please success: provider={provider}, user={user_info.get('email')}, "
f"redirecting to CLI callback"
)
return flask_redirect(cli_final_url, code=302)
# ── Frontend flow ─────────────────────────────────────────────────────
# Pass oidc_session_id through so the frontend can complete the OIDC flow
if oidc_session_id:
params["oidc_session_id"] = oidc_session_id
@@ -1049,3 +1247,203 @@ def _get_provider_endpoints(provider_type: AuthMethodType):
"UNSUPPORTED_PROVIDER",
400,
)
# =============================================================================
# Admin: Application-level OAuth Provider Management
# =============================================================================
@api_v1_bp.route("/admin/oauth/providers", methods=["GET"])
@login_required
def admin_list_app_providers():
"""List all application-level OAuth provider configurations (admin only).
Returns:
200: List of providers with client_id and enabled status
401: Not authenticated
403: Not an admin
"""
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
from gatehouse_app.models import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
# Verify caller is admin in any org
admin_memberships = OrganizationMember.query.filter(
OrganizationMember.user_id == g.current_user.id,
OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]),
).all()
if not admin_memberships:
return api_response(
success=False,
message="Admin access required",
status=403,
error_type="FORBIDDEN",
)
PROVIDERS = [
{"id": "google", "name": "Google"},
{"id": "github", "name": "GitHub"},
{"id": "microsoft", "name": "Microsoft"},
]
db_configs = {
c.provider_type: c
for c in ApplicationProviderConfig.query.all()
}
result = []
for p in PROVIDERS:
cfg = db_configs.get(p["id"])
result.append({
"id": p["id"],
"name": p["name"],
"is_configured": cfg is not None,
"is_enabled": cfg.is_enabled if cfg else False,
"client_id": cfg.client_id if cfg else None,
})
return api_response(
data={"providers": result},
message="OAuth providers retrieved successfully",
)
@api_v1_bp.route("/admin/oauth/providers/<provider>", methods=["PUT"])
@login_required
def admin_configure_app_provider(provider: str):
"""Create or update an application-level OAuth provider config (admin only).
Args:
provider: Provider type (google, github, microsoft)
Request body:
client_id: OAuth client ID
client_secret: OAuth client secret (optional omit to keep existing)
is_enabled: Whether the provider is enabled (default: true)
Returns:
200: Provider configuration updated
400: Validation error
401: Not authenticated
403: Not an admin
"""
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
from gatehouse_app.models import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
from gatehouse_app.extensions import db
SUPPORTED = ["google", "github", "microsoft"]
if provider not in SUPPORTED:
return api_response(
success=False,
message=f"Unsupported provider. Must be one of: {', '.join(SUPPORTED)}",
status=400,
error_type="VALIDATION_ERROR",
)
# Verify caller is admin in any org
admin_memberships = OrganizationMember.query.filter(
OrganizationMember.user_id == g.current_user.id,
OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]),
).all()
if not admin_memberships:
return api_response(
success=False,
message="Admin access required",
status=403,
error_type="FORBIDDEN",
)
data = request.json or {}
client_id = (data.get("client_id") or "").strip()
client_secret = (data.get("client_secret") or "").strip()
is_enabled = data.get("is_enabled", True)
if not client_id:
return api_response(
success=False,
message="client_id is required",
status=400,
error_type="VALIDATION_ERROR",
)
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
if cfg:
cfg.client_id = client_id
if client_secret:
cfg.set_client_secret(client_secret)
cfg.is_enabled = bool(is_enabled)
db.session.commit()
else:
cfg = ApplicationProviderConfig(
provider_type=provider,
client_id=client_id,
is_enabled=bool(is_enabled),
)
if client_secret:
cfg.set_client_secret(client_secret)
db.session.add(cfg)
db.session.commit()
return api_response(
data={
"provider": {
"id": provider,
"client_id": cfg.client_id,
"is_enabled": cfg.is_enabled,
}
},
message=f"{provider.capitalize()} OAuth provider configured successfully",
)
@api_v1_bp.route("/admin/oauth/providers/<provider>", methods=["DELETE"])
@login_required
def admin_delete_app_provider(provider: str):
"""Delete an application-level OAuth provider config (admin only).
Args:
provider: Provider type (google, github, microsoft)
Returns:
200: Provider configuration deleted
404: Provider not found
401: Not authenticated
403: Not an admin
"""
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
from gatehouse_app.models import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
from gatehouse_app.extensions import db
# Verify caller is admin in any org
admin_memberships = OrganizationMember.query.filter(
OrganizationMember.user_id == g.current_user.id,
OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]),
).all()
if not admin_memberships:
return api_response(
success=False,
message="Admin access required",
status=403,
error_type="FORBIDDEN",
)
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
if not cfg:
return api_response(
success=False,
message=f"Provider '{provider}' is not configured",
status=404,
error_type="NOT_FOUND",
)
db.session.delete(cfg)
db.session.commit()
return api_response(
message=f"{provider.capitalize()} OAuth provider configuration removed",
)
File diff suppressed because it is too large Load Diff
+35 -11
View File
@@ -195,20 +195,46 @@ def get_org_mfa_compliance(org_id):
limit = min(int(request.args.get("limit", 100)), 100)
offset = int(request.args.get("offset", 0))
page = int(request.args.get("page", 1))
page_size = min(int(request.args.get("page_size", limit)), 100)
effective_offset = offset if request.args.get("offset") else (page - 1) * page_size
compliance_list = MfaPolicyService.get_org_compliance_list(
organization_id=org_id,
status=status,
limit=limit,
offset=offset,
limit=page_size,
offset=effective_offset,
)
def format_member(c):
"""Normalize compliance record to UI-expected shape."""
if isinstance(c, dict):
return {
"user_id": c.get("user_id"),
"user_email": c.get("email"),
"user_name": c.get("full_name"),
"status": c.get("status"),
"deadline_at": c.get("deadline_at"),
"compliant_at": c.get("compliant_at"),
"last_notified_at": c.get("notified_at"),
}
return {
"user_id": getattr(c, "user_id", None),
"user_email": getattr(c, "email", None),
"user_name": getattr(c, "full_name", None),
"status": getattr(c, "status", None),
"deadline_at": getattr(c, "deadline_at", None),
"compliant_at": getattr(c, "compliant_at", None),
"last_notified_at": getattr(c, "notified_at", None),
}
return api_response(
data={
"compliance": compliance_list,
"members": [format_member(c) for c in compliance_list],
"count": len(compliance_list),
"limit": limit,
"offset": offset,
"page": page,
"page_size": page_size,
},
message="Compliance records retrieved successfully",
)
@@ -325,12 +351,10 @@ def get_my_mfa_compliance():
return api_response(
data={
"mfa_compliance": {
"overall_status": compliance_summary.overall_status,
"missing_methods": compliance_summary.missing_methods,
"deadline_at": compliance_summary.deadline_at,
"orgs": orgs,
}
"overall_status": compliance_summary.overall_status,
"missing_methods": compliance_summary.missing_methods,
"deadline_at": compliance_summary.deadline_at,
"orgs": orgs,
},
message="MFA compliance retrieved successfully",
)
+779
View File
@@ -0,0 +1,779 @@
"""Principal endpoints."""
from flask import g, request
from marshmallow import Schema, fields, validate, ValidationError
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
from gatehouse_app.models import Principal, PrincipalMembership, Department, DepartmentPrincipal
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.exceptions import OrganizationNotFoundError
from gatehouse_app.extensions import db
class PrincipalCreateSchema(Schema):
"""Schema for creating a principal."""
name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
class PrincipalUpdateSchema(Schema):
"""Schema for updating a principal."""
name = fields.Str(validate=validate.Length(min=1, max=255))
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
class AddPrincipalMemberSchema(Schema):
"""Schema for adding a member to a principal."""
email = fields.Email(required=True)
class LinkPrincipalSchema(Schema):
"""Schema for linking principal to department."""
department_id = fields.Str(required=True)
@api_v1_bp.route("/organizations/<org_id>/principals", methods=["GET"])
@login_required
@full_access_required
def list_principals(org_id):
"""
List all principals in an organization.
Args:
org_id: Organization ID
Returns:
200: List of principals
401: Not authenticated
403: Not a member
404: Organization not found
"""
org = OrganizationService.get_organization_by_id(org_id)
if not org.is_member(g.current_user.id):
return api_response(
success=False,
message="You are not a member of this organization",
status=403,
error_type="AUTHORIZATION_ERROR",
)
principals = Principal.query.filter_by(
organization_id=org_id,
deleted_at=None
).all()
return api_response(
data={
"principals": [p.to_dict() for p in principals],
"count": len(principals),
},
message="Principals retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/principals", methods=["POST"])
@login_required
@require_admin
@full_access_required
def create_principal(org_id):
"""
Create a new principal.
Args:
org_id: Organization ID
Request body:
name: Principal name (required)
description: Optional description
Returns:
201: Principal created successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization not found
409: Principal name already exists
"""
try:
org = OrganizationService.get_organization_by_id(org_id)
schema = PrincipalCreateSchema()
data = schema.load(request.json or {})
# Check if principal name already exists
existing = Principal.query.filter_by(
organization_id=org_id,
name=data["name"],
deleted_at=None
).first()
if existing:
return api_response(
success=False,
message=f"Principal '{data['name']}' already exists",
status=409,
error_type="CONFLICT",
)
# Create principal
principal = Principal(
organization_id=org_id,
name=data["name"],
description=data.get("description"),
)
db.session.add(principal)
db.session.commit()
return api_response(
data={"principal": principal.to_dict()},
message="Principal created successfully",
status=201,
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>", methods=["GET"])
@login_required
@full_access_required
def get_principal(org_id, principal_id):
"""
Get a specific principal.
Args:
org_id: Organization ID
principal_id: Principal ID
Returns:
200: Principal data
401: Not authenticated
403: Not a member
404: Organization or principal not found
"""
org = OrganizationService.get_organization_by_id(org_id)
if not org.is_member(g.current_user.id):
return api_response(
success=False,
message="You are not a member of this organization",
status=403,
error_type="AUTHORIZATION_ERROR",
)
principal = Principal.query.filter_by(
id=principal_id,
organization_id=org_id,
deleted_at=None
).first()
if not principal:
return api_response(
success=False,
message="Principal not found",
status=404,
error_type="NOT_FOUND",
)
return api_response(
data={"principal": principal.to_dict()},
message="Principal retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>", methods=["PATCH"])
@login_required
@require_admin
@full_access_required
def update_principal(org_id, principal_id):
"""
Update a principal.
Args:
org_id: Organization ID
principal_id: Principal ID
Request body:
name: Optional new name
description: Optional new description
Returns:
200: Principal updated successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization or principal not found
409: Name already exists
"""
try:
org = OrganizationService.get_organization_by_id(org_id)
principal = Principal.query.filter_by(
id=principal_id,
organization_id=org_id,
deleted_at=None
).first()
if not principal:
return api_response(
success=False,
message="Principal not found",
status=404,
error_type="NOT_FOUND",
)
schema = PrincipalUpdateSchema()
data = schema.load(request.json or {})
# Check if new name already exists
if "name" in data and data["name"] != principal.name:
existing = Principal.query.filter_by(
organization_id=org_id,
name=data["name"],
deleted_at=None
).first()
if existing:
return api_response(
success=False,
message=f"Principal '{data['name']}' already exists",
status=409,
error_type="CONFLICT",
)
# Update fields
for key, value in data.items():
setattr(principal, key, value)
db.session.commit()
return api_response(
data={"principal": principal.to_dict()},
message="Principal updated successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>", methods=["DELETE"])
@login_required
@require_admin
@full_access_required
def delete_principal(org_id, principal_id):
"""
Delete a principal (soft delete).
Args:
org_id: Organization ID
principal_id: Principal ID
Returns:
200: Principal deleted successfully
401: Not authenticated
403: Not an admin
404: Organization or principal not found
"""
org = OrganizationService.get_organization_by_id(org_id)
principal = Principal.query.filter_by(
id=principal_id,
organization_id=org_id,
deleted_at=None
).first()
if not principal:
return api_response(
success=False,
message="Principal not found",
status=404,
error_type="NOT_FOUND",
)
# Soft delete
principal.deleted_at = db.func.now()
db.session.commit()
return api_response(
message="Principal deleted successfully",
)
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/members", methods=["GET"])
@login_required
@full_access_required
def get_principal_members(org_id, principal_id):
"""
Get all members (direct + via department) with access to a principal.
Args:
org_id: Organization ID
principal_id: Principal ID
Returns:
200: List of members
401: Not authenticated
403: Not a member
404: Organization or principal not found
"""
org = OrganizationService.get_organization_by_id(org_id)
if not org.is_member(g.current_user.id):
return api_response(
success=False,
message="You are not a member of this organization",
status=403,
error_type="AUTHORIZATION_ERROR",
)
principal = Principal.query.filter_by(
id=principal_id,
organization_id=org_id,
deleted_at=None
).first()
if not principal:
return api_response(
success=False,
message="Principal not found",
status=404,
error_type="NOT_FOUND",
)
# Get direct members
direct_members = PrincipalMembership.query.filter_by(
principal_id=principal_id,
deleted_at=None
).all()
all_users = set()
for membership in direct_members:
if membership.user.deleted_at is None:
all_users.add(membership.user)
# Get members via departments
dept_links = DepartmentPrincipal.query.filter_by(
principal_id=principal_id,
deleted_at=None
).all()
for link in dept_links:
dept = link.department
if dept.deleted_at is None:
dept_members = dept.get_members(active_only=True)
for dept_member in dept_members:
if dept_member.user.deleted_at is None:
all_users.add(dept_member.user)
users_data = [u.to_dict() for u in all_users]
return api_response(
data={
"members": users_data,
"count": len(users_data),
},
message="Members retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/members", methods=["POST"])
@login_required
@require_admin
@full_access_required
def add_principal_member(org_id, principal_id):
"""
Add a direct member to a principal.
Args:
org_id: Organization ID
principal_id: Principal ID
Request body:
email: User email to add
Returns:
201: Member added successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization, principal, or user not found
409: User already a member
"""
try:
org = OrganizationService.get_organization_by_id(org_id)
principal = Principal.query.filter_by(
id=principal_id,
organization_id=org_id,
deleted_at=None
).first()
if not principal:
return api_response(
success=False,
message="Principal not found",
status=404,
error_type="NOT_FOUND",
)
schema = AddPrincipalMemberSchema()
data = schema.load(request.json or {})
# Find user by email
user = UserService.get_user_by_email(data["email"])
if not user:
return api_response(
success=False,
message="User not found",
status=404,
error_type="NOT_FOUND",
)
# Check if already a member
existing = PrincipalMembership.query.filter_by(
user_id=user.id,
principal_id=principal_id,
deleted_at=None
).first()
if existing:
return api_response(
success=False,
message="User is already a member of this principal",
status=409,
error_type="CONFLICT",
)
soft_deleted = PrincipalMembership.query.filter(
PrincipalMembership.user_id == user.id,
PrincipalMembership.principal_id == principal_id,
PrincipalMembership.deleted_at.isnot(None)
).first()
if soft_deleted:
soft_deleted.deleted_at = None
membership = soft_deleted
else:
membership = PrincipalMembership(
user_id=user.id,
principal_id=principal_id,
)
db.session.add(membership)
db.session.commit()
member_dict = membership.to_dict()
member_dict["user"] = user.to_dict()
return api_response(
data={"member": member_dict},
message="Member added successfully",
status=201,
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/members/<user_id>", methods=["DELETE"])
@login_required
@require_admin
@full_access_required
def remove_principal_member(org_id, principal_id, user_id):
"""
Remove a direct member from a principal.
Args:
org_id: Organization ID
principal_id: Principal ID
user_id: User ID to remove
Returns:
200: Member removed successfully
401: Not authenticated
403: Not an admin
404: Organization, principal, or member not found
"""
org = OrganizationService.get_organization_by_id(org_id)
principal = Principal.query.filter_by(
id=principal_id,
organization_id=org_id,
deleted_at=None
).first()
if not principal:
return api_response(
success=False,
message="Principal not found",
status=404,
error_type="NOT_FOUND",
)
membership = PrincipalMembership.query.filter_by(
user_id=user_id,
principal_id=principal_id,
deleted_at=None
).first()
if not membership:
return api_response(
success=False,
message="User is not a member of this principal",
status=404,
error_type="NOT_FOUND",
)
# Soft delete
membership.deleted_at = db.func.now()
db.session.commit()
return api_response(
message="Member removed successfully",
)
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/departments", methods=["GET"])
@login_required
@full_access_required
def get_principal_departments(org_id, principal_id):
"""
Get all departments this principal is assigned to.
Args:
org_id: Organization ID
principal_id: Principal ID
Returns:
200: List of departments
401: Not authenticated
403: Not a member
404: Organization or principal not found
"""
org = OrganizationService.get_organization_by_id(org_id)
if not org.is_member(g.current_user.id):
return api_response(
success=False,
message="You are not a member of this organization",
status=403,
error_type="AUTHORIZATION_ERROR",
)
principal = Principal.query.filter_by(
id=principal_id,
organization_id=org_id,
deleted_at=None
).first()
if not principal:
return api_response(
success=False,
message="Principal not found",
status=404,
error_type="NOT_FOUND",
)
depts = principal.get_departments(active_only=True)
return api_response(
data={
"departments": [d.to_dict() for d in depts],
"count": len(depts),
},
message="Departments retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/departments/<dept_id>", methods=["POST"])
@login_required
@require_admin
@full_access_required
def link_principal_to_department(org_id, principal_id, dept_id):
"""
Link a principal to a department.
Args:
org_id: Organization ID
principal_id: Principal ID
dept_id: Department ID
Returns:
201: Principal linked successfully
401: Not authenticated
403: Not an admin
404: Organization, principal, or department not found
409: Already linked
"""
try:
org = OrganizationService.get_organization_by_id(org_id)
except OrganizationNotFoundError:
return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND")
principal = Principal.query.filter_by(
id=principal_id,
organization_id=org_id,
deleted_at=None
).first()
if not principal:
return api_response(
success=False,
message="Principal not found",
status=404,
error_type="NOT_FOUND",
)
dept = Department.query.filter_by(
id=dept_id,
organization_id=org_id,
deleted_at=None
).first()
if not dept:
return api_response(
success=False,
message="Department not found",
status=404,
error_type="NOT_FOUND",
)
existing = DepartmentPrincipal.query.filter_by(
department_id=dept_id,
principal_id=principal_id,
deleted_at=None
).first()
if existing:
return api_response(
success=False,
message="Principal is already linked to this department",
status=409,
error_type="CONFLICT",
)
soft_deleted = DepartmentPrincipal.query.filter(
DepartmentPrincipal.department_id == dept_id,
DepartmentPrincipal.principal_id == principal_id,
DepartmentPrincipal.deleted_at.isnot(None),
).first()
try:
if soft_deleted:
soft_deleted.deleted_at = None
else:
link = DepartmentPrincipal(
department_id=dept_id,
principal_id=principal_id,
)
db.session.add(link)
db.session.commit()
except Exception:
db.session.rollback()
return api_response(
success=False,
message="Failed to link principal to department",
status=500,
error_type="SERVER_ERROR",
)
return api_response(
data={
"principal": principal.to_dict(),
"department": dept.to_dict(),
},
message="Principal linked to department successfully",
status=201,
)
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/departments/<dept_id>", methods=["DELETE"])
@login_required
@require_admin
@full_access_required
def unlink_principal_from_department(org_id, principal_id, dept_id):
"""
Unlink a principal from a department.
Args:
org_id: Organization ID
principal_id: Principal ID
dept_id: Department ID
Returns:
200: Principal unlinked successfully
401: Not authenticated
403: Not an admin
404: Organization, principal, department, or link not found
"""
org = OrganizationService.get_organization_by_id(org_id)
principal = Principal.query.filter_by(
id=principal_id,
organization_id=org_id,
deleted_at=None
).first()
if not principal:
return api_response(
success=False,
message="Principal not found",
status=404,
error_type="NOT_FOUND",
)
dept = Department.query.filter_by(
id=dept_id,
organization_id=org_id,
deleted_at=None
).first()
if not dept:
return api_response(
success=False,
message="Department not found",
status=404,
error_type="NOT_FOUND",
)
link = DepartmentPrincipal.query.filter_by(
department_id=dept_id,
principal_id=principal_id,
deleted_at=None
).first()
if not link:
return api_response(
success=False,
message="Principal is not linked to this department",
status=404,
error_type="NOT_FOUND",
)
# Soft delete
link.deleted_at = db.func.now()
db.session.commit()
return api_response(
message="Principal unlinked from department successfully",
)
File diff suppressed because it is too large Load Diff
+714 -6
View File
@@ -73,11 +73,51 @@ def delete_me():
"""
Delete current user account (soft delete).
Blocked if the user is the sole owner of any organization that has other
active members they must transfer ownership or dissolve those organizations
first.
Returns:
200: Account deleted successfully
401: Not authenticated
409: User is sole owner of one or more organizations with other members
"""
UserService.delete_user(g.current_user, soft=True)
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
user = g.current_user
# Find orgs where this user is the sole owner AND other members exist.
owned_memberships = OrganizationMember.query.filter_by(
user_id=user.id,
role=OrganizationRole.OWNER,
deleted_at=None,
).all()
blocked_orgs = []
for membership in owned_memberships:
org = membership.organization
if org.deleted_at is not None:
continue
member_count = org.get_member_count()
if member_count > 1:
blocked_orgs.append(org.name)
if blocked_orgs:
names = ", ".join(f'"{n}"' for n in blocked_orgs)
return api_response(
success=False,
message=(
f"You are the sole owner of {len(blocked_orgs)} organization"
f"{'s' if len(blocked_orgs) > 1 else ''}: {names}. "
"Transfer ownership or delete those organizations before deleting your account."
),
status=409,
error_type="USER_IS_SOLE_OWNER",
error_details={"organizations": blocked_orgs},
)
UserService.delete_user(user, soft=True)
return api_response(
message="Account deleted successfully",
@@ -142,18 +182,686 @@ def change_password():
@full_access_required
def get_my_organizations():
"""
Get all organizations current user is a member of.
Get all organizations current user is a member of, including the user's role.
Returns:
200: List of organizations
200: List of organizations with role
401: Not authenticated
"""
organizations = UserService.get_user_organizations(g.current_user)
from gatehouse_app.models.organization.organization_member import OrganizationMember
user = g.current_user
memberships = OrganizationMember.query.filter_by(
user_id=user.id,
deleted_at=None,
).all()
orgs = []
for membership in memberships:
org = membership.organization
if not org or org.deleted_at is not None:
continue
org_dict = org.to_dict()
org_dict["role"] = membership.role.value if hasattr(membership.role, "value") else str(membership.role)
orgs.append(org_dict)
return api_response(
data={
"organizations": [org.to_dict() for org in organizations],
"count": len(organizations),
"organizations": orgs,
"count": len(orgs),
},
message="Organizations retrieved successfully",
)
@api_v1_bp.route("/users/me/principals", methods=["GET"])
@login_required
@full_access_required
def get_my_principals():
"""Return all principals the current user can sign certificates for.
For each organization the user belongs to, returns:
- Their effective principals (direct membership + via department)
- Their role in that org (so the frontend can offer admin-mode selection)
- All principals in the org (admin/owner only so they can pick any)
Returns:
200: {
orgs: [{
org_id, org_name, role,
my_principals: [{id, name, description}],
all_principals: [{id, name, description}] # populated for admin/owner only
}]
}
"""
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal
from gatehouse_app.utils.constants import OrganizationRole
user = g.current_user
user_id = user.id
# Get all org memberships
memberships = OrganizationMember.query.filter_by(
user_id=user_id,
).all()
orgs_result = []
for membership in memberships:
org = membership.organization
if not org or org.deleted_at is not None:
continue
role = membership.role
is_admin = role in (OrganizationRole.ADMIN, OrganizationRole.OWNER)
# Collect the user's effective principals for this org
# Track direct vs via-department separately
direct_principal_ids = set()
via_dept_principal_ids = set()
# Direct memberships
direct = PrincipalMembership.query.filter_by(
user_id=user_id,
deleted_at=None,
).all()
for pm in direct:
if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None:
direct_principal_ids.add(pm.principal_id)
# Via department
dept_memberships = DepartmentMembership.query.filter_by(
user_id=user_id,
deleted_at=None,
).all()
for dm in dept_memberships:
if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None:
dept_principals = DepartmentPrincipal.query.filter_by(
department_id=dm.department_id,
deleted_at=None,
).all()
for dp in dept_principals:
if dp.principal and dp.principal.deleted_at is None:
via_dept_principal_ids.add(dp.principal_id)
effective_principal_ids = direct_principal_ids | via_dept_principal_ids
# Fetch principal objects
my_principals = []
if effective_principal_ids:
my_p = Principal.query.filter(
Principal.id.in_(list(effective_principal_ids)),
Principal.deleted_at == None,
).all()
my_principals = [
{
"id": p.id,
"name": p.name,
"description": p.description,
# direct=True means removable via API; False=inherited via department
"direct": p.id in direct_principal_ids,
}
for p in my_p
]
# For admins/owners: also return all principals in the org
all_principals = []
if is_admin:
all_p = Principal.query.filter_by(
organization_id=org.id,
deleted_at=None,
).all()
all_principals = [{"id": p.id, "name": p.name, "description": p.description} for p in all_p]
orgs_result.append({
"org_id": org.id,
"org_name": org.name,
"role": role.value if hasattr(role, "value") else role,
"is_admin": is_admin,
"my_principals": my_principals,
"all_principals": all_principals,
})
return api_response(
data={"orgs": orgs_result},
message="Principals retrieved successfully",
)
@api_v1_bp.route("/admin/users", methods=["GET"])
@login_required
@full_access_required
def admin_list_users():
"""List all users the caller has admin rights to see.
The caller must be an OWNER or ADMIN of at least one organization.
Returns users that share an organization with the caller and where the
caller holds admin/owner role in that organization.
Query params:
q optional search string (matched against name/email)
page page number (default 1)
per_page page size (default 50, max 200)
"""
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.user.user import User as _User
from gatehouse_app.extensions import db as _db
from sqlalchemy import or_
caller = g.current_user
# Find orgs where caller is admin/owner
admin_memberships = OrganizationMember.query.filter(
OrganizationMember.user_id == caller.id,
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
OrganizationMember.deleted_at == None,
).all()
if not admin_memberships:
return api_response(
success=False,
message="Admin or owner role required",
status=403,
error_type="AUTHORIZATION_ERROR",
)
admin_org_ids = [m.organization_id for m in admin_memberships]
# Collect user IDs in those orgs
member_rows = OrganizationMember.query.filter(
OrganizationMember.organization_id.in_(admin_org_ids),
OrganizationMember.deleted_at == None,
).all()
visible_user_ids = list({row.user_id for row in member_rows})
# Optional search
q = request.args.get("q", "").strip()
try:
page = max(1, int(request.args.get("page", 1)))
per_page = min(200, max(1, int(request.args.get("per_page", 50))))
except ValueError:
page, per_page = 1, 50
query = _User.query.filter(
_User.id.in_(visible_user_ids),
_User.deleted_at == None,
)
if q:
like = f"%{q}%"
query = query.filter(or_(_User.email.ilike(like), _User.full_name.ilike(like)))
total = query.count()
users = query.order_by(_User.email).offset((page - 1) * per_page).limit(per_page).all()
member_lookup: dict = {}
for row in member_rows:
if row.user_id not in member_lookup:
member_lookup[row.user_id] = {
"organization_id": row.organization_id,
"role": row.role.value if hasattr(row.role, "value") else row.role,
}
users_data = []
for u in users:
d = u.to_dict()
m = member_lookup.get(u.id, {})
d["org_role"] = m.get("role", "member")
d["org_id"] = m.get("organization_id")
users_data.append(d)
return api_response(
data={
"users": users_data,
"count": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page,
},
message="Users retrieved successfully",
)
@api_v1_bp.route("/admin/users/<user_id>", methods=["GET"])
@login_required
@full_access_required
def admin_get_user(user_id):
"""Get a single user's profile (admin view with SSH keys)."""
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.user.user import User as _User
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
caller = g.current_user
target = _User.query.filter_by(id=user_id, deleted_at=None).first()
if not target:
return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
# Verify caller has admin access to a shared org
target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None}
has_access = OrganizationMember.query.filter(
OrganizationMember.user_id == caller.id,
OrganizationMember.organization_id.in_(target_org_ids),
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
OrganizationMember.deleted_at == None,
).first() is not None
if not has_access:
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
ssh_keys = SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all()
return api_response(
data={
"user": target.to_dict(),
"ssh_keys": [k.to_dict() for k in ssh_keys],
},
message="User retrieved",
)
@api_v1_bp.route("/admin/users/<user_id>/suspend", methods=["POST"])
@login_required
@full_access_required
def admin_suspend_user(user_id):
"""Suspend a user account (blocks CA issuance and login).
The caller must be an OWNER or ADMIN of an organization the target user belongs to.
"""
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.user.user import User as _User
from gatehouse_app.extensions import db as _db
from gatehouse_app.utils.constants import UserStatus, AuditAction
from gatehouse_app.services.audit_service import AuditService
caller = g.current_user
target = _User.query.filter_by(id=user_id, deleted_at=None).first()
if not target:
return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
if target.id == caller.id:
return api_response(success=False, message="Cannot suspend yourself", status=400, error_type="BAD_REQUEST")
# Verify caller has admin access to a shared org
target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None}
admin_in_shared_org = OrganizationMember.query.filter(
OrganizationMember.user_id == caller.id,
OrganizationMember.organization_id.in_(target_org_ids),
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
OrganizationMember.deleted_at == None,
).first()
if not admin_in_shared_org:
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
# ── Owner protection ──────────────────────────────────────────────────────
# An org owner cannot be suspended until they transfer ownership.
from gatehouse_app.utils.constants import OrganizationRole
owner_memberships = OrganizationMember.query.filter(
OrganizationMember.user_id == target.id,
OrganizationMember.role == OrganizationRole.OWNER,
OrganizationMember.deleted_at == None,
).all()
if owner_memberships:
org_names = [
m.organization.name
for m in owner_memberships
if m.organization and not m.organization.deleted_at
]
return api_response(
success=False,
message=(
f"Cannot suspend an organization owner. "
f"{target.email} is the owner of: {', '.join(org_names)}. "
"Transfer ownership to another member first."
),
status=403,
error_type="OWNER_PROTECTION",
)
if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT")
target.status = UserStatus.SUSPENDED
_db.session.commit()
AuditService.log_action(
action=AuditAction.USER_SUSPEND,
user_id=caller.id,
organization_id=admin_in_shared_org.organization_id,
resource_type="user",
resource_id=str(target.id),
description=f"Admin suspended user {target.email}",
metadata={"target_user_id": str(target.id), "target_email": target.email},
)
return api_response(data={"user": target.to_dict()}, message="User suspended successfully")
@api_v1_bp.route("/admin/users/<user_id>/unsuspend", methods=["POST"])
@login_required
@full_access_required
def admin_unsuspend_user(user_id):
"""Restore a suspended user account to active status."""
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.user.user import User as _User
from gatehouse_app.extensions import db as _db
from gatehouse_app.utils.constants import UserStatus, AuditAction
from gatehouse_app.services.audit_service import AuditService
caller = g.current_user
target = _User.query.filter_by(id=user_id, deleted_at=None).first()
if not target:
return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
# Verify caller has admin access to a shared org
target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None}
admin_in_shared_org = OrganizationMember.query.filter(
OrganizationMember.user_id == caller.id,
OrganizationMember.organization_id.in_(target_org_ids),
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
OrganizationMember.deleted_at == None,
).first()
if not admin_in_shared_org:
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
if target.status not in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
return api_response(success=False, message="User is not suspended", status=409, error_type="CONFLICT")
target.status = UserStatus.ACTIVE
_db.session.commit()
AuditService.log_action(
action=AuditAction.USER_UNSUSPEND,
user_id=caller.id,
organization_id=admin_in_shared_org.organization_id,
resource_type="user",
resource_id=str(target.id),
description=f"Admin unsuspended user {target.email}",
metadata={"target_user_id": str(target.id), "target_email": target.email},
)
return api_response(data={"user": target.to_dict()}, message="User unsuspended successfully")
@api_v1_bp.route("/users/me/invites", methods=["GET"])
@login_required
def get_my_pending_invites():
"""Return pending (unaccepted, non-expired) invitations for the current user's email."""
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
from datetime import datetime, timezone
user = g.current_user
now = datetime.now(timezone.utc)
invites = OrgInviteToken.query.filter(
OrgInviteToken.email == user.email,
OrgInviteToken.accepted_at.is_(None),
OrgInviteToken.expires_at > now,
OrgInviteToken.deleted_at.is_(None),
).all()
return api_response(
data={
"invites": [
{
"token": i.token,
"organization": {"id": str(i.organization_id), "name": i.organization.name},
"role": i.role,
"expires_at": i.expires_at.isoformat(),
}
for i in invites
]
},
message="Pending invitations retrieved",
)
@api_v1_bp.route("/users/me/memberships", methods=["GET"])
@login_required
def get_my_memberships():
"""Return the current user's department and principal memberships across all orgs.
Returns:
200: {
orgs: [{
org_id, org_name, role,
departments: [{id, name, description}],
principals: [{id, name, description, via_department: bool}]
}]
}
"""
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal, Department
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
user = g.current_user
memberships = OrganizationMember.query.filter_by(
user_id=user.id,
deleted_at=None,
).all()
orgs_result = []
for membership in memberships:
org = membership.organization
if not org or org.deleted_at is not None:
continue
# Departments the user belongs to
dept_memberships = DepartmentMembership.query.filter_by(
user_id=user.id,
deleted_at=None,
).all()
user_depts = [
dm.department for dm in dept_memberships
if dm.department
and dm.department.organization_id == org.id
and dm.department.deleted_at is None
]
# Principals: direct
direct_pm = PrincipalMembership.query.filter_by(
user_id=user.id,
deleted_at=None,
).all()
direct_principal_ids = {
pm.principal_id for pm in direct_pm
if pm.principal
and pm.principal.organization_id == org.id
and pm.principal.deleted_at is None
}
# Principals: via department
via_dept_principal_ids = set()
for dept in user_depts:
for dp in DepartmentPrincipal.query.filter_by(department_id=dept.id, deleted_at=None).all():
if dp.principal and dp.principal.deleted_at is None:
via_dept_principal_ids.add(dp.principal_id)
all_principal_ids = direct_principal_ids | via_dept_principal_ids
principals_list = []
if all_principal_ids:
for p in Principal.query.filter(
Principal.id.in_(list(all_principal_ids)),
Principal.deleted_at == None,
).all():
principals_list.append({
"id": str(p.id),
"name": p.name,
"description": p.description,
"via_department": p.id not in direct_principal_ids,
})
role = membership.role
orgs_result.append({
"org_id": str(org.id),
"org_name": org.name,
"role": role.value if hasattr(role, "value") else role,
"departments": [
{"id": str(d.id), "name": d.name, "description": d.description}
for d in user_depts
],
"principals": principals_list,
})
return api_response(
data={"orgs": orgs_result},
message="Memberships retrieved",
)
@api_v1_bp.route("/admin/users/<user_id>/delete", methods=["POST"])
@login_required
@full_access_required
def admin_hard_delete_user(user_id):
"""Permanently delete a user and ALL associated data (hard delete, irreversible).
Required body: {"confirm": true}
Pre-conditions:
- Caller is OWNER or ADMIN of a shared org with the target.
- Cannot delete yourself.
- Target must not be the OWNER of any active organization (transfer first).
Side-effects:
- All active SSH certificates are revoked before deletion.
- The user row and all cascaded rows are hard-deleted from the database.
- An audit log entry is written by the *caller* (so it is not lost with the user).
"""
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.user.user import User as _User
from gatehouse_app.extensions import db as _db
from gatehouse_app.utils.constants import UserStatus, AuditAction, OrganizationRole
from gatehouse_app.services.audit_service import AuditService
caller = g.current_user
data = request.get_json() or {}
if not data.get("confirm"):
return api_response(
success=False,
message="Deletion requires explicit confirmation. Send {\"confirm\": true} to proceed.",
status=400,
error_type="CONFIRMATION_REQUIRED",
)
target = _User.query.filter_by(id=user_id).first()
if not target:
return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
if target.id == caller.id:
return api_response(
success=False,
message="Cannot delete your own account via this endpoint.",
status=400,
error_type="BAD_REQUEST",
)
# Caller must be OWNER/ADMIN of a shared org.
# Include soft-deleted memberships so that already-soft-deleted users can
# still be hard-deleted by an admin who shared an org with them.
target_org_ids = {m.organization_id for m in target.organization_memberships}
admin_in_shared_org = OrganizationMember.query.filter(
OrganizationMember.user_id == caller.id,
OrganizationMember.organization_id.in_(target_org_ids),
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
OrganizationMember.deleted_at == None,
).first()
if not admin_in_shared_org:
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
# Block deletion if target is an org owner — they must transfer first
owner_memberships = OrganizationMember.query.filter(
OrganizationMember.user_id == target.id,
OrganizationMember.role == OrganizationRole.OWNER,
OrganizationMember.deleted_at == None,
).all()
if owner_memberships:
org_names = [
m.organization.name
for m in owner_memberships
if m.organization and not m.organization.deleted_at
]
return api_response(
success=False,
message=(
f"Cannot delete an organization owner. "
f"{target.email} is the owner of: {', '.join(org_names)}. "
"Transfer ownership to another member first."
),
status=403,
error_type="OWNER_PROTECTION",
)
# ── Collect counts for audit metadata ────────────────────────────────────
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
ssh_key_count = SSHKey.query.filter_by(user_id=target.id, deleted_at=None).count()
active_cert_count = SSHCertificate.query.filter_by(
user_id=target.id, revoked=False
).filter(SSHCertificate.deleted_at == None).count()
# ── Revoke all active SSH certificates before deletion ───────────────────
active_certs = SSHCertificate.query.filter_by(
user_id=target.id, revoked=False
).filter(SSHCertificate.deleted_at == None).all()
for cert in active_certs:
try:
cert.revoke("account_deleted")
except Exception:
pass
if active_certs:
try:
_db.session.flush()
except Exception:
pass
# ── Hard delete ───────────────────────────────────────────────────────────
target_email = target.email # capture before deletion
target_id_str = str(target.id)
try:
_db.session.delete(target) # cascades to all child tables
_db.session.flush()
except Exception as exc:
_db.session.rollback()
import logging
logging.getLogger(__name__).error(f"Hard delete failed for {target_id_str}: {exc}")
return api_response(
success=False,
message="Failed to delete user account. Please try again.",
status=500,
error_type="SERVER_ERROR",
)
# ── Audit log (written as the caller so it survives the deletion) ─────────
AuditService.log_action(
action=AuditAction.USER_HARD_DELETE,
user_id=caller.id,
organization_id=admin_in_shared_org.organization_id,
resource_type="user",
resource_id=target_id_str,
description=f"Admin permanently deleted user account: {target_email}",
metadata={
"deleted_user_id": target_id_str,
"deleted_user_email": target_email,
"ssh_keys_deleted": ssh_key_count,
"certs_revoked": active_cert_count,
},
)
_db.session.commit()
return api_response(
message=f"User account {target_email} has been permanently deleted.",
data={
"deleted_user_id": target_id_str,
"deleted_user_email": target_email,
"ssh_keys_deleted": ssh_key_count,
"certs_revoked": active_cert_count,
},
)
+234
View File
@@ -0,0 +1,234 @@
"""SSH CA Configuration Manager.
Handles loading and managing SSH CA configuration from etc/ssh_ca.conf
and environment variables.
"""
import os
import configparser
from pathlib import Path
from typing import Optional, Union
class SSHCAConfig:
"""Configuration manager for SSH CA settings.
Loads configuration from:
1. etc/ssh_ca.conf file
2. Environment variables (override config file)
3. Application environment-specific defaults
Example:
config = SSHCAConfig()
cert_hours = config.get_int('cert_validity_hours')
key_path = config.get_str('ca_key_path')
"""
# Configuration file location (relative to project root)
DEFAULT_CONFIG_FILE = "etc/ssh_ca.conf"
# Default values if config file is missing
DEFAULTS = {
'cert_validity_hours': '8',
'max_cert_validity_hours': '720',
'ca_key_path': '',
'max_principals_per_cert': '256',
'max_key_id_length': '255',
'verification_challenge_max_age': '24',
'auto_delete_unverified_days': '30',
}
def __init__(self, config_file: Optional[str] = None, environment: Optional[str] = None):
"""Initialize SSH CA configuration.
Args:
config_file: Path to config file (default: etc/ssh_ca.conf)
environment: Environment name (development, production, testing)
Default: value of FLASK_ENV or 'development'
"""
self.config = configparser.ConfigParser()
# Determine environment
if environment is None:
environment = os.environ.get('FLASK_ENV', 'development')
self.environment = environment
# Load config file
if config_file is None:
# Try to find config file relative to this module
module_dir = Path(__file__).parent.parent.parent
config_file = module_dir / self.DEFAULT_CONFIG_FILE
self.config_file = config_file
self._load_config()
def _load_config(self):
"""Load configuration from file and apply environment-specific overrides."""
# Set defaults
self.config['default'] = self.DEFAULTS.copy()
# Load config file if it exists
if Path(self.config_file).exists():
self.config.read(self.config_file)
# Apply environment-specific configuration
if self.environment in self.config:
for key, value in self.config[self.environment].items():
self.config['default'][key] = value
def get_str(self, key: str, default: Optional[str] = None) -> str:
"""Get a string configuration value.
First checks environment variables (SSH_CA_<KEY>), then config file.
Args:
key: Configuration key
default: Default value if not found
Returns:
Configuration value as string
"""
env_key = f"SSH_CA_{key.upper()}"
# Check environment variable first
if env_key in os.environ:
return os.environ[env_key]
# Check config file
if key in self.config['default']:
value = self.config['default'][key]
# Handle environment variable substitution
return os.path.expandvars(value)
# Return default
if default is not None:
return default
return self.DEFAULTS.get(key, '')
def get_int(self, key: str, default: Optional[int] = None) -> int:
"""Get an integer configuration value.
Args:
key: Configuration key
default: Default value if not found
Returns:
Configuration value as integer
Raises:
ValueError: If value cannot be converted to integer
"""
str_value = self.get_str(key)
if not str_value:
if default is not None:
return default
raise ValueError(f"No value found for {key}")
try:
return int(str_value)
except ValueError:
if default is not None:
return default
raise ValueError(f"Configuration {key}={str_value} is not a valid integer")
def get_bool(self, key: str, default: Optional[bool] = None) -> bool:
"""Get a boolean configuration value.
Args:
key: Configuration key
default: Default value if not found
Returns:
Configuration value as boolean
"""
str_value = self.get_str(key)
if not str_value:
if default is not None:
return default
return False
return str_value.lower() in ('true', '1', 'yes', 'on')
def get_list(self, key: str, delimiter: str = ',', default: Optional[list] = None) -> list:
"""Get a comma-separated list configuration value.
Args:
key: Configuration key
delimiter: Delimiter between items (default: comma)
default: Default value if not found
Returns:
Configuration value as list of strings
"""
str_value = self.get_str(key)
if not str_value:
if default is not None:
return default
return []
return [item.strip() for item in str_value.split(delimiter) if item.strip()]
def validate_config(self) -> list:
"""Validate SSH CA configuration.
Returns:
List of validation error messages (empty if valid)
"""
errors = []
# Check cert validity hours
try:
validity = self.get_int('cert_validity_hours')
max_validity = self.get_int('max_cert_validity_hours')
if validity > max_validity:
errors.append(
f"cert_validity_hours ({validity}) > max_cert_validity_hours ({max_validity})"
)
except ValueError as e:
errors.append(f"Invalid cert validity hours: {e}")
# Check principals limit
max_principals = self.get_int('max_principals_per_cert')
if max_principals > 256:
errors.append(f"max_principals_per_cert ({max_principals}) exceeds SSH limit of 256")
# Check ca_key_path is set
if not self.get_str('ca_key_path', '').strip():
errors.append("ca_key_path is not set")
return errors
def to_dict(self) -> dict:
"""Export current configuration as dictionary.
"""
return dict(self.config['default'])
def __repr__(self):
"""String representation of configuration."""
return f"<SSHCAConfig environment={self.environment} file={self.config_file}>"
# Global configuration instance
_config_instance = None
def get_ssh_ca_config() -> SSHCAConfig:
"""Get the global SSH CA configuration instance.
This function uses a singleton pattern to ensure only one
configuration instance is created and reused.
Returns:
SSHCAConfig instance
"""
global _config_instance
if _config_instance is None:
_config_instance = SSHCAConfig()
return _config_instance
def reset_config_instance():
"""Reset the global configuration instance.
"""
global _config_instance
_config_instance = None
+29
View File
@@ -19,6 +19,21 @@ from gatehouse_app.exceptions.validation_exceptions import (
OrganizationNotFoundError,
UserNotFoundError,
)
from gatehouse_app.exceptions.ssh_exceptions import (
SSHCAError,
SSHKeyError,
SSHKeyNotFoundError,
SSHKeyAlreadyExistsError,
SSHKeyNotVerifiedError,
SSHCertificateError,
SSHCertificateNotFoundError,
CAError,
CANotFoundError,
PrincipalError,
PrincipalNotFoundError,
DepartmentError,
DepartmentNotFoundError,
)
__all__ = [
"BaseAPIException",
@@ -37,4 +52,18 @@ __all__ = [
"EmailAlreadyExistsError",
"OrganizationNotFoundError",
"UserNotFoundError",
"SSHCAError",
"SSHKeyError",
"SSHKeyNotFoundError",
"SSHKeyAlreadyExistsError",
"SSHKeyNotVerifiedError",
"SSHCertificateError",
"SSHCertificateNotFoundError",
"CAError",
"CANotFoundError",
"PrincipalError",
"PrincipalNotFoundError",
"DepartmentError",
"DepartmentNotFoundError",
]
+2 -1
View File
@@ -16,9 +16,10 @@ class BaseAPIException(Exception):
message: Custom error message
error_details: Additional error details dictionary
"""
super().__init__()
super().__init__(self.message)
if message:
self.message = message
super().__init__(message) # update args so str(e) works
self.error_details = error_details or {}
def to_dict(self):
@@ -0,0 +1,93 @@
"""SSH-specific exceptions."""
from gatehouse_app.exceptions.base import BaseAPIException
class SSHCAError(BaseAPIException):
"""Base exception for SSH CA operations."""
status_code = 500
error_type = "SSH_CA_ERROR"
class SSHKeyError(BaseAPIException):
"""Exception for SSH key operations."""
status_code = 400
error_type = "SSH_KEY_ERROR"
class SSHKeyNotFoundError(BaseAPIException):
"""SSH key not found."""
status_code = 404
error_type = "SSH_KEY_NOT_FOUND"
class SSHKeyAlreadyExistsError(BaseAPIException):
"""SSH key already exists (duplicate fingerprint)."""
status_code = 409
error_type = "SSH_KEY_ALREADY_EXISTS"
class SSHKeyNotVerifiedError(BaseAPIException):
"""SSH key has not been verified."""
status_code = 400
error_type = "SSH_KEY_NOT_VERIFIED"
class SSHCertificateError(BaseAPIException):
"""Exception for SSH certificate operations."""
status_code = 400
error_type = "SSH_CERT_ERROR"
class SSHCertificateNotFoundError(BaseAPIException):
"""SSH certificate not found."""
status_code = 404
error_type = "SSH_CERT_NOT_FOUND"
class CAError(BaseAPIException):
"""Exception for Certificate Authority operations."""
status_code = 400
error_type = "CA_ERROR"
class CANotFoundError(BaseAPIException):
"""Certificate Authority not found."""
status_code = 404
error_type = "CA_NOT_FOUND"
class PrincipalError(BaseAPIException):
"""Exception for principal operations."""
status_code = 400
error_type = "PRINCIPAL_ERROR"
class PrincipalNotFoundError(BaseAPIException):
"""Principal not found."""
status_code = 404
error_type = "PRINCIPAL_NOT_FOUND"
class DepartmentError(BaseAPIException):
"""Exception for department operations."""
status_code = 400
error_type = "DEPARTMENT_ERROR"
class DepartmentNotFoundError(BaseAPIException):
"""Department not found."""
status_code = 404
error_type = "DEPARTMENT_NOT_FOUND"
+12 -2
View File
@@ -1,5 +1,6 @@
"""Security headers middleware."""
from flask import request
import os
from flask import current_app, request
class SecurityHeadersMiddleware:
@@ -34,13 +35,22 @@ class SecurityHeadersMiddleware:
)
# Content Security Policy
try:
flask_env = current_app.config.get("ENV") or os.environ.get("FLASK_ENV", "production")
if flask_env == "development":
connect_src = "connect-src 'self' http://localhost:5000 http://127.0.0.1:5000"
else:
connect_src = "connect-src 'self'"
except RuntimeError:
connect_src = "connect-src 'self'"
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"font-src 'self' data:; "
"connect-src 'self'"
+ connect_src
)
# Referrer Policy
+124 -18
View File
@@ -1,43 +1,149 @@
"""Models package."""
from gatehouse_app.models.base import BaseModel
from gatehouse_app.models.user import User
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.authentication_method import (
"""Models package.
Sub-packages
------------
models.user User, Session
models.organization Organization, OrganizationMember, Department,
DepartmentMembership, DepartmentPrincipal,
DepartmentCertPolicy, Principal, PrincipalMembership,
OrgInviteToken
models.auth AuthenticationMethod, ApplicationProviderConfig,
OrganizationProviderOverride, OAuthState,
AuditLog, PasswordResetToken, EmailVerificationToken
models.oidc OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCSession,
OIDCTokenMetadata, OIDCAuditLog, OidcJwksKey
models.ssh_ca CA, KeyType, CertType, CaType, CAPermission,
SSHKey, SSHCertificate, CertificateStatus,
CertificateAuditLog
models.security OrganizationSecurityPolicy, UserSecurityPolicy,
MfaPolicyCompliance
All names are re-exported here so that existing code using the flat import
style (``from gatehouse_app.models import X``) or the old per-file style
(``from gatehouse_app.models.user import User``) continue to work unchanged.
"""
# ── Base ──────────────────────────────────────────────────────────────────────
from gatehouse_app.models.base import BaseModel # noqa: F401
# ── User ──────────────────────────────────────────────────────────────────────
from gatehouse_app.models.user.user import User # noqa: F401
from gatehouse_app.models.user.session import Session # noqa: F401
# ── Organization ──────────────────────────────────────────────────────────────
from gatehouse_app.models.organization.organization import Organization # noqa: F401
from gatehouse_app.models.organization.organization_member import ( # noqa: F401
OrganizationMember,
)
from gatehouse_app.models.organization.department import ( # noqa: F401
Department,
DepartmentMembership,
DepartmentPrincipal,
)
from gatehouse_app.models.organization.department_cert_policy import ( # noqa: F401
DepartmentCertPolicy,
STANDARD_EXTENSIONS,
)
from gatehouse_app.models.organization.principal import ( # noqa: F401
Principal,
PrincipalMembership,
)
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken # noqa: F401
# ── Auth ──────────────────────────────────────────────────────────────────────
from gatehouse_app.models.auth.authentication_method import ( # noqa: F401
AuthenticationMethod,
ApplicationProviderConfig,
OrganizationProviderOverride,
OAuthState,
)
from gatehouse_app.models.session import Session
from gatehouse_app.models.audit_log import AuditLog
from gatehouse_app.models.oidc_client import OIDCClient
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken
from gatehouse_app.models.oidc_session import OIDCSession
from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.auth.audit_log import AuditLog # noqa: F401
from gatehouse_app.models.auth.password_reset_token import PasswordResetToken # noqa: F401
from gatehouse_app.models.auth.email_verification_token import ( # noqa: F401
EmailVerificationToken,
)
# ── OIDC ──────────────────────────────────────────────────────────────────────
from gatehouse_app.models.oidc.oidc_client import OIDCClient # noqa: F401
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode # noqa: F401
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken # noqa: F401
from gatehouse_app.models.oidc.oidc_session import OIDCSession # noqa: F401
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata # noqa: F401
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog # noqa: F401
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey # noqa: F401
# ── SSH / CA ──────────────────────────────────────────────────────────────────
from gatehouse_app.models.ssh_ca.ca import ( # noqa: F401
CA,
KeyType,
CertType,
CaType,
CAPermission,
)
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey # noqa: F401
from gatehouse_app.models.ssh_ca.ssh_certificate import ( # noqa: F401
SSHCertificate,
CertificateStatus,
)
from gatehouse_app.models.ssh_ca.certificate_audit_log import ( # noqa: F401
CertificateAuditLog,
)
# ── Security ──────────────────────────────────────────────────────────────────
from gatehouse_app.models.security.organization_security_policy import ( # noqa: F401
OrganizationSecurityPolicy,
)
from gatehouse_app.models.security.user_security_policy import ( # noqa: F401
UserSecurityPolicy,
)
from gatehouse_app.models.security.mfa_policy_compliance import ( # noqa: F401
MfaPolicyCompliance,
)
__all__ = [
# Base
"BaseModel",
# User
"User",
"Session",
# Organization
"Organization",
"OrganizationMember",
"Department",
"DepartmentMembership",
"DepartmentPrincipal",
"DepartmentCertPolicy",
"STANDARD_EXTENSIONS",
"Principal",
"PrincipalMembership",
"OrgInviteToken",
# Auth
"AuthenticationMethod",
"ApplicationProviderConfig",
"OrganizationProviderOverride",
"OAuthState",
"Session",
"AuditLog",
"PasswordResetToken",
"EmailVerificationToken",
# OIDC
"OIDCClient",
"OIDCAuthCode",
"OIDCRefreshToken",
"OIDCSession",
"OIDCTokenMetadata",
"OIDCAuditLog",
"OidcJwksKey",
# SSH / CA
"CA",
"KeyType",
"CertType",
"CaType",
"CAPermission",
"SSHKey",
"SSHCertificate",
"CertificateStatus",
"CertificateAuditLog",
# Security
"OrganizationSecurityPolicy",
"UserSecurityPolicy",
"MfaPolicyCompliance",
+20
View File
@@ -0,0 +1,20 @@
"""Auth subpackage — authentication methods, tokens, and audit logs."""
from gatehouse_app.models.auth.authentication_method import (
AuthenticationMethod,
ApplicationProviderConfig,
OrganizationProviderOverride,
OAuthState,
)
from gatehouse_app.models.auth.audit_log import AuditLog
from gatehouse_app.models.auth.password_reset_token import PasswordResetToken
from gatehouse_app.models.auth.email_verification_token import EmailVerificationToken
__all__ = [
"AuthenticationMethod",
"ApplicationProviderConfig",
"OrganizationProviderOverride",
"OAuthState",
"AuditLog",
"PasswordResetToken",
"EmailVerificationToken",
]
@@ -26,14 +26,13 @@ class AuditLog(BaseModel):
extra_data = db.Column(db.JSON, nullable=True)
description = db.Column(db.Text, nullable=True)
# Success/failure
# Outcome
success = db.Column(db.Boolean, default=True, nullable=False)
error_message = db.Column(db.Text, nullable=True)
# Relationships
user = db.relationship("User", back_populates="audit_logs")
# Indexes for common queries
__table_args__ = (
db.Index("idx_audit_user_action", "user_id", "action"),
db.Index("idx_audit_resource", "resource_type", "resource_id"),
@@ -45,9 +44,8 @@ class AuditLog(BaseModel):
return f"<AuditLog action={self.action} user_id={self.user_id}>"
@classmethod
def log(cls, action, user_id=None, **kwargs):
"""
Create an audit log entry.
def log(cls, action, user_id=None, **kwargs) -> "AuditLog":
"""Create an audit log entry.
Args:
action: AuditAction enum value
@@ -1,4 +1,4 @@
"""Authentication method model."""
"""Authentication method model — user credentials and OAuth provider config."""
from datetime import datetime, timedelta, timezone
import secrets
from gatehouse_app.extensions import db
@@ -35,7 +35,6 @@ class AuthenticationMethod(BaseModel):
# Relationships
user = db.relationship("User", back_populates="authentication_methods")
# Ensure unique provider combinations
__table_args__ = (
db.Index("idx_user_method", "user_id", "method_type"),
db.UniqueConstraint(
@@ -45,13 +44,15 @@ class AuthenticationMethod(BaseModel):
def __repr__(self):
"""String representation of AuthenticationMethod."""
return f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
return (
f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
)
def is_password(self):
def is_password(self) -> bool:
"""Check if this is a password authentication method."""
return self.method_type == AuthMethodType.PASSWORD
def is_oauth(self):
def is_oauth(self) -> bool:
"""Check if this is an OAuth authentication method."""
return self.method_type in [
AuthMethodType.GOOGLE,
@@ -59,32 +60,32 @@ class AuthenticationMethod(BaseModel):
AuthMethodType.MICROSOFT,
]
def is_totp(self):
def is_totp(self) -> bool:
"""Check if this is a TOTP authentication method."""
return self.method_type == AuthMethodType.TOTP
def is_webauthn(self):
def is_webauthn(self) -> bool:
"""Check if this is a WebAuthn authentication method."""
return self.method_type == AuthMethodType.WEBAUTHN
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude password hash and TOTP secrets
exclude.append("password_hash")
exclude.append("totp_secret")
exclude.append("totp_backup_codes")
# Always exclude credential material
for field in ("password_hash", "totp_secret", "totp_backup_codes"):
if field not in exclude:
exclude.append(field)
return super().to_dict(exclude=exclude)
def to_webauthn_dict(self):
"""Convert WebAuthn credential to public dictionary.
Returns:
Dictionary with safe-to-expose credential information.
Dictionary with safe-to-expose credential information, or None.
"""
if not self.is_webauthn() or not self.provider_data:
return None
data = self.provider_data
return {
"id": data.get("credential_id"),
@@ -98,26 +99,26 @@ class AuthenticationMethod(BaseModel):
class ApplicationProviderConfig(BaseModel):
"""Application-wide OAuth provider configuration.
This model stores OAuth provider credentials at the application level,
allowing users to authenticate without needing to specify an organization first.
Stores OAuth provider credentials at the application level, allowing users
to authenticate without needing to specify an organization first.
"""
__tablename__ = "application_provider_configs"
# Provider identification
provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True)
# OAuth credentials (encrypted)
# OAuth credentials (client_secret encrypted at rest)
client_id = db.Column(db.String(255), nullable=False)
client_secret_encrypted = db.Column(db.String(512), nullable=True)
# Provider status
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
# Default redirect URL
default_redirect_url = db.Column(db.String(2048), nullable=True)
# Provider-specific settings (JSON)
additional_config = db.Column(db.JSON, nullable=True)
@@ -126,28 +127,34 @@ class ApplicationProviderConfig(BaseModel):
"OrganizationProviderOverride",
back_populates="application_config",
foreign_keys="OrganizationProviderOverride.provider_type",
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
cascade="all, delete-orphan"
primaryjoin=(
"ApplicationProviderConfig.provider_type"
"==OrganizationProviderOverride.provider_type"
),
cascade="all, delete-orphan",
)
def __repr__(self):
"""String representation of ApplicationProviderConfig."""
return f"<ApplicationProviderConfig provider={self.provider_type} enabled={self.is_enabled}>"
return (
f"<ApplicationProviderConfig provider={self.provider_type} "
f"enabled={self.is_enabled}>"
)
def set_client_secret(self, plaintext_secret: str):
def set_client_secret(self, plaintext_secret: str) -> None:
"""Encrypt and store client secret.
Args:
plaintext_secret: The plaintext OAuth client secret
"""
if plaintext_secret:
self.client_secret_encrypted = encrypt(plaintext_secret)
def get_client_secret(self) -> str:
def get_client_secret(self) -> str | None:
"""Decrypt and return client secret.
Returns:
The plaintext OAuth client secret
The plaintext OAuth client secret, or None if not set.
"""
if self.client_secret_encrypted:
return decrypt(self.client_secret_encrypted)
@@ -156,37 +163,38 @@ class ApplicationProviderConfig(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude encrypted client secret
exclude.append("client_secret_encrypted")
if "client_secret_encrypted" not in exclude:
exclude.append("client_secret_encrypted")
return super().to_dict(exclude=exclude)
class OrganizationProviderOverride(BaseModel):
"""Organization-specific OAuth configuration overrides.
This model allows organizations to override application-level OAuth settings
for enterprise SSO scenarios or custom provider configurations.
Allows organizations to override application-level OAuth settings for
enterprise SSO scenarios or custom provider configurations.
"""
__tablename__ = "organization_provider_overrides"
# References
organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"),
nullable=False, index=True
db.String(36),
db.ForeignKey("organizations.id"),
nullable=False,
index=True,
)
provider_type = db.Column(db.String(50), nullable=False, index=True)
# Override OAuth credentials (encrypted, nullable - only if overriding)
# Override OAuth credentials (encrypted, nullable only set when overriding)
client_id = db.Column(db.String(255), nullable=True)
client_secret_encrypted = db.Column(db.String(512), nullable=True)
# Provider status
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
# Redirect URL override
redirect_url_override = db.Column(db.String(2048), nullable=True)
# Provider-specific settings override (JSON)
additional_config = db.Column(db.JSON, nullable=True)
@@ -196,37 +204,33 @@ class OrganizationProviderOverride(BaseModel):
"ApplicationProviderConfig",
back_populates="organization_overrides",
foreign_keys=[provider_type],
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
viewonly=True
primaryjoin=(
"ApplicationProviderConfig.provider_type"
"==OrganizationProviderOverride.provider_type"
),
viewonly=True,
)
# Unique constraint on (organization_id, provider_type)
__table_args__ = (
db.UniqueConstraint(
"organization_id", "provider_type",
name="uix_org_provider_type"
"organization_id", "provider_type", name="uix_org_provider_type"
),
)
def __repr__(self):
"""String representation of OrganizationProviderOverride."""
return f"<OrganizationProviderOverride org={self.organization_id} provider={self.provider_type}>"
return (
f"<OrganizationProviderOverride org={self.organization_id} "
f"provider={self.provider_type}>"
)
def set_client_secret(self, plaintext_secret: str):
"""Encrypt and store client secret override.
Args:
plaintext_secret: The plaintext OAuth client secret
"""
def set_client_secret(self, plaintext_secret: str) -> None:
"""Encrypt and store client secret override."""
if plaintext_secret:
self.client_secret_encrypted = encrypt(plaintext_secret)
def get_client_secret(self) -> str:
"""Decrypt and return client secret override.
Returns:
The plaintext OAuth client secret
"""
def get_client_secret(self) -> str | None:
"""Decrypt and return client secret override."""
if self.client_secret_encrypted:
return decrypt(self.client_secret_encrypted)
return None
@@ -234,53 +238,52 @@ class OrganizationProviderOverride(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude encrypted client secret
exclude.append("client_secret_encrypted")
if "client_secret_encrypted" not in exclude:
exclude.append("client_secret_encrypted")
return super().to_dict(exclude=exclude)
class OAuthState(BaseModel):
"""OAuth flow state tracking.
This model tracks OAuth authentication flow state, including PKCE parameters
and organization context (which is now optional to support login flows where
the organization isn't known until after authentication).
Tracks OAuth authentication flow state, including PKCE parameters and
organization context (which is optional to support login flows where the
organization isn't known until after authentication).
"""
__tablename__ = "oauth_states"
# OAuth state parameter (unique, used for CSRF protection)
state = db.Column(db.String(64), unique=True, nullable=False, index=True)
# Flow type: "login", "register", "link"
flow_type = db.Column(db.String(50), nullable=False)
# Provider type
provider_type = db.Column(db.String(50), nullable=False)
# User context (optional - not set for login/register flows)
# User context (optional not set for login/register flows)
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
# Organization context (NOW OPTIONAL - for SSO discovery or post-auth)
# Organization context (optional — for SSO discovery or post-auth)
organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"),
nullable=True, index=True
db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True
)
# PKCE parameters
nonce = db.Column(db.String(128), nullable=True)
code_verifier = db.Column(db.String(128), nullable=True)
code_challenge = db.Column(db.String(128), nullable=True)
# OAuth parameters
redirect_uri = db.Column(db.String(2048), nullable=True)
# Post-auth redirect (for frontend routing)
return_url = db.Column(db.String(2048), nullable=True)
# Additional state data
extra_data = db.Column(db.JSON, nullable=True)
# Expiration and usage tracking
expires_at = db.Column(db.DateTime, nullable=False, index=True)
used = db.Column(db.Boolean, default=False, nullable=False)
@@ -291,7 +294,10 @@ class OAuthState(BaseModel):
def __repr__(self):
"""String representation of OAuthState."""
return f"<OAuthState state={self.state[:8]}... flow={self.flow_type} provider={self.provider_type}>"
return (
f"<OAuthState state={self.state[:8]}... "
f"flow={self.flow_type} provider={self.provider_type}>"
)
@classmethod
def create_state(
@@ -306,10 +312,10 @@ class OAuthState(BaseModel):
code_challenge: str = None,
nonce: str = None,
extra_data: dict = None,
lifetime_seconds: int = 600
):
"""Create a new OAuth state with auto-generated state parameter.
lifetime_seconds: int = 600,
) -> "OAuthState":
"""Create a new OAuth state with an auto-generated state parameter.
Args:
flow_type: Type of flow ("login", "register", "link")
provider_type: OAuth provider type
@@ -322,13 +328,13 @@ class OAuthState(BaseModel):
nonce: OpenID Connect nonce
extra_data: Additional state data
lifetime_seconds: How long the state is valid (default 10 minutes)
Returns:
New OAuthState instance
"""
state = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds)
oauth_state = cls(
state=state,
flow_type=flow_type,
@@ -342,31 +348,30 @@ class OAuthState(BaseModel):
nonce=nonce,
extra_data=extra_data,
expires_at=expires_at,
used=False
used=False,
)
oauth_state.save()
return oauth_state
def is_valid(self) -> bool:
"""Check if the OAuth state is still valid.
Returns:
True if state hasn't expired and hasn't been used
True if state hasn't expired and hasn't been used.
"""
now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive (database returns naive datetimes)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return not self.used and expires_at > now
def mark_used(self):
def mark_used(self) -> None:
"""Mark the state as used to prevent replay attacks."""
self.used = True
self.save()
@classmethod
def cleanup_expired(cls):
def cleanup_expired(cls) -> None:
"""Remove expired OAuth states."""
now = datetime.now(timezone.utc)
cls.query.filter(cls.expires_at < now).delete()
@@ -375,6 +380,7 @@ class OAuthState(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Exclude code_verifier as it's sensitive
exclude.append("code_verifier")
# code_verifier must never be exposed
if "code_verifier" not in exclude:
exclude.append("code_verifier")
return super().to_dict(exclude=exclude)
@@ -0,0 +1,68 @@
"""Email verification token model."""
import secrets
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class EmailVerificationToken(BaseModel):
"""Single-use token for verifying a user's email address."""
__tablename__ = "email_verification_tokens"
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
used_at = db.Column(db.DateTime, nullable=True)
user = db.relationship(
"User",
backref=db.backref("email_verification_tokens", cascade="all, delete-orphan"),
)
@classmethod
def generate(cls, user_id: str, ttl_hours: int = 24) -> "EmailVerificationToken":
"""Create a new verification token for a user.
Any existing unused tokens for this user are invalidated first.
"""
cls.query.filter_by(user_id=user_id, used_at=None).delete()
db.session.flush()
token_value = secrets.token_urlsafe(48)
instance = cls(
user_id=user_id,
token=token_value,
expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
)
db.session.add(instance)
db.session.commit()
return instance
@property
def is_valid(self) -> bool:
"""Return True if the token has not been used and has not expired."""
if self.used_at is not None:
return False
now = datetime.now(timezone.utc)
expires = self.expires_at
if expires.tzinfo is None:
expires = expires.replace(tzinfo=timezone.utc)
return now < expires
def consume(self) -> None:
"""Mark the token as used."""
self.used_at = datetime.now(timezone.utc)
db.session.commit()
def __repr__(self) -> str:
return (
f"<EmailVerificationToken user_id={self.user_id} "
f"used={self.used_at is not None}>"
)
@@ -0,0 +1,69 @@
"""Password reset token model."""
import secrets
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class PasswordResetToken(BaseModel):
"""Single-use token for resetting a user's password."""
__tablename__ = "password_reset_tokens"
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
used_at = db.Column(db.DateTime, nullable=True)
user = db.relationship(
"User",
backref=db.backref("password_reset_tokens", cascade="all, delete-orphan"),
)
@classmethod
def generate(cls, user_id: str, ttl_hours: int = 2) -> "PasswordResetToken":
"""Create a new password reset token for a user.
Any existing unused tokens for this user are invalidated first.
"""
# Invalidate any existing unused tokens for this user
cls.query.filter_by(user_id=user_id, used_at=None).delete()
db.session.flush()
token_value = secrets.token_urlsafe(48)
instance = cls(
user_id=user_id,
token=token_value,
expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
)
db.session.add(instance)
db.session.commit()
return instance
@property
def is_valid(self) -> bool:
"""Return True if the token has not been used and has not expired."""
if self.used_at is not None:
return False
now = datetime.now(timezone.utc)
expires = self.expires_at
if expires.tzinfo is None:
expires = expires.replace(tzinfo=timezone.utc)
return now < expires
def consume(self) -> None:
"""Mark the token as used."""
self.used_at = datetime.now(timezone.utc)
db.session.commit()
def __repr__(self) -> str:
return (
f"<PasswordResetToken user_id={self.user_id} "
f"used={self.used_at is not None}>"
)
+4 -1
View File
@@ -82,7 +82,10 @@ class BaseModel(db.Model):
if column.name not in exclude:
value = getattr(self, column.name)
if isinstance(value, datetime):
result[column.name] = value.isoformat()
if value.tzinfo is None:
result[column.name] = value.isoformat() + "Z"
else:
result[column.name] = value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
else:
result[column.name] = value
return result
+18
View File
@@ -0,0 +1,18 @@
"""OIDC subpackage — clients, tokens, sessions, and audit logs."""
from gatehouse_app.models.oidc.oidc_client import OIDCClient
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken
from gatehouse_app.models.oidc.oidc_session import OIDCSession
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey
__all__ = [
"OIDCClient",
"OIDCAuthCode",
"OIDCRefreshToken",
"OIDCSession",
"OIDCTokenMetadata",
"OIDCAuditLog",
"OidcJwksKey",
]
@@ -1,5 +1,4 @@
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
from datetime import datetime
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
@@ -7,8 +6,7 @@ from gatehouse_app.models.base import BaseModel
class OIDCAuditLog(BaseModel):
"""OIDC Audit Log model for comprehensive OIDC event tracking.
This model logs all OIDC-related events for security, compliance,
and debugging purposes.
Logs all OIDC-related events for security, compliance, and debugging.
"""
__tablename__ = "oidc_audit_logs"
@@ -46,16 +44,29 @@ class OIDCAuditLog(BaseModel):
def __repr__(self):
"""String representation of OIDCAuditLog."""
status = "success" if self.success else "failed"
return f"<OIDCAuditLog event={self.event_type} status={status} client={self.client_id}>"
return (
f"<OIDCAuditLog event={self.event_type} "
f"status={status} client={self.client_id}>"
)
@classmethod
def log_event(cls, event_type, client_id=None, user_id=None, success=True,
error_code=None, error_description=None, ip_address=None,
user_agent=None, request_id=None, event_metadata=None):
def log_event(
cls,
event_type: str,
client_id: str = None,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
event_metadata: dict = None,
) -> "OIDCAuditLog":
"""Log an OIDC event.
Args:
event_type: Type of event (e.g., "authorization_request", "token_issue")
event_type: Type of event (e.g., "authorization_request")
client_id: The OIDC client ID
user_id: The user ID
success: Whether the event was successful
@@ -86,9 +97,19 @@ class OIDCAuditLog(BaseModel):
return log
@classmethod
def log_authorization_request(cls, client_id, user_id, redirect_uri, scope,
ip_address=None, user_agent=None, request_id=None,
success=True, error_code=None, error_description=None):
def log_authorization_request(
cls,
client_id: str,
user_id: str,
redirect_uri: str,
scope,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
) -> "OIDCAuditLog":
"""Log an authorization request event."""
return cls.log_event(
event_type="authorization_request",
@@ -100,15 +121,19 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={
"redirect_uri": redirect_uri,
"scope": scope,
}
event_metadata={"redirect_uri": redirect_uri, "scope": scope},
)
@classmethod
def log_token_issue(cls, client_id, user_id, token_type,
ip_address=None, user_agent=None, request_id=None):
def log_token_issue(
cls,
client_id: str,
user_id: str,
token_type: str,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log a token issuance event."""
return cls.log_event(
event_type="token_issue",
@@ -118,12 +143,20 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={"token_type": token_type}
event_metadata={"token_type": token_type},
)
@classmethod
def log_token_revocation(cls, client_id, user_id, token_type, reason=None,
ip_address=None, user_agent=None, request_id=None):
def log_token_revocation(
cls,
client_id: str,
user_id: str,
token_type: str,
reason: str = None,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log a token revocation event."""
return cls.log_event(
event_type="token_revocation",
@@ -133,15 +166,19 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={
"token_type": token_type,
"reason": reason,
}
event_metadata={"token_type": token_type, "reason": reason},
)
@classmethod
def log_authentication_failure(cls, client_id, error_code, error_description,
ip_address=None, user_agent=None, request_id=None):
def log_authentication_failure(
cls,
client_id: str,
error_code: str,
error_description: str,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log an authentication failure event."""
return cls.log_event(
event_type="authentication_failure",
@@ -155,7 +192,7 @@ class OIDCAuditLog(BaseModel):
)
@classmethod
def get_events_for_user(cls, user_id, limit=100):
def get_events_for_user(cls, user_id: str, limit: int = 100) -> list:
"""Get audit events for a user.
Args:
@@ -165,13 +202,15 @@ class OIDCAuditLog(BaseModel):
Returns:
List of OIDCAuditLog instances
"""
return cls.query.filter_by(user_id=user_id, deleted_at=None)\
.order_by(cls.created_at.desc())\
.limit(limit)\
return (
cls.query.filter_by(user_id=user_id, deleted_at=None)
.order_by(cls.created_at.desc())
.limit(limit)
.all()
)
@classmethod
def get_events_for_client(cls, client_id, limit=100):
def get_events_for_client(cls, client_id: str, limit: int = 100) -> list:
"""Get audit events for a client.
Args:
@@ -181,14 +220,22 @@ class OIDCAuditLog(BaseModel):
Returns:
List of OIDCAuditLog instances
"""
return cls.query.filter_by(client_id=client_id, deleted_at=None)\
.order_by(cls.created_at.desc())\
.limit(limit)\
return (
cls.query.filter_by(client_id=client_id, deleted_at=None)
.order_by(cls.created_at.desc())
.limit(limit)
.all()
)
@classmethod
def get_failed_events(cls, client_id=None, user_id=None, start_date=None,
end_date=None, limit=100):
def get_failed_events(
cls,
client_id: str = None,
user_id: str = None,
start_date=None,
end_date=None,
limit: int = 100,
) -> list:
"""Get failed audit events.
Args:
@@ -210,22 +257,8 @@ class OIDCAuditLog(BaseModel):
query = query.filter(cls.created_at >= start_date)
if end_date:
query = query.filter(cls.created_at <= end_date)
return query.order_by(cls.created_at.desc()).limit(limit).all()
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_audit_logs = db.relationship(
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.audit_logs = db.relationship(
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
)
@@ -1,14 +1,14 @@
"""OIDC Authorization Code model for auth code flow."""
"""OIDC Authorization Code model for the authorization code grant flow."""
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCAuthCode(BaseModel):
"""OIDC Authorization Code model for authorization code flow.
"""OIDC Authorization Code model for the authorization code grant flow.
Authorization codes are single-use, short-lived codes used in the
authorization code grant flow. The code is hashed for security.
Authorization codes are single-use, short-lived codes. The code itself is
hashed before storage so that a database breach cannot replay codes.
"""
__tablename__ = "oidc_authorization_codes"
@@ -26,9 +26,9 @@ class OIDCAuthCode(BaseModel):
# Request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True) # Requested scopes
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
code_verifier = db.Column(db.String(255), nullable=True) # For PKCE
scope = db.Column(db.JSON, nullable=True)
nonce = db.Column(db.String(255), nullable=True)
code_verifier = db.Column(db.String(255), nullable=True)
# Status tracking
expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -39,37 +39,48 @@ class OIDCAuthCode(BaseModel):
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
# Relationships
# Relationships — back_populates declared on User and OIDCClient
client = db.relationship("OIDCClient", back_populates="authorization_codes")
user = db.relationship("User", back_populates="oidc_auth_codes")
def __repr__(self):
"""String representation of OIDCAuthCode."""
return f"<OIDCAuthCode client_id={self.client_id} user_id={self.user_id} used={self.is_used}>"
return (
f"<OIDCAuthCode client_id={self.client_id} "
f"user_id={self.user_id} used={self.is_used}>"
)
def is_expired(self):
def is_expired(self) -> bool:
"""Check if the authorization code has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
# Make naive datetime timezone-aware (UTC)
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_valid(self):
def is_valid(self) -> bool:
"""Check if the authorization code is valid for use."""
return not self.is_used and not self.is_expired() and self.deleted_at is None
def mark_as_used(self):
def mark_as_used(self) -> None:
"""Mark the authorization code as used."""
self.is_used = True
self.used_at = datetime.now(timezone.utc)
db.session.commit()
@classmethod
def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None,
nonce=None, code_verifier=None, ip_address=None, user_agent=None,
lifetime_seconds=600):
def create_code(
cls,
client_id: str,
user_id: str,
code_hash: str,
redirect_uri: str,
scope=None,
nonce: str = None,
code_verifier: str = None,
ip_address: str = None,
user_agent: str = None,
lifetime_seconds: int = 600,
) -> "OIDCAuthCode":
"""Create a new authorization code.
Args:
@@ -79,7 +90,7 @@ class OIDCAuthCode(BaseModel):
redirect_uri: The redirect URI
scope: Requested scopes
nonce: OIDC nonce
code_verifier: PKCE code verifier
code_verifier: PKCE code verifier (stored hashed server-side)
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
@@ -106,20 +117,7 @@ class OIDCAuthCode(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude code hash
exclude.append("code_hash")
exclude.append("code_verifier")
for field in ("code_hash", "code_verifier"):
if field not in exclude:
exclude.append(field)
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_auth_codes = db.relationship(
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.authorization_codes = db.relationship(
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
)
@@ -17,10 +17,10 @@ class OIDCClient(BaseModel):
client_secret_hash = db.Column(db.String(255), nullable=False)
# OAuth/OIDC configuration
redirect_uris = db.Column(db.JSON, nullable=False) # List of allowed redirect URIs
grant_types = db.Column(db.JSON, nullable=False) # List of allowed grant types
response_types = db.Column(db.JSON, nullable=False) # List of allowed response types
scopes = db.Column(db.JSON, nullable=False) # List of allowed scopes
redirect_uris = db.Column(db.JSON, nullable=False) # Allowed redirect URIs
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
response_types = db.Column(db.JSON, nullable=False) # Allowed response types
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
# Client metadata
logo_uri = db.Column(db.String(512), nullable=True)
@@ -41,6 +41,23 @@ class OIDCClient(BaseModel):
# Relationships
organization = db.relationship("Organization", back_populates="oidc_clients")
# OIDC sub-resource relationships (declared here, not monkey-patched elsewhere)
authorization_codes = db.relationship(
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
)
refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
)
oidc_sessions = db.relationship(
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
)
token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
)
audit_logs = db.relationship(
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
)
def __repr__(self):
"""String representation of OIDCClient."""
return f"<OIDCClient {self.name} client_id={self.client_id}>"
@@ -48,22 +65,22 @@ class OIDCClient(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude client secret
exclude.append("client_secret_hash")
if "client_secret_hash" not in exclude:
exclude.append("client_secret_hash")
return super().to_dict(exclude=exclude)
def has_grant_type(self, grant_type):
def has_grant_type(self, grant_type) -> bool:
"""Check if client supports a specific grant type."""
return grant_type in self.grant_types
def has_response_type(self, response_type):
def has_response_type(self, response_type) -> bool:
"""Check if client supports a specific response type."""
return response_type in self.response_types
def is_redirect_uri_allowed(self, redirect_uri):
def is_redirect_uri_allowed(self, redirect_uri: str) -> bool:
"""Check if a redirect URI is allowed for this client."""
return redirect_uri in self.redirect_uris
def has_scope(self, scope):
def has_scope(self, scope: str) -> bool:
"""Check if client is allowed to request a specific scope."""
return scope in self.scopes
@@ -0,0 +1,76 @@
"""OIDC JWKS Key model for persisting signing keys."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OidcJwksKey(BaseModel):
"""OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
Stores RSA/ECDSA key pairs used for signing OIDC tokens. Multiple keys can
be stored to support key rotation scenarios.
Attributes:
kid: Unique key ID used in JWT ``kid`` header
key_type: Type of key (e.g., "RSA", "EC")
private_key: PEM-encoded private key (never exposed in API responses)
public_key: PEM-encoded public key
algorithm: Signing algorithm (e.g., "RS256", "ES256")
is_active: Whether this key is currently used for signing/verification
is_primary: Whether this is the primary signing key
expires_at: Optional expiry for key rotation enforcement
"""
__tablename__ = "oidc_jwks_keys"
# Override the default UUID id with integer primary key for JWKS key sets
id = db.Column(db.Integer, primary_key=True)
expires_at = db.Column(db.DateTime, nullable=True)
# Key identification and type
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
# Key material (PEM-encoded) — private_key must never be returned by API
private_key = db.Column(db.Text, nullable=False)
public_key = db.Column(db.Text, nullable=False)
# Key status
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_primary = db.Column(db.Boolean, default=False, nullable=False)
def __repr__(self):
"""String representation of OidcJwksKey."""
return (
f"<OidcJwksKey kid={self.kid} "
f"key_type={self.key_type} algorithm={self.algorithm}>"
)
def to_dict(self, exclude_private_key: bool = True):
"""Convert model to dictionary.
Args:
exclude_private_key: If True (default), excludes the private key.
Returns:
Dictionary representation of the model
"""
exclude = ["private_key"] if exclude_private_key else []
return super().to_dict(exclude=exclude)
@classmethod
def get_active_keys(cls) -> list:
"""Get all active keys for signing operations."""
return cls.query.filter_by(is_active=True).all()
@classmethod
def get_primary_key(cls) -> "OidcJwksKey | None":
"""Get the primary signing key."""
return cls.query.filter_by(is_primary=True).first()
@classmethod
def get_key_by_kid(cls, kid: str) -> "OidcJwksKey | None":
"""Get an active key by its key ID."""
return cls.query.filter_by(kid=kid, is_active=True).first()
@@ -1,5 +1,5 @@
"""OIDC Refresh Token model for token rotation."""
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
@@ -8,7 +8,8 @@ class OIDCRefreshToken(BaseModel):
"""OIDC Refresh Token model for token refresh and rotation.
Refresh tokens are long-lived credentials used to obtain new access tokens.
They support token rotation for enhanced security.
They support token rotation for enhanced security each use invalidates
the old token and issues a new one.
"""
__tablename__ = "oidc_refresh_tokens"
@@ -21,16 +22,14 @@ class OIDCRefreshToken(BaseModel):
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Token (hashed for security)
# Token (hashed for security — never store plaintext refresh tokens)
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
# Associated access token ID (stores JWT JTI string — no FK to sessions)
access_token_id = db.Column(
db.String(255), nullable=True, index=True
)
# Associated access token JTI (no FK — stored as string for lightweight lookup)
access_token_id = db.Column(db.String(255), nullable=True, index=True)
# Token scope
scope = db.Column(db.JSON, nullable=True) # Granted scopes
scope = db.Column(db.JSON, nullable=True)
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -40,7 +39,7 @@ class OIDCRefreshToken(BaseModel):
revoked_reason = db.Column(db.String(255), nullable=True)
# Token rotation metadata
previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation
previous_token_hash = db.Column(db.String(255), nullable=True)
rotation_count = db.Column(db.Integer, default=0, nullable=False)
# Request metadata
@@ -53,25 +52,27 @@ class OIDCRefreshToken(BaseModel):
def __repr__(self):
"""String representation of OIDCRefreshToken."""
return f"<OIDCRefreshToken client_id={self.client_id} user_id={self.user_id} revoked={self.is_revoked()}>"
return (
f"<OIDCRefreshToken client_id={self.client_id} "
f"user_id={self.user_id} revoked={self.is_revoked()}>"
)
def is_expired(self):
def is_expired(self) -> bool:
"""Check if the refresh token has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self):
def is_revoked(self) -> bool:
"""Check if the refresh token has been revoked."""
return self.revoked_at is not None
def is_valid(self):
def is_valid(self) -> bool:
"""Check if the refresh token is valid for use."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason=None):
def revoke(self, reason: str = None) -> None:
"""Revoke the refresh token.
Args:
@@ -81,8 +82,8 @@ class OIDCRefreshToken(BaseModel):
self.revoked_reason = reason
db.session.commit()
def rotate(self, new_token_hash):
"""Rotate the refresh token (invalidate old, create new).
def rotate(self, new_token_hash: str) -> "OIDCRefreshToken":
"""Rotate the refresh token invalidate the old hash, store the new one.
Args:
new_token_hash: Hash of the new refresh token
@@ -90,20 +91,25 @@ class OIDCRefreshToken(BaseModel):
Returns:
self for chaining
"""
# Store reference to old token
self.previous_token_hash = self.token_hash
self.token_hash = new_token_hash
self.rotation_count += 1
# Extend expiration on rotation
from datetime import timedelta
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
db.session.commit()
return self
@classmethod
def create_token(cls, client_id, user_id, token_hash, scope=None,
access_token_id=None, ip_address=None, user_agent=None,
lifetime_seconds=2592000):
def create_token(
cls,
client_id: str,
user_id: str,
token_hash: str,
scope=None,
access_token_id: str = None,
ip_address: str = None,
user_agent: str = None,
lifetime_seconds: int = 2592000,
) -> "OIDCRefreshToken":
"""Create a new refresh token.
Args:
@@ -111,7 +117,7 @@ class OIDCRefreshToken(BaseModel):
user_id: The user ID
token_hash: Hashed refresh token
scope: Granted scopes
access_token_id: Associated access token ID
access_token_id: Associated access token JTI
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Token lifetime in seconds (default 30 days)
@@ -119,7 +125,6 @@ class OIDCRefreshToken(BaseModel):
Returns:
OIDCRefreshToken instance
"""
from datetime import timedelta
token = cls(
client_id=client_id,
user_id=user_id,
@@ -137,20 +142,7 @@ class OIDCRefreshToken(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude token hashes
exclude.append("token_hash")
exclude.append("previous_token_hash")
for field in ("token_hash", "previous_token_hash"):
if field not in exclude:
exclude.append(field)
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
)
@@ -1,5 +1,7 @@
"""OIDC Session model for OIDC session tracking."""
from datetime import datetime, timezone
import hashlib
import base64
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
@@ -7,8 +9,8 @@ from gatehouse_app.models.base import BaseModel
class OIDCSession(BaseModel):
"""OIDC Session model for tracking OIDC authentication sessions.
This model tracks the state during the OIDC authentication flow,
including PKCE parameters and nonce validation.
Tracks the state during the OIDC authorization flow, including PKCE
parameters and nonce validation.
"""
__tablename__ = "oidc_sessions"
@@ -25,11 +27,11 @@ class OIDCSession(BaseModel):
# State management
state = db.Column(db.String(255), nullable=False, index=True)
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
nonce = db.Column(db.String(255), nullable=True)
# Authorization request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True) # Requested scopes
scope = db.Column(db.JSON, nullable=True)
# PKCE parameters
code_challenge = db.Column(db.String(255), nullable=True)
@@ -45,50 +47,52 @@ class OIDCSession(BaseModel):
def __repr__(self):
"""String representation of OIDCSession."""
return f"<OIDCSession user_id={self.user_id} client_id={self.client_id} state={self.state[:8]}...>"
return (
f"<OIDCSession user_id={self.user_id} "
f"client_id={self.client_id} state={self.state[:8]}...>"
)
def is_expired(self):
def is_expired(self) -> bool:
"""Check if the OIDC session has expired."""
return datetime.now(timezone.utc) > self.expires_at
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_authenticated(self):
def is_authenticated(self) -> bool:
"""Check if the user has been authenticated in this session."""
return self.authenticated_at is not None
def mark_authenticated(self):
def mark_authenticated(self) -> None:
"""Mark the session as authenticated."""
self.authenticated_at = datetime.now(timezone.utc)
db.session.commit()
def validate_nonce(self, expected_nonce):
def validate_nonce(self, expected_nonce: str) -> bool:
"""Validate the nonce matches the expected value.
Args:
expected_nonce: The expected nonce value
Returns:
bool: True if nonce matches
True if nonce matches
"""
return self.nonce == expected_nonce
def validate_code_challenge(self, code_verifier):
def validate_code_challenge(self, code_verifier: str) -> bool:
"""Validate the code verifier against the stored code challenge.
Args:
code_verifier: The PKCE code verifier
Returns:
bool: True if code challenge is valid
True if the challenge is satisfied
"""
if not self.code_challenge:
return False
if self.code_challenge_method == "S256":
import hashlib
import base64
# SHA256 hash of code_verifier
digest = hashlib.sha256(code_verifier.encode()).digest()
# Base64 URL encode without padding
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
return self.code_challenge == expected
elif self.code_challenge_method == "plain":
@@ -97,9 +101,18 @@ class OIDCSession(BaseModel):
return False
@classmethod
def create_session(cls, user_id, client_id, state, redirect_uri, scope=None,
nonce=None, code_challenge=None, code_challenge_method=None,
lifetime_seconds=600):
def create_session(
cls,
user_id: str,
client_id: str,
state: str,
redirect_uri: str,
scope=None,
nonce: str = None,
code_challenge: str = None,
code_challenge_method: str = None,
lifetime_seconds: int = 600,
) -> "OIDCSession":
"""Create a new OIDC session.
Args:
@@ -116,7 +129,6 @@ class OIDCSession(BaseModel):
Returns:
OIDCSession instance
"""
from datetime import timedelta
session = cls(
user_id=user_id,
client_id=client_id,
@@ -133,7 +145,7 @@ class OIDCSession(BaseModel):
return session
@classmethod
def get_by_state(cls, state):
def get_by_state(cls, state: str) -> "OIDCSession | None":
"""Get a session by state parameter.
Args:
@@ -147,16 +159,3 @@ class OIDCSession(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_sessions = db.relationship(
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.oidc_sessions = db.relationship(
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
)
@@ -8,13 +8,14 @@ from gatehouse_app.models.base import BaseModel
class OIDCTokenMetadata(BaseModel):
"""OIDC Token Metadata model for tracking issued tokens.
This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens)
for the purpose of token revocation. The id field matches the JTI (JWT ID) claim.
Stores metadata about issued tokens (access, refresh, ID) for revocation.
The ``id`` field on this model intentionally overrides the BaseModel UUID
to store the JWT JTI directly as the primary key for O(1) revocation checks.
"""
__tablename__ = "oidc_token_metadata"
# Token identifier (matches JTI in JWT)
# Primary key = JTI so revocation lookups are always a PK scan
id = db.Column(
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
@@ -27,11 +28,11 @@ class OIDCTokenMetadata(BaseModel):
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Token type
token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token"
# Token type: "access_token", "refresh_token", or "id_token"
token_type = db.Column(db.String(50), nullable=False)
# Token identifier for revocation lookup
token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim
# JWT ID claim (indexed for fast lookup when id != jti)
token_jti = db.Column(db.String(255), nullable=False, index=True)
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -46,25 +47,27 @@ class OIDCTokenMetadata(BaseModel):
def __repr__(self):
"""String representation of OIDCTokenMetadata."""
return f"<OIDCTokenMetadata jti={self.token_jti[:8]}... type={self.token_type} revoked={self.is_revoked()}>"
return (
f"<OIDCTokenMetadata jti={self.token_jti[:8]}... "
f"type={self.token_type} revoked={self.is_revoked()}>"
)
def is_expired(self):
def is_expired(self) -> bool:
"""Check if the token has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self):
def is_revoked(self) -> bool:
"""Check if the token has been revoked."""
return self.revoked_at is not None
def is_valid(self):
def is_valid(self) -> bool:
"""Check if the token is valid (not expired and not revoked)."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason=None):
def revoke(self, reason: str = None) -> None:
"""Revoke the token.
Args:
@@ -75,8 +78,16 @@ class OIDCTokenMetadata(BaseModel):
db.session.commit()
@classmethod
def create_metadata(cls, client_id, user_id, token_type, token_jti,
expires_at, ip_address=None, user_agent=None):
def create_metadata(
cls,
client_id: str,
user_id: str,
token_type: str,
token_jti: str,
expires_at,
ip_address: str = None,
user_agent: str = None,
) -> "OIDCTokenMetadata":
"""Create token metadata for tracking.
Args:
@@ -85,8 +96,8 @@ class OIDCTokenMetadata(BaseModel):
token_type: Type of token ("access_token", "refresh_token", "id_token")
token_jti: JWT ID claim
expires_at: Token expiration datetime
ip_address: Client IP address
user_agent: Client user agent
ip_address: Client IP address (unused column, kept for API compat)
user_agent: Client user agent (unused column, kept for API compat)
Returns:
OIDCTokenMetadata instance
@@ -104,7 +115,7 @@ class OIDCTokenMetadata(BaseModel):
return metadata
@classmethod
def get_by_jti(cls, token_jti):
def get_by_jti(cls, token_jti: str) -> "OIDCTokenMetadata | None":
"""Get token metadata by JWT ID.
Args:
@@ -116,7 +127,7 @@ class OIDCTokenMetadata(BaseModel):
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
@classmethod
def revoke_by_jti(cls, token_jti, reason=None):
def revoke_by_jti(cls, token_jti: str, reason: str = None) -> bool:
"""Revoke a token by its JWT ID.
Args:
@@ -124,7 +135,7 @@ class OIDCTokenMetadata(BaseModel):
reason: Optional revocation reason
Returns:
bool: True if token was found and revoked
True if token was found and revoked, False otherwise
"""
metadata = cls.get_by_jti(token_jti)
if metadata:
@@ -133,47 +144,53 @@ class OIDCTokenMetadata(BaseModel):
return False
@classmethod
def revoke_all_for_user(cls, user_id, client_id=None, reason=None):
def revoke_all_for_user(
cls, user_id: str, client_id: str = None, reason: str = None
) -> int:
"""Revoke all tokens for a user.
Args:
user_id: The user ID
client_id: Optional client ID to filter by
client_id: Optional client ID filter
reason: Optional revocation reason
Returns:
int: Number of tokens revoked
Number of tokens revoked
"""
query = cls.query.filter_by(user_id=user_id, deleted_at=None)
query = cls.query.filter_by(user_id=user_id, deleted_at=None).filter(
cls.revoked_at.is_(None)
)
if client_id:
query = query.filter_by(client_id=client_id)
tokens = query.filter(cls.revoked_at == None).all()
count = 0
for token in tokens:
for token in query.all():
token.revoke(reason)
count += 1
return count
@classmethod
def revoke_all_for_client(cls, client_id, user_id=None, reason=None):
def revoke_all_for_client(
cls, client_id: str, user_id: str = None, reason: str = None
) -> int:
"""Revoke all tokens for a client.
Args:
client_id: The client ID
user_id: Optional user ID to filter by
user_id: Optional user ID filter
reason: Optional revocation reason
Returns:
int: Number of tokens revoked
Number of tokens revoked
"""
query = cls.query.filter_by(client_id=client_id, deleted_at=None)
query = cls.query.filter_by(client_id=client_id, deleted_at=None).filter(
cls.revoked_at.is_(None)
)
if user_id:
query = query.filter_by(user_id=user_id)
tokens = query.filter(cls.revoked_at == None).all()
count = 0
for token in tokens:
for token in query.all():
token.revoke(reason)
count += 1
return count
@@ -181,16 +198,3 @@ class OIDCTokenMetadata(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
)
-77
View File
@@ -1,77 +0,0 @@
"""OIDC JWKS Key model for persisting signing keys."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OidcJwksKey(BaseModel):
"""
OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
This model stores RSA/ECDSA key pairs used for signing OIDC tokens.
Multiple keys can be stored to support key rotation scenarios.
Attributes:
id: Integer primary key
kid: Unique key ID used in JWT "kid" header
key_type: Type of key (e.g., "RSA", "EC")
private_key: PEM-encoded private key
public_key: PEM-encoded public key
algorithm: Signing algorithm (e.g., "RS256", "ES256")
created_at: When the key was created
is_active: Whether this key is currently active for signing
is_primary: Whether this is the primary signing key
expires_at: ...
"""
__tablename__ = "oidc_jwks_keys"
# Override the default UUID id with integer primary key
id = db.Column(db.Integer, primary_key=True)
expires_at = db.Column(db.DateTime, nullable=True)
# Key identification and type
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
# Key material (PEM-encoded)
private_key = db.Column(db.Text, nullable=False)
public_key = db.Column(db.Text, nullable=False)
# Key status
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_primary = db.Column(db.Boolean, default=False, nullable=False)
def __repr__(self):
"""String representation of OidcJwksKey."""
return f"<OidcJwksKey kid={self.kid} key_type={self.key_type} algorithm={self.algorithm}>"
def to_dict(self, exclude_private_key=True):
"""
Convert model to dictionary.
Args:
exclude_private_key: If True, excludes the private key from output
Returns:
Dictionary representation of the model
"""
exclude = ["private_key"] if exclude_private_key else []
return super().to_dict(exclude=exclude)
@classmethod
def get_active_keys(cls):
"""Get all active keys for signing operations."""
return cls.query.filter(cls.is_active == True).all()
@classmethod
def get_primary_key(cls):
"""Get the primary signing key."""
return cls.query.filter(cls.is_primary == True).first()
@classmethod
def get_key_by_kid(cls, kid):
"""Get a key by its key ID."""
return cls.query.filter(cls.kid == kid, cls.is_active == True).first()
@@ -0,0 +1,27 @@
"""Organization subpackage."""
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.organization.department import (
Department,
DepartmentMembership,
DepartmentPrincipal,
)
from gatehouse_app.models.organization.department_cert_policy import (
DepartmentCertPolicy,
STANDARD_EXTENSIONS,
)
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
__all__ = [
"Organization",
"OrganizationMember",
"Department",
"DepartmentMembership",
"DepartmentPrincipal",
"DepartmentCertPolicy",
"STANDARD_EXTENSIONS",
"Principal",
"PrincipalMembership",
"OrgInviteToken",
]
@@ -0,0 +1,196 @@
"""Department, DepartmentMembership, and DepartmentPrincipal models."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class Department(BaseModel):
"""Department model representing an organizational unit for SSH access control.
Departments are used to group users and assign SSH principals (access levels)
to them. A user can be a member of multiple departments, and each department
can have multiple principals assigned.
Example:
- Department: "Engineering"
- Members: user1@example.com, user2@example.com
- Principals: "eng-prod", "eng-staging"
- Users get access based on their principal assignments
"""
__tablename__ = "departments"
organization_id = db.Column(
db.String(36),
db.ForeignKey("organizations.id"),
nullable=False,
index=True,
)
name = db.Column(db.String(255), nullable=False, index=True)
description = db.Column(db.Text, nullable=True)
# Relationships
organization = db.relationship("Organization", back_populates="departments")
memberships = db.relationship(
"DepartmentMembership",
back_populates="department",
cascade="all, delete-orphan",
)
principal_links = db.relationship(
"DepartmentPrincipal",
back_populates="department",
cascade="all, delete-orphan",
)
cert_policy = db.relationship(
"DepartmentCertPolicy",
back_populates="department",
uselist=False,
cascade="all, delete-orphan",
)
__table_args__ = (
db.UniqueConstraint("organization_id", "name", name="uix_org_dept_name"),
)
def __repr__(self):
"""String representation of Department."""
return f"<Department {self.name} (org_id={self.organization_id})>"
def to_dict(self, exclude=None):
"""Convert department to dictionary."""
exclude = exclude or []
data = super().to_dict(exclude=exclude)
data["member_count"] = len([m for m in self.memberships if m.deleted_at is None])
data["principal_count"] = len([p for p in self.principal_links if p.deleted_at is None])
return data
def get_members(self, active_only: bool = True):
"""Get all members of this department.
Args:
active_only: If True, exclude soft-deleted members
Returns:
List of DepartmentMembership objects
"""
if active_only:
return [m for m in self.memberships if m.deleted_at is None]
return list(self.memberships)
def get_principals(self, active_only: bool = True):
"""Get all principals assigned to this department.
Args:
active_only: If True, exclude soft-deleted principals
Returns:
List of Principal objects via DepartmentPrincipal
"""
if active_only:
return [
p.principal
for p in self.principal_links
if p.deleted_at is None and p.principal.deleted_at is None
]
return [p.principal for p in self.principal_links]
def is_member(self, user_id: str) -> bool:
"""Check if a user is a member of this department.
Args:
user_id: ID of the user to check
Returns:
True if user is an active member, False otherwise
"""
return (
DepartmentMembership.query.filter_by(
user_id=user_id,
department_id=self.id,
deleted_at=None,
).first()
is not None
)
def get_member_count(self) -> int:
"""Get the count of active members in this department."""
return len(self.get_members(active_only=True))
class DepartmentMembership(BaseModel):
"""Department membership model representing user membership in a department.
When a user is added to a department, they become eligible for SSH principals
assigned to that department.
"""
__tablename__ = "department_memberships"
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id"),
nullable=False,
index=True,
)
department_id = db.Column(
db.String(36),
db.ForeignKey("departments.id"),
nullable=False,
index=True,
)
# Relationships
user = db.relationship("User", back_populates="department_memberships")
department = db.relationship("Department", back_populates="memberships")
__table_args__ = (
db.UniqueConstraint("user_id", "department_id", name="uix_user_dept"),
)
def __repr__(self):
"""String representation of DepartmentMembership."""
return (
f"<DepartmentMembership user_id={self.user_id} dept_id={self.department_id}>"
)
class DepartmentPrincipal(BaseModel):
"""Department principal assignment model.
Represents the assignment of principals to departments. All members of a
department get access to its assigned principals (transitively).
Example:
- Department: "Engineering"
- Principal: "eng-prod-servers"
- All engineering department members can SSH as "eng-prod-servers"
"""
__tablename__ = "department_principals"
department_id = db.Column(
db.String(36),
db.ForeignKey("departments.id"),
nullable=False,
index=True,
)
principal_id = db.Column(
db.String(36),
db.ForeignKey("principals.id"),
nullable=False,
index=True,
)
# Relationships
department = db.relationship("Department", back_populates="principal_links")
principal = db.relationship("Principal", back_populates="department_links")
__table_args__ = (
db.UniqueConstraint("department_id", "principal_id", name="uix_dept_principal"),
)
def __repr__(self):
"""String representation of DepartmentPrincipal."""
return (
f"<DepartmentPrincipal dept_id={self.department_id} "
f"principal_id={self.principal_id}>"
)
@@ -0,0 +1,76 @@
"""DepartmentCertPolicy — per-department SSH certificate issuance rules."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
# Standard SSH certificate extensions
STANDARD_EXTENSIONS = [
"permit-X11-forwarding",
"permit-agent-forwarding",
"permit-pty",
"permit-port-forwarding",
"permit-user-rc",
]
class DepartmentCertPolicy(BaseModel):
"""SSH certificate policy for a department.
Controls:
- Whether members may choose their own expiry date (up to ``max_expiry_hours``)
- Default expiry hours when the user doesn't (or can't) pick
- Maximum expiry hours (hard ceiling, even for admins signing on behalf)
- Which SSH certificate extensions are granted to members of this department
- Any custom extensions the admin wants to add beyond the standard five
Inherits ``id``, ``created_at``, ``updated_at``, and ``deleted_at`` from
:class:`BaseModel` so soft-delete and the standard timestamp behaviour are
consistent with every other model in the application.
"""
__tablename__ = "department_cert_policies"
department_id = db.Column(
db.String(36),
db.ForeignKey("departments.id"),
nullable=False,
unique=True,
index=True,
)
# Expiry control
allow_user_expiry = db.Column(db.Boolean, nullable=False, default=False)
default_expiry_hours = db.Column(db.Integer, nullable=False, default=1)
max_expiry_hours = db.Column(db.Integer, nullable=False, default=24)
# Extensions — list of extension name strings
allowed_extensions = db.Column(
db.JSON,
nullable=False,
default=lambda: list(STANDARD_EXTENSIONS),
)
# Admin-defined extras beyond the standard five
custom_extensions = db.Column(db.JSON, nullable=False, default=list)
# Relationship back to department
department = db.relationship("Department", back_populates="cert_policy", uselist=False)
def __repr__(self):
return (
f"<DepartmentCertPolicy dept={self.department_id} "
f"allow_user_expiry={self.allow_user_expiry}>"
)
def all_extensions(self) -> list:
"""Return the full list of enabled extensions (allowed + custom)."""
return list((self.allowed_extensions or []) + (self.custom_extensions or []))
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
data = super().to_dict(exclude=exclude)
# Augment with computed / convenience fields not in the base columns
data["all_extensions"] = self.all_extensions()
data["standard_extensions"] = STANDARD_EXTENSIONS
return data
@@ -0,0 +1,77 @@
"""Organization invite token model."""
import secrets
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OrgInviteToken(BaseModel):
"""Token-based invitation to join an organization."""
__tablename__ = "org_invite_tokens"
organization_id = db.Column(
db.String(36),
db.ForeignKey("organizations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
invited_by_id = db.Column(
db.String(36),
db.ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
email = db.Column(db.String(255), nullable=False, index=True)
role = db.Column(db.String(64), nullable=False, default="member")
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
accepted_at = db.Column(db.DateTime, nullable=True)
organization = db.relationship(
"Organization",
backref=db.backref("invite_tokens", cascade="all, delete-orphan"),
)
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
@classmethod
def generate(
cls,
organization_id: str,
email: str,
role: str = "member",
invited_by_id: str = None,
ttl_days: int = 7,
) -> "OrgInviteToken":
"""Create a new invite token for an organization."""
token_value = secrets.token_urlsafe(48)
instance = cls(
organization_id=organization_id,
email=email.lower(),
role=role,
invited_by_id=invited_by_id,
token=token_value,
expires_at=datetime.now(timezone.utc) + timedelta(days=ttl_days),
)
db.session.add(instance)
db.session.commit()
return instance
@property
def is_valid(self) -> bool:
"""Return True if the token is unused and not expired."""
if self.accepted_at is not None:
return False
now = datetime.now(timezone.utc)
expires = self.expires_at
if expires.tzinfo is None:
expires = expires.replace(tzinfo=timezone.utc)
return now < expires
def accept(self) -> None:
"""Mark the invite as accepted."""
self.accepted_at = datetime.now(timezone.utc)
db.session.commit()
def __repr__(self) -> str:
return f"<OrgInviteToken org={self.organization_id} email={self.email}>"
@@ -34,6 +34,15 @@ class Organization(BaseModel):
cascade="all, delete-orphan",
foreign_keys="OrganizationSecurityPolicy.organization_id",
)
departments = db.relationship(
"Department", back_populates="organization", cascade="all, delete-orphan"
)
principals = db.relationship(
"Principal", back_populates="organization", cascade="all, delete-orphan"
)
cas = db.relationship(
"CA", back_populates="organization", cascade="all, delete-orphan"
)
def __repr__(self):
"""String representation of Organization."""
@@ -52,9 +61,9 @@ class Organization(BaseModel):
return member.user
return None
def is_member(self, user_id):
def is_member(self, user_id: str) -> bool:
"""Check if a user is a member of the organization."""
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.organization.organization_member import OrganizationMember
return (
OrganizationMember.query.filter_by(
@@ -1,4 +1,4 @@
"""Organization member model."""
"""Organization member model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import OrganizationRole
@@ -21,31 +21,35 @@ class OrganizationMember(BaseModel):
joined_at = db.Column(db.DateTime, nullable=True)
# Relationships
user = db.relationship("User", foreign_keys=[user_id], back_populates="organization_memberships")
user = db.relationship(
"User", foreign_keys=[user_id], back_populates="organization_memberships"
)
organization = db.relationship("Organization", back_populates="members")
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
# Unique constraint to prevent duplicate memberships
__table_args__ = (
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"),
)
def __repr__(self):
"""String representation of OrganizationMember."""
return f"<OrganizationMember user_id={self.user_id} org_id={self.organization_id} role={self.role}>"
return (
f"<OrganizationMember user_id={self.user_id} "
f"org_id={self.organization_id} role={self.role}>"
)
def is_owner(self):
def is_owner(self) -> bool:
"""Check if member is an owner."""
return self.role == OrganizationRole.OWNER
def is_admin(self):
def is_admin(self) -> bool:
"""Check if member is an admin or owner."""
return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
def can_manage_members(self):
def can_manage_members(self) -> bool:
"""Check if member can manage other members."""
return self.is_admin()
def can_delete_organization(self):
def can_delete_organization(self) -> bool:
"""Check if member can delete the organization."""
return self.is_owner()
@@ -0,0 +1,215 @@
"""Principal and PrincipalMembership models."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class Principal(BaseModel):
"""Principal model representing an SSH principal (access level/role).
In SSH CA terminology, a principal is a string like "eng-prod-servers" or
"devops-admins" that represents a set of machines or access level. Users
can be granted access to principals, either directly or via department
membership.
Example:
- Principal: "eng-prod-servers"
- Users with this principal can SSH to prod servers
- Can be assigned to departments or directly to users
"""
__tablename__ = "principals"
organization_id = db.Column(
db.String(36),
db.ForeignKey("organizations.id"),
nullable=False,
index=True,
)
name = db.Column(db.String(255), nullable=False, index=True)
description = db.Column(db.Text, nullable=True)
# Relationships
organization = db.relationship("Organization", back_populates="principals")
memberships = db.relationship(
"PrincipalMembership",
back_populates="principal",
cascade="all, delete-orphan",
)
department_links = db.relationship(
"DepartmentPrincipal",
back_populates="principal",
cascade="all, delete-orphan",
)
__table_args__ = (
db.UniqueConstraint("organization_id", "name", name="uix_org_principal_name"),
)
def __repr__(self):
"""String representation of Principal."""
return f"<Principal {self.name} (org_id={self.organization_id})>"
def to_dict(self, exclude=None):
"""Convert principal to dictionary."""
exclude = exclude or []
data = super().to_dict(exclude=exclude)
data["direct_member_count"] = len(
[m for m in self.memberships if m.deleted_at is None]
)
data["department_count"] = len(
[d for d in self.department_links if d.deleted_at is None]
)
return data
def get_members(self, active_only: bool = True):
"""Get all users who are directly assigned to this principal.
Does NOT include users who get access via department membership.
Args:
active_only: If True, exclude soft-deleted members
Returns:
List of PrincipalMembership objects
"""
if active_only:
return [m for m in self.memberships if m.deleted_at is None]
return list(self.memberships)
def get_all_members(self, active_only: bool = True):
"""Get all users who have access to this principal.
Includes both direct members and users via department membership.
Args:
active_only: If True, exclude soft-deleted members
Returns:
Set of User objects with access to this principal
"""
all_users: set = set()
# Direct members
for membership in self.get_members(active_only=active_only):
if not active_only or membership.user.deleted_at is None:
all_users.add(membership.user)
# Members via department assignment
for dept_link in self.department_links:
if dept_link.deleted_at is None or not active_only:
for dept_member in dept_link.department.get_members(active_only=active_only):
if not active_only or dept_member.user.deleted_at is None:
all_users.add(dept_member.user)
return all_users
def get_departments(self, active_only: bool = True):
"""Get all departments this principal is assigned to.
Args:
active_only: If True, exclude soft-deleted departments
Returns:
List of Department objects
"""
if active_only:
return [
d.department
for d in self.department_links
if d.deleted_at is None and d.department.deleted_at is None
]
return [d.department for d in self.department_links]
def is_member(self, user_id: str, include_via_department: bool = True) -> bool:
"""Check if a user has access to this principal.
Args:
user_id: ID of the user to check
include_via_department: If True, check department memberships too
Returns:
True if user has access to this principal
"""
# Check direct membership
has_direct = (
PrincipalMembership.query.filter_by(
user_id=user_id,
principal_id=self.id,
deleted_at=None,
).first()
is not None
)
if has_direct:
return True
if not include_via_department:
return False
# Check department membership
dept_ids = [d.id for d in self.get_departments(active_only=True)]
if not dept_ids:
return False
from gatehouse_app.models.organization.department import DepartmentMembership
return (
DepartmentMembership.query.filter(
DepartmentMembership.user_id == user_id,
DepartmentMembership.department_id.in_(dept_ids),
DepartmentMembership.deleted_at.is_(None),
).first()
is not None
)
def get_member_count(self, include_via_department: bool = True) -> int:
"""Get the count of active members with access to this principal.
Args:
include_via_department: If True, include members via department
Returns:
Count of members
"""
if not include_via_department:
return len(self.get_members(active_only=True))
return len(self.get_all_members(active_only=True))
class PrincipalMembership(BaseModel):
"""Principal membership model representing direct user assignment to a principal.
When a user is assigned directly to a principal, they get access to that
principal for SSH authentication. This is in addition to any principals
they get via department membership.
"""
__tablename__ = "principal_memberships"
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id"),
nullable=False,
index=True,
)
principal_id = db.Column(
db.String(36),
db.ForeignKey("principals.id"),
nullable=False,
index=True,
)
# Relationships
user = db.relationship("User", back_populates="principal_memberships")
principal = db.relationship("Principal", back_populates="memberships")
__table_args__ = (
db.UniqueConstraint("user_id", "principal_id", name="uix_user_principal"),
)
def __repr__(self):
"""String representation of PrincipalMembership."""
return (
f"<PrincipalMembership user_id={self.user_id} "
f"principal_id={self.principal_id}>"
)
+12
View File
@@ -0,0 +1,12 @@
"""Security subpackage — organization and user security policies, MFA compliance."""
from gatehouse_app.models.security.organization_security_policy import (
OrganizationSecurityPolicy,
)
from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
__all__ = [
"OrganizationSecurityPolicy",
"UserSecurityPolicy",
"MfaPolicyCompliance",
]
@@ -1,4 +1,4 @@
"""MfaPolicyCompliance model."""
"""MfaPolicyCompliance model — per-user per-organization MFA compliance tracking."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import MfaComplianceStatus
@@ -7,7 +7,8 @@ from gatehouse_app.utils.constants import MfaComplianceStatus
class MfaPolicyCompliance(BaseModel):
"""MFA policy compliance tracking per user per organization.
Tracks each user's MFA compliance state separately for each organization membership.
Tracks each user's MFA compliance state separately for each organization
membership. One row per (user, org) pair.
"""
__tablename__ = "mfa_policy_compliance"
@@ -25,13 +26,13 @@ class MfaPolicyCompliance(BaseModel):
default=MfaComplianceStatus.NOT_APPLICABLE,
)
# Snapshot of org policy at the time this record became active
# Snapshot of org policy version when this record became active
policy_version = db.Column(db.Integer, nullable=False)
# When policy started applying to this user
applied_at = db.Column(db.DateTime, nullable=True)
# Final deadline for this user to comply (per user, not global)
# Final deadline for this user to comply
deadline_at = db.Column(db.DateTime, nullable=True)
# When they became compliant under this policy_version
@@ -45,9 +46,7 @@ class MfaPolicyCompliance(BaseModel):
notification_count = db.Column(db.Integer, nullable=False, default=0)
__table_args__ = (
db.UniqueConstraint(
"user_id", "organization_id", name="uix_user_org_compliance"
),
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_compliance"),
)
# Relationships
@@ -58,9 +57,11 @@ class MfaPolicyCompliance(BaseModel):
def __repr__(self):
"""String representation of MfaPolicyCompliance."""
return f"<MfaPolicyCompliance user={self.user_id} org={self.organization_id} status={self.status}>"
return (
f"<MfaPolicyCompliance user={self.user_id} "
f"org={self.organization_id} status={self.status}>"
)
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
return super().to_dict(exclude=exclude or [])
@@ -39,15 +39,19 @@ class OrganizationSecurityPolicy(BaseModel):
# Relationships
organization = db.relationship(
"Organization", back_populates="security_policy", foreign_keys=[organization_id]
"Organization",
back_populates="security_policy",
foreign_keys=[organization_id],
)
updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id])
def __repr__(self):
"""String representation of OrganizationSecurityPolicy."""
return f"<OrganizationSecurityPolicy org={self.organization_id} mode={self.mfa_policy_mode}>"
return (
f"<OrganizationSecurityPolicy "
f"org={self.organization_id} mode={self.mfa_policy_mode}>"
)
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
return super().to_dict(exclude=exclude or [])
@@ -1,4 +1,4 @@
"""UserSecurityPolicy model."""
"""UserSecurityPolicy model — per-user MFA overrides."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import MfaRequirementOverride
@@ -7,7 +7,7 @@ from gatehouse_app.utils.constants import MfaRequirementOverride
class UserSecurityPolicy(BaseModel):
"""User security policy model for per-user MFA overrides.
Stores per user overrides of organization level MFA requirements.
Stores per-user overrides of organization-level MFA requirements.
"""
__tablename__ = "user_security_policies"
@@ -25,29 +25,27 @@ class UserSecurityPolicy(BaseModel):
default=MfaRequirementOverride.INHERIT,
)
# If override is REQUIRED and you want to force a specific factor set
# If override is REQUIRED, optionally force a specific factor set
force_totp = db.Column(db.Boolean, nullable=False, default=False)
force_webauthn = db.Column(db.Boolean, nullable=False, default=False)
__table_args__ = (
db.UniqueConstraint(
"user_id", "organization_id", name="uix_user_org_policy"
),
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_policy"),
)
# Relationships
user = db.relationship(
"User", back_populates="security_policies", foreign_keys=[user_id]
)
organization = db.relationship(
"Organization", foreign_keys=[organization_id]
)
organization = db.relationship("Organization", foreign_keys=[organization_id])
def __repr__(self):
"""String representation of UserSecurityPolicy."""
return f"<UserSecurityPolicy user={self.user_id} org={self.organization_id} mode={self.mfa_override_mode}>"
return (
f"<UserSecurityPolicy user={self.user_id} "
f"org={self.organization_id} mode={self.mfa_override_mode}>"
)
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
return super().to_dict(exclude=exclude or [])
+17
View File
@@ -0,0 +1,17 @@
"""SSH/CA subpackage — certificate authorities, SSH keys, and certificates."""
from gatehouse_app.models.ssh_ca.ca import CA, KeyType, CertType, CaType, CAPermission
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
__all__ = [
"CA",
"KeyType",
"CertType",
"CaType",
"CAPermission",
"SSHKey",
"SSHCertificate",
"CertificateStatus",
"CertificateAuditLog",
]
+238
View File
@@ -0,0 +1,238 @@
"""Certificate Authority (CA) model."""
from enum import Enum
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class KeyType(str, Enum):
"""SSH CA key types."""
ED25519 = "ed25519"
RSA = "rsa"
ECDSA = "ecdsa"
class CertType(str, Enum):
"""SSH certificate types."""
USER = "user"
HOST = "host"
class CaType(str, Enum):
"""CA signing type — whether this CA signs user or host certificates."""
USER = "user"
HOST = "host"
class CA(BaseModel):
"""Certificate Authority (CA) model for SSH certificate signing.
Each organization can have multiple CAs for different purposes
(e.g., production vs. staging). Private keys are encrypted at rest
and should be protected with KMS.
"""
__tablename__ = "cas"
organization_id = db.Column(
db.String(36),
db.ForeignKey("organizations.id"),
nullable=True, # NULL for the global system-config CA
index=True,
)
# CA identity
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=True)
# CA signing type: 'user' signs user certificates, 'host' signs host certs
ca_type = db.Column(
db.Enum(CaType, values_callable=lambda x: [e.value for e in x]),
default=CaType.USER,
nullable=False,
)
# Key type (ED25519, RSA, ECDSA)
key_type = db.Column(
db.Enum(KeyType, values_callable=lambda x: [e.value for e in x]),
default=KeyType.ED25519,
nullable=False,
)
# Private key — PEM-encoded, encrypted at rest by database/KMS
private_key = db.Column(db.Text, nullable=False)
# Public key — PEM format
public_key = db.Column(db.Text, nullable=False)
# SHA256 fingerprint of the public key
fingerprint = db.Column(db.String(255), nullable=False, unique=True)
# CRL (Certificate Revocation List) configuration
crl_enabled = db.Column(db.Boolean, default=True, nullable=False)
crl_endpoint = db.Column(db.String(512), nullable=True)
# Default certificate validity in hours (overridable per request)
default_cert_validity_hours = db.Column(db.Integer, default=1, nullable=False)
# Maximum validity duration allowed
max_cert_validity_hours = db.Column(db.Integer, default=24, nullable=False)
# CA status
is_active = db.Column(db.Boolean, default=True, nullable=False, index=True)
# Key rotation tracking
rotated_at = db.Column(db.DateTime, nullable=True)
rotation_reason = db.Column(db.String(255), nullable=True)
# Monotonically-increasing serial counter. Every cert this CA issues
# gets the next value so serials are unique, ordered, and auditable.
# Protected by a row-level SELECT … FOR UPDATE in get_next_serial().
next_serial_number = db.Column(db.BigInteger, default=1, nullable=False)
# Relationships
organization = db.relationship("Organization", back_populates="cas")
certificates = db.relationship(
"SSHCertificate",
back_populates="ca",
cascade="all, delete-orphan",
)
permissions = db.relationship(
"CAPermission",
back_populates="ca",
cascade="all, delete-orphan",
)
__table_args__ = (
db.Index("idx_ca_org_active", "organization_id", "is_active"),
)
def __repr__(self):
"""String representation of CA."""
return (
f"<CA {self.name} "
f"(org_id={self.organization_id}, type={self.key_type})>"
)
def to_dict(self, exclude=None):
"""Convert CA to dictionary, never exposing the private key."""
exclude = exclude or []
if "private_key" not in exclude:
exclude.append("private_key")
data = super().to_dict(exclude=exclude)
# Add computed fields
data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None])
data["active_certs"] = len(
[c for c in self.certificates if c.deleted_at is None and not c.revoked]
)
data["revoked_certs"] = len(
[c for c in self.certificates if c.deleted_at is None and c.revoked]
)
return data
def get_active_certificates(self) -> list:
"""Get all active (non-revoked) certificates issued by this CA."""
return [
c for c in self.certificates if c.deleted_at is None and not c.revoked
]
def rotate_key(
self,
new_private_key: str,
new_public_key: str,
new_fingerprint: str,
reason: str = None,
) -> None:
"""Rotate the CA's key pair.
This should only be done in carefully controlled circumstances.
All existing certificates remain valid but no new certificates can be
signed with the old key after rotation.
Args:
new_private_key: New PEM-encoded private key
new_public_key: New PEM-encoded public key
new_fingerprint: SHA256 fingerprint of new public key
reason: Optional reason for rotation
"""
self.private_key = new_private_key
self.public_key = new_public_key
self.fingerprint = new_fingerprint
self.rotated_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
self.rotation_reason = reason
self.save()
def get_next_serial(self) -> int:
"""Atomically increment and return the next certificate serial number.
Uses a SELECT FOR UPDATE row lock so concurrent requests never
receive the same serial. Must be called inside an active DB
transaction (i.e. before the final session.commit()).
Returns:
int: The serial number to embed in the next certificate.
"""
# Re-fetch this CA row with an exclusive row lock
locked = (
db.session.query(CA)
.with_for_update()
.filter_by(id=self.id)
.one()
)
serial = locked.next_serial_number
locked.next_serial_number = serial + 1
db.session.flush() # write increment; commit happens in the caller
return serial
class CAPermission(BaseModel):
"""Per-user CA permission model.
Controls which users are allowed to sign certificates against a specific CA.
When a CA has any permission rows, the signing endpoint enforces the list;
CAs with no rows are open to all org members (backwards-compatible default).
Permission values:
sign user may request certificate signing
admin user may sign AND manage the CA (rotate keys, delete, etc.)
"""
__tablename__ = "ca_permissions"
ca_id = db.Column(
db.String(36),
db.ForeignKey("cas.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
permission = db.Column(db.String(50), nullable=False, default="sign")
# Relationships
ca = db.relationship("CA", back_populates="permissions")
user = db.relationship("User", back_populates="ca_permissions")
__table_args__ = (
db.UniqueConstraint("ca_id", "user_id", name="uix_ca_permission"),
)
def __repr__(self):
return (
f"<CAPermission ca_id={self.ca_id} "
f"user_id={self.user_id} permission={self.permission}>"
)
def to_dict(self, exclude=None):
data = super().to_dict(exclude=exclude or [])
data["permission"] = self.permission
return data
@@ -0,0 +1,91 @@
"""Certificate audit log model — tracks SSH certificate lifecycle events."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class CertificateAuditLog(BaseModel):
"""Audit log for SSH certificate lifecycle events.
Tracks all operations on SSH certificates: signing, revocation, validation,
etc. Kept separate from the general AuditLog to provide detailed certificate
operation tracking without polluting the main audit stream.
"""
__tablename__ = "certificate_audit_logs"
# Reference to the certificate
certificate_id = db.Column(
db.String(36),
db.ForeignKey("ssh_certificates.id"),
nullable=False,
index=True,
)
# The user who performed the action (null for system actions)
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id"),
nullable=True,
index=True,
)
# Action type (e.g., "signed", "revoked", "validated", "requested")
action = db.Column(db.String(50), nullable=False, index=True)
# Request details
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.String(512), nullable=True)
request_id = db.Column(db.String(36), nullable=True)
# Detailed message
message = db.Column(db.Text, nullable=True)
# Additional context
extra_data = db.Column(db.JSON, nullable=True)
# Outcome
success = db.Column(db.Boolean, default=True, nullable=False)
error_message = db.Column(db.Text, nullable=True)
# Relationships
certificate = db.relationship("SSHCertificate", back_populates="audit_logs")
user = db.relationship("User")
__table_args__ = (
db.Index("idx_cert_audit_cert_action", "certificate_id", "action"),
db.Index("idx_cert_audit_user", "user_id", "created_at"),
)
def __repr__(self):
"""String representation of CertificateAuditLog."""
return (
f"<CertificateAuditLog cert_id={self.certificate_id} action={self.action}>"
)
@classmethod
def log(
cls,
certificate_id: str,
action: str,
user_id: str = None,
**kwargs,
) -> "CertificateAuditLog":
"""Create a certificate audit log entry.
Args:
certificate_id: ID of the certificate
action: Action type (e.g., "signed", "revoked")
user_id: ID of the user performing the action (optional)
**kwargs: Additional fields (ip_address, user_agent, message, etc.)
Returns:
CertificateAuditLog instance
"""
log_entry = cls(
certificate_id=certificate_id,
action=action,
user_id=user_id,
**kwargs,
)
log_entry.save()
return log_entry
@@ -0,0 +1,176 @@
"""SSH Certificate model — signed SSH user/host certificates."""
from enum import Enum
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.models.ssh_ca.ca import CertType
class CertificateStatus(str, Enum):
"""SSH certificate lifecycle status."""
REQUESTED = "requested" # Waiting for signing
ISSUED = "issued" # Signed and valid
REVOKED = "revoked" # Manually revoked
EXPIRED = "expired" # Validity period ended
SUPERSEDED = "superseded" # Replaced by newer certificate
class SSHCertificate(BaseModel):
"""SSH Certificate model representing a signed SSH user/host certificate.
Certificates are issued by a CA and associated with an SSH public key.
They include principals (access levels), validity periods, and standard
OpenSSH certificate metadata.
"""
__tablename__ = "ssh_certificates"
# Certificate relationships
ca_id = db.Column(
db.String(36),
db.ForeignKey("cas.id"),
nullable=False,
index=True,
)
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id"),
nullable=False,
index=True,
)
ssh_key_id = db.Column(
db.String(36),
db.ForeignKey("ssh_keys.id"),
nullable=False,
index=True,
)
# Certificate content — full signed certificate in OpenSSH format
certificate = db.Column(db.Text, nullable=False)
# Certificate metadata
serial = db.Column(db.String(255), nullable=False, unique=True, index=True)
key_id = db.Column(db.String(255), nullable=False) # Usually user email
cert_type = db.Column(
db.Enum(CertType, values_callable=lambda x: [e.value for e in x]),
default=CertType.USER,
nullable=False,
)
# Principals — JSON list, e.g., ["prod-servers", "dev-servers"]
principals = db.Column(db.JSON, nullable=False, default=list)
# Validity period
valid_after = db.Column(db.DateTime, nullable=False)
valid_before = db.Column(db.DateTime, nullable=False)
# Revocation status
revoked = db.Column(db.Boolean, default=False, nullable=False, index=True)
revoked_at = db.Column(db.DateTime, nullable=True)
revoke_reason = db.Column(db.String(255), nullable=True)
# Status tracking
status = db.Column(
db.Enum(CertificateStatus, values_callable=lambda x: [e.value for e in x]),
default=CertificateStatus.ISSUED,
nullable=False,
index=True,
)
# Request metadata
request_ip = db.Column(db.String(45), nullable=True)
request_user_agent = db.Column(db.String(512), nullable=True)
# Critical options — OpenSSH critical options (JSON)
critical_options = db.Column(db.JSON, nullable=True, default=dict)
# Extensions — OpenSSH extensions (JSON)
extensions = db.Column(db.JSON, nullable=True, default=dict)
# Relationships
ca = db.relationship("CA", back_populates="certificates")
user = db.relationship("User", back_populates="ssh_certificates")
ssh_key = db.relationship(
"SSHKey",
back_populates="certificates",
foreign_keys="SSHCertificate.ssh_key_id",
)
audit_logs = db.relationship(
"CertificateAuditLog",
back_populates="certificate",
cascade="all, delete-orphan",
)
__table_args__ = (
db.Index("idx_cert_user_status", "user_id", "status"),
db.Index("idx_cert_validity", "valid_after", "valid_before"),
db.Index("idx_cert_revoked", "revoked", "revoked_at"),
)
def __repr__(self):
"""String representation of SSHCertificate."""
return f"<SSHCertificate serial={self.serial[:16]}... user_id={self.user_id}>"
def to_dict(self, exclude=None):
"""Convert certificate to dictionary.
The raw ``certificate`` blob is excluded by default (it is large and
callers can request it explicitly by removing it from the exclude list).
"""
exclude = exclude or []
if "certificate" not in exclude:
exclude.append("certificate")
data = super().to_dict(exclude=exclude)
data["is_valid"] = self.is_valid()
data["days_until_expiry"] = self.days_until_expiry()
return data
def _aware(self, dt: datetime) -> datetime:
"""Return a timezone-aware UTC datetime."""
return dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt
def is_valid(self) -> bool:
"""Check if certificate is currently valid.
Returns:
True if certificate is issued, not revoked, and within validity period
"""
if self.revoked or self.status == CertificateStatus.REVOKED:
return False
now = datetime.now(timezone.utc)
return self._aware(self.valid_after) <= now <= self._aware(self.valid_before)
def is_expired(self) -> bool:
"""Check if certificate has expired.
Returns:
True if current time is past valid_before
"""
return datetime.now(timezone.utc) > self._aware(self.valid_before)
def days_until_expiry(self) -> int:
"""Get number of days until certificate expires.
Returns:
Number of days remaining (negative if already expired)
"""
delta = self._aware(self.valid_before) - datetime.now(timezone.utc)
return delta.days + (1 if delta.seconds > 0 else 0)
def revoke(self, reason: str = None) -> None:
"""Revoke this certificate.
Args:
reason: Optional reason for revocation
"""
self.revoked = True
self.revoked_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
self.revoke_reason = reason
self.status = CertificateStatus.REVOKED
self.save()
def mark_expired(self) -> None:
"""Mark certificate as expired when validity period ends."""
self.status = CertificateStatus.EXPIRED
self.save()
+98
View File
@@ -0,0 +1,98 @@
"""SSH Key model — user SSH public keys registered for certificate signing."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class SSHKey(BaseModel):
"""SSH Key model representing a user's SSH public key.
Users register SSH public keys for certificate signing. Keys must be
verified (owner proved possession) before they can be used.
"""
__tablename__ = "ssh_keys"
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id"),
nullable=False,
index=True,
)
# SSH key payload in OpenSSH format (e.g., "ssh-ed25519 AAAAB3Nz...")
payload = db.Column(db.Text, nullable=False, unique=True)
# SHA256 fingerprint for quick comparison and deduplication
fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True)
# Optional human-readable description (e.g., "My laptop key")
description = db.Column(db.String(255), nullable=True)
# Verification status
verified = db.Column(db.Boolean, default=False, nullable=False, index=True)
verified_at = db.Column(db.DateTime, nullable=True)
# Verification challenge — shown to user once, cleared after verification
verify_text = db.Column(db.String(255), nullable=True)
verify_text_created_at = db.Column(db.DateTime, nullable=True)
# Key metadata extracted from the key
key_type = db.Column(db.String(50), nullable=True) # ssh-rsa, ssh-ed25519, etc.
key_bits = db.Column(db.Integer, nullable=True) # key length
key_comment = db.Column(db.String(255), nullable=True)
# Relationships
user = db.relationship("User", back_populates="ssh_keys")
certificates = db.relationship(
"SSHCertificate",
back_populates="ssh_key",
cascade="all, delete-orphan",
foreign_keys="SSHCertificate.ssh_key_id",
)
__table_args__ = (
db.Index("idx_ssh_key_user_verified", "user_id", "verified"),
)
def __repr__(self):
"""String representation of SSHKey."""
return f"<SSHKey {self.fingerprint[:16]}... user_id={self.user_id}>"
def to_dict(self, exclude=None):
"""Convert SSH key to dictionary.
``payload`` and ``verify_text`` are never exposed through the API.
"""
exclude = exclude or []
for field in ("payload", "verify_text"):
if field not in exclude:
exclude.append(field)
data = super().to_dict(exclude=exclude)
data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None])
return data
def mark_verified(self) -> None:
"""Mark this SSH key as verified and clear the challenge."""
self.verified = True
self.verified_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
self.verify_text = None
self.save()
def needs_verification_refresh(self, max_age_hours: int = 24) -> bool:
"""Check if verification challenge needs to be refreshed.
Args:
max_age_hours: Maximum age of verification challenge in hours
Returns:
True if verification challenge is stale or missing
"""
if not self.verify_text_created_at:
return True
age = datetime.now(timezone.utc) - self.verify_text_created_at.replace(
tzinfo=timezone.utc
) if self.verify_text_created_at.tzinfo is None else (
datetime.now(timezone.utc) - self.verify_text_created_at
)
return age.total_seconds() > (max_age_hours * 3600)
+3 -152
View File
@@ -1,153 +1,4 @@
"""User model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import UserStatus
"""Backward-compatibility shim — import from gatehouse_app.models.user.user instead."""
from gatehouse_app.models.user.user import User # noqa: F401
class User(BaseModel):
"""User model representing a user account."""
__tablename__ = "users"
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
email_verified = db.Column(db.Boolean, default=False, nullable=False)
full_name = db.Column(db.String(255), nullable=True)
avatar_url = db.Column(db.String(512), nullable=True)
status = db.Column(
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
)
last_login_at = db.Column(db.DateTime, nullable=True)
last_login_ip = db.Column(db.String(45), nullable=True)
# Relationships
authentication_methods = db.relationship(
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
)
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
organization_memberships = db.relationship(
"OrganizationMember",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="OrganizationMember.user_id",
)
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
security_policies = db.relationship(
"UserSecurityPolicy",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="UserSecurityPolicy.user_id",
)
mfa_compliance = db.relationship(
"MfaPolicyCompliance",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="MfaPolicyCompliance.user_id",
)
def __repr__(self):
"""String representation of User."""
return f"<User {self.email}>"
def to_dict(self, exclude=None):
"""Convert user to dictionary, excluding sensitive fields by default."""
exclude = exclude or []
# Always exclude password-related fields
default_exclude = []
all_exclude = list(set(default_exclude + exclude))
return super().to_dict(exclude=all_exclude)
def has_password_auth(self):
"""Check if user has password authentication enabled."""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
).first()
is not None
)
def get_organizations(self):
"""Get all organizations the user is a member of."""
return [membership.organization for membership in self.organization_memberships]
def has_totp_enabled(self) -> bool:
"""Check if user has TOTP enabled and verified.
Returns:
True if user has a verified TOTP authentication method, False otherwise.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.TOTP,
verified=True,
deleted_at=None,
).first()
is not None
)
def get_totp_method(self):
"""Get user's TOTP authentication method.
Returns:
The AuthenticationMethod instance for TOTP or None if not found.
Note:
Returns the most recently created TOTP method to handle cases where
multiple enrollment attempts may exist.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
).order_by(AuthenticationMethod.created_at.desc()).first()
def has_webauthn_enabled(self) -> bool:
"""Check if user has any WebAuthn passkey credentials.
Returns:
True if user has at least one WebAuthn credential, False otherwise.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None,
).first()
is not None
)
def get_webauthn_credentials(self):
"""Get all WebAuthn credentials for the user.
Returns:
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
).order_by(AuthenticationMethod.created_at.desc()).all()
def get_webauthn_credential_count(self) -> int:
"""Get the count of WebAuthn credentials for the user.
Returns:
Number of WebAuthn credentials.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
).count()
__all__ = ["User"]
+5
View File
@@ -0,0 +1,5 @@
"""User subpackage."""
from gatehouse_app.models.user.user import User
from gatehouse_app.models.user.session import Session
__all__ = ["User", "Session"]
@@ -21,7 +21,9 @@ class Session(BaseModel):
# Timing
expires_at = db.Column(db.DateTime, nullable=False)
last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
last_activity_at = db.Column(
db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)
)
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
@@ -38,7 +40,6 @@ class Session(BaseModel):
def is_active(self):
"""Check if session is currently active."""
now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
@@ -51,15 +52,13 @@ class Session(BaseModel):
def is_expired(self):
"""Check if session has expired."""
now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
def refresh(self, duration_seconds=86400):
"""
Refresh session expiration.
def refresh(self, duration_seconds: int = 86400):
"""Refresh session expiration.
Args:
duration_seconds: New session duration in seconds
@@ -68,9 +67,8 @@ class Session(BaseModel):
self.last_activity_at = datetime.now(timezone.utc)
db.session.commit()
def revoke(self, reason=None):
"""
Revoke the session.
def revoke(self, reason: str = None):
"""Revoke the session.
Args:
reason: Optional reason for revocation
@@ -84,6 +82,5 @@ class Session(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Exclude token from dict
exclude.append("token")
return super().to_dict(exclude=exclude)
+209
View File
@@ -0,0 +1,209 @@
"""User model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import UserStatus
class User(BaseModel):
"""User model representing a user account."""
__tablename__ = "users"
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
email_verified = db.Column(db.Boolean, default=False, nullable=False)
full_name = db.Column(db.String(255), nullable=True)
avatar_url = db.Column(db.String(512), nullable=True)
status = db.Column(
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
)
last_login_at = db.Column(db.DateTime, nullable=True)
last_login_ip = db.Column(db.String(45), nullable=True)
# Account activation (email-link flow)
activated = db.Column(db.Boolean, default=True, nullable=False)
activation_key = db.Column(db.String(128), unique=True, nullable=True, index=True)
# Relationships defined here only for models that don't circular-import
authentication_methods = db.relationship(
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
)
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
organization_memberships = db.relationship(
"OrganizationMember",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="OrganizationMember.user_id",
)
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
security_policies = db.relationship(
"UserSecurityPolicy",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="UserSecurityPolicy.user_id",
)
mfa_compliance = db.relationship(
"MfaPolicyCompliance",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="MfaPolicyCompliance.user_id",
)
department_memberships = db.relationship(
"DepartmentMembership",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="DepartmentMembership.user_id",
)
principal_memberships = db.relationship(
"PrincipalMembership",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="PrincipalMembership.user_id",
)
ssh_keys = db.relationship(
"SSHKey",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHKey.user_id",
)
ssh_certificates = db.relationship(
"SSHCertificate",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHCertificate.user_id",
)
ca_permissions = db.relationship(
"CAPermission",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="CAPermission.user_id",
)
# OIDC relationships registered here (no monkey-patching needed)
oidc_auth_codes = db.relationship(
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
)
oidc_refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
)
oidc_sessions = db.relationship(
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
)
oidc_token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
)
oidc_audit_logs = db.relationship(
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
)
def __repr__(self):
"""String representation of User."""
return f"<User {self.email}>"
def to_dict(self, exclude=None):
"""Convert user to dictionary, excluding sensitive fields by default."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
def has_password_auth(self):
"""Check if user has password authentication enabled."""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
).first()
is not None
)
def get_organizations(self):
"""Get all organizations the user is a member of."""
return [membership.organization for membership in self.organization_memberships]
def has_totp_enabled(self) -> bool:
"""Check if user has TOTP enabled and verified.
Returns:
True if user has a verified TOTP authentication method, False otherwise.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.TOTP,
verified=True,
deleted_at=None,
).first()
is not None
)
def get_totp_method(self):
"""Get user's TOTP authentication method.
Returns:
The AuthenticationMethod instance for TOTP or None if not found.
Note:
Returns the most recently created TOTP method to handle cases where
multiple enrollment attempts may exist.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
)
.order_by(AuthenticationMethod.created_at.desc())
.first()
)
def has_webauthn_enabled(self) -> bool:
"""Check if user has any WebAuthn passkey credentials.
Returns:
True if user has at least one WebAuthn credential, False otherwise.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None,
).first()
is not None
)
def get_webauthn_credentials(self):
"""Get all WebAuthn credentials for the user.
Returns:
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
)
.order_by(AuthenticationMethod.created_at.desc())
.all()
)
def get_webauthn_credential_count(self) -> int:
"""Get the count of WebAuthn credentials for the user.
Returns:
Number of WebAuthn credentials.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
).count()
+27 -3
View File
@@ -25,7 +25,7 @@ class LoginSchema(Schema):
email = fields.Email(required=True)
password = fields.Str(required=True, validate=validate.Length(min=1))
remember_me = fields.Bool(missing=False)
remember_me = fields.Bool(load_default=False)
class RefreshTokenSchema(Schema):
@@ -77,14 +77,38 @@ class TOTPVerifyEnrollmentSchema(Schema):
class TOTPVerifySchema(Schema):
"""Schema for TOTP code verification during login."""
code = fields.Str(required=True)
is_backup_code = fields.Bool(missing=False)
code = fields.Str(
required=True,
validate=validate.Length(min=1),
)
is_backup_code = fields.Bool(load_default=False)
client_timestamp = fields.Int(
required=False,
allow_none=True,
metadata={"description": "Client UTC timestamp in seconds since epoch for TOTP verification"},
)
@validates_schema
def validate_code_format(self, data, **kwargs):
"""Validate code format depending on whether it's a backup code."""
code = data.get("code", "")
is_backup_code = data.get("is_backup_code", False)
if is_backup_code:
# Backup codes are 16 uppercase hex characters
if not code or len(code) != 16 or not all(c in "0123456789ABCDEFabcdef" for c in code):
raise ValidationError(
"Backup code must be a 16-character hexadecimal string.",
field_name="code",
)
else:
# Regular TOTP codes are exactly 6 digits
import re
if not re.match(r"^\d{6}$", code):
raise ValidationError(
"Code must be a 6-digit number.",
field_name="code",
)
class TOTPDisableSchema(Schema):
"""Schema for disabling TOTP."""
+2 -2
View File
@@ -1,6 +1,6 @@
"""Audit service."""
from flask import request, g
from gatehouse_app.models.audit_log import AuditLog
from gatehouse_app.models.auth.audit_log import AuditLog
from gatehouse_app.utils.constants import AuditAction
@@ -59,7 +59,7 @@ class AuditService:
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
metadata=metadata,
extra_data=metadata,
description=description,
success=success,
error_message=error_message,
+35 -4
View File
@@ -5,9 +5,9 @@ from datetime import datetime, timedelta, timezone
from typing import Optional
from flask import request, g, current_app
from gatehouse_app.extensions import db, bcrypt
from gatehouse_app.models.user import User
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.models.session import Session
from gatehouse_app.models.user.user import User
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, UserStatus, AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError
@@ -102,7 +102,7 @@ class AuthService:
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}")
if user.status == UserStatus.SUSPENDED:
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
raise AccountSuspendedError()
if user.status == UserStatus.INACTIVE:
raise AccountInactiveError()
@@ -210,6 +210,22 @@ class AuthService:
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
db.session.commit()
# Invalidate all other sessions so that if an attacker had a valid
# session token, changing the password actually locks them out.
# The current request's session (if any) is preserved so the user
# doesn't have to log in again immediately.
from flask import g as flask_g
current_session_id = getattr(flask_g, "current_session", None)
current_session_id = current_session_id.id if current_session_id else None
sessions_to_revoke = Session.query.filter(
Session.user_id == user.id,
Session.revoked_at == None, # noqa: E711
).all()
for sess in sessions_to_revoke:
if sess.id != current_session_id:
sess.revoke(reason="Password changed")
db.session.commit()
# Log password change
AuditService.log_action(
action=AuditAction.PASSWORD_CHANGE,
@@ -482,9 +498,24 @@ class AuthService:
if not secret:
raise InvalidCredentialsError("TOTP secret not found")
# Replay-attack prevention: reject codes that have already been
# accepted within the current validity window.
if TOTPService.is_code_already_used(str(user.id), code):
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_FAILED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP code replay attempt detected",
)
raise InvalidCredentialsError("Invalid TOTP code")
is_valid = TOTPService.verify_code(secret, code, client_utc_timestamp=client_utc_timestamp)
if is_valid:
# Mark this code as used to prevent replay within the validity window
TOTPService.mark_code_used(str(user.id), code)
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
@@ -8,7 +8,7 @@ from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models import User, AuthenticationMethod
from gatehouse_app.models.authentication_method import (
from gatehouse_app.models.auth.authentication_method import (
OAuthState,
ApplicationProviderConfig,
OrganizationProviderOverride
@@ -736,9 +736,14 @@ class ExternalAuthService:
400,
)
# Generate PKCE
code_verifier = secrets.token_urlsafe(32)
code_challenge = cls._compute_s256_challenge(code_verifier)
# Generate PKCE — skip for confidential clients (Google, Microsoft) that use a
# client_secret. Sending code_challenge to Microsoft causes it to enforce PKCE on
# the token exchange, which then fails. Matches the behaviour of initiate_login_flow.
code_verifier = None
code_challenge = None
if provider_type_str not in ('google', 'microsoft'):
code_verifier = secrets.token_urlsafe(32)
code_challenge = cls._compute_s256_challenge(code_verifier)
# Create OAuth state
state = OAuthState.create_state(
@@ -1210,12 +1215,35 @@ class ExternalAuthService:
else:
email_verified = data.get("email_verified", False)
sub = data.get("sub")
# Derive email from sub when the provider omits the email claim.
# This happens with some OIDC servers (including the nav-security mock)
# that only return the minimal {sub, iss, iat, exp} set.
# Rule: if sub looks like an email address, use it directly.
# Otherwise, construct a deterministic fallback so we never get NULL.
raw_email = data.get("email")
if not raw_email and sub:
import re as _re
if _re.match(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", sub):
raw_email = sub
email_verified = True # if sub IS the email it's already verified
else:
# e.g. "12345" → "12345@google.local" so we can store it
raw_email = f"{sub}@{provider or 'oauth'}.local"
email_verified = False
# Derive display name when omitted
raw_name = data.get("name") or data.get("display_name")
if not raw_name and raw_email:
raw_name = raw_email.split("@")[0]
# Standardize user info
return {
"provider_user_id": data.get("sub"),
"email": data.get("email"),
"provider_user_id": sub,
"email": raw_email,
"email_verified": email_verified,
"name": data.get("name"),
"name": raw_name,
"first_name": data.get("given_name"),
"last_name": data.get("family_name"),
"picture": data.get("picture"),
+6 -6
View File
@@ -4,11 +4,11 @@ from datetime import datetime, timezone
from typing import Optional, List, Dict, Any
from gatehouse_app.extensions import db
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.user import User
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.user.user import User
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.utils.constants import (
MfaPolicyMode,
@@ -702,7 +702,7 @@ class MfaPolicyService:
if now is None:
now = datetime.now(timezone.utc)
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.organization.organization_member import OrganizationMember
updated_count = 0
+55 -94
View File
@@ -19,9 +19,9 @@ import logging
import json
from gatehouse_app.extensions import db
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.user import User
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.user.user import User
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.utils.constants import AuditAction
@@ -37,6 +37,7 @@ class NotificationService:
SMTP_PORT_KEY = "SMTP_PORT"
SMTP_USERNAME_KEY = "SMTP_USERNAME"
SMTP_PASSWORD_KEY = "SMTP_PASSWORD"
SMTP_USE_TLS_KEY = "SMTP_USE_TLS"
FROM_ADDRESS_KEY = "FROM_ADDRESS"
@staticmethod
@@ -86,10 +87,9 @@ class NotificationService:
if success:
logger.info(
f"Sent MFA deadline reminder to {user.email} "
f"({days_until_deadline} days remaining # Audit log
)"
f"({days_until_deadline} days remaining)"
)
AuditService.log_action(
AuditService.log_action(
action=AuditAction.MFA_POLICY_USER_COMPLIANT,
user_id=user.id,
organization_id=compliance.organization_id,
@@ -291,101 +291,62 @@ Gatehouse Security Team
body: str,
html_body: Optional[str] = None,
) -> bool:
"""Send an email notification.
"""Send an email via SMTP.
This method attempts to send an email using configured SMTP settings.
If email is not configured, it logs the notification instead.
Args:
to_address: Recipient email address
subject: Email subject
body: Plain text email body
html_body: Optional HTML email body
Returns:
True if email was sent (or logged), False on error
Returns True if the email was sent successfully, False otherwise.
If EMAIL_ENABLED is False, logs the email body instead (simulation mode).
"""
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from flask import current_app
email_enabled = current_app.config.get(NotificationService.EMAIL_ENABLED_KEY, False)
if not email_enabled:
logger.info(
f"[EMAIL DISABLED] Would have sent to: {to_address} | Subject: {subject}\n"
f"Body: {body[:500]}"
)
return False
smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY, "localhost")
smtp_port = int(current_app.config.get(NotificationService.SMTP_PORT_KEY, 587))
smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY)
smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY)
smtp_use_tls = current_app.config.get(
NotificationService.SMTP_USE_TLS_KEY,
smtp_port not in (25, 1025),
)
from_address = current_app.config.get(
NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local"
)
try:
from flask import current_app
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["From"] = from_address
msg["To"] = to_address
msg.attach(MIMEText(body, "plain"))
if html_body:
msg.attach(MIMEText(html_body, "html"))
# Check if email is configured
email_enabled = current_app.config.get(
NotificationService.EMAIL_ENABLED_KEY, False
)
if not email_enabled:
# Log the notification instead of sending
logger.info(
f"[EMAIL SIMULATION] To: {to_address}\n"
f"Subject: {subject}\n"
f"Body: {body[:200]}..." if len(body) > 200 else f"Body: {body}"
)
return True
# Get email configuration
smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY)
smtp_port = current_app.config.get(NotificationService.SMTP_PORT_KEY, 587)
smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY)
smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY)
from_address = current_app.config.get(
NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local"
)
# Import send_email based on available mail library
try:
from flask_mail import Message
from gatehouse_app import mail
msg = Message(
subject=subject,
recipients=[to_address],
body=body,
html=html_body,
sender=from_address,
)
mail.send(msg)
logger.info(f"Email sent successfully to {to_address}")
return True
except ImportError:
# Flask-Mail not available, use SMTP directly
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["From"] = from_address
msg["To"] = to_address
# Attach plain text and HTML versions
part1 = MIMEText(body, "plain")
msg.attach(part1)
if html_body:
part2 = MIMEText(html_body, "html")
msg.attach(part2)
# Send via SMTP
with smtplib.SMTP(smtp_host, smtp_port) as server:
with smtplib.SMTP(smtp_host, smtp_port) as server:
server.ehlo()
if smtp_use_tls:
server.starttls()
if smtp_username and smtp_password:
server.login(smtp_username, smtp_password)
server.send_message(msg)
server.ehlo()
if smtp_username and smtp_password:
server.login(smtp_username, smtp_password)
server.send_message(msg)
logger.info(f"Email sent successfully to {to_address}")
return True
logger.info(f"[EMAIL] Sent to {to_address} | Subject: {subject}")
return True
except Exception as e:
logger.exception(f"Failed to send email to {to_address}: {e}")
# Log the notification as fallback
logger.info(
f"[EMAIL FALLBACK] To: {to_address}\n"
f"Subject: {subject}\n"
f"Body: {body[:500]}..." if len(body) > 500 else f"Body: {body}"
)
return True # Return True to continue processing
logger.error(f"[EMAIL] Failed to send to {to_address}: {e}")
return False
@staticmethod
def get_notification_stats(user_id: str) -> Dict[str, Any]:
@@ -397,7 +358,7 @@ Gatehouse Security Team
Returns:
Dictionary with notification statistics
"""
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
stats = {
"total_notifications": 0,
+108 -25
View File
@@ -9,10 +9,10 @@ from flask import current_app, request, g, redirect
from gatehouse_app.extensions import db
from gatehouse_app.models import User, AuthenticationMethod
from gatehouse_app.models.authentication_method import OAuthState
from gatehouse_app.models.auth.authentication_method import OAuthState
from gatehouse_app.models.base import BaseModel
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
from gatehouse_app.utils.constants import AuthMethodType
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.services.external_auth_service import (
ExternalAuthService,
@@ -139,7 +139,7 @@ class OAuthFlowService:
except ExternalAuthError as e:
# Log failed initiation
AuditService.log_action(
action="external_auth.login.initiated",
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
organization_id=organization_id,
metadata={
"provider_type": provider_type_str,
@@ -236,7 +236,7 @@ class OAuthFlowService:
except ExternalAuthError as e:
AuditService.log_action(
action="external_auth.register.initiated",
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
organization_id=organization_id,
metadata={
"provider_type": provider_type_str,
@@ -399,6 +399,27 @@ class OAuthFlowService:
access_token=tokens["access_token"],
)
if not user_info.get("provider_user_id"):
raise OAuthFlowError(
"Provider did not return a user identifier (sub claim). "
"Cannot complete authentication.",
"MISSING_PROVIDER_USER_ID",
400,
)
if not user_info.get("email"):
raise OAuthFlowError(
"Provider did not return an email address. "
"Cannot complete authentication.",
"MISSING_EMAIL",
400,
)
logger.debug(
f"Got user_info from provider: sub={user_info['provider_user_id']}, "
f"email={user_info['email']}, email_verified={user_info.get('email_verified')}"
)
# Look up user by provider_user_id
auth_method = AuthenticationMethod.query.filter_by(
method_type=provider_type,
@@ -494,25 +515,52 @@ class OAuthFlowService:
if not target_org and len(user_orgs) == 1:
target_org = user_orgs[0]
# Priority 3: No orgs at all — auto-create a personal org and log in
# Priority 3: No orgs at all — send to org-setup instead of auto-creating
if not target_org and len(user_orgs) == 0:
import re
import uuid
from gatehouse_app.services.organization_service import OrganizationService
org_name = f"{user_info.get('name') or user.email.split('@')[0]}'s Workspace"
# Build a URL-safe slug and ensure uniqueness with a short suffix
base_slug = re.sub(r"[^a-z0-9]+", "-", org_name.lower()).strip("-")[:40]
slug = f"{base_slug}-{uuid.uuid4().hex[:6]}"
org = OrganizationService.create_organization(
name=org_name,
slug=slug,
owner_user_id=user.id,
)
target_org = org
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
from gatehouse_app.services.auth_service import AuthService as _AS
_now = datetime.now(timezone.utc)
_session = _AS.create_session(user=user, is_compliance_only=False)
_session_dict = _session.to_dict()
_session_dict["token"] = _session.token
_expires_at = _session.expires_at
if _expires_at.tzinfo is None:
_expires_at = _expires_at.replace(tzinfo=timezone.utc)
_session_dict["expires_in"] = int((_expires_at - _now).total_seconds())
_pending = OrgInviteToken.query.filter(
OrgInviteToken.email == user.email,
OrgInviteToken.accepted_at.is_(None),
OrgInviteToken.expires_at > _now,
OrgInviteToken.deleted_at.is_(None),
).all()
_pending_list = [
{
"token": inv.token,
"organization": {
"id": str(inv.organization_id),
"name": inv.organization.name,
},
"role": inv.role,
"expires_at": inv.expires_at.isoformat(),
}
for inv in _pending
]
state_record.mark_used()
logger.info(
f"OAuth login: auto-created org '{org.name}' (id={org.id}) "
f"for new user {user.id}"
f"OAuth login: user {user.id} has no org, redirecting to org-setup "
f"(pending_invites={len(_pending_list)})"
)
return {
"success": True,
"flow_type": "login",
"requires_org_creation": True,
"user": {"id": user.id, "email": user.email, "full_name": user.full_name},
"session": _session_dict,
"pending_invites": _pending_list,
"state": state_record.state,
}
# Priority 4: Multiple orgs — need user to pick one
if not target_org:
@@ -755,7 +803,7 @@ class OAuthFlowService:
# If organization_id hint was provided and valid, create session for that org
if state_record.organization_id:
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization.organization import Organization
org = Organization.query.get(state_record.organization_id)
if org:
from gatehouse_app.services.auth_service import AuthService
@@ -784,7 +832,40 @@ class OAuthFlowService:
"session": session_dict,
}
# No organization hint or invalid - need to create/select org
# No organization hint or invalid - need to create/select org.
# Still create a session so the frontend can call /organizations
# and /invites after redirecting to /org-setup.
from gatehouse_app.services.auth_service import AuthService as _AS
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
_session = _AS.create_session(user=user, is_compliance_only=False)
_session_dict = _session.to_dict()
_session_dict["token"] = _session.token
_expires_at = _session.expires_at
if _expires_at.tzinfo is None:
_expires_at = _expires_at.replace(tzinfo=timezone.utc)
_now = datetime.now(timezone.utc)
_session_dict["expires_in"] = int((_expires_at - _now).total_seconds())
# Surface pending invitations so the UI can offer "join vs create"
_pending = OrgInviteToken.query.filter(
OrgInviteToken.email == user.email,
OrgInviteToken.accepted_at.is_(None),
OrgInviteToken.expires_at > _now,
OrgInviteToken.deleted_at.is_(None),
).all()
_pending_list = [
{
"token": inv.token,
"organization": {
"id": str(inv.organization_id),
"name": inv.organization.name,
},
"role": inv.role,
"expires_at": inv.expires_at.isoformat(),
}
for inv in _pending
]
return {
"success": True,
"flow_type": "register",
@@ -794,6 +875,8 @@ class OAuthFlowService:
"email": user.email,
"full_name": user.full_name,
},
"session": _session_dict,
"pending_invites": _pending_list,
"state": state_record.state,
}
@@ -966,8 +1049,8 @@ class OAuthFlowService:
)
# Determine organization
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
# Get user's organizations
user_orgs = user.get_organizations()
+1 -1
View File
@@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Tuple
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.oidc_jwks_key import OidcJwksKey
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey
class JWKSKey:
+1 -1
View File
@@ -14,7 +14,7 @@ from gatehouse_app.models import (
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
OIDCSession, OIDCTokenMetadata
)
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.exceptions.validation_exceptions import (
ValidationError, NotFoundError, BadRequestError
)
+1 -1
View File
@@ -11,7 +11,7 @@ import jwt
from flask import current_app, g
from gatehouse_app.models import User, OIDCClient
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
logger = logging.getLogger(__name__)
+22 -4
View File
@@ -3,8 +3,8 @@ import logging
from datetime import datetime, timezone
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.exceptions.validation_exceptions import OrganizationNotFoundError, ConflictError
from gatehouse_app.utils.constants import OrganizationRole, AuditAction
from gatehouse_app.services.audit_service import AuditService
@@ -188,11 +188,10 @@ class OrganizationService:
Raises:
ConflictError: If user is already a member
"""
# Check if already a member
# Check if already a member (active or soft-deleted — both blocked by DB unique constraint)
existing = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org.id,
deleted_at=None,
).first()
# Development-only debug logging for membership validation
@@ -200,6 +199,25 @@ class OrganizationService:
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
if existing:
if existing.deleted_at is not None:
# Reactivate the soft-deleted membership with the new role
existing.deleted_at = None
existing.role = role
existing.invited_by_id = inviter_id
existing.invited_at = datetime.now(timezone.utc)
existing.joined_at = datetime.now(timezone.utc)
existing.save()
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ADD,
user_id=inviter_id,
organization_id=org.id,
resource_type="organization_member",
resource_id=existing.id,
metadata={"added_user_id": user_id, "role": role.value},
description=f"Member re-added to organization with role: {role.value}",
)
return existing
raise ConflictError("User is already a member of this organization")
# Create membership
+2 -2
View File
@@ -1,6 +1,6 @@
"""Session service."""
from datetime import datetime, timezone
from gatehouse_app.models.session import Session
from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import SessionStatus
@@ -17,7 +17,7 @@ class SessionService:
Returns:
Session object if found and active, None otherwise
"""
from gatehouse_app.models.session import Session
from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import SessionStatus
return Session.query.filter_by(
token=token,
@@ -0,0 +1,359 @@
"""SSH Certificate Authority signing service.
Handles SSH certificate signing operations, leveraging sshkey-tools library.
This service is a Gatehouse-integrated version of the secuird/ssh_ca.py logic.
"""
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import List, Dict, Any, Optional
from sshkey_tools.cert import SSHCertificate, CertificateFields
from sshkey_tools.keys import PublicKey, PrivateKey
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
from gatehouse_app.exceptions import SSHCAError, ValidationError
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
logger = logging.getLogger(__name__)
class SSHCASigningError(Exception):
"""SSH CA signing operation error."""
pass
class SSHCertificateSigningRequest:
"""Represents an SSH certificate signing request."""
def __init__(
self,
ssh_public_key: str,
principals: List[str],
key_id: str,
cert_type: str = "user",
expiry_hours: Optional[int] = None,
critical_options: Optional[Dict[str, str]] = None,
extensions: Optional[List[str]] = None,
):
"""Initialize signing request.
Args:
ssh_public_key: Public key in OpenSSH format (e.g., "ssh-ed25519 AAAA...")
principals: List of principals (e.g., ["prod-servers", "staging"])
key_id: Key identifier (usually user email)
cert_type: Certificate type - "user" or "host" (default: user)
expiry_hours: Certificate validity in hours
critical_options: Critical options dict
extensions: List of extensions (e.g., ["permit-pty", "permit-agent-forwarding"])
"""
self.ssh_public_key = ssh_public_key
self.principals = principals or []
self.key_id = key_id
self.cert_type = cert_type
self.expiry_hours = expiry_hours
self.critical_options = critical_options or {}
self.extensions = extensions or []
def validate(self) -> List[str]:
"""Validate the signing request.
Returns:
List of validation errors (empty if valid)
"""
errors = []
config = get_ssh_ca_config()
# Validate cert type
if self.cert_type not in ("user", "host"):
errors.append(f"Invalid cert_type: {self.cert_type}. Must be 'user' or 'host'")
# Validate SSH public key
if not self.ssh_public_key or len(self.ssh_public_key) < 16:
errors.append("SSH public key is missing or invalid")
else:
try:
PublicKey.from_string(self.ssh_public_key)
except Exception as e:
errors.append(f"SSH public key is not valid: {str(e)}")
# Validate principals
if not self.principals or len(self.principals) == 0:
errors.append("At least one principal is required")
else:
max_principals = config.get_int('max_principals_per_cert')
if len(self.principals) > max_principals:
errors.append(
f"Too many principals ({len(self.principals)}). "
f"Maximum is {max_principals}"
)
# Validate key_id
if not self.key_id or len(self.key_id) < 5:
errors.append("key_id is missing or too short (minimum 5 characters)")
else:
max_id_len = config.get_int('max_key_id_length')
if len(self.key_id) > max_id_len:
errors.append(f"key_id exceeds maximum length of {max_id_len}")
# Validate expiry_hours
if self.expiry_hours is not None:
if not isinstance(self.expiry_hours, int) or self.expiry_hours <= 0:
errors.append("expiry_hours must be a positive integer")
else:
max_validity = config.get_int('max_cert_validity_hours')
if self.expiry_hours > max_validity:
errors.append(
f"Requested expiry ({self.expiry_hours}h) exceeds "
f"maximum allowed ({max_validity}h)"
)
return errors
class SSHCertificateSigningResponse:
"""Represents a signed SSH certificate response."""
def __init__(
self,
certificate: str,
serial: str,
valid_after: datetime,
valid_before: datetime,
principals: Optional[List[str]] = None,
):
"""Initialize signing response.
Args:
certificate: Full certificate in OpenSSH format
serial: Certificate serial number
valid_after: Validity start datetime
valid_before: Validity end datetime
principals: List of principals the cert was issued for
"""
self.certificate = certificate
self.serial = serial
self.valid_after = valid_after
self.valid_before = valid_before
self.principals = principals or []
def to_dict(self) -> Dict[str, Any]:
"""Convert response to dictionary."""
return {
'certificate': self.certificate,
'serial': self.serial,
'valid_after': self.valid_after.isoformat(),
'valid_before': self.valid_before.isoformat(),
}
class SSHCASigningService:
"""Service for signing SSH certificates.
This service handles all SSH certificate signing operations.
It uses configuration from ssh_ca_config to apply rules and limits.
"""
def __init__(self):
"""Initialize the SSH CA signing service."""
self.config = get_ssh_ca_config()
self.logger = logger
def _load_ca_key_from_config(self) -> str:
"""Load CA private key from config (local file or env var).
Returns:
CA private key in PEM/OpenSSH format as string
Raises:
SSHCASigningError: If key cannot be loaded
"""
# Check env var first
key_content = os.environ.get('SSH_CA_PRIVATE_KEY')
if key_content:
return key_content
# Load from file path
key_path = self.config.get_str('ca_key_path', '').strip()
if not key_path:
raise SSHCASigningError(
"CA private key not configured. Set SSH_CA_PRIVATE_KEY env var "
"or ca_key_path in etc/ssh_ca.conf"
)
key_path = os.path.expandvars(os.path.expanduser(key_path))
if not os.path.exists(key_path):
raise SSHCASigningError(f"CA private key file not found: {key_path}")
with open(key_path, 'r') as f:
return f.read()
def sign_certificate(
self,
signing_request: SSHCertificateSigningRequest,
ca_private_key: Optional[str] = None,
ca_obj=None,
) -> "SSHCertificateSigningResponse":
"""Sign an SSH certificate.
Args:
signing_request: SSHCertificateSigningRequest instance
ca_private_key: CA private key in PEM format. If not provided,
loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var)
ca_obj: Optional CA model instance. When supplied its monotonic
serial counter is incremented atomically (SELECT FOR UPDATE)
and the resulting integer is embedded in the certificate's
serial field. This ensures every issued cert has a unique,
ordered, auditable serial number.
Returns:
SSHCertificateSigningResponse with signed certificate
Raises:
SSHCASigningError: If signing fails
ValidationError: If request is invalid
"""
# Validate request
errors = signing_request.validate()
if errors:
error_msg = "; ".join(errors)
self.logger.error(f"Certificate signing validation failed: {error_msg}")
raise ValidationError(f"Certificate signing validation failed: {error_msg}")
# Load CA key if not provided
if ca_private_key is None:
ca_private_key = self._load_ca_key_from_config()
try:
# Parse CA private key
try:
ca_key = PrivateKey.from_string(ca_private_key)
except Exception as e:
self.logger.error(f"Failed to load CA private key: {str(e)}")
raise SSHCASigningError(f"Invalid CA private key: {str(e)}")
# Parse user's public key
try:
user_pub_key = PublicKey.from_string(signing_request.ssh_public_key)
except Exception as e:
self.logger.error(f"Failed to parse user public key: {str(e)}")
raise SSHCASigningError(f"Invalid user public key: {str(e)}")
# Create certificate
certificate = SSHCertificate.create(
subject_pubkey=user_pub_key,
ca_privkey=ca_key,
)
# Set validity period
now = datetime.now(timezone.utc)
expiry_hours = signing_request.expiry_hours or self.config.get_int('cert_validity_hours')
valid_before = now + timedelta(hours=expiry_hours)
# Set certificate fields
# sshkey-tools: user=1, host=2 (not 0)
cert_type = 1 if signing_request.cert_type == "user" else 2
certificate.fields.cert_type = cert_type
certificate.fields.key_id = signing_request.key_id
certificate.fields.principals = signing_request.principals
certificate.fields.valid_after = now
certificate.fields.valid_before = valid_before
# ── Serial number ────────────────────────────────────────────────
# If a CA object is provided, use its monotonic counter so every
# certificate gets a unique, ordered, auditable serial. The
# counter increment is flushed inside get_next_serial(); the
# caller's commit() persists it atomically with the cert record.
if ca_obj is not None:
assigned_serial = ca_obj.get_next_serial()
certificate.fields.serial = assigned_serial
self.logger.debug(
f"Assigned serial {assigned_serial} from CA {ca_obj.id}"
)
# ─────────────────────────────────────────────────────────────────
# Set extensions — prefer policy-provided list, fall back to standard set
extensions = signing_request.extensions
if not extensions:
from gatehouse_app.models.organization.department_cert_policy import STANDARD_EXTENSIONS
extensions = list(STANDARD_EXTENSIONS)
certificate.fields.extensions = extensions
certificate.fields.critical_options = signing_request.critical_options or {}
# Validate certificate before signing
if not certificate.can_sign():
raise SSHCASigningError("Certificate cannot be signed")
# Sign the certificate
certificate.sign()
# Verify the certificate
try:
certificate.verify(ca_key.public_key, raise_on_error=True)
except Exception as e:
self.logger.error(f"Certificate verification failed: {str(e)}")
raise SSHCASigningError(f"Certificate verification failed: {str(e)}")
# Extract serial from certificate — use the integer we assigned
# when ca_obj was provided, otherwise fall back to whatever the
# library generated.
if ca_obj is not None:
serial = str(assigned_serial)
else:
serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
# Build response
cert_string = certificate.to_string()
self.logger.info(
f"Successfully signed certificate: serial={serial}, "
f"key_id={signing_request.key_id}, principals={signing_request.principals}"
)
return SSHCertificateSigningResponse(
certificate=cert_string,
serial=serial,
valid_after=now,
valid_before=valid_before,
principals=signing_request.principals,
)
except (SSHCASigningError, ValidationError):
raise
except Exception as e:
self.logger.error(f"Unexpected error during certificate signing: {str(e)}", exc_info=True)
raise SSHCASigningError(f"Error signing certificate: {str(e)}")
def verify_ca_key(self, ca_private_key: str) -> Dict[str, Any]:
"""Verify a CA private key is valid and extract metadata.
Args:
ca_private_key: CA private key in PEM format
Returns:
Dictionary with key metadata (fingerprint, key_type, etc.)
Raises:
SSHCASigningError: If key is invalid
"""
try:
ca_key = PrivateKey.from_string(ca_private_key)
pub_key = ca_key.public_key
# Compute fingerprint
fingerprint = compute_ssh_fingerprint(pub_key.to_string())
# Get key type
key_type = pub_key.keytype if hasattr(pub_key, 'keytype') else 'unknown'
return {
'fingerprint': fingerprint,
'key_type': key_type,
'public_key': pub_key.to_string(),
'valid': True,
}
except Exception as e:
self.logger.error(f"CA key verification failed: {str(e)}")
raise SSHCASigningError(f"Invalid CA key: {str(e)}")
+375
View File
@@ -0,0 +1,375 @@
"""SSH Key management service."""
import base64
import logging
import os
import secrets
import subprocess
import tempfile
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from gatehouse_app.extensions import db
from gatehouse_app.models import SSHKey, User
from gatehouse_app.exceptions import (
SSHKeyError,
SSHKeyNotFoundError,
SSHKeyAlreadyExistsError,
SSHKeyNotVerifiedError,
ValidationError,
UserNotFoundError,
)
from gatehouse_app.utils.crypto import (
compute_ssh_fingerprint,
verify_ssh_key_format,
extract_ssh_key_type,
extract_ssh_key_comment,
)
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
logger = logging.getLogger(__name__)
class SSHKeyService:
"""Service for managing SSH keys."""
def __init__(self):
"""Initialize SSH key service."""
self.config = get_ssh_ca_config()
def add_ssh_key(
self,
user_id: str,
public_key: str,
description: Optional[str] = None,
) -> SSHKey:
"""Add an SSH public key for a user.
Args:
user_id: ID of the user
public_key: SSH public key in OpenSSH format
description: Optional description of the key
Returns:
Created SSHKey instance
Raises:
UserNotFoundError: If user doesn't exist
SSHKeyError: If key format is invalid
SSHKeyAlreadyExistsError: If key already exists
"""
# Verify user exists
user = User.query.get(user_id)
if not user:
raise UserNotFoundError(f"User {user_id} not found")
# Validate key format
if not verify_ssh_key_format(public_key):
raise SSHKeyError("Invalid SSH public key format")
# Compute fingerprint
try:
fingerprint = compute_ssh_fingerprint(public_key)
except Exception as e:
logger.error(f"Failed to compute fingerprint: {str(e)}")
raise SSHKeyError(f"Failed to compute key fingerprint: {str(e)}")
# Check for duplicate (including soft-deleted records — fingerprint is unique in DB)
existing = SSHKey.query.filter_by(fingerprint=fingerprint).first()
if existing:
if existing.deleted_at is not None:
# Restore the soft-deleted key: clear deleted_at and update fields
existing.deleted_at = None
existing.user_id = user_id
existing.description = description or existing.description
existing.verified = False
existing.verified_at = None
existing.verify_text = None
existing.verify_text_created_at = None
db.session.commit()
logger.info(
f"Restored soft-deleted SSH key for user {user_id}: "
f"fingerprint={fingerprint}"
)
return existing
raise SSHKeyAlreadyExistsError(
f"SSH key with fingerprint {fingerprint} already exists"
)
# Extract metadata
key_type = extract_ssh_key_type(public_key)
key_comment = extract_ssh_key_comment(public_key)
# Create SSH key record
ssh_key = SSHKey(
user_id=user_id,
payload=public_key,
fingerprint=fingerprint,
description=description,
key_type=key_type,
key_comment=key_comment,
verified=False,
)
ssh_key.save()
logger.info(
f"SSH key added for user {user_id}: "
f"fingerprint={fingerprint}, type={key_type}"
)
return ssh_key
def get_ssh_key(self, key_id: str) -> SSHKey:
"""Get an SSH key by ID.
Args:
key_id: SSH key ID
Returns:
SSHKey instance
Raises:
SSHKeyNotFoundError: If key not found
"""
key = SSHKey.query.filter_by(id=key_id, deleted_at=None).first()
if not key:
raise SSHKeyNotFoundError(f"SSH key {key_id} not found")
return key
def get_user_ssh_keys(self, user_id: str) -> List[SSHKey]:
"""Get all SSH keys for a user.
Args:
user_id: User ID
Returns:
List of SSHKey instances
"""
return SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all()
def get_user_verified_ssh_keys(self, user_id: str) -> List[SSHKey]:
"""Get all verified SSH keys for a user.
Args:
user_id: User ID
Returns:
List of verified SSHKey instances
"""
return SSHKey.query.filter_by(
user_id=user_id,
verified=True,
deleted_at=None,
).all()
def delete_ssh_key(self, key_id: str) -> None:
"""Soft-delete an SSH key.
Args:
key_id: SSH key ID
Raises:
SSHKeyNotFoundError: If key not found
"""
key = self.get_ssh_key(key_id)
key.delete()
logger.info(f"SSH key deleted: {key_id}")
def generate_verification_challenge(self, key_id: str) -> str:
"""Generate a verification challenge for an SSH key.
The user must sign this challenge text with their private key
to prove key ownership.
Args:
key_id: SSH key ID
Returns:
Verification challenge text
Raises:
SSHKeyNotFoundError: If key not found
"""
key = self.get_ssh_key(key_id)
# Generate random challenge
challenge = secrets.token_hex(32)
challenge_text = f"Please sign this to verify SSH key ownership: {challenge}"
# Store challenge
key.verify_text = challenge_text
key.verify_text_created_at = datetime.utcnow()
key.save()
logger.info(f"Generated verification challenge for SSH key {key_id}")
return challenge_text
def verify_ssh_key_ownership(
self,
key_id: str,
signature: str,
) -> bool:
"""Verify SSH key ownership via signature.
The user must sign the verification challenge with their private key.
We verify the signature using the public key.
Args:
key_id: SSH key ID
signature: Base64-encoded signature of the challenge
Returns:
True if signature is valid
Raises:
SSHKeyNotFoundError: If key not found
SSHKeyNotVerifiedError: If challenge is stale or missing
SSHKeyError: If verification fails
"""
key = self.get_ssh_key(key_id)
# Check if challenge exists and is not stale
if not key.verify_text or not key.verify_text_created_at:
raise SSHKeyNotVerifiedError("No verification challenge generated")
max_age = self.config.get_int('verification_challenge_max_age')
age = datetime.utcnow() - key.verify_text_created_at
if age.total_seconds() > (max_age * 3600):
raise SSHKeyNotVerifiedError("Verification challenge has expired")
try:
# Verify the SSH signature using ssh-keygen -Y verify.
# The CLI signs the challenge with: ssh-keygen -Y sign -f <key> -n file <challenge>
# We verify with: ssh-keygen -Y verify -f <allowed_signers> -I <identity> -n file -s <sig> < <message>
#
# allowed_signers format: "<identity> <keytype> <pubkey>"
# We use the key fingerprint as the identity.
sig_bytes = base64.b64decode(signature)
challenge_text = key.verify_text + "\n"
with tempfile.TemporaryDirectory() as tmpdir:
allowed_signers_path = os.path.join(tmpdir, "allowed_signers")
sig_path = os.path.join(tmpdir, "message.sig")
message_path = os.path.join(tmpdir, "message.txt")
identity = key.fingerprint
# Write the allowed_signers file
with open(allowed_signers_path, "w") as f:
f.write(f"{identity} {key.payload}\n")
# Write the signature file
with open(sig_path, "wb") as f:
f.write(sig_bytes)
# Write the challenge message
with open(message_path, "w") as f:
f.write(challenge_text)
result = subprocess.run(
[
"ssh-keygen", "-Y", "verify",
"-f", allowed_signers_path,
"-I", identity,
"-n", "file",
"-s", sig_path,
],
stdin=open(message_path, "rb"),
capture_output=True,
timeout=10,
)
if result.returncode != 0:
stderr = result.stderr.decode(errors="replace").strip()
logger.warning(f"SSH signature verification failed for key {key_id}: {stderr}")
raise SSHKeyError(f"Signature verification failed: {stderr}")
key.mark_verified()
logger.info(f"SSH key verified: {key_id}")
return True
except SSHKeyError:
raise
except Exception as e:
logger.error(f"SSH key verification failed: {str(e)}")
raise SSHKeyError(f"Signature verification failed: {str(e)}")
def get_key_fingerprint(self, key_id: str) -> str:
"""Get the fingerprint of an SSH key.
Args:
key_id: SSH key ID
Returns:
Fingerprint string
Raises:
SSHKeyNotFoundError: If key not found
"""
key = self.get_ssh_key(key_id)
return key.fingerprint
def update_ssh_key_description(self, key_id: str, description: str) -> SSHKey:
"""Update the description of an SSH key.
Args:
key_id: SSH key ID
description: New description
Returns:
Updated SSHKey instance
Raises:
SSHKeyNotFoundError: If key not found
"""
key = self.get_ssh_key(key_id)
key.description = description
key.save()
return key
def cleanup_expired_challenges(self) -> int:
"""Clean up expired verification challenges.
Returns:
Number of challenges cleaned
"""
max_age = self.config.get_int('verification_challenge_max_age')
threshold = datetime.utcnow() - timedelta(hours=max_age)
expired = SSHKey.query.filter(
SSHKey.verify_text_created_at < threshold,
SSHKey.verify_text_created_at.isnot(None),
SSHKey.deleted_at.is_(None),
).update({"verify_text": None, "verify_text_created_at": None})
db.session.commit()
logger.info(f"Cleaned up {expired} expired verification challenges")
return expired
def cleanup_unverified_keys(self) -> int:
"""Delete unverified SSH keys older than configured days.
Returns:
Number of keys deleted
"""
days = self.config.get_int('auto_delete_unverified_days')
threshold = datetime.utcnow() - timedelta(days=days)
old_unverified = SSHKey.query.filter(
SSHKey.verified == False,
SSHKey.created_at < threshold,
SSHKey.deleted_at.is_(None),
).all()
count = 0
for key in old_unverified:
key.delete()
count += 1
logger.info(f"Deleted {count} unverified SSH keys older than {days} days")
return count
+41
View File
@@ -11,10 +11,51 @@ from gatehouse_app.extensions import bcrypt
logger = logging.getLogger(__name__)
# TOTP codes are valid for at most (2*window + 1) * 30s steps.
# With window=1 that's 3 steps = 90 seconds. We use a slightly
# generous TTL of 95 seconds to account for clock skew at boundaries.
_TOTP_USED_CODE_TTL = 95
class TOTPService:
"""Service for TOTP operations."""
# ------------------------------------------------------------------
# Replay-attack prevention helpers
# ------------------------------------------------------------------
@staticmethod
def _used_key(user_id: str, code: str) -> str:
return f"totp:used:{user_id}:{code}"
@staticmethod
def is_code_already_used(user_id: str, code: str) -> bool:
"""Return True if *code* has already been accepted for *user_id*
within the current validity window (prevents replay attacks)."""
try:
from gatehouse_app.extensions import redis_client
if redis_client is None:
return False
return redis_client.exists(TOTPService._used_key(user_id, code)) == 1
except Exception:
logger.warning("Redis unavailable for TOTP replay check; allowing code")
return False
@staticmethod
def mark_code_used(user_id: str, code: str) -> None:
"""Record *code* as consumed for *user_id* so it cannot be reused."""
try:
from gatehouse_app.extensions import redis_client
if redis_client is None:
return
redis_client.setex(
TOTPService._used_key(user_id, code),
_TOTP_USED_CODE_TTL,
"1",
)
except Exception:
logger.warning("Redis unavailable; TOTP used-code not recorded")
@staticmethod
def generate_secret() -> str:
"""
+1 -1
View File
@@ -2,7 +2,7 @@
import logging
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.user import User
from gatehouse_app.models.user.user import User
from gatehouse_app.exceptions.validation_exceptions import UserNotFoundError
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
+22 -2
View File
@@ -10,8 +10,8 @@ from flask import current_app
from sqlalchemy.orm.attributes import flag_modified
from gatehouse_app.extensions import db, redis_client
from gatehouse_app.models.user import User
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.models.user.user import User
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
from gatehouse_app.services.audit_service import AuditService
@@ -641,6 +641,26 @@ class WebAuthnService:
)
return True
@classmethod
def credential_belongs_to_user(cls, credential_id: str, user: User) -> bool:
"""Check whether *credential_id* exists and belongs to *user*.
Args:
credential_id: The credential ID to look up
user: User instance
Returns:
True if the credential exists and belongs to this user, False otherwise.
"""
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None,
).first()
if not auth_method or not auth_method.provider_data:
return False
return auth_method.provider_data.get("credential_id") == credential_id
@classmethod
def rename_credential(cls, credential_id: str, user: User, name: str) -> bool:
+206
View File
@@ -0,0 +1,206 @@
"""Encryption helpers for CA private keys stored in the database.
CA private keys are encrypted at rest using Fernet (AES-128-CBC + HMAC-SHA256)
from the ``cryptography`` package. The encryption key is derived from the
``CA_ENCRYPTION_KEY`` environment variable (or ``Flask.config["CA_ENCRYPTION_KEY"]``).
Key derivation
--------------
Fernet requires a URL-safe base64-encoded 32-byte key. We accept any string
from the env and derive the actual Fernet key using SHA-256 so that operators
can supply human-readable secrets without having to pre-encode them.
Envelope format
---------------
Encrypted values are stored as the string::
$fernet$<fernet_token>
The ``$fernet$`` prefix lets the code distinguish already-encrypted values from
legacy plaintext PEM keys so that the migration path is safe and idempotent.
Usage
-----
Encrypt before storing::
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
ca.private_key = encrypt_ca_key(private_key_pem)
Decrypt before use::
from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
plaintext_pem = decrypt_ca_key(ca.private_key)
"""
import base64
import hashlib
import logging
import os
from cryptography.fernet import Fernet, InvalidToken
logger = logging.getLogger(__name__)
# Prefix that marks a stored value as Fernet-encrypted
_FERNET_PREFIX = "$fernet$"
class CAKeyEncryptionError(Exception):
"""Raised when CA key encryption or decryption fails."""
def _get_fernet() -> Fernet:
"""Build a Fernet instance from the configured encryption key.
Looks up ``CA_ENCRYPTION_KEY`` in the environment first, then falls back to
the Flask app config (if a request context is active).
Raises:
CAKeyEncryptionError: if no key is configured or it is the insecure
placeholder value in a production-like environment.
"""
raw_key = os.environ.get("CA_ENCRYPTION_KEY")
if not raw_key:
# Try Flask config if we're inside an app context
try:
from flask import current_app
raw_key = current_app.config.get("CA_ENCRYPTION_KEY")
except RuntimeError:
pass # No app context
if not raw_key:
raise CAKeyEncryptionError(
"CA_ENCRYPTION_KEY is not set. "
"Set this environment variable before starting the application."
)
# Warn loudly when running with the placeholder in a non-test environment
env_name = os.environ.get("FLASK_ENV", "").lower()
if raw_key.startswith("dev-") and env_name not in ("development", "testing", "test"):
logger.warning(
"CA_ENCRYPTION_KEY appears to be a development placeholder. "
"Set a strong random key for production environments."
)
# Derive a 32-byte key from the raw secret via SHA-256, then URL-safe base64
key_bytes = hashlib.sha256(raw_key.encode()).digest()
fernet_key = base64.urlsafe_b64encode(key_bytes)
return Fernet(fernet_key)
def encrypt_ca_key(plaintext_pem: str) -> str:
"""Encrypt a CA private key PEM string.
Idempotent: already-encrypted values are returned unchanged.
Args:
plaintext_pem: CA private key in OpenSSH/PEM format.
Returns:
Encrypted string with ``$fernet$`` prefix, safe for database storage.
Raises:
CAKeyEncryptionError: if the key cannot be encrypted.
"""
if not plaintext_pem:
raise CAKeyEncryptionError("Cannot encrypt an empty key")
# Already encrypted — do not double-encrypt
if plaintext_pem.startswith(_FERNET_PREFIX):
return plaintext_pem
try:
fernet = _get_fernet()
token = fernet.encrypt(plaintext_pem.encode()).decode()
return f"{_FERNET_PREFIX}{token}"
except CAKeyEncryptionError:
raise
except Exception as exc:
raise CAKeyEncryptionError(f"Failed to encrypt CA key: {exc}") from exc
def decrypt_ca_key(stored_value: str) -> str:
"""Decrypt a CA private key retrieved from the database.
Idempotent: plaintext (legacy) values are returned unchanged so that the
system continues to work while a migration encrypts existing rows.
Args:
stored_value: Value from ``CA.private_key`` column.
Returns:
Plaintext PEM string ready for use with ``sshkey_tools``.
Raises:
CAKeyEncryptionError: if decryption fails (wrong key, corrupted data).
"""
if not stored_value:
raise CAKeyEncryptionError("Cannot decrypt an empty value")
# Legacy plaintext key — return as-is
if not stored_value.startswith(_FERNET_PREFIX):
logger.warning(
"CA private key appears to be stored as plaintext. "
"Run the migration to encrypt existing keys."
)
return stored_value
token = stored_value[len(_FERNET_PREFIX):]
try:
fernet = _get_fernet()
return fernet.decrypt(token.encode()).decode()
except InvalidToken as exc:
raise CAKeyEncryptionError(
"CA key decryption failed — the CA_ENCRYPTION_KEY may be incorrect "
"or the stored key is corrupted."
) from exc
except CAKeyEncryptionError:
raise
except Exception as exc:
raise CAKeyEncryptionError(f"Unexpected decryption error: {exc}") from exc
def is_encrypted(stored_value: str) -> bool:
"""Return True if the stored value has the ``$fernet$`` envelope.
Args:
stored_value: Value from ``CA.private_key`` column.
"""
return bool(stored_value and stored_value.startswith(_FERNET_PREFIX))
def reencrypt_ca_key(stored_value: str, old_raw_key: str, new_raw_key: str) -> str:
"""Re-encrypt a CA key with a new encryption key (for key rotation).
Args:
stored_value: Current value from ``CA.private_key`` (may or may not be encrypted).
old_raw_key: The current ``CA_ENCRYPTION_KEY`` value (raw secret string).
new_raw_key: The new ``CA_ENCRYPTION_KEY`` value to encrypt with.
Returns:
New encrypted envelope string.
Raises:
CAKeyEncryptionError: if decryption or re-encryption fails.
"""
# Decrypt with old key
if stored_value.startswith(_FERNET_PREFIX):
token = stored_value[len(_FERNET_PREFIX):]
old_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(old_raw_key.encode()).digest())
try:
plaintext = Fernet(old_key_bytes).decrypt(token.encode()).decode()
except InvalidToken as exc:
raise CAKeyEncryptionError(
"Re-encryption failed: could not decrypt with the old key."
) from exc
else:
# Plaintext
plaintext = stored_value
# Re-encrypt with new key
new_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(new_raw_key.encode()).digest())
try:
token = Fernet(new_key_bytes).encrypt(plaintext.encode()).decode()
return f"{_FERNET_PREFIX}{token}"
except Exception as exc:
raise CAKeyEncryptionError(f"Re-encryption with new key failed: {exc}") from exc
+45
View File
@@ -12,6 +12,16 @@ class UserStatus(str, Enum):
COMPLIANCE_SUSPENDED = "compliance_suspended"
class Role(str, Enum):
"""Generic role definitions (hierarchy: Admin > Manager > Member > Viewer > Guest)."""
ADMIN = "admin"
MANAGER = "manager"
MEMBER = "member"
VIEWER = "viewer"
GUEST = "guest"
class OrganizationRole(str, Enum):
"""Organization member roles."""
@@ -51,6 +61,9 @@ class AuditAction(str, Enum):
USER_REGISTER = "user.register"
USER_UPDATE = "user.update"
USER_DELETE = "user.delete"
USER_HARD_DELETE = "user.hard_delete"
USER_SUSPEND = "user.suspend"
USER_UNSUSPEND = "user.unsuspend"
PASSWORD_CHANGE = "user.password_change"
PASSWORD_RESET = "user.password_reset"
@@ -61,6 +74,7 @@ class AuditAction(str, Enum):
ORG_MEMBER_ADD = "org.member.add"
ORG_MEMBER_REMOVE = "org.member.remove"
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
ORG_OWNERSHIP_TRANSFERRED = "org.ownership.transferred"
# Session actions
SESSION_CREATE = "session.create"
@@ -105,6 +119,37 @@ class AuditAction(str, Enum):
EXTERNAL_AUTH_CONFIG_UPDATE = "external_auth.config.update"
EXTERNAL_AUTH_CONFIG_DELETE = "external_auth.config.delete"
# SSH Key and Certificate actions
SSH_KEY_ADDED = "ssh.key.added"
SSH_KEY_VERIFIED = "ssh.key.verified"
SSH_KEY_DELETED = "ssh.key.deleted"
SSH_KEY_VALIDATION_FAILED = "ssh.key.validation.failed"
SSH_CERT_REQUESTED = "ssh.cert.requested"
SSH_CERT_ISSUED = "ssh.cert.issued"
SSH_CERT_FAILED = "ssh.cert.failed"
SSH_CERT_REVOKED = "ssh.cert.revoked"
SSH_CERT_EXPIRED = "ssh.cert.expired"
# CA actions
CA_CREATED = "ca.created"
CA_UPDATED = "ca.updated"
CA_DELETED = "ca.deleted"
CA_KEY_ROTATED = "ca.key.rotated"
# Principal actions
PRINCIPAL_CREATED = "principal.created"
PRINCIPAL_UPDATED = "principal.updated"
PRINCIPAL_DELETED = "principal.deleted"
PRINCIPAL_MEMBER_ADDED = "principal.member.added"
PRINCIPAL_MEMBER_REMOVED = "principal.member.removed"
# Department actions
DEPARTMENT_CREATED = "department.created"
DEPARTMENT_UPDATED = "department.updated"
DEPARTMENT_DELETED = "department.deleted"
DEPARTMENT_MEMBER_ADDED = "department.member.added"
DEPARTMENT_MEMBER_REMOVED = "department.member.removed"
class OIDCGrantType(str, Enum):
"""OIDC grant types."""
+128
View File
@@ -0,0 +1,128 @@
"""Cryptographic utilities for SSH operations."""
import hashlib
import base64
from typing import Optional
def compute_ssh_fingerprint(public_key_str: str, hash_algorithm: str = "sha256") -> str:
"""Compute the fingerprint of an SSH public key.
Args:
public_key_str: SSH public key in OpenSSH format
hash_algorithm: Hash algorithm to use (sha256, sha1, md5)
Returns:
Fingerprint string in the format "algorithm:hex_digest"
Example:
>>> key = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKp2..."
>>> fp = compute_ssh_fingerprint(key)
>>> print(fp)
sha256:Kb+...
"""
if not public_key_str:
raise ValueError("Public key string is empty")
# Parse OpenSSH format: "ssh-ed25519 <base64> [comment]"
parts = public_key_str.strip().split()
if len(parts) < 2:
raise ValueError("Invalid OpenSSH public key format")
try:
# The base64-encoded key is the second part
key_bytes = base64.b64decode(parts[1])
except Exception as e:
raise ValueError(f"Failed to decode public key: {str(e)}")
# Compute hash
if hash_algorithm == "sha256":
digest = hashlib.sha256(key_bytes).digest()
# SSH format uses base64 encoding without padding
fingerprint = base64.b64encode(digest).decode().rstrip('=')
elif hash_algorithm == "sha1":
digest = hashlib.sha1(key_bytes).hexdigest()
fingerprint = digest
elif hash_algorithm == "md5":
digest = hashlib.md5(key_bytes).hexdigest()
# Format as colons
fingerprint = ':'.join(digest[i:i+2] for i in range(0, len(digest), 2))
else:
raise ValueError(f"Unsupported hash algorithm: {hash_algorithm}")
return f"{hash_algorithm}:{fingerprint}"
def verify_ssh_key_format(public_key_str: str) -> bool:
"""Verify that a string is in valid OpenSSH public key format.
Args:
public_key_str: Potential SSH public key
Returns:
True if valid OpenSSH format, False otherwise
"""
if not public_key_str or not isinstance(public_key_str, str):
return False
parts = public_key_str.strip().split()
# Must have at least key type and key material
if len(parts) < 2:
return False
key_type = parts[0]
# Valid key types
valid_types = [
'ssh-rsa',
'ssh-ed25519',
'ecdsa-sha2-nistp256',
'ecdsa-sha2-nistp384',
'ecdsa-sha2-nistp521',
'ssh-dss',
]
if key_type not in valid_types:
return False
# Try to decode base64
try:
base64.b64decode(parts[1])
return True
except Exception:
return False
def extract_ssh_key_type(public_key_str: str) -> Optional[str]:
"""Extract the key type from an OpenSSH public key.
Args:
public_key_str: SSH public key in OpenSSH format
Returns:
Key type (e.g., "ssh-ed25519") or None if invalid
"""
if not verify_ssh_key_format(public_key_str):
return None
return public_key_str.strip().split()[0]
def extract_ssh_key_comment(public_key_str: str) -> Optional[str]:
"""Extract the comment from an OpenSSH public key.
Args:
public_key_str: SSH public key in OpenSSH format
Returns:
Comment string or None if not present
"""
if not verify_ssh_key_format(public_key_str):
return None
parts = public_key_str.strip().split()
if len(parts) >= 3:
# Everything after the second part is the comment
return ' '.join(parts[2:])
return None
+40 -3
View File
@@ -64,11 +64,47 @@ def login_required(f):
session.last_activity_at = datetime.now(timezone.utc)
from gatehouse_app import db
db.session.commit()
# Set context variables
g.current_user = session.user
g.current_session = session
user = session.user
token_groups: list = []
try:
if session.device_info:
# device_info may carry OIDC claims stored at login time
claims = session.device_info
# Normalise: Gatehouse stores roles as [{"organization_id":…,"role":…}]
roles_claim = claims.get("roles", [])
if isinstance(roles_claim, list):
for entry in roles_claim:
if isinstance(entry, dict):
role_val = entry.get("role")
if role_val:
token_groups.append(str(role_val))
elif isinstance(entry, str):
token_groups.append(entry)
# Standard OIDC groups claim
groups_claim = claims.get("groups", [])
if isinstance(groups_claim, list):
token_groups.extend(str(g_) for g_ in groups_claim if g_)
except Exception:
pass # Never block auth over token_groups enrichment failure
user.token_groups = token_groups
# Activation check: if the user has an `activated` attribute and it is
# explicitly False, block access. New accounts without the attribute are
# treated as active to avoid breaking existing sessions.
activated = getattr(user, "activated", None)
if activated is False:
return api_response(
success=False,
message="Account not yet activated. Please check your email for an activation link.",
status=403,
error_type="ACCOUNT_NOT_ACTIVATED",
)
return f(*args, **kwargs)
return decorated_function
@@ -97,11 +133,12 @@ def require_role(*allowed_roles):
raise ForbiddenError("Organization context required")
# Check user's role in the organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.organization.organization_member import OrganizationMember
membership = OrganizationMember.query.filter_by(
user_id=g.current_user.id,
organization_id=org_id,
deleted_at=None,
).first()
if not membership:
+74
View File
@@ -111,5 +111,79 @@ def mfa_compliance_status():
print("=" * 60)
@cli.command("configure_oauth")
def configure_oauth():
"""Interactively configure an OAuth provider at the application level.
Usage:
python manage.py configure_oauth
Supported providers: google, github, microsoft
"""
import getpass
from gatehouse_app.models.authentication_method import ApplicationProviderConfig
from gatehouse_app.extensions import db
SUPPORTED = ["google", "github", "microsoft"]
print("=" * 60)
print("OAuth Provider Configuration")
print("=" * 60)
print(f"Supported providers: {', '.join(SUPPORTED)}")
provider = input("Provider [google/github/microsoft]: ").strip().lower()
if provider not in SUPPORTED:
print(f"❌ Unknown provider: {provider}")
return
client_id = input("Client ID: ").strip()
if not client_id:
print("❌ client_id is required")
return
client_secret = getpass.getpass("Client Secret (leave blank to keep existing): ").strip()
with app.app_context():
config = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
if config:
config.client_id = client_id
if client_secret:
config.set_client_secret(client_secret)
config.is_enabled = True
db.session.commit()
print(f"✅ Updated {provider} provider config.")
else:
config = ApplicationProviderConfig(
provider_type=provider,
client_id=client_id,
is_enabled=True,
)
if client_secret:
config.set_client_secret(client_secret)
db.session.add(config)
db.session.commit()
print(f"✅ Created {provider} provider config.")
@cli.command("list_oauth")
def list_oauth():
"""List all configured OAuth providers.
Usage:
python manage.py list_oauth
"""
from gatehouse_app.models.authentication_method import ApplicationProviderConfig
with app.app_context():
configs = ApplicationProviderConfig.query.all()
if not configs:
print("No OAuth providers configured.")
return
print(f"{'Provider':<15} {'Client ID':<40} {'Enabled'}")
print("-" * 65)
for c in configs:
print(f"{c.provider_type:<15} {c.client_id:<40} {c.is_enabled}")
if __name__ == "__main__":
cli()
@@ -0,0 +1,127 @@
"""Add Department and Principal models for SSH CA management.
Revision ID: 006
Revises: 005
Create Date: 2026-02-27 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '006'
down_revision = '005'
branch_labels = None
depends_on = None
def upgrade():
# ### Department table ###
op.create_table('departments',
sa.Column('organization_id', sa.String(length=36), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('organization_id', 'name', name='uix_org_dept_name')
)
op.create_index(op.f('ix_departments_organization_id'), 'departments', ['organization_id'], unique=False)
op.create_index(op.f('ix_departments_name'), 'departments', ['name'], unique=False)
# ### DepartmentMembership table ###
op.create_table('department_memberships',
sa.Column('user_id', sa.String(length=36), nullable=False),
sa.Column('department_id', sa.String(length=36), nullable=False),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('user_id', 'department_id', name='uix_user_dept')
)
op.create_index(op.f('ix_department_memberships_user_id'), 'department_memberships', ['user_id'], unique=False)
op.create_index(op.f('ix_department_memberships_department_id'), 'department_memberships', ['department_id'], unique=False)
# ### Principal table ###
op.create_table('principals',
sa.Column('organization_id', sa.String(length=36), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('organization_id', 'name', name='uix_org_principal_name')
)
op.create_index(op.f('ix_principals_organization_id'), 'principals', ['organization_id'], unique=False)
op.create_index(op.f('ix_principals_name'), 'principals', ['name'], unique=False)
# ### PrincipalMembership table ###
op.create_table('principal_memberships',
sa.Column('user_id', sa.String(length=36), nullable=False),
sa.Column('principal_id', sa.String(length=36), nullable=False),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('user_id', 'principal_id', name='uix_user_principal')
)
op.create_index(op.f('ix_principal_memberships_user_id'), 'principal_memberships', ['user_id'], unique=False)
op.create_index(op.f('ix_principal_memberships_principal_id'), 'principal_memberships', ['principal_id'], unique=False)
# ### DepartmentPrincipal table ###
op.create_table('department_principals',
sa.Column('department_id', sa.String(length=36), nullable=False),
sa.Column('principal_id', sa.String(length=36), nullable=False),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('department_id', 'principal_id', name='uix_dept_principal')
)
op.create_index(op.f('ix_department_principals_department_id'), 'department_principals', ['department_id'], unique=False)
op.create_index(op.f('ix_department_principals_principal_id'), 'department_principals', ['principal_id'], unique=False)
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_department_principals_principal_id'), table_name='department_principals')
op.drop_index(op.f('ix_department_principals_department_id'), table_name='department_principals')
op.drop_table('department_principals')
op.drop_index(op.f('ix_principal_memberships_principal_id'), table_name='principal_memberships')
op.drop_index(op.f('ix_principal_memberships_user_id'), table_name='principal_memberships')
op.drop_table('principal_memberships')
op.drop_index(op.f('ix_principals_name'), table_name='principals')
op.drop_index(op.f('ix_principals_organization_id'), table_name='principals')
op.drop_table('principals')
op.drop_index(op.f('ix_department_memberships_department_id'), table_name='department_memberships')
op.drop_index(op.f('ix_department_memberships_user_id'), table_name='department_memberships')
op.drop_table('department_memberships')
op.drop_index(op.f('ix_departments_name'), table_name='departments')
op.drop_index(op.f('ix_departments_organization_id'), table_name='departments')
op.drop_table('departments')
# ### end Alembic commands ###
@@ -0,0 +1,173 @@
"""Add SSH CA models: SSHKey, SSHCertificate, CA, CertificateAuditLog.
Revision ID: 007
Revises: 006
Create Date: 2026-02-27 11:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '007'
down_revision = '006'
branch_labels = None
depends_on = None
def upgrade():
# ### CA table ###
op.create_table('cas',
sa.Column('organization_id', sa.String(length=36), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('key_type', sa.Enum('ed25519', 'rsa', 'ecdsa', name='ca_key_type_enum'), nullable=False),
sa.Column('private_key', sa.Text(), nullable=False),
sa.Column('public_key', sa.Text(), nullable=False),
sa.Column('fingerprint', sa.String(length=255), nullable=False),
sa.Column('crl_enabled', sa.Boolean(), nullable=False),
sa.Column('crl_endpoint', sa.String(length=512), nullable=True),
sa.Column('default_cert_validity_hours', sa.Integer(), nullable=False),
sa.Column('max_cert_validity_hours', sa.Integer(), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('rotated_at', sa.DateTime(), nullable=True),
sa.Column('rotation_reason', sa.String(length=255), nullable=True),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('fingerprint'),
sa.UniqueConstraint('organization_id', 'name', name='uix_org_ca_name')
)
op.create_index(op.f('ix_cas_organization_id'), 'cas', ['organization_id'], unique=False)
op.create_index('idx_ca_org_active', 'cas', ['organization_id', 'is_active'], unique=False)
# ### SSHKey table ###
op.create_table('ssh_keys',
sa.Column('user_id', sa.String(length=36), nullable=False),
sa.Column('payload', sa.Text(), nullable=False),
sa.Column('fingerprint', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(length=255), nullable=True),
sa.Column('verified', sa.Boolean(), nullable=False),
sa.Column('verified_at', sa.DateTime(), nullable=True),
sa.Column('verify_text', sa.String(length=255), nullable=True),
sa.Column('verify_text_created_at', sa.DateTime(), nullable=True),
sa.Column('key_type', sa.String(length=50), nullable=True),
sa.Column('key_bits', sa.Integer(), nullable=True),
sa.Column('key_comment', sa.String(length=255), nullable=True),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('payload'),
sa.UniqueConstraint('fingerprint')
)
op.create_index(op.f('ix_ssh_keys_user_id'), 'ssh_keys', ['user_id'], unique=False)
op.create_index(op.f('ix_ssh_keys_fingerprint'), 'ssh_keys', ['fingerprint'], unique=False)
op.create_index(op.f('ix_ssh_keys_verified'), 'ssh_keys', ['verified'], unique=False)
op.create_index('idx_ssh_key_user_verified', 'ssh_keys', ['user_id', 'verified'], unique=False)
# ### SSHCertificate table ###
op.create_table('ssh_certificates',
sa.Column('ca_id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=36), nullable=False),
sa.Column('ssh_key_id', sa.String(length=36), nullable=False),
sa.Column('certificate', sa.Text(), nullable=False),
sa.Column('serial', sa.String(length=255), nullable=False),
sa.Column('key_id', sa.String(length=255), nullable=False),
sa.Column('cert_type', sa.Enum('user', 'host', name='ssh_cert_type_enum'), nullable=False),
sa.Column('principals', sa.JSON(), nullable=False),
sa.Column('valid_after', sa.DateTime(), nullable=False),
sa.Column('valid_before', sa.DateTime(), nullable=False),
sa.Column('revoked', sa.Boolean(), nullable=False),
sa.Column('revoked_at', sa.DateTime(), nullable=True),
sa.Column('revoke_reason', sa.String(length=255), nullable=True),
sa.Column('status', sa.Enum('requested', 'issued', 'revoked', 'expired', 'superseded', name='ssh_cert_status_enum'), nullable=False),
sa.Column('request_ip', sa.String(length=45), nullable=True),
sa.Column('request_user_agent', sa.String(length=512), nullable=True),
sa.Column('critical_options', sa.JSON(), nullable=True),
sa.Column('extensions', sa.JSON(), nullable=True),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ),
sa.ForeignKeyConstraint(['ssh_key_id'], ['ssh_keys.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('serial')
)
op.create_index(op.f('ix_ssh_certificates_ca_id'), 'ssh_certificates', ['ca_id'], unique=False)
op.create_index(op.f('ix_ssh_certificates_user_id'), 'ssh_certificates', ['user_id'], unique=False)
op.create_index(op.f('ix_ssh_certificates_ssh_key_id'), 'ssh_certificates', ['ssh_key_id'], unique=False)
op.create_index(op.f('ix_ssh_certificates_serial'), 'ssh_certificates', ['serial'], unique=False)
op.create_index(op.f('ix_ssh_certificates_revoked'), 'ssh_certificates', ['revoked'], unique=False)
op.create_index(op.f('ix_ssh_certificates_status'), 'ssh_certificates', ['status'], unique=False)
op.create_index('idx_cert_user_status', 'ssh_certificates', ['user_id', 'status'], unique=False)
op.create_index('idx_cert_validity', 'ssh_certificates', ['valid_after', 'valid_before'], unique=False)
op.create_index('idx_cert_revoked', 'ssh_certificates', ['revoked', 'revoked_at'], unique=False)
# ### CertificateAuditLog table ###
op.create_table('certificate_audit_logs',
sa.Column('certificate_id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=36), nullable=True),
sa.Column('action', sa.String(length=50), nullable=False),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.String(length=512), nullable=True),
sa.Column('request_id', sa.String(length=36), nullable=True),
sa.Column('message', sa.Text(), nullable=True),
sa.Column('extra_data', sa.JSON(), nullable=True),
sa.Column('success', sa.Boolean(), nullable=False),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['certificate_id'], ['ssh_certificates.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
)
op.create_index(op.f('ix_certificate_audit_logs_certificate_id'), 'certificate_audit_logs', ['certificate_id'], unique=False)
op.create_index(op.f('ix_certificate_audit_logs_user_id'), 'certificate_audit_logs', ['user_id'], unique=False)
op.create_index(op.f('ix_certificate_audit_logs_action'), 'certificate_audit_logs', ['action'], unique=False)
op.create_index('idx_cert_audit_cert_action', 'certificate_audit_logs', ['certificate_id', 'action'], unique=False)
op.create_index('idx_cert_audit_user', 'certificate_audit_logs', ['user_id', 'created_at'], unique=False)
def downgrade():
op.drop_index('idx_cert_audit_user', table_name='certificate_audit_logs')
op.drop_index('idx_cert_audit_cert_action', table_name='certificate_audit_logs')
op.drop_index(op.f('ix_certificate_audit_logs_action'), table_name='certificate_audit_logs')
op.drop_index(op.f('ix_certificate_audit_logs_user_id'), table_name='certificate_audit_logs')
op.drop_index(op.f('ix_certificate_audit_logs_certificate_id'), table_name='certificate_audit_logs')
op.drop_table('certificate_audit_logs')
op.drop_index('idx_cert_revoked', table_name='ssh_certificates')
op.drop_index('idx_cert_validity', table_name='ssh_certificates')
op.drop_index('idx_cert_user_status', table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_status'), table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_revoked'), table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_serial'), table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_ssh_key_id'), table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_user_id'), table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_ca_id'), table_name='ssh_certificates')
op.drop_table('ssh_certificates')
op.drop_index('idx_ssh_key_user_verified', table_name='ssh_keys')
op.drop_index(op.f('ix_ssh_keys_verified'), table_name='ssh_keys')
op.drop_index(op.f('ix_ssh_keys_fingerprint'), table_name='ssh_keys')
op.drop_index(op.f('ix_ssh_keys_user_id'), table_name='ssh_keys')
op.drop_table('ssh_keys')
op.drop_index('idx_ca_org_active', table_name='cas')
op.drop_index(op.f('ix_cas_organization_id'), table_name='cas')
op.drop_table('cas')
@@ -0,0 +1,53 @@
"""Add TOTP and WEBAUTHN to authmethodtype enum.
Revision ID: 008
Revises: 007
Create Date: 2026-02-27 15:00:00.000000
The original migration (001_base) created authmethodtype with only:
PASSWORD, GOOGLE, GITHUB, MICROSOFT, SAML, OIDC
This migration adds the missing TOTP and WEBAUTHN values so
has_totp_enabled() and has_webauthn_enabled() queries work correctly.
"""
from alembic import op
import sqlalchemy as sa
revision = '008'
down_revision = '007'
branch_labels = None
depends_on = None
def upgrade():
# Add TOTP to the enum (idempotent approach using DO block)
op.execute("""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_enum
WHERE enumlabel = 'TOTP'
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype')
) THEN
ALTER TYPE authmethodtype ADD VALUE 'TOTP';
END IF;
END$$;
""")
op.execute("""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_enum
WHERE enumlabel = 'WEBAUTHN'
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype')
) THEN
ALTER TYPE authmethodtype ADD VALUE 'WEBAUTHN';
END IF;
END$$;
""")
def downgrade():
# PostgreSQL does not support removing enum values; downgrade is a no-op.
pass
@@ -0,0 +1,61 @@
"""Sync auditaction enum with all AuditAction Python enum values.
Revision ID: 009
Revises: 008
Create Date: 2026-02-27 15:20:00.000000
The auditaction DB enum was only created with the initial 17 values from 001_base.py.
All TOTP, WebAuthn, OAuth, SSH, CA, Principal, and Department audit actions were added
to the Python enum but never synced to the DB type.
"""
from alembic import op
revision = '009'
down_revision = '008'
branch_labels = None
depends_on = None
MISSING_VALUES = [
'TOTP_ENROLL_INITIATED', 'TOTP_ENROLL_COMPLETED', 'TOTP_VERIFY_SUCCESS',
'TOTP_VERIFY_FAILED', 'TOTP_DISABLED', 'TOTP_BACKUP_CODE_USED',
'TOTP_BACKUP_CODES_REGENERATED', 'WEBAUTHN_REGISTER_INITIATED',
'WEBAUTHN_REGISTER_COMPLETED', 'WEBAUTHN_REGISTER_FAILED',
'WEBAUTHN_LOGIN_INITIATED', 'WEBAUTHN_LOGIN_SUCCESS', 'WEBAUTHN_LOGIN_FAILED',
'WEBAUTHN_CREDENTIAL_DELETED', 'WEBAUTHN_CREDENTIAL_RENAMED',
'ORG_SECURITY_POLICY_UPDATE', 'USER_SECURITY_POLICY_OVERRIDE_UPDATE',
'MFA_POLICY_USER_SUSPENDED', 'MFA_POLICY_USER_COMPLIANT',
'EXTERNAL_AUTH_LINK_INITIATED', 'EXTERNAL_AUTH_LINK_COMPLETED',
'EXTERNAL_AUTH_LINK_FAILED', 'EXTERNAL_AUTH_UNLINK', 'EXTERNAL_AUTH_LOGIN',
'EXTERNAL_AUTH_LOGIN_FAILED', 'EXTERNAL_AUTH_TOKEN_REFRESH',
'EXTERNAL_AUTH_CONFIG_CREATE', 'EXTERNAL_AUTH_CONFIG_UPDATE',
'EXTERNAL_AUTH_CONFIG_DELETE', 'SSH_KEY_ADDED', 'SSH_KEY_VERIFIED',
'SSH_KEY_DELETED', 'SSH_KEY_VALIDATION_FAILED', 'SSH_CERT_REQUESTED',
'SSH_CERT_ISSUED', 'SSH_CERT_FAILED', 'SSH_CERT_REVOKED', 'SSH_CERT_EXPIRED',
'CA_CREATED', 'CA_UPDATED', 'CA_DELETED', 'CA_KEY_ROTATED',
'PRINCIPAL_CREATED', 'PRINCIPAL_UPDATED', 'PRINCIPAL_DELETED',
'PRINCIPAL_MEMBER_ADDED', 'PRINCIPAL_MEMBER_REMOVED',
'DEPARTMENT_CREATED', 'DEPARTMENT_UPDATED', 'DEPARTMENT_DELETED',
'DEPARTMENT_MEMBER_ADDED', 'DEPARTMENT_MEMBER_REMOVED',
]
def upgrade():
for val in MISSING_VALUES:
op.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_enum
WHERE enumlabel = '{val}'
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
) THEN
ALTER TYPE auditaction ADD VALUE '{val}';
END IF;
END$$;
""")
def downgrade():
# PostgreSQL does not support removing enum values; downgrade is a no-op.
pass
@@ -0,0 +1,50 @@
"""add password reset and email verification token tables
Revision ID: 010_password_reset_email_verify
Revises: 009_sync_auditaction_enum
Create Date: 2025-01-01 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '010_password_reset_email_verify'
down_revision = '009'
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
'password_reset_tokens',
sa.Column('id', sa.String(36), primary_key=True, nullable=False),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
sa.Column('token', sa.String(128), nullable=False, unique=True),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('used_at', sa.DateTime, nullable=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
)
op.create_index('ix_password_reset_tokens_user_id', 'password_reset_tokens', ['user_id'])
op.create_index('ix_password_reset_tokens_token', 'password_reset_tokens', ['token'])
op.create_table(
'email_verification_tokens',
sa.Column('id', sa.String(36), primary_key=True, nullable=False),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
sa.Column('token', sa.String(128), nullable=False, unique=True),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('used_at', sa.DateTime, nullable=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
)
op.create_index('ix_email_verification_tokens_user_id', 'email_verification_tokens', ['user_id'])
op.create_index('ix_email_verification_tokens_token', 'email_verification_tokens', ['token'])
def downgrade():
op.drop_table('email_verification_tokens')
op.drop_table('password_reset_tokens')
@@ -0,0 +1,38 @@
"""add org_invite_tokens table
Revision ID: 011_org_invite_tokens
Revises: 010_password_reset_email_verify
Create Date: 2025-01-01 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = '011_org_invite_tokens'
down_revision = '010_password_reset_email_verify'
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
'org_invite_tokens',
sa.Column('id', sa.String(36), primary_key=True, nullable=False),
sa.Column('organization_id', sa.String(36), sa.ForeignKey('organizations.id', ondelete='CASCADE'), nullable=False),
sa.Column('invited_by_id', sa.String(36), sa.ForeignKey('users.id', ondelete='SET NULL'), nullable=True),
sa.Column('email', sa.String(255), nullable=False),
sa.Column('role', sa.String(64), nullable=False, server_default='member'),
sa.Column('token', sa.String(128), nullable=False, unique=True),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('accepted_at', sa.DateTime, nullable=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
)
op.create_index('ix_org_invite_tokens_organization_id', 'org_invite_tokens', ['organization_id'])
op.create_index('ix_org_invite_tokens_email', 'org_invite_tokens', ['email'])
op.create_index('ix_org_invite_tokens_token', 'org_invite_tokens', ['token'])
def downgrade():
op.drop_table('org_invite_tokens')
@@ -0,0 +1,33 @@
"""Make CA.organization_id nullable (system CA) and add cert_id to sign response
Revision ID: 012_ca_nullable_org_and_cert_serial
Revises: 011_org_invite_tokens
Create Date: 2025-01-01 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = '012_ca_nullable_org'
down_revision = '011_org_invite_tokens'
branch_labels = None
depends_on = None
def upgrade():
# Allow CA records without an org (e.g. the global system-config CA)
with op.batch_alter_table('cas', schema=None) as batch_op:
batch_op.alter_column(
'organization_id',
existing_type=sa.String(36),
nullable=True,
)
def downgrade():
with op.batch_alter_table('cas', schema=None) as batch_op:
batch_op.alter_column(
'organization_id',
existing_type=sa.String(36),
nullable=False,
)
+42
View File
@@ -0,0 +1,42 @@
"""Add ca_type column to cas table (user/host).
Revision ID: 013
Revises: d34bfb72844e
Create Date: 2026-02-28 23:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '013'
down_revision = 'd34bfb72844e'
branch_labels = None
depends_on = None
def upgrade():
# Create the enum type first (PostgreSQL requires this)
ca_type_enum = sa.Enum('user', 'host', name='ca_type_enum')
ca_type_enum.create(op.get_bind(), checkfirst=True)
# Add ca_type column with a default of 'user' so existing CAs stay valid
op.add_column(
'cas',
sa.Column(
'ca_type',
ca_type_enum,
nullable=False,
server_default='user',
),
)
def downgrade():
op.drop_column('cas', 'ca_type')
# Drop the enum type (PostgreSQL only; SQLite ignores)
try:
op.execute("DROP TYPE IF EXISTS ca_type_enum")
except Exception:
pass
@@ -0,0 +1,44 @@
"""add_department_cert_policies
Adds the department_cert_policies table which stores per-department
SSH certificate issuance rules:
- whether users may choose their own expiry
- default and maximum expiry durations
- allowed SSH certificate extensions
"""
from alembic import op
import sqlalchemy as sa
revision = "014_add_dept_cert_policy"
down_revision = "013"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"department_cert_policies",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("department_id", sa.String(36), sa.ForeignKey("departments.id"), nullable=False, unique=True),
# Whether users are allowed to specify their own expiry (up to max)
sa.Column("allow_user_expiry", sa.Boolean(), nullable=False, server_default="0"),
# Default validity in hours (used when user doesn't specify, or not allowed to)
sa.Column("default_expiry_hours", sa.Integer(), nullable=False, server_default="1"),
# Hard cap on validity; admin cannot be exceeded
sa.Column("max_expiry_hours", sa.Integer(), nullable=False, server_default="24"),
# JSON list of extension names that are enabled for this department
# e.g. ["permit-pty", "permit-agent-forwarding"]
sa.Column("allowed_extensions", sa.JSON(), nullable=False, server_default='["permit-pty","permit-agent-forwarding","permit-X11-forwarding","permit-port-forwarding","permit-user-rc"]'),
# Admin-defined custom extension names beyond the standard five
sa.Column("custom_extensions", sa.JSON(), nullable=False, server_default="[]"),
sa.Column("created_at", sa.DateTime(), nullable=True),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("deleted_at", sa.DateTime(), nullable=True),
)
op.create_index("idx_dept_cert_policy_dept", "department_cert_policies", ["department_id"])
def downgrade():
op.drop_index("idx_dept_cert_policy_dept", "department_cert_policies")
op.drop_table("department_cert_policies")
@@ -0,0 +1,37 @@
"""Add USER_SUSPEND and USER_UNSUSPEND to auditaction enum.
Revision ID: 015_add_user_suspend_audit_actions
Revises: 014_add_dept_cert_policy
Create Date: 2026-03-02
USER_SUSPEND and USER_UNSUSPEND were added to the Python AuditAction enum
but were never synced to the PostgreSQL auditaction type, causing a
DataError (invalid enum value) whenever an admin suspends or unsuspends a user.
"""
from alembic import op
revision = "015_user_suspend_audit"
down_revision = "014_add_dept_cert_policy"
branch_labels = None
depends_on = None
def upgrade():
for val in ("USER_SUSPEND", "USER_UNSUSPEND"):
op.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_enum
WHERE enumlabel = '{val}'
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
) THEN
ALTER TYPE auditaction ADD VALUE '{val}';
END IF;
END$$;
""")
def downgrade():
# PostgreSQL does not support removing enum values; downgrade is a no-op.
pass
@@ -0,0 +1,168 @@
"""Encrypt existing plaintext CA private keys at rest.
Revision ID: 016_encrypt_existing_ca_keys
Revises: 015_add_user_suspend_audit_actions
Create Date: 2026-03-02
All CA private keys created before this migration were stored as plaintext PEM
strings in the ``cas.private_key`` column. This migration detects those rows
(by checking for the absence of the ``$fernet$`` prefix that encrypted values
carry) and re-encrypts them with the key derived from ``CA_ENCRYPTION_KEY``.
The migration is safe to re-run: already-encrypted rows are left untouched.
Prerequisites
-------------
``CA_ENCRYPTION_KEY`` must be set in the environment before running this
migration. The same value must be configured for the running application.
To roll back to plaintext (downgrade):
The ``downgrade()`` function decrypts all rows back to plaintext PEM. This is
provided only for emergency rollback and should not be used in production once
the system has been running with encrypted keys.
"""
import os
import base64
import hashlib
import logging
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
# Alembic revision identifiers
revision = "016_encrypt_ca_keys"
down_revision = "015_user_suspend_audit"
branch_labels = None
depends_on = None
_FERNET_PREFIX = "$fernet$"
def _get_fernet():
"""Build a Fernet instance from CA_ENCRYPTION_KEY env var."""
from cryptography.fernet import Fernet
raw_key = os.environ.get("CA_ENCRYPTION_KEY")
if not raw_key:
raise RuntimeError(
"CA_ENCRYPTION_KEY environment variable is not set. "
"Set it before running this migration."
)
key_bytes = base64.urlsafe_b64encode(hashlib.sha256(raw_key.encode()).digest())
return Fernet(key_bytes)
def upgrade():
"""Encrypt plaintext CA private keys."""
bind = op.get_bind()
session = Session(bind=bind)
try:
fernet = _get_fernet()
except RuntimeError as exc:
raise RuntimeError(str(exc)) from exc
# Fetch all non-deleted CA rows
rows = session.execute(
sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
).fetchall()
encrypted_count = 0
skipped_count = 0
for row in rows:
ca_id, private_key = row[0], row[1]
if not private_key:
logger.warning(f"CA {ca_id} has empty private_key — skipping")
skipped_count += 1
continue
if private_key.startswith(_FERNET_PREFIX):
# Already encrypted
skipped_count += 1
continue
# Encrypt
try:
token = fernet.encrypt(private_key.encode()).decode()
encrypted_value = f"{_FERNET_PREFIX}{token}"
session.execute(
sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
{"pk": encrypted_value, "id": ca_id},
)
encrypted_count += 1
logger.info(f"Encrypted private key for CA {ca_id}")
except Exception as exc:
session.rollback()
raise RuntimeError(
f"Failed to encrypt private key for CA {ca_id}: {exc}"
) from exc
session.commit()
logger.info(
f"CA key encryption migration complete: "
f"{encrypted_count} encrypted, {skipped_count} skipped"
)
print(
f" [016_encrypt_ca_keys] {encrypted_count} CA private key(s) encrypted, "
f"{skipped_count} already encrypted or empty."
)
def downgrade():
"""Decrypt CA private keys back to plaintext (emergency rollback only)."""
bind = op.get_bind()
session = Session(bind=bind)
try:
fernet = _get_fernet()
except RuntimeError as exc:
raise RuntimeError(str(exc)) from exc
rows = session.execute(
sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
).fetchall()
decrypted_count = 0
skipped_count = 0
for row in rows:
ca_id, private_key = row[0], row[1]
if not private_key or not private_key.startswith(_FERNET_PREFIX):
skipped_count += 1
continue
token = private_key[len(_FERNET_PREFIX):]
try:
from cryptography.fernet import InvalidToken
try:
plaintext = fernet.decrypt(token.encode()).decode()
except InvalidToken as exc:
raise RuntimeError(
f"Downgrade failed: cannot decrypt CA {ca_id} — wrong key or corrupted data."
) from exc
session.execute(
sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
{"pk": plaintext, "id": ca_id},
)
decrypted_count += 1
logger.warning(f"Decrypted (plaintext restore) private key for CA {ca_id}")
except RuntimeError:
session.rollback()
raise
session.commit()
logger.warning(
f"CA key decryption (downgrade) complete: "
f"{decrypted_count} decrypted, {skipped_count} skipped"
)
print(
f" [016_encrypt_ca_keys] DOWNGRADE: {decrypted_count} CA private key(s) "
f"decrypted to plaintext. WARNING: keys are now unencrypted at rest."
)
@@ -0,0 +1,37 @@
"""Add monotonic serial counter to CAs table.
Each CA now owns a `next_serial_number` (BigInteger) that is atomically
incremented every time a certificate is signed. This guarantees:
- Serials are unique per CA
- Serials are monotonically increasing (auditable, no gaps by accident)
- The value embedded in the OpenSSH certificate matches what is stored
in the `ssh_certificates.serial` column
Revision ID: 017_add_ca_serial_counter
Revises: 016_encrypt_ca_keys
Create Date: 2026-03-02
"""
from alembic import op
import sqlalchemy as sa
revision = "017_add_ca_serial_counter"
down_revision = "016_encrypt_ca_keys"
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table("cas", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"next_serial_number",
sa.BigInteger(),
nullable=False,
server_default="1",
)
)
def downgrade():
with op.batch_alter_table("cas", schema=None) as batch_op:
batch_op.drop_column("next_serial_number")
@@ -0,0 +1,52 @@
"""Add ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE to auditaction enum.
Revision ID: 018_audit_enum_values
Revises: 017_add_ca_serial_counter
Create Date: 2026-03-02
ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE were added to the Python
AuditAction enum but were never synced to the PostgreSQL auditaction type,
causing a DataError (invalid enum value) when transferring org ownership
or hard-deleting a user.
"""
from alembic import op
revision = "018_audit_enum_values"
down_revision = "017_add_ca_serial_counter"
branch_labels = None
depends_on = None
def upgrade():
# ALTER TYPE ... ADD VALUE cannot run inside a transaction block in PostgreSQL.
# Alembic has already opened a transaction on the connection by the time our
# upgrade() runs, so we must:
# 1. Roll back that open transaction on the raw psycopg2 connection.
# 2. Switch to autocommit so the ALTER TYPE runs outside any transaction.
# 3. Restore the previous state afterwards.
conn = op.get_bind()
# SQLAlchemy 2.x: conn.connection is a _ConnectionFairy; .driver_connection is psycopg2
fairy = conn.connection
raw = getattr(fairy, "driver_connection", None) or getattr(fairy, "dbapi_connection", fairy)
# Roll back the open transaction so psycopg2 allows us to change autocommit.
raw.rollback()
old_autocommit = raw.autocommit
raw.autocommit = True
try:
with raw.cursor() as cur:
for val in ("ORG_OWNERSHIP_TRANSFERRED", "USER_HARD_DELETE"):
cur.execute(
"SELECT 1 FROM pg_enum "
"WHERE enumlabel = %s "
"AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')",
(val,),
)
if not cur.fetchone():
cur.execute(f"ALTER TYPE auditaction ADD VALUE '{val}'")
finally:
raw.autocommit = old_autocommit
def downgrade():
# PostgreSQL does not support removing enum values; downgrade is a no-op.
pass
@@ -0,0 +1,50 @@
"""add_activation_fields_and_ca_permissions
Revision ID: d34bfb72844e
Revises: 012_ca_nullable_org
Create Date: 2026-02-28 18:06:47.328552
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd34bfb72844e'
down_revision = '012_ca_nullable_org'
branch_labels = None
depends_on = None
def upgrade():
# Create ca_permissions table
op.create_table(
'ca_permissions',
sa.Column('ca_id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=36), nullable=False),
sa.Column('permission', sa.String(length=50), nullable=False),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('ca_id', 'user_id', name='uix_ca_permission'),
)
op.create_index('ix_ca_permissions_ca_id', 'ca_permissions', ['ca_id'], unique=False)
op.create_index('ix_ca_permissions_user_id', 'ca_permissions', ['user_id'], unique=False)
# Add activation columns to users
op.add_column('users', sa.Column('activated', sa.Boolean(), nullable=False,
server_default=sa.text('true')))
op.add_column('users', sa.Column('activation_key', sa.String(length=128), nullable=True))
op.create_index('ix_users_activation_key', 'users', ['activation_key'], unique=True)
def downgrade():
op.drop_index('ix_users_activation_key', table_name='users')
op.drop_column('users', 'activation_key')
op.drop_column('users', 'activated')
op.drop_index('ix_ca_permissions_user_id', table_name='ca_permissions')
op.drop_index('ix_ca_permissions_ca_id', table_name='ca_permissions')
op.drop_table('ca_permissions')
+6 -3
View File
@@ -14,7 +14,7 @@ Flask-Marshmallow==0.15.0
marshmallow-sqlalchemy==0.29.0
# Security
bcrypt==4.1.2
bcrypt==4.2.0
Flask-Bcrypt==1.0.1
pyotp==2.9.0
@@ -24,7 +24,7 @@ cbor2==5.6.0
# JWT / OIDC
PyJWT==2.8.0
cryptography==41.0.7
cryptography==42.0.7
# CORS
Flask-CORS==4.0.0
@@ -47,4 +47,7 @@ Flask-Limiter==3.5.0
# Logging
python-json-logger==2.0.7
qrcode[pil]
qrcode[pil]
# SSH CA Certificate signing
sshkey-tools==0.11.3
+22
View File
@@ -20,3 +20,25 @@ watchdog==3.0.0
# Documentation
sphinx==7.2.6
# Web framework & Database
Flask==3.0.0
Flask-SQLAlchemy==3.1.1
Flask-Migrate==4.0.5
sqlalchemy-cockroachdb==2.0.3
# Utilities
colorlog==6.8.0
coloredlogs==15.0.1
prettytable==3.10.2
tabulate==0.9.0
requests==2.31.0
pytz==2023.3
python-dotenv==1.0.0
pydantic==2.5.0
PyJWT==2.8.0
cryptography==42.0.7
pycryptodome==3.20.0
psycopg2-binary==2.9.9
sshkey-tools==0.11.3
sendgrid==6.11.0
+5 -5
View File
@@ -15,11 +15,11 @@ load_dotenv()
from gatehouse_app import create_app
from gatehouse_app.extensions import db
from gatehouse_app.models.user import User
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.models.oidc_client import OIDCClient
from gatehouse_app.models.user.user import User
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.models.oidc.oidc_client import OIDCClient
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.utils.constants import OrganizationRole, UserStatus, AuthMethodType