Fix(Feat): CA, Audits, Rte Limit

CA Encryption, Serials, Rate Limiter, Account suspension blocks login
Transfer Ownership & Delete Account
This commit is contained in:
2026-03-02 23:53:51 +05:45
parent be87fd90b1
commit 5250d18eb0
23 changed files with 1399 additions and 34 deletions
+1 -2
View File
@@ -51,8 +51,7 @@ class MyServer(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
self.wfile.write(bytes("<html><head><title>OIDC Workflow Tool</title></head>", "utf-8")) self.wfile.write(bytes("<html><head><title>OIDC Workflow Tool</title></head>", "utf-8"))
self.wfile.write(bytes("<body><p>The token has been received</p>", "utf-8")) self.wfile.write(bytes("<body><p>The token has been received</p>", "utf-8"))
self.wfile.write(bytes("<p>Window closing in <span id='countdown'>5</span> seconds...</p>", "utf-8")) self.wfile.write(bytes("<p>You may now close this window.</p>", "utf-8"))
self.wfile.write(bytes("<script>var count = 5; setInterval(function() { count--; document.getElementById('countdown').textContent = count; if (count === 0) window.close(); }, 1000);</script>", "utf-8"))
self.wfile.write(bytes("</body></html>", "utf-8")) self.wfile.write(bytes("</body></html>", "utf-8"))
parsed_url = urlparse(self.path) parsed_url = urlparse(self.path)
+12
View File
@@ -29,6 +29,11 @@ class BaseConfig:
# Encryption key for sensitive data (client secrets, tokens, etc.) # Encryption key for sensitive data (client secrets, tokens, etc.)
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "dev-encryption-key-change-in-production") ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "dev-encryption-key-change-in-production")
# Encryption key for CA private keys stored in the database.
# Must be set to a strong random secret in production.
# Any string is accepted — it is SHA-256 derived to a 32-byte Fernet key internally.
CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "dev-ca-encryption-key-change-in-production")
# Session configuration for WebAuthn cross-origin support # Session configuration for WebAuthn cross-origin support
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "True").lower() == "true" SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "True").lower() == "true"
SESSION_COOKIE_HTTPONLY = True SESSION_COOKIE_HTTPONLY = True
@@ -72,6 +77,13 @@ class BaseConfig:
RATELIMIT_STORAGE_URL = os.getenv("RATELIMIT_STORAGE_URL", "redis://localhost:6379/1") RATELIMIT_STORAGE_URL = os.getenv("RATELIMIT_STORAGE_URL", "redis://localhost:6379/1")
RATELIMIT_DEFAULT = "100/hour" RATELIMIT_DEFAULT = "100/hour"
# Per-endpoint auth rate limits (override via env vars for each environment)
RATELIMIT_AUTH_REGISTER = os.getenv("RATELIMIT_AUTH_REGISTER", "10 per minute; 50 per hour")
RATELIMIT_AUTH_LOGIN = os.getenv("RATELIMIT_AUTH_LOGIN", "20 per minute; 100 per hour")
RATELIMIT_AUTH_TOTP_VERIFY = os.getenv("RATELIMIT_AUTH_TOTP_VERIFY", "20 per minute; 100 per hour")
RATELIMIT_AUTH_FORGOT_PASSWORD = os.getenv("RATELIMIT_AUTH_FORGOT_PASSWORD", "5 per minute; 20 per hour")
RATELIMIT_AUTH_RESET_PASSWORD = os.getenv("RATELIMIT_AUTH_RESET_PASSWORD", "10 per minute; 30 per hour")
# Logging # Logging
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true" LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true"
+3
View File
@@ -12,6 +12,9 @@ class TestingConfig(BaseConfig):
# Explicitly set SECRET_KEY for testing # Explicitly set SECRET_KEY for testing
SECRET_KEY = os.getenv("SECRET_KEY", "test-secret-key-for-testing") SECRET_KEY = os.getenv("SECRET_KEY", "test-secret-key-for-testing")
# CA key encryption — use a fixed test key so tests are deterministic
CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "test-ca-encryption-key-fixed-for-tests")
# Use in-memory SQLite for testing # Use in-memory SQLite for testing
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:" SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
SQLALCHEMY_ECHO = False SQLALCHEMY_ECHO = False
+68 -1
View File
@@ -22,7 +22,11 @@ from gatehouse_app.extensions import bcrypt as flask_bcrypt
from gatehouse_app.extensions import redis_client as _redis_client_ref # may be None until app init from gatehouse_app.extensions import redis_client as _redis_client_ref # may be None until app init
from gatehouse_app.models import User, OIDCClient from gatehouse_app.models import User, OIDCClient
from gatehouse_app.models.organization.organization import Organization from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError from gatehouse_app.exceptions.auth_exceptions import (
InvalidCredentialsError,
AccountSuspendedError,
AccountInactiveError,
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers for Redis-backed OIDC pending state # Helpers for Redis-backed OIDC pending state
@@ -343,6 +347,20 @@ def oidc_complete():
user_id = str(gh_session.user_id) user_id = str(gh_session.user_id)
# Check the user is still active (not suspended after session was issued)
from gatehouse_app.models.user.user import User as _User
from gatehouse_app.utils.constants import UserStatus
_complete_user = _User.query.filter_by(id=user_id, deleted_at=None).first()
if not _complete_user or _complete_user.status in (
UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED, UserStatus.INACTIVE
):
return api_response(
success=False,
message="Your account is not active or has been suspended.",
status=403,
error_type="ACCOUNT_SUSPENDED",
)
# Retrieve stashed OIDC params (consume = True removes from Redis atomically) # Retrieve stashed OIDC params (consume = True removes from Redis atomically)
params = _fetch_oidc_params(oidc_session_id, consume=True) params = _fetch_oidc_params(oidc_session_id, consume=True)
if not params: if not params:
@@ -565,6 +583,28 @@ def oidc_authorize():
session["oidc_user_id"] = user_id session["oidc_user_id"] = user_id
logger.debug("[OIDC] User authentication successful: user_id=%s, email=%s", user_id, email) logger.debug("[OIDC] User authentication successful: user_id=%s, email=%s", user_id, email)
except AccountSuspendedError:
logger.debug("[OIDC] User authentication failed: account suspended for email=%s", email)
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type,
error="Your account has been suspended. Please contact an administrator.",
)
except AccountInactiveError:
logger.debug("[OIDC] User authentication failed: account inactive for email=%s", email)
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type,
error="Your account is not active. Please verify your email.",
)
except InvalidCredentialsError: except InvalidCredentialsError:
logger.debug("[OIDC] User authentication failed: invalid credentials for email=%s", email) logger.debug("[OIDC] User authentication failed: invalid credentials for email=%s", email)
return _show_login_page( return _show_login_page(
@@ -601,6 +641,33 @@ def oidc_authorize():
logger.debug("[OIDC] Redirecting with error: server_error (user not found)") logger.debug("[OIDC] Redirecting with error: server_error (user not found)")
return _redirect_with_error(redirect_uri, "server_error", "User not found", state) return _redirect_with_error(redirect_uri, "server_error", "User not found", state)
# Check account is still active (user could have been suspended after session start)
from gatehouse_app.utils.constants import UserStatus as _UserStatus
if user.status in (_UserStatus.SUSPENDED, _UserStatus.COMPLIANCE_SUSPENDED):
session.pop("oidc_user_id", None) # clear stale session
logger.debug("[OIDC] User is suspended, clearing session and showing login error: user_id=%s", user_id)
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type,
error="Your account has been suspended. Please contact an administrator.",
)
if user.status == _UserStatus.INACTIVE:
session.pop("oidc_user_id", None)
logger.debug("[OIDC] User is inactive, clearing session and showing login error: user_id=%s", user_id)
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type,
error="Your account is not active. Please verify your email.",
)
logger.debug("[OIDC] Generating authorization code...") logger.debug("[OIDC] Generating authorization code...")
logger.debug("[OIDC] Authorization code params: client_id=%s, user_id=%s, redirect_uri=%s", client_id, user_id, redirect_uri) logger.debug("[OIDC] Authorization code params: client_id=%s, user_id=%s, redirect_uri=%s", client_id, user_id, redirect_uri)
logger.debug("[OIDC] Authorization code params: scopes=%s, state=%s, nonce=%s", valid_scopes, state, nonce) logger.debug("[OIDC] Authorization code params: scopes=%s, state=%s, nonce=%s", valid_scopes, state, nonce)
+81 -2
View File
@@ -4,6 +4,7 @@ import logging
from flask import request, session, g, jsonify, current_app from flask import request, session, g, jsonify, current_app
from marshmallow import ValidationError from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.extensions import limiter
from gatehouse_app.utils.response import api_response from gatehouse_app.utils.response import api_response
from gatehouse_app.schemas.auth_schema import ( from gatehouse_app.schemas.auth_schema import (
RegisterSchema, RegisterSchema,
@@ -32,6 +33,7 @@ from gatehouse_app.exceptions.validation_exceptions import ConflictError, NotFou
@api_v1_bp.route("/auth/register", methods=["POST"]) @api_v1_bp.route("/auth/register", methods=["POST"])
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_REGISTER"])
def register(): def register():
""" """
Register a new user. Register a new user.
@@ -135,6 +137,7 @@ def register():
@api_v1_bp.route("/auth/login", methods=["POST"]) @api_v1_bp.route("/auth/login", methods=["POST"])
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_LOGIN"])
def login(): def login():
""" """
Login user. Login user.
@@ -325,8 +328,13 @@ def get_current_user():
data={ data={
"user": user.to_dict(), "user": user.to_dict(),
"organizations": [ "organizations": [
{"id": org.id, "name": org.name, "slug": org.slug} {
for org in user.get_organizations() "id": membership.organization.id,
"name": membership.organization.name,
"slug": membership.organization.slug,
"role": membership.role,
}
for membership in user.organization_memberships
], ],
}, },
message="User retrieved successfully", message="User retrieved successfully",
@@ -478,6 +486,7 @@ def verify_totp_enrollment():
@api_v1_bp.route("/auth/totp/verify", methods=["POST"]) @api_v1_bp.route("/auth/totp/verify", methods=["POST"])
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_TOTP_VERIFY"])
def verify_totp(): def verify_totp():
""" """
Verify TOTP code during login. Verify TOTP code during login.
@@ -520,6 +529,18 @@ def verify_totp():
error_type="AUTHENTICATION_ERROR", error_type="AUTHENTICATION_ERROR",
) )
# Check account suspension before completing TOTP verification
from gatehouse_app.utils.constants import UserStatus
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
session.pop("totp_pending_user_id", None)
session.pop("webauthn_pending_user_id", None)
return api_response(
success=False,
message="Account is suspended. Contact an administrator.",
status=403,
error_type="ACCOUNT_SUSPENDED",
)
# Verify TOTP code # Verify TOTP code
AuthService.authenticate_with_totp( AuthService.authenticate_with_totp(
user, user,
@@ -909,6 +930,17 @@ def begin_webauthn_login():
error_type="NOT_FOUND", error_type="NOT_FOUND",
) )
# Check account suspension before proceeding
from gatehouse_app.utils.constants import UserStatus
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
logger.warning(f"WebAuthn login begin - suspended account attempt: {user.email}")
return api_response(
success=False,
message="Account is suspended. Contact an administrator.",
status=403,
error_type="ACCOUNT_SUSPENDED",
)
# Check if user has any WebAuthn credentials # Check if user has any WebAuthn credentials
if not user.has_webauthn_enabled(): if not user.has_webauthn_enabled():
logger.warning(f"WebAuthn login begin - no credentials for user: {user.email}") logger.warning(f"WebAuthn login begin - no credentials for user: {user.email}")
@@ -992,6 +1024,18 @@ def complete_webauthn_login():
error_type="AUTHENTICATION_ERROR", error_type="AUTHENTICATION_ERROR",
) )
# Check account suspension before completing login
from gatehouse_app.utils.constants import UserStatus
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
session.pop("webauthn_pending_user_id", None)
logger.warning(f"WebAuthn login complete - suspended account attempt: {user.email}")
return api_response(
success=False,
message="Account is suspended. Contact an administrator.",
status=403,
error_type="ACCOUNT_SUSPENDED",
)
# Extract challenge from client data # Extract challenge from client data
client_data = data.get("response", {}).get("clientDataJSON", "") client_data = data.get("response", {}).get("clientDataJSON", "")
@@ -1129,6 +1173,19 @@ def delete_webauthn_credential(credential_id):
""" """
user = g.current_user user = g.current_user
# First check that the specific credential actually belongs to this user.
# Only then check whether it is the last one — otherwise a user with zero
# credentials gets a misleading "Cannot delete the last passkey" error
# instead of a 404.
credential_exists = WebAuthnService.credential_belongs_to_user(credential_id, user)
if not credential_exists:
return api_response(
success=False,
message="Credential not found",
status=404,
error_type="NOT_FOUND",
)
# Check if this is the last credential # Check if this is the last credential
credential_count = user.get_webauthn_credential_count() credential_count = user.get_webauthn_credential_count()
if credential_count <= 1: if credential_count <= 1:
@@ -1238,6 +1295,7 @@ _pw_logger = logging.getLogger(__name__)
@api_v1_bp.route("/auth/forgot-password", methods=["POST"]) @api_v1_bp.route("/auth/forgot-password", methods=["POST"])
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_FORGOT_PASSWORD"])
def forgot_password(): def forgot_password():
"""Request a password reset email. """Request a password reset email.
@@ -1294,6 +1352,7 @@ def forgot_password():
@api_v1_bp.route("/auth/reset-password", methods=["POST"]) @api_v1_bp.route("/auth/reset-password", methods=["POST"])
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESET_PASSWORD"])
def reset_password(): def reset_password():
"""Reset a user's password using a reset token. """Reset a user's password using a reset token.
@@ -1601,11 +1660,31 @@ def get_token():
302: Redirect to ``<redirect>?token=<token>`` 302: Redirect to ``<redirect>?token=<token>``
""" """
from flask import redirect as flask_redirect from flask import redirect as flask_redirect
from urllib.parse import urlparse
token = g.current_session.token token = g.current_session.token
redirect_url = request.args.get("redirect", "").strip() redirect_url = request.args.get("redirect", "").strip()
if redirect_url: if redirect_url:
# Validate redirect URL against allowed origins to prevent open-redirect
# token exfiltration attacks (CWE-601).
allowed_origins = set(current_app.config.get("CORS_ORIGINS", []))
frontend_url = current_app.config.get("FRONTEND_URL", "")
if frontend_url:
parsed = urlparse(frontend_url)
allowed_origins.add(f"{parsed.scheme}://{parsed.netloc}")
parsed_redirect = urlparse(redirect_url)
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
if redirect_origin not in allowed_origins:
return api_response(
success=False,
message="Redirect URL is not allowed.",
status=400,
error_type="INVALID_REDIRECT",
)
sep = "&" if "?" in redirect_url else "?" sep = "&" if "?" in redirect_url else "?"
return flask_redirect(f"{redirect_url}{sep}token={token}", code=302) return flask_redirect(f"{redirect_url}{sep}token={token}", code=302)
+216 -4
View File
@@ -226,6 +226,10 @@ def delete_organization(org_id):
""" """
Delete organization (soft delete). Delete organization (soft delete).
The owner may only delete the organization if they are the *sole* remaining
member. If other active members exist they must first transfer ownership
(or remove all other members) before deleting the organization.
Args: Args:
org_id: Organization ID org_id: Organization ID
@@ -234,9 +238,26 @@ def delete_organization(org_id):
401: Not authenticated 401: Not authenticated
403: Not the owner 403: Not the owner
404: Organization not found 404: Organization not found
409: Organization still has other members transfer ownership first
""" """
org = OrganizationService.get_organization_by_id(org_id) org = OrganizationService.get_organization_by_id(org_id)
# Guard: block deletion while non-owner members still exist so ownership
# can be transferred rather than silently orphaning them.
active_member_count = org.get_member_count()
if active_member_count > 1:
return api_response(
success=False,
message=(
"This organization still has other members. "
"Please transfer ownership to another member or remove all "
"other members before deleting the organization."
),
status=409,
error_type="ORG_HAS_MEMBERS",
error_details={"member_count": active_member_count},
)
OrganizationService.delete_organization( OrganizationService.delete_organization(
org=org, org=org,
user_id=g.current_user.id, user_id=g.current_user.id,
@@ -446,6 +467,152 @@ def update_member_role(org_id, user_id):
) )
@api_v1_bp.route("/organizations/<org_id>/transfer-ownership", methods=["POST"])
@login_required
@full_access_required
def transfer_organization_ownership(org_id):
"""Transfer organization ownership from the current user to another member.
Only the current OWNER of the organization may call this endpoint.
The caller will be demoted to ADMIN and the target user will be promoted to OWNER.
Request body:
new_owner_user_id (str): UUID of the member to promote to OWNER.
Returns:
200: Ownership transferred successfully
400: Validation error / missing fields
403: Caller is not the OWNER of this org
404: Organization or target member not found
409: Target is already the OWNER
"""
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole, AuditAction
from gatehouse_app.services.audit_service import AuditService
caller = g.current_user
data = request.get_json() or {}
new_owner_user_id = data.get("new_owner_user_id")
if not new_owner_user_id:
return api_response(
success=False,
message="new_owner_user_id is required",
status=400,
error_type="VALIDATION_ERROR",
)
if str(new_owner_user_id) == str(caller.id):
return api_response(
success=False,
message="You are already the owner of this organization.",
status=409,
error_type="CONFLICT",
)
# Fetch org (raises NotFound internally)
org = OrganizationService.get_organization_by_id(org_id)
# Confirm caller is the current OWNER
caller_membership = OrganizationMember.query.filter_by(
organization_id=org.id,
user_id=caller.id,
deleted_at=None,
).first()
if not caller_membership or caller_membership.role != OrganizationRole.OWNER:
return api_response(
success=False,
message="Only the organization owner can transfer ownership.",
status=403,
error_type="AUTHORIZATION_ERROR",
)
# Verify the target is an active member
target_membership = OrganizationMember.query.filter_by(
organization_id=org.id,
user_id=new_owner_user_id,
deleted_at=None,
).first()
if not target_membership:
return api_response(
success=False,
message="Target user is not a member of this organization.",
status=404,
error_type="NOT_FOUND",
)
if target_membership.role == OrganizationRole.OWNER:
return api_response(
success=False,
message="Target user is already the owner.",
status=409,
error_type="CONFLICT",
)
# ── Atomic role swap ─────────────────────────────────────────────────────
# Demote caller → ADMIN, promote target → OWNER.
# Both updates go through OrganizationService so all hooks/auditing fire.
try:
demoted = OrganizationService.update_member_role(
org=org,
user_id=str(caller.id),
new_role=OrganizationRole.ADMIN,
updater_id=str(caller.id),
)
promoted = OrganizationService.update_member_role(
org=org,
user_id=str(new_owner_user_id),
new_role=OrganizationRole.OWNER,
updater_id=str(caller.id),
)
except Exception as exc:
from gatehouse_app.extensions import db as _db
_db.session.rollback()
return api_response(
success=False,
message=f"Failed to transfer ownership: {exc}",
status=500,
error_type="SERVER_ERROR",
)
AuditService.log_action(
action=AuditAction.ORG_OWNERSHIP_TRANSFERRED,
user_id=caller.id,
organization_id=org.id,
resource_type="organization",
resource_id=str(org.id),
description=(
f"Ownership of '{org.name}' transferred from {caller.email} "
f"to {target_membership.user.email if target_membership.user else new_owner_user_id}"
),
metadata={
"previous_owner_id": str(caller.id),
"previous_owner_email": caller.email,
"new_owner_id": str(new_owner_user_id),
"new_owner_email": (
target_membership.user.email if target_membership.user else None
),
},
)
def _member_dict(m):
d = m.to_dict()
if m.user:
d["user"] = m.user.to_dict()
return d
return api_response(
data={
"previous_owner": _member_dict(demoted),
"new_owner": _member_dict(promoted),
},
message=(
f"Ownership of '{org.name}' successfully transferred to "
f"{target_membership.user.email if target_membership.user else new_owner_user_id}."
),
)
@api_v1_bp.route("/organizations/<org_id>/audit-logs", methods=["GET"]) @api_v1_bp.route("/organizations/<org_id>/audit-logs", methods=["GET"])
@login_required @login_required
@require_admin @require_admin
@@ -756,10 +923,30 @@ def accept_invite(token):
inviter_id=invite.invited_by_id, inviter_id=invite.invited_by_id,
) )
except Exception: except Exception:
pass # Already a member is fine from gatehouse_app.extensions import db
db.session.rollback() # Clear broken transaction so invite.accept() can commit
invite.accept() invite.accept()
has_webauthn = user.has_webauthn_enabled()
has_totp = user.has_totp_enabled()
if has_webauthn:
from flask import session as flask_session
flask_session["webauthn_pending_user_id"] = user.id
return api_response(
data={"requires_webauthn": True},
message="Passkey verification required. Please use your passkey to complete sign-in.",
)
if has_totp:
from flask import session as flask_session
flask_session["totp_pending_user_id"] = user.id
return api_response(
data={"requires_totp": True},
message="TOTP code required. Please enter your 6-digit code from your authenticator app.",
)
user_session = AuthService.create_session(user) user_session = AuthService.create_session(user)
return api_response( return api_response(
@@ -1379,6 +1566,7 @@ def create_org_ca(org_id):
from gatehouse_app.models.ssh_ca.ca import CA, KeyType from gatehouse_app.models.ssh_ca.ca import CA, KeyType
from gatehouse_app.models.organization.organization import Organization from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.utils.crypto import compute_ssh_fingerprint from gatehouse_app.utils.crypto import compute_ssh_fingerprint
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
from marshmallow import Schema, fields as ma_fields, validate, ValidationError as MaValidationError from marshmallow import Schema, fields as ma_fields, validate, ValidationError as MaValidationError
from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
@@ -1448,13 +1636,16 @@ def create_org_ca(org_id):
public_key_str = private_key_obj.public_key.to_string() public_key_str = private_key_obj.public_key.to_string()
fingerprint = compute_ssh_fingerprint(public_key_str) fingerprint = compute_ssh_fingerprint(public_key_str)
# Encrypt the private key before storing in the database
encrypted_private_key = encrypt_ca_key(private_key_pem)
ca = CA( ca = CA(
organization_id=org_id, organization_id=org_id,
name=data["name"], name=data["name"],
description=data["description"], description=data["description"],
ca_type=CaType(ca_type_val), ca_type=CaType(ca_type_val),
key_type=KeyType(key_type), key_type=KeyType(key_type),
private_key=private_key_pem, private_key=encrypted_private_key,
public_key=public_key_str, public_key=public_key_str,
fingerprint=fingerprint, fingerprint=fingerprint,
default_cert_validity_hours=data["default_cert_validity_hours"], default_cert_validity_hours=data["default_cert_validity_hours"],
@@ -1462,7 +1653,24 @@ def create_org_ca(org_id):
is_active=True, is_active=True,
) )
db.session.add(ca) db.session.add(ca)
db.session.commit() try:
db.session.commit()
except Exception as commit_exc:
db.session.rollback()
# Surface unique-constraint violations (soft-deleted record with same name) as a
# user-friendly 400 instead of a 500.
exc_str = str(commit_exc).lower()
if "uix_org_ca_name" in exc_str or "unique" in exc_str:
return api_response(
success=False,
message=(
"A CA with that name already exists in this organization "
"(it may have been recently deleted — choose a different name)."
),
status=400,
error_type="DUPLICATE_NAME",
)
raise
return api_response( return api_response(
data={"ca": ca.to_dict()}, data={"ca": ca.to_dict()},
@@ -1570,6 +1778,7 @@ def rotate_org_ca(org_id, ca_id):
from gatehouse_app.models.ssh_ca.ca import CA, KeyType from gatehouse_app.models.ssh_ca.ca import CA, KeyType
from gatehouse_app.models.organization.organization import Organization from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.utils.crypto import compute_ssh_fingerprint from gatehouse_app.utils.crypto import compute_ssh_fingerprint
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
from gatehouse_app.utils.constants import AuditAction from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.models import AuditLog from gatehouse_app.models import AuditLog
from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
@@ -1609,8 +1818,11 @@ def rotate_org_ca(org_id, ca_id):
new_public_key = private_key_obj.public_key.to_string() new_public_key = private_key_obj.public_key.to_string()
new_fingerprint = compute_ssh_fingerprint(new_public_key) new_fingerprint = compute_ssh_fingerprint(new_public_key)
# Encrypt the new private key before storing
encrypted_new_private_key = encrypt_ca_key(new_private_key)
ca.rotate_key( ca.rotate_key(
new_private_key=new_private_key, new_private_key=encrypted_new_private_key,
new_public_key=new_public_key, new_public_key=new_public_key,
new_fingerprint=new_fingerprint, new_fingerprint=new_fingerprint,
reason=reason, reason=reason,
+89 -8
View File
@@ -15,6 +15,7 @@ from gatehouse_app.exceptions import (
) )
from gatehouse_app.utils.constants import AuditAction from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.models import AuditLog from gatehouse_app.models import AuditLog
from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
from gatehouse_app.utils.decorators import login_required from gatehouse_app.utils.decorators import login_required
from gatehouse_app.utils.response import api_response from gatehouse_app.utils.response import api_response
@@ -78,11 +79,16 @@ def _get_or_create_system_ca():
with open(pub_key_path) as f: with open(pub_key_path) as f:
pub_key = f.read().strip() pub_key = f.read().strip()
# Load private key for the record (stored but not actually used for signing here) # Load private key for the record (encrypt before storing in DB)
priv_key = "" priv_key = ""
if os.path.exists(key_path): if os.path.exists(key_path):
with open(key_path) as f: with open(key_path) as f:
priv_key = f.read() raw_priv_key = f.read()
try:
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
priv_key = encrypt_ca_key(raw_priv_key)
except Exception:
priv_key = raw_priv_key # fallback: store as-is if encryption unavailable
fingerprint = compute_ssh_fingerprint(pub_key) fingerprint = compute_ssh_fingerprint(pub_key)
@@ -120,7 +126,7 @@ def _get_or_create_system_ca():
return None return None
def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user'): def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user', cert_identity=None):
"""Save a signed certificate to the ssh_certificates table. """Save a signed certificate to the ssh_certificates table.
Args: Args:
@@ -130,6 +136,8 @@ def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=N
signing_response: SSHCertificateSigningResponse signing_response: SSHCertificateSigningResponse
request_ip: Client IP address request_ip: Client IP address
cert_type_str: 'user' or 'host' (from the sign request) cert_type_str: 'user' or 'host' (from the sign request)
cert_identity: Rich OpenSSH key_id string (e.g. "user@host (Name) [org:slug]").
Falls back to str(ssh_key_id) when not provided.
Returns: Returns:
SSHCertificate instance or None if persistence failed SSHCertificate instance or None if persistence failed
@@ -153,7 +161,7 @@ def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=N
ssh_key_id=ssh_key_id, ssh_key_id=ssh_key_id,
certificate=signing_response.certificate, certificate=signing_response.certificate,
serial=signing_response.serial, serial=signing_response.serial,
key_id=str(ssh_key_id), key_id=cert_identity or str(ssh_key_id),
cert_type=resolved_cert_type, cert_type=resolved_cert_type,
principals=signing_response.principals, principals=signing_response.principals,
valid_after=signing_response.valid_after, valid_after=signing_response.valid_after,
@@ -465,7 +473,7 @@ def sign_certificate():
# ── Check account suspension ────────────────────────────────────────────── # ── Check account suspension ──────────────────────────────────────────────
from gatehouse_app.utils.constants import UserStatus from gatehouse_app.utils.constants import UserStatus
if user.status == UserStatus.SUSPENDED: if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
return api_response( return api_response(
success=False, success=False,
message="Your account is suspended. Contact an administrator.", message="Your account is suspended. Contact an administrator.",
@@ -482,6 +490,18 @@ def sign_certificate():
key_id = data.get('key_id') or data.get('cert_id') key_id = data.get('key_id') or data.get('cert_id')
expiry_hours = data.get('expiry_hours') expiry_hours = data.get('expiry_hours')
# ── Log the request ───────────────────────────────────────────────────────
AuditLog.log(
action=AuditAction.SSH_CERT_REQUESTED,
user_id=user_id,
resource_type='SSHCertificate',
ip_address=request.remote_addr,
description=(
f'{user.email} requested a certificate'
+ (f' for principals: {", ".join(requested_principals)}' if requested_principals else '')
),
)
# ── Resolve which principals the user is allowed to use ────────────────── # ── Resolve which principals the user is allowed to use ──────────────────
from gatehouse_app.models.organization.organization_member import OrganizationMember from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
@@ -601,11 +621,24 @@ def sign_certificate():
else: else:
policy_extensions = None # let signing service use its own defaults policy_extensions = None # let signing service use its own defaults
# ── Build rich key_id identity for the OpenSSH cert ─────────────────────
# This appears in `ssh-keygen -L -f cert.pub` as the Key ID field and
# is stored in the DB cert record so it's auditable.
org_slugs = sorted({
om.organization.slug
for om in memberships
if om.organization and om.organization.deleted_at is None
and getattr(om.organization, 'slug', None)
})
org_slug = org_slugs[0] if org_slugs else "unknown"
full_name = getattr(user, 'full_name', None) or getattr(user, 'name', None) or "unknown"
cert_identity = f"{user.email} ({full_name}) [org:{org_slug}]"
signing_request = SSHCertificateSigningRequest( signing_request = SSHCertificateSigningRequest(
ssh_public_key=ssh_key.payload, ssh_public_key=ssh_key.payload,
principals=principals, principals=principals,
cert_type=cert_type, cert_type=cert_type,
key_id=key_id, key_id=cert_identity,
expiry_hours=int(expiry_hours) if expiry_hours else None, expiry_hours=int(expiry_hours) if expiry_hours else None,
extensions=policy_extensions, extensions=policy_extensions,
) )
@@ -620,7 +653,11 @@ def sign_certificate():
) )
try: try:
response = ssh_ca_service.sign_certificate(signing_request, ca_private_key=db_ca.private_key) from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
ca_private_key_pem = decrypt_ca_key(db_ca.private_key)
response = ssh_ca_service.sign_certificate(
signing_request, ca_private_key=ca_private_key_pem, ca_obj=db_ca
)
except SSHCertificateError as e: except SSHCertificateError as e:
AuditLog.log( AuditLog.log(
action=AuditAction.SSH_CERT_FAILED, action=AuditAction.SSH_CERT_FAILED,
@@ -649,6 +686,7 @@ def sign_certificate():
signing_response=response, signing_response=response,
request_ip=request.remote_addr, request_ip=request.remote_addr,
cert_type_str=cert_type, cert_type_str=cert_type,
cert_identity=cert_identity,
) )
AuditLog.log( AuditLog.log(
@@ -657,9 +695,42 @@ def sign_certificate():
resource_type='SSHCertificate', resource_type='SSHCertificate',
resource_id=cert_record.id if cert_record else key_id, resource_id=cert_record.id if cert_record else key_id,
ip_address=request.remote_addr, ip_address=request.remote_addr,
description=f'Certificate issued for principals: {", ".join(principals)}', description=(
f'Certificate serial={response.serial} issued for {user.email}; '
f'principals: {", ".join(principals)}'
),
extra_data={
'serial': response.serial,
'key_id': cert_identity,
'principals': principals,
'ca_id': str(db_ca.id),
'ssh_key_id': str(key_id),
},
) )
if cert_record:
CertificateAuditLog.log(
certificate_id=cert_record.id,
action='issued',
user_id=user_id,
ip_address=request.remote_addr,
user_agent=request.headers.get('User-Agent'),
message=(
f'Certificate serial={response.serial} issued for {user.email}; '
f'principals: {", ".join(principals)}'
),
extra_data={
'serial': response.serial,
'key_id': cert_identity,
'principals': principals,
'ca_id': str(db_ca.id),
'ssh_key_id': str(key_id),
'valid_after': response.valid_after.isoformat() if response.valid_after else None,
'valid_before': response.valid_before.isoformat() if response.valid_before else None,
},
success=True,
)
result = { result = {
'certificate': response.certificate, 'certificate': response.certificate,
'serial': response.serial, 'serial': response.serial,
@@ -753,6 +824,16 @@ def revoke_certificate(cert_id):
description=f'Revoked: {reason}', description=f'Revoked: {reason}',
) )
CertificateAuditLog.log(
certificate_id=cert_id,
action='revoked',
user_id=user_id,
ip_address=request.remote_addr,
user_agent=request.headers.get('User-Agent'),
message=f'Certificate revoked: {reason}',
success=True,
)
return api_response( return api_response(
success=True, success=True,
message='Certificate revoked successfully', message='Certificate revoked successfully',
+221 -1
View File
@@ -73,11 +73,51 @@ def delete_me():
""" """
Delete current user account (soft delete). Delete current user account (soft delete).
Blocked if the user is the sole owner of any organization that has other
active members they must transfer ownership or dissolve those organizations
first.
Returns: Returns:
200: Account deleted successfully 200: Account deleted successfully
401: Not authenticated 401: Not authenticated
409: User is sole owner of one or more organizations with other members
""" """
UserService.delete_user(g.current_user, soft=True) from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
user = g.current_user
# Find orgs where this user is the sole owner AND other members exist.
owned_memberships = OrganizationMember.query.filter_by(
user_id=user.id,
role=OrganizationRole.OWNER,
deleted_at=None,
).all()
blocked_orgs = []
for membership in owned_memberships:
org = membership.organization
if org.deleted_at is not None:
continue
member_count = org.get_member_count()
if member_count > 1:
blocked_orgs.append(org.name)
if blocked_orgs:
names = ", ".join(f'"{n}"' for n in blocked_orgs)
return api_response(
success=False,
message=(
f"You are the sole owner of {len(blocked_orgs)} organization"
f"{'s' if len(blocked_orgs) > 1 else ''}: {names}. "
"Transfer ownership or delete those organizations before deleting your account."
),
status=409,
error_type="USER_IS_SOLE_OWNER",
error_details={"organizations": blocked_orgs},
)
UserService.delete_user(user, soft=True)
return api_response( return api_response(
message="Account deleted successfully", message="Account deleted successfully",
@@ -454,6 +494,31 @@ def admin_suspend_user(user_id):
if not admin_in_shared_org: if not admin_in_shared_org:
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
# ── Owner protection ──────────────────────────────────────────────────────
# An org owner cannot be suspended until they transfer ownership.
from gatehouse_app.utils.constants import OrganizationRole
owner_memberships = OrganizationMember.query.filter(
OrganizationMember.user_id == target.id,
OrganizationMember.role == OrganizationRole.OWNER,
OrganizationMember.deleted_at == None,
).all()
if owner_memberships:
org_names = [
m.organization.name
for m in owner_memberships
if m.organization and not m.organization.deleted_at
]
return api_response(
success=False,
message=(
f"Cannot suspend an organization owner. "
f"{target.email} is the owner of: {', '.join(org_names)}. "
"Transfer ownership to another member first."
),
status=403,
error_type="OWNER_PROTECTION",
)
if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT") return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT")
@@ -645,3 +710,158 @@ def get_my_memberships():
data={"orgs": orgs_result}, data={"orgs": orgs_result},
message="Memberships retrieved", message="Memberships retrieved",
) )
@api_v1_bp.route("/admin/users/<user_id>/delete", methods=["POST"])
@login_required
@full_access_required
def admin_hard_delete_user(user_id):
"""Permanently delete a user and ALL associated data (hard delete, irreversible).
Required body: {"confirm": true}
Pre-conditions:
- Caller is OWNER or ADMIN of a shared org with the target.
- Cannot delete yourself.
- Target must not be the OWNER of any active organization (transfer first).
Side-effects:
- All active SSH certificates are revoked before deletion.
- The user row and all cascaded rows are hard-deleted from the database.
- An audit log entry is written by the *caller* (so it is not lost with the user).
"""
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.user.user import User as _User
from gatehouse_app.extensions import db as _db
from gatehouse_app.utils.constants import UserStatus, AuditAction, OrganizationRole
from gatehouse_app.services.audit_service import AuditService
caller = g.current_user
data = request.get_json() or {}
if not data.get("confirm"):
return api_response(
success=False,
message="Deletion requires explicit confirmation. Send {\"confirm\": true} to proceed.",
status=400,
error_type="CONFIRMATION_REQUIRED",
)
target = _User.query.filter_by(id=user_id).first()
if not target:
return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
if target.id == caller.id:
return api_response(
success=False,
message="Cannot delete your own account via this endpoint.",
status=400,
error_type="BAD_REQUEST",
)
# Caller must be OWNER/ADMIN of a shared org.
# Include soft-deleted memberships so that already-soft-deleted users can
# still be hard-deleted by an admin who shared an org with them.
target_org_ids = {m.organization_id for m in target.organization_memberships}
admin_in_shared_org = OrganizationMember.query.filter(
OrganizationMember.user_id == caller.id,
OrganizationMember.organization_id.in_(target_org_ids),
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
OrganizationMember.deleted_at == None,
).first()
if not admin_in_shared_org:
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
# Block deletion if target is an org owner — they must transfer first
owner_memberships = OrganizationMember.query.filter(
OrganizationMember.user_id == target.id,
OrganizationMember.role == OrganizationRole.OWNER,
OrganizationMember.deleted_at == None,
).all()
if owner_memberships:
org_names = [
m.organization.name
for m in owner_memberships
if m.organization and not m.organization.deleted_at
]
return api_response(
success=False,
message=(
f"Cannot delete an organization owner. "
f"{target.email} is the owner of: {', '.join(org_names)}. "
"Transfer ownership to another member first."
),
status=403,
error_type="OWNER_PROTECTION",
)
# ── Collect counts for audit metadata ────────────────────────────────────
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
ssh_key_count = SSHKey.query.filter_by(user_id=target.id, deleted_at=None).count()
active_cert_count = SSHCertificate.query.filter_by(
user_id=target.id, revoked=False
).filter(SSHCertificate.deleted_at == None).count()
# ── Revoke all active SSH certificates before deletion ───────────────────
active_certs = SSHCertificate.query.filter_by(
user_id=target.id, revoked=False
).filter(SSHCertificate.deleted_at == None).all()
for cert in active_certs:
try:
cert.revoke("account_deleted")
except Exception:
pass
if active_certs:
try:
_db.session.flush()
except Exception:
pass
# ── Hard delete ───────────────────────────────────────────────────────────
target_email = target.email # capture before deletion
target_id_str = str(target.id)
try:
_db.session.delete(target) # cascades to all child tables
_db.session.flush()
except Exception as exc:
_db.session.rollback()
import logging
logging.getLogger(__name__).error(f"Hard delete failed for {target_id_str}: {exc}")
return api_response(
success=False,
message="Failed to delete user account. Please try again.",
status=500,
error_type="SERVER_ERROR",
)
# ── Audit log (written as the caller so it survives the deletion) ─────────
AuditService.log_action(
action=AuditAction.USER_HARD_DELETE,
user_id=caller.id,
organization_id=admin_in_shared_org.organization_id,
resource_type="user",
resource_id=target_id_str,
description=f"Admin permanently deleted user account: {target_email}",
metadata={
"deleted_user_id": target_id_str,
"deleted_user_email": target_email,
"ssh_keys_deleted": ssh_key_count,
"certs_revoked": active_cert_count,
},
)
_db.session.commit()
return api_response(
message=f"User account {target_email} has been permanently deleted.",
data={
"deleted_user_id": target_id_str,
"deleted_user_email": target_email,
"ssh_keys_deleted": ssh_key_count,
"certs_revoked": active_cert_count,
},
)
+27 -1
View File
@@ -88,6 +88,11 @@ class CA(BaseModel):
rotated_at = db.Column(db.DateTime, nullable=True) rotated_at = db.Column(db.DateTime, nullable=True)
rotation_reason = db.Column(db.String(255), nullable=True) rotation_reason = db.Column(db.String(255), nullable=True)
# Monotonically-increasing serial counter. Every cert this CA issues
# gets the next value so serials are unique, ordered, and auditable.
# Protected by a row-level SELECT … FOR UPDATE in get_next_serial().
next_serial_number = db.Column(db.BigInteger, default=1, nullable=False)
# Relationships # Relationships
organization = db.relationship("Organization", back_populates="cas") organization = db.relationship("Organization", back_populates="cas")
certificates = db.relationship( certificates = db.relationship(
@@ -102,7 +107,6 @@ class CA(BaseModel):
) )
__table_args__ = ( __table_args__ = (
db.UniqueConstraint("organization_id", "name", name="uix_org_ca_name"),
db.Index("idx_ca_org_active", "organization_id", "is_active"), db.Index("idx_ca_org_active", "organization_id", "is_active"),
) )
@@ -162,6 +166,28 @@ class CA(BaseModel):
self.rotation_reason = reason self.rotation_reason = reason
self.save() self.save()
def get_next_serial(self) -> int:
"""Atomically increment and return the next certificate serial number.
Uses a SELECT FOR UPDATE row lock so concurrent requests never
receive the same serial. Must be called inside an active DB
transaction (i.e. before the final session.commit()).
Returns:
int: The serial number to embed in the next certificate.
"""
# Re-fetch this CA row with an exclusive row lock
locked = (
db.session.query(CA)
.with_for_update()
.filter_by(id=self.id)
.one()
)
serial = locked.next_serial_number
locked.next_serial_number = serial + 1
db.session.flush() # write increment; commit happens in the caller
return serial
class CAPermission(BaseModel): class CAPermission(BaseModel):
"""Per-user CA permission model. """Per-user CA permission model.
+25 -1
View File
@@ -77,7 +77,10 @@ class TOTPVerifyEnrollmentSchema(Schema):
class TOTPVerifySchema(Schema): class TOTPVerifySchema(Schema):
"""Schema for TOTP code verification during login.""" """Schema for TOTP code verification during login."""
code = fields.Str(required=True) code = fields.Str(
required=True,
validate=validate.Length(min=1),
)
is_backup_code = fields.Bool(load_default=False) is_backup_code = fields.Bool(load_default=False)
client_timestamp = fields.Int( client_timestamp = fields.Int(
required=False, required=False,
@@ -85,6 +88,27 @@ class TOTPVerifySchema(Schema):
metadata={"description": "Client UTC timestamp in seconds since epoch for TOTP verification"}, metadata={"description": "Client UTC timestamp in seconds since epoch for TOTP verification"},
) )
@validates_schema
def validate_code_format(self, data, **kwargs):
"""Validate code format depending on whether it's a backup code."""
code = data.get("code", "")
is_backup_code = data.get("is_backup_code", False)
if is_backup_code:
# Backup codes are 16 uppercase hex characters
if not code or len(code) != 16 or not all(c in "0123456789ABCDEFabcdef" for c in code):
raise ValidationError(
"Backup code must be a 16-character hexadecimal string.",
field_name="code",
)
else:
# Regular TOTP codes are exactly 6 digits
import re
if not re.match(r"^\d{6}$", code):
raise ValidationError(
"Code must be a 6-digit number.",
field_name="code",
)
class TOTPDisableSchema(Schema): class TOTPDisableSchema(Schema):
"""Schema for disabling TOTP.""" """Schema for disabling TOTP."""
+1 -1
View File
@@ -59,7 +59,7 @@ class AuditService:
ip_address=ip_address, ip_address=ip_address,
user_agent=user_agent, user_agent=user_agent,
request_id=request_id, request_id=request_id,
metadata=metadata, extra_data=metadata,
description=description, description=description,
success=success, success=success,
error_message=error_message, error_message=error_message,
+32 -1
View File
@@ -102,7 +102,7 @@ class AuthService:
if current_app.config.get('ENV') == 'development': if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}") logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}")
if user.status == UserStatus.SUSPENDED: if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
raise AccountSuspendedError() raise AccountSuspendedError()
if user.status == UserStatus.INACTIVE: if user.status == UserStatus.INACTIVE:
raise AccountInactiveError() raise AccountInactiveError()
@@ -210,6 +210,22 @@ class AuthService:
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8") auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
db.session.commit() db.session.commit()
# Invalidate all other sessions so that if an attacker had a valid
# session token, changing the password actually locks them out.
# The current request's session (if any) is preserved so the user
# doesn't have to log in again immediately.
from flask import g as flask_g
current_session_id = getattr(flask_g, "current_session", None)
current_session_id = current_session_id.id if current_session_id else None
sessions_to_revoke = Session.query.filter(
Session.user_id == user.id,
Session.revoked_at == None, # noqa: E711
).all()
for sess in sessions_to_revoke:
if sess.id != current_session_id:
sess.revoke(reason="Password changed")
db.session.commit()
# Log password change # Log password change
AuditService.log_action( AuditService.log_action(
action=AuditAction.PASSWORD_CHANGE, action=AuditAction.PASSWORD_CHANGE,
@@ -482,9 +498,24 @@ class AuthService:
if not secret: if not secret:
raise InvalidCredentialsError("TOTP secret not found") raise InvalidCredentialsError("TOTP secret not found")
# Replay-attack prevention: reject codes that have already been
# accepted within the current validity window.
if TOTPService.is_code_already_used(str(user.id), code):
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_FAILED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP code replay attempt detected",
)
raise InvalidCredentialsError("Invalid TOTP code")
is_valid = TOTPService.verify_code(secret, code, client_utc_timestamp=client_utc_timestamp) is_valid = TOTPService.verify_code(secret, code, client_utc_timestamp=client_utc_timestamp)
if is_valid: if is_valid:
# Mark this code as used to prevent replay within the validity window
TOTPService.mark_code_used(str(user.id), code)
auth_method.last_used_at = datetime.now(timezone.utc) auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit() db.session.commit()
@@ -736,9 +736,14 @@ class ExternalAuthService:
400, 400,
) )
# Generate PKCE # Generate PKCE — skip for confidential clients (Google, Microsoft) that use a
code_verifier = secrets.token_urlsafe(32) # client_secret. Sending code_challenge to Microsoft causes it to enforce PKCE on
code_challenge = cls._compute_s256_challenge(code_verifier) # the token exchange, which then fails. Matches the behaviour of initiate_login_flow.
code_verifier = None
code_challenge = None
if provider_type_str not in ('google', 'microsoft'):
code_verifier = secrets.token_urlsafe(32)
code_challenge = cls._compute_s256_challenge(code_verifier)
# Create OAuth state # Create OAuth state
state = OAuthState.create_state( state = OAuthState.create_state(
+20 -2
View File
@@ -188,11 +188,10 @@ class OrganizationService:
Raises: Raises:
ConflictError: If user is already a member ConflictError: If user is already a member
""" """
# Check if already a member # Check if already a member (active or soft-deleted — both blocked by DB unique constraint)
existing = OrganizationMember.query.filter_by( existing = OrganizationMember.query.filter_by(
user_id=user_id, user_id=user_id,
organization_id=org.id, organization_id=org.id,
deleted_at=None,
).first() ).first()
# Development-only debug logging for membership validation # Development-only debug logging for membership validation
@@ -200,6 +199,25 @@ class OrganizationService:
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}") logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
if existing: if existing:
if existing.deleted_at is not None:
# Reactivate the soft-deleted membership with the new role
existing.deleted_at = None
existing.role = role
existing.invited_by_id = inviter_id
existing.invited_at = datetime.now(timezone.utc)
existing.joined_at = datetime.now(timezone.utc)
existing.save()
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ADD,
user_id=inviter_id,
organization_id=org.id,
resource_type="organization_member",
resource_id=existing.id,
metadata={"added_user_id": user_id, "role": role.value},
description=f"Member re-added to organization with role: {role.value}",
)
return existing
raise ConflictError("User is already a member of this organization") raise ConflictError("User is already a member of this organization")
# Create membership # Create membership
@@ -192,13 +192,19 @@ class SSHCASigningService:
self, self,
signing_request: SSHCertificateSigningRequest, signing_request: SSHCertificateSigningRequest,
ca_private_key: Optional[str] = None, ca_private_key: Optional[str] = None,
) -> SSHCertificateSigningResponse: ca_obj=None,
) -> "SSHCertificateSigningResponse":
"""Sign an SSH certificate. """Sign an SSH certificate.
Args: Args:
signing_request: SSHCertificateSigningRequest instance signing_request: SSHCertificateSigningRequest instance
ca_private_key: CA private key in PEM format. If not provided, ca_private_key: CA private key in PEM format. If not provided,
loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var) loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var)
ca_obj: Optional CA model instance. When supplied its monotonic
serial counter is incremented atomically (SELECT FOR UPDATE)
and the resulting integer is embedded in the certificate's
serial field. This ensures every issued cert has a unique,
ordered, auditable serial number.
Returns: Returns:
SSHCertificateSigningResponse with signed certificate SSHCertificateSigningResponse with signed certificate
@@ -245,7 +251,8 @@ class SSHCASigningService:
valid_before = now + timedelta(hours=expiry_hours) valid_before = now + timedelta(hours=expiry_hours)
# Set certificate fields # Set certificate fields
cert_type = 1 if signing_request.cert_type == "user" else 0 # sshkey-tools: user=1, host=2 (not 0)
cert_type = 1 if signing_request.cert_type == "user" else 2
certificate.fields.cert_type = cert_type certificate.fields.cert_type = cert_type
certificate.fields.key_id = signing_request.key_id certificate.fields.key_id = signing_request.key_id
@@ -253,6 +260,19 @@ class SSHCASigningService:
certificate.fields.valid_after = now certificate.fields.valid_after = now
certificate.fields.valid_before = valid_before certificate.fields.valid_before = valid_before
# ── Serial number ────────────────────────────────────────────────
# If a CA object is provided, use its monotonic counter so every
# certificate gets a unique, ordered, auditable serial. The
# counter increment is flushed inside get_next_serial(); the
# caller's commit() persists it atomically with the cert record.
if ca_obj is not None:
assigned_serial = ca_obj.get_next_serial()
certificate.fields.serial = assigned_serial
self.logger.debug(
f"Assigned serial {assigned_serial} from CA {ca_obj.id}"
)
# ─────────────────────────────────────────────────────────────────
# Set extensions — prefer policy-provided list, fall back to standard set # Set extensions — prefer policy-provided list, fall back to standard set
extensions = signing_request.extensions extensions = signing_request.extensions
if not extensions: if not extensions:
@@ -276,8 +296,13 @@ class SSHCASigningService:
self.logger.error(f"Certificate verification failed: {str(e)}") self.logger.error(f"Certificate verification failed: {str(e)}")
raise SSHCASigningError(f"Certificate verification failed: {str(e)}") raise SSHCASigningError(f"Certificate verification failed: {str(e)}")
# Extract serial from certificate # Extract serial from certificate — use the integer we assigned
serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial) # when ca_obj was provided, otherwise fall back to whatever the
# library generated.
if ca_obj is not None:
serial = str(assigned_serial)
else:
serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
# Build response # Build response
cert_string = certificate.to_string() cert_string = certificate.to_string()
+41
View File
@@ -11,10 +11,51 @@ from gatehouse_app.extensions import bcrypt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TOTP codes are valid for at most (2*window + 1) * 30s steps.
# With window=1 that's 3 steps = 90 seconds. We use a slightly
# generous TTL of 95 seconds to account for clock skew at boundaries.
_TOTP_USED_CODE_TTL = 95
class TOTPService: class TOTPService:
"""Service for TOTP operations.""" """Service for TOTP operations."""
# ------------------------------------------------------------------
# Replay-attack prevention helpers
# ------------------------------------------------------------------
@staticmethod
def _used_key(user_id: str, code: str) -> str:
return f"totp:used:{user_id}:{code}"
@staticmethod
def is_code_already_used(user_id: str, code: str) -> bool:
"""Return True if *code* has already been accepted for *user_id*
within the current validity window (prevents replay attacks)."""
try:
from gatehouse_app.extensions import redis_client
if redis_client is None:
return False
return redis_client.exists(TOTPService._used_key(user_id, code)) == 1
except Exception:
logger.warning("Redis unavailable for TOTP replay check; allowing code")
return False
@staticmethod
def mark_code_used(user_id: str, code: str) -> None:
"""Record *code* as consumed for *user_id* so it cannot be reused."""
try:
from gatehouse_app.extensions import redis_client
if redis_client is None:
return
redis_client.setex(
TOTPService._used_key(user_id, code),
_TOTP_USED_CODE_TTL,
"1",
)
except Exception:
logger.warning("Redis unavailable; TOTP used-code not recorded")
@staticmethod @staticmethod
def generate_secret() -> str: def generate_secret() -> str:
""" """
@@ -642,6 +642,26 @@ class WebAuthnService:
return True return True
@classmethod
def credential_belongs_to_user(cls, credential_id: str, user: User) -> bool:
"""Check whether *credential_id* exists and belongs to *user*.
Args:
credential_id: The credential ID to look up
user: User instance
Returns:
True if the credential exists and belongs to this user, False otherwise.
"""
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None,
).first()
if not auth_method or not auth_method.provider_data:
return False
return auth_method.provider_data.get("credential_id") == credential_id
@classmethod @classmethod
def rename_credential(cls, credential_id: str, user: User, name: str) -> bool: def rename_credential(cls, credential_id: str, user: User, name: str) -> bool:
"""Rename a passkey credential. """Rename a passkey credential.
+206
View File
@@ -0,0 +1,206 @@
"""Encryption helpers for CA private keys stored in the database.
CA private keys are encrypted at rest using Fernet (AES-128-CBC + HMAC-SHA256)
from the ``cryptography`` package. The encryption key is derived from the
``CA_ENCRYPTION_KEY`` environment variable (or ``Flask.config["CA_ENCRYPTION_KEY"]``).
Key derivation
--------------
Fernet requires a URL-safe base64-encoded 32-byte key. We accept any string
from the env and derive the actual Fernet key using SHA-256 so that operators
can supply human-readable secrets without having to pre-encode them.
Envelope format
---------------
Encrypted values are stored as the string::
$fernet$<fernet_token>
The ``$fernet$`` prefix lets the code distinguish already-encrypted values from
legacy plaintext PEM keys so that the migration path is safe and idempotent.
Usage
-----
Encrypt before storing::
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
ca.private_key = encrypt_ca_key(private_key_pem)
Decrypt before use::
from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
plaintext_pem = decrypt_ca_key(ca.private_key)
"""
import base64
import hashlib
import logging
import os
from cryptography.fernet import Fernet, InvalidToken
logger = logging.getLogger(__name__)
# Prefix that marks a stored value as Fernet-encrypted
_FERNET_PREFIX = "$fernet$"
class CAKeyEncryptionError(Exception):
"""Raised when CA key encryption or decryption fails."""
def _get_fernet() -> Fernet:
"""Build a Fernet instance from the configured encryption key.
Looks up ``CA_ENCRYPTION_KEY`` in the environment first, then falls back to
the Flask app config (if a request context is active).
Raises:
CAKeyEncryptionError: if no key is configured or it is the insecure
placeholder value in a production-like environment.
"""
raw_key = os.environ.get("CA_ENCRYPTION_KEY")
if not raw_key:
# Try Flask config if we're inside an app context
try:
from flask import current_app
raw_key = current_app.config.get("CA_ENCRYPTION_KEY")
except RuntimeError:
pass # No app context
if not raw_key:
raise CAKeyEncryptionError(
"CA_ENCRYPTION_KEY is not set. "
"Set this environment variable before starting the application."
)
# Warn loudly when running with the placeholder in a non-test environment
env_name = os.environ.get("FLASK_ENV", "").lower()
if raw_key.startswith("dev-") and env_name not in ("development", "testing", "test"):
logger.warning(
"CA_ENCRYPTION_KEY appears to be a development placeholder. "
"Set a strong random key for production environments."
)
# Derive a 32-byte key from the raw secret via SHA-256, then URL-safe base64
key_bytes = hashlib.sha256(raw_key.encode()).digest()
fernet_key = base64.urlsafe_b64encode(key_bytes)
return Fernet(fernet_key)
def encrypt_ca_key(plaintext_pem: str) -> str:
"""Encrypt a CA private key PEM string.
Idempotent: already-encrypted values are returned unchanged.
Args:
plaintext_pem: CA private key in OpenSSH/PEM format.
Returns:
Encrypted string with ``$fernet$`` prefix, safe for database storage.
Raises:
CAKeyEncryptionError: if the key cannot be encrypted.
"""
if not plaintext_pem:
raise CAKeyEncryptionError("Cannot encrypt an empty key")
# Already encrypted — do not double-encrypt
if plaintext_pem.startswith(_FERNET_PREFIX):
return plaintext_pem
try:
fernet = _get_fernet()
token = fernet.encrypt(plaintext_pem.encode()).decode()
return f"{_FERNET_PREFIX}{token}"
except CAKeyEncryptionError:
raise
except Exception as exc:
raise CAKeyEncryptionError(f"Failed to encrypt CA key: {exc}") from exc
def decrypt_ca_key(stored_value: str) -> str:
"""Decrypt a CA private key retrieved from the database.
Idempotent: plaintext (legacy) values are returned unchanged so that the
system continues to work while a migration encrypts existing rows.
Args:
stored_value: Value from ``CA.private_key`` column.
Returns:
Plaintext PEM string ready for use with ``sshkey_tools``.
Raises:
CAKeyEncryptionError: if decryption fails (wrong key, corrupted data).
"""
if not stored_value:
raise CAKeyEncryptionError("Cannot decrypt an empty value")
# Legacy plaintext key — return as-is
if not stored_value.startswith(_FERNET_PREFIX):
logger.warning(
"CA private key appears to be stored as plaintext. "
"Run the migration to encrypt existing keys."
)
return stored_value
token = stored_value[len(_FERNET_PREFIX):]
try:
fernet = _get_fernet()
return fernet.decrypt(token.encode()).decode()
except InvalidToken as exc:
raise CAKeyEncryptionError(
"CA key decryption failed — the CA_ENCRYPTION_KEY may be incorrect "
"or the stored key is corrupted."
) from exc
except CAKeyEncryptionError:
raise
except Exception as exc:
raise CAKeyEncryptionError(f"Unexpected decryption error: {exc}") from exc
def is_encrypted(stored_value: str) -> bool:
"""Return True if the stored value has the ``$fernet$`` envelope.
Args:
stored_value: Value from ``CA.private_key`` column.
"""
return bool(stored_value and stored_value.startswith(_FERNET_PREFIX))
def reencrypt_ca_key(stored_value: str, old_raw_key: str, new_raw_key: str) -> str:
"""Re-encrypt a CA key with a new encryption key (for key rotation).
Args:
stored_value: Current value from ``CA.private_key`` (may or may not be encrypted).
old_raw_key: The current ``CA_ENCRYPTION_KEY`` value (raw secret string).
new_raw_key: The new ``CA_ENCRYPTION_KEY`` value to encrypt with.
Returns:
New encrypted envelope string.
Raises:
CAKeyEncryptionError: if decryption or re-encryption fails.
"""
# Decrypt with old key
if stored_value.startswith(_FERNET_PREFIX):
token = stored_value[len(_FERNET_PREFIX):]
old_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(old_raw_key.encode()).digest())
try:
plaintext = Fernet(old_key_bytes).decrypt(token.encode()).decode()
except InvalidToken as exc:
raise CAKeyEncryptionError(
"Re-encryption failed: could not decrypt with the old key."
) from exc
else:
# Plaintext
plaintext = stored_value
# Re-encrypt with new key
new_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(new_raw_key.encode()).digest())
try:
token = Fernet(new_key_bytes).encrypt(plaintext.encode()).decode()
return f"{_FERNET_PREFIX}{token}"
except Exception as exc:
raise CAKeyEncryptionError(f"Re-encryption with new key failed: {exc}") from exc
+2
View File
@@ -61,6 +61,7 @@ class AuditAction(str, Enum):
USER_REGISTER = "user.register" USER_REGISTER = "user.register"
USER_UPDATE = "user.update" USER_UPDATE = "user.update"
USER_DELETE = "user.delete" USER_DELETE = "user.delete"
USER_HARD_DELETE = "user.hard_delete"
USER_SUSPEND = "user.suspend" USER_SUSPEND = "user.suspend"
USER_UNSUSPEND = "user.unsuspend" USER_UNSUSPEND = "user.unsuspend"
PASSWORD_CHANGE = "user.password_change" PASSWORD_CHANGE = "user.password_change"
@@ -73,6 +74,7 @@ class AuditAction(str, Enum):
ORG_MEMBER_ADD = "org.member.add" ORG_MEMBER_ADD = "org.member.add"
ORG_MEMBER_REMOVE = "org.member.remove" ORG_MEMBER_REMOVE = "org.member.remove"
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change" ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
ORG_OWNERSHIP_TRANSFERRED = "org.ownership.transferred"
# Session actions # Session actions
SESSION_CREATE = "session.create" SESSION_CREATE = "session.create"
@@ -0,0 +1,37 @@
"""Add USER_SUSPEND and USER_UNSUSPEND to auditaction enum.
Revision ID: 015_add_user_suspend_audit_actions
Revises: 014_add_dept_cert_policy
Create Date: 2026-03-02
USER_SUSPEND and USER_UNSUSPEND were added to the Python AuditAction enum
but were never synced to the PostgreSQL auditaction type, causing a
DataError (invalid enum value) whenever an admin suspends or unsuspends a user.
"""
from alembic import op
revision = "015_user_suspend_audit"
down_revision = "014_add_dept_cert_policy"
branch_labels = None
depends_on = None
def upgrade():
for val in ("USER_SUSPEND", "USER_UNSUSPEND"):
op.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_enum
WHERE enumlabel = '{val}'
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
) THEN
ALTER TYPE auditaction ADD VALUE '{val}';
END IF;
END$$;
""")
def downgrade():
# PostgreSQL does not support removing enum values; downgrade is a no-op.
pass
@@ -0,0 +1,168 @@
"""Encrypt existing plaintext CA private keys at rest.
Revision ID: 016_encrypt_existing_ca_keys
Revises: 015_add_user_suspend_audit_actions
Create Date: 2026-03-02
All CA private keys created before this migration were stored as plaintext PEM
strings in the ``cas.private_key`` column. This migration detects those rows
(by checking for the absence of the ``$fernet$`` prefix that encrypted values
carry) and re-encrypts them with the key derived from ``CA_ENCRYPTION_KEY``.
The migration is safe to re-run: already-encrypted rows are left untouched.
Prerequisites
-------------
``CA_ENCRYPTION_KEY`` must be set in the environment before running this
migration. The same value must be configured for the running application.
To roll back to plaintext (downgrade):
The ``downgrade()`` function decrypts all rows back to plaintext PEM. This is
provided only for emergency rollback and should not be used in production once
the system has been running with encrypted keys.
"""
import os
import base64
import hashlib
import logging
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
# Alembic revision identifiers
revision = "016_encrypt_ca_keys"
down_revision = "015_user_suspend_audit"
branch_labels = None
depends_on = None
_FERNET_PREFIX = "$fernet$"
def _get_fernet():
"""Build a Fernet instance from CA_ENCRYPTION_KEY env var."""
from cryptography.fernet import Fernet
raw_key = os.environ.get("CA_ENCRYPTION_KEY")
if not raw_key:
raise RuntimeError(
"CA_ENCRYPTION_KEY environment variable is not set. "
"Set it before running this migration."
)
key_bytes = base64.urlsafe_b64encode(hashlib.sha256(raw_key.encode()).digest())
return Fernet(key_bytes)
def upgrade():
"""Encrypt plaintext CA private keys."""
bind = op.get_bind()
session = Session(bind=bind)
try:
fernet = _get_fernet()
except RuntimeError as exc:
raise RuntimeError(str(exc)) from exc
# Fetch all non-deleted CA rows
rows = session.execute(
sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
).fetchall()
encrypted_count = 0
skipped_count = 0
for row in rows:
ca_id, private_key = row[0], row[1]
if not private_key:
logger.warning(f"CA {ca_id} has empty private_key — skipping")
skipped_count += 1
continue
if private_key.startswith(_FERNET_PREFIX):
# Already encrypted
skipped_count += 1
continue
# Encrypt
try:
token = fernet.encrypt(private_key.encode()).decode()
encrypted_value = f"{_FERNET_PREFIX}{token}"
session.execute(
sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
{"pk": encrypted_value, "id": ca_id},
)
encrypted_count += 1
logger.info(f"Encrypted private key for CA {ca_id}")
except Exception as exc:
session.rollback()
raise RuntimeError(
f"Failed to encrypt private key for CA {ca_id}: {exc}"
) from exc
session.commit()
logger.info(
f"CA key encryption migration complete: "
f"{encrypted_count} encrypted, {skipped_count} skipped"
)
print(
f" [016_encrypt_ca_keys] {encrypted_count} CA private key(s) encrypted, "
f"{skipped_count} already encrypted or empty."
)
def downgrade():
"""Decrypt CA private keys back to plaintext (emergency rollback only)."""
bind = op.get_bind()
session = Session(bind=bind)
try:
fernet = _get_fernet()
except RuntimeError as exc:
raise RuntimeError(str(exc)) from exc
rows = session.execute(
sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
).fetchall()
decrypted_count = 0
skipped_count = 0
for row in rows:
ca_id, private_key = row[0], row[1]
if not private_key or not private_key.startswith(_FERNET_PREFIX):
skipped_count += 1
continue
token = private_key[len(_FERNET_PREFIX):]
try:
from cryptography.fernet import InvalidToken
try:
plaintext = fernet.decrypt(token.encode()).decode()
except InvalidToken as exc:
raise RuntimeError(
f"Downgrade failed: cannot decrypt CA {ca_id} — wrong key or corrupted data."
) from exc
session.execute(
sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
{"pk": plaintext, "id": ca_id},
)
decrypted_count += 1
logger.warning(f"Decrypted (plaintext restore) private key for CA {ca_id}")
except RuntimeError:
session.rollback()
raise
session.commit()
logger.warning(
f"CA key decryption (downgrade) complete: "
f"{decrypted_count} decrypted, {skipped_count} skipped"
)
print(
f" [016_encrypt_ca_keys] DOWNGRADE: {decrypted_count} CA private key(s) "
f"decrypted to plaintext. WARNING: keys are now unencrypted at rest."
)
@@ -0,0 +1,37 @@
"""Add monotonic serial counter to CAs table.
Each CA now owns a `next_serial_number` (BigInteger) that is atomically
incremented every time a certificate is signed. This guarantees:
- Serials are unique per CA
- Serials are monotonically increasing (auditable, no gaps by accident)
- The value embedded in the OpenSSH certificate matches what is stored
in the `ssh_certificates.serial` column
Revision ID: 017_add_ca_serial_counter
Revises: 016_encrypt_ca_keys
Create Date: 2026-03-02
"""
from alembic import op
import sqlalchemy as sa
revision = "017_add_ca_serial_counter"
down_revision = "016_encrypt_ca_keys"
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table("cas", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"next_serial_number",
sa.BigInteger(),
nullable=False,
server_default="1",
)
)
def downgrade():
with op.batch_alter_table("cas", schema=None) as batch_op:
batch_op.drop_column("next_serial_number")
@@ -0,0 +1,52 @@
"""Add ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE to auditaction enum.
Revision ID: 018_audit_enum_values
Revises: 017_add_ca_serial_counter
Create Date: 2026-03-02
ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE were added to the Python
AuditAction enum but were never synced to the PostgreSQL auditaction type,
causing a DataError (invalid enum value) when transferring org ownership
or hard-deleting a user.
"""
from alembic import op
revision = "018_audit_enum_values"
down_revision = "017_add_ca_serial_counter"
branch_labels = None
depends_on = None
def upgrade():
# ALTER TYPE ... ADD VALUE cannot run inside a transaction block in PostgreSQL.
# Alembic has already opened a transaction on the connection by the time our
# upgrade() runs, so we must:
# 1. Roll back that open transaction on the raw psycopg2 connection.
# 2. Switch to autocommit so the ALTER TYPE runs outside any transaction.
# 3. Restore the previous state afterwards.
conn = op.get_bind()
# SQLAlchemy 2.x: conn.connection is a _ConnectionFairy; .driver_connection is psycopg2
fairy = conn.connection
raw = getattr(fairy, "driver_connection", None) or getattr(fairy, "dbapi_connection", fairy)
# Roll back the open transaction so psycopg2 allows us to change autocommit.
raw.rollback()
old_autocommit = raw.autocommit
raw.autocommit = True
try:
with raw.cursor() as cur:
for val in ("ORG_OWNERSHIP_TRANSFERRED", "USER_HARD_DELETE"):
cur.execute(
"SELECT 1 FROM pg_enum "
"WHERE enumlabel = %s "
"AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')",
(val,),
)
if not cur.fetchone():
cur.execute(f"ALTER TYPE auditaction ADD VALUE '{val}'")
finally:
raw.autocommit = old_autocommit
def downgrade():
# PostgreSQL does not support removing enum values; downgrade is a no-op.
pass