Fix(Feat): CA, Audits, Rte Limit
CA Encryption, Serials, Rate Limiter, Account suspension blocks login Transfer Ownership & Delete Account
This commit is contained in:
@@ -51,8 +51,7 @@ class MyServer(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
self.wfile.write(bytes("<html><head><title>OIDC Workflow Tool</title></head>", "utf-8"))
|
||||
self.wfile.write(bytes("<body><p>The token has been received</p>", "utf-8"))
|
||||
self.wfile.write(bytes("<p>Window closing in <span id='countdown'>5</span> seconds...</p>", "utf-8"))
|
||||
self.wfile.write(bytes("<script>var count = 5; setInterval(function() { count--; document.getElementById('countdown').textContent = count; if (count === 0) window.close(); }, 1000);</script>", "utf-8"))
|
||||
self.wfile.write(bytes("<p>You may now close this window.</p>", "utf-8"))
|
||||
self.wfile.write(bytes("</body></html>", "utf-8"))
|
||||
|
||||
parsed_url = urlparse(self.path)
|
||||
|
||||
@@ -28,6 +28,11 @@ class BaseConfig:
|
||||
|
||||
# Encryption key for sensitive data (client secrets, tokens, etc.)
|
||||
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "dev-encryption-key-change-in-production")
|
||||
|
||||
# Encryption key for CA private keys stored in the database.
|
||||
# Must be set to a strong random secret in production.
|
||||
# Any string is accepted — it is SHA-256 derived to a 32-byte Fernet key internally.
|
||||
CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "dev-ca-encryption-key-change-in-production")
|
||||
|
||||
# Session configuration for WebAuthn cross-origin support
|
||||
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "True").lower() == "true"
|
||||
@@ -72,6 +77,13 @@ class BaseConfig:
|
||||
RATELIMIT_STORAGE_URL = os.getenv("RATELIMIT_STORAGE_URL", "redis://localhost:6379/1")
|
||||
RATELIMIT_DEFAULT = "100/hour"
|
||||
|
||||
# Per-endpoint auth rate limits (override via env vars for each environment)
|
||||
RATELIMIT_AUTH_REGISTER = os.getenv("RATELIMIT_AUTH_REGISTER", "10 per minute; 50 per hour")
|
||||
RATELIMIT_AUTH_LOGIN = os.getenv("RATELIMIT_AUTH_LOGIN", "20 per minute; 100 per hour")
|
||||
RATELIMIT_AUTH_TOTP_VERIFY = os.getenv("RATELIMIT_AUTH_TOTP_VERIFY", "20 per minute; 100 per hour")
|
||||
RATELIMIT_AUTH_FORGOT_PASSWORD = os.getenv("RATELIMIT_AUTH_FORGOT_PASSWORD", "5 per minute; 20 per hour")
|
||||
RATELIMIT_AUTH_RESET_PASSWORD = os.getenv("RATELIMIT_AUTH_RESET_PASSWORD", "10 per minute; 30 per hour")
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true"
|
||||
|
||||
@@ -12,6 +12,9 @@ class TestingConfig(BaseConfig):
|
||||
# Explicitly set SECRET_KEY for testing
|
||||
SECRET_KEY = os.getenv("SECRET_KEY", "test-secret-key-for-testing")
|
||||
|
||||
# CA key encryption — use a fixed test key so tests are deterministic
|
||||
CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "test-ca-encryption-key-fixed-for-tests")
|
||||
|
||||
# Use in-memory SQLite for testing
|
||||
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
|
||||
SQLALCHEMY_ECHO = False
|
||||
|
||||
@@ -22,7 +22,11 @@ from gatehouse_app.extensions import bcrypt as flask_bcrypt
|
||||
from gatehouse_app.extensions import redis_client as _redis_client_ref # may be None until app init
|
||||
from gatehouse_app.models import User, OIDCClient
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||
from gatehouse_app.exceptions.auth_exceptions import (
|
||||
InvalidCredentialsError,
|
||||
AccountSuspendedError,
|
||||
AccountInactiveError,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for Redis-backed OIDC pending state
|
||||
@@ -343,6 +347,20 @@ def oidc_complete():
|
||||
|
||||
user_id = str(gh_session.user_id)
|
||||
|
||||
# Check the user is still active (not suspended after session was issued)
|
||||
from gatehouse_app.models.user.user import User as _User
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
_complete_user = _User.query.filter_by(id=user_id, deleted_at=None).first()
|
||||
if not _complete_user or _complete_user.status in (
|
||||
UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED, UserStatus.INACTIVE
|
||||
):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Your account is not active or has been suspended.",
|
||||
status=403,
|
||||
error_type="ACCOUNT_SUSPENDED",
|
||||
)
|
||||
|
||||
# Retrieve stashed OIDC params (consume = True removes from Redis atomically)
|
||||
params = _fetch_oidc_params(oidc_session_id, consume=True)
|
||||
if not params:
|
||||
@@ -565,6 +583,28 @@ def oidc_authorize():
|
||||
session["oidc_user_id"] = user_id
|
||||
|
||||
logger.debug("[OIDC] User authentication successful: user_id=%s, email=%s", user_id, email)
|
||||
except AccountSuspendedError:
|
||||
logger.debug("[OIDC] User authentication failed: account suspended for email=%s", email)
|
||||
return _show_login_page(
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
response_type=response_type,
|
||||
error="Your account has been suspended. Please contact an administrator.",
|
||||
)
|
||||
except AccountInactiveError:
|
||||
logger.debug("[OIDC] User authentication failed: account inactive for email=%s", email)
|
||||
return _show_login_page(
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
response_type=response_type,
|
||||
error="Your account is not active. Please verify your email.",
|
||||
)
|
||||
except InvalidCredentialsError:
|
||||
logger.debug("[OIDC] User authentication failed: invalid credentials for email=%s", email)
|
||||
return _show_login_page(
|
||||
@@ -600,7 +640,34 @@ def oidc_authorize():
|
||||
if not user:
|
||||
logger.debug("[OIDC] Redirecting with error: server_error (user not found)")
|
||||
return _redirect_with_error(redirect_uri, "server_error", "User not found", state)
|
||||
|
||||
|
||||
# Check account is still active (user could have been suspended after session start)
|
||||
from gatehouse_app.utils.constants import UserStatus as _UserStatus
|
||||
if user.status in (_UserStatus.SUSPENDED, _UserStatus.COMPLIANCE_SUSPENDED):
|
||||
session.pop("oidc_user_id", None) # clear stale session
|
||||
logger.debug("[OIDC] User is suspended, clearing session and showing login error: user_id=%s", user_id)
|
||||
return _show_login_page(
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
response_type=response_type,
|
||||
error="Your account has been suspended. Please contact an administrator.",
|
||||
)
|
||||
if user.status == _UserStatus.INACTIVE:
|
||||
session.pop("oidc_user_id", None)
|
||||
logger.debug("[OIDC] User is inactive, clearing session and showing login error: user_id=%s", user_id)
|
||||
return _show_login_page(
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
response_type=response_type,
|
||||
error="Your account is not active. Please verify your email.",
|
||||
)
|
||||
|
||||
logger.debug("[OIDC] Generating authorization code...")
|
||||
logger.debug("[OIDC] Authorization code params: client_id=%s, user_id=%s, redirect_uri=%s", client_id, user_id, redirect_uri)
|
||||
logger.debug("[OIDC] Authorization code params: scopes=%s, state=%s, nonce=%s", valid_scopes, state, nonce)
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
from flask import request, session, g, jsonify, current_app
|
||||
from marshmallow import ValidationError
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.extensions import limiter
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.schemas.auth_schema import (
|
||||
RegisterSchema,
|
||||
@@ -32,6 +33,7 @@ from gatehouse_app.exceptions.validation_exceptions import ConflictError, NotFou
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/register", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_REGISTER"])
|
||||
def register():
|
||||
"""
|
||||
Register a new user.
|
||||
@@ -135,6 +137,7 @@ def register():
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/login", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_LOGIN"])
|
||||
def login():
|
||||
"""
|
||||
Login user.
|
||||
@@ -325,8 +328,13 @@ def get_current_user():
|
||||
data={
|
||||
"user": user.to_dict(),
|
||||
"organizations": [
|
||||
{"id": org.id, "name": org.name, "slug": org.slug}
|
||||
for org in user.get_organizations()
|
||||
{
|
||||
"id": membership.organization.id,
|
||||
"name": membership.organization.name,
|
||||
"slug": membership.organization.slug,
|
||||
"role": membership.role,
|
||||
}
|
||||
for membership in user.organization_memberships
|
||||
],
|
||||
},
|
||||
message="User retrieved successfully",
|
||||
@@ -478,6 +486,7 @@ def verify_totp_enrollment():
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/totp/verify", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_TOTP_VERIFY"])
|
||||
def verify_totp():
|
||||
"""
|
||||
Verify TOTP code during login.
|
||||
@@ -520,6 +529,18 @@ def verify_totp():
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
)
|
||||
|
||||
# Check account suspension before completing TOTP verification
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
session.pop("totp_pending_user_id", None)
|
||||
session.pop("webauthn_pending_user_id", None)
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Account is suspended. Contact an administrator.",
|
||||
status=403,
|
||||
error_type="ACCOUNT_SUSPENDED",
|
||||
)
|
||||
|
||||
# Verify TOTP code
|
||||
AuthService.authenticate_with_totp(
|
||||
user,
|
||||
@@ -908,7 +929,18 @@ def begin_webauthn_login():
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
|
||||
# Check account suspension before proceeding
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
logger.warning(f"WebAuthn login begin - suspended account attempt: {user.email}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Account is suspended. Contact an administrator.",
|
||||
status=403,
|
||||
error_type="ACCOUNT_SUSPENDED",
|
||||
)
|
||||
|
||||
# Check if user has any WebAuthn credentials
|
||||
if not user.has_webauthn_enabled():
|
||||
logger.warning(f"WebAuthn login begin - no credentials for user: {user.email}")
|
||||
@@ -991,7 +1023,19 @@ def complete_webauthn_login():
|
||||
status=401,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
)
|
||||
|
||||
|
||||
# Check account suspension before completing login
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
session.pop("webauthn_pending_user_id", None)
|
||||
logger.warning(f"WebAuthn login complete - suspended account attempt: {user.email}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Account is suspended. Contact an administrator.",
|
||||
status=403,
|
||||
error_type="ACCOUNT_SUSPENDED",
|
||||
)
|
||||
|
||||
# Extract challenge from client data
|
||||
client_data = data.get("response", {}).get("clientDataJSON", "")
|
||||
|
||||
@@ -1129,6 +1173,19 @@ def delete_webauthn_credential(credential_id):
|
||||
"""
|
||||
user = g.current_user
|
||||
|
||||
# First check that the specific credential actually belongs to this user.
|
||||
# Only then check whether it is the last one — otherwise a user with zero
|
||||
# credentials gets a misleading "Cannot delete the last passkey" error
|
||||
# instead of a 404.
|
||||
credential_exists = WebAuthnService.credential_belongs_to_user(credential_id, user)
|
||||
if not credential_exists:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Credential not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Check if this is the last credential
|
||||
credential_count = user.get_webauthn_credential_count()
|
||||
if credential_count <= 1:
|
||||
@@ -1238,6 +1295,7 @@ _pw_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/forgot-password", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_FORGOT_PASSWORD"])
|
||||
def forgot_password():
|
||||
"""Request a password reset email.
|
||||
|
||||
@@ -1294,6 +1352,7 @@ def forgot_password():
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/reset-password", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESET_PASSWORD"])
|
||||
def reset_password():
|
||||
"""Reset a user's password using a reset token.
|
||||
|
||||
@@ -1601,11 +1660,31 @@ def get_token():
|
||||
302: Redirect to ``<redirect>?token=<token>``
|
||||
"""
|
||||
from flask import redirect as flask_redirect
|
||||
from urllib.parse import urlparse
|
||||
|
||||
token = g.current_session.token
|
||||
redirect_url = request.args.get("redirect", "").strip()
|
||||
|
||||
if redirect_url:
|
||||
# Validate redirect URL against allowed origins to prevent open-redirect
|
||||
# token exfiltration attacks (CWE-601).
|
||||
allowed_origins = set(current_app.config.get("CORS_ORIGINS", []))
|
||||
frontend_url = current_app.config.get("FRONTEND_URL", "")
|
||||
if frontend_url:
|
||||
parsed = urlparse(frontend_url)
|
||||
allowed_origins.add(f"{parsed.scheme}://{parsed.netloc}")
|
||||
|
||||
parsed_redirect = urlparse(redirect_url)
|
||||
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
|
||||
|
||||
if redirect_origin not in allowed_origins:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Redirect URL is not allowed.",
|
||||
status=400,
|
||||
error_type="INVALID_REDIRECT",
|
||||
)
|
||||
|
||||
sep = "&" if "?" in redirect_url else "?"
|
||||
return flask_redirect(f"{redirect_url}{sep}token={token}", code=302)
|
||||
|
||||
|
||||
@@ -226,6 +226,10 @@ def delete_organization(org_id):
|
||||
"""
|
||||
Delete organization (soft delete).
|
||||
|
||||
The owner may only delete the organization if they are the *sole* remaining
|
||||
member. If other active members exist they must first transfer ownership
|
||||
(or remove all other members) before deleting the organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
|
||||
@@ -234,9 +238,26 @@ def delete_organization(org_id):
|
||||
401: Not authenticated
|
||||
403: Not the owner
|
||||
404: Organization not found
|
||||
409: Organization still has other members — transfer ownership first
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
# Guard: block deletion while non-owner members still exist so ownership
|
||||
# can be transferred rather than silently orphaning them.
|
||||
active_member_count = org.get_member_count()
|
||||
if active_member_count > 1:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=(
|
||||
"This organization still has other members. "
|
||||
"Please transfer ownership to another member or remove all "
|
||||
"other members before deleting the organization."
|
||||
),
|
||||
status=409,
|
||||
error_type="ORG_HAS_MEMBERS",
|
||||
error_details={"member_count": active_member_count},
|
||||
)
|
||||
|
||||
OrganizationService.delete_organization(
|
||||
org=org,
|
||||
user_id=g.current_user.id,
|
||||
@@ -446,6 +467,152 @@ def update_member_role(org_id, user_id):
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/transfer-ownership", methods=["POST"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def transfer_organization_ownership(org_id):
|
||||
"""Transfer organization ownership from the current user to another member.
|
||||
|
||||
Only the current OWNER of the organization may call this endpoint.
|
||||
The caller will be demoted to ADMIN and the target user will be promoted to OWNER.
|
||||
|
||||
Request body:
|
||||
new_owner_user_id (str): UUID of the member to promote to OWNER.
|
||||
|
||||
Returns:
|
||||
200: Ownership transferred successfully
|
||||
400: Validation error / missing fields
|
||||
403: Caller is not the OWNER of this org
|
||||
404: Organization or target member not found
|
||||
409: Target is already the OWNER
|
||||
"""
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.utils.constants import OrganizationRole, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
caller = g.current_user
|
||||
|
||||
data = request.get_json() or {}
|
||||
new_owner_user_id = data.get("new_owner_user_id")
|
||||
if not new_owner_user_id:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="new_owner_user_id is required",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
if str(new_owner_user_id) == str(caller.id):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are already the owner of this organization.",
|
||||
status=409,
|
||||
error_type="CONFLICT",
|
||||
)
|
||||
|
||||
# Fetch org (raises NotFound internally)
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
# Confirm caller is the current OWNER
|
||||
caller_membership = OrganizationMember.query.filter_by(
|
||||
organization_id=org.id,
|
||||
user_id=caller.id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if not caller_membership or caller_membership.role != OrganizationRole.OWNER:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Only the organization owner can transfer ownership.",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
# Verify the target is an active member
|
||||
target_membership = OrganizationMember.query.filter_by(
|
||||
organization_id=org.id,
|
||||
user_id=new_owner_user_id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if not target_membership:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Target user is not a member of this organization.",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
if target_membership.role == OrganizationRole.OWNER:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Target user is already the owner.",
|
||||
status=409,
|
||||
error_type="CONFLICT",
|
||||
)
|
||||
|
||||
# ── Atomic role swap ─────────────────────────────────────────────────────
|
||||
# Demote caller → ADMIN, promote target → OWNER.
|
||||
# Both updates go through OrganizationService so all hooks/auditing fire.
|
||||
try:
|
||||
demoted = OrganizationService.update_member_role(
|
||||
org=org,
|
||||
user_id=str(caller.id),
|
||||
new_role=OrganizationRole.ADMIN,
|
||||
updater_id=str(caller.id),
|
||||
)
|
||||
promoted = OrganizationService.update_member_role(
|
||||
org=org,
|
||||
user_id=str(new_owner_user_id),
|
||||
new_role=OrganizationRole.OWNER,
|
||||
updater_id=str(caller.id),
|
||||
)
|
||||
except Exception as exc:
|
||||
from gatehouse_app.extensions import db as _db
|
||||
_db.session.rollback()
|
||||
return api_response(
|
||||
success=False,
|
||||
message=f"Failed to transfer ownership: {exc}",
|
||||
status=500,
|
||||
error_type="SERVER_ERROR",
|
||||
)
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_OWNERSHIP_TRANSFERRED,
|
||||
user_id=caller.id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization",
|
||||
resource_id=str(org.id),
|
||||
description=(
|
||||
f"Ownership of '{org.name}' transferred from {caller.email} "
|
||||
f"to {target_membership.user.email if target_membership.user else new_owner_user_id}"
|
||||
),
|
||||
metadata={
|
||||
"previous_owner_id": str(caller.id),
|
||||
"previous_owner_email": caller.email,
|
||||
"new_owner_id": str(new_owner_user_id),
|
||||
"new_owner_email": (
|
||||
target_membership.user.email if target_membership.user else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def _member_dict(m):
|
||||
d = m.to_dict()
|
||||
if m.user:
|
||||
d["user"] = m.user.to_dict()
|
||||
return d
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"previous_owner": _member_dict(demoted),
|
||||
"new_owner": _member_dict(promoted),
|
||||
},
|
||||
message=(
|
||||
f"Ownership of '{org.name}' successfully transferred to "
|
||||
f"{target_membership.user.email if target_membership.user else new_owner_user_id}."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/audit-logs", methods=["GET"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@@ -756,10 +923,30 @@ def accept_invite(token):
|
||||
inviter_id=invite.invited_by_id,
|
||||
)
|
||||
except Exception:
|
||||
pass # Already a member is fine
|
||||
from gatehouse_app.extensions import db
|
||||
db.session.rollback() # Clear broken transaction so invite.accept() can commit
|
||||
|
||||
invite.accept()
|
||||
|
||||
has_webauthn = user.has_webauthn_enabled()
|
||||
has_totp = user.has_totp_enabled()
|
||||
|
||||
if has_webauthn:
|
||||
from flask import session as flask_session
|
||||
flask_session["webauthn_pending_user_id"] = user.id
|
||||
return api_response(
|
||||
data={"requires_webauthn": True},
|
||||
message="Passkey verification required. Please use your passkey to complete sign-in.",
|
||||
)
|
||||
|
||||
if has_totp:
|
||||
from flask import session as flask_session
|
||||
flask_session["totp_pending_user_id"] = user.id
|
||||
return api_response(
|
||||
data={"requires_totp": True},
|
||||
message="TOTP code required. Please enter your 6-digit code from your authenticator app.",
|
||||
)
|
||||
|
||||
user_session = AuthService.create_session(user)
|
||||
|
||||
return api_response(
|
||||
@@ -1379,6 +1566,7 @@ def create_org_ca(org_id):
|
||||
from gatehouse_app.models.ssh_ca.ca import CA, KeyType
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
||||
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
|
||||
from marshmallow import Schema, fields as ma_fields, validate, ValidationError as MaValidationError
|
||||
from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
|
||||
|
||||
@@ -1448,13 +1636,16 @@ def create_org_ca(org_id):
|
||||
public_key_str = private_key_obj.public_key.to_string()
|
||||
fingerprint = compute_ssh_fingerprint(public_key_str)
|
||||
|
||||
# Encrypt the private key before storing in the database
|
||||
encrypted_private_key = encrypt_ca_key(private_key_pem)
|
||||
|
||||
ca = CA(
|
||||
organization_id=org_id,
|
||||
name=data["name"],
|
||||
description=data["description"],
|
||||
ca_type=CaType(ca_type_val),
|
||||
key_type=KeyType(key_type),
|
||||
private_key=private_key_pem,
|
||||
private_key=encrypted_private_key,
|
||||
public_key=public_key_str,
|
||||
fingerprint=fingerprint,
|
||||
default_cert_validity_hours=data["default_cert_validity_hours"],
|
||||
@@ -1462,7 +1653,24 @@ def create_org_ca(org_id):
|
||||
is_active=True,
|
||||
)
|
||||
db.session.add(ca)
|
||||
db.session.commit()
|
||||
try:
|
||||
db.session.commit()
|
||||
except Exception as commit_exc:
|
||||
db.session.rollback()
|
||||
# Surface unique-constraint violations (soft-deleted record with same name) as a
|
||||
# user-friendly 400 instead of a 500.
|
||||
exc_str = str(commit_exc).lower()
|
||||
if "uix_org_ca_name" in exc_str or "unique" in exc_str:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=(
|
||||
"A CA with that name already exists in this organization "
|
||||
"(it may have been recently deleted — choose a different name)."
|
||||
),
|
||||
status=400,
|
||||
error_type="DUPLICATE_NAME",
|
||||
)
|
||||
raise
|
||||
|
||||
return api_response(
|
||||
data={"ca": ca.to_dict()},
|
||||
@@ -1570,6 +1778,7 @@ def rotate_org_ca(org_id, ca_id):
|
||||
from gatehouse_app.models.ssh_ca.ca import CA, KeyType
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
||||
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.models import AuditLog
|
||||
from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
|
||||
@@ -1609,8 +1818,11 @@ def rotate_org_ca(org_id, ca_id):
|
||||
new_public_key = private_key_obj.public_key.to_string()
|
||||
new_fingerprint = compute_ssh_fingerprint(new_public_key)
|
||||
|
||||
# Encrypt the new private key before storing
|
||||
encrypted_new_private_key = encrypt_ca_key(new_private_key)
|
||||
|
||||
ca.rotate_key(
|
||||
new_private_key=new_private_key,
|
||||
new_private_key=encrypted_new_private_key,
|
||||
new_public_key=new_public_key,
|
||||
new_fingerprint=new_fingerprint,
|
||||
reason=reason,
|
||||
|
||||
@@ -15,6 +15,7 @@ from gatehouse_app.exceptions import (
|
||||
)
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.models import AuditLog
|
||||
from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
|
||||
from gatehouse_app.utils.decorators import login_required
|
||||
from gatehouse_app.utils.response import api_response
|
||||
|
||||
@@ -78,11 +79,16 @@ def _get_or_create_system_ca():
|
||||
with open(pub_key_path) as f:
|
||||
pub_key = f.read().strip()
|
||||
|
||||
# Load private key for the record (stored but not actually used for signing here)
|
||||
# Load private key for the record (encrypt before storing in DB)
|
||||
priv_key = ""
|
||||
if os.path.exists(key_path):
|
||||
with open(key_path) as f:
|
||||
priv_key = f.read()
|
||||
raw_priv_key = f.read()
|
||||
try:
|
||||
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
|
||||
priv_key = encrypt_ca_key(raw_priv_key)
|
||||
except Exception:
|
||||
priv_key = raw_priv_key # fallback: store as-is if encryption unavailable
|
||||
|
||||
fingerprint = compute_ssh_fingerprint(pub_key)
|
||||
|
||||
@@ -120,7 +126,7 @@ def _get_or_create_system_ca():
|
||||
return None
|
||||
|
||||
|
||||
def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user'):
|
||||
def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user', cert_identity=None):
|
||||
"""Save a signed certificate to the ssh_certificates table.
|
||||
|
||||
Args:
|
||||
@@ -130,6 +136,8 @@ def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=N
|
||||
signing_response: SSHCertificateSigningResponse
|
||||
request_ip: Client IP address
|
||||
cert_type_str: 'user' or 'host' (from the sign request)
|
||||
cert_identity: Rich OpenSSH key_id string (e.g. "user@host (Name) [org:slug]").
|
||||
Falls back to str(ssh_key_id) when not provided.
|
||||
|
||||
Returns:
|
||||
SSHCertificate instance or None if persistence failed
|
||||
@@ -153,7 +161,7 @@ def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=N
|
||||
ssh_key_id=ssh_key_id,
|
||||
certificate=signing_response.certificate,
|
||||
serial=signing_response.serial,
|
||||
key_id=str(ssh_key_id),
|
||||
key_id=cert_identity or str(ssh_key_id),
|
||||
cert_type=resolved_cert_type,
|
||||
principals=signing_response.principals,
|
||||
valid_after=signing_response.valid_after,
|
||||
@@ -465,7 +473,7 @@ def sign_certificate():
|
||||
|
||||
# ── Check account suspension ──────────────────────────────────────────────
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
if user.status == UserStatus.SUSPENDED:
|
||||
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Your account is suspended. Contact an administrator.",
|
||||
@@ -482,6 +490,18 @@ def sign_certificate():
|
||||
key_id = data.get('key_id') or data.get('cert_id')
|
||||
expiry_hours = data.get('expiry_hours')
|
||||
|
||||
# ── Log the request ───────────────────────────────────────────────────────
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_CERT_REQUESTED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHCertificate',
|
||||
ip_address=request.remote_addr,
|
||||
description=(
|
||||
f'{user.email} requested a certificate'
|
||||
+ (f' for principals: {", ".join(requested_principals)}' if requested_principals else '')
|
||||
),
|
||||
)
|
||||
|
||||
# ── Resolve which principals the user is allowed to use ──────────────────
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
|
||||
@@ -601,11 +621,24 @@ def sign_certificate():
|
||||
else:
|
||||
policy_extensions = None # let signing service use its own defaults
|
||||
|
||||
# ── Build rich key_id identity for the OpenSSH cert ─────────────────────
|
||||
# This appears in `ssh-keygen -L -f cert.pub` as the Key ID field and
|
||||
# is stored in the DB cert record so it's auditable.
|
||||
org_slugs = sorted({
|
||||
om.organization.slug
|
||||
for om in memberships
|
||||
if om.organization and om.organization.deleted_at is None
|
||||
and getattr(om.organization, 'slug', None)
|
||||
})
|
||||
org_slug = org_slugs[0] if org_slugs else "unknown"
|
||||
full_name = getattr(user, 'full_name', None) or getattr(user, 'name', None) or "unknown"
|
||||
cert_identity = f"{user.email} ({full_name}) [org:{org_slug}]"
|
||||
|
||||
signing_request = SSHCertificateSigningRequest(
|
||||
ssh_public_key=ssh_key.payload,
|
||||
principals=principals,
|
||||
cert_type=cert_type,
|
||||
key_id=key_id,
|
||||
key_id=cert_identity,
|
||||
expiry_hours=int(expiry_hours) if expiry_hours else None,
|
||||
extensions=policy_extensions,
|
||||
)
|
||||
@@ -620,7 +653,11 @@ def sign_certificate():
|
||||
)
|
||||
|
||||
try:
|
||||
response = ssh_ca_service.sign_certificate(signing_request, ca_private_key=db_ca.private_key)
|
||||
from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
|
||||
ca_private_key_pem = decrypt_ca_key(db_ca.private_key)
|
||||
response = ssh_ca_service.sign_certificate(
|
||||
signing_request, ca_private_key=ca_private_key_pem, ca_obj=db_ca
|
||||
)
|
||||
except SSHCertificateError as e:
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_CERT_FAILED,
|
||||
@@ -649,6 +686,7 @@ def sign_certificate():
|
||||
signing_response=response,
|
||||
request_ip=request.remote_addr,
|
||||
cert_type_str=cert_type,
|
||||
cert_identity=cert_identity,
|
||||
)
|
||||
|
||||
AuditLog.log(
|
||||
@@ -657,9 +695,42 @@ def sign_certificate():
|
||||
resource_type='SSHCertificate',
|
||||
resource_id=cert_record.id if cert_record else key_id,
|
||||
ip_address=request.remote_addr,
|
||||
description=f'Certificate issued for principals: {", ".join(principals)}',
|
||||
description=(
|
||||
f'Certificate serial={response.serial} issued for {user.email}; '
|
||||
f'principals: {", ".join(principals)}'
|
||||
),
|
||||
extra_data={
|
||||
'serial': response.serial,
|
||||
'key_id': cert_identity,
|
||||
'principals': principals,
|
||||
'ca_id': str(db_ca.id),
|
||||
'ssh_key_id': str(key_id),
|
||||
},
|
||||
)
|
||||
|
||||
if cert_record:
|
||||
CertificateAuditLog.log(
|
||||
certificate_id=cert_record.id,
|
||||
action='issued',
|
||||
user_id=user_id,
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get('User-Agent'),
|
||||
message=(
|
||||
f'Certificate serial={response.serial} issued for {user.email}; '
|
||||
f'principals: {", ".join(principals)}'
|
||||
),
|
||||
extra_data={
|
||||
'serial': response.serial,
|
||||
'key_id': cert_identity,
|
||||
'principals': principals,
|
||||
'ca_id': str(db_ca.id),
|
||||
'ssh_key_id': str(key_id),
|
||||
'valid_after': response.valid_after.isoformat() if response.valid_after else None,
|
||||
'valid_before': response.valid_before.isoformat() if response.valid_before else None,
|
||||
},
|
||||
success=True,
|
||||
)
|
||||
|
||||
result = {
|
||||
'certificate': response.certificate,
|
||||
'serial': response.serial,
|
||||
@@ -753,6 +824,16 @@ def revoke_certificate(cert_id):
|
||||
description=f'Revoked: {reason}',
|
||||
)
|
||||
|
||||
CertificateAuditLog.log(
|
||||
certificate_id=cert_id,
|
||||
action='revoked',
|
||||
user_id=user_id,
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get('User-Agent'),
|
||||
message=f'Certificate revoked: {reason}',
|
||||
success=True,
|
||||
)
|
||||
|
||||
return api_response(
|
||||
success=True,
|
||||
message='Certificate revoked successfully',
|
||||
|
||||
@@ -73,11 +73,51 @@ def delete_me():
|
||||
"""
|
||||
Delete current user account (soft delete).
|
||||
|
||||
Blocked if the user is the sole owner of any organization that has other
|
||||
active members — they must transfer ownership or dissolve those organizations
|
||||
first.
|
||||
|
||||
Returns:
|
||||
200: Account deleted successfully
|
||||
401: Not authenticated
|
||||
409: User is sole owner of one or more organizations with other members
|
||||
"""
|
||||
UserService.delete_user(g.current_user, soft=True)
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
user = g.current_user
|
||||
|
||||
# Find orgs where this user is the sole owner AND other members exist.
|
||||
owned_memberships = OrganizationMember.query.filter_by(
|
||||
user_id=user.id,
|
||||
role=OrganizationRole.OWNER,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
|
||||
blocked_orgs = []
|
||||
for membership in owned_memberships:
|
||||
org = membership.organization
|
||||
if org.deleted_at is not None:
|
||||
continue
|
||||
member_count = org.get_member_count()
|
||||
if member_count > 1:
|
||||
blocked_orgs.append(org.name)
|
||||
|
||||
if blocked_orgs:
|
||||
names = ", ".join(f'"{n}"' for n in blocked_orgs)
|
||||
return api_response(
|
||||
success=False,
|
||||
message=(
|
||||
f"You are the sole owner of {len(blocked_orgs)} organization"
|
||||
f"{'s' if len(blocked_orgs) > 1 else ''}: {names}. "
|
||||
"Transfer ownership or delete those organizations before deleting your account."
|
||||
),
|
||||
status=409,
|
||||
error_type="USER_IS_SOLE_OWNER",
|
||||
error_details={"organizations": blocked_orgs},
|
||||
)
|
||||
|
||||
UserService.delete_user(user, soft=True)
|
||||
|
||||
return api_response(
|
||||
message="Account deleted successfully",
|
||||
@@ -454,6 +494,31 @@ def admin_suspend_user(user_id):
|
||||
if not admin_in_shared_org:
|
||||
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
|
||||
|
||||
# ── Owner protection ──────────────────────────────────────────────────────
|
||||
# An org owner cannot be suspended until they transfer ownership.
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
owner_memberships = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == target.id,
|
||||
OrganizationMember.role == OrganizationRole.OWNER,
|
||||
OrganizationMember.deleted_at == None,
|
||||
).all()
|
||||
if owner_memberships:
|
||||
org_names = [
|
||||
m.organization.name
|
||||
for m in owner_memberships
|
||||
if m.organization and not m.organization.deleted_at
|
||||
]
|
||||
return api_response(
|
||||
success=False,
|
||||
message=(
|
||||
f"Cannot suspend an organization owner. "
|
||||
f"{target.email} is the owner of: {', '.join(org_names)}. "
|
||||
"Transfer ownership to another member first."
|
||||
),
|
||||
status=403,
|
||||
error_type="OWNER_PROTECTION",
|
||||
)
|
||||
|
||||
if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT")
|
||||
|
||||
@@ -645,3 +710,158 @@ def get_my_memberships():
|
||||
data={"orgs": orgs_result},
|
||||
message="Memberships retrieved",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/admin/users/<user_id>/delete", methods=["POST"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def admin_hard_delete_user(user_id):
|
||||
"""Permanently delete a user and ALL associated data (hard delete, irreversible).
|
||||
|
||||
Required body: {"confirm": true}
|
||||
|
||||
Pre-conditions:
|
||||
- Caller is OWNER or ADMIN of a shared org with the target.
|
||||
- Cannot delete yourself.
|
||||
- Target must not be the OWNER of any active organization (transfer first).
|
||||
|
||||
Side-effects:
|
||||
- All active SSH certificates are revoked before deletion.
|
||||
- The user row and all cascaded rows are hard-deleted from the database.
|
||||
- An audit log entry is written by the *caller* (so it is not lost with the user).
|
||||
"""
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.user.user import User as _User
|
||||
from gatehouse_app.extensions import db as _db
|
||||
from gatehouse_app.utils.constants import UserStatus, AuditAction, OrganizationRole
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
caller = g.current_user
|
||||
data = request.get_json() or {}
|
||||
|
||||
if not data.get("confirm"):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Deletion requires explicit confirmation. Send {\"confirm\": true} to proceed.",
|
||||
status=400,
|
||||
error_type="CONFIRMATION_REQUIRED",
|
||||
)
|
||||
|
||||
target = _User.query.filter_by(id=user_id).first()
|
||||
if not target:
|
||||
return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
|
||||
|
||||
if target.id == caller.id:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Cannot delete your own account via this endpoint.",
|
||||
status=400,
|
||||
error_type="BAD_REQUEST",
|
||||
)
|
||||
|
||||
# Caller must be OWNER/ADMIN of a shared org.
|
||||
# Include soft-deleted memberships so that already-soft-deleted users can
|
||||
# still be hard-deleted by an admin who shared an org with them.
|
||||
target_org_ids = {m.organization_id for m in target.organization_memberships}
|
||||
admin_in_shared_org = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == caller.id,
|
||||
OrganizationMember.organization_id.in_(target_org_ids),
|
||||
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
|
||||
OrganizationMember.deleted_at == None,
|
||||
).first()
|
||||
if not admin_in_shared_org:
|
||||
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
|
||||
|
||||
# Block deletion if target is an org owner — they must transfer first
|
||||
owner_memberships = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == target.id,
|
||||
OrganizationMember.role == OrganizationRole.OWNER,
|
||||
OrganizationMember.deleted_at == None,
|
||||
).all()
|
||||
if owner_memberships:
|
||||
org_names = [
|
||||
m.organization.name
|
||||
for m in owner_memberships
|
||||
if m.organization and not m.organization.deleted_at
|
||||
]
|
||||
return api_response(
|
||||
success=False,
|
||||
message=(
|
||||
f"Cannot delete an organization owner. "
|
||||
f"{target.email} is the owner of: {', '.join(org_names)}. "
|
||||
"Transfer ownership to another member first."
|
||||
),
|
||||
status=403,
|
||||
error_type="OWNER_PROTECTION",
|
||||
)
|
||||
|
||||
# ── Collect counts for audit metadata ────────────────────────────────────
|
||||
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
|
||||
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
|
||||
|
||||
ssh_key_count = SSHKey.query.filter_by(user_id=target.id, deleted_at=None).count()
|
||||
active_cert_count = SSHCertificate.query.filter_by(
|
||||
user_id=target.id, revoked=False
|
||||
).filter(SSHCertificate.deleted_at == None).count()
|
||||
|
||||
# ── Revoke all active SSH certificates before deletion ───────────────────
|
||||
active_certs = SSHCertificate.query.filter_by(
|
||||
user_id=target.id, revoked=False
|
||||
).filter(SSHCertificate.deleted_at == None).all()
|
||||
for cert in active_certs:
|
||||
try:
|
||||
cert.revoke("account_deleted")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if active_certs:
|
||||
try:
|
||||
_db.session.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Hard delete ───────────────────────────────────────────────────────────
|
||||
target_email = target.email # capture before deletion
|
||||
target_id_str = str(target.id)
|
||||
|
||||
try:
|
||||
_db.session.delete(target) # cascades to all child tables
|
||||
_db.session.flush()
|
||||
except Exception as exc:
|
||||
_db.session.rollback()
|
||||
import logging
|
||||
logging.getLogger(__name__).error(f"Hard delete failed for {target_id_str}: {exc}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Failed to delete user account. Please try again.",
|
||||
status=500,
|
||||
error_type="SERVER_ERROR",
|
||||
)
|
||||
|
||||
# ── Audit log (written as the caller so it survives the deletion) ─────────
|
||||
AuditService.log_action(
|
||||
action=AuditAction.USER_HARD_DELETE,
|
||||
user_id=caller.id,
|
||||
organization_id=admin_in_shared_org.organization_id,
|
||||
resource_type="user",
|
||||
resource_id=target_id_str,
|
||||
description=f"Admin permanently deleted user account: {target_email}",
|
||||
metadata={
|
||||
"deleted_user_id": target_id_str,
|
||||
"deleted_user_email": target_email,
|
||||
"ssh_keys_deleted": ssh_key_count,
|
||||
"certs_revoked": active_cert_count,
|
||||
},
|
||||
)
|
||||
|
||||
_db.session.commit()
|
||||
|
||||
return api_response(
|
||||
message=f"User account {target_email} has been permanently deleted.",
|
||||
data={
|
||||
"deleted_user_id": target_id_str,
|
||||
"deleted_user_email": target_email,
|
||||
"ssh_keys_deleted": ssh_key_count,
|
||||
"certs_revoked": active_cert_count,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -88,6 +88,11 @@ class CA(BaseModel):
|
||||
rotated_at = db.Column(db.DateTime, nullable=True)
|
||||
rotation_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Monotonically-increasing serial counter. Every cert this CA issues
|
||||
# gets the next value so serials are unique, ordered, and auditable.
|
||||
# Protected by a row-level SELECT … FOR UPDATE in get_next_serial().
|
||||
next_serial_number = db.Column(db.BigInteger, default=1, nullable=False)
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="cas")
|
||||
certificates = db.relationship(
|
||||
@@ -102,7 +107,6 @@ class CA(BaseModel):
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("organization_id", "name", name="uix_org_ca_name"),
|
||||
db.Index("idx_ca_org_active", "organization_id", "is_active"),
|
||||
)
|
||||
|
||||
@@ -162,6 +166,28 @@ class CA(BaseModel):
|
||||
self.rotation_reason = reason
|
||||
self.save()
|
||||
|
||||
def get_next_serial(self) -> int:
|
||||
"""Atomically increment and return the next certificate serial number.
|
||||
|
||||
Uses a SELECT … FOR UPDATE row lock so concurrent requests never
|
||||
receive the same serial. Must be called inside an active DB
|
||||
transaction (i.e. before the final session.commit()).
|
||||
|
||||
Returns:
|
||||
int: The serial number to embed in the next certificate.
|
||||
"""
|
||||
# Re-fetch this CA row with an exclusive row lock
|
||||
locked = (
|
||||
db.session.query(CA)
|
||||
.with_for_update()
|
||||
.filter_by(id=self.id)
|
||||
.one()
|
||||
)
|
||||
serial = locked.next_serial_number
|
||||
locked.next_serial_number = serial + 1
|
||||
db.session.flush() # write increment; commit happens in the caller
|
||||
return serial
|
||||
|
||||
|
||||
class CAPermission(BaseModel):
|
||||
"""Per-user CA permission model.
|
||||
|
||||
@@ -77,7 +77,10 @@ class TOTPVerifyEnrollmentSchema(Schema):
|
||||
class TOTPVerifySchema(Schema):
|
||||
"""Schema for TOTP code verification during login."""
|
||||
|
||||
code = fields.Str(required=True)
|
||||
code = fields.Str(
|
||||
required=True,
|
||||
validate=validate.Length(min=1),
|
||||
)
|
||||
is_backup_code = fields.Bool(load_default=False)
|
||||
client_timestamp = fields.Int(
|
||||
required=False,
|
||||
@@ -85,6 +88,27 @@ class TOTPVerifySchema(Schema):
|
||||
metadata={"description": "Client UTC timestamp in seconds since epoch for TOTP verification"},
|
||||
)
|
||||
|
||||
@validates_schema
|
||||
def validate_code_format(self, data, **kwargs):
|
||||
"""Validate code format depending on whether it's a backup code."""
|
||||
code = data.get("code", "")
|
||||
is_backup_code = data.get("is_backup_code", False)
|
||||
if is_backup_code:
|
||||
# Backup codes are 16 uppercase hex characters
|
||||
if not code or len(code) != 16 or not all(c in "0123456789ABCDEFabcdef" for c in code):
|
||||
raise ValidationError(
|
||||
"Backup code must be a 16-character hexadecimal string.",
|
||||
field_name="code",
|
||||
)
|
||||
else:
|
||||
# Regular TOTP codes are exactly 6 digits
|
||||
import re
|
||||
if not re.match(r"^\d{6}$", code):
|
||||
raise ValidationError(
|
||||
"Code must be a 6-digit number.",
|
||||
field_name="code",
|
||||
)
|
||||
|
||||
|
||||
class TOTPDisableSchema(Schema):
|
||||
"""Schema for disabling TOTP."""
|
||||
|
||||
@@ -59,7 +59,7 @@ class AuditService:
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
metadata=metadata,
|
||||
extra_data=metadata,
|
||||
description=description,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
|
||||
@@ -102,7 +102,7 @@ class AuthService:
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}")
|
||||
|
||||
if user.status == UserStatus.SUSPENDED:
|
||||
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
raise AccountSuspendedError()
|
||||
if user.status == UserStatus.INACTIVE:
|
||||
raise AccountInactiveError()
|
||||
@@ -210,6 +210,22 @@ class AuthService:
|
||||
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
|
||||
db.session.commit()
|
||||
|
||||
# Invalidate all other sessions so that if an attacker had a valid
|
||||
# session token, changing the password actually locks them out.
|
||||
# The current request's session (if any) is preserved so the user
|
||||
# doesn't have to log in again immediately.
|
||||
from flask import g as flask_g
|
||||
current_session_id = getattr(flask_g, "current_session", None)
|
||||
current_session_id = current_session_id.id if current_session_id else None
|
||||
sessions_to_revoke = Session.query.filter(
|
||||
Session.user_id == user.id,
|
||||
Session.revoked_at == None, # noqa: E711
|
||||
).all()
|
||||
for sess in sessions_to_revoke:
|
||||
if sess.id != current_session_id:
|
||||
sess.revoke(reason="Password changed")
|
||||
db.session.commit()
|
||||
|
||||
# Log password change
|
||||
AuditService.log_action(
|
||||
action=AuditAction.PASSWORD_CHANGE,
|
||||
@@ -482,9 +498,24 @@ class AuthService:
|
||||
if not secret:
|
||||
raise InvalidCredentialsError("TOTP secret not found")
|
||||
|
||||
# Replay-attack prevention: reject codes that have already been
|
||||
# accepted within the current validity window.
|
||||
if TOTPService.is_code_already_used(str(user.id), code):
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_VERIFY_FAILED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="TOTP code replay attempt detected",
|
||||
)
|
||||
raise InvalidCredentialsError("Invalid TOTP code")
|
||||
|
||||
is_valid = TOTPService.verify_code(secret, code, client_utc_timestamp=client_utc_timestamp)
|
||||
|
||||
if is_valid:
|
||||
# Mark this code as used to prevent replay within the validity window
|
||||
TOTPService.mark_code_used(str(user.id), code)
|
||||
|
||||
auth_method.last_used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -736,9 +736,14 @@ class ExternalAuthService:
|
||||
400,
|
||||
)
|
||||
|
||||
# Generate PKCE
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = cls._compute_s256_challenge(code_verifier)
|
||||
# Generate PKCE — skip for confidential clients (Google, Microsoft) that use a
|
||||
# client_secret. Sending code_challenge to Microsoft causes it to enforce PKCE on
|
||||
# the token exchange, which then fails. Matches the behaviour of initiate_login_flow.
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if provider_type_str not in ('google', 'microsoft'):
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = cls._compute_s256_challenge(code_verifier)
|
||||
|
||||
# Create OAuth state
|
||||
state = OAuthState.create_state(
|
||||
|
||||
@@ -188,11 +188,10 @@ class OrganizationService:
|
||||
Raises:
|
||||
ConflictError: If user is already a member
|
||||
"""
|
||||
# Check if already a member
|
||||
# Check if already a member (active or soft-deleted — both blocked by DB unique constraint)
|
||||
existing = OrganizationMember.query.filter_by(
|
||||
user_id=user_id,
|
||||
organization_id=org.id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
# Development-only debug logging for membership validation
|
||||
@@ -200,6 +199,25 @@ class OrganizationService:
|
||||
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
|
||||
|
||||
if existing:
|
||||
if existing.deleted_at is not None:
|
||||
# Reactivate the soft-deleted membership with the new role
|
||||
existing.deleted_at = None
|
||||
existing.role = role
|
||||
existing.invited_by_id = inviter_id
|
||||
existing.invited_at = datetime.now(timezone.utc)
|
||||
existing.joined_at = datetime.now(timezone.utc)
|
||||
existing.save()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_MEMBER_ADD,
|
||||
user_id=inviter_id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization_member",
|
||||
resource_id=existing.id,
|
||||
metadata={"added_user_id": user_id, "role": role.value},
|
||||
description=f"Member re-added to organization with role: {role.value}",
|
||||
)
|
||||
return existing
|
||||
raise ConflictError("User is already a member of this organization")
|
||||
|
||||
# Create membership
|
||||
|
||||
@@ -192,13 +192,19 @@ class SSHCASigningService:
|
||||
self,
|
||||
signing_request: SSHCertificateSigningRequest,
|
||||
ca_private_key: Optional[str] = None,
|
||||
) -> SSHCertificateSigningResponse:
|
||||
ca_obj=None,
|
||||
) -> "SSHCertificateSigningResponse":
|
||||
"""Sign an SSH certificate.
|
||||
|
||||
Args:
|
||||
signing_request: SSHCertificateSigningRequest instance
|
||||
ca_private_key: CA private key in PEM format. If not provided,
|
||||
loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var)
|
||||
ca_obj: Optional CA model instance. When supplied its monotonic
|
||||
serial counter is incremented atomically (SELECT FOR UPDATE)
|
||||
and the resulting integer is embedded in the certificate's
|
||||
serial field. This ensures every issued cert has a unique,
|
||||
ordered, auditable serial number.
|
||||
|
||||
Returns:
|
||||
SSHCertificateSigningResponse with signed certificate
|
||||
@@ -245,13 +251,27 @@ class SSHCASigningService:
|
||||
valid_before = now + timedelta(hours=expiry_hours)
|
||||
|
||||
# Set certificate fields
|
||||
cert_type = 1 if signing_request.cert_type == "user" else 0
|
||||
# sshkey-tools: user=1, host=2 (not 0)
|
||||
cert_type = 1 if signing_request.cert_type == "user" else 2
|
||||
|
||||
certificate.fields.cert_type = cert_type
|
||||
certificate.fields.key_id = signing_request.key_id
|
||||
certificate.fields.principals = signing_request.principals
|
||||
certificate.fields.valid_after = now
|
||||
certificate.fields.valid_before = valid_before
|
||||
|
||||
# ── Serial number ────────────────────────────────────────────────
|
||||
# If a CA object is provided, use its monotonic counter so every
|
||||
# certificate gets a unique, ordered, auditable serial. The
|
||||
# counter increment is flushed inside get_next_serial(); the
|
||||
# caller's commit() persists it atomically with the cert record.
|
||||
if ca_obj is not None:
|
||||
assigned_serial = ca_obj.get_next_serial()
|
||||
certificate.fields.serial = assigned_serial
|
||||
self.logger.debug(
|
||||
f"Assigned serial {assigned_serial} from CA {ca_obj.id}"
|
||||
)
|
||||
# ─────────────────────────────────────────────────────────────────
|
||||
|
||||
# Set extensions — prefer policy-provided list, fall back to standard set
|
||||
extensions = signing_request.extensions
|
||||
@@ -276,8 +296,13 @@ class SSHCASigningService:
|
||||
self.logger.error(f"Certificate verification failed: {str(e)}")
|
||||
raise SSHCASigningError(f"Certificate verification failed: {str(e)}")
|
||||
|
||||
# Extract serial from certificate
|
||||
serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
|
||||
# Extract serial from certificate — use the integer we assigned
|
||||
# when ca_obj was provided, otherwise fall back to whatever the
|
||||
# library generated.
|
||||
if ca_obj is not None:
|
||||
serial = str(assigned_serial)
|
||||
else:
|
||||
serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
|
||||
|
||||
# Build response
|
||||
cert_string = certificate.to_string()
|
||||
|
||||
@@ -11,10 +11,51 @@ from gatehouse_app.extensions import bcrypt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TOTP codes are valid for at most (2*window + 1) * 30s steps.
|
||||
# With window=1 that's 3 steps = 90 seconds. We use a slightly
|
||||
# generous TTL of 95 seconds to account for clock skew at boundaries.
|
||||
_TOTP_USED_CODE_TTL = 95
|
||||
|
||||
|
||||
class TOTPService:
|
||||
"""Service for TOTP operations."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Replay-attack prevention helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _used_key(user_id: str, code: str) -> str:
|
||||
return f"totp:used:{user_id}:{code}"
|
||||
|
||||
@staticmethod
|
||||
def is_code_already_used(user_id: str, code: str) -> bool:
|
||||
"""Return True if *code* has already been accepted for *user_id*
|
||||
within the current validity window (prevents replay attacks)."""
|
||||
try:
|
||||
from gatehouse_app.extensions import redis_client
|
||||
if redis_client is None:
|
||||
return False
|
||||
return redis_client.exists(TOTPService._used_key(user_id, code)) == 1
|
||||
except Exception:
|
||||
logger.warning("Redis unavailable for TOTP replay check; allowing code")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def mark_code_used(user_id: str, code: str) -> None:
|
||||
"""Record *code* as consumed for *user_id* so it cannot be reused."""
|
||||
try:
|
||||
from gatehouse_app.extensions import redis_client
|
||||
if redis_client is None:
|
||||
return
|
||||
redis_client.setex(
|
||||
TOTPService._used_key(user_id, code),
|
||||
_TOTP_USED_CODE_TTL,
|
||||
"1",
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Redis unavailable; TOTP used-code not recorded")
|
||||
|
||||
@staticmethod
|
||||
def generate_secret() -> str:
|
||||
"""
|
||||
|
||||
@@ -641,6 +641,26 @@ class WebAuthnService:
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def credential_belongs_to_user(cls, credential_id: str, user: User) -> bool:
|
||||
"""Check whether *credential_id* exists and belongs to *user*.
|
||||
|
||||
Args:
|
||||
credential_id: The credential ID to look up
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
True if the credential exists and belongs to this user, False otherwise.
|
||||
"""
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if not auth_method or not auth_method.provider_data:
|
||||
return False
|
||||
return auth_method.provider_data.get("credential_id") == credential_id
|
||||
|
||||
@classmethod
|
||||
def rename_credential(cls, credential_id: str, user: User, name: str) -> bool:
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Encryption helpers for CA private keys stored in the database.
|
||||
|
||||
CA private keys are encrypted at rest using Fernet (AES-128-CBC + HMAC-SHA256)
|
||||
from the ``cryptography`` package. The encryption key is derived from the
|
||||
``CA_ENCRYPTION_KEY`` environment variable (or ``Flask.config["CA_ENCRYPTION_KEY"]``).
|
||||
|
||||
Key derivation
|
||||
--------------
|
||||
Fernet requires a URL-safe base64-encoded 32-byte key. We accept any string
|
||||
from the env and derive the actual Fernet key using SHA-256 so that operators
|
||||
can supply human-readable secrets without having to pre-encode them.
|
||||
|
||||
Envelope format
|
||||
---------------
|
||||
Encrypted values are stored as the string::
|
||||
|
||||
$fernet$<fernet_token>
|
||||
|
||||
The ``$fernet$`` prefix lets the code distinguish already-encrypted values from
|
||||
legacy plaintext PEM keys so that the migration path is safe and idempotent.
|
||||
|
||||
Usage
|
||||
-----
|
||||
Encrypt before storing::
|
||||
|
||||
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
|
||||
ca.private_key = encrypt_ca_key(private_key_pem)
|
||||
|
||||
Decrypt before use::
|
||||
|
||||
from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
|
||||
plaintext_pem = decrypt_ca_key(ca.private_key)
|
||||
"""
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Prefix that marks a stored value as Fernet-encrypted
|
||||
_FERNET_PREFIX = "$fernet$"
|
||||
|
||||
|
||||
class CAKeyEncryptionError(Exception):
|
||||
"""Raised when CA key encryption or decryption fails."""
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
"""Build a Fernet instance from the configured encryption key.
|
||||
|
||||
Looks up ``CA_ENCRYPTION_KEY`` in the environment first, then falls back to
|
||||
the Flask app config (if a request context is active).
|
||||
|
||||
Raises:
|
||||
CAKeyEncryptionError: if no key is configured or it is the insecure
|
||||
placeholder value in a production-like environment.
|
||||
"""
|
||||
raw_key = os.environ.get("CA_ENCRYPTION_KEY")
|
||||
|
||||
if not raw_key:
|
||||
# Try Flask config if we're inside an app context
|
||||
try:
|
||||
from flask import current_app
|
||||
raw_key = current_app.config.get("CA_ENCRYPTION_KEY")
|
||||
except RuntimeError:
|
||||
pass # No app context
|
||||
|
||||
if not raw_key:
|
||||
raise CAKeyEncryptionError(
|
||||
"CA_ENCRYPTION_KEY is not set. "
|
||||
"Set this environment variable before starting the application."
|
||||
)
|
||||
|
||||
# Warn loudly when running with the placeholder in a non-test environment
|
||||
env_name = os.environ.get("FLASK_ENV", "").lower()
|
||||
if raw_key.startswith("dev-") and env_name not in ("development", "testing", "test"):
|
||||
logger.warning(
|
||||
"CA_ENCRYPTION_KEY appears to be a development placeholder. "
|
||||
"Set a strong random key for production environments."
|
||||
)
|
||||
|
||||
# Derive a 32-byte key from the raw secret via SHA-256, then URL-safe base64
|
||||
key_bytes = hashlib.sha256(raw_key.encode()).digest()
|
||||
fernet_key = base64.urlsafe_b64encode(key_bytes)
|
||||
return Fernet(fernet_key)
|
||||
|
||||
|
||||
def encrypt_ca_key(plaintext_pem: str) -> str:
|
||||
"""Encrypt a CA private key PEM string.
|
||||
|
||||
Idempotent: already-encrypted values are returned unchanged.
|
||||
|
||||
Args:
|
||||
plaintext_pem: CA private key in OpenSSH/PEM format.
|
||||
|
||||
Returns:
|
||||
Encrypted string with ``$fernet$`` prefix, safe for database storage.
|
||||
|
||||
Raises:
|
||||
CAKeyEncryptionError: if the key cannot be encrypted.
|
||||
"""
|
||||
if not plaintext_pem:
|
||||
raise CAKeyEncryptionError("Cannot encrypt an empty key")
|
||||
|
||||
# Already encrypted — do not double-encrypt
|
||||
if plaintext_pem.startswith(_FERNET_PREFIX):
|
||||
return plaintext_pem
|
||||
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
token = fernet.encrypt(plaintext_pem.encode()).decode()
|
||||
return f"{_FERNET_PREFIX}{token}"
|
||||
except CAKeyEncryptionError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise CAKeyEncryptionError(f"Failed to encrypt CA key: {exc}") from exc
|
||||
|
||||
|
||||
def decrypt_ca_key(stored_value: str) -> str:
|
||||
"""Decrypt a CA private key retrieved from the database.
|
||||
|
||||
Idempotent: plaintext (legacy) values are returned unchanged so that the
|
||||
system continues to work while a migration encrypts existing rows.
|
||||
|
||||
Args:
|
||||
stored_value: Value from ``CA.private_key`` column.
|
||||
|
||||
Returns:
|
||||
Plaintext PEM string ready for use with ``sshkey_tools``.
|
||||
|
||||
Raises:
|
||||
CAKeyEncryptionError: if decryption fails (wrong key, corrupted data).
|
||||
"""
|
||||
if not stored_value:
|
||||
raise CAKeyEncryptionError("Cannot decrypt an empty value")
|
||||
|
||||
# Legacy plaintext key — return as-is
|
||||
if not stored_value.startswith(_FERNET_PREFIX):
|
||||
logger.warning(
|
||||
"CA private key appears to be stored as plaintext. "
|
||||
"Run the migration to encrypt existing keys."
|
||||
)
|
||||
return stored_value
|
||||
|
||||
token = stored_value[len(_FERNET_PREFIX):]
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
return fernet.decrypt(token.encode()).decode()
|
||||
except InvalidToken as exc:
|
||||
raise CAKeyEncryptionError(
|
||||
"CA key decryption failed — the CA_ENCRYPTION_KEY may be incorrect "
|
||||
"or the stored key is corrupted."
|
||||
) from exc
|
||||
except CAKeyEncryptionError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise CAKeyEncryptionError(f"Unexpected decryption error: {exc}") from exc
|
||||
|
||||
|
||||
def is_encrypted(stored_value: str) -> bool:
|
||||
"""Return True if the stored value has the ``$fernet$`` envelope.
|
||||
|
||||
Args:
|
||||
stored_value: Value from ``CA.private_key`` column.
|
||||
"""
|
||||
return bool(stored_value and stored_value.startswith(_FERNET_PREFIX))
|
||||
|
||||
|
||||
def reencrypt_ca_key(stored_value: str, old_raw_key: str, new_raw_key: str) -> str:
|
||||
"""Re-encrypt a CA key with a new encryption key (for key rotation).
|
||||
|
||||
Args:
|
||||
stored_value: Current value from ``CA.private_key`` (may or may not be encrypted).
|
||||
old_raw_key: The current ``CA_ENCRYPTION_KEY`` value (raw secret string).
|
||||
new_raw_key: The new ``CA_ENCRYPTION_KEY`` value to encrypt with.
|
||||
|
||||
Returns:
|
||||
New encrypted envelope string.
|
||||
|
||||
Raises:
|
||||
CAKeyEncryptionError: if decryption or re-encryption fails.
|
||||
"""
|
||||
# Decrypt with old key
|
||||
if stored_value.startswith(_FERNET_PREFIX):
|
||||
token = stored_value[len(_FERNET_PREFIX):]
|
||||
old_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(old_raw_key.encode()).digest())
|
||||
try:
|
||||
plaintext = Fernet(old_key_bytes).decrypt(token.encode()).decode()
|
||||
except InvalidToken as exc:
|
||||
raise CAKeyEncryptionError(
|
||||
"Re-encryption failed: could not decrypt with the old key."
|
||||
) from exc
|
||||
else:
|
||||
# Plaintext
|
||||
plaintext = stored_value
|
||||
|
||||
# Re-encrypt with new key
|
||||
new_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(new_raw_key.encode()).digest())
|
||||
try:
|
||||
token = Fernet(new_key_bytes).encrypt(plaintext.encode()).decode()
|
||||
return f"{_FERNET_PREFIX}{token}"
|
||||
except Exception as exc:
|
||||
raise CAKeyEncryptionError(f"Re-encryption with new key failed: {exc}") from exc
|
||||
@@ -61,6 +61,7 @@ class AuditAction(str, Enum):
|
||||
USER_REGISTER = "user.register"
|
||||
USER_UPDATE = "user.update"
|
||||
USER_DELETE = "user.delete"
|
||||
USER_HARD_DELETE = "user.hard_delete"
|
||||
USER_SUSPEND = "user.suspend"
|
||||
USER_UNSUSPEND = "user.unsuspend"
|
||||
PASSWORD_CHANGE = "user.password_change"
|
||||
@@ -73,6 +74,7 @@ class AuditAction(str, Enum):
|
||||
ORG_MEMBER_ADD = "org.member.add"
|
||||
ORG_MEMBER_REMOVE = "org.member.remove"
|
||||
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
|
||||
ORG_OWNERSHIP_TRANSFERRED = "org.ownership.transferred"
|
||||
|
||||
# Session actions
|
||||
SESSION_CREATE = "session.create"
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add USER_SUSPEND and USER_UNSUSPEND to auditaction enum.
|
||||
|
||||
Revision ID: 015_add_user_suspend_audit_actions
|
||||
Revises: 014_add_dept_cert_policy
|
||||
Create Date: 2026-03-02
|
||||
|
||||
USER_SUSPEND and USER_UNSUSPEND were added to the Python AuditAction enum
|
||||
but were never synced to the PostgreSQL auditaction type, causing a
|
||||
DataError (invalid enum value) whenever an admin suspends or unsuspends a user.
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
revision = "015_user_suspend_audit"
|
||||
down_revision = "014_add_dept_cert_policy"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
for val in ("USER_SUSPEND", "USER_UNSUSPEND"):
|
||||
op.execute(f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = '{val}'
|
||||
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
|
||||
) THEN
|
||||
ALTER TYPE auditaction ADD VALUE '{val}';
|
||||
END IF;
|
||||
END$$;
|
||||
""")
|
||||
|
||||
|
||||
def downgrade():
|
||||
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,168 @@
|
||||
"""Encrypt existing plaintext CA private keys at rest.
|
||||
|
||||
Revision ID: 016_encrypt_existing_ca_keys
|
||||
Revises: 015_add_user_suspend_audit_actions
|
||||
Create Date: 2026-03-02
|
||||
|
||||
All CA private keys created before this migration were stored as plaintext PEM
|
||||
strings in the ``cas.private_key`` column. This migration detects those rows
|
||||
(by checking for the absence of the ``$fernet$`` prefix that encrypted values
|
||||
carry) and re-encrypts them with the key derived from ``CA_ENCRYPTION_KEY``.
|
||||
|
||||
The migration is safe to re-run: already-encrypted rows are left untouched.
|
||||
|
||||
Prerequisites
|
||||
-------------
|
||||
``CA_ENCRYPTION_KEY`` must be set in the environment before running this
|
||||
migration. The same value must be configured for the running application.
|
||||
|
||||
To roll back to plaintext (downgrade):
|
||||
The ``downgrade()`` function decrypts all rows back to plaintext PEM. This is
|
||||
provided only for emergency rollback and should not be used in production once
|
||||
the system has been running with encrypted keys.
|
||||
"""
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Alembic revision identifiers
|
||||
revision = "016_encrypt_ca_keys"
|
||||
down_revision = "015_user_suspend_audit"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_FERNET_PREFIX = "$fernet$"
|
||||
|
||||
|
||||
def _get_fernet():
|
||||
"""Build a Fernet instance from CA_ENCRYPTION_KEY env var."""
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
raw_key = os.environ.get("CA_ENCRYPTION_KEY")
|
||||
if not raw_key:
|
||||
raise RuntimeError(
|
||||
"CA_ENCRYPTION_KEY environment variable is not set. "
|
||||
"Set it before running this migration."
|
||||
)
|
||||
key_bytes = base64.urlsafe_b64encode(hashlib.sha256(raw_key.encode()).digest())
|
||||
return Fernet(key_bytes)
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Encrypt plaintext CA private keys."""
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
except RuntimeError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
# Fetch all non-deleted CA rows
|
||||
rows = session.execute(
|
||||
sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
|
||||
).fetchall()
|
||||
|
||||
encrypted_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for row in rows:
|
||||
ca_id, private_key = row[0], row[1]
|
||||
|
||||
if not private_key:
|
||||
logger.warning(f"CA {ca_id} has empty private_key — skipping")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
if private_key.startswith(_FERNET_PREFIX):
|
||||
# Already encrypted
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Encrypt
|
||||
try:
|
||||
token = fernet.encrypt(private_key.encode()).decode()
|
||||
encrypted_value = f"{_FERNET_PREFIX}{token}"
|
||||
session.execute(
|
||||
sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
|
||||
{"pk": encrypted_value, "id": ca_id},
|
||||
)
|
||||
encrypted_count += 1
|
||||
logger.info(f"Encrypted private key for CA {ca_id}")
|
||||
except Exception as exc:
|
||||
session.rollback()
|
||||
raise RuntimeError(
|
||||
f"Failed to encrypt private key for CA {ca_id}: {exc}"
|
||||
) from exc
|
||||
|
||||
session.commit()
|
||||
logger.info(
|
||||
f"CA key encryption migration complete: "
|
||||
f"{encrypted_count} encrypted, {skipped_count} skipped"
|
||||
)
|
||||
print(
|
||||
f" [016_encrypt_ca_keys] {encrypted_count} CA private key(s) encrypted, "
|
||||
f"{skipped_count} already encrypted or empty."
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Decrypt CA private keys back to plaintext (emergency rollback only)."""
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
except RuntimeError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
rows = session.execute(
|
||||
sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
|
||||
).fetchall()
|
||||
|
||||
decrypted_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for row in rows:
|
||||
ca_id, private_key = row[0], row[1]
|
||||
|
||||
if not private_key or not private_key.startswith(_FERNET_PREFIX):
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
token = private_key[len(_FERNET_PREFIX):]
|
||||
try:
|
||||
from cryptography.fernet import InvalidToken
|
||||
try:
|
||||
plaintext = fernet.decrypt(token.encode()).decode()
|
||||
except InvalidToken as exc:
|
||||
raise RuntimeError(
|
||||
f"Downgrade failed: cannot decrypt CA {ca_id} — wrong key or corrupted data."
|
||||
) from exc
|
||||
|
||||
session.execute(
|
||||
sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
|
||||
{"pk": plaintext, "id": ca_id},
|
||||
)
|
||||
decrypted_count += 1
|
||||
logger.warning(f"Decrypted (plaintext restore) private key for CA {ca_id}")
|
||||
except RuntimeError:
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
session.commit()
|
||||
logger.warning(
|
||||
f"CA key decryption (downgrade) complete: "
|
||||
f"{decrypted_count} decrypted, {skipped_count} skipped"
|
||||
)
|
||||
print(
|
||||
f" [016_encrypt_ca_keys] DOWNGRADE: {decrypted_count} CA private key(s) "
|
||||
f"decrypted to plaintext. WARNING: keys are now unencrypted at rest."
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add monotonic serial counter to CAs table.
|
||||
|
||||
Each CA now owns a `next_serial_number` (BigInteger) that is atomically
|
||||
incremented every time a certificate is signed. This guarantees:
|
||||
- Serials are unique per CA
|
||||
- Serials are monotonically increasing (auditable, no gaps by accident)
|
||||
- The value embedded in the OpenSSH certificate matches what is stored
|
||||
in the `ssh_certificates.serial` column
|
||||
|
||||
Revision ID: 017_add_ca_serial_counter
|
||||
Revises: 016_encrypt_ca_keys
|
||||
Create Date: 2026-03-02
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "017_add_ca_serial_counter"
|
||||
down_revision = "016_encrypt_ca_keys"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("cas", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"next_serial_number",
|
||||
sa.BigInteger(),
|
||||
nullable=False,
|
||||
server_default="1",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("cas", schema=None) as batch_op:
|
||||
batch_op.drop_column("next_serial_number")
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Add ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE to auditaction enum.
|
||||
|
||||
Revision ID: 018_audit_enum_values
|
||||
Revises: 017_add_ca_serial_counter
|
||||
Create Date: 2026-03-02
|
||||
|
||||
ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE were added to the Python
|
||||
AuditAction enum but were never synced to the PostgreSQL auditaction type,
|
||||
causing a DataError (invalid enum value) when transferring org ownership
|
||||
or hard-deleting a user.
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
revision = "018_audit_enum_values"
|
||||
down_revision = "017_add_ca_serial_counter"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ALTER TYPE ... ADD VALUE cannot run inside a transaction block in PostgreSQL.
|
||||
# Alembic has already opened a transaction on the connection by the time our
|
||||
# upgrade() runs, so we must:
|
||||
# 1. Roll back that open transaction on the raw psycopg2 connection.
|
||||
# 2. Switch to autocommit so the ALTER TYPE runs outside any transaction.
|
||||
# 3. Restore the previous state afterwards.
|
||||
conn = op.get_bind()
|
||||
# SQLAlchemy 2.x: conn.connection is a _ConnectionFairy; .driver_connection is psycopg2
|
||||
fairy = conn.connection
|
||||
raw = getattr(fairy, "driver_connection", None) or getattr(fairy, "dbapi_connection", fairy)
|
||||
# Roll back the open transaction so psycopg2 allows us to change autocommit.
|
||||
raw.rollback()
|
||||
old_autocommit = raw.autocommit
|
||||
raw.autocommit = True
|
||||
try:
|
||||
with raw.cursor() as cur:
|
||||
for val in ("ORG_OWNERSHIP_TRANSFERRED", "USER_HARD_DELETE"):
|
||||
cur.execute(
|
||||
"SELECT 1 FROM pg_enum "
|
||||
"WHERE enumlabel = %s "
|
||||
"AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')",
|
||||
(val,),
|
||||
)
|
||||
if not cur.fetchone():
|
||||
cur.execute(f"ALTER TYPE auditaction ADD VALUE '{val}'")
|
||||
finally:
|
||||
raw.autocommit = old_autocommit
|
||||
|
||||
|
||||
def downgrade():
|
||||
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||
pass
|
||||
Reference in New Issue
Block a user