Fix(Feat): CA, Audits, Rte Limit
CA Encryption, Serials, Rate Limiter, Account suspension blocks login Transfer Ownership & Delete Account
This commit is contained in:
@@ -51,8 +51,7 @@ class MyServer(BaseHTTPRequestHandler):
|
|||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(bytes("<html><head><title>OIDC Workflow Tool</title></head>", "utf-8"))
|
self.wfile.write(bytes("<html><head><title>OIDC Workflow Tool</title></head>", "utf-8"))
|
||||||
self.wfile.write(bytes("<body><p>The token has been received</p>", "utf-8"))
|
self.wfile.write(bytes("<body><p>The token has been received</p>", "utf-8"))
|
||||||
self.wfile.write(bytes("<p>Window closing in <span id='countdown'>5</span> seconds...</p>", "utf-8"))
|
self.wfile.write(bytes("<p>You may now close this window.</p>", "utf-8"))
|
||||||
self.wfile.write(bytes("<script>var count = 5; setInterval(function() { count--; document.getElementById('countdown').textContent = count; if (count === 0) window.close(); }, 1000);</script>", "utf-8"))
|
|
||||||
self.wfile.write(bytes("</body></html>", "utf-8"))
|
self.wfile.write(bytes("</body></html>", "utf-8"))
|
||||||
|
|
||||||
parsed_url = urlparse(self.path)
|
parsed_url = urlparse(self.path)
|
||||||
|
|||||||
@@ -29,6 +29,11 @@ class BaseConfig:
|
|||||||
# Encryption key for sensitive data (client secrets, tokens, etc.)
|
# Encryption key for sensitive data (client secrets, tokens, etc.)
|
||||||
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "dev-encryption-key-change-in-production")
|
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "dev-encryption-key-change-in-production")
|
||||||
|
|
||||||
|
# Encryption key for CA private keys stored in the database.
|
||||||
|
# Must be set to a strong random secret in production.
|
||||||
|
# Any string is accepted — it is SHA-256 derived to a 32-byte Fernet key internally.
|
||||||
|
CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "dev-ca-encryption-key-change-in-production")
|
||||||
|
|
||||||
# Session configuration for WebAuthn cross-origin support
|
# Session configuration for WebAuthn cross-origin support
|
||||||
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "True").lower() == "true"
|
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "True").lower() == "true"
|
||||||
SESSION_COOKIE_HTTPONLY = True
|
SESSION_COOKIE_HTTPONLY = True
|
||||||
@@ -72,6 +77,13 @@ class BaseConfig:
|
|||||||
RATELIMIT_STORAGE_URL = os.getenv("RATELIMIT_STORAGE_URL", "redis://localhost:6379/1")
|
RATELIMIT_STORAGE_URL = os.getenv("RATELIMIT_STORAGE_URL", "redis://localhost:6379/1")
|
||||||
RATELIMIT_DEFAULT = "100/hour"
|
RATELIMIT_DEFAULT = "100/hour"
|
||||||
|
|
||||||
|
# Per-endpoint auth rate limits (override via env vars for each environment)
|
||||||
|
RATELIMIT_AUTH_REGISTER = os.getenv("RATELIMIT_AUTH_REGISTER", "10 per minute; 50 per hour")
|
||||||
|
RATELIMIT_AUTH_LOGIN = os.getenv("RATELIMIT_AUTH_LOGIN", "20 per minute; 100 per hour")
|
||||||
|
RATELIMIT_AUTH_TOTP_VERIFY = os.getenv("RATELIMIT_AUTH_TOTP_VERIFY", "20 per minute; 100 per hour")
|
||||||
|
RATELIMIT_AUTH_FORGOT_PASSWORD = os.getenv("RATELIMIT_AUTH_FORGOT_PASSWORD", "5 per minute; 20 per hour")
|
||||||
|
RATELIMIT_AUTH_RESET_PASSWORD = os.getenv("RATELIMIT_AUTH_RESET_PASSWORD", "10 per minute; 30 per hour")
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||||
LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true"
|
LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true"
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ class TestingConfig(BaseConfig):
|
|||||||
# Explicitly set SECRET_KEY for testing
|
# Explicitly set SECRET_KEY for testing
|
||||||
SECRET_KEY = os.getenv("SECRET_KEY", "test-secret-key-for-testing")
|
SECRET_KEY = os.getenv("SECRET_KEY", "test-secret-key-for-testing")
|
||||||
|
|
||||||
|
# CA key encryption — use a fixed test key so tests are deterministic
|
||||||
|
CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "test-ca-encryption-key-fixed-for-tests")
|
||||||
|
|
||||||
# Use in-memory SQLite for testing
|
# Use in-memory SQLite for testing
|
||||||
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
|
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
|
||||||
SQLALCHEMY_ECHO = False
|
SQLALCHEMY_ECHO = False
|
||||||
|
|||||||
@@ -22,7 +22,11 @@ from gatehouse_app.extensions import bcrypt as flask_bcrypt
|
|||||||
from gatehouse_app.extensions import redis_client as _redis_client_ref # may be None until app init
|
from gatehouse_app.extensions import redis_client as _redis_client_ref # may be None until app init
|
||||||
from gatehouse_app.models import User, OIDCClient
|
from gatehouse_app.models import User, OIDCClient
|
||||||
from gatehouse_app.models.organization.organization import Organization
|
from gatehouse_app.models.organization.organization import Organization
|
||||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
from gatehouse_app.exceptions.auth_exceptions import (
|
||||||
|
InvalidCredentialsError,
|
||||||
|
AccountSuspendedError,
|
||||||
|
AccountInactiveError,
|
||||||
|
)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers for Redis-backed OIDC pending state
|
# Helpers for Redis-backed OIDC pending state
|
||||||
@@ -343,6 +347,20 @@ def oidc_complete():
|
|||||||
|
|
||||||
user_id = str(gh_session.user_id)
|
user_id = str(gh_session.user_id)
|
||||||
|
|
||||||
|
# Check the user is still active (not suspended after session was issued)
|
||||||
|
from gatehouse_app.models.user.user import User as _User
|
||||||
|
from gatehouse_app.utils.constants import UserStatus
|
||||||
|
_complete_user = _User.query.filter_by(id=user_id, deleted_at=None).first()
|
||||||
|
if not _complete_user or _complete_user.status in (
|
||||||
|
UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED, UserStatus.INACTIVE
|
||||||
|
):
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Your account is not active or has been suspended.",
|
||||||
|
status=403,
|
||||||
|
error_type="ACCOUNT_SUSPENDED",
|
||||||
|
)
|
||||||
|
|
||||||
# Retrieve stashed OIDC params (consume = True removes from Redis atomically)
|
# Retrieve stashed OIDC params (consume = True removes from Redis atomically)
|
||||||
params = _fetch_oidc_params(oidc_session_id, consume=True)
|
params = _fetch_oidc_params(oidc_session_id, consume=True)
|
||||||
if not params:
|
if not params:
|
||||||
@@ -565,6 +583,28 @@ def oidc_authorize():
|
|||||||
session["oidc_user_id"] = user_id
|
session["oidc_user_id"] = user_id
|
||||||
|
|
||||||
logger.debug("[OIDC] User authentication successful: user_id=%s, email=%s", user_id, email)
|
logger.debug("[OIDC] User authentication successful: user_id=%s, email=%s", user_id, email)
|
||||||
|
except AccountSuspendedError:
|
||||||
|
logger.debug("[OIDC] User authentication failed: account suspended for email=%s", email)
|
||||||
|
return _show_login_page(
|
||||||
|
client_id=client_id,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=scope,
|
||||||
|
state=state,
|
||||||
|
nonce=nonce,
|
||||||
|
response_type=response_type,
|
||||||
|
error="Your account has been suspended. Please contact an administrator.",
|
||||||
|
)
|
||||||
|
except AccountInactiveError:
|
||||||
|
logger.debug("[OIDC] User authentication failed: account inactive for email=%s", email)
|
||||||
|
return _show_login_page(
|
||||||
|
client_id=client_id,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=scope,
|
||||||
|
state=state,
|
||||||
|
nonce=nonce,
|
||||||
|
response_type=response_type,
|
||||||
|
error="Your account is not active. Please verify your email.",
|
||||||
|
)
|
||||||
except InvalidCredentialsError:
|
except InvalidCredentialsError:
|
||||||
logger.debug("[OIDC] User authentication failed: invalid credentials for email=%s", email)
|
logger.debug("[OIDC] User authentication failed: invalid credentials for email=%s", email)
|
||||||
return _show_login_page(
|
return _show_login_page(
|
||||||
@@ -601,6 +641,33 @@ def oidc_authorize():
|
|||||||
logger.debug("[OIDC] Redirecting with error: server_error (user not found)")
|
logger.debug("[OIDC] Redirecting with error: server_error (user not found)")
|
||||||
return _redirect_with_error(redirect_uri, "server_error", "User not found", state)
|
return _redirect_with_error(redirect_uri, "server_error", "User not found", state)
|
||||||
|
|
||||||
|
# Check account is still active (user could have been suspended after session start)
|
||||||
|
from gatehouse_app.utils.constants import UserStatus as _UserStatus
|
||||||
|
if user.status in (_UserStatus.SUSPENDED, _UserStatus.COMPLIANCE_SUSPENDED):
|
||||||
|
session.pop("oidc_user_id", None) # clear stale session
|
||||||
|
logger.debug("[OIDC] User is suspended, clearing session and showing login error: user_id=%s", user_id)
|
||||||
|
return _show_login_page(
|
||||||
|
client_id=client_id,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=scope,
|
||||||
|
state=state,
|
||||||
|
nonce=nonce,
|
||||||
|
response_type=response_type,
|
||||||
|
error="Your account has been suspended. Please contact an administrator.",
|
||||||
|
)
|
||||||
|
if user.status == _UserStatus.INACTIVE:
|
||||||
|
session.pop("oidc_user_id", None)
|
||||||
|
logger.debug("[OIDC] User is inactive, clearing session and showing login error: user_id=%s", user_id)
|
||||||
|
return _show_login_page(
|
||||||
|
client_id=client_id,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=scope,
|
||||||
|
state=state,
|
||||||
|
nonce=nonce,
|
||||||
|
response_type=response_type,
|
||||||
|
error="Your account is not active. Please verify your email.",
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug("[OIDC] Generating authorization code...")
|
logger.debug("[OIDC] Generating authorization code...")
|
||||||
logger.debug("[OIDC] Authorization code params: client_id=%s, user_id=%s, redirect_uri=%s", client_id, user_id, redirect_uri)
|
logger.debug("[OIDC] Authorization code params: client_id=%s, user_id=%s, redirect_uri=%s", client_id, user_id, redirect_uri)
|
||||||
logger.debug("[OIDC] Authorization code params: scopes=%s, state=%s, nonce=%s", valid_scopes, state, nonce)
|
logger.debug("[OIDC] Authorization code params: scopes=%s, state=%s, nonce=%s", valid_scopes, state, nonce)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import logging
|
|||||||
from flask import request, session, g, jsonify, current_app
|
from flask import request, session, g, jsonify, current_app
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
from gatehouse_app.api.v1 import api_v1_bp
|
from gatehouse_app.api.v1 import api_v1_bp
|
||||||
|
from gatehouse_app.extensions import limiter
|
||||||
from gatehouse_app.utils.response import api_response
|
from gatehouse_app.utils.response import api_response
|
||||||
from gatehouse_app.schemas.auth_schema import (
|
from gatehouse_app.schemas.auth_schema import (
|
||||||
RegisterSchema,
|
RegisterSchema,
|
||||||
@@ -32,6 +33,7 @@ from gatehouse_app.exceptions.validation_exceptions import ConflictError, NotFou
|
|||||||
|
|
||||||
|
|
||||||
@api_v1_bp.route("/auth/register", methods=["POST"])
|
@api_v1_bp.route("/auth/register", methods=["POST"])
|
||||||
|
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_REGISTER"])
|
||||||
def register():
|
def register():
|
||||||
"""
|
"""
|
||||||
Register a new user.
|
Register a new user.
|
||||||
@@ -135,6 +137,7 @@ def register():
|
|||||||
|
|
||||||
|
|
||||||
@api_v1_bp.route("/auth/login", methods=["POST"])
|
@api_v1_bp.route("/auth/login", methods=["POST"])
|
||||||
|
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_LOGIN"])
|
||||||
def login():
|
def login():
|
||||||
"""
|
"""
|
||||||
Login user.
|
Login user.
|
||||||
@@ -325,8 +328,13 @@ def get_current_user():
|
|||||||
data={
|
data={
|
||||||
"user": user.to_dict(),
|
"user": user.to_dict(),
|
||||||
"organizations": [
|
"organizations": [
|
||||||
{"id": org.id, "name": org.name, "slug": org.slug}
|
{
|
||||||
for org in user.get_organizations()
|
"id": membership.organization.id,
|
||||||
|
"name": membership.organization.name,
|
||||||
|
"slug": membership.organization.slug,
|
||||||
|
"role": membership.role,
|
||||||
|
}
|
||||||
|
for membership in user.organization_memberships
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
message="User retrieved successfully",
|
message="User retrieved successfully",
|
||||||
@@ -478,6 +486,7 @@ def verify_totp_enrollment():
|
|||||||
|
|
||||||
|
|
||||||
@api_v1_bp.route("/auth/totp/verify", methods=["POST"])
|
@api_v1_bp.route("/auth/totp/verify", methods=["POST"])
|
||||||
|
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_TOTP_VERIFY"])
|
||||||
def verify_totp():
|
def verify_totp():
|
||||||
"""
|
"""
|
||||||
Verify TOTP code during login.
|
Verify TOTP code during login.
|
||||||
@@ -520,6 +529,18 @@ def verify_totp():
|
|||||||
error_type="AUTHENTICATION_ERROR",
|
error_type="AUTHENTICATION_ERROR",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check account suspension before completing TOTP verification
|
||||||
|
from gatehouse_app.utils.constants import UserStatus
|
||||||
|
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||||
|
session.pop("totp_pending_user_id", None)
|
||||||
|
session.pop("webauthn_pending_user_id", None)
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Account is suspended. Contact an administrator.",
|
||||||
|
status=403,
|
||||||
|
error_type="ACCOUNT_SUSPENDED",
|
||||||
|
)
|
||||||
|
|
||||||
# Verify TOTP code
|
# Verify TOTP code
|
||||||
AuthService.authenticate_with_totp(
|
AuthService.authenticate_with_totp(
|
||||||
user,
|
user,
|
||||||
@@ -909,6 +930,17 @@ def begin_webauthn_login():
|
|||||||
error_type="NOT_FOUND",
|
error_type="NOT_FOUND",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check account suspension before proceeding
|
||||||
|
from gatehouse_app.utils.constants import UserStatus
|
||||||
|
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||||
|
logger.warning(f"WebAuthn login begin - suspended account attempt: {user.email}")
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Account is suspended. Contact an administrator.",
|
||||||
|
status=403,
|
||||||
|
error_type="ACCOUNT_SUSPENDED",
|
||||||
|
)
|
||||||
|
|
||||||
# Check if user has any WebAuthn credentials
|
# Check if user has any WebAuthn credentials
|
||||||
if not user.has_webauthn_enabled():
|
if not user.has_webauthn_enabled():
|
||||||
logger.warning(f"WebAuthn login begin - no credentials for user: {user.email}")
|
logger.warning(f"WebAuthn login begin - no credentials for user: {user.email}")
|
||||||
@@ -992,6 +1024,18 @@ def complete_webauthn_login():
|
|||||||
error_type="AUTHENTICATION_ERROR",
|
error_type="AUTHENTICATION_ERROR",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check account suspension before completing login
|
||||||
|
from gatehouse_app.utils.constants import UserStatus
|
||||||
|
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||||
|
session.pop("webauthn_pending_user_id", None)
|
||||||
|
logger.warning(f"WebAuthn login complete - suspended account attempt: {user.email}")
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Account is suspended. Contact an administrator.",
|
||||||
|
status=403,
|
||||||
|
error_type="ACCOUNT_SUSPENDED",
|
||||||
|
)
|
||||||
|
|
||||||
# Extract challenge from client data
|
# Extract challenge from client data
|
||||||
client_data = data.get("response", {}).get("clientDataJSON", "")
|
client_data = data.get("response", {}).get("clientDataJSON", "")
|
||||||
|
|
||||||
@@ -1129,6 +1173,19 @@ def delete_webauthn_credential(credential_id):
|
|||||||
"""
|
"""
|
||||||
user = g.current_user
|
user = g.current_user
|
||||||
|
|
||||||
|
# First check that the specific credential actually belongs to this user.
|
||||||
|
# Only then check whether it is the last one — otherwise a user with zero
|
||||||
|
# credentials gets a misleading "Cannot delete the last passkey" error
|
||||||
|
# instead of a 404.
|
||||||
|
credential_exists = WebAuthnService.credential_belongs_to_user(credential_id, user)
|
||||||
|
if not credential_exists:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Credential not found",
|
||||||
|
status=404,
|
||||||
|
error_type="NOT_FOUND",
|
||||||
|
)
|
||||||
|
|
||||||
# Check if this is the last credential
|
# Check if this is the last credential
|
||||||
credential_count = user.get_webauthn_credential_count()
|
credential_count = user.get_webauthn_credential_count()
|
||||||
if credential_count <= 1:
|
if credential_count <= 1:
|
||||||
@@ -1238,6 +1295,7 @@ _pw_logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@api_v1_bp.route("/auth/forgot-password", methods=["POST"])
|
@api_v1_bp.route("/auth/forgot-password", methods=["POST"])
|
||||||
|
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_FORGOT_PASSWORD"])
|
||||||
def forgot_password():
|
def forgot_password():
|
||||||
"""Request a password reset email.
|
"""Request a password reset email.
|
||||||
|
|
||||||
@@ -1294,6 +1352,7 @@ def forgot_password():
|
|||||||
|
|
||||||
|
|
||||||
@api_v1_bp.route("/auth/reset-password", methods=["POST"])
|
@api_v1_bp.route("/auth/reset-password", methods=["POST"])
|
||||||
|
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESET_PASSWORD"])
|
||||||
def reset_password():
|
def reset_password():
|
||||||
"""Reset a user's password using a reset token.
|
"""Reset a user's password using a reset token.
|
||||||
|
|
||||||
@@ -1601,11 +1660,31 @@ def get_token():
|
|||||||
302: Redirect to ``<redirect>?token=<token>``
|
302: Redirect to ``<redirect>?token=<token>``
|
||||||
"""
|
"""
|
||||||
from flask import redirect as flask_redirect
|
from flask import redirect as flask_redirect
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
token = g.current_session.token
|
token = g.current_session.token
|
||||||
redirect_url = request.args.get("redirect", "").strip()
|
redirect_url = request.args.get("redirect", "").strip()
|
||||||
|
|
||||||
if redirect_url:
|
if redirect_url:
|
||||||
|
# Validate redirect URL against allowed origins to prevent open-redirect
|
||||||
|
# token exfiltration attacks (CWE-601).
|
||||||
|
allowed_origins = set(current_app.config.get("CORS_ORIGINS", []))
|
||||||
|
frontend_url = current_app.config.get("FRONTEND_URL", "")
|
||||||
|
if frontend_url:
|
||||||
|
parsed = urlparse(frontend_url)
|
||||||
|
allowed_origins.add(f"{parsed.scheme}://{parsed.netloc}")
|
||||||
|
|
||||||
|
parsed_redirect = urlparse(redirect_url)
|
||||||
|
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
|
||||||
|
|
||||||
|
if redirect_origin not in allowed_origins:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Redirect URL is not allowed.",
|
||||||
|
status=400,
|
||||||
|
error_type="INVALID_REDIRECT",
|
||||||
|
)
|
||||||
|
|
||||||
sep = "&" if "?" in redirect_url else "?"
|
sep = "&" if "?" in redirect_url else "?"
|
||||||
return flask_redirect(f"{redirect_url}{sep}token={token}", code=302)
|
return flask_redirect(f"{redirect_url}{sep}token={token}", code=302)
|
||||||
|
|
||||||
|
|||||||
@@ -226,6 +226,10 @@ def delete_organization(org_id):
|
|||||||
"""
|
"""
|
||||||
Delete organization (soft delete).
|
Delete organization (soft delete).
|
||||||
|
|
||||||
|
The owner may only delete the organization if they are the *sole* remaining
|
||||||
|
member. If other active members exist they must first transfer ownership
|
||||||
|
(or remove all other members) before deleting the organization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
org_id: Organization ID
|
org_id: Organization ID
|
||||||
|
|
||||||
@@ -234,9 +238,26 @@ def delete_organization(org_id):
|
|||||||
401: Not authenticated
|
401: Not authenticated
|
||||||
403: Not the owner
|
403: Not the owner
|
||||||
404: Organization not found
|
404: Organization not found
|
||||||
|
409: Organization still has other members — transfer ownership first
|
||||||
"""
|
"""
|
||||||
org = OrganizationService.get_organization_by_id(org_id)
|
org = OrganizationService.get_organization_by_id(org_id)
|
||||||
|
|
||||||
|
# Guard: block deletion while non-owner members still exist so ownership
|
||||||
|
# can be transferred rather than silently orphaning them.
|
||||||
|
active_member_count = org.get_member_count()
|
||||||
|
if active_member_count > 1:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=(
|
||||||
|
"This organization still has other members. "
|
||||||
|
"Please transfer ownership to another member or remove all "
|
||||||
|
"other members before deleting the organization."
|
||||||
|
),
|
||||||
|
status=409,
|
||||||
|
error_type="ORG_HAS_MEMBERS",
|
||||||
|
error_details={"member_count": active_member_count},
|
||||||
|
)
|
||||||
|
|
||||||
OrganizationService.delete_organization(
|
OrganizationService.delete_organization(
|
||||||
org=org,
|
org=org,
|
||||||
user_id=g.current_user.id,
|
user_id=g.current_user.id,
|
||||||
@@ -446,6 +467,152 @@ def update_member_role(org_id, user_id):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/organizations/<org_id>/transfer-ownership", methods=["POST"])
|
||||||
|
@login_required
|
||||||
|
@full_access_required
|
||||||
|
def transfer_organization_ownership(org_id):
|
||||||
|
"""Transfer organization ownership from the current user to another member.
|
||||||
|
|
||||||
|
Only the current OWNER of the organization may call this endpoint.
|
||||||
|
The caller will be demoted to ADMIN and the target user will be promoted to OWNER.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
new_owner_user_id (str): UUID of the member to promote to OWNER.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: Ownership transferred successfully
|
||||||
|
400: Validation error / missing fields
|
||||||
|
403: Caller is not the OWNER of this org
|
||||||
|
404: Organization or target member not found
|
||||||
|
409: Target is already the OWNER
|
||||||
|
"""
|
||||||
|
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||||
|
from gatehouse_app.utils.constants import OrganizationRole, AuditAction
|
||||||
|
from gatehouse_app.services.audit_service import AuditService
|
||||||
|
|
||||||
|
caller = g.current_user
|
||||||
|
|
||||||
|
data = request.get_json() or {}
|
||||||
|
new_owner_user_id = data.get("new_owner_user_id")
|
||||||
|
if not new_owner_user_id:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="new_owner_user_id is required",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
if str(new_owner_user_id) == str(caller.id):
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="You are already the owner of this organization.",
|
||||||
|
status=409,
|
||||||
|
error_type="CONFLICT",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch org (raises NotFound internally)
|
||||||
|
org = OrganizationService.get_organization_by_id(org_id)
|
||||||
|
|
||||||
|
# Confirm caller is the current OWNER
|
||||||
|
caller_membership = OrganizationMember.query.filter_by(
|
||||||
|
organization_id=org.id,
|
||||||
|
user_id=caller.id,
|
||||||
|
deleted_at=None,
|
||||||
|
).first()
|
||||||
|
if not caller_membership or caller_membership.role != OrganizationRole.OWNER:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Only the organization owner can transfer ownership.",
|
||||||
|
status=403,
|
||||||
|
error_type="AUTHORIZATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the target is an active member
|
||||||
|
target_membership = OrganizationMember.query.filter_by(
|
||||||
|
organization_id=org.id,
|
||||||
|
user_id=new_owner_user_id,
|
||||||
|
deleted_at=None,
|
||||||
|
).first()
|
||||||
|
if not target_membership:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Target user is not a member of this organization.",
|
||||||
|
status=404,
|
||||||
|
error_type="NOT_FOUND",
|
||||||
|
)
|
||||||
|
|
||||||
|
if target_membership.role == OrganizationRole.OWNER:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Target user is already the owner.",
|
||||||
|
status=409,
|
||||||
|
error_type="CONFLICT",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Atomic role swap ─────────────────────────────────────────────────────
|
||||||
|
# Demote caller → ADMIN, promote target → OWNER.
|
||||||
|
# Both updates go through OrganizationService so all hooks/auditing fire.
|
||||||
|
try:
|
||||||
|
demoted = OrganizationService.update_member_role(
|
||||||
|
org=org,
|
||||||
|
user_id=str(caller.id),
|
||||||
|
new_role=OrganizationRole.ADMIN,
|
||||||
|
updater_id=str(caller.id),
|
||||||
|
)
|
||||||
|
promoted = OrganizationService.update_member_role(
|
||||||
|
org=org,
|
||||||
|
user_id=str(new_owner_user_id),
|
||||||
|
new_role=OrganizationRole.OWNER,
|
||||||
|
updater_id=str(caller.id),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
from gatehouse_app.extensions import db as _db
|
||||||
|
_db.session.rollback()
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=f"Failed to transfer ownership: {exc}",
|
||||||
|
status=500,
|
||||||
|
error_type="SERVER_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.ORG_OWNERSHIP_TRANSFERRED,
|
||||||
|
user_id=caller.id,
|
||||||
|
organization_id=org.id,
|
||||||
|
resource_type="organization",
|
||||||
|
resource_id=str(org.id),
|
||||||
|
description=(
|
||||||
|
f"Ownership of '{org.name}' transferred from {caller.email} "
|
||||||
|
f"to {target_membership.user.email if target_membership.user else new_owner_user_id}"
|
||||||
|
),
|
||||||
|
metadata={
|
||||||
|
"previous_owner_id": str(caller.id),
|
||||||
|
"previous_owner_email": caller.email,
|
||||||
|
"new_owner_id": str(new_owner_user_id),
|
||||||
|
"new_owner_email": (
|
||||||
|
target_membership.user.email if target_membership.user else None
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _member_dict(m):
|
||||||
|
d = m.to_dict()
|
||||||
|
if m.user:
|
||||||
|
d["user"] = m.user.to_dict()
|
||||||
|
return d
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"previous_owner": _member_dict(demoted),
|
||||||
|
"new_owner": _member_dict(promoted),
|
||||||
|
},
|
||||||
|
message=(
|
||||||
|
f"Ownership of '{org.name}' successfully transferred to "
|
||||||
|
f"{target_membership.user.email if target_membership.user else new_owner_user_id}."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@api_v1_bp.route("/organizations/<org_id>/audit-logs", methods=["GET"])
|
@api_v1_bp.route("/organizations/<org_id>/audit-logs", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@require_admin
|
@require_admin
|
||||||
@@ -756,10 +923,30 @@ def accept_invite(token):
|
|||||||
inviter_id=invite.invited_by_id,
|
inviter_id=invite.invited_by_id,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Already a member is fine
|
from gatehouse_app.extensions import db
|
||||||
|
db.session.rollback() # Clear broken transaction so invite.accept() can commit
|
||||||
|
|
||||||
invite.accept()
|
invite.accept()
|
||||||
|
|
||||||
|
has_webauthn = user.has_webauthn_enabled()
|
||||||
|
has_totp = user.has_totp_enabled()
|
||||||
|
|
||||||
|
if has_webauthn:
|
||||||
|
from flask import session as flask_session
|
||||||
|
flask_session["webauthn_pending_user_id"] = user.id
|
||||||
|
return api_response(
|
||||||
|
data={"requires_webauthn": True},
|
||||||
|
message="Passkey verification required. Please use your passkey to complete sign-in.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_totp:
|
||||||
|
from flask import session as flask_session
|
||||||
|
flask_session["totp_pending_user_id"] = user.id
|
||||||
|
return api_response(
|
||||||
|
data={"requires_totp": True},
|
||||||
|
message="TOTP code required. Please enter your 6-digit code from your authenticator app.",
|
||||||
|
)
|
||||||
|
|
||||||
user_session = AuthService.create_session(user)
|
user_session = AuthService.create_session(user)
|
||||||
|
|
||||||
return api_response(
|
return api_response(
|
||||||
@@ -1379,6 +1566,7 @@ def create_org_ca(org_id):
|
|||||||
from gatehouse_app.models.ssh_ca.ca import CA, KeyType
|
from gatehouse_app.models.ssh_ca.ca import CA, KeyType
|
||||||
from gatehouse_app.models.organization.organization import Organization
|
from gatehouse_app.models.organization.organization import Organization
|
||||||
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
||||||
|
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
|
||||||
from marshmallow import Schema, fields as ma_fields, validate, ValidationError as MaValidationError
|
from marshmallow import Schema, fields as ma_fields, validate, ValidationError as MaValidationError
|
||||||
from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
|
from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
|
||||||
|
|
||||||
@@ -1448,13 +1636,16 @@ def create_org_ca(org_id):
|
|||||||
public_key_str = private_key_obj.public_key.to_string()
|
public_key_str = private_key_obj.public_key.to_string()
|
||||||
fingerprint = compute_ssh_fingerprint(public_key_str)
|
fingerprint = compute_ssh_fingerprint(public_key_str)
|
||||||
|
|
||||||
|
# Encrypt the private key before storing in the database
|
||||||
|
encrypted_private_key = encrypt_ca_key(private_key_pem)
|
||||||
|
|
||||||
ca = CA(
|
ca = CA(
|
||||||
organization_id=org_id,
|
organization_id=org_id,
|
||||||
name=data["name"],
|
name=data["name"],
|
||||||
description=data["description"],
|
description=data["description"],
|
||||||
ca_type=CaType(ca_type_val),
|
ca_type=CaType(ca_type_val),
|
||||||
key_type=KeyType(key_type),
|
key_type=KeyType(key_type),
|
||||||
private_key=private_key_pem,
|
private_key=encrypted_private_key,
|
||||||
public_key=public_key_str,
|
public_key=public_key_str,
|
||||||
fingerprint=fingerprint,
|
fingerprint=fingerprint,
|
||||||
default_cert_validity_hours=data["default_cert_validity_hours"],
|
default_cert_validity_hours=data["default_cert_validity_hours"],
|
||||||
@@ -1462,7 +1653,24 @@ def create_org_ca(org_id):
|
|||||||
is_active=True,
|
is_active=True,
|
||||||
)
|
)
|
||||||
db.session.add(ca)
|
db.session.add(ca)
|
||||||
|
try:
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
except Exception as commit_exc:
|
||||||
|
db.session.rollback()
|
||||||
|
# Surface unique-constraint violations (soft-deleted record with same name) as a
|
||||||
|
# user-friendly 400 instead of a 500.
|
||||||
|
exc_str = str(commit_exc).lower()
|
||||||
|
if "uix_org_ca_name" in exc_str or "unique" in exc_str:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=(
|
||||||
|
"A CA with that name already exists in this organization "
|
||||||
|
"(it may have been recently deleted — choose a different name)."
|
||||||
|
),
|
||||||
|
status=400,
|
||||||
|
error_type="DUPLICATE_NAME",
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
return api_response(
|
return api_response(
|
||||||
data={"ca": ca.to_dict()},
|
data={"ca": ca.to_dict()},
|
||||||
@@ -1570,6 +1778,7 @@ def rotate_org_ca(org_id, ca_id):
|
|||||||
from gatehouse_app.models.ssh_ca.ca import CA, KeyType
|
from gatehouse_app.models.ssh_ca.ca import CA, KeyType
|
||||||
from gatehouse_app.models.organization.organization import Organization
|
from gatehouse_app.models.organization.organization import Organization
|
||||||
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
||||||
|
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
|
||||||
from gatehouse_app.utils.constants import AuditAction
|
from gatehouse_app.utils.constants import AuditAction
|
||||||
from gatehouse_app.models import AuditLog
|
from gatehouse_app.models import AuditLog
|
||||||
from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
|
from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey
|
||||||
@@ -1609,8 +1818,11 @@ def rotate_org_ca(org_id, ca_id):
|
|||||||
new_public_key = private_key_obj.public_key.to_string()
|
new_public_key = private_key_obj.public_key.to_string()
|
||||||
new_fingerprint = compute_ssh_fingerprint(new_public_key)
|
new_fingerprint = compute_ssh_fingerprint(new_public_key)
|
||||||
|
|
||||||
|
# Encrypt the new private key before storing
|
||||||
|
encrypted_new_private_key = encrypt_ca_key(new_private_key)
|
||||||
|
|
||||||
ca.rotate_key(
|
ca.rotate_key(
|
||||||
new_private_key=new_private_key,
|
new_private_key=encrypted_new_private_key,
|
||||||
new_public_key=new_public_key,
|
new_public_key=new_public_key,
|
||||||
new_fingerprint=new_fingerprint,
|
new_fingerprint=new_fingerprint,
|
||||||
reason=reason,
|
reason=reason,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from gatehouse_app.exceptions import (
|
|||||||
)
|
)
|
||||||
from gatehouse_app.utils.constants import AuditAction
|
from gatehouse_app.utils.constants import AuditAction
|
||||||
from gatehouse_app.models import AuditLog
|
from gatehouse_app.models import AuditLog
|
||||||
|
from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
|
||||||
from gatehouse_app.utils.decorators import login_required
|
from gatehouse_app.utils.decorators import login_required
|
||||||
from gatehouse_app.utils.response import api_response
|
from gatehouse_app.utils.response import api_response
|
||||||
|
|
||||||
@@ -78,11 +79,16 @@ def _get_or_create_system_ca():
|
|||||||
with open(pub_key_path) as f:
|
with open(pub_key_path) as f:
|
||||||
pub_key = f.read().strip()
|
pub_key = f.read().strip()
|
||||||
|
|
||||||
# Load private key for the record (stored but not actually used for signing here)
|
# Load private key for the record (encrypt before storing in DB)
|
||||||
priv_key = ""
|
priv_key = ""
|
||||||
if os.path.exists(key_path):
|
if os.path.exists(key_path):
|
||||||
with open(key_path) as f:
|
with open(key_path) as f:
|
||||||
priv_key = f.read()
|
raw_priv_key = f.read()
|
||||||
|
try:
|
||||||
|
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
|
||||||
|
priv_key = encrypt_ca_key(raw_priv_key)
|
||||||
|
except Exception:
|
||||||
|
priv_key = raw_priv_key # fallback: store as-is if encryption unavailable
|
||||||
|
|
||||||
fingerprint = compute_ssh_fingerprint(pub_key)
|
fingerprint = compute_ssh_fingerprint(pub_key)
|
||||||
|
|
||||||
@@ -120,7 +126,7 @@ def _get_or_create_system_ca():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user'):
|
def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user', cert_identity=None):
|
||||||
"""Save a signed certificate to the ssh_certificates table.
|
"""Save a signed certificate to the ssh_certificates table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -130,6 +136,8 @@ def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=N
|
|||||||
signing_response: SSHCertificateSigningResponse
|
signing_response: SSHCertificateSigningResponse
|
||||||
request_ip: Client IP address
|
request_ip: Client IP address
|
||||||
cert_type_str: 'user' or 'host' (from the sign request)
|
cert_type_str: 'user' or 'host' (from the sign request)
|
||||||
|
cert_identity: Rich OpenSSH key_id string (e.g. "user@host (Name) [org:slug]").
|
||||||
|
Falls back to str(ssh_key_id) when not provided.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SSHCertificate instance or None if persistence failed
|
SSHCertificate instance or None if persistence failed
|
||||||
@@ -153,7 +161,7 @@ def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=N
|
|||||||
ssh_key_id=ssh_key_id,
|
ssh_key_id=ssh_key_id,
|
||||||
certificate=signing_response.certificate,
|
certificate=signing_response.certificate,
|
||||||
serial=signing_response.serial,
|
serial=signing_response.serial,
|
||||||
key_id=str(ssh_key_id),
|
key_id=cert_identity or str(ssh_key_id),
|
||||||
cert_type=resolved_cert_type,
|
cert_type=resolved_cert_type,
|
||||||
principals=signing_response.principals,
|
principals=signing_response.principals,
|
||||||
valid_after=signing_response.valid_after,
|
valid_after=signing_response.valid_after,
|
||||||
@@ -465,7 +473,7 @@ def sign_certificate():
|
|||||||
|
|
||||||
# ── Check account suspension ──────────────────────────────────────────────
|
# ── Check account suspension ──────────────────────────────────────────────
|
||||||
from gatehouse_app.utils.constants import UserStatus
|
from gatehouse_app.utils.constants import UserStatus
|
||||||
if user.status == UserStatus.SUSPENDED:
|
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||||
return api_response(
|
return api_response(
|
||||||
success=False,
|
success=False,
|
||||||
message="Your account is suspended. Contact an administrator.",
|
message="Your account is suspended. Contact an administrator.",
|
||||||
@@ -482,6 +490,18 @@ def sign_certificate():
|
|||||||
key_id = data.get('key_id') or data.get('cert_id')
|
key_id = data.get('key_id') or data.get('cert_id')
|
||||||
expiry_hours = data.get('expiry_hours')
|
expiry_hours = data.get('expiry_hours')
|
||||||
|
|
||||||
|
# ── Log the request ───────────────────────────────────────────────────────
|
||||||
|
AuditLog.log(
|
||||||
|
action=AuditAction.SSH_CERT_REQUESTED,
|
||||||
|
user_id=user_id,
|
||||||
|
resource_type='SSHCertificate',
|
||||||
|
ip_address=request.remote_addr,
|
||||||
|
description=(
|
||||||
|
f'{user.email} requested a certificate'
|
||||||
|
+ (f' for principals: {", ".join(requested_principals)}' if requested_principals else '')
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# ── Resolve which principals the user is allowed to use ──────────────────
|
# ── Resolve which principals the user is allowed to use ──────────────────
|
||||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||||
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
|
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
|
||||||
@@ -601,11 +621,24 @@ def sign_certificate():
|
|||||||
else:
|
else:
|
||||||
policy_extensions = None # let signing service use its own defaults
|
policy_extensions = None # let signing service use its own defaults
|
||||||
|
|
||||||
|
# ── Build rich key_id identity for the OpenSSH cert ─────────────────────
|
||||||
|
# This appears in `ssh-keygen -L -f cert.pub` as the Key ID field and
|
||||||
|
# is stored in the DB cert record so it's auditable.
|
||||||
|
org_slugs = sorted({
|
||||||
|
om.organization.slug
|
||||||
|
for om in memberships
|
||||||
|
if om.organization and om.organization.deleted_at is None
|
||||||
|
and getattr(om.organization, 'slug', None)
|
||||||
|
})
|
||||||
|
org_slug = org_slugs[0] if org_slugs else "unknown"
|
||||||
|
full_name = getattr(user, 'full_name', None) or getattr(user, 'name', None) or "unknown"
|
||||||
|
cert_identity = f"{user.email} ({full_name}) [org:{org_slug}]"
|
||||||
|
|
||||||
signing_request = SSHCertificateSigningRequest(
|
signing_request = SSHCertificateSigningRequest(
|
||||||
ssh_public_key=ssh_key.payload,
|
ssh_public_key=ssh_key.payload,
|
||||||
principals=principals,
|
principals=principals,
|
||||||
cert_type=cert_type,
|
cert_type=cert_type,
|
||||||
key_id=key_id,
|
key_id=cert_identity,
|
||||||
expiry_hours=int(expiry_hours) if expiry_hours else None,
|
expiry_hours=int(expiry_hours) if expiry_hours else None,
|
||||||
extensions=policy_extensions,
|
extensions=policy_extensions,
|
||||||
)
|
)
|
||||||
@@ -620,7 +653,11 @@ def sign_certificate():
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = ssh_ca_service.sign_certificate(signing_request, ca_private_key=db_ca.private_key)
|
from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
|
||||||
|
ca_private_key_pem = decrypt_ca_key(db_ca.private_key)
|
||||||
|
response = ssh_ca_service.sign_certificate(
|
||||||
|
signing_request, ca_private_key=ca_private_key_pem, ca_obj=db_ca
|
||||||
|
)
|
||||||
except SSHCertificateError as e:
|
except SSHCertificateError as e:
|
||||||
AuditLog.log(
|
AuditLog.log(
|
||||||
action=AuditAction.SSH_CERT_FAILED,
|
action=AuditAction.SSH_CERT_FAILED,
|
||||||
@@ -649,6 +686,7 @@ def sign_certificate():
|
|||||||
signing_response=response,
|
signing_response=response,
|
||||||
request_ip=request.remote_addr,
|
request_ip=request.remote_addr,
|
||||||
cert_type_str=cert_type,
|
cert_type_str=cert_type,
|
||||||
|
cert_identity=cert_identity,
|
||||||
)
|
)
|
||||||
|
|
||||||
AuditLog.log(
|
AuditLog.log(
|
||||||
@@ -657,7 +695,40 @@ def sign_certificate():
|
|||||||
resource_type='SSHCertificate',
|
resource_type='SSHCertificate',
|
||||||
resource_id=cert_record.id if cert_record else key_id,
|
resource_id=cert_record.id if cert_record else key_id,
|
||||||
ip_address=request.remote_addr,
|
ip_address=request.remote_addr,
|
||||||
description=f'Certificate issued for principals: {", ".join(principals)}',
|
description=(
|
||||||
|
f'Certificate serial={response.serial} issued for {user.email}; '
|
||||||
|
f'principals: {", ".join(principals)}'
|
||||||
|
),
|
||||||
|
extra_data={
|
||||||
|
'serial': response.serial,
|
||||||
|
'key_id': cert_identity,
|
||||||
|
'principals': principals,
|
||||||
|
'ca_id': str(db_ca.id),
|
||||||
|
'ssh_key_id': str(key_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if cert_record:
|
||||||
|
CertificateAuditLog.log(
|
||||||
|
certificate_id=cert_record.id,
|
||||||
|
action='issued',
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=request.remote_addr,
|
||||||
|
user_agent=request.headers.get('User-Agent'),
|
||||||
|
message=(
|
||||||
|
f'Certificate serial={response.serial} issued for {user.email}; '
|
||||||
|
f'principals: {", ".join(principals)}'
|
||||||
|
),
|
||||||
|
extra_data={
|
||||||
|
'serial': response.serial,
|
||||||
|
'key_id': cert_identity,
|
||||||
|
'principals': principals,
|
||||||
|
'ca_id': str(db_ca.id),
|
||||||
|
'ssh_key_id': str(key_id),
|
||||||
|
'valid_after': response.valid_after.isoformat() if response.valid_after else None,
|
||||||
|
'valid_before': response.valid_before.isoformat() if response.valid_before else None,
|
||||||
|
},
|
||||||
|
success=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
@@ -753,6 +824,16 @@ def revoke_certificate(cert_id):
|
|||||||
description=f'Revoked: {reason}',
|
description=f'Revoked: {reason}',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CertificateAuditLog.log(
|
||||||
|
certificate_id=cert_id,
|
||||||
|
action='revoked',
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=request.remote_addr,
|
||||||
|
user_agent=request.headers.get('User-Agent'),
|
||||||
|
message=f'Certificate revoked: {reason}',
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
|
||||||
return api_response(
|
return api_response(
|
||||||
success=True,
|
success=True,
|
||||||
message='Certificate revoked successfully',
|
message='Certificate revoked successfully',
|
||||||
|
|||||||
@@ -73,11 +73,51 @@ def delete_me():
|
|||||||
"""
|
"""
|
||||||
Delete current user account (soft delete).
|
Delete current user account (soft delete).
|
||||||
|
|
||||||
|
Blocked if the user is the sole owner of any organization that has other
|
||||||
|
active members — they must transfer ownership or dissolve those organizations
|
||||||
|
first.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
200: Account deleted successfully
|
200: Account deleted successfully
|
||||||
401: Not authenticated
|
401: Not authenticated
|
||||||
|
409: User is sole owner of one or more organizations with other members
|
||||||
"""
|
"""
|
||||||
UserService.delete_user(g.current_user, soft=True)
|
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||||
|
from gatehouse_app.utils.constants import OrganizationRole
|
||||||
|
|
||||||
|
user = g.current_user
|
||||||
|
|
||||||
|
# Find orgs where this user is the sole owner AND other members exist.
|
||||||
|
owned_memberships = OrganizationMember.query.filter_by(
|
||||||
|
user_id=user.id,
|
||||||
|
role=OrganizationRole.OWNER,
|
||||||
|
deleted_at=None,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
blocked_orgs = []
|
||||||
|
for membership in owned_memberships:
|
||||||
|
org = membership.organization
|
||||||
|
if org.deleted_at is not None:
|
||||||
|
continue
|
||||||
|
member_count = org.get_member_count()
|
||||||
|
if member_count > 1:
|
||||||
|
blocked_orgs.append(org.name)
|
||||||
|
|
||||||
|
if blocked_orgs:
|
||||||
|
names = ", ".join(f'"{n}"' for n in blocked_orgs)
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=(
|
||||||
|
f"You are the sole owner of {len(blocked_orgs)} organization"
|
||||||
|
f"{'s' if len(blocked_orgs) > 1 else ''}: {names}. "
|
||||||
|
"Transfer ownership or delete those organizations before deleting your account."
|
||||||
|
),
|
||||||
|
status=409,
|
||||||
|
error_type="USER_IS_SOLE_OWNER",
|
||||||
|
error_details={"organizations": blocked_orgs},
|
||||||
|
)
|
||||||
|
|
||||||
|
UserService.delete_user(user, soft=True)
|
||||||
|
|
||||||
return api_response(
|
return api_response(
|
||||||
message="Account deleted successfully",
|
message="Account deleted successfully",
|
||||||
@@ -454,6 +494,31 @@ def admin_suspend_user(user_id):
|
|||||||
if not admin_in_shared_org:
|
if not admin_in_shared_org:
|
||||||
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
|
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
|
||||||
|
|
||||||
|
# ── Owner protection ──────────────────────────────────────────────────────
|
||||||
|
# An org owner cannot be suspended until they transfer ownership.
|
||||||
|
from gatehouse_app.utils.constants import OrganizationRole
|
||||||
|
owner_memberships = OrganizationMember.query.filter(
|
||||||
|
OrganizationMember.user_id == target.id,
|
||||||
|
OrganizationMember.role == OrganizationRole.OWNER,
|
||||||
|
OrganizationMember.deleted_at == None,
|
||||||
|
).all()
|
||||||
|
if owner_memberships:
|
||||||
|
org_names = [
|
||||||
|
m.organization.name
|
||||||
|
for m in owner_memberships
|
||||||
|
if m.organization and not m.organization.deleted_at
|
||||||
|
]
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=(
|
||||||
|
f"Cannot suspend an organization owner. "
|
||||||
|
f"{target.email} is the owner of: {', '.join(org_names)}. "
|
||||||
|
"Transfer ownership to another member first."
|
||||||
|
),
|
||||||
|
status=403,
|
||||||
|
error_type="OWNER_PROTECTION",
|
||||||
|
)
|
||||||
|
|
||||||
if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||||
return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT")
|
return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT")
|
||||||
|
|
||||||
@@ -645,3 +710,158 @@ def get_my_memberships():
|
|||||||
data={"orgs": orgs_result},
|
data={"orgs": orgs_result},
|
||||||
message="Memberships retrieved",
|
message="Memberships retrieved",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/admin/users/<user_id>/delete", methods=["POST"])
|
||||||
|
@login_required
|
||||||
|
@full_access_required
|
||||||
|
def admin_hard_delete_user(user_id):
|
||||||
|
"""Permanently delete a user and ALL associated data (hard delete, irreversible).
|
||||||
|
|
||||||
|
Required body: {"confirm": true}
|
||||||
|
|
||||||
|
Pre-conditions:
|
||||||
|
- Caller is OWNER or ADMIN of a shared org with the target.
|
||||||
|
- Cannot delete yourself.
|
||||||
|
- Target must not be the OWNER of any active organization (transfer first).
|
||||||
|
|
||||||
|
Side-effects:
|
||||||
|
- All active SSH certificates are revoked before deletion.
|
||||||
|
- The user row and all cascaded rows are hard-deleted from the database.
|
||||||
|
- An audit log entry is written by the *caller* (so it is not lost with the user).
|
||||||
|
"""
|
||||||
|
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||||
|
from gatehouse_app.models.user.user import User as _User
|
||||||
|
from gatehouse_app.extensions import db as _db
|
||||||
|
from gatehouse_app.utils.constants import UserStatus, AuditAction, OrganizationRole
|
||||||
|
from gatehouse_app.services.audit_service import AuditService
|
||||||
|
|
||||||
|
caller = g.current_user
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
if not data.get("confirm"):
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Deletion requires explicit confirmation. Send {\"confirm\": true} to proceed.",
|
||||||
|
status=400,
|
||||||
|
error_type="CONFIRMATION_REQUIRED",
|
||||||
|
)
|
||||||
|
|
||||||
|
target = _User.query.filter_by(id=user_id).first()
|
||||||
|
if not target:
|
||||||
|
return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
|
||||||
|
|
||||||
|
if target.id == caller.id:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Cannot delete your own account via this endpoint.",
|
||||||
|
status=400,
|
||||||
|
error_type="BAD_REQUEST",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Caller must be OWNER/ADMIN of a shared org.
|
||||||
|
# Include soft-deleted memberships so that already-soft-deleted users can
|
||||||
|
# still be hard-deleted by an admin who shared an org with them.
|
||||||
|
target_org_ids = {m.organization_id for m in target.organization_memberships}
|
||||||
|
admin_in_shared_org = OrganizationMember.query.filter(
|
||||||
|
OrganizationMember.user_id == caller.id,
|
||||||
|
OrganizationMember.organization_id.in_(target_org_ids),
|
||||||
|
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
|
||||||
|
OrganizationMember.deleted_at == None,
|
||||||
|
).first()
|
||||||
|
if not admin_in_shared_org:
|
||||||
|
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
|
||||||
|
|
||||||
|
# Block deletion if target is an org owner — they must transfer first
|
||||||
|
owner_memberships = OrganizationMember.query.filter(
|
||||||
|
OrganizationMember.user_id == target.id,
|
||||||
|
OrganizationMember.role == OrganizationRole.OWNER,
|
||||||
|
OrganizationMember.deleted_at == None,
|
||||||
|
).all()
|
||||||
|
if owner_memberships:
|
||||||
|
org_names = [
|
||||||
|
m.organization.name
|
||||||
|
for m in owner_memberships
|
||||||
|
if m.organization and not m.organization.deleted_at
|
||||||
|
]
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=(
|
||||||
|
f"Cannot delete an organization owner. "
|
||||||
|
f"{target.email} is the owner of: {', '.join(org_names)}. "
|
||||||
|
"Transfer ownership to another member first."
|
||||||
|
),
|
||||||
|
status=403,
|
||||||
|
error_type="OWNER_PROTECTION",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Collect counts for audit metadata ────────────────────────────────────
|
||||||
|
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
|
||||||
|
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
|
||||||
|
|
||||||
|
ssh_key_count = SSHKey.query.filter_by(user_id=target.id, deleted_at=None).count()
|
||||||
|
active_cert_count = SSHCertificate.query.filter_by(
|
||||||
|
user_id=target.id, revoked=False
|
||||||
|
).filter(SSHCertificate.deleted_at == None).count()
|
||||||
|
|
||||||
|
# ── Revoke all active SSH certificates before deletion ───────────────────
|
||||||
|
active_certs = SSHCertificate.query.filter_by(
|
||||||
|
user_id=target.id, revoked=False
|
||||||
|
).filter(SSHCertificate.deleted_at == None).all()
|
||||||
|
for cert in active_certs:
|
||||||
|
try:
|
||||||
|
cert.revoke("account_deleted")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if active_certs:
|
||||||
|
try:
|
||||||
|
_db.session.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ── Hard delete ───────────────────────────────────────────────────────────
|
||||||
|
target_email = target.email # capture before deletion
|
||||||
|
target_id_str = str(target.id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_db.session.delete(target) # cascades to all child tables
|
||||||
|
_db.session.flush()
|
||||||
|
except Exception as exc:
|
||||||
|
_db.session.rollback()
|
||||||
|
import logging
|
||||||
|
logging.getLogger(__name__).error(f"Hard delete failed for {target_id_str}: {exc}")
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Failed to delete user account. Please try again.",
|
||||||
|
status=500,
|
||||||
|
error_type="SERVER_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Audit log (written as the caller so it survives the deletion) ─────────
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.USER_HARD_DELETE,
|
||||||
|
user_id=caller.id,
|
||||||
|
organization_id=admin_in_shared_org.organization_id,
|
||||||
|
resource_type="user",
|
||||||
|
resource_id=target_id_str,
|
||||||
|
description=f"Admin permanently deleted user account: {target_email}",
|
||||||
|
metadata={
|
||||||
|
"deleted_user_id": target_id_str,
|
||||||
|
"deleted_user_email": target_email,
|
||||||
|
"ssh_keys_deleted": ssh_key_count,
|
||||||
|
"certs_revoked": active_cert_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
_db.session.commit()
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
message=f"User account {target_email} has been permanently deleted.",
|
||||||
|
data={
|
||||||
|
"deleted_user_id": target_id_str,
|
||||||
|
"deleted_user_email": target_email,
|
||||||
|
"ssh_keys_deleted": ssh_key_count,
|
||||||
|
"certs_revoked": active_cert_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@@ -88,6 +88,11 @@ class CA(BaseModel):
|
|||||||
rotated_at = db.Column(db.DateTime, nullable=True)
|
rotated_at = db.Column(db.DateTime, nullable=True)
|
||||||
rotation_reason = db.Column(db.String(255), nullable=True)
|
rotation_reason = db.Column(db.String(255), nullable=True)
|
||||||
|
|
||||||
|
# Monotonically-increasing serial counter. Every cert this CA issues
|
||||||
|
# gets the next value so serials are unique, ordered, and auditable.
|
||||||
|
# Protected by a row-level SELECT … FOR UPDATE in get_next_serial().
|
||||||
|
next_serial_number = db.Column(db.BigInteger, default=1, nullable=False)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
organization = db.relationship("Organization", back_populates="cas")
|
organization = db.relationship("Organization", back_populates="cas")
|
||||||
certificates = db.relationship(
|
certificates = db.relationship(
|
||||||
@@ -102,7 +107,6 @@ class CA(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.UniqueConstraint("organization_id", "name", name="uix_org_ca_name"),
|
|
||||||
db.Index("idx_ca_org_active", "organization_id", "is_active"),
|
db.Index("idx_ca_org_active", "organization_id", "is_active"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -162,6 +166,28 @@ class CA(BaseModel):
|
|||||||
self.rotation_reason = reason
|
self.rotation_reason = reason
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
|
def get_next_serial(self) -> int:
|
||||||
|
"""Atomically increment and return the next certificate serial number.
|
||||||
|
|
||||||
|
Uses a SELECT … FOR UPDATE row lock so concurrent requests never
|
||||||
|
receive the same serial. Must be called inside an active DB
|
||||||
|
transaction (i.e. before the final session.commit()).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The serial number to embed in the next certificate.
|
||||||
|
"""
|
||||||
|
# Re-fetch this CA row with an exclusive row lock
|
||||||
|
locked = (
|
||||||
|
db.session.query(CA)
|
||||||
|
.with_for_update()
|
||||||
|
.filter_by(id=self.id)
|
||||||
|
.one()
|
||||||
|
)
|
||||||
|
serial = locked.next_serial_number
|
||||||
|
locked.next_serial_number = serial + 1
|
||||||
|
db.session.flush() # write increment; commit happens in the caller
|
||||||
|
return serial
|
||||||
|
|
||||||
|
|
||||||
class CAPermission(BaseModel):
|
class CAPermission(BaseModel):
|
||||||
"""Per-user CA permission model.
|
"""Per-user CA permission model.
|
||||||
|
|||||||
@@ -77,7 +77,10 @@ class TOTPVerifyEnrollmentSchema(Schema):
|
|||||||
class TOTPVerifySchema(Schema):
|
class TOTPVerifySchema(Schema):
|
||||||
"""Schema for TOTP code verification during login."""
|
"""Schema for TOTP code verification during login."""
|
||||||
|
|
||||||
code = fields.Str(required=True)
|
code = fields.Str(
|
||||||
|
required=True,
|
||||||
|
validate=validate.Length(min=1),
|
||||||
|
)
|
||||||
is_backup_code = fields.Bool(load_default=False)
|
is_backup_code = fields.Bool(load_default=False)
|
||||||
client_timestamp = fields.Int(
|
client_timestamp = fields.Int(
|
||||||
required=False,
|
required=False,
|
||||||
@@ -85,6 +88,27 @@ class TOTPVerifySchema(Schema):
|
|||||||
metadata={"description": "Client UTC timestamp in seconds since epoch for TOTP verification"},
|
metadata={"description": "Client UTC timestamp in seconds since epoch for TOTP verification"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@validates_schema
|
||||||
|
def validate_code_format(self, data, **kwargs):
|
||||||
|
"""Validate code format depending on whether it's a backup code."""
|
||||||
|
code = data.get("code", "")
|
||||||
|
is_backup_code = data.get("is_backup_code", False)
|
||||||
|
if is_backup_code:
|
||||||
|
# Backup codes are 16 uppercase hex characters
|
||||||
|
if not code or len(code) != 16 or not all(c in "0123456789ABCDEFabcdef" for c in code):
|
||||||
|
raise ValidationError(
|
||||||
|
"Backup code must be a 16-character hexadecimal string.",
|
||||||
|
field_name="code",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Regular TOTP codes are exactly 6 digits
|
||||||
|
import re
|
||||||
|
if not re.match(r"^\d{6}$", code):
|
||||||
|
raise ValidationError(
|
||||||
|
"Code must be a 6-digit number.",
|
||||||
|
field_name="code",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TOTPDisableSchema(Schema):
|
class TOTPDisableSchema(Schema):
|
||||||
"""Schema for disabling TOTP."""
|
"""Schema for disabling TOTP."""
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class AuditService:
|
|||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
metadata=metadata,
|
extra_data=metadata,
|
||||||
description=description,
|
description=description,
|
||||||
success=success,
|
success=success,
|
||||||
error_message=error_message,
|
error_message=error_message,
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class AuthService:
|
|||||||
if current_app.config.get('ENV') == 'development':
|
if current_app.config.get('ENV') == 'development':
|
||||||
logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}")
|
logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}")
|
||||||
|
|
||||||
if user.status == UserStatus.SUSPENDED:
|
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||||
raise AccountSuspendedError()
|
raise AccountSuspendedError()
|
||||||
if user.status == UserStatus.INACTIVE:
|
if user.status == UserStatus.INACTIVE:
|
||||||
raise AccountInactiveError()
|
raise AccountInactiveError()
|
||||||
@@ -210,6 +210,22 @@ class AuthService:
|
|||||||
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
|
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
# Invalidate all other sessions so that if an attacker had a valid
|
||||||
|
# session token, changing the password actually locks them out.
|
||||||
|
# The current request's session (if any) is preserved so the user
|
||||||
|
# doesn't have to log in again immediately.
|
||||||
|
from flask import g as flask_g
|
||||||
|
current_session_id = getattr(flask_g, "current_session", None)
|
||||||
|
current_session_id = current_session_id.id if current_session_id else None
|
||||||
|
sessions_to_revoke = Session.query.filter(
|
||||||
|
Session.user_id == user.id,
|
||||||
|
Session.revoked_at == None, # noqa: E711
|
||||||
|
).all()
|
||||||
|
for sess in sessions_to_revoke:
|
||||||
|
if sess.id != current_session_id:
|
||||||
|
sess.revoke(reason="Password changed")
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
# Log password change
|
# Log password change
|
||||||
AuditService.log_action(
|
AuditService.log_action(
|
||||||
action=AuditAction.PASSWORD_CHANGE,
|
action=AuditAction.PASSWORD_CHANGE,
|
||||||
@@ -482,9 +498,24 @@ class AuthService:
|
|||||||
if not secret:
|
if not secret:
|
||||||
raise InvalidCredentialsError("TOTP secret not found")
|
raise InvalidCredentialsError("TOTP secret not found")
|
||||||
|
|
||||||
|
# Replay-attack prevention: reject codes that have already been
|
||||||
|
# accepted within the current validity window.
|
||||||
|
if TOTPService.is_code_already_used(str(user.id), code):
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.TOTP_VERIFY_FAILED,
|
||||||
|
user_id=user.id,
|
||||||
|
resource_type="authentication_method",
|
||||||
|
resource_id=auth_method.id,
|
||||||
|
description="TOTP code replay attempt detected",
|
||||||
|
)
|
||||||
|
raise InvalidCredentialsError("Invalid TOTP code")
|
||||||
|
|
||||||
is_valid = TOTPService.verify_code(secret, code, client_utc_timestamp=client_utc_timestamp)
|
is_valid = TOTPService.verify_code(secret, code, client_utc_timestamp=client_utc_timestamp)
|
||||||
|
|
||||||
if is_valid:
|
if is_valid:
|
||||||
|
# Mark this code as used to prevent replay within the validity window
|
||||||
|
TOTPService.mark_code_used(str(user.id), code)
|
||||||
|
|
||||||
auth_method.last_used_at = datetime.now(timezone.utc)
|
auth_method.last_used_at = datetime.now(timezone.utc)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -736,7 +736,12 @@ class ExternalAuthService:
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate PKCE
|
# Generate PKCE — skip for confidential clients (Google, Microsoft) that use a
|
||||||
|
# client_secret. Sending code_challenge to Microsoft causes it to enforce PKCE on
|
||||||
|
# the token exchange, which then fails. Matches the behaviour of initiate_login_flow.
|
||||||
|
code_verifier = None
|
||||||
|
code_challenge = None
|
||||||
|
if provider_type_str not in ('google', 'microsoft'):
|
||||||
code_verifier = secrets.token_urlsafe(32)
|
code_verifier = secrets.token_urlsafe(32)
|
||||||
code_challenge = cls._compute_s256_challenge(code_verifier)
|
code_challenge = cls._compute_s256_challenge(code_verifier)
|
||||||
|
|
||||||
|
|||||||
@@ -188,11 +188,10 @@ class OrganizationService:
|
|||||||
Raises:
|
Raises:
|
||||||
ConflictError: If user is already a member
|
ConflictError: If user is already a member
|
||||||
"""
|
"""
|
||||||
# Check if already a member
|
# Check if already a member (active or soft-deleted — both blocked by DB unique constraint)
|
||||||
existing = OrganizationMember.query.filter_by(
|
existing = OrganizationMember.query.filter_by(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
organization_id=org.id,
|
organization_id=org.id,
|
||||||
deleted_at=None,
|
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
# Development-only debug logging for membership validation
|
# Development-only debug logging for membership validation
|
||||||
@@ -200,6 +199,25 @@ class OrganizationService:
|
|||||||
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
|
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
|
if existing.deleted_at is not None:
|
||||||
|
# Reactivate the soft-deleted membership with the new role
|
||||||
|
existing.deleted_at = None
|
||||||
|
existing.role = role
|
||||||
|
existing.invited_by_id = inviter_id
|
||||||
|
existing.invited_at = datetime.now(timezone.utc)
|
||||||
|
existing.joined_at = datetime.now(timezone.utc)
|
||||||
|
existing.save()
|
||||||
|
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.ORG_MEMBER_ADD,
|
||||||
|
user_id=inviter_id,
|
||||||
|
organization_id=org.id,
|
||||||
|
resource_type="organization_member",
|
||||||
|
resource_id=existing.id,
|
||||||
|
metadata={"added_user_id": user_id, "role": role.value},
|
||||||
|
description=f"Member re-added to organization with role: {role.value}",
|
||||||
|
)
|
||||||
|
return existing
|
||||||
raise ConflictError("User is already a member of this organization")
|
raise ConflictError("User is already a member of this organization")
|
||||||
|
|
||||||
# Create membership
|
# Create membership
|
||||||
|
|||||||
@@ -192,13 +192,19 @@ class SSHCASigningService:
|
|||||||
self,
|
self,
|
||||||
signing_request: SSHCertificateSigningRequest,
|
signing_request: SSHCertificateSigningRequest,
|
||||||
ca_private_key: Optional[str] = None,
|
ca_private_key: Optional[str] = None,
|
||||||
) -> SSHCertificateSigningResponse:
|
ca_obj=None,
|
||||||
|
) -> "SSHCertificateSigningResponse":
|
||||||
"""Sign an SSH certificate.
|
"""Sign an SSH certificate.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
signing_request: SSHCertificateSigningRequest instance
|
signing_request: SSHCertificateSigningRequest instance
|
||||||
ca_private_key: CA private key in PEM format. If not provided,
|
ca_private_key: CA private key in PEM format. If not provided,
|
||||||
loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var)
|
loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var)
|
||||||
|
ca_obj: Optional CA model instance. When supplied its monotonic
|
||||||
|
serial counter is incremented atomically (SELECT FOR UPDATE)
|
||||||
|
and the resulting integer is embedded in the certificate's
|
||||||
|
serial field. This ensures every issued cert has a unique,
|
||||||
|
ordered, auditable serial number.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SSHCertificateSigningResponse with signed certificate
|
SSHCertificateSigningResponse with signed certificate
|
||||||
@@ -245,7 +251,8 @@ class SSHCASigningService:
|
|||||||
valid_before = now + timedelta(hours=expiry_hours)
|
valid_before = now + timedelta(hours=expiry_hours)
|
||||||
|
|
||||||
# Set certificate fields
|
# Set certificate fields
|
||||||
cert_type = 1 if signing_request.cert_type == "user" else 0
|
# sshkey-tools: user=1, host=2 (not 0)
|
||||||
|
cert_type = 1 if signing_request.cert_type == "user" else 2
|
||||||
|
|
||||||
certificate.fields.cert_type = cert_type
|
certificate.fields.cert_type = cert_type
|
||||||
certificate.fields.key_id = signing_request.key_id
|
certificate.fields.key_id = signing_request.key_id
|
||||||
@@ -253,6 +260,19 @@ class SSHCASigningService:
|
|||||||
certificate.fields.valid_after = now
|
certificate.fields.valid_after = now
|
||||||
certificate.fields.valid_before = valid_before
|
certificate.fields.valid_before = valid_before
|
||||||
|
|
||||||
|
# ── Serial number ────────────────────────────────────────────────
|
||||||
|
# If a CA object is provided, use its monotonic counter so every
|
||||||
|
# certificate gets a unique, ordered, auditable serial. The
|
||||||
|
# counter increment is flushed inside get_next_serial(); the
|
||||||
|
# caller's commit() persists it atomically with the cert record.
|
||||||
|
if ca_obj is not None:
|
||||||
|
assigned_serial = ca_obj.get_next_serial()
|
||||||
|
certificate.fields.serial = assigned_serial
|
||||||
|
self.logger.debug(
|
||||||
|
f"Assigned serial {assigned_serial} from CA {ca_obj.id}"
|
||||||
|
)
|
||||||
|
# ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
# Set extensions — prefer policy-provided list, fall back to standard set
|
# Set extensions — prefer policy-provided list, fall back to standard set
|
||||||
extensions = signing_request.extensions
|
extensions = signing_request.extensions
|
||||||
if not extensions:
|
if not extensions:
|
||||||
@@ -276,7 +296,12 @@ class SSHCASigningService:
|
|||||||
self.logger.error(f"Certificate verification failed: {str(e)}")
|
self.logger.error(f"Certificate verification failed: {str(e)}")
|
||||||
raise SSHCASigningError(f"Certificate verification failed: {str(e)}")
|
raise SSHCASigningError(f"Certificate verification failed: {str(e)}")
|
||||||
|
|
||||||
# Extract serial from certificate
|
# Extract serial from certificate — use the integer we assigned
|
||||||
|
# when ca_obj was provided, otherwise fall back to whatever the
|
||||||
|
# library generated.
|
||||||
|
if ca_obj is not None:
|
||||||
|
serial = str(assigned_serial)
|
||||||
|
else:
|
||||||
serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
|
serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
|
||||||
|
|
||||||
# Build response
|
# Build response
|
||||||
|
|||||||
@@ -11,10 +11,51 @@ from gatehouse_app.extensions import bcrypt
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# TOTP codes are valid for at most (2*window + 1) * 30s steps.
|
||||||
|
# With window=1 that's 3 steps = 90 seconds. We use a slightly
|
||||||
|
# generous TTL of 95 seconds to account for clock skew at boundaries.
|
||||||
|
_TOTP_USED_CODE_TTL = 95
|
||||||
|
|
||||||
|
|
||||||
class TOTPService:
|
class TOTPService:
|
||||||
"""Service for TOTP operations."""
|
"""Service for TOTP operations."""
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Replay-attack prevention helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _used_key(user_id: str, code: str) -> str:
|
||||||
|
return f"totp:used:{user_id}:{code}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_code_already_used(user_id: str, code: str) -> bool:
|
||||||
|
"""Return True if *code* has already been accepted for *user_id*
|
||||||
|
within the current validity window (prevents replay attacks)."""
|
||||||
|
try:
|
||||||
|
from gatehouse_app.extensions import redis_client
|
||||||
|
if redis_client is None:
|
||||||
|
return False
|
||||||
|
return redis_client.exists(TOTPService._used_key(user_id, code)) == 1
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Redis unavailable for TOTP replay check; allowing code")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mark_code_used(user_id: str, code: str) -> None:
|
||||||
|
"""Record *code* as consumed for *user_id* so it cannot be reused."""
|
||||||
|
try:
|
||||||
|
from gatehouse_app.extensions import redis_client
|
||||||
|
if redis_client is None:
|
||||||
|
return
|
||||||
|
redis_client.setex(
|
||||||
|
TOTPService._used_key(user_id, code),
|
||||||
|
_TOTP_USED_CODE_TTL,
|
||||||
|
"1",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Redis unavailable; TOTP used-code not recorded")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_secret() -> str:
|
def generate_secret() -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -642,6 +642,26 @@ class WebAuthnService:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def credential_belongs_to_user(cls, credential_id: str, user: User) -> bool:
|
||||||
|
"""Check whether *credential_id* exists and belongs to *user*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
credential_id: The credential ID to look up
|
||||||
|
user: User instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the credential exists and belongs to this user, False otherwise.
|
||||||
|
"""
|
||||||
|
auth_method = AuthenticationMethod.query.filter_by(
|
||||||
|
user_id=user.id,
|
||||||
|
method_type=AuthMethodType.WEBAUTHN,
|
||||||
|
deleted_at=None,
|
||||||
|
).first()
|
||||||
|
if not auth_method or not auth_method.provider_data:
|
||||||
|
return False
|
||||||
|
return auth_method.provider_data.get("credential_id") == credential_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def rename_credential(cls, credential_id: str, user: User, name: str) -> bool:
|
def rename_credential(cls, credential_id: str, user: User, name: str) -> bool:
|
||||||
"""Rename a passkey credential.
|
"""Rename a passkey credential.
|
||||||
|
|||||||
@@ -0,0 +1,206 @@
|
|||||||
|
"""Encryption helpers for CA private keys stored in the database.
|
||||||
|
|
||||||
|
CA private keys are encrypted at rest using Fernet (AES-128-CBC + HMAC-SHA256)
|
||||||
|
from the ``cryptography`` package. The encryption key is derived from the
|
||||||
|
``CA_ENCRYPTION_KEY`` environment variable (or ``Flask.config["CA_ENCRYPTION_KEY"]``).
|
||||||
|
|
||||||
|
Key derivation
|
||||||
|
--------------
|
||||||
|
Fernet requires a URL-safe base64-encoded 32-byte key. We accept any string
|
||||||
|
from the env and derive the actual Fernet key using SHA-256 so that operators
|
||||||
|
can supply human-readable secrets without having to pre-encode them.
|
||||||
|
|
||||||
|
Envelope format
|
||||||
|
---------------
|
||||||
|
Encrypted values are stored as the string::
|
||||||
|
|
||||||
|
$fernet$<fernet_token>
|
||||||
|
|
||||||
|
The ``$fernet$`` prefix lets the code distinguish already-encrypted values from
|
||||||
|
legacy plaintext PEM keys so that the migration path is safe and idempotent.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
Encrypt before storing::
|
||||||
|
|
||||||
|
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
|
||||||
|
ca.private_key = encrypt_ca_key(private_key_pem)
|
||||||
|
|
||||||
|
Decrypt before use::
|
||||||
|
|
||||||
|
from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
|
||||||
|
plaintext_pem = decrypt_ca_key(ca.private_key)
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Prefix that marks a stored value as Fernet-encrypted
|
||||||
|
_FERNET_PREFIX = "$fernet$"
|
||||||
|
|
||||||
|
|
||||||
|
class CAKeyEncryptionError(Exception):
|
||||||
|
"""Raised when CA key encryption or decryption fails."""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fernet() -> Fernet:
|
||||||
|
"""Build a Fernet instance from the configured encryption key.
|
||||||
|
|
||||||
|
Looks up ``CA_ENCRYPTION_KEY`` in the environment first, then falls back to
|
||||||
|
the Flask app config (if a request context is active).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CAKeyEncryptionError: if no key is configured or it is the insecure
|
||||||
|
placeholder value in a production-like environment.
|
||||||
|
"""
|
||||||
|
raw_key = os.environ.get("CA_ENCRYPTION_KEY")
|
||||||
|
|
||||||
|
if not raw_key:
|
||||||
|
# Try Flask config if we're inside an app context
|
||||||
|
try:
|
||||||
|
from flask import current_app
|
||||||
|
raw_key = current_app.config.get("CA_ENCRYPTION_KEY")
|
||||||
|
except RuntimeError:
|
||||||
|
pass # No app context
|
||||||
|
|
||||||
|
if not raw_key:
|
||||||
|
raise CAKeyEncryptionError(
|
||||||
|
"CA_ENCRYPTION_KEY is not set. "
|
||||||
|
"Set this environment variable before starting the application."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warn loudly when running with the placeholder in a non-test environment
|
||||||
|
env_name = os.environ.get("FLASK_ENV", "").lower()
|
||||||
|
if raw_key.startswith("dev-") and env_name not in ("development", "testing", "test"):
|
||||||
|
logger.warning(
|
||||||
|
"CA_ENCRYPTION_KEY appears to be a development placeholder. "
|
||||||
|
"Set a strong random key for production environments."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Derive a 32-byte key from the raw secret via SHA-256, then URL-safe base64
|
||||||
|
key_bytes = hashlib.sha256(raw_key.encode()).digest()
|
||||||
|
fernet_key = base64.urlsafe_b64encode(key_bytes)
|
||||||
|
return Fernet(fernet_key)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_ca_key(plaintext_pem: str) -> str:
|
||||||
|
"""Encrypt a CA private key PEM string.
|
||||||
|
|
||||||
|
Idempotent: already-encrypted values are returned unchanged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plaintext_pem: CA private key in OpenSSH/PEM format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encrypted string with ``$fernet$`` prefix, safe for database storage.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CAKeyEncryptionError: if the key cannot be encrypted.
|
||||||
|
"""
|
||||||
|
if not plaintext_pem:
|
||||||
|
raise CAKeyEncryptionError("Cannot encrypt an empty key")
|
||||||
|
|
||||||
|
# Already encrypted — do not double-encrypt
|
||||||
|
if plaintext_pem.startswith(_FERNET_PREFIX):
|
||||||
|
return plaintext_pem
|
||||||
|
|
||||||
|
try:
|
||||||
|
fernet = _get_fernet()
|
||||||
|
token = fernet.encrypt(plaintext_pem.encode()).decode()
|
||||||
|
return f"{_FERNET_PREFIX}{token}"
|
||||||
|
except CAKeyEncryptionError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
raise CAKeyEncryptionError(f"Failed to encrypt CA key: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_ca_key(stored_value: str) -> str:
|
||||||
|
"""Decrypt a CA private key retrieved from the database.
|
||||||
|
|
||||||
|
Idempotent: plaintext (legacy) values are returned unchanged so that the
|
||||||
|
system continues to work while a migration encrypts existing rows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stored_value: Value from ``CA.private_key`` column.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plaintext PEM string ready for use with ``sshkey_tools``.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CAKeyEncryptionError: if decryption fails (wrong key, corrupted data).
|
||||||
|
"""
|
||||||
|
if not stored_value:
|
||||||
|
raise CAKeyEncryptionError("Cannot decrypt an empty value")
|
||||||
|
|
||||||
|
# Legacy plaintext key — return as-is
|
||||||
|
if not stored_value.startswith(_FERNET_PREFIX):
|
||||||
|
logger.warning(
|
||||||
|
"CA private key appears to be stored as plaintext. "
|
||||||
|
"Run the migration to encrypt existing keys."
|
||||||
|
)
|
||||||
|
return stored_value
|
||||||
|
|
||||||
|
token = stored_value[len(_FERNET_PREFIX):]
|
||||||
|
try:
|
||||||
|
fernet = _get_fernet()
|
||||||
|
return fernet.decrypt(token.encode()).decode()
|
||||||
|
except InvalidToken as exc:
|
||||||
|
raise CAKeyEncryptionError(
|
||||||
|
"CA key decryption failed — the CA_ENCRYPTION_KEY may be incorrect "
|
||||||
|
"or the stored key is corrupted."
|
||||||
|
) from exc
|
||||||
|
except CAKeyEncryptionError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
raise CAKeyEncryptionError(f"Unexpected decryption error: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def is_encrypted(stored_value: str) -> bool:
|
||||||
|
"""Return True if the stored value has the ``$fernet$`` envelope.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stored_value: Value from ``CA.private_key`` column.
|
||||||
|
"""
|
||||||
|
return bool(stored_value and stored_value.startswith(_FERNET_PREFIX))
|
||||||
|
|
||||||
|
|
||||||
|
def reencrypt_ca_key(stored_value: str, old_raw_key: str, new_raw_key: str) -> str:
|
||||||
|
"""Re-encrypt a CA key with a new encryption key (for key rotation).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stored_value: Current value from ``CA.private_key`` (may or may not be encrypted).
|
||||||
|
old_raw_key: The current ``CA_ENCRYPTION_KEY`` value (raw secret string).
|
||||||
|
new_raw_key: The new ``CA_ENCRYPTION_KEY`` value to encrypt with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New encrypted envelope string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CAKeyEncryptionError: if decryption or re-encryption fails.
|
||||||
|
"""
|
||||||
|
# Decrypt with old key
|
||||||
|
if stored_value.startswith(_FERNET_PREFIX):
|
||||||
|
token = stored_value[len(_FERNET_PREFIX):]
|
||||||
|
old_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(old_raw_key.encode()).digest())
|
||||||
|
try:
|
||||||
|
plaintext = Fernet(old_key_bytes).decrypt(token.encode()).decode()
|
||||||
|
except InvalidToken as exc:
|
||||||
|
raise CAKeyEncryptionError(
|
||||||
|
"Re-encryption failed: could not decrypt with the old key."
|
||||||
|
) from exc
|
||||||
|
else:
|
||||||
|
# Plaintext
|
||||||
|
plaintext = stored_value
|
||||||
|
|
||||||
|
# Re-encrypt with new key
|
||||||
|
new_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(new_raw_key.encode()).digest())
|
||||||
|
try:
|
||||||
|
token = Fernet(new_key_bytes).encrypt(plaintext.encode()).decode()
|
||||||
|
return f"{_FERNET_PREFIX}{token}"
|
||||||
|
except Exception as exc:
|
||||||
|
raise CAKeyEncryptionError(f"Re-encryption with new key failed: {exc}") from exc
|
||||||
@@ -61,6 +61,7 @@ class AuditAction(str, Enum):
|
|||||||
USER_REGISTER = "user.register"
|
USER_REGISTER = "user.register"
|
||||||
USER_UPDATE = "user.update"
|
USER_UPDATE = "user.update"
|
||||||
USER_DELETE = "user.delete"
|
USER_DELETE = "user.delete"
|
||||||
|
USER_HARD_DELETE = "user.hard_delete"
|
||||||
USER_SUSPEND = "user.suspend"
|
USER_SUSPEND = "user.suspend"
|
||||||
USER_UNSUSPEND = "user.unsuspend"
|
USER_UNSUSPEND = "user.unsuspend"
|
||||||
PASSWORD_CHANGE = "user.password_change"
|
PASSWORD_CHANGE = "user.password_change"
|
||||||
@@ -73,6 +74,7 @@ class AuditAction(str, Enum):
|
|||||||
ORG_MEMBER_ADD = "org.member.add"
|
ORG_MEMBER_ADD = "org.member.add"
|
||||||
ORG_MEMBER_REMOVE = "org.member.remove"
|
ORG_MEMBER_REMOVE = "org.member.remove"
|
||||||
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
|
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
|
||||||
|
ORG_OWNERSHIP_TRANSFERRED = "org.ownership.transferred"
|
||||||
|
|
||||||
# Session actions
|
# Session actions
|
||||||
SESSION_CREATE = "session.create"
|
SESSION_CREATE = "session.create"
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""Add USER_SUSPEND and USER_UNSUSPEND to auditaction enum.
|
||||||
|
|
||||||
|
Revision ID: 015_add_user_suspend_audit_actions
|
||||||
|
Revises: 014_add_dept_cert_policy
|
||||||
|
Create Date: 2026-03-02
|
||||||
|
|
||||||
|
USER_SUSPEND and USER_UNSUSPEND were added to the Python AuditAction enum
|
||||||
|
but were never synced to the PostgreSQL auditaction type, causing a
|
||||||
|
DataError (invalid enum value) whenever an admin suspends or unsuspends a user.
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "015_user_suspend_audit"
|
||||||
|
down_revision = "014_add_dept_cert_policy"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
for val in ("USER_SUSPEND", "USER_UNSUSPEND"):
|
||||||
|
op.execute(f"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM pg_enum
|
||||||
|
WHERE enumlabel = '{val}'
|
||||||
|
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
|
||||||
|
) THEN
|
||||||
|
ALTER TYPE auditaction ADD VALUE '{val}';
|
||||||
|
END IF;
|
||||||
|
END$$;
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||||
|
pass
|
||||||
@@ -0,0 +1,168 @@
|
|||||||
|
"""Encrypt existing plaintext CA private keys at rest.
|
||||||
|
|
||||||
|
Revision ID: 016_encrypt_existing_ca_keys
|
||||||
|
Revises: 015_add_user_suspend_audit_actions
|
||||||
|
Create Date: 2026-03-02
|
||||||
|
|
||||||
|
All CA private keys created before this migration were stored as plaintext PEM
|
||||||
|
strings in the ``cas.private_key`` column. This migration detects those rows
|
||||||
|
(by checking for the absence of the ``$fernet$`` prefix that encrypted values
|
||||||
|
carry) and re-encrypts them with the key derived from ``CA_ENCRYPTION_KEY``.
|
||||||
|
|
||||||
|
The migration is safe to re-run: already-encrypted rows are left untouched.
|
||||||
|
|
||||||
|
Prerequisites
|
||||||
|
-------------
|
||||||
|
``CA_ENCRYPTION_KEY`` must be set in the environment before running this
|
||||||
|
migration. The same value must be configured for the running application.
|
||||||
|
|
||||||
|
To roll back to plaintext (downgrade):
|
||||||
|
The ``downgrade()`` function decrypts all rows back to plaintext PEM. This is
|
||||||
|
provided only for emergency rollback and should not be used in production once
|
||||||
|
the system has been running with encrypted keys.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Alembic revision identifiers
|
||||||
|
revision = "016_encrypt_ca_keys"
|
||||||
|
down_revision = "015_user_suspend_audit"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
_FERNET_PREFIX = "$fernet$"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fernet():
|
||||||
|
"""Build a Fernet instance from CA_ENCRYPTION_KEY env var."""
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
raw_key = os.environ.get("CA_ENCRYPTION_KEY")
|
||||||
|
if not raw_key:
|
||||||
|
raise RuntimeError(
|
||||||
|
"CA_ENCRYPTION_KEY environment variable is not set. "
|
||||||
|
"Set it before running this migration."
|
||||||
|
)
|
||||||
|
key_bytes = base64.urlsafe_b64encode(hashlib.sha256(raw_key.encode()).digest())
|
||||||
|
return Fernet(key_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
"""Encrypt plaintext CA private keys."""
|
||||||
|
bind = op.get_bind()
|
||||||
|
session = Session(bind=bind)
|
||||||
|
|
||||||
|
try:
|
||||||
|
fernet = _get_fernet()
|
||||||
|
except RuntimeError as exc:
|
||||||
|
raise RuntimeError(str(exc)) from exc
|
||||||
|
|
||||||
|
# Fetch all non-deleted CA rows
|
||||||
|
rows = session.execute(
|
||||||
|
sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
encrypted_count = 0
|
||||||
|
skipped_count = 0
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
ca_id, private_key = row[0], row[1]
|
||||||
|
|
||||||
|
if not private_key:
|
||||||
|
logger.warning(f"CA {ca_id} has empty private_key — skipping")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if private_key.startswith(_FERNET_PREFIX):
|
||||||
|
# Already encrypted
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Encrypt
|
||||||
|
try:
|
||||||
|
token = fernet.encrypt(private_key.encode()).decode()
|
||||||
|
encrypted_value = f"{_FERNET_PREFIX}{token}"
|
||||||
|
session.execute(
|
||||||
|
sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
|
||||||
|
{"pk": encrypted_value, "id": ca_id},
|
||||||
|
)
|
||||||
|
encrypted_count += 1
|
||||||
|
logger.info(f"Encrypted private key for CA {ca_id}")
|
||||||
|
except Exception as exc:
|
||||||
|
session.rollback()
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to encrypt private key for CA {ca_id}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(
|
||||||
|
f"CA key encryption migration complete: "
|
||||||
|
f"{encrypted_count} encrypted, {skipped_count} skipped"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" [016_encrypt_ca_keys] {encrypted_count} CA private key(s) encrypted, "
|
||||||
|
f"{skipped_count} already encrypted or empty."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
"""Decrypt CA private keys back to plaintext (emergency rollback only)."""
|
||||||
|
bind = op.get_bind()
|
||||||
|
session = Session(bind=bind)
|
||||||
|
|
||||||
|
try:
|
||||||
|
fernet = _get_fernet()
|
||||||
|
except RuntimeError as exc:
|
||||||
|
raise RuntimeError(str(exc)) from exc
|
||||||
|
|
||||||
|
rows = session.execute(
|
||||||
|
sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
decrypted_count = 0
|
||||||
|
skipped_count = 0
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
ca_id, private_key = row[0], row[1]
|
||||||
|
|
||||||
|
if not private_key or not private_key.startswith(_FERNET_PREFIX):
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
token = private_key[len(_FERNET_PREFIX):]
|
||||||
|
try:
|
||||||
|
from cryptography.fernet import InvalidToken
|
||||||
|
try:
|
||||||
|
plaintext = fernet.decrypt(token.encode()).decode()
|
||||||
|
except InvalidToken as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Downgrade failed: cannot decrypt CA {ca_id} — wrong key or corrupted data."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
session.execute(
|
||||||
|
sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
|
||||||
|
{"pk": plaintext, "id": ca_id},
|
||||||
|
)
|
||||||
|
decrypted_count += 1
|
||||||
|
logger.warning(f"Decrypted (plaintext restore) private key for CA {ca_id}")
|
||||||
|
except RuntimeError:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.warning(
|
||||||
|
f"CA key decryption (downgrade) complete: "
|
||||||
|
f"{decrypted_count} decrypted, {skipped_count} skipped"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" [016_encrypt_ca_keys] DOWNGRADE: {decrypted_count} CA private key(s) "
|
||||||
|
f"decrypted to plaintext. WARNING: keys are now unencrypted at rest."
|
||||||
|
)
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""Add monotonic serial counter to CAs table.
|
||||||
|
|
||||||
|
Each CA now owns a `next_serial_number` (BigInteger) that is atomically
|
||||||
|
incremented every time a certificate is signed. This guarantees:
|
||||||
|
- Serials are unique per CA
|
||||||
|
- Serials are monotonically increasing (auditable, no gaps by accident)
|
||||||
|
- The value embedded in the OpenSSH certificate matches what is stored
|
||||||
|
in the `ssh_certificates.serial` column
|
||||||
|
|
||||||
|
Revision ID: 017_add_ca_serial_counter
|
||||||
|
Revises: 016_encrypt_ca_keys
|
||||||
|
Create Date: 2026-03-02
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision = "017_add_ca_serial_counter"
|
||||||
|
down_revision = "016_encrypt_ca_keys"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
with op.batch_alter_table("cas", schema=None) as batch_op:
|
||||||
|
batch_op.add_column(
|
||||||
|
sa.Column(
|
||||||
|
"next_serial_number",
|
||||||
|
sa.BigInteger(),
|
||||||
|
nullable=False,
|
||||||
|
server_default="1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
with op.batch_alter_table("cas", schema=None) as batch_op:
|
||||||
|
batch_op.drop_column("next_serial_number")
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
"""Add ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE to auditaction enum.
|
||||||
|
|
||||||
|
Revision ID: 018_audit_enum_values
|
||||||
|
Revises: 017_add_ca_serial_counter
|
||||||
|
Create Date: 2026-03-02
|
||||||
|
|
||||||
|
ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE were added to the Python
|
||||||
|
AuditAction enum but were never synced to the PostgreSQL auditaction type,
|
||||||
|
causing a DataError (invalid enum value) when transferring org ownership
|
||||||
|
or hard-deleting a user.
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "018_audit_enum_values"
|
||||||
|
down_revision = "017_add_ca_serial_counter"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ALTER TYPE ... ADD VALUE cannot run inside a transaction block in PostgreSQL.
|
||||||
|
# Alembic has already opened a transaction on the connection by the time our
|
||||||
|
# upgrade() runs, so we must:
|
||||||
|
# 1. Roll back that open transaction on the raw psycopg2 connection.
|
||||||
|
# 2. Switch to autocommit so the ALTER TYPE runs outside any transaction.
|
||||||
|
# 3. Restore the previous state afterwards.
|
||||||
|
conn = op.get_bind()
|
||||||
|
# SQLAlchemy 2.x: conn.connection is a _ConnectionFairy; .driver_connection is psycopg2
|
||||||
|
fairy = conn.connection
|
||||||
|
raw = getattr(fairy, "driver_connection", None) or getattr(fairy, "dbapi_connection", fairy)
|
||||||
|
# Roll back the open transaction so psycopg2 allows us to change autocommit.
|
||||||
|
raw.rollback()
|
||||||
|
old_autocommit = raw.autocommit
|
||||||
|
raw.autocommit = True
|
||||||
|
try:
|
||||||
|
with raw.cursor() as cur:
|
||||||
|
for val in ("ORG_OWNERSHIP_TRANSFERRED", "USER_HARD_DELETE"):
|
||||||
|
cur.execute(
|
||||||
|
"SELECT 1 FROM pg_enum "
|
||||||
|
"WHERE enumlabel = %s "
|
||||||
|
"AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')",
|
||||||
|
(val,),
|
||||||
|
)
|
||||||
|
if not cur.fetchone():
|
||||||
|
cur.execute(f"ALTER TYPE auditaction ADD VALUE '{val}'")
|
||||||
|
finally:
|
||||||
|
raw.autocommit = old_autocommit
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user