diff --git a/client/gatehouse-cli.py b/client/gatehouse-cli.py
index e4f223a..2060c5a 100755
--- a/client/gatehouse-cli.py
+++ b/client/gatehouse-cli.py
@@ -51,8 +51,7 @@ class MyServer(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(bytes("
OIDC Workflow Tool", "utf-8"))
self.wfile.write(bytes("The token has been received
", "utf-8"))
- self.wfile.write(bytes("Window closing in 5 seconds...
", "utf-8"))
- self.wfile.write(bytes("", "utf-8"))
+ self.wfile.write(bytes("You may now close this window.
", "utf-8"))
self.wfile.write(bytes("", "utf-8"))
parsed_url = urlparse(self.path)
diff --git a/config/base.py b/config/base.py
index 599a9cd..73a4d7f 100644
--- a/config/base.py
+++ b/config/base.py
@@ -28,6 +28,11 @@ class BaseConfig:
# Encryption key for sensitive data (client secrets, tokens, etc.)
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "dev-encryption-key-change-in-production")
+
+ # Encryption key for CA private keys stored in the database.
+ # Must be set to a strong random secret in production.
+ # Any string is accepted — it is SHA-256 derived to a 32-byte Fernet key internally.
+ CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "dev-ca-encryption-key-change-in-production")
# Session configuration for WebAuthn cross-origin support
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "True").lower() == "true"
@@ -72,6 +77,13 @@ class BaseConfig:
RATELIMIT_STORAGE_URL = os.getenv("RATELIMIT_STORAGE_URL", "redis://localhost:6379/1")
RATELIMIT_DEFAULT = "100/hour"
+ # Per-endpoint auth rate limits (override via env vars for each environment)
+ RATELIMIT_AUTH_REGISTER = os.getenv("RATELIMIT_AUTH_REGISTER", "10 per minute; 50 per hour")
+ RATELIMIT_AUTH_LOGIN = os.getenv("RATELIMIT_AUTH_LOGIN", "20 per minute; 100 per hour")
+ RATELIMIT_AUTH_TOTP_VERIFY = os.getenv("RATELIMIT_AUTH_TOTP_VERIFY", "20 per minute; 100 per hour")
+ RATELIMIT_AUTH_FORGOT_PASSWORD = os.getenv("RATELIMIT_AUTH_FORGOT_PASSWORD", "5 per minute; 20 per hour")
+ RATELIMIT_AUTH_RESET_PASSWORD = os.getenv("RATELIMIT_AUTH_RESET_PASSWORD", "10 per minute; 30 per hour")
+
# Logging
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true"
diff --git a/config/testing.py b/config/testing.py
index 4ecfff0..aa988a7 100644
--- a/config/testing.py
+++ b/config/testing.py
@@ -12,6 +12,9 @@ class TestingConfig(BaseConfig):
# Explicitly set SECRET_KEY for testing
SECRET_KEY = os.getenv("SECRET_KEY", "test-secret-key-for-testing")
+ # CA key encryption — use a fixed test key so tests are deterministic
+ CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "test-ca-encryption-key-fixed-for-tests")
+
# Use in-memory SQLite for testing
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
SQLALCHEMY_ECHO = False
diff --git a/gatehouse_app/api/oidc.py b/gatehouse_app/api/oidc.py
index 3821e57..43bf5b4 100644
--- a/gatehouse_app/api/oidc.py
+++ b/gatehouse_app/api/oidc.py
@@ -22,7 +22,11 @@ from gatehouse_app.extensions import bcrypt as flask_bcrypt
from gatehouse_app.extensions import redis_client as _redis_client_ref # may be None until app init
from gatehouse_app.models import User, OIDCClient
from gatehouse_app.models.organization.organization import Organization
-from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
+from gatehouse_app.exceptions.auth_exceptions import (
+ InvalidCredentialsError,
+ AccountSuspendedError,
+ AccountInactiveError,
+)
# ---------------------------------------------------------------------------
# Helpers for Redis-backed OIDC pending state
@@ -343,6 +347,20 @@ def oidc_complete():
user_id = str(gh_session.user_id)
+ # Check the user is still active (not suspended after session was issued)
+ from gatehouse_app.models.user.user import User as _User
+ from gatehouse_app.utils.constants import UserStatus
+ _complete_user = _User.query.filter_by(id=user_id, deleted_at=None).first()
+ if not _complete_user or _complete_user.status in (
+ UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED, UserStatus.INACTIVE
+ ):
+ return api_response(
+ success=False,
+ message="Your account is not active or has been suspended.",
+ status=403,
+ error_type="ACCOUNT_SUSPENDED",
+ )
+
# Retrieve stashed OIDC params (consume = True removes from Redis atomically)
params = _fetch_oidc_params(oidc_session_id, consume=True)
if not params:
@@ -565,6 +583,28 @@ def oidc_authorize():
session["oidc_user_id"] = user_id
logger.debug("[OIDC] User authentication successful: user_id=%s, email=%s", user_id, email)
+ except AccountSuspendedError:
+ logger.debug("[OIDC] User authentication failed: account suspended for email=%s", email)
+ return _show_login_page(
+ client_id=client_id,
+ redirect_uri=redirect_uri,
+ scope=scope,
+ state=state,
+ nonce=nonce,
+ response_type=response_type,
+ error="Your account has been suspended. Please contact an administrator.",
+ )
+ except AccountInactiveError:
+ logger.debug("[OIDC] User authentication failed: account inactive for email=%s", email)
+ return _show_login_page(
+ client_id=client_id,
+ redirect_uri=redirect_uri,
+ scope=scope,
+ state=state,
+ nonce=nonce,
+ response_type=response_type,
+ error="Your account is not active. Please verify your email.",
+ )
except InvalidCredentialsError:
logger.debug("[OIDC] User authentication failed: invalid credentials for email=%s", email)
return _show_login_page(
@@ -600,7 +640,34 @@ def oidc_authorize():
if not user:
logger.debug("[OIDC] Redirecting with error: server_error (user not found)")
return _redirect_with_error(redirect_uri, "server_error", "User not found", state)
-
+
+ # Check account is still active (user could have been suspended after session start)
+ from gatehouse_app.utils.constants import UserStatus as _UserStatus
+ if user.status in (_UserStatus.SUSPENDED, _UserStatus.COMPLIANCE_SUSPENDED):
+ session.pop("oidc_user_id", None) # clear stale session
+ logger.debug("[OIDC] User is suspended, clearing session and showing login error: user_id=%s", user_id)
+ return _show_login_page(
+ client_id=client_id,
+ redirect_uri=redirect_uri,
+ scope=scope,
+ state=state,
+ nonce=nonce,
+ response_type=response_type,
+ error="Your account has been suspended. Please contact an administrator.",
+ )
+ if user.status == _UserStatus.INACTIVE:
+ session.pop("oidc_user_id", None)
+ logger.debug("[OIDC] User is inactive, clearing session and showing login error: user_id=%s", user_id)
+ return _show_login_page(
+ client_id=client_id,
+ redirect_uri=redirect_uri,
+ scope=scope,
+ state=state,
+ nonce=nonce,
+ response_type=response_type,
+ error="Your account is not active. Please verify your email.",
+ )
+
logger.debug("[OIDC] Generating authorization code...")
logger.debug("[OIDC] Authorization code params: client_id=%s, user_id=%s, redirect_uri=%s", client_id, user_id, redirect_uri)
logger.debug("[OIDC] Authorization code params: scopes=%s, state=%s, nonce=%s", valid_scopes, state, nonce)
diff --git a/gatehouse_app/api/v1/auth.py b/gatehouse_app/api/v1/auth.py
index d1ccc07..993f213 100644
--- a/gatehouse_app/api/v1/auth.py
+++ b/gatehouse_app/api/v1/auth.py
@@ -4,6 +4,7 @@ import logging
from flask import request, session, g, jsonify, current_app
from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp
+from gatehouse_app.extensions import limiter
from gatehouse_app.utils.response import api_response
from gatehouse_app.schemas.auth_schema import (
RegisterSchema,
@@ -32,6 +33,7 @@ from gatehouse_app.exceptions.validation_exceptions import ConflictError, NotFou
@api_v1_bp.route("/auth/register", methods=["POST"])
+@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_REGISTER"])
def register():
"""
Register a new user.
@@ -135,6 +137,7 @@ def register():
@api_v1_bp.route("/auth/login", methods=["POST"])
+@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_LOGIN"])
def login():
"""
Login user.
@@ -325,8 +328,13 @@ def get_current_user():
data={
"user": user.to_dict(),
"organizations": [
- {"id": org.id, "name": org.name, "slug": org.slug}
- for org in user.get_organizations()
+ {
+ "id": membership.organization.id,
+ "name": membership.organization.name,
+ "slug": membership.organization.slug,
+ "role": membership.role,
+ }
+ for membership in user.organization_memberships
],
},
message="User retrieved successfully",
@@ -478,6 +486,7 @@ def verify_totp_enrollment():
@api_v1_bp.route("/auth/totp/verify", methods=["POST"])
+@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_TOTP_VERIFY"])
def verify_totp():
"""
Verify TOTP code during login.
@@ -520,6 +529,18 @@ def verify_totp():
error_type="AUTHENTICATION_ERROR",
)
+ # Check account suspension before completing TOTP verification
+ from gatehouse_app.utils.constants import UserStatus
+ if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
+ session.pop("totp_pending_user_id", None)
+ session.pop("webauthn_pending_user_id", None)
+ return api_response(
+ success=False,
+ message="Account is suspended. Contact an administrator.",
+ status=403,
+ error_type="ACCOUNT_SUSPENDED",
+ )
+
# Verify TOTP code
AuthService.authenticate_with_totp(
user,
@@ -908,7 +929,18 @@ def begin_webauthn_login():
status=404,
error_type="NOT_FOUND",
)
-
+
+ # Check account suspension before proceeding
+ from gatehouse_app.utils.constants import UserStatus
+ if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
+ logger.warning(f"WebAuthn login begin - suspended account attempt: {user.email}")
+ return api_response(
+ success=False,
+ message="Account is suspended. Contact an administrator.",
+ status=403,
+ error_type="ACCOUNT_SUSPENDED",
+ )
+
# Check if user has any WebAuthn credentials
if not user.has_webauthn_enabled():
logger.warning(f"WebAuthn login begin - no credentials for user: {user.email}")
@@ -991,7 +1023,19 @@ def complete_webauthn_login():
status=401,
error_type="AUTHENTICATION_ERROR",
)
-
+
+ # Check account suspension before completing login
+ from gatehouse_app.utils.constants import UserStatus
+ if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
+ session.pop("webauthn_pending_user_id", None)
+ logger.warning(f"WebAuthn login complete - suspended account attempt: {user.email}")
+ return api_response(
+ success=False,
+ message="Account is suspended. Contact an administrator.",
+ status=403,
+ error_type="ACCOUNT_SUSPENDED",
+ )
+
# Extract challenge from client data
client_data = data.get("response", {}).get("clientDataJSON", "")
@@ -1129,6 +1173,19 @@ def delete_webauthn_credential(credential_id):
"""
user = g.current_user
+ # First check that the specific credential actually belongs to this user.
+ # Only then check whether it is the last one — otherwise a user with zero
+ # credentials gets a misleading "Cannot delete the last passkey" error
+ # instead of a 404.
+ credential_exists = WebAuthnService.credential_belongs_to_user(credential_id, user)
+ if not credential_exists:
+ return api_response(
+ success=False,
+ message="Credential not found",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
# Check if this is the last credential
credential_count = user.get_webauthn_credential_count()
if credential_count <= 1:
@@ -1238,6 +1295,7 @@ _pw_logger = logging.getLogger(__name__)
@api_v1_bp.route("/auth/forgot-password", methods=["POST"])
+@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_FORGOT_PASSWORD"])
def forgot_password():
"""Request a password reset email.
@@ -1294,6 +1352,7 @@ def forgot_password():
@api_v1_bp.route("/auth/reset-password", methods=["POST"])
+@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESET_PASSWORD"])
def reset_password():
"""Reset a user's password using a reset token.
@@ -1601,11 +1660,31 @@ def get_token():
302: Redirect to ``?token=``
"""
from flask import redirect as flask_redirect
+ from urllib.parse import urlparse
token = g.current_session.token
redirect_url = request.args.get("redirect", "").strip()
if redirect_url:
+ # Validate redirect URL against allowed origins to prevent open-redirect
+ # token exfiltration attacks (CWE-601).
+ allowed_origins = set(current_app.config.get("CORS_ORIGINS", []))
+ frontend_url = current_app.config.get("FRONTEND_URL", "")
+ if frontend_url:
+ parsed = urlparse(frontend_url)
+ allowed_origins.add(f"{parsed.scheme}://{parsed.netloc}")
+
+ parsed_redirect = urlparse(redirect_url)
+ redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
+
+ if redirect_origin not in allowed_origins:
+ return api_response(
+ success=False,
+ message="Redirect URL is not allowed.",
+ status=400,
+ error_type="INVALID_REDIRECT",
+ )
+
sep = "&" if "?" in redirect_url else "?"
return flask_redirect(f"{redirect_url}{sep}token={token}", code=302)
diff --git a/gatehouse_app/api/v1/organizations.py b/gatehouse_app/api/v1/organizations.py
index 9f9db4c..7d8f4d3 100644
--- a/gatehouse_app/api/v1/organizations.py
+++ b/gatehouse_app/api/v1/organizations.py
@@ -226,6 +226,10 @@ def delete_organization(org_id):
"""
Delete organization (soft delete).
+ The owner may only delete the organization if they are the *sole* remaining
+ member. If other active members exist they must first transfer ownership
+ (or remove all other members) before deleting the organization.
+
Args:
org_id: Organization ID
@@ -234,9 +238,26 @@ def delete_organization(org_id):
401: Not authenticated
403: Not the owner
404: Organization not found
+ 409: Organization still has other members — transfer ownership first
"""
org = OrganizationService.get_organization_by_id(org_id)
+ # Guard: block deletion while non-owner members still exist so ownership
+ # can be transferred rather than silently orphaning them.
+ active_member_count = org.get_member_count()
+ if active_member_count > 1:
+ return api_response(
+ success=False,
+ message=(
+ "This organization still has other members. "
+ "Please transfer ownership to another member or remove all "
+ "other members before deleting the organization."
+ ),
+ status=409,
+ error_type="ORG_HAS_MEMBERS",
+ error_details={"member_count": active_member_count},
+ )
+
OrganizationService.delete_organization(
org=org,
user_id=g.current_user.id,
@@ -446,6 +467,152 @@ def update_member_role(org_id, user_id):
)
+@api_v1_bp.route("/organizations//transfer-ownership", methods=["POST"])
+@login_required
+@full_access_required
+def transfer_organization_ownership(org_id):
+ """Transfer organization ownership from the current user to another member.
+
+ Only the current OWNER of the organization may call this endpoint.
+ The caller will be demoted to ADMIN and the target user will be promoted to OWNER.
+
+ Request body:
+ new_owner_user_id (str): UUID of the member to promote to OWNER.
+
+ Returns:
+ 200: Ownership transferred successfully
+ 400: Validation error / missing fields
+ 403: Caller is not the OWNER of this org
+ 404: Organization or target member not found
+ 409: Target is already the OWNER
+ """
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.utils.constants import OrganizationRole, AuditAction
+ from gatehouse_app.services.audit_service import AuditService
+
+ caller = g.current_user
+
+ data = request.get_json() or {}
+ new_owner_user_id = data.get("new_owner_user_id")
+ if not new_owner_user_id:
+ return api_response(
+ success=False,
+ message="new_owner_user_id is required",
+ status=400,
+ error_type="VALIDATION_ERROR",
+ )
+
+ if str(new_owner_user_id) == str(caller.id):
+ return api_response(
+ success=False,
+ message="You are already the owner of this organization.",
+ status=409,
+ error_type="CONFLICT",
+ )
+
+ # Fetch org (raises NotFound internally)
+ org = OrganizationService.get_organization_by_id(org_id)
+
+ # Confirm caller is the current OWNER
+ caller_membership = OrganizationMember.query.filter_by(
+ organization_id=org.id,
+ user_id=caller.id,
+ deleted_at=None,
+ ).first()
+ if not caller_membership or caller_membership.role != OrganizationRole.OWNER:
+ return api_response(
+ success=False,
+ message="Only the organization owner can transfer ownership.",
+ status=403,
+ error_type="AUTHORIZATION_ERROR",
+ )
+
+ # Verify the target is an active member
+ target_membership = OrganizationMember.query.filter_by(
+ organization_id=org.id,
+ user_id=new_owner_user_id,
+ deleted_at=None,
+ ).first()
+ if not target_membership:
+ return api_response(
+ success=False,
+ message="Target user is not a member of this organization.",
+ status=404,
+ error_type="NOT_FOUND",
+ )
+
+ if target_membership.role == OrganizationRole.OWNER:
+ return api_response(
+ success=False,
+ message="Target user is already the owner.",
+ status=409,
+ error_type="CONFLICT",
+ )
+
+ # ── Atomic role swap ─────────────────────────────────────────────────────
+ # Demote caller → ADMIN, promote target → OWNER.
+ # Both updates go through OrganizationService so all hooks/auditing fire.
+ try:
+ demoted = OrganizationService.update_member_role(
+ org=org,
+ user_id=str(caller.id),
+ new_role=OrganizationRole.ADMIN,
+ updater_id=str(caller.id),
+ )
+ promoted = OrganizationService.update_member_role(
+ org=org,
+ user_id=str(new_owner_user_id),
+ new_role=OrganizationRole.OWNER,
+ updater_id=str(caller.id),
+ )
+ except Exception as exc:
+ from gatehouse_app.extensions import db as _db
+ _db.session.rollback()
+ return api_response(
+ success=False,
+ message=f"Failed to transfer ownership: {exc}",
+ status=500,
+ error_type="SERVER_ERROR",
+ )
+
+ AuditService.log_action(
+ action=AuditAction.ORG_OWNERSHIP_TRANSFERRED,
+ user_id=caller.id,
+ organization_id=org.id,
+ resource_type="organization",
+ resource_id=str(org.id),
+ description=(
+ f"Ownership of '{org.name}' transferred from {caller.email} "
+ f"to {target_membership.user.email if target_membership.user else new_owner_user_id}"
+ ),
+ metadata={
+ "previous_owner_id": str(caller.id),
+ "previous_owner_email": caller.email,
+ "new_owner_id": str(new_owner_user_id),
+ "new_owner_email": (
+ target_membership.user.email if target_membership.user else None
+ ),
+ },
+ )
+
+ def _member_dict(m):
+ d = m.to_dict()
+ if m.user:
+ d["user"] = m.user.to_dict()
+ return d
+
+ return api_response(
+ data={
+ "previous_owner": _member_dict(demoted),
+ "new_owner": _member_dict(promoted),
+ },
+ message=(
+ f"Ownership of '{org.name}' successfully transferred to "
+ f"{target_membership.user.email if target_membership.user else new_owner_user_id}."
+ ),
+ )
+
+
@api_v1_bp.route("/organizations//audit-logs", methods=["GET"])
@login_required
@require_admin
@@ -756,10 +923,30 @@ def accept_invite(token):
inviter_id=invite.invited_by_id,
)
except Exception:
- pass # Already a member is fine
+ from gatehouse_app.extensions import db
+ db.session.rollback() # Clear broken transaction so invite.accept() can commit
invite.accept()
+ has_webauthn = user.has_webauthn_enabled()
+ has_totp = user.has_totp_enabled()
+
+ if has_webauthn:
+ from flask import session as flask_session
+ flask_session["webauthn_pending_user_id"] = user.id
+ return api_response(
+ data={"requires_webauthn": True},
+ message="Passkey verification required. Please use your passkey to complete sign-in.",
+ )
+
+ if has_totp:
+ from flask import session as flask_session
+ flask_session["totp_pending_user_id"] = user.id
+ return api_response(
+ data={"requires_totp": True},
+ message="TOTP code required. Please enter your 6-digit code from your authenticator app.",
+ )
+
user_session = AuthService.create_session(user)
return api_response(
@@ -1379,6 +1566,7 @@ def create_org_ca(org_id):
from gatehouse_app.models.ssh_ca.ca import CA, KeyType
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
+ from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
from marshmallow import Schema, fields as ma_fields, validate, ValidationError as MaValidationError
from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
@@ -1448,13 +1636,16 @@ def create_org_ca(org_id):
public_key_str = private_key_obj.public_key.to_string()
fingerprint = compute_ssh_fingerprint(public_key_str)
+ # Encrypt the private key before storing in the database
+ encrypted_private_key = encrypt_ca_key(private_key_pem)
+
ca = CA(
organization_id=org_id,
name=data["name"],
description=data["description"],
ca_type=CaType(ca_type_val),
key_type=KeyType(key_type),
- private_key=private_key_pem,
+ private_key=encrypted_private_key,
public_key=public_key_str,
fingerprint=fingerprint,
default_cert_validity_hours=data["default_cert_validity_hours"],
@@ -1462,7 +1653,24 @@ def create_org_ca(org_id):
is_active=True,
)
db.session.add(ca)
- db.session.commit()
+ try:
+ db.session.commit()
+ except Exception as commit_exc:
+ db.session.rollback()
+ # Surface unique-constraint violations (soft-deleted record with same name) as a
+ # user-friendly 400 instead of a 500.
+ exc_str = str(commit_exc).lower()
+ if "uix_org_ca_name" in exc_str or "unique" in exc_str:
+ return api_response(
+ success=False,
+ message=(
+ "A CA with that name already exists in this organization "
+ "(it may have been recently deleted — choose a different name)."
+ ),
+ status=400,
+ error_type="DUPLICATE_NAME",
+ )
+ raise
return api_response(
data={"ca": ca.to_dict()},
@@ -1570,6 +1778,7 @@ def rotate_org_ca(org_id, ca_id):
from gatehouse_app.models.ssh_ca.ca import CA, KeyType
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
+ from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.models import AuditLog
from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
@@ -1609,8 +1818,11 @@ def rotate_org_ca(org_id, ca_id):
new_public_key = private_key_obj.public_key.to_string()
new_fingerprint = compute_ssh_fingerprint(new_public_key)
+ # Encrypt the new private key before storing
+ encrypted_new_private_key = encrypt_ca_key(new_private_key)
+
ca.rotate_key(
- new_private_key=new_private_key,
+ new_private_key=encrypted_new_private_key,
new_public_key=new_public_key,
new_fingerprint=new_fingerprint,
reason=reason,
diff --git a/gatehouse_app/api/v1/ssh.py b/gatehouse_app/api/v1/ssh.py
index 7491eb1..8a54ae5 100644
--- a/gatehouse_app/api/v1/ssh.py
+++ b/gatehouse_app/api/v1/ssh.py
@@ -15,6 +15,7 @@ from gatehouse_app.exceptions import (
)
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.models import AuditLog
+from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
from gatehouse_app.utils.decorators import login_required
from gatehouse_app.utils.response import api_response
@@ -78,11 +79,16 @@ def _get_or_create_system_ca():
with open(pub_key_path) as f:
pub_key = f.read().strip()
- # Load private key for the record (stored but not actually used for signing here)
+ # Load private key for the record (encrypt before storing in DB)
priv_key = ""
if os.path.exists(key_path):
with open(key_path) as f:
- priv_key = f.read()
+ raw_priv_key = f.read()
+ try:
+ from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
+ priv_key = encrypt_ca_key(raw_priv_key)
+ except Exception:
+ priv_key = raw_priv_key # fallback: store as-is if encryption unavailable
fingerprint = compute_ssh_fingerprint(pub_key)
@@ -120,7 +126,7 @@ def _get_or_create_system_ca():
return None
-def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user'):
+def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user', cert_identity=None):
"""Save a signed certificate to the ssh_certificates table.
Args:
@@ -130,6 +136,8 @@ def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=N
signing_response: SSHCertificateSigningResponse
request_ip: Client IP address
cert_type_str: 'user' or 'host' (from the sign request)
+ cert_identity: Rich OpenSSH key_id string (e.g. "user@host (Name) [org:slug]").
+ Falls back to str(ssh_key_id) when not provided.
Returns:
SSHCertificate instance or None if persistence failed
@@ -153,7 +161,7 @@ def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=N
ssh_key_id=ssh_key_id,
certificate=signing_response.certificate,
serial=signing_response.serial,
- key_id=str(ssh_key_id),
+ key_id=cert_identity or str(ssh_key_id),
cert_type=resolved_cert_type,
principals=signing_response.principals,
valid_after=signing_response.valid_after,
@@ -465,7 +473,7 @@ def sign_certificate():
# ── Check account suspension ──────────────────────────────────────────────
from gatehouse_app.utils.constants import UserStatus
- if user.status == UserStatus.SUSPENDED:
+ if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
return api_response(
success=False,
message="Your account is suspended. Contact an administrator.",
@@ -482,6 +490,18 @@ def sign_certificate():
key_id = data.get('key_id') or data.get('cert_id')
expiry_hours = data.get('expiry_hours')
+ # ── Log the request ───────────────────────────────────────────────────────
+ AuditLog.log(
+ action=AuditAction.SSH_CERT_REQUESTED,
+ user_id=user_id,
+ resource_type='SSHCertificate',
+ ip_address=request.remote_addr,
+ description=(
+ f'{user.email} requested a certificate'
+ + (f' for principals: {", ".join(requested_principals)}' if requested_principals else '')
+ ),
+ )
+
# ── Resolve which principals the user is allowed to use ──────────────────
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
@@ -601,11 +621,24 @@ def sign_certificate():
else:
policy_extensions = None # let signing service use its own defaults
+ # ── Build rich key_id identity for the OpenSSH cert ─────────────────────
+ # This appears in `ssh-keygen -L -f cert.pub` as the Key ID field and
+ # is stored in the DB cert record so it's auditable.
+ org_slugs = sorted({
+ om.organization.slug
+ for om in memberships
+ if om.organization and om.organization.deleted_at is None
+ and getattr(om.organization, 'slug', None)
+ })
+ org_slug = org_slugs[0] if org_slugs else "unknown"
+ full_name = getattr(user, 'full_name', None) or getattr(user, 'name', None) or "unknown"
+ cert_identity = f"{user.email} ({full_name}) [org:{org_slug}]"
+
signing_request = SSHCertificateSigningRequest(
ssh_public_key=ssh_key.payload,
principals=principals,
cert_type=cert_type,
- key_id=key_id,
+ key_id=cert_identity,
expiry_hours=int(expiry_hours) if expiry_hours else None,
extensions=policy_extensions,
)
@@ -620,7 +653,11 @@ def sign_certificate():
)
try:
- response = ssh_ca_service.sign_certificate(signing_request, ca_private_key=db_ca.private_key)
+ from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
+ ca_private_key_pem = decrypt_ca_key(db_ca.private_key)
+ response = ssh_ca_service.sign_certificate(
+ signing_request, ca_private_key=ca_private_key_pem, ca_obj=db_ca
+ )
except SSHCertificateError as e:
AuditLog.log(
action=AuditAction.SSH_CERT_FAILED,
@@ -649,6 +686,7 @@ def sign_certificate():
signing_response=response,
request_ip=request.remote_addr,
cert_type_str=cert_type,
+ cert_identity=cert_identity,
)
AuditLog.log(
@@ -657,9 +695,42 @@ def sign_certificate():
resource_type='SSHCertificate',
resource_id=cert_record.id if cert_record else key_id,
ip_address=request.remote_addr,
- description=f'Certificate issued for principals: {", ".join(principals)}',
+ description=(
+ f'Certificate serial={response.serial} issued for {user.email}; '
+ f'principals: {", ".join(principals)}'
+ ),
+ extra_data={
+ 'serial': response.serial,
+ 'key_id': cert_identity,
+ 'principals': principals,
+ 'ca_id': str(db_ca.id),
+ 'ssh_key_id': str(key_id),
+ },
)
+ if cert_record:
+ CertificateAuditLog.log(
+ certificate_id=cert_record.id,
+ action='issued',
+ user_id=user_id,
+ ip_address=request.remote_addr,
+ user_agent=request.headers.get('User-Agent'),
+ message=(
+ f'Certificate serial={response.serial} issued for {user.email}; '
+ f'principals: {", ".join(principals)}'
+ ),
+ extra_data={
+ 'serial': response.serial,
+ 'key_id': cert_identity,
+ 'principals': principals,
+ 'ca_id': str(db_ca.id),
+ 'ssh_key_id': str(key_id),
+ 'valid_after': response.valid_after.isoformat() if response.valid_after else None,
+ 'valid_before': response.valid_before.isoformat() if response.valid_before else None,
+ },
+ success=True,
+ )
+
result = {
'certificate': response.certificate,
'serial': response.serial,
@@ -753,6 +824,16 @@ def revoke_certificate(cert_id):
description=f'Revoked: {reason}',
)
+ CertificateAuditLog.log(
+ certificate_id=cert_id,
+ action='revoked',
+ user_id=user_id,
+ ip_address=request.remote_addr,
+ user_agent=request.headers.get('User-Agent'),
+ message=f'Certificate revoked: {reason}',
+ success=True,
+ )
+
return api_response(
success=True,
message='Certificate revoked successfully',
diff --git a/gatehouse_app/api/v1/users.py b/gatehouse_app/api/v1/users.py
index fd555cb..65c42c3 100644
--- a/gatehouse_app/api/v1/users.py
+++ b/gatehouse_app/api/v1/users.py
@@ -73,11 +73,51 @@ def delete_me():
"""
Delete current user account (soft delete).
+ Blocked if the user is the sole owner of any organization that has other
+ active members — they must transfer ownership or dissolve those organizations
+ first.
+
Returns:
200: Account deleted successfully
401: Not authenticated
+ 409: User is sole owner of one or more organizations with other members
"""
- UserService.delete_user(g.current_user, soft=True)
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.utils.constants import OrganizationRole
+
+ user = g.current_user
+
+ # Find orgs where this user is the sole owner AND other members exist.
+ owned_memberships = OrganizationMember.query.filter_by(
+ user_id=user.id,
+ role=OrganizationRole.OWNER,
+ deleted_at=None,
+ ).all()
+
+ blocked_orgs = []
+ for membership in owned_memberships:
+ org = membership.organization
+ if org.deleted_at is not None:
+ continue
+ member_count = org.get_member_count()
+ if member_count > 1:
+ blocked_orgs.append(org.name)
+
+ if blocked_orgs:
+ names = ", ".join(f'"{n}"' for n in blocked_orgs)
+ return api_response(
+ success=False,
+ message=(
+ f"You are the sole owner of {len(blocked_orgs)} organization"
+ f"{'s' if len(blocked_orgs) > 1 else ''}: {names}. "
+ "Transfer ownership or delete those organizations before deleting your account."
+ ),
+ status=409,
+ error_type="USER_IS_SOLE_OWNER",
+ error_details={"organizations": blocked_orgs},
+ )
+
+ UserService.delete_user(user, soft=True)
return api_response(
message="Account deleted successfully",
@@ -454,6 +494,31 @@ def admin_suspend_user(user_id):
if not admin_in_shared_org:
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
+ # ── Owner protection ──────────────────────────────────────────────────────
+ # An org owner cannot be suspended until they transfer ownership.
+ from gatehouse_app.utils.constants import OrganizationRole
+ owner_memberships = OrganizationMember.query.filter(
+ OrganizationMember.user_id == target.id,
+ OrganizationMember.role == OrganizationRole.OWNER,
+ OrganizationMember.deleted_at == None,
+ ).all()
+ if owner_memberships:
+ org_names = [
+ m.organization.name
+ for m in owner_memberships
+ if m.organization and not m.organization.deleted_at
+ ]
+ return api_response(
+ success=False,
+ message=(
+ f"Cannot suspend an organization owner. "
+ f"{target.email} is the owner of: {', '.join(org_names)}. "
+ "Transfer ownership to another member first."
+ ),
+ status=403,
+ error_type="OWNER_PROTECTION",
+ )
+
if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT")
@@ -645,3 +710,158 @@ def get_my_memberships():
data={"orgs": orgs_result},
message="Memberships retrieved",
)
+
+
+@api_v1_bp.route("/admin/users//delete", methods=["POST"])
+@login_required
+@full_access_required
+def admin_hard_delete_user(user_id):
+ """Permanently delete a user and ALL associated data (hard delete, irreversible).
+
+ Required body: {"confirm": true}
+
+ Pre-conditions:
+ - Caller is OWNER or ADMIN of a shared org with the target.
+ - Cannot delete yourself.
+ - Target must not be the OWNER of any active organization (transfer first).
+
+ Side-effects:
+ - All active SSH certificates are revoked before deletion.
+ - The user row and all cascaded rows are hard-deleted from the database.
+ - An audit log entry is written by the *caller* (so it is not lost with the user).
+ """
+ from gatehouse_app.models.organization.organization_member import OrganizationMember
+ from gatehouse_app.models.user.user import User as _User
+ from gatehouse_app.extensions import db as _db
+ from gatehouse_app.utils.constants import UserStatus, AuditAction, OrganizationRole
+ from gatehouse_app.services.audit_service import AuditService
+
+ caller = g.current_user
+ data = request.get_json() or {}
+
+ if not data.get("confirm"):
+ return api_response(
+ success=False,
+ message="Deletion requires explicit confirmation. Send {\"confirm\": true} to proceed.",
+ status=400,
+ error_type="CONFIRMATION_REQUIRED",
+ )
+
+ target = _User.query.filter_by(id=user_id).first()
+ if not target:
+ return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
+
+ if target.id == caller.id:
+ return api_response(
+ success=False,
+ message="Cannot delete your own account via this endpoint.",
+ status=400,
+ error_type="BAD_REQUEST",
+ )
+
+ # Caller must be OWNER/ADMIN of a shared org.
+ # Include soft-deleted memberships so that already-soft-deleted users can
+ # still be hard-deleted by an admin who shared an org with them.
+ target_org_ids = {m.organization_id for m in target.organization_memberships}
+ admin_in_shared_org = OrganizationMember.query.filter(
+ OrganizationMember.user_id == caller.id,
+ OrganizationMember.organization_id.in_(target_org_ids),
+ OrganizationMember.role.in_(["OWNER", "ADMIN"]),
+ OrganizationMember.deleted_at == None,
+ ).first()
+ if not admin_in_shared_org:
+ return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
+
+ # Block deletion if target is an org owner — they must transfer first
+ owner_memberships = OrganizationMember.query.filter(
+ OrganizationMember.user_id == target.id,
+ OrganizationMember.role == OrganizationRole.OWNER,
+ OrganizationMember.deleted_at == None,
+ ).all()
+ if owner_memberships:
+ org_names = [
+ m.organization.name
+ for m in owner_memberships
+ if m.organization and not m.organization.deleted_at
+ ]
+ return api_response(
+ success=False,
+ message=(
+ f"Cannot delete an organization owner. "
+ f"{target.email} is the owner of: {', '.join(org_names)}. "
+ "Transfer ownership to another member first."
+ ),
+ status=403,
+ error_type="OWNER_PROTECTION",
+ )
+
+ # ── Collect counts for audit metadata ────────────────────────────────────
+ from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
+ from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
+
+ ssh_key_count = SSHKey.query.filter_by(user_id=target.id, deleted_at=None).count()
+ active_cert_count = SSHCertificate.query.filter_by(
+ user_id=target.id, revoked=False
+ ).filter(SSHCertificate.deleted_at == None).count()
+
+ # ── Revoke all active SSH certificates before deletion ───────────────────
+ active_certs = SSHCertificate.query.filter_by(
+ user_id=target.id, revoked=False
+ ).filter(SSHCertificate.deleted_at == None).all()
+ for cert in active_certs:
+ try:
+ cert.revoke("account_deleted")
+ except Exception:
+ pass
+
+ if active_certs:
+ try:
+ _db.session.flush()
+ except Exception:
+ pass
+
+ # ── Hard delete ───────────────────────────────────────────────────────────
+ target_email = target.email # capture before deletion
+ target_id_str = str(target.id)
+
+ try:
+ _db.session.delete(target) # cascades to all child tables
+ _db.session.flush()
+ except Exception as exc:
+ _db.session.rollback()
+ import logging
+ logging.getLogger(__name__).error(f"Hard delete failed for {target_id_str}: {exc}")
+ return api_response(
+ success=False,
+ message="Failed to delete user account. Please try again.",
+ status=500,
+ error_type="SERVER_ERROR",
+ )
+
+ # ── Audit log (written as the caller so it survives the deletion) ─────────
+ AuditService.log_action(
+ action=AuditAction.USER_HARD_DELETE,
+ user_id=caller.id,
+ organization_id=admin_in_shared_org.organization_id,
+ resource_type="user",
+ resource_id=target_id_str,
+ description=f"Admin permanently deleted user account: {target_email}",
+ metadata={
+ "deleted_user_id": target_id_str,
+ "deleted_user_email": target_email,
+ "ssh_keys_deleted": ssh_key_count,
+ "certs_revoked": active_cert_count,
+ },
+ )
+
+ _db.session.commit()
+
+ return api_response(
+ message=f"User account {target_email} has been permanently deleted.",
+ data={
+ "deleted_user_id": target_id_str,
+ "deleted_user_email": target_email,
+ "ssh_keys_deleted": ssh_key_count,
+ "certs_revoked": active_cert_count,
+ },
+ )
diff --git a/gatehouse_app/models/ssh_ca/ca.py b/gatehouse_app/models/ssh_ca/ca.py
index 337f7c8..182d842 100644
--- a/gatehouse_app/models/ssh_ca/ca.py
+++ b/gatehouse_app/models/ssh_ca/ca.py
@@ -88,6 +88,11 @@ class CA(BaseModel):
rotated_at = db.Column(db.DateTime, nullable=True)
rotation_reason = db.Column(db.String(255), nullable=True)
+ # Monotonically-increasing serial counter. Every cert this CA issues
+ # gets the next value so serials are unique, ordered, and auditable.
+ # Protected by a row-level SELECT … FOR UPDATE in get_next_serial().
+ next_serial_number = db.Column(db.BigInteger, default=1, nullable=False)
+
# Relationships
organization = db.relationship("Organization", back_populates="cas")
certificates = db.relationship(
@@ -102,7 +107,6 @@ class CA(BaseModel):
)
__table_args__ = (
- db.UniqueConstraint("organization_id", "name", name="uix_org_ca_name"),
db.Index("idx_ca_org_active", "organization_id", "is_active"),
)
@@ -162,6 +166,28 @@ class CA(BaseModel):
self.rotation_reason = reason
self.save()
+ def get_next_serial(self) -> int:
+ """Atomically increment and return the next certificate serial number.
+
+ Uses a SELECT … FOR UPDATE row lock so concurrent requests never
+ receive the same serial. Must be called inside an active DB
+ transaction (i.e. before the final session.commit()).
+
+ Returns:
+ int: The serial number to embed in the next certificate.
+ """
+ # Re-fetch this CA row with an exclusive row lock
+ locked = (
+ db.session.query(CA)
+ .with_for_update()
+ .filter_by(id=self.id)
+ .one()
+ )
+ serial = locked.next_serial_number
+ locked.next_serial_number = serial + 1
+ db.session.flush() # write increment; commit happens in the caller
+ return serial
+
class CAPermission(BaseModel):
"""Per-user CA permission model.
diff --git a/gatehouse_app/schemas/auth_schema.py b/gatehouse_app/schemas/auth_schema.py
index 69d107b..dff1758 100644
--- a/gatehouse_app/schemas/auth_schema.py
+++ b/gatehouse_app/schemas/auth_schema.py
@@ -77,7 +77,10 @@ class TOTPVerifyEnrollmentSchema(Schema):
class TOTPVerifySchema(Schema):
"""Schema for TOTP code verification during login."""
- code = fields.Str(required=True)
+ code = fields.Str(
+ required=True,
+ validate=validate.Length(min=1),
+ )
is_backup_code = fields.Bool(load_default=False)
client_timestamp = fields.Int(
required=False,
@@ -85,6 +88,27 @@ class TOTPVerifySchema(Schema):
metadata={"description": "Client UTC timestamp in seconds since epoch for TOTP verification"},
)
+ @validates_schema
+ def validate_code_format(self, data, **kwargs):
+ """Validate code format depending on whether it's a backup code."""
+ code = data.get("code", "")
+ is_backup_code = data.get("is_backup_code", False)
+ if is_backup_code:
+ # Backup codes are 16 uppercase hex characters
+ if not code or len(code) != 16 or not all(c in "0123456789ABCDEFabcdef" for c in code):
+ raise ValidationError(
+ "Backup code must be a 16-character hexadecimal string.",
+ field_name="code",
+ )
+ else:
+ # Regular TOTP codes are exactly 6 digits
+ import re
+ if not re.match(r"^\d{6}$", code):
+ raise ValidationError(
+ "Code must be a 6-digit number.",
+ field_name="code",
+ )
+
class TOTPDisableSchema(Schema):
"""Schema for disabling TOTP."""
diff --git a/gatehouse_app/services/audit_service.py b/gatehouse_app/services/audit_service.py
index 85162fb..ed70f14 100644
--- a/gatehouse_app/services/audit_service.py
+++ b/gatehouse_app/services/audit_service.py
@@ -59,7 +59,7 @@ class AuditService:
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
- metadata=metadata,
+ extra_data=metadata,
description=description,
success=success,
error_message=error_message,
diff --git a/gatehouse_app/services/auth_service.py b/gatehouse_app/services/auth_service.py
index 4c591c7..9061f33 100644
--- a/gatehouse_app/services/auth_service.py
+++ b/gatehouse_app/services/auth_service.py
@@ -102,7 +102,7 @@ class AuthService:
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}")
- if user.status == UserStatus.SUSPENDED:
+ if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
raise AccountSuspendedError()
if user.status == UserStatus.INACTIVE:
raise AccountInactiveError()
@@ -210,6 +210,22 @@ class AuthService:
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
db.session.commit()
+ # Invalidate all other sessions so that if an attacker had a valid
+ # session token, changing the password actually locks them out.
+ # The current request's session (if any) is preserved so the user
+ # doesn't have to log in again immediately.
+ from flask import g as flask_g
+ current_session_id = getattr(flask_g, "current_session", None)
+ current_session_id = current_session_id.id if current_session_id else None
+ sessions_to_revoke = Session.query.filter(
+ Session.user_id == user.id,
+ Session.revoked_at == None, # noqa: E711
+ ).all()
+ for sess in sessions_to_revoke:
+ if sess.id != current_session_id:
+ sess.revoke(reason="Password changed")
+ db.session.commit()
+
# Log password change
AuditService.log_action(
action=AuditAction.PASSWORD_CHANGE,
@@ -482,9 +498,24 @@ class AuthService:
if not secret:
raise InvalidCredentialsError("TOTP secret not found")
+ # Replay-attack prevention: reject codes that have already been
+ # accepted within the current validity window.
+ if TOTPService.is_code_already_used(str(user.id), code):
+ AuditService.log_action(
+ action=AuditAction.TOTP_VERIFY_FAILED,
+ user_id=user.id,
+ resource_type="authentication_method",
+ resource_id=auth_method.id,
+ description="TOTP code replay attempt detected",
+ )
+ raise InvalidCredentialsError("Invalid TOTP code")
+
is_valid = TOTPService.verify_code(secret, code, client_utc_timestamp=client_utc_timestamp)
if is_valid:
+ # Mark this code as used to prevent replay within the validity window
+ TOTPService.mark_code_used(str(user.id), code)
+
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
diff --git a/gatehouse_app/services/external_auth_service.py b/gatehouse_app/services/external_auth_service.py
index f105db4..57cad76 100644
--- a/gatehouse_app/services/external_auth_service.py
+++ b/gatehouse_app/services/external_auth_service.py
@@ -736,9 +736,14 @@ class ExternalAuthService:
400,
)
- # Generate PKCE
- code_verifier = secrets.token_urlsafe(32)
- code_challenge = cls._compute_s256_challenge(code_verifier)
+ # Generate PKCE — skip for confidential clients (Google, Microsoft) that use a
+ # client_secret. Sending code_challenge to Microsoft causes it to enforce PKCE on
+ # the token exchange, which then fails. Matches the behaviour of initiate_login_flow.
+ code_verifier = None
+ code_challenge = None
+ if provider_type_str not in ('google', 'microsoft'):
+ code_verifier = secrets.token_urlsafe(32)
+ code_challenge = cls._compute_s256_challenge(code_verifier)
# Create OAuth state
state = OAuthState.create_state(
diff --git a/gatehouse_app/services/organization_service.py b/gatehouse_app/services/organization_service.py
index 9bcccda..27ec7f4 100644
--- a/gatehouse_app/services/organization_service.py
+++ b/gatehouse_app/services/organization_service.py
@@ -188,11 +188,10 @@ class OrganizationService:
Raises:
ConflictError: If user is already a member
"""
- # Check if already a member
+ # Check if already a member (active or soft-deleted — both blocked by DB unique constraint)
existing = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org.id,
- deleted_at=None,
).first()
# Development-only debug logging for membership validation
@@ -200,6 +199,25 @@ class OrganizationService:
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
if existing:
+ if existing.deleted_at is not None:
+ # Reactivate the soft-deleted membership with the new role
+ existing.deleted_at = None
+ existing.role = role
+ existing.invited_by_id = inviter_id
+ existing.invited_at = datetime.now(timezone.utc)
+ existing.joined_at = datetime.now(timezone.utc)
+ existing.save()
+
+ AuditService.log_action(
+ action=AuditAction.ORG_MEMBER_ADD,
+ user_id=inviter_id,
+ organization_id=org.id,
+ resource_type="organization_member",
+ resource_id=existing.id,
+ metadata={"added_user_id": user_id, "role": role.value},
+ description=f"Member re-added to organization with role: {role.value}",
+ )
+ return existing
raise ConflictError("User is already a member of this organization")
# Create membership
diff --git a/gatehouse_app/services/ssh_ca_signing_service.py b/gatehouse_app/services/ssh_ca_signing_service.py
index 308afc1..1502ab8 100644
--- a/gatehouse_app/services/ssh_ca_signing_service.py
+++ b/gatehouse_app/services/ssh_ca_signing_service.py
@@ -192,13 +192,19 @@ class SSHCASigningService:
self,
signing_request: SSHCertificateSigningRequest,
ca_private_key: Optional[str] = None,
- ) -> SSHCertificateSigningResponse:
+ ca_obj=None,
+ ) -> "SSHCertificateSigningResponse":
"""Sign an SSH certificate.
Args:
signing_request: SSHCertificateSigningRequest instance
ca_private_key: CA private key in PEM format. If not provided,
loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var)
+ ca_obj: Optional CA model instance. When supplied its monotonic
+ serial counter is incremented atomically (SELECT FOR UPDATE)
+ and the resulting integer is embedded in the certificate's
+ serial field. This ensures every issued cert has a unique,
+ ordered, auditable serial number.
Returns:
SSHCertificateSigningResponse with signed certificate
@@ -245,13 +251,27 @@ class SSHCASigningService:
valid_before = now + timedelta(hours=expiry_hours)
# Set certificate fields
- cert_type = 1 if signing_request.cert_type == "user" else 0
+ # sshkey-tools: user=1, host=2 (not 0)
+ cert_type = 1 if signing_request.cert_type == "user" else 2
certificate.fields.cert_type = cert_type
certificate.fields.key_id = signing_request.key_id
certificate.fields.principals = signing_request.principals
certificate.fields.valid_after = now
certificate.fields.valid_before = valid_before
+
+ # ── Serial number ────────────────────────────────────────────────
+ # If a CA object is provided, use its monotonic counter so every
+ # certificate gets a unique, ordered, auditable serial. The
+ # counter increment is flushed inside get_next_serial(); the
+ # caller's commit() persists it atomically with the cert record.
+ if ca_obj is not None:
+ assigned_serial = ca_obj.get_next_serial()
+ certificate.fields.serial = assigned_serial
+ self.logger.debug(
+ f"Assigned serial {assigned_serial} from CA {ca_obj.id}"
+ )
+ # ─────────────────────────────────────────────────────────────────
# Set extensions — prefer policy-provided list, fall back to standard set
extensions = signing_request.extensions
@@ -276,8 +296,13 @@ class SSHCASigningService:
self.logger.error(f"Certificate verification failed: {str(e)}")
raise SSHCASigningError(f"Certificate verification failed: {str(e)}")
- # Extract serial from certificate
- serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
+ # Extract serial from certificate — use the integer we assigned
+ # when ca_obj was provided, otherwise fall back to whatever the
+ # library generated.
+ if ca_obj is not None:
+ serial = str(assigned_serial)
+ else:
+ serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
# Build response
cert_string = certificate.to_string()
diff --git a/gatehouse_app/services/totp_service.py b/gatehouse_app/services/totp_service.py
index b71f56a..c667e3e 100644
--- a/gatehouse_app/services/totp_service.py
+++ b/gatehouse_app/services/totp_service.py
@@ -11,10 +11,51 @@ from gatehouse_app.extensions import bcrypt
logger = logging.getLogger(__name__)
+# TOTP codes are valid for at most (2*window + 1) * 30s steps.
+# With window=1 that's 3 steps = 90 seconds. We use a slightly
+# generous TTL of 95 seconds to account for clock skew at boundaries.
+_TOTP_USED_CODE_TTL = 95
+
class TOTPService:
"""Service for TOTP operations."""
+ # ------------------------------------------------------------------
+ # Replay-attack prevention helpers
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def _used_key(user_id: str, code: str) -> str:
+ return f"totp:used:{user_id}:{code}"
+
+ @staticmethod
+ def is_code_already_used(user_id: str, code: str) -> bool:
+ """Return True if *code* has already been accepted for *user_id*
+ within the current validity window (prevents replay attacks)."""
+ try:
+ from gatehouse_app.extensions import redis_client
+ if redis_client is None:
+ return False
+ return redis_client.exists(TOTPService._used_key(user_id, code)) == 1
+ except Exception:
+ logger.warning("Redis unavailable for TOTP replay check; allowing code")
+ return False
+
+ @staticmethod
+ def mark_code_used(user_id: str, code: str) -> None:
+ """Record *code* as consumed for *user_id* so it cannot be reused."""
+ try:
+ from gatehouse_app.extensions import redis_client
+ if redis_client is None:
+ return
+ redis_client.setex(
+ TOTPService._used_key(user_id, code),
+ _TOTP_USED_CODE_TTL,
+ "1",
+ )
+ except Exception:
+ logger.warning("Redis unavailable; TOTP used-code not recorded")
+
@staticmethod
def generate_secret() -> str:
"""
diff --git a/gatehouse_app/services/webauthn_service.py b/gatehouse_app/services/webauthn_service.py
index cba6204..9f953d3 100644
--- a/gatehouse_app/services/webauthn_service.py
+++ b/gatehouse_app/services/webauthn_service.py
@@ -641,6 +641,26 @@ class WebAuthnService:
)
return True
+
+ @classmethod
+ def credential_belongs_to_user(cls, credential_id: str, user: User) -> bool:
+ """Check whether *credential_id* exists and belongs to *user*.
+
+ Args:
+ credential_id: The credential ID to look up
+ user: User instance
+
+ Returns:
+ True if the credential exists and belongs to this user, False otherwise.
+ """
+ auth_method = AuthenticationMethod.query.filter_by(
+ user_id=user.id,
+ method_type=AuthMethodType.WEBAUTHN,
+ deleted_at=None,
+ ).first()
+ if not auth_method or not auth_method.provider_data:
+ return False
+ return auth_method.provider_data.get("credential_id") == credential_id
@classmethod
def rename_credential(cls, credential_id: str, user: User, name: str) -> bool:
diff --git a/gatehouse_app/utils/ca_key_encryption.py b/gatehouse_app/utils/ca_key_encryption.py
new file mode 100644
index 0000000..183e3c3
--- /dev/null
+++ b/gatehouse_app/utils/ca_key_encryption.py
@@ -0,0 +1,206 @@
+"""Encryption helpers for CA private keys stored in the database.
+
+CA private keys are encrypted at rest using Fernet (AES-128-CBC + HMAC-SHA256)
+from the ``cryptography`` package. The encryption key is derived from the
+``CA_ENCRYPTION_KEY`` environment variable (or ``Flask.config["CA_ENCRYPTION_KEY"]``).
+
+Key derivation
+--------------
+Fernet requires a URL-safe base64-encoded 32-byte key. We accept any string
+from the env and derive the actual Fernet key using SHA-256 so that operators
+can supply human-readable secrets without having to pre-encode them.
+
+Envelope format
+---------------
+Encrypted values are stored as the string::
+
+ $fernet$
+
+The ``$fernet$`` prefix lets the code distinguish already-encrypted values from
+legacy plaintext PEM keys so that the migration path is safe and idempotent.
+
+Usage
+-----
+Encrypt before storing::
+
+ from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
+ ca.private_key = encrypt_ca_key(private_key_pem)
+
+Decrypt before use::
+
+ from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
+ plaintext_pem = decrypt_ca_key(ca.private_key)
+"""
+import base64
+import hashlib
+import logging
+import os
+
+from cryptography.fernet import Fernet, InvalidToken
+
+logger = logging.getLogger(__name__)
+
+# Prefix that marks a stored value as Fernet-encrypted
+_FERNET_PREFIX = "$fernet$"
+
+
+class CAKeyEncryptionError(Exception):
+ """Raised when CA key encryption or decryption fails."""
+
+
+def _get_fernet() -> Fernet:
+ """Build a Fernet instance from the configured encryption key.
+
+ Looks up ``CA_ENCRYPTION_KEY`` in the environment first, then falls back to
+ the Flask app config (if a request context is active).
+
+ Raises:
+ CAKeyEncryptionError: if no key is configured or it is the insecure
+ placeholder value in a production-like environment.
+ """
+ raw_key = os.environ.get("CA_ENCRYPTION_KEY")
+
+ if not raw_key:
+ # Try Flask config if we're inside an app context
+ try:
+ from flask import current_app
+ raw_key = current_app.config.get("CA_ENCRYPTION_KEY")
+ except RuntimeError:
+ pass # No app context
+
+ if not raw_key:
+ raise CAKeyEncryptionError(
+ "CA_ENCRYPTION_KEY is not set. "
+ "Set this environment variable before starting the application."
+ )
+
+ # Warn loudly when running with the placeholder in a non-test environment
+ env_name = os.environ.get("FLASK_ENV", "").lower()
+ if raw_key.startswith("dev-") and env_name not in ("development", "testing", "test"):
+ logger.warning(
+ "CA_ENCRYPTION_KEY appears to be a development placeholder. "
+ "Set a strong random key for production environments."
+ )
+
+ # Derive a 32-byte key from the raw secret via SHA-256, then URL-safe base64
+ key_bytes = hashlib.sha256(raw_key.encode()).digest()
+ fernet_key = base64.urlsafe_b64encode(key_bytes)
+ return Fernet(fernet_key)
+
+
+def encrypt_ca_key(plaintext_pem: str) -> str:
+ """Encrypt a CA private key PEM string.
+
+ Idempotent: already-encrypted values are returned unchanged.
+
+ Args:
+ plaintext_pem: CA private key in OpenSSH/PEM format.
+
+ Returns:
+ Encrypted string with ``$fernet$`` prefix, safe for database storage.
+
+ Raises:
+ CAKeyEncryptionError: if the key cannot be encrypted.
+ """
+ if not plaintext_pem:
+ raise CAKeyEncryptionError("Cannot encrypt an empty key")
+
+ # Already encrypted — do not double-encrypt
+ if plaintext_pem.startswith(_FERNET_PREFIX):
+ return plaintext_pem
+
+ try:
+ fernet = _get_fernet()
+ token = fernet.encrypt(plaintext_pem.encode()).decode()
+ return f"{_FERNET_PREFIX}{token}"
+ except CAKeyEncryptionError:
+ raise
+ except Exception as exc:
+ raise CAKeyEncryptionError(f"Failed to encrypt CA key: {exc}") from exc
+
+
+def decrypt_ca_key(stored_value: str) -> str:
+ """Decrypt a CA private key retrieved from the database.
+
+ Idempotent: plaintext (legacy) values are returned unchanged so that the
+ system continues to work while a migration encrypts existing rows.
+
+ Args:
+ stored_value: Value from ``CA.private_key`` column.
+
+ Returns:
+ Plaintext PEM string ready for use with ``sshkey_tools``.
+
+ Raises:
+ CAKeyEncryptionError: if decryption fails (wrong key, corrupted data).
+ """
+ if not stored_value:
+ raise CAKeyEncryptionError("Cannot decrypt an empty value")
+
+ # Legacy plaintext key — return as-is
+ if not stored_value.startswith(_FERNET_PREFIX):
+ logger.warning(
+ "CA private key appears to be stored as plaintext. "
+ "Run the migration to encrypt existing keys."
+ )
+ return stored_value
+
+ token = stored_value[len(_FERNET_PREFIX):]
+ try:
+ fernet = _get_fernet()
+ return fernet.decrypt(token.encode()).decode()
+ except InvalidToken as exc:
+ raise CAKeyEncryptionError(
+ "CA key decryption failed — the CA_ENCRYPTION_KEY may be incorrect "
+ "or the stored key is corrupted."
+ ) from exc
+ except CAKeyEncryptionError:
+ raise
+ except Exception as exc:
+ raise CAKeyEncryptionError(f"Unexpected decryption error: {exc}") from exc
+
+
+def is_encrypted(stored_value: str) -> bool:
+ """Return True if the stored value has the ``$fernet$`` envelope.
+
+ Args:
+ stored_value: Value from ``CA.private_key`` column.
+ """
+ return bool(stored_value and stored_value.startswith(_FERNET_PREFIX))
+
+
+def reencrypt_ca_key(stored_value: str, old_raw_key: str, new_raw_key: str) -> str:
+ """Re-encrypt a CA key with a new encryption key (for key rotation).
+
+ Args:
+ stored_value: Current value from ``CA.private_key`` (may or may not be encrypted).
+ old_raw_key: The current ``CA_ENCRYPTION_KEY`` value (raw secret string).
+ new_raw_key: The new ``CA_ENCRYPTION_KEY`` value to encrypt with.
+
+ Returns:
+ New encrypted envelope string.
+
+ Raises:
+ CAKeyEncryptionError: if decryption or re-encryption fails.
+ """
+ # Decrypt with old key
+ if stored_value.startswith(_FERNET_PREFIX):
+ token = stored_value[len(_FERNET_PREFIX):]
+ old_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(old_raw_key.encode()).digest())
+ try:
+ plaintext = Fernet(old_key_bytes).decrypt(token.encode()).decode()
+ except InvalidToken as exc:
+ raise CAKeyEncryptionError(
+ "Re-encryption failed: could not decrypt with the old key."
+ ) from exc
+ else:
+ # Plaintext
+ plaintext = stored_value
+
+ # Re-encrypt with new key
+ new_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(new_raw_key.encode()).digest())
+ try:
+ token = Fernet(new_key_bytes).encrypt(plaintext.encode()).decode()
+ return f"{_FERNET_PREFIX}{token}"
+ except Exception as exc:
+ raise CAKeyEncryptionError(f"Re-encryption with new key failed: {exc}") from exc
diff --git a/gatehouse_app/utils/constants.py b/gatehouse_app/utils/constants.py
index a3983f3..2a99825 100644
--- a/gatehouse_app/utils/constants.py
+++ b/gatehouse_app/utils/constants.py
@@ -61,6 +61,7 @@ class AuditAction(str, Enum):
USER_REGISTER = "user.register"
USER_UPDATE = "user.update"
USER_DELETE = "user.delete"
+ USER_HARD_DELETE = "user.hard_delete"
USER_SUSPEND = "user.suspend"
USER_UNSUSPEND = "user.unsuspend"
PASSWORD_CHANGE = "user.password_change"
@@ -73,6 +74,7 @@ class AuditAction(str, Enum):
ORG_MEMBER_ADD = "org.member.add"
ORG_MEMBER_REMOVE = "org.member.remove"
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
+ ORG_OWNERSHIP_TRANSFERRED = "org.ownership.transferred"
# Session actions
SESSION_CREATE = "session.create"
diff --git a/migrations/versions/015_add_user_suspend_audit_actions.py b/migrations/versions/015_add_user_suspend_audit_actions.py
new file mode 100644
index 0000000..06834da
--- /dev/null
+++ b/migrations/versions/015_add_user_suspend_audit_actions.py
@@ -0,0 +1,37 @@
+"""Add USER_SUSPEND and USER_UNSUSPEND to auditaction enum.
+
+Revision ID: 015_add_user_suspend_audit_actions
+Revises: 014_add_dept_cert_policy
+Create Date: 2026-03-02
+
+USER_SUSPEND and USER_UNSUSPEND were added to the Python AuditAction enum
+but were never synced to the PostgreSQL auditaction type, causing a
+DataError (invalid enum value) whenever an admin suspends or unsuspends a user.
+"""
+from alembic import op
+
+revision = "015_user_suspend_audit"
+down_revision = "014_add_dept_cert_policy"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ for val in ("USER_SUSPEND", "USER_UNSUSPEND"):
+ op.execute(f"""
+ DO $$
+ BEGIN
+ IF NOT EXISTS (
+ SELECT 1 FROM pg_enum
+ WHERE enumlabel = '{val}'
+ AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
+ ) THEN
+ ALTER TYPE auditaction ADD VALUE '{val}';
+ END IF;
+ END$$;
+ """)
+
+
+def downgrade():
+ # PostgreSQL does not support removing enum values; downgrade is a no-op.
+ pass
diff --git a/migrations/versions/016_encrypt_existing_ca_keys.py b/migrations/versions/016_encrypt_existing_ca_keys.py
new file mode 100644
index 0000000..91acdda
--- /dev/null
+++ b/migrations/versions/016_encrypt_existing_ca_keys.py
@@ -0,0 +1,168 @@
+"""Encrypt existing plaintext CA private keys at rest.
+
+Revision ID: 016_encrypt_existing_ca_keys
+Revises: 015_add_user_suspend_audit_actions
+Create Date: 2026-03-02
+
+All CA private keys created before this migration were stored as plaintext PEM
+strings in the ``cas.private_key`` column. This migration detects those rows
+(by checking for the absence of the ``$fernet$`` prefix that encrypted values
+carry) and re-encrypts them with the key derived from ``CA_ENCRYPTION_KEY``.
+
+The migration is safe to re-run: already-encrypted rows are left untouched.
+
+Prerequisites
+-------------
+``CA_ENCRYPTION_KEY`` must be set in the environment before running this
+migration. The same value must be configured for the running application.
+
+To roll back to plaintext (downgrade):
+The ``downgrade()`` function decrypts all rows back to plaintext PEM. This is
+provided only for emergency rollback and should not be used in production once
+the system has been running with encrypted keys.
+"""
+import os
+import base64
+import hashlib
+import logging
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.orm import Session
+
+logger = logging.getLogger(__name__)
+
+# Alembic revision identifiers
+revision = "016_encrypt_ca_keys"
+down_revision = "015_user_suspend_audit"
+branch_labels = None
+depends_on = None
+
+_FERNET_PREFIX = "$fernet$"
+
+
+def _get_fernet():
+ """Build a Fernet instance from CA_ENCRYPTION_KEY env var."""
+ from cryptography.fernet import Fernet
+
+ raw_key = os.environ.get("CA_ENCRYPTION_KEY")
+ if not raw_key:
+ raise RuntimeError(
+ "CA_ENCRYPTION_KEY environment variable is not set. "
+ "Set it before running this migration."
+ )
+ key_bytes = base64.urlsafe_b64encode(hashlib.sha256(raw_key.encode()).digest())
+ return Fernet(key_bytes)
+
+
+def upgrade():
+ """Encrypt plaintext CA private keys."""
+ bind = op.get_bind()
+ session = Session(bind=bind)
+
+ try:
+ fernet = _get_fernet()
+ except RuntimeError as exc:
+ raise RuntimeError(str(exc)) from exc
+
+ # Fetch all non-deleted CA rows
+ rows = session.execute(
+ sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
+ ).fetchall()
+
+ encrypted_count = 0
+ skipped_count = 0
+
+ for row in rows:
+ ca_id, private_key = row[0], row[1]
+
+ if not private_key:
+ logger.warning(f"CA {ca_id} has empty private_key — skipping")
+ skipped_count += 1
+ continue
+
+ if private_key.startswith(_FERNET_PREFIX):
+ # Already encrypted
+ skipped_count += 1
+ continue
+
+ # Encrypt
+ try:
+ token = fernet.encrypt(private_key.encode()).decode()
+ encrypted_value = f"{_FERNET_PREFIX}{token}"
+ session.execute(
+ sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
+ {"pk": encrypted_value, "id": ca_id},
+ )
+ encrypted_count += 1
+ logger.info(f"Encrypted private key for CA {ca_id}")
+ except Exception as exc:
+ session.rollback()
+ raise RuntimeError(
+ f"Failed to encrypt private key for CA {ca_id}: {exc}"
+ ) from exc
+
+ session.commit()
+ logger.info(
+ f"CA key encryption migration complete: "
+ f"{encrypted_count} encrypted, {skipped_count} skipped"
+ )
+ print(
+ f" [016_encrypt_ca_keys] {encrypted_count} CA private key(s) encrypted, "
+ f"{skipped_count} already encrypted or empty."
+ )
+
+
+def downgrade():
+ """Decrypt CA private keys back to plaintext (emergency rollback only)."""
+ bind = op.get_bind()
+ session = Session(bind=bind)
+
+ try:
+ fernet = _get_fernet()
+ except RuntimeError as exc:
+ raise RuntimeError(str(exc)) from exc
+
+ rows = session.execute(
+ sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
+ ).fetchall()
+
+ decrypted_count = 0
+ skipped_count = 0
+
+ for row in rows:
+ ca_id, private_key = row[0], row[1]
+
+ if not private_key or not private_key.startswith(_FERNET_PREFIX):
+ skipped_count += 1
+ continue
+
+ token = private_key[len(_FERNET_PREFIX):]
+ try:
+ from cryptography.fernet import InvalidToken
+ try:
+ plaintext = fernet.decrypt(token.encode()).decode()
+ except InvalidToken as exc:
+ raise RuntimeError(
+ f"Downgrade failed: cannot decrypt CA {ca_id} — wrong key or corrupted data."
+ ) from exc
+
+ session.execute(
+ sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
+ {"pk": plaintext, "id": ca_id},
+ )
+ decrypted_count += 1
+ logger.warning(f"Decrypted (plaintext restore) private key for CA {ca_id}")
+ except RuntimeError:
+ session.rollback()
+ raise
+
+ session.commit()
+ logger.warning(
+ f"CA key decryption (downgrade) complete: "
+ f"{decrypted_count} decrypted, {skipped_count} skipped"
+ )
+ print(
+ f" [016_encrypt_ca_keys] DOWNGRADE: {decrypted_count} CA private key(s) "
+ f"decrypted to plaintext. WARNING: keys are now unencrypted at rest."
+ )
diff --git a/migrations/versions/017_add_ca_serial_counter.py b/migrations/versions/017_add_ca_serial_counter.py
new file mode 100644
index 0000000..94b85e9
--- /dev/null
+++ b/migrations/versions/017_add_ca_serial_counter.py
@@ -0,0 +1,37 @@
+"""Add monotonic serial counter to CAs table.
+
+Each CA now owns a `next_serial_number` (BigInteger) that is atomically
+incremented every time a certificate is signed. This guarantees:
+ - Serials are unique per CA
+ - Serials are monotonically increasing (auditable, no gaps by accident)
+ - The value embedded in the OpenSSH certificate matches what is stored
+ in the `ssh_certificates.serial` column
+
+Revision ID: 017_add_ca_serial_counter
+Revises: 016_encrypt_ca_keys
+Create Date: 2026-03-02
+"""
+from alembic import op
+import sqlalchemy as sa
+
+revision = "017_add_ca_serial_counter"
+down_revision = "016_encrypt_ca_keys"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ with op.batch_alter_table("cas", schema=None) as batch_op:
+ batch_op.add_column(
+ sa.Column(
+ "next_serial_number",
+ sa.BigInteger(),
+ nullable=False,
+ server_default="1",
+ )
+ )
+
+
+def downgrade():
+ with op.batch_alter_table("cas", schema=None) as batch_op:
+ batch_op.drop_column("next_serial_number")
diff --git a/migrations/versions/018_add_ownership_and_hard_delete_audit_actions.py b/migrations/versions/018_add_ownership_and_hard_delete_audit_actions.py
new file mode 100644
index 0000000..7cbbd32
--- /dev/null
+++ b/migrations/versions/018_add_ownership_and_hard_delete_audit_actions.py
@@ -0,0 +1,52 @@
+"""Add ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE to auditaction enum.
+
+Revision ID: 018_audit_enum_values
+Revises: 017_add_ca_serial_counter
+Create Date: 2026-03-02
+
+ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE were added to the Python
+AuditAction enum but were never synced to the PostgreSQL auditaction type,
+causing a DataError (invalid enum value) when transferring org ownership
+or hard-deleting a user.
+"""
+from alembic import op
+
+revision = "018_audit_enum_values"
+down_revision = "017_add_ca_serial_counter"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ALTER TYPE ... ADD VALUE cannot run inside a transaction block in PostgreSQL.
+ # Alembic has already opened a transaction on the connection by the time our
+ # upgrade() runs, so we must:
+ # 1. Roll back that open transaction on the raw psycopg2 connection.
+ # 2. Switch to autocommit so the ALTER TYPE runs outside any transaction.
+ # 3. Restore the previous state afterwards.
+ conn = op.get_bind()
+ # SQLAlchemy 2.x: conn.connection is a _ConnectionFairy; .driver_connection is psycopg2
+ fairy = conn.connection
+ raw = getattr(fairy, "driver_connection", None) or getattr(fairy, "dbapi_connection", fairy)
+ # Roll back the open transaction so psycopg2 allows us to change autocommit.
+ raw.rollback()
+ old_autocommit = raw.autocommit
+ raw.autocommit = True
+ try:
+ with raw.cursor() as cur:
+ for val in ("ORG_OWNERSHIP_TRANSFERRED", "USER_HARD_DELETE"):
+ cur.execute(
+ "SELECT 1 FROM pg_enum "
+ "WHERE enumlabel = %s "
+ "AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')",
+ (val,),
+ )
+ if not cur.fetchone():
+ cur.execute(f"ALTER TYPE auditaction ADD VALUE '{val}'")
+ finally:
+ raw.autocommit = old_autocommit
+
+
+def downgrade():
+ # PostgreSQL does not support removing enum values; downgrade is a no-op.
+ pass