From 7cb522b5907f7166f59d0bec3ab16106904fc35c Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Wed, 4 Mar 2026 18:49:04 +0545 Subject: [PATCH] Feat(Chore, Fix): Refractor, Half Baked Deletion + Admin Privilege Refractor Codes into sub file/folders Admin can remove users'/members mfa/2fa, unlink account from oauth provider Admin can add/reset password Different Email (OIDC + Manual)-Same Account; (Block Linking and authorize if available) --- .env.example | 7 + config/development.py | 3 + gatehouse_app/api/oidc.py | 5 +- gatehouse_app/api/v1/auth.py | 1691 --------------- gatehouse_app/api/v1/auth/__init__.py | 2 + gatehouse_app/api/v1/auth/core.py | 251 +++ gatehouse_app/api/v1/auth/password.py | 218 ++ gatehouse_app/api/v1/auth/totp.py | 183 ++ gatehouse_app/api/v1/auth/webauthn.py | 217 ++ gatehouse_app/api/v1/external_auth.py | 1449 ------------- .../api/v1/external_auth/__init__.py | 2 + .../api/v1/external_auth/_helpers.py | 94 + gatehouse_app/api/v1/external_auth/admin.py | 109 + gatehouse_app/api/v1/external_auth/cli.py | 68 + gatehouse_app/api/v1/external_auth/oauth.py | 244 +++ .../api/v1/external_auth/providers.py | 201 ++ gatehouse_app/api/v1/organizations.py | 1888 ----------------- .../api/v1/organizations/__init__.py | 4 + .../api/v1/organizations/_helpers.py | 52 + gatehouse_app/api/v1/organizations/audit.py | 175 ++ gatehouse_app/api/v1/organizations/cas.py | 261 +++ gatehouse_app/api/v1/organizations/clients.py | 110 + gatehouse_app/api/v1/organizations/core.py | 85 + gatehouse_app/api/v1/organizations/invites.py | 256 +++ gatehouse_app/api/v1/organizations/members.py | 176 ++ gatehouse_app/api/v1/organizations/roles.py | 85 + gatehouse_app/api/v1/ssh.py | 1418 ------------- gatehouse_app/api/v1/ssh/__init__.py | 3 + gatehouse_app/api/v1/ssh/_helpers.py | 174 ++ gatehouse_app/api/v1/ssh/admin.py | 111 + gatehouse_app/api/v1/ssh/certs.py | 391 ++++ gatehouse_app/api/v1/ssh/keys.py | 125 ++ gatehouse_app/api/v1/users.py | 879 -------- gatehouse_app/api/v1/users/__init__.py | 2 + gatehouse_app/api/v1/users/admin.py | 842 ++++++++ gatehouse_app/api/v1/users/me.py | 299 +++ gatehouse_app/jobs/__init__.py | 2 +- gatehouse_app/jobs/mfa_compliance_job.py | 19 +- gatehouse_app/models/auth/audit_log.py | 2 +- gatehouse_app/schemas/auth_schema.py | 3 +- gatehouse_app/services/__init__.py | 2 +- gatehouse_app/services/auth_service.py | 17 +- .../services/external_auth/__init__.py | 168 ++ .../services/external_auth/_helpers.py | 183 ++ .../services/external_auth/app_provider.py | 125 ++ .../services/external_auth/linking.py | 339 +++ .../services/external_auth/models.py | 173 ++ .../services/external_auth/org_override.py | 147 ++ .../services/external_auth_service.py | 1328 ------------ .../services/notification_service.py | 31 +- gatehouse_app/services/oauth_flow/__init__.py | 209 ++ gatehouse_app/services/oauth_flow/code.py | 141 ++ gatehouse_app/services/oauth_flow/login.py | 410 ++++ gatehouse_app/services/oauth_flow/register.py | 248 +++ gatehouse_app/services/oauth_flow_service.py | 1152 ---------- gatehouse_app/services/oidc/__init__.py | 150 ++ gatehouse_app/services/oidc/auth_code.py | 196 ++ gatehouse_app/services/oidc/tokens.py | 321 +++ gatehouse_app/services/oidc/userinfo.py | 65 + gatehouse_app/services/oidc_service.py | 1025 --------- .../services/organization_service.py | 76 +- gatehouse_app/utils/constants.py | 4 + ...019_convert_auditaction_enum_to_varchar.py | 143 ++ 63 files changed, 7896 insertions(+), 10863 deletions(-) delete mode 100644 gatehouse_app/api/v1/auth.py create mode 100644 gatehouse_app/api/v1/auth/__init__.py create mode 100644 gatehouse_app/api/v1/auth/core.py create mode 100644 gatehouse_app/api/v1/auth/password.py create mode 100644 gatehouse_app/api/v1/auth/totp.py create mode 100644 gatehouse_app/api/v1/auth/webauthn.py delete mode 100644 gatehouse_app/api/v1/external_auth.py create mode 100644 gatehouse_app/api/v1/external_auth/__init__.py create mode 100644 gatehouse_app/api/v1/external_auth/_helpers.py create mode 100644 gatehouse_app/api/v1/external_auth/admin.py create mode 100644 gatehouse_app/api/v1/external_auth/cli.py create mode 100644 gatehouse_app/api/v1/external_auth/oauth.py create mode 100644 gatehouse_app/api/v1/external_auth/providers.py delete mode 100644 gatehouse_app/api/v1/organizations.py create mode 100644 gatehouse_app/api/v1/organizations/__init__.py create mode 100644 gatehouse_app/api/v1/organizations/_helpers.py create mode 100644 gatehouse_app/api/v1/organizations/audit.py create mode 100644 gatehouse_app/api/v1/organizations/cas.py create mode 100644 gatehouse_app/api/v1/organizations/clients.py create mode 100644 gatehouse_app/api/v1/organizations/core.py create mode 100644 gatehouse_app/api/v1/organizations/invites.py create mode 100644 gatehouse_app/api/v1/organizations/members.py create mode 100644 gatehouse_app/api/v1/organizations/roles.py delete mode 100644 gatehouse_app/api/v1/ssh.py create mode 100644 gatehouse_app/api/v1/ssh/__init__.py create mode 100644 gatehouse_app/api/v1/ssh/_helpers.py create mode 100644 gatehouse_app/api/v1/ssh/admin.py create mode 100644 gatehouse_app/api/v1/ssh/certs.py create mode 100644 gatehouse_app/api/v1/ssh/keys.py delete mode 100644 gatehouse_app/api/v1/users.py create mode 100644 gatehouse_app/api/v1/users/__init__.py create mode 100644 gatehouse_app/api/v1/users/admin.py create mode 100644 gatehouse_app/api/v1/users/me.py create mode 100644 gatehouse_app/services/external_auth/__init__.py create mode 100644 gatehouse_app/services/external_auth/_helpers.py create mode 100644 gatehouse_app/services/external_auth/app_provider.py create mode 100644 gatehouse_app/services/external_auth/linking.py create mode 100644 gatehouse_app/services/external_auth/models.py create mode 100644 gatehouse_app/services/external_auth/org_override.py delete mode 100644 gatehouse_app/services/external_auth_service.py create mode 100644 gatehouse_app/services/oauth_flow/__init__.py create mode 100644 gatehouse_app/services/oauth_flow/code.py create mode 100644 gatehouse_app/services/oauth_flow/login.py create mode 100644 gatehouse_app/services/oauth_flow/register.py delete mode 100644 gatehouse_app/services/oauth_flow_service.py create mode 100644 gatehouse_app/services/oidc/__init__.py create mode 100644 gatehouse_app/services/oidc/auth_code.py create mode 100644 gatehouse_app/services/oidc/tokens.py create mode 100644 gatehouse_app/services/oidc/userinfo.py delete mode 100644 gatehouse_app/services/oidc_service.py create mode 100644 migrations/versions/019_convert_auditaction_enum_to_varchar.py diff --git a/.env.example b/.env.example index 9e97462..4d1bf81 100644 --- a/.env.example +++ b/.env.example @@ -46,3 +46,10 @@ SSH_CA_KEY_PATH=/path/to/ca-users # Or set the key content directly (takes priority over SSH_CA_KEY_PATH): # SSH_CA_PRIVATE_KEY= +EMAIL_ENABLED= +SMTP_HOST= +SMTP_PORT= +SMTP_USERNAME= +SMTP_PASSWORD= +FROM_ADDRESS= +WEBAUTHN_ORIGIN= diff --git a/config/development.py b/config/development.py index 5436adf..a622243 100644 --- a/config/development.py +++ b/config/development.py @@ -9,6 +9,9 @@ class DevelopmentConfig(BaseConfig): # Use environment variable like BaseConfig does SQLALCHEMY_ECHO = os.getenv("SQLALCHEMY_ECHO", "False").lower() == "true" SESSION_COOKIE_SECURE = False + # SameSite=None requires Secure=True — browsers silently drop the cookie otherwise. + # In dev (http://localhost) use Lax so the TOTP/WebAuthn session cookie is actually sent. + SESSION_COOKIE_SAMESITE = "Lax" # More verbose logging in development LOG_LEVEL = "DEBUG" diff --git a/gatehouse_app/api/oidc.py b/gatehouse_app/api/oidc.py index 43bf5b4..f09e4cb 100644 --- a/gatehouse_app/api/oidc.py +++ b/gatehouse_app/api/oidc.py @@ -12,7 +12,7 @@ from flask import Blueprint, request, redirect, jsonify, session, g, current_app logger = logging.getLogger(__name__) from gatehouse_app.utils.response import api_response -from gatehouse_app.services.oidc_service import ( +from gatehouse_app.services.oidc import ( OIDCService, InvalidClientError, InvalidGrantError, InvalidRequestError ) from gatehouse_app.services.auth_service import AuthService @@ -756,7 +756,8 @@ def _show_login_page(client_id, redirect_uri, scope, state, nonce, response_type if error: params["error"] = error - return redirect(f"{ui_base_url}/login?{urlencode(params)}") + # /oidc-login is the dedicated OIDC bridge UI (not the main /login page) + return redirect(f"{ui_base_url}/oidc-login?{urlencode(params)}") # ============================================================================ diff --git a/gatehouse_app/api/v1/auth.py b/gatehouse_app/api/v1/auth.py deleted file mode 100644 index 993f213..0000000 --- a/gatehouse_app/api/v1/auth.py +++ /dev/null @@ -1,1691 +0,0 @@ -"""Authentication endpoints.""" -import json -import logging -from flask import request, session, g, jsonify, current_app -from marshmallow import ValidationError -from gatehouse_app.api.v1 import api_v1_bp -from gatehouse_app.extensions import limiter -from gatehouse_app.utils.response import api_response -from gatehouse_app.schemas.auth_schema import ( - RegisterSchema, - LoginSchema, - TOTPVerifyEnrollmentSchema, - TOTPVerifySchema, - TOTPDisableSchema, - TOTPRegenerateBackupCodesSchema, -) -from gatehouse_app.schemas.webauthn_schema import ( - WebAuthnRegistrationBeginSchema, - WebAuthnRegistrationCompleteSchema, - WebAuthnLoginBeginSchema, - WebAuthnLoginCompleteSchema, - WebAuthnCredentialRenameSchema, -) -from gatehouse_app.services.auth_service import AuthService -from gatehouse_app.services.webauthn_service import WebAuthnService -from gatehouse_app.services.user_service import UserService -from gatehouse_app.services.mfa_policy_service import MfaPolicyService -from gatehouse_app.services.notification_service import NotificationService -from gatehouse_app.utils.decorators import login_required -from gatehouse_app.utils.constants import AuditAction -from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError -from gatehouse_app.exceptions.validation_exceptions import ConflictError, NotFoundError - - -@api_v1_bp.route("/auth/register", methods=["POST"]) -@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_REGISTER"]) -def register(): - """ - Register a new user. - - Request body: - email: User email - password: User password - password_confirm: Password confirmation - full_name: Optional full name - - Returns: - 201: User created successfully - 400: Validation error - 409: Email already exists - """ - try: - # Validate request data - schema = RegisterSchema() - data = schema.load(request.json) - - # Register user - user = AuthService.register_user( - email=data["email"], - password=data["password"], - full_name=data.get("full_name"), - ) - - # Send verification email - try: - from gatehouse_app.models import EmailVerificationToken - verify_token = EmailVerificationToken.generate(user_id=user.id) - app_url = current_app.config.get("APP_URL", "http://localhost:8080") - verify_link = f"{app_url}/verify-email?token={verify_token.token}" - subject = "Verify your Gatehouse email address" - body = ( - f"Hi {user.full_name or user.email},\n\n" - f"Welcome to Gatehouse! Please verify your email address by clicking the link below (valid for 24 hours):\n" - f"{verify_link}\n\n" - f"Gatehouse Security Team" - ) - NotificationService._send_email(to_address=user.email, subject=subject, body=body) - except Exception as exc: - logging.getLogger(__name__).warning(f"Failed to send verification email on register: {exc}") - - # Create session - user_session = AuthService.create_session(user) - - # ── Post-registration hints ───────────────────────────────────────── - from gatehouse_app.models.organization.org_invite_token import OrgInviteToken - from gatehouse_app.models.user.user import User as _User - from datetime import datetime, timezone as _tz - - now = datetime.now(_tz.utc) - pending_invites = OrgInviteToken.query.filter( - OrgInviteToken.email == user.email, - OrgInviteToken.accepted_at.is_(None), - OrgInviteToken.expires_at > now, - OrgInviteToken.deleted_at.is_(None), - ).all() - - # Determine if this is the very first user ever registered on this - # instance (exactly 1 active user means it must be this one). - total_users = _User.query.filter(_User.deleted_at.is_(None)).count() - is_first_user = total_users == 1 - - expires_str = user_session.expires_at.isoformat() - if expires_str[-1] != "Z": - expires_str += "Z" - - return api_response( - data={ - "user": user.to_dict(), - "token": user_session.token, - "expires_at": expires_str, - "is_first_user": is_first_user, - "pending_invites": [ - { - "token": inv.token, - "organization": { - "id": str(inv.organization_id), - "name": inv.organization.name, - }, - "role": inv.role, - "expires_at": inv.expires_at.isoformat(), - } - for inv in pending_invites - ], - }, - message="Registration successful", - status=201, - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - -@api_v1_bp.route("/auth/login", methods=["POST"]) -@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_LOGIN"]) -def login(): - """ - Login user. - - Request body: - email: User email - password: User password - remember_me: Optional boolean for extended session - - Returns: - 200: Login successful or TOTP code required - 400: Validation error - 401: Invalid credentials - """ - import logging - logger = logging.getLogger(__name__) - - try: - # Validate request data - schema = LoginSchema() - data = schema.load(request.json) - - # Authenticate user with email and password - user = AuthService.authenticate( - email=data["email"], - password=data["password"], - ) - - # Check MFA enrollment status - has_totp = user.has_totp_enabled() - has_webauthn = user.has_webauthn_enabled() - logger.info(f"Login attempt for user {user.email} - TOTP enabled: {has_totp}, WebAuthn enabled: {has_webauthn}") - - # MFA Enforcement: Check WebAuthn first (most secure), then TOTP fallback - # Priority: WebAuthn > TOTP > No MFA - if has_webauthn: - # User has WebAuthn enrolled - require WebAuthn verification - # Store user_id in session for WebAuthn verification - # The /auth/webauthn/login/complete endpoint will retrieve this user_id - session["webauthn_pending_user_id"] = user.id - - # Return response indicating WebAuthn verification is required - return api_response( - data={ - "requires_webauthn": True, - }, - message="Passkey verification required. Please use your passkey to complete login.", - ) - - # Check if user has TOTP enabled for two-factor authentication - if has_totp: - # TOTP is enabled - store user_id in session for TOTP verification - # The /auth/totp/verify endpoint will retrieve this user_id - session["totp_pending_user_id"] = user.id - - # Return response indicating TOTP code is required - # Do NOT create session or return token yet - wait for TOTP verification - return api_response( - data={ - "requires_totp": True, - }, - message="TOTP code required. Please enter your 6-digit code from your authenticator app.", - ) - - # Evaluate MFA policy after primary authentication - remember_me = data.get("remember_me", False) - policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me) - - # Create session with appropriate duration based on remember_me preference - duration = 2592000 if remember_me else 86400 # 30 days vs 1 day - - # Determine if this should be a compliance-only session - is_compliance_only = policy_result.create_compliance_only_session - - user_session = AuthService.create_session( - user, - duration_seconds=duration, - is_compliance_only=is_compliance_only - ) - - # Build response data - response_data = { - "user": user.to_dict(), - "token": user_session.token, - "expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(), - } - - # Add MFA compliance information - if policy_result.compliance_summary: - response_data["mfa_compliance"] = { - "overall_status": policy_result.compliance_summary.overall_status, - "missing_methods": policy_result.compliance_summary.missing_methods, - "deadline_at": policy_result.compliance_summary.deadline_at, - "orgs": [ - { - "organization_id": org.organization_id, - "organization_name": org.organization_name, - "status": org.status, - "effective_mode": org.effective_mode, - "deadline_at": org.deadline_at, - "applied_at": org.applied_at, - } - for org in policy_result.compliance_summary.orgs - ], - } - - # Add requires_mfa_enrollment flag if compliance-only session - if is_compliance_only: - response_data["requires_mfa_enrollment"] = True - - # ── Org-setup hint for org-less users ──────────────────────────────── - # If the user has no organisation memberships, surface any pending - # invitations so the UI can redirect straight to /org-setup instead of - # showing an empty dashboard. - user_orgs = user.get_organizations() - if not user_orgs: - from gatehouse_app.models.organization.org_invite_token import OrgInviteToken - from datetime import datetime, timezone as _tz - _now = datetime.now(_tz.utc) - pending_invites = OrgInviteToken.query.filter( - OrgInviteToken.email == user.email, - OrgInviteToken.accepted_at.is_(None), - OrgInviteToken.expires_at > _now, - OrgInviteToken.deleted_at.is_(None), - ).all() - response_data["pending_invites"] = [ - { - "token": inv.token, - "organization": { - "id": str(inv.organization_id), - "name": inv.organization.name, - }, - "role": inv.role, - "expires_at": inv.expires_at.isoformat(), - } - for inv in pending_invites - ] - # Flag so the UI knows to send this user through org-setup - response_data["requires_org_setup"] = True - - return api_response( - data=response_data, - message="Login successful", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - -@api_v1_bp.route("/auth/logout", methods=["POST"]) -@login_required -def logout(): - """ - Logout current user. - - Returns: - 200: Logout successful - 401: Not authenticated - """ - # Revoke current session (g.current_session is set by login_required decorator) - if g.current_session: - AuthService.revoke_session(g.current_session.id, reason="User logout") - - return api_response( - message="Logout successful", - ) - - -@api_v1_bp.route("/auth/me", methods=["GET"]) -@login_required -def get_current_user(): - """ - Get current authenticated user. - - Returns: - 200: User data - 401: Not authenticated - """ - user = g.current_user - - return api_response( - data={ - "user": user.to_dict(), - "organizations": [ - { - "id": membership.organization.id, - "name": membership.organization.name, - "slug": membership.organization.slug, - "role": membership.role, - } - for membership in user.organization_memberships - ], - }, - message="User retrieved successfully", - ) - - -@api_v1_bp.route("/auth/sessions", methods=["GET"]) -@login_required -def get_user_sessions(): - """ - Get all active sessions for current user. - - Returns: - 200: List of active sessions - 401: Not authenticated - """ - from gatehouse_app.services.session_service import SessionService - - sessions = SessionService.get_user_sessions(g.current_user.id, active_only=True) - - return api_response( - data={ - "sessions": [session.to_dict() for session in sessions], - "count": len(sessions), - }, - message="Sessions retrieved successfully", - ) - - -@api_v1_bp.route("/auth/sessions/", methods=["DELETE"]) -@login_required -def revoke_session(session_id): - """ - Revoke a specific session. - - Args: - session_id: ID of session to revoke - - Returns: - 200: Session revoked - 401: Not authenticated - 404: Session not found - """ - from gatehouse_app.models.user.session import Session - - # Ensure session belongs to current user - user_session = Session.query.filter_by( - id=session_id, user_id=g.current_user.id, deleted_at=None - ).first() - - if not user_session: - return api_response( - success=False, - message="Session not found", - status=404, - error_type="NOT_FOUND", - ) - - AuthService.revoke_session(session_id, reason="Revoked by user") - - return api_response( - message="Session revoked successfully", - ) - - -@api_v1_bp.route("/auth/totp/enroll", methods=["POST"]) -@login_required -def enroll_totp(): - """ - Initiate TOTP enrollment for the current user. - - Returns: - 201: TOTP enrollment initiated with secret, provisioning_uri, qr_code, and backup_codes - 401: Not authenticated - 409: TOTP already enabled - """ - try: - # Initiate TOTP enrollment - result = AuthService.enroll_totp(g.current_user) - - return api_response( - data={ - "secret": result["secret"], - "provisioning_uri": result["provisioning_uri"], - "qr_code": result["qr_code"], - "backup_codes": result["backup_codes"], - }, - message="TOTP enrollment initiated. Please verify with your authenticator app.", - status=201, - ) - - except ConflictError as e: - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - - -@api_v1_bp.route("/auth/totp/verify-enrollment", methods=["POST"]) -@login_required -def verify_totp_enrollment(): - """ - Complete TOTP enrollment by verifying the first TOTP code. - - Request body: - code: 6-digit TOTP code from authenticator app - client_timestamp: Optional client UTC timestamp in seconds since epoch - - Returns: - 200: TOTP enrollment completed successfully - 400: Validation error - 401: Not authenticated - 401: Invalid TOTP code - """ - try: - # Validate request data - schema = TOTPVerifyEnrollmentSchema() - data = schema.load(request.json) - - # Verify TOTP enrollment - AuthService.verify_totp_enrollment( - g.current_user, - data["code"], - client_utc_timestamp=data.get("client_timestamp"), - ) - - return api_response( - message="TOTP enrollment completed successfully", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - except InvalidCredentialsError as e: - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - - -@api_v1_bp.route("/auth/totp/verify", methods=["POST"]) -@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_TOTP_VERIFY"]) -def verify_totp(): - """ - Verify TOTP code during login. - - Request body: - code: 6-digit TOTP code or backup code - is_backup_code: True if code is a backup code, False if TOTP code (default: False) - client_timestamp: Optional client UTC timestamp in seconds since epoch - - Returns: - 200: TOTP code verified successfully with session token - 400: Validation error - 401: Invalid TOTP code or session not found - """ - try: - # Validate request data - schema = TOTPVerifySchema() - data = schema.load(request.json) - - # Get user from temporary session (stored in Flask session by login endpoint) - # Check totp_pending_user_id first, then fall back to webauthn_pending_user_id - # This allows TOTP to be used as a fallback when WebAuthn was the primary MFA method - user_id = session.get("totp_pending_user_id") or session.get("webauthn_pending_user_id") - if not user_id: - return api_response( - success=False, - message="No pending TOTP verification. Please login first.", - status=401, - error_type="AUTHENTICATION_ERROR", - ) - - # Get user from database - from gatehouse_app.models.user.user import User - user = User.query.get(user_id) - if not user: - return api_response( - success=False, - message="User not found", - status=401, - error_type="AUTHENTICATION_ERROR", - ) - - # Check account suspension before completing TOTP verification - from gatehouse_app.utils.constants import UserStatus - if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): - session.pop("totp_pending_user_id", None) - session.pop("webauthn_pending_user_id", None) - return api_response( - success=False, - message="Account is suspended. Contact an administrator.", - status=403, - error_type="ACCOUNT_SUSPENDED", - ) - - # Verify TOTP code - AuthService.authenticate_with_totp( - user, - data["code"], - data.get("is_backup_code", False), - client_utc_timestamp=data.get("client_timestamp"), - ) - - # Evaluate MFA policy after primary authentication - policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False) - - # Determine if this should be a compliance-only session - is_compliance_only = policy_result.create_compliance_only_session - - # Create session - user_session = AuthService.create_session(user, is_compliance_only=is_compliance_only) - - # Clear temporary session - clear both pending user IDs - session.pop("totp_pending_user_id", None) - session.pop("webauthn_pending_user_id", None) - - # Build response data - response_data = { - "user": user.to_dict(), - "token": user_session.token, - "expires_at": user_session.expires_at.isoformat() + "Z" - if user_session.expires_at.isoformat()[-1] != "Z" - else user_session.expires_at.isoformat(), - } - - # Add MFA compliance information - if policy_result.compliance_summary: - response_data["mfa_compliance"] = { - "overall_status": policy_result.compliance_summary.overall_status, - "missing_methods": policy_result.compliance_summary.missing_methods, - "deadline_at": policy_result.compliance_summary.deadline_at, - "orgs": [ - { - "organization_id": org.organization_id, - "organization_name": org.organization_name, - "status": org.status, - "effective_mode": org.effective_mode, - "deadline_at": org.deadline_at, - "applied_at": org.applied_at, - } - for org in policy_result.compliance_summary.orgs - ], - } - - # Add requires_mfa_enrollment flag if compliance-only session - if is_compliance_only: - response_data["requires_mfa_enrollment"] = True - - return api_response( - data=response_data, - message="TOTP verification successful", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - except InvalidCredentialsError as e: - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - - -@api_v1_bp.route("/auth/totp/disable", methods=["DELETE"]) -@login_required -def disable_totp(): - """ - Disable TOTP for the current user. - - Request body: - password: User's current password for verification - - Returns: - 200: TOTP disabled successfully - 400: Validation error - 401: Not authenticated or invalid password - 401: TOTP not enabled - """ - try: - # Validate request data - schema = TOTPDisableSchema() - data = schema.load(request.json) - - # Disable TOTP - AuthService.disable_totp(g.current_user, data["password"]) - - return api_response( - message="TOTP disabled successfully", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - except InvalidCredentialsError as e: - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - - -@api_v1_bp.route("/auth/totp/status", methods=["GET"]) -@login_required -def get_totp_status(): - """ - Get TOTP status for the current user. - - Returns: - 200: TOTP status with totp_enabled, verified_at, and backup_codes_remaining - 401: Not authenticated - """ - user = g.current_user - - # Check if TOTP is enabled - totp_enabled = user.has_totp_enabled() - - # Get TOTP method to check backup codes remaining - backup_codes_remaining = 0 - verified_at = None - - if totp_enabled: - totp_method = user.get_totp_method() - if totp_method and totp_method.provider_data: - backup_codes = totp_method.provider_data.get("backup_codes", []) - backup_codes_remaining = len(backup_codes) - if totp_method and totp_method.totp_verified_at: - verified_at = totp_method.totp_verified_at.isoformat() + "Z" if totp_method.totp_verified_at.isoformat()[-1] != "Z" else totp_method.totp_verified_at.isoformat() - - return api_response( - data={ - "totp_enabled": totp_enabled, - "verified_at": verified_at, - "backup_codes_remaining": backup_codes_remaining, - }, - message="TOTP status retrieved successfully", - ) - - -@api_v1_bp.route("/auth/totp/regenerate-backup-codes", methods=["POST"]) -@login_required -def regenerate_totp_backup_codes(): - """ - Generate new backup codes for TOTP. - - Request body: - password: User's current password for verification - - Returns: - 200: New backup codes generated successfully - 400: Validation error - 401: Not authenticated or invalid password - 401: TOTP not enabled - """ - try: - # Validate request data - schema = TOTPRegenerateBackupCodesSchema() - data = schema.load(request.json) - - # Regenerate backup codes - backup_codes = AuthService.regenerate_totp_backup_codes( - g.current_user, data["password"] - ) - - return api_response( - data={ - "backup_codes": backup_codes, - }, - message="Backup codes regenerated successfully", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - except InvalidCredentialsError as e: - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - - -# ============================================================================= -# WebAuthn Passkey Endpoints -# ============================================================================= - - -@api_v1_bp.route("/auth/webauthn/register/begin", methods=["POST"]) -@login_required -def begin_webauthn_registration(): - """ - Begin WebAuthn passkey registration. - - Returns: - 200: PublicKeyCredentialCreationOptions (raw JSON, no wrapper) - 401: Not authenticated - """ - user = g.current_user - - # Generate registration challenge - options = WebAuthnService.generate_registration_challenge(user) - - # Return unwrapped JSON for WebAuthn - return jsonify(options), 200 - - -@api_v1_bp.route("/auth/webauthn/register/complete", methods=["POST"]) -@login_required -def complete_webauthn_registration(): - """ - Complete WebAuthn passkey registration. - - Request body: - id: Credential ID - rawId: Base64URL-encoded credential ID - type: "public-key" - response: Attestation response data - transports: List of transport types - - Returns: - 200: Registration successful - 400: Validation error - 401: Not authenticated - 409: Credential already exists - """ - import base64 - import logging - logger = logging.getLogger(__name__) - - user_email = g.current_user.email - logger.info(f"WebAuthn registration completion started for user: {user_email}") - - try: - # Validate request data - schema = WebAuthnRegistrationCompleteSchema() - data = schema.load(request.json) - - # Extract challenge from client data - client_data_json_b64 = data.get("response", {}).get("clientDataJSON", "") - - if not client_data_json_b64: - logger.error(f"WebAuthn registration failed - missing clientDataJSON for user: {user_email}") - return api_response( - success=False, - message="Missing clientDataJSON in response", - status=400, - error_type="VALIDATION_ERROR", - ) - - try: - # Add padding if needed - padding = 4 - (len(client_data_json_b64) % 4) - if padding != 4: - client_data_json_b64_padded = client_data_json_b64 + '=' * padding - else: - client_data_json_b64_padded = client_data_json_b64 - - client_data_json = base64.urlsafe_b64decode(client_data_json_b64_padded) - client_data_dict = json.loads(client_data_json) - - except Exception as e: - logger.error(f"WebAuthn registration failed - client data decode error for user {user_email}: {e}") - return api_response( - success=False, - message=f"Failed to decode client data JSON: {str(e)}", - status=400, - error_type="VALIDATION_ERROR", - ) - - challenge = client_data_dict.get("challenge") - - if not challenge: - logger.error(f"WebAuthn registration failed - no challenge in client data for user: {user_email}") - return api_response( - success=False, - message="Invalid challenge in client data", - status=400, - error_type="VALIDATION_ERROR", - ) - - # Verify registration response - auth_method = WebAuthnService.verify_registration_response( - g.current_user, - data, - challenge - ) - - logger.info(f"WebAuthn registration completed successfully for user: {user_email}") - - return api_response( - data={ - "credential": auth_method.to_webauthn_dict(), - }, - message="Passkey registered successfully", - status=201, - ) - - except ValidationError as e: - logger.error(f"WebAuthn registration validation error for user {user_email}: {e.messages}") - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - except InvalidCredentialsError as e: - logger.warning(f"WebAuthn registration failed for user {user_email}: {e.message}") - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - - except Exception as e: - logger.exception(f"WebAuthn registration unexpected error for user {user_email}: {e}") - return api_response( - success=False, - message="An unexpected error occurred during registration", - status=500, - error_type="INTERNAL_ERROR", - ) - - -@api_v1_bp.route("/auth/webauthn/login/begin", methods=["POST"]) -def begin_webauthn_login(): - """ - Begin WebAuthn passkey login. - - Request body: - email: User email address - - Returns: - 200: PublicKeyCredentialRequestOptions (raw JSON, no wrapper) - 400: Validation error - 404: User not found - """ - import logging - logger = logging.getLogger(__name__) - - try: - # Validate request data - schema = WebAuthnLoginBeginSchema() - data = schema.load(request.json) - - # Find user by email - from gatehouse_app.models.user.user import User - user = User.query.filter_by( - email=data["email"].lower(), - deleted_at=None - ).first() - - if not user: - logger.warning(f"WebAuthn login begin - user not found: {data['email']}") - return api_response( - success=False, - message="User not found", - status=404, - error_type="NOT_FOUND", - ) - - # Check account suspension before proceeding - from gatehouse_app.utils.constants import UserStatus - if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): - logger.warning(f"WebAuthn login begin - suspended account attempt: {user.email}") - return api_response( - success=False, - message="Account is suspended. Contact an administrator.", - status=403, - error_type="ACCOUNT_SUSPENDED", - ) - - # Check if user has any WebAuthn credentials - if not user.has_webauthn_enabled(): - logger.warning(f"WebAuthn login begin - no credentials for user: {user.email}") - return api_response( - success=False, - message="No passkeys found for this account", - status=404, - error_type="NOT_FOUND", - ) - - logger.info(f"WebAuthn login challenge generated for user: {user.email}") - - # Generate authentication challenge - options = WebAuthnService.generate_authentication_challenge(user) - - # Store user_id in Flask session for WebAuthn verification - session["webauthn_pending_user_id"] = user.id - - # Return unwrapped JSON for WebAuthn - return jsonify(options), 200 - - except ValidationError as e: - logger.error(f"WebAuthn login begin validation error: {e.messages}") - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - except Exception as e: - logger.exception(f"WebAuthn login begin unexpected error: {e}") - raise - - -@api_v1_bp.route("/auth/webauthn/login/complete", methods=["POST"]) -def complete_webauthn_login(): - """ - Complete WebAuthn passkey login. - - Request body: - id: Credential ID - rawId: Base64URL-encoded credential ID - type: "public-key" - response: Assertion response data - - Returns: - 200: Login successful with session token - 400: Validation error - 401: Authentication failed - """ - import logging - import base64 - logger = logging.getLogger(__name__) - - try: - # Get user from Flask session (stored by /begin endpoint) - user_id = session.get("webauthn_pending_user_id") - if not user_id: - logger.error("WebAuthn login complete - no pending verification in session") - return api_response( - success=False, - message="No pending WebAuthn verification. Please initiate login first.", - status=401, - error_type="AUTHENTICATION_ERROR", - ) - - # Validate request data - schema = WebAuthnLoginCompleteSchema() - data = schema.load(request.json) - - # Get user from database - from gatehouse_app.models.user.user import User - user = User.query.get(user_id) - if not user: - logger.error(f"WebAuthn login complete - user not found: {user_id}") - return api_response( - success=False, - message="User not found", - status=401, - error_type="AUTHENTICATION_ERROR", - ) - - # Check account suspension before completing login - from gatehouse_app.utils.constants import UserStatus - if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): - session.pop("webauthn_pending_user_id", None) - logger.warning(f"WebAuthn login complete - suspended account attempt: {user.email}") - return api_response( - success=False, - message="Account is suspended. Contact an administrator.", - status=403, - error_type="ACCOUNT_SUSPENDED", - ) - - # Extract challenge from client data - client_data = data.get("response", {}).get("clientDataJSON", "") - - client_data_json = base64.urlsafe_b64decode(client_data + "==") - client_data_dict = json.loads(client_data_json) - - challenge = client_data_dict.get("challenge") - - if not challenge: - logger.error(f"WebAuthn login complete - no challenge in client data for user: {user.email}") - return api_response( - success=False, - message="Invalid challenge in client data", - status=400, - error_type="VALIDATION_ERROR", - ) - - # Verify authentication response - WebAuthnService.verify_authentication_response( - user, - data, - challenge - ) - - # Evaluate MFA policy after primary authentication - policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False) - - # Determine if this should be a compliance-only session - is_compliance_only = policy_result.create_compliance_only_session - - # Create session - user_session = AuthService.create_session(user, is_compliance_only=is_compliance_only) - - # Clear pending session - session.pop("webauthn_pending_user_id", None) - - logger.info(f"WebAuthn login completed successfully for user: {user.email}") - - # Build response data - response_data = { - "user": user.to_dict(), - "token": user_session.token, - "expires_at": user_session.expires_at.isoformat() + "Z" - if user_session.expires_at.isoformat()[-1] != "Z" - else user_session.expires_at.isoformat(), - } - - # Add MFA compliance information - if policy_result.compliance_summary: - response_data["mfa_compliance"] = { - "overall_status": policy_result.compliance_summary.overall_status, - "missing_methods": policy_result.compliance_summary.missing_methods, - "deadline_at": policy_result.compliance_summary.deadline_at, - "orgs": [ - { - "organization_id": org.organization_id, - "organization_name": org.organization_name, - "status": org.status, - "effective_mode": org.effective_mode, - "deadline_at": org.deadline_at, - "applied_at": org.applied_at, - } - for org in policy_result.compliance_summary.orgs - ], - } - - # Add requires_mfa_enrollment flag if compliance-only session - if is_compliance_only: - response_data["requires_mfa_enrollment"] = True - - return api_response( - data=response_data, - message="Login successful", - ) - - except ValidationError as e: - logger.error(f"WebAuthn login complete validation error: {e.messages}") - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - except InvalidCredentialsError as e: - logger.warning(f"WebAuthn login complete authentication failed: {e.message}") - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - - except Exception as e: - logger.exception(f"WebAuthn login complete unexpected error: {e}") - raise - - -@api_v1_bp.route("/auth/webauthn/credentials", methods=["GET"]) -@login_required -def list_webauthn_credentials(): - """ - List all WebAuthn passkey credentials for the current user. - - Returns: - 200: List of credentials - 401: Not authenticated - """ - user = g.current_user - credentials = WebAuthnService.get_user_credentials(user) - - return api_response( - data={ - "credentials": [cred.to_webauthn_dict() for cred in credentials], - "count": len(credentials), - }, - message="Credentials retrieved successfully", - ) - - -@api_v1_bp.route("/auth/webauthn/credentials/", methods=["DELETE"]) -@login_required -def delete_webauthn_credential(credential_id): - """ - Delete a WebAuthn passkey credential. - - Args: - credential_id: ID of the credential to delete - - Returns: - 200: Credential deleted successfully - 401: Not authenticated - 404: Credential not found - """ - user = g.current_user - - # First check that the specific credential actually belongs to this user. - # Only then check whether it is the last one — otherwise a user with zero - # credentials gets a misleading "Cannot delete the last passkey" error - # instead of a 404. - credential_exists = WebAuthnService.credential_belongs_to_user(credential_id, user) - if not credential_exists: - return api_response( - success=False, - message="Credential not found", - status=404, - error_type="NOT_FOUND", - ) - - # Check if this is the last credential - credential_count = user.get_webauthn_credential_count() - if credential_count <= 1: - return api_response( - success=False, - message="Cannot delete the last passkey. Add another passkey first.", - status=400, - error_type="BAD_REQUEST", - ) - - # Delete the credential - success = WebAuthnService.delete_credential(credential_id, user) - - if not success: - return api_response( - success=False, - message="Credential not found", - status=404, - error_type="NOT_FOUND", - ) - - return api_response( - message="Passkey deleted successfully", - ) - - -@api_v1_bp.route("/auth/webauthn/credentials/", methods=["PATCH"]) -@login_required -def rename_webauthn_credential(credential_id): - """ - Rename a WebAuthn passkey credential. - - Args: - credential_id: ID of the credential to rename - - Request body: - name: New name for the credential - - Returns: - 200: Credential renamed successfully - 400: Validation error - 401: Not authenticated - 404: Credential not found - """ - try: - # Validate request data - schema = WebAuthnCredentialRenameSchema() - data = schema.load(request.json) - - # Rename the credential - success = WebAuthnService.rename_credential( - credential_id, - g.current_user, - data["name"] - ) - - if not success: - return api_response( - success=False, - message="Credential not found", - status=404, - error_type="NOT_FOUND", - ) - - # Get updated credential - credential = WebAuthnService.get_credential_by_id(credential_id, g.current_user) - - return api_response( - data={ - "credential": credential.to_webauthn_dict() if credential else None, - }, - message="Passkey renamed successfully", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - -@api_v1_bp.route("/auth/webauthn/status", methods=["GET"]) -@login_required -def get_webauthn_status(): - """ - Get WebAuthn status for the current user. - - Returns: - 200: WebAuthn status with webauthn_enabled and credential_count - 401: Not authenticated - """ - user = g.current_user - - return api_response( - data={ - "webauthn_enabled": user.has_webauthn_enabled(), - "credential_count": user.get_webauthn_credential_count(), - }, - message="WebAuthn status retrieved successfully", - ) - - -_pw_logger = logging.getLogger(__name__) - - -@api_v1_bp.route("/auth/forgot-password", methods=["POST"]) -@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_FORGOT_PASSWORD"]) -def forgot_password(): - """Request a password reset email. - - Always returns 200 to avoid leaking account existence. - - Request body: - email: User email address - - Returns: - 200: Password reset email sent (or silently no-op if email not found) - """ - from gatehouse_app.models import User, PasswordResetToken - - data = request.get_json() or {} - email = (data.get("email") or "").strip().lower() - - if not email: - return api_response( - success=False, - message="Email is required", - status=400, - error_type="VALIDATION_ERROR", - ) - - # Always return 200 — don't leak whether the email exists - user = User.query.filter_by(email=email, deleted_at=None).first() - if user: - try: - reset_token = PasswordResetToken.generate(user_id=user.id) - app_url = current_app.config.get("APP_URL", "http://localhost:8080") - reset_link = f"{app_url}/reset-password?token={reset_token.token}" - subject = "Reset your Gatehouse password" - body = ( - f"Hi {user.full_name or user.email},\n\n" - f"You requested a password reset for your Gatehouse account.\n\n" - f"Click the link below to reset your password (valid for 2 hours):\n" - f"{reset_link}\n\n" - f"If you did not request this, you can safely ignore this email.\n\n" - f"Gatehouse Security Team" - ) - NotificationService._send_email( - to_address=user.email, - subject=subject, - body=body, - ) - _pw_logger.info(f"Password reset token generated for user {user.id}") - except Exception as exc: - _pw_logger.exception(f"Error generating password reset token: {exc}") - - return api_response( - data={}, - message="If an account exists for this email, you will receive a password reset link shortly.", - ) - - -@api_v1_bp.route("/auth/reset-password", methods=["POST"]) -@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESET_PASSWORD"]) -def reset_password(): - """Reset a user's password using a reset token. - - Request body: - token: Password reset token from email - password: New password - password_confirm: Password confirmation - - Returns: - 200: Password reset successfully - 400: Invalid or expired token / validation error - """ - import bcrypt as _bcrypt - from gatehouse_app.extensions import bcrypt - from gatehouse_app.models import PasswordResetToken, AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - data = request.get_json() or {} - token_value = (data.get("token") or "").strip() - new_password = data.get("password") or "" - password_confirm = data.get("password_confirm") or "" - - if not token_value or not new_password: - return api_response( - success=False, - message="Token and new password are required", - status=400, - error_type="VALIDATION_ERROR", - ) - - if new_password != password_confirm: - return api_response( - success=False, - message="Passwords do not match", - status=400, - error_type="VALIDATION_ERROR", - ) - - if len(new_password) < 8: - return api_response( - success=False, - message="Password must be at least 8 characters", - status=400, - error_type="VALIDATION_ERROR", - ) - - reset_token = PasswordResetToken.query.filter_by(token=token_value).first() - if not reset_token or not reset_token.is_valid: - return api_response( - success=False, - message="This password reset link is invalid or has expired.", - status=400, - error_type="INVALID_TOKEN", - ) - - try: - user = reset_token.user - # Update the password hash on the authentication method - auth_method = AuthenticationMethod.query.filter_by( - user_id=user.id, - method_type=AuthMethodType.PASSWORD, - deleted_at=None, - ).first() - if auth_method: - auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8") - from gatehouse_app.extensions import db - db.session.add(auth_method) - - reset_token.consume() - _pw_logger.info(f"Password reset for user {user.id}") - - return api_response( - data={}, - message="Your password has been reset. You can now sign in with your new password.", - ) - except Exception as exc: - _pw_logger.exception(f"Error resetting password: {exc}") - return api_response( - success=False, - message="An error occurred while resetting your password.", - status=500, - error_type="INTERNAL_ERROR", - ) - - -@api_v1_bp.route("/auth/verify-email", methods=["POST"]) -def verify_email(): - """Verify a user's email address using a verification token. - - Request body: - token: Email verification token - - Returns: - 200: Email verified successfully - 400: Invalid or expired token - """ - from gatehouse_app.models import EmailVerificationToken - - data = request.get_json() or {} - token_value = (data.get("token") or "").strip() - - if not token_value: - return api_response( - success=False, - message="Verification token is required", - status=400, - error_type="VALIDATION_ERROR", - ) - - verify_token = EmailVerificationToken.query.filter_by(token=token_value).first() - if not verify_token or not verify_token.is_valid: - return api_response( - success=False, - message="This verification link is invalid or has expired.", - status=400, - error_type="INVALID_TOKEN", - ) - - try: - user = verify_token.user - user.email_verified = True - from gatehouse_app.extensions import db - db.session.add(user) - verify_token.consume() - _pw_logger.info(f"Email verified for user {user.id}") - - return api_response( - data={}, - message="Your email has been verified. You can now sign in.", - ) - except Exception as exc: - _pw_logger.exception(f"Error verifying email: {exc}") - return api_response( - success=False, - message="An error occurred while verifying your email.", - status=500, - error_type="INTERNAL_ERROR", - ) - - -@api_v1_bp.route("/auth/resend-verification", methods=["POST"]) -def resend_verification(): - """Resend email verification link. - - Always returns 200 to avoid leaking account existence. - - Request body: - email: User email address - - Returns: - 200: Verification email sent (or silently no-op) - """ - from gatehouse_app.models import User, EmailVerificationToken - - data = request.get_json() or {} - email = (data.get("email") or "").strip().lower() - - if not email: - return api_response( - success=False, - message="Email is required", - status=400, - error_type="VALIDATION_ERROR", - ) - - user = User.query.filter_by(email=email, deleted_at=None).first() - if user and not user.email_verified: - try: - verify_token = EmailVerificationToken.generate(user_id=user.id) - app_url = current_app.config.get("APP_URL", "http://localhost:8080") - verify_link = f"{app_url}/verify-email?token={verify_token.token}" - subject = "Verify your Gatehouse email address" - body = ( - f"Hi {user.full_name or user.email},\n\n" - f"Please verify your email address by clicking the link below (valid for 24 hours):\n" - f"{verify_link}\n\n" - f"Gatehouse Security Team" - ) - NotificationService._send_email( - to_address=user.email, - subject=subject, - body=body, - ) - _pw_logger.info(f"Verification email sent for user {user.id}") - except Exception as exc: - _pw_logger.exception(f"Error sending verification email: {exc}") - - return api_response( - data={}, - message="If an account exists for this email and is not yet verified, you will receive a verification link shortly.", - ) - - -# ============================================================================= -# Account Activation (separate from email-verification) -# ============================================================================= - -@api_v1_bp.route("/auth/activate", methods=["POST"]) -def activate_account(): - """Activate a user account via a one-time activation code. - - Request body: - code – the activation_key from the welcome email - - Returns: - 200: Account activated, session token returned - 400: Missing code - 404: Invalid or already-used code - """ - import secrets - from gatehouse_app.models.user.user import User - from gatehouse_app.extensions import db - - data = request.get_json() or {} - code = (data.get("code") or "").strip() - if not code: - return api_response(success=False, message="Activation code is required", status=400, error_type="VALIDATION_ERROR") - - user = User.query.filter_by(activation_key=code, deleted_at=None).first() - if not user: - return api_response(success=False, message="Invalid or expired activation code", status=404, error_type="NOT_FOUND") - - user.activated = True - user.activation_key = None # one-time use - db.session.add(user) - db.session.commit() - - user_session = AuthService.create_session(user) - _pw_logger.info(f"Account activated for user {user.id}") - - return api_response( - data={ - "user": user.to_dict(), - "token": user_session.token, - "expires_at": user_session.expires_at.isoformat() + "Z" - if user_session.expires_at.isoformat()[-1] != "Z" - else user_session.expires_at.isoformat(), - }, - message="Account activated successfully", - ) - - -@api_v1_bp.route("/auth/resend-activation", methods=["POST"]) -def resend_activation(): - """Re-send an account activation email. - - Always returns 200 to avoid leaking whether an account exists. - - Request body: - email – user email address - """ - import secrets - from gatehouse_app.models.user.user import User - from gatehouse_app.extensions import db - - data = request.get_json() or {} - email = (data.get("email") or "").strip().lower() - if not email: - return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR") - - user = User.query.filter_by(email=email, deleted_at=None).first() - if user and not user.activated: - try: - code = secrets.token_urlsafe(32) - user.activation_key = code - db.session.add(user) - db.session.commit() - - app_url = current_app.config.get("APP_URL", current_app.config.get("FRONTEND_URL", "http://localhost:8080")) - activate_link = f"{app_url}/activate?code={code}" - subject = "Activate your Gatehouse account" - body = ( - f"Hi {user.full_name or user.email},\n\n" - f"Please activate your Gatehouse account by clicking the link below:\n" - f"{activate_link}\n\n" - f"If you did not create an account, you can safely ignore this email.\n\n" - f"Gatehouse Security Team" - ) - NotificationService._send_email(to_address=user.email, subject=subject, body=body) - _pw_logger.info(f"Activation email re-sent to {user.id}") - except Exception as exc: - _pw_logger.exception(f"Error re-sending activation email: {exc}") - - return api_response( - data={}, - message="If an unactivated account exists for this email, you will receive a new activation link shortly.", - ) - - -# ============================================================================= -# Token retrieval / redirect (for CLI / external tools) -# ============================================================================= - -@api_v1_bp.route("/auth/token", methods=["GET"]) -@login_required -def get_token(): - """Return the current session token, optionally redirecting to a URL. - - Query parameters: - redirect – optional URL to redirect to with the token appended as - a query param: ``?token=`` - - Returns: - 200: JSON ``{"token": ""}`` (no redirect given) - 302: Redirect to ``?token=`` - """ - from flask import redirect as flask_redirect - from urllib.parse import urlparse - - token = g.current_session.token - redirect_url = request.args.get("redirect", "").strip() - - if redirect_url: - # Validate redirect URL against allowed origins to prevent open-redirect - # token exfiltration attacks (CWE-601). - allowed_origins = set(current_app.config.get("CORS_ORIGINS", [])) - frontend_url = current_app.config.get("FRONTEND_URL", "") - if frontend_url: - parsed = urlparse(frontend_url) - allowed_origins.add(f"{parsed.scheme}://{parsed.netloc}") - - parsed_redirect = urlparse(redirect_url) - redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}" - - if redirect_origin not in allowed_origins: - return api_response( - success=False, - message="Redirect URL is not allowed.", - status=400, - error_type="INVALID_REDIRECT", - ) - - sep = "&" if "?" in redirect_url else "?" - return flask_redirect(f"{redirect_url}{sep}token={token}", code=302) - - return api_response(data={"token": token}, message="Token retrieved") diff --git a/gatehouse_app/api/v1/auth/__init__.py b/gatehouse_app/api/v1/auth/__init__.py new file mode 100644 index 0000000..8518cdf --- /dev/null +++ b/gatehouse_app/api/v1/auth/__init__.py @@ -0,0 +1,2 @@ +"""Auth blueprint subpackage.""" +from gatehouse_app.api.v1.auth import core, totp, webauthn, password diff --git a/gatehouse_app/api/v1/auth/core.py b/gatehouse_app/api/v1/auth/core.py new file mode 100644 index 0000000..f401834 --- /dev/null +++ b/gatehouse_app/api/v1/auth/core.py @@ -0,0 +1,251 @@ +"""Core auth endpoints: register, login, logout, sessions.""" +import logging +from flask import request, session, g, current_app +from marshmallow import ValidationError +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.extensions import limiter +from gatehouse_app.utils.response import api_response +from gatehouse_app.schemas.auth_schema import RegisterSchema, LoginSchema +from gatehouse_app.services.auth_service import AuthService +from gatehouse_app.services.mfa_policy_service import MfaPolicyService +from gatehouse_app.services.notification_service import NotificationService +from gatehouse_app.utils.decorators import login_required +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError + + +@api_v1_bp.route("/auth/register", methods=["POST"]) +@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_REGISTER"]) +def register(): + try: + schema = RegisterSchema() + data = schema.load(request.json) + + user = AuthService.register_user( + email=data["email"], + password=data["password"], + full_name=data.get("full_name"), + ) + + try: + from gatehouse_app.models import EmailVerificationToken + verify_token = EmailVerificationToken.generate(user_id=user.id) + app_url = current_app.config.get("APP_URL", "http://localhost:8080") + verify_link = f"{app_url}/verify-email?token={verify_token.token}" + subject = "Verify your Gatehouse email address" + body = ( + f"Hi {user.full_name or user.email},\n\n" + f"Welcome to Gatehouse! Please verify your email address by clicking the link below (valid for 24 hours):\n" + f"{verify_link}\n\n" + f"Gatehouse Security Team" + ) + NotificationService._send_email(to_address=user.email, subject=subject, body=body) + except Exception as exc: + logging.getLogger(__name__).warning(f"Failed to send verification email on register: {exc}") + + user_session = AuthService.create_session(user) + + from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + from gatehouse_app.models.user.user import User as _User + from datetime import datetime, timezone as _tz + + now = datetime.now(_tz.utc) + pending_invites = OrgInviteToken.query.filter( + OrgInviteToken.email == user.email, + OrgInviteToken.accepted_at.is_(None), + OrgInviteToken.expires_at > now, + OrgInviteToken.deleted_at.is_(None), + ).all() + + total_users = _User.query.filter(_User.deleted_at.is_(None)).count() + is_first_user = total_users == 1 + + expires_str = user_session.expires_at.isoformat() + if expires_str[-1] != "Z": + expires_str += "Z" + + return api_response( + data={ + "user": user.to_dict(), + "token": user_session.token, + "expires_at": expires_str, + "is_first_user": is_first_user, + "pending_invites": [ + { + "token": inv.token, + "organization": {"id": str(inv.organization_id), "name": inv.organization.name}, + "role": inv.role, + "expires_at": inv.expires_at.isoformat(), + } + for inv in pending_invites + ], + }, + message="Registration successful", + status=201, + ) + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + + +@api_v1_bp.route("/auth/login", methods=["POST"]) +@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_LOGIN"]) +def login(): + logger = logging.getLogger(__name__) + + try: + schema = LoginSchema() + data = schema.load(request.json) + + user = AuthService.authenticate(email=data["email"], password=data["password"]) + + has_totp = user.has_totp_enabled() + has_webauthn = user.has_webauthn_enabled() + logger.info(f"Login attempt for user {user.email} - TOTP enabled: {has_totp}, WebAuthn enabled: {has_webauthn}") + + if has_webauthn: + session["webauthn_pending_user_id"] = user.id + return api_response(data={"requires_webauthn": True}, message="Passkey verification required. Please use your passkey to complete login.") + + if has_totp: + session["totp_pending_user_id"] = user.id + return api_response(data={"requires_totp": True}, message="TOTP code required. Please enter your 6-digit code from your authenticator app.") + + remember_me = data.get("remember_me", False) + policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me) + duration = 2592000 if remember_me else 86400 + is_compliance_only = policy_result.create_compliance_only_session + + user_session = AuthService.create_session(user, duration_seconds=duration, is_compliance_only=is_compliance_only) + + response_data = { + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(), + } + + if policy_result.compliance_summary: + response_data["mfa_compliance"] = { + "overall_status": policy_result.compliance_summary.overall_status, + "missing_methods": policy_result.compliance_summary.missing_methods, + "deadline_at": policy_result.compliance_summary.deadline_at, + "orgs": [ + { + "organization_id": org.organization_id, + "organization_name": org.organization_name, + "status": org.status, + "effective_mode": org.effective_mode, + "deadline_at": org.deadline_at, + "applied_at": org.applied_at, + } + for org in policy_result.compliance_summary.orgs + ], + } + + if is_compliance_only: + response_data["requires_mfa_enrollment"] = True + + user_orgs = user.get_organizations() + if not user_orgs: + from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + from datetime import datetime, timezone as _tz + _now = datetime.now(_tz.utc) + pending_invites = OrgInviteToken.query.filter( + OrgInviteToken.email == user.email, + OrgInviteToken.accepted_at.is_(None), + OrgInviteToken.expires_at > _now, + OrgInviteToken.deleted_at.is_(None), + ).all() + response_data["pending_invites"] = [ + { + "token": inv.token, + "organization": {"id": str(inv.organization_id), "name": inv.organization.name}, + "role": inv.role, + "expires_at": inv.expires_at.isoformat(), + } + for inv in pending_invites + ] + response_data["requires_org_setup"] = True + + return api_response(data=response_data, message="Login successful") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + + +@api_v1_bp.route("/auth/logout", methods=["POST"]) +@login_required +def logout(): + if g.current_session: + AuthService.revoke_session(g.current_session.id, reason="User logout") + return api_response(message="Logout successful") + + +@api_v1_bp.route("/auth/me", methods=["GET"]) +@login_required +def get_current_user(): + user = g.current_user + return api_response( + data={ + "user": user.to_dict(), + "organizations": [ + { + "id": membership.organization.id, + "name": membership.organization.name, + "slug": membership.organization.slug, + "role": membership.role.value if hasattr(membership.role, "value") else str(membership.role), + } + for membership in user.organization_memberships + if membership.deleted_at is None and membership.organization and not membership.organization.deleted_at + ], + }, + message="User retrieved successfully", + ) + + +@api_v1_bp.route("/auth/sessions", methods=["GET"]) +@login_required +def get_user_sessions(): + from gatehouse_app.services.session_service import SessionService + + sessions = SessionService.get_user_sessions(g.current_user.id, active_only=True) + return api_response(data={"sessions": [s.to_dict() for s in sessions], "count": len(sessions)}, message="Sessions retrieved successfully") + + +@api_v1_bp.route("/auth/sessions/", methods=["DELETE"]) +@login_required +def revoke_session(session_id): + from gatehouse_app.models.user.session import Session + + user_session = Session.query.filter_by(id=session_id, user_id=g.current_user.id, deleted_at=None).first() + if not user_session: + return api_response(success=False, message="Session not found", status=404, error_type="NOT_FOUND") + + AuthService.revoke_session(session_id, reason="Revoked by user") + return api_response(message="Session revoked successfully") + + +@api_v1_bp.route("/auth/token", methods=["GET"]) +@login_required +def get_token(): + from flask import redirect as flask_redirect + from urllib.parse import urlparse + + token = g.current_session.token + redirect_url = request.args.get("redirect", "").strip() + + if redirect_url: + allowed_origins = set(current_app.config.get("CORS_ORIGINS", [])) + frontend_url = current_app.config.get("FRONTEND_URL", "") + if frontend_url: + parsed = urlparse(frontend_url) + allowed_origins.add(f"{parsed.scheme}://{parsed.netloc}") + + parsed_redirect = urlparse(redirect_url) + redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}" + + if redirect_origin not in allowed_origins: + return api_response(success=False, message="Redirect URL is not allowed.", status=400, error_type="INVALID_REDIRECT") + + sep = "&" if "?" in redirect_url else "?" + return flask_redirect(f"{redirect_url}{sep}token={token}", code=302) + + return api_response(data={"token": token}, message="Token retrieved") diff --git a/gatehouse_app/api/v1/auth/password.py b/gatehouse_app/api/v1/auth/password.py new file mode 100644 index 0000000..b972cbd --- /dev/null +++ b/gatehouse_app/api/v1/auth/password.py @@ -0,0 +1,218 @@ +"""Password reset, email verification, and account activation endpoints.""" +import logging +from flask import request, current_app +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.extensions import limiter +from gatehouse_app.utils.response import api_response +from gatehouse_app.services.auth_service import AuthService +from gatehouse_app.services.notification_service import NotificationService + +_logger = logging.getLogger(__name__) + + +@api_v1_bp.route("/auth/forgot-password", methods=["POST"]) +@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_FORGOT_PASSWORD"]) +def forgot_password(): + from gatehouse_app.models import User, PasswordResetToken + + data = request.get_json() or {} + email = (data.get("email") or "").strip().lower() + + if not email: + return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR") + + user = User.query.filter_by(email=email, deleted_at=None).first() + if user: + try: + reset_token = PasswordResetToken.generate(user_id=user.id) + app_url = current_app.config.get("APP_URL", "http://localhost:8080") + reset_link = f"{app_url}/reset-password?token={reset_token.token}" + NotificationService._send_email( + to_address=user.email, + subject="Reset your Gatehouse password", + body=( + f"Hi {user.full_name or user.email},\n\n" + f"You requested a password reset for your Gatehouse account.\n\n" + f"Click the link below to reset your password (valid for 2 hours):\n" + f"{reset_link}\n\n" + f"If you did not request this, you can safely ignore this email.\n\n" + f"Gatehouse Security Team" + ), + ) + _logger.info(f"Password reset token generated for user {user.id}") + except Exception as exc: + _logger.exception(f"Error generating password reset token: {exc}") + + return api_response(data={}, message="If an account exists for this email, you will receive a password reset link shortly.") + + +@api_v1_bp.route("/auth/reset-password", methods=["POST"]) +@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESET_PASSWORD"]) +def reset_password(): + from gatehouse_app.extensions import bcrypt + from gatehouse_app.models import PasswordResetToken, AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + from gatehouse_app.extensions import db + + data = request.get_json() or {} + token_value = (data.get("token") or "").strip() + new_password = data.get("password") or "" + password_confirm = data.get("password_confirm") or "" + + if not token_value or not new_password: + return api_response(success=False, message="Token and new password are required", status=400, error_type="VALIDATION_ERROR") + + if new_password != password_confirm: + return api_response(success=False, message="Passwords do not match", status=400, error_type="VALIDATION_ERROR") + + if len(new_password) < 8: + return api_response(success=False, message="Password must be at least 8 characters", status=400, error_type="VALIDATION_ERROR") + + reset_token = PasswordResetToken.query.filter_by(token=token_value).first() + if not reset_token or not reset_token.is_valid: + return api_response(success=False, message="This password reset link is invalid or has expired.", status=400, error_type="INVALID_TOKEN") + + try: + user = reset_token.user + auth_method = AuthenticationMethod.query.filter_by(user_id=user.id, method_type=AuthMethodType.PASSWORD, deleted_at=None).first() + if auth_method: + auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8") + db.session.add(auth_method) + reset_token.consume() + _logger.info(f"Password reset for user {user.id}") + return api_response(data={}, message="Your password has been reset. You can now sign in with your new password.") + except Exception as exc: + _logger.exception(f"Error resetting password: {exc}") + return api_response(success=False, message="An error occurred while resetting your password.", status=500, error_type="INTERNAL_ERROR") + + +@api_v1_bp.route("/auth/verify-email", methods=["POST"]) +def verify_email(): + from gatehouse_app.models import EmailVerificationToken + from gatehouse_app.extensions import db + + data = request.get_json() or {} + token_value = (data.get("token") or "").strip() + + if not token_value: + return api_response(success=False, message="Verification token is required", status=400, error_type="VALIDATION_ERROR") + + verify_token = EmailVerificationToken.query.filter_by(token=token_value).first() + if not verify_token or not verify_token.is_valid: + return api_response(success=False, message="This verification link is invalid or has expired.", status=400, error_type="INVALID_TOKEN") + + try: + user = verify_token.user + user.email_verified = True + db.session.add(user) + verify_token.consume() + _logger.info(f"Email verified for user {user.id}") + return api_response(data={}, message="Your email has been verified. You can now sign in.") + except Exception as exc: + _logger.exception(f"Error verifying email: {exc}") + return api_response(success=False, message="An error occurred while verifying your email.", status=500, error_type="INTERNAL_ERROR") + + +@api_v1_bp.route("/auth/resend-verification", methods=["POST"]) +def resend_verification(): + from gatehouse_app.models import User, EmailVerificationToken + + data = request.get_json() or {} + email = (data.get("email") or "").strip().lower() + + if not email: + return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR") + + user = User.query.filter_by(email=email, deleted_at=None).first() + if user and not user.email_verified: + try: + verify_token = EmailVerificationToken.generate(user_id=user.id) + app_url = current_app.config.get("APP_URL", "http://localhost:8080") + verify_link = f"{app_url}/verify-email?token={verify_token.token}" + NotificationService._send_email( + to_address=user.email, + subject="Verify your Gatehouse email address", + body=( + f"Hi {user.full_name or user.email},\n\n" + f"Please verify your email address by clicking the link below (valid for 24 hours):\n" + f"{verify_link}\n\n" + f"Gatehouse Security Team" + ), + ) + _logger.info(f"Verification email sent for user {user.id}") + except Exception as exc: + _logger.exception(f"Error sending verification email: {exc}") + + return api_response(data={}, message="If an account exists for this email and is not yet verified, you will receive a verification link shortly.") + + +@api_v1_bp.route("/auth/activate", methods=["POST"]) +def activate_account(): + import secrets + from gatehouse_app.models.user.user import User + from gatehouse_app.extensions import db + + data = request.get_json() or {} + code = (data.get("code") or "").strip() + if not code: + return api_response(success=False, message="Activation code is required", status=400, error_type="VALIDATION_ERROR") + + user = User.query.filter_by(activation_key=code, deleted_at=None).first() + if not user: + return api_response(success=False, message="Invalid or expired activation code", status=404, error_type="NOT_FOUND") + + user.activated = True + user.activation_key = None + db.session.add(user) + db.session.commit() + + user_session = AuthService.create_session(user) + _logger.info(f"Account activated for user {user.id}") + + return api_response( + data={ + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(), + }, + message="Account activated successfully", + ) + + +@api_v1_bp.route("/auth/resend-activation", methods=["POST"]) +def resend_activation(): + import secrets + from gatehouse_app.models.user.user import User + from gatehouse_app.extensions import db + + data = request.get_json() or {} + email = (data.get("email") or "").strip().lower() + if not email: + return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR") + + user = User.query.filter_by(email=email, deleted_at=None).first() + if user and not user.activated: + try: + code = secrets.token_urlsafe(32) + user.activation_key = code + db.session.add(user) + db.session.commit() + + app_url = current_app.config.get("APP_URL", current_app.config.get("FRONTEND_URL", "http://localhost:8080")) + activate_link = f"{app_url}/activate?code={code}" + NotificationService._send_email( + to_address=user.email, + subject="Activate your Gatehouse account", + body=( + f"Hi {user.full_name or user.email},\n\n" + f"Please activate your Gatehouse account by clicking the link below:\n" + f"{activate_link}\n\n" + f"If you did not create an account, you can safely ignore this email.\n\n" + f"Gatehouse Security Team" + ), + ) + _logger.info(f"Activation email re-sent to {user.id}") + except Exception as exc: + _logger.exception(f"Error re-sending activation email: {exc}") + + return api_response(data={}, message="If an unactivated account exists for this email, you will receive a new activation link shortly.") diff --git a/gatehouse_app/api/v1/auth/totp.py b/gatehouse_app/api/v1/auth/totp.py new file mode 100644 index 0000000..0bb8505 --- /dev/null +++ b/gatehouse_app/api/v1/auth/totp.py @@ -0,0 +1,183 @@ +"""TOTP authentication endpoints.""" +from flask import request, session, g, current_app +from marshmallow import ValidationError +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.extensions import limiter +from gatehouse_app.utils.response import api_response +from gatehouse_app.schemas.auth_schema import ( + TOTPVerifyEnrollmentSchema, + TOTPVerifySchema, + TOTPDisableSchema, + TOTPRegenerateBackupCodesSchema, +) +from gatehouse_app.services.auth_service import AuthService +from gatehouse_app.services.mfa_policy_service import MfaPolicyService +from gatehouse_app.utils.decorators import login_required +from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError +from gatehouse_app.exceptions.validation_exceptions import ConflictError + + +@api_v1_bp.route("/auth/totp/enroll", methods=["POST"]) +@login_required +def enroll_totp(): + try: + result = AuthService.enroll_totp(g.current_user) + return api_response( + data={ + "secret": result["secret"], + "provisioning_uri": result["provisioning_uri"], + "qr_code": result["qr_code"], + "backup_codes": result["backup_codes"], + }, + message="TOTP enrollment initiated. Please verify with your authenticator app.", + status=201, + ) + except ConflictError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) + + +@api_v1_bp.route("/auth/totp/verify-enrollment", methods=["POST"]) +@login_required +def verify_totp_enrollment(): + try: + schema = TOTPVerifyEnrollmentSchema() + data = schema.load(request.json) + AuthService.verify_totp_enrollment(g.current_user, data["code"], client_utc_timestamp=data.get("client_timestamp")) + return api_response(message="TOTP enrollment completed successfully") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + except InvalidCredentialsError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) + + +@api_v1_bp.route("/auth/totp/verify", methods=["POST"]) +@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_TOTP_VERIFY"]) +def verify_totp(): + try: + schema = TOTPVerifySchema() + data = schema.load(request.json) + + user_id = session.get("totp_pending_user_id") or session.get("webauthn_pending_user_id") + if not user_id: + return api_response(success=False, message="No pending TOTP verification. Please login first.", status=401, error_type="AUTHENTICATION_ERROR") + + from gatehouse_app.models.user.user import User + user = User.query.get(user_id) + if not user: + return api_response(success=False, message="User not found", status=401, error_type="AUTHENTICATION_ERROR") + + from gatehouse_app.utils.constants import UserStatus + if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + session.pop("totp_pending_user_id", None) + session.pop("webauthn_pending_user_id", None) + return api_response(success=False, message="Account is suspended. Contact an administrator.", status=403, error_type="ACCOUNT_SUSPENDED") + + AuthService.authenticate_with_totp(user, data["code"], data.get("is_backup_code", False), client_utc_timestamp=data.get("client_timestamp")) + + policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False) + is_compliance_only = policy_result.create_compliance_only_session + user_session = AuthService.create_session(user, is_compliance_only=is_compliance_only) + + session.pop("totp_pending_user_id", None) + session.pop("webauthn_pending_user_id", None) + + response_data = { + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(), + } + + if policy_result.compliance_summary: + response_data["mfa_compliance"] = { + "overall_status": policy_result.compliance_summary.overall_status, + "missing_methods": policy_result.compliance_summary.missing_methods, + "deadline_at": policy_result.compliance_summary.deadline_at, + "orgs": [ + { + "organization_id": org.organization_id, + "organization_name": org.organization_name, + "status": org.status, + "effective_mode": org.effective_mode, + "deadline_at": org.deadline_at, + "applied_at": org.applied_at, + } + for org in policy_result.compliance_summary.orgs + ], + } + + if is_compliance_only: + response_data["requires_mfa_enrollment"] = True + + return api_response(data=response_data, message="TOTP verification successful") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + except InvalidCredentialsError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) + + +@api_v1_bp.route("/auth/totp/disable", methods=["DELETE"]) +@login_required +def disable_totp(): + try: + schema = TOTPDisableSchema() + data = schema.load(request.json) + AuthService.disable_totp(g.current_user, data["password"]) + return api_response(message="TOTP disabled successfully") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + except InvalidCredentialsError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) + + +@api_v1_bp.route("/auth/totp/status", methods=["GET"]) +@login_required +def get_totp_status(): + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + from gatehouse_app.extensions import db as _db + from datetime import datetime, timezone + + user = g.current_user + + stale = AuthenticationMethod.query.filter_by(user_id=user.id, method_type=AuthMethodType.TOTP, verified=False, deleted_at=None).all() + for s in stale: + secret = (s.provider_data or {}).get("secret") if s.provider_data else None + if not secret: + s.deleted_at = datetime.now(timezone.utc) + _db.session.add(s) + if stale: + _db.session.commit() + + totp_enabled = user.has_totp_enabled() + backup_codes_remaining = 0 + verified_at = None + + if totp_enabled: + totp_method = AuthenticationMethod.query.filter_by( + user_id=user.id, method_type=AuthMethodType.TOTP, verified=True, deleted_at=None + ).order_by(AuthenticationMethod.created_at.desc()).first() + + if totp_method and totp_method.provider_data: + backup_codes_remaining = len(totp_method.provider_data.get("backup_codes", [])) + if totp_method and totp_method.totp_verified_at: + ts = totp_method.totp_verified_at.isoformat() + verified_at = ts if ts.endswith("Z") else ts + "Z" + + return api_response( + data={"totp_enabled": totp_enabled, "verified_at": verified_at, "backup_codes_remaining": backup_codes_remaining}, + message="TOTP status retrieved successfully", + ) + + +@api_v1_bp.route("/auth/totp/regenerate-backup-codes", methods=["POST"]) +@login_required +def regenerate_totp_backup_codes(): + try: + schema = TOTPRegenerateBackupCodesSchema() + data = schema.load(request.json) + backup_codes = AuthService.regenerate_totp_backup_codes(g.current_user, data["password"]) + return api_response(data={"backup_codes": backup_codes}, message="Backup codes regenerated successfully") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + except InvalidCredentialsError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) diff --git a/gatehouse_app/api/v1/auth/webauthn.py b/gatehouse_app/api/v1/auth/webauthn.py new file mode 100644 index 0000000..dc5cf8b --- /dev/null +++ b/gatehouse_app/api/v1/auth/webauthn.py @@ -0,0 +1,217 @@ +"""WebAuthn passkey authentication endpoints.""" +import json +import base64 +import logging +from flask import request, session, g, jsonify +from marshmallow import ValidationError +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.schemas.webauthn_schema import ( + WebAuthnRegistrationBeginSchema, + WebAuthnRegistrationCompleteSchema, + WebAuthnLoginBeginSchema, + WebAuthnLoginCompleteSchema, + WebAuthnCredentialRenameSchema, +) +from gatehouse_app.services.auth_service import AuthService +from gatehouse_app.services.webauthn_service import WebAuthnService +from gatehouse_app.services.mfa_policy_service import MfaPolicyService +from gatehouse_app.utils.decorators import login_required +from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError + +logger = logging.getLogger(__name__) + + +@api_v1_bp.route("/auth/webauthn/register/begin", methods=["POST"]) +@login_required +def begin_webauthn_registration(): + options = WebAuthnService.generate_registration_challenge(g.current_user) + return jsonify(options), 200 + + +@api_v1_bp.route("/auth/webauthn/register/complete", methods=["POST"]) +@login_required +def complete_webauthn_registration(): + user_email = g.current_user.email + logger.info(f"WebAuthn registration completion started for user: {user_email}") + + try: + schema = WebAuthnRegistrationCompleteSchema() + data = schema.load(request.json) + + client_data_json_b64 = data.get("response", {}).get("clientDataJSON", "") + if not client_data_json_b64: + return api_response(success=False, message="Missing clientDataJSON in response", status=400, error_type="VALIDATION_ERROR") + + try: + padding = 4 - (len(client_data_json_b64) % 4) + padded = client_data_json_b64 + ("=" * padding if padding != 4 else "") + client_data_dict = json.loads(base64.urlsafe_b64decode(padded)) + except Exception as e: + return api_response(success=False, message=f"Failed to decode client data JSON: {str(e)}", status=400, error_type="VALIDATION_ERROR") + + challenge = client_data_dict.get("challenge") + if not challenge: + return api_response(success=False, message="Invalid challenge in client data", status=400, error_type="VALIDATION_ERROR") + + auth_method = WebAuthnService.verify_registration_response(g.current_user, data, challenge) + logger.info(f"WebAuthn registration completed successfully for user: {user_email}") + return api_response(data={"credential": auth_method.to_webauthn_dict()}, message="Passkey registered successfully", status=201) + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + except InvalidCredentialsError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) + except Exception as e: + logger.exception(f"WebAuthn registration unexpected error for user {user_email}: {e}") + return api_response(success=False, message="An unexpected error occurred during registration", status=500, error_type="INTERNAL_ERROR") + + +@api_v1_bp.route("/auth/webauthn/login/begin", methods=["POST"]) +def begin_webauthn_login(): + try: + schema = WebAuthnLoginBeginSchema() + data = schema.load(request.json) + + from gatehouse_app.models.user.user import User + user = User.query.filter_by(email=data["email"].lower(), deleted_at=None).first() + if not user: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + from gatehouse_app.utils.constants import UserStatus + if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + return api_response(success=False, message="Account is suspended. Contact an administrator.", status=403, error_type="ACCOUNT_SUSPENDED") + + if not user.has_webauthn_enabled(): + return api_response(success=False, message="No passkeys found for this account", status=404, error_type="NOT_FOUND") + + options = WebAuthnService.generate_authentication_challenge(user) + session["webauthn_pending_user_id"] = user.id + return jsonify(options), 200 + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + except Exception as e: + logger.exception(f"WebAuthn login begin unexpected error: {e}") + raise + + +@api_v1_bp.route("/auth/webauthn/login/complete", methods=["POST"]) +def complete_webauthn_login(): + try: + user_id = session.get("webauthn_pending_user_id") + if not user_id: + return api_response(success=False, message="No pending WebAuthn verification. Please initiate login first.", status=401, error_type="AUTHENTICATION_ERROR") + + schema = WebAuthnLoginCompleteSchema() + data = schema.load(request.json) + + from gatehouse_app.models.user.user import User + user = User.query.get(user_id) + if not user: + return api_response(success=False, message="User not found", status=401, error_type="AUTHENTICATION_ERROR") + + from gatehouse_app.utils.constants import UserStatus + if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + session.pop("webauthn_pending_user_id", None) + return api_response(success=False, message="Account is suspended. Contact an administrator.", status=403, error_type="ACCOUNT_SUSPENDED") + + client_data = data.get("response", {}).get("clientDataJSON", "") + client_data_dict = json.loads(base64.urlsafe_b64decode(client_data + "==")) + challenge = client_data_dict.get("challenge") + + if not challenge: + return api_response(success=False, message="Invalid challenge in client data", status=400, error_type="VALIDATION_ERROR") + + WebAuthnService.verify_authentication_response(user, data, challenge) + + policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False) + is_compliance_only = policy_result.create_compliance_only_session + user_session = AuthService.create_session(user, is_compliance_only=is_compliance_only) + session.pop("webauthn_pending_user_id", None) + + logger.info(f"WebAuthn login completed successfully for user: {user.email}") + + response_data = { + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(), + } + + if policy_result.compliance_summary: + response_data["mfa_compliance"] = { + "overall_status": policy_result.compliance_summary.overall_status, + "missing_methods": policy_result.compliance_summary.missing_methods, + "deadline_at": policy_result.compliance_summary.deadline_at, + "orgs": [ + { + "organization_id": org.organization_id, + "organization_name": org.organization_name, + "status": org.status, + "effective_mode": org.effective_mode, + "deadline_at": org.deadline_at, + "applied_at": org.applied_at, + } + for org in policy_result.compliance_summary.orgs + ], + } + + if is_compliance_only: + response_data["requires_mfa_enrollment"] = True + + return api_response(data=response_data, message="Login successful") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + except InvalidCredentialsError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) + except Exception as e: + logger.exception(f"WebAuthn login complete unexpected error: {e}") + raise + + +@api_v1_bp.route("/auth/webauthn/credentials", methods=["GET"]) +@login_required +def list_webauthn_credentials(): + credentials = WebAuthnService.get_user_credentials(g.current_user) + return api_response(data={"credentials": [c.to_webauthn_dict() for c in credentials], "count": len(credentials)}, message="Credentials retrieved successfully") + + +@api_v1_bp.route("/auth/webauthn/credentials/", methods=["DELETE"]) +@login_required +def delete_webauthn_credential(credential_id): + user = g.current_user + + if not WebAuthnService.credential_belongs_to_user(credential_id, user): + return api_response(success=False, message="Credential not found", status=404, error_type="NOT_FOUND") + + if user.get_webauthn_credential_count() <= 1: + return api_response(success=False, message="Cannot delete the last passkey. Add another passkey first.", status=400, error_type="BAD_REQUEST") + + if not WebAuthnService.delete_credential(credential_id, user): + return api_response(success=False, message="Credential not found", status=404, error_type="NOT_FOUND") + + return api_response(message="Passkey deleted successfully") + + +@api_v1_bp.route("/auth/webauthn/credentials/", methods=["PATCH"]) +@login_required +def rename_webauthn_credential(credential_id): + try: + schema = WebAuthnCredentialRenameSchema() + data = schema.load(request.json) + + if not WebAuthnService.rename_credential(credential_id, g.current_user, data["name"]): + return api_response(success=False, message="Credential not found", status=404, error_type="NOT_FOUND") + + credential = WebAuthnService.get_credential_by_id(credential_id, g.current_user) + return api_response(data={"credential": credential.to_webauthn_dict() if credential else None}, message="Passkey renamed successfully") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + + +@api_v1_bp.route("/auth/webauthn/status", methods=["GET"]) +@login_required +def get_webauthn_status(): + user = g.current_user + return api_response( + data={"webauthn_enabled": user.has_webauthn_enabled(), "credential_count": user.get_webauthn_credential_count()}, + message="WebAuthn status retrieved successfully", + ) diff --git a/gatehouse_app/api/v1/external_auth.py b/gatehouse_app/api/v1/external_auth.py deleted file mode 100644 index a23101c..0000000 --- a/gatehouse_app/api/v1/external_auth.py +++ /dev/null @@ -1,1449 +0,0 @@ -"""External authentication provider endpoints.""" -import json -import logging -from flask import request, g -from marshmallow import ValidationError -from gatehouse_app.api.v1 import api_v1_bp -from gatehouse_app.utils.response import api_response -from gatehouse_app.utils.decorators import login_required -from gatehouse_app.utils.constants import AuthMethodType -from gatehouse_app.services.external_auth_service import ( - ExternalAuthService, - ExternalAuthError, -) -from gatehouse_app.services.oauth_flow_service import ( - OAuthFlowService, - OAuthFlowError, -) -from gatehouse_app.services.audit_service import AuditService - -_OAUTH_BRIDGE_TTL = 600 # 10 minutes - - -def _store_oidc_bridge(oauth_state: str, oidc_session_id: str) -> None: - """Store oidc_session_id keyed by OAuth state for retrieval in callback.""" - try: - import gatehouse_app.extensions as _ext - rc = _ext.redis_client - if rc is not None: - rc.setex(f"oauth_oidc_bridge:{oauth_state}", _OAUTH_BRIDGE_TTL, oidc_session_id) - except Exception: - pass - - -def _pop_oidc_bridge(oauth_state: str) -> str | None: - """Retrieve and delete oidc_session_id for the given OAuth state.""" - try: - import gatehouse_app.extensions as _ext - rc = _ext.redis_client - if rc is not None: - key = f"oauth_oidc_bridge:{oauth_state}" - val = rc.get(key) - if val: - rc.delete(key) - return val.decode() if isinstance(val, bytes) else val - except Exception: - pass - return None - - -def _store_cli_redirect(oauth_state: str, redirect_url: str) -> None: - """Store CLI redirect_url keyed by OAuth state (for /token_please flow).""" - try: - import gatehouse_app.extensions as _ext - rc = _ext.redis_client - if rc is not None: - rc.setex(f"oauth_cli_redirect:{oauth_state}", _OAUTH_BRIDGE_TTL, redirect_url) - except Exception: - pass - - -def _pop_cli_redirect(oauth_state: str) -> str | None: - """Retrieve and delete CLI redirect_url for the given OAuth state.""" - try: - import gatehouse_app.extensions as _ext - rc = _ext.redis_client - if rc is not None: - key = f"oauth_cli_redirect:{oauth_state}" - val = rc.get(key) - if val: - rc.delete(key) - return val.decode() if isinstance(val, bytes) else val - except Exception: - pass - return None - -logger = logging.getLogger(__name__) - - -# Provider type mapping -PROVIDER_TYPE_MAP = { - "google": AuthMethodType.GOOGLE, - "github": AuthMethodType.GITHUB, - "microsoft": AuthMethodType.MICROSOFT, -} - - -def get_provider_type(provider: str) -> AuthMethodType: - """Get AuthMethodType from provider string.""" - provider_lower = provider.lower() - if provider_lower not in PROVIDER_TYPE_MAP: - raise ExternalAuthError( - f"Unsupported provider: {provider}", - "UNSUPPORTED_PROVIDER", - 400, - ) - return PROVIDER_TYPE_MAP[provider_lower] - - -@api_v1_bp.route("/token_please", methods=["GET"]) -def token_please(): - """ - CLI token acquisition endpoint. - - Redirects the user's browser to the Gatehouse login page so they can - authenticate using any method (password, OAuth, passkey, TOTP, etc.). - On successful login the frontend delivers the session token directly to - the CLI's local callback server. - - This endpoint is designed for CLI clients that: - 1. Start a local HTTP server on LISTENER_SERVER_PORT (e.g. 8250) - 2. Open a browser to /api/v1/token_please?redirect_url=http://127.0.0.1:8250/?token= - 3. Wait for the browser to deliver the token to their local server - - Query parameters: - redirect_url: Local callback URL where the token will be appended - """ - import secrets - from urllib.parse import urlencode, quote - from flask import current_app, redirect as flask_redirect - - redirect_url = request.args.get("redirect_url", "").strip() - - if not redirect_url: - return api_response( - success=False, - message="redirect_url query parameter is required", - status=400, - error_type="MISSING_REDIRECT_URL", - ) - - # Validate redirect_url is localhost/127.0.0.1 (security: prevent open redirect) - from urllib.parse import urlparse as _urlparse - parsed = _urlparse(redirect_url) - if parsed.hostname not in ("localhost", "127.0.0.1"): - return api_response( - success=False, - message="redirect_url must point to localhost", - status=400, - error_type="INVALID_REDIRECT_URL", - ) - - # Store the CLI redirect URL in Redis keyed by a short-lived token so the - # frontend can retrieve it after login without it being visible in the URL. - cli_token = secrets.token_urlsafe(32) - try: - import gatehouse_app.extensions as _ext - rc = _ext.redis_client - if rc is not None: - rc.setex(f"cli_redirect:{cli_token}", _OAUTH_BRIDGE_TTL, redirect_url) - else: - logger.warning("Redis not available; passing cli_redirect directly in URL") - cli_token = None - except Exception: - cli_token = None - - frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080") - - if cli_token: - # Pass an opaque token; the frontend exchanges it for the real URL via - # GET /api/v1/cli/redirect-url?token= - login_url = f"{frontend_url}/login?cli_token={cli_token}" - else: - # Fallback: put the redirect URL directly (still localhost-only, validated above) - login_url = f"{frontend_url}/login?cli_redirect={quote(redirect_url, safe='')}" - - logger.info(f"CLI token_please: redirecting browser to Gatehouse login page") - return flask_redirect(login_url, code=302) - - -@api_v1_bp.route("/cli/redirect-url", methods=["GET"]) -def cli_redirect_url_lookup(): - """ - Exchange a short-lived cli_token for the CLI's local redirect URL. - - Called by the frontend LoginPage after it detects the cli_token query - param so it can obtain the actual CLI callback URL from Redis without - exposing it in the browser URL bar. - - Query parameters: - token: The cli_token issued by /token_please - - Returns: - 200: { "redirect_url": "http://127.0.0.1:8250/?token=" } - 400: Missing token - 404: Token not found or expired - """ - cli_token = request.args.get("token", "").strip() - if not cli_token: - return api_response( - success=False, - message="token query parameter is required", - status=400, - error_type="MISSING_TOKEN", - ) - - try: - import gatehouse_app.extensions as _ext - rc = _ext.redis_client - if rc is not None: - key = f"cli_redirect:{cli_token}" - val = rc.get(key) - if val is None: - return api_response( - success=False, - message="CLI token not found or expired", - status=404, - error_type="TOKEN_NOT_FOUND", - ) - # Keep the key alive until the login actually completes (consume on use - # would break multi-step auth like TOTP), so we leave it as-is. - redirect_url = val.decode() if isinstance(val, bytes) else val - return api_response(data={"redirect_url": redirect_url}) - except Exception as e: - logger.error(f"cli_redirect_url_lookup error: {e}") - return api_response( - success=False, - message="Internal error looking up CLI token", - status=500, - error_type="INTERNAL_ERROR", - ) - - return api_response( - success=False, - message="Redis not available", - status=503, - error_type="SERVICE_UNAVAILABLE", - ) - - -# ============================================================================= -# Provider Configuration Endpoints (Admin) -# ============================================================================= - -@api_v1_bp.route("/auth/external/providers", methods=["GET"]) -@login_required -def list_providers(): - """ - List available external authentication providers for current organization. - - Returns: - 200: List of providers with their configuration status - 401: Not authenticated - """ - from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig - from gatehouse_app.services.external_auth_service import ExternalProviderConfig - - # Check app-level provider configs (ApplicationProviderConfig) - app_configs = { - c.provider_type.lower(): c - for c in ApplicationProviderConfig.query.filter_by(is_enabled=True).all() - } - - # Get user's primary organization — check for org-level overrides too - user_orgs = g.current_user.get_organizations() - org_configs = {} - if user_orgs: - organization_id = user_orgs[0].id - org_level = ExternalProviderConfig.query.filter_by( - organization_id=organization_id, - ).all() - org_configs = {c.provider_type.lower(): c for c in org_level} - - def provider_info(provider_id: str, name: str) -> dict: - app_cfg = app_configs.get(provider_id) - org_cfg = org_configs.get(provider_id) - is_configured = app_cfg is not None or org_cfg is not None - is_active = False - if app_cfg: - is_active = bool(app_cfg.is_enabled) - if org_cfg and hasattr(org_cfg, "is_active"): - is_active = bool(org_cfg.is_active) - return { - "id": provider_id, - "name": name, - "type": provider_id, - "is_configured": is_configured, - "is_active": is_active, - "settings": { - "requires_domain": False, - "supports_refresh_tokens": True, - }, - } - - providers = [ - provider_info("google", "Google"), - provider_info("github", "GitHub"), - provider_info("microsoft", "Microsoft"), - ] - - return api_response( - data={"providers": providers}, - message="Providers retrieved successfully", - ) - - -@api_v1_bp.route("/auth/external/providers//config", methods=["GET"]) -@login_required -def get_provider_config(provider: str): - """ - Get provider configuration (admin only). - - Args: - provider: Provider type (google, github, microsoft) - - Returns: - 200: Provider configuration - 401: Not authenticated - 403: Not authorized (not admin) - 404: Provider not configured - """ - from gatehouse_app.models import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - from gatehouse_app.services.external_auth_service import ExternalProviderConfig - - provider_type = get_provider_type(provider) - - # Get user's primary organization - user_orgs = g.current_user.get_organizations() - if not user_orgs: - return api_response( - success=False, - message="No organizations found for user", - status=400, - error_type="BAD_REQUEST", - ) - - organization_id = user_orgs[0].id - - # Check if user is admin - member = OrganizationMember.query.filter_by( - user_id=g.current_user.id, - organization_id=organization_id, - ).first() - - if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: - return api_response( - success=False, - message="Admin access required", - status=403, - error_type="FORBIDDEN", - ) - - # Get provider config - config = ExternalProviderConfig.query.filter_by( - organization_id=organization_id, - provider_type=provider_type.value, - ).first() - - if not config: - return api_response( - success=False, - message=f"{provider.title()} OAuth is not configured", - status=404, - error_type="NOT_FOUND", - ) - - return api_response( - data=config.to_dict(include_secrets=False), - message="Provider configuration retrieved successfully", - ) - - -@api_v1_bp.route("/auth/external/providers//config", methods=["POST"]) -@login_required -def create_or_update_provider_config(provider: str): - """ - Create or update provider configuration (admin only). - - Args: - provider: Provider type (google, github, microsoft) - - Request body: - client_id: OAuth client ID - client_secret: OAuth client secret - scopes: List of OAuth scopes - redirect_uris: List of allowed redirect URIs - settings: Provider-specific settings - is_active: Whether the provider is active - - Returns: - 200: Provider configuration updated - 201: Provider configuration created - 400: Validation error - 401: Not authenticated - 403: Not authorized (not admin) - """ - from gatehouse_app.models import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - from gatehouse_app.services.external_auth_service import ExternalProviderConfig - - provider_type = get_provider_type(provider) - - # Get user's primary organization - user_orgs = g.current_user.get_organizations() - if not user_orgs: - return api_response( - success=False, - message="No organizations found for user", - status=400, - error_type="BAD_REQUEST", - ) - - organization_id = user_orgs[0].id - - # Check if user is admin - member = OrganizationMember.query.filter_by( - user_id=g.current_user.id, - organization_id=organization_id, - ).first() - - if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: - return api_response( - success=False, - message="Admin access required", - status=403, - error_type="FORBIDDEN", - ) - - # Validate request data - data = request.json or {} - client_id = data.get("client_id") - client_secret = data.get("client_secret") - - if not client_id: - return api_response( - success=False, - message="client_id is required", - status=400, - error_type="VALIDATION_ERROR", - ) - - # Get or create config - config = ExternalProviderConfig.query.filter_by( - organization_id=organization_id, - provider_type=provider_type.value, - ).first() - - is_new = config is None - - if config: - # Update existing - config.client_id = client_id - if client_secret: - config.set_client_secret(client_secret) - config.scopes = data.get("scopes", ["openid", "profile", "email"]) - config.redirect_uris = data.get("redirect_uris", []) - config.settings = data.get("settings", {}) - config.is_active = data.get("is_active", True) - config.save() - - # Audit log - config update - AuditService.log_external_auth_config_update( - user_id=g.current_user.id, - organization_id=organization_id, - provider_type=provider_type.value, - config_id=config.id, - changes={ - "client_id": "updated", - "client_secret": "updated" if client_secret else None, - "scopes": data.get("scopes"), - "redirect_uris": data.get("redirect_uris"), - "is_active": config.is_active, - }, - ) - else: - # Create new - get provider endpoints - auth_url, token_url, userinfo_url = _get_provider_endpoints(provider_type) - - config = ExternalProviderConfig( - organization_id=organization_id, - provider_type=provider_type.value, - client_id=client_id, - client_secret_encrypted=None, - auth_url=auth_url, - token_url=token_url, - userinfo_url=userinfo_url, - scopes=data.get("scopes", ["openid", "profile", "email"]), - redirect_uris=data.get("redirect_uris", []), - settings=data.get("settings", {}), - is_active=data.get("is_active", True), - ) - - if client_secret: - config.set_client_secret(client_secret) - - config.save() - - # Audit log - config create - AuditService.log_external_auth_config_create( - user_id=g.current_user.id, - organization_id=organization_id, - provider_type=provider_type.value, - config_id=config.id, - ) - - return api_response( - data=config.to_dict(include_secrets=False), - message="Provider configuration saved successfully", - status=201 if is_new else 200, - ) - - -@api_v1_bp.route("/auth/external/providers//config", methods=["DELETE"]) -@login_required -def delete_provider_config(provider: str): - """ - Delete provider configuration (admin only). - - Args: - provider: Provider type (google, github, microsoft) - - Returns: - 200: Provider configuration deleted - 401: Not authenticated - 403: Not authorized (not admin) - 404: Provider not configured - """ - from gatehouse_app.models import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - from gatehouse_app.services.external_auth_service import ExternalProviderConfig - - provider_type = get_provider_type(provider) - - # Get user's primary organization - user_orgs = g.current_user.get_organizations() - if not user_orgs: - return api_response( - success=False, - message="No organizations found for user", - status=400, - error_type="BAD_REQUEST", - ) - - organization_id = user_orgs[0].id - - # Check if user is admin - member = OrganizationMember.query.filter_by( - user_id=g.current_user.id, - organization_id=organization_id, - ).first() - - if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: - return api_response( - success=False, - message="Admin access required", - status=403, - error_type="FORBIDDEN", - ) - - # Get and delete config - config = ExternalProviderConfig.query.filter_by( - organization_id=organization_id, - provider_type=provider_type.value, - ).first() - - if not config: - return api_response( - success=False, - message=f"{provider.title()} OAuth is not configured", - status=404, - error_type="NOT_FOUND", - ) - - config_id = config.id - config.delete() - - # Audit log - config delete - AuditService.log_external_auth_config_delete( - user_id=g.current_user.id, - organization_id=organization_id, - provider_type=provider_type.value, - config_id=config_id, - ) - - return api_response( - message=f"{provider.title()} provider configuration deleted successfully", - ) - - -# ============================================================================= -# Account Linking Endpoints -# ============================================================================= - -@api_v1_bp.route("/auth/external/linked-accounts", methods=["GET"]) -@login_required -def list_linked_accounts(): - """ - List all linked external accounts for the current user. - - Returns: - 200: List of linked accounts - 401: Not authenticated - """ - linked_accounts = ExternalAuthService.get_linked_accounts(g.current_user.id) - - # Check if user has other auth methods (for unlink availability) - from gatehouse_app.models import AuthenticationMethod - other_methods = AuthenticationMethod.query.filter_by( - user_id=g.current_user.id, - ).count() - - return api_response( - data={ - "linked_accounts": linked_accounts, - "unlink_available": other_methods > 1, - }, - message="Linked accounts retrieved successfully", - ) - - -@api_v1_bp.route("/auth/external//link", methods=["POST"]) -@login_required -def initiate_link_account(provider: str): - """ - Initiate OAuth flow to link an external account. - - Args: - provider: Provider type (google, github, microsoft) - - Request body: - redirect_uri: Optional redirect URI after linking - - Returns: - 302: Redirect to provider authorization page - 400: Validation error or provider not configured - 401: Not authenticated - """ - provider_type = get_provider_type(provider) - - # Get user's organization - user_orgs = g.current_user.get_organizations() - organization_id = user_orgs[0].id if user_orgs else None - - # Get optional redirect URI - data = request.json or {} - redirect_uri = data.get("redirect_uri") - - try: - # Initiate link flow - auth_url, state = ExternalAuthService.initiate_link_flow( - user_id=g.current_user.id, - provider_type=provider_type, - organization_id=organization_id, - redirect_uri=redirect_uri, - ) - - return api_response( - data={ - "authorization_url": auth_url, - "state": state, - }, - message="Link flow initiated. Redirect to authorization URL.", - ) - - except ExternalAuthError as e: - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - - -@api_v1_bp.route("/auth/external//unlink", methods=["DELETE"]) -@login_required -def unlink_account(provider: str): - """ - Unlink an external account from the user's profile. - - Args: - provider: Provider type (google, github, microsoft) - - Returns: - 200: Account unlinked successfully - 400: Validation error or cannot unlink last method - 401: Not authenticated - 404: Provider not linked - """ - provider_type = get_provider_type(provider) - - # Get user's organization - user_orgs = g.current_user.get_organizations() - organization_id = user_orgs[0].id if user_orgs else None - - try: - ExternalAuthService.unlink_provider( - user_id=g.current_user.id, - provider_type=provider_type, - organization_id=organization_id, - ) - - return api_response( - message=f"{provider.title()} account unlinked successfully", - ) - - except ExternalAuthError as e: - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - - -# ============================================================================= -# OAuth Flow Endpoints -# ============================================================================= - -@api_v1_bp.route("/auth/external//authorize", methods=["GET"]) -def initiate_oauth_authorize(provider: str): - """ - Initiate OAuth authentication or account registration flow. - - This endpoint initiates OAuth flows without requiring organization_id upfront. - The organization context is determined after successful authentication based on - the user's memberships. - - Args: - provider: Provider type (google, github, microsoft) - - Query parameters: - flow: 'login' or 'register' (default: 'login') - redirect_uri: Optional redirect URI after OAuth completion - organization_id: Optional organization hint (for SSO discovery) - - Returns: - 200: Authorization URL and state token - 400: Validation error or provider not configured at application level - - Response: - { - "authorization_url": "https://...", - "state": "state_token" - } - """ - # Get query parameters - organization_id is now optional - flow = request.args.get("flow", "login") - redirect_uri = request.args.get("redirect_uri") - organization_id = request.args.get("organization_id") # Optional hint - oidc_session_id = request.args.get("oidc_session_id") # OIDC bridge passthrough - - if flow not in ["login", "register"]: - return api_response( - success=False, - message="Invalid flow type. Must be 'login' or 'register'", - status=400, - error_type="VALIDATION_ERROR", - ) - - try: - provider_type = get_provider_type(provider) - if flow == "login": - auth_url, state = OAuthFlowService.initiate_login_flow( - provider_type=provider_type, - organization_id=organization_id, # Optional hint - redirect_uri=redirect_uri, - ) - else: - auth_url, state = OAuthFlowService.initiate_register_flow( - provider_type=provider_type, - organization_id=organization_id, # Optional hint - redirect_uri=redirect_uri, - ) - - # If this authorize was triggered during an OIDC bridge flow, remember - # the oidc_session_id so we can hand it back in the callback. - if oidc_session_id: - _store_oidc_bridge(state, oidc_session_id) - - return api_response( - data={ - "authorization_url": auth_url, - "state": state, - }, - message=f"OAuth {flow} flow initiated", - ) - - except OAuthFlowError as e: - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - except ExternalAuthError as e: - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - - -@api_v1_bp.route("/auth/external//callback", methods=["GET"]) -def handle_oauth_callback(provider: str): - """ - Handle OAuth callback from provider. - - Google (and other providers) redirect the browser here after authentication. - On success, this endpoint redirects the browser to the frontend - /oauth/callback page carrying the session token as a URL parameter so the - frontend SPA can store it without needing a second API call. - - Success redirect: - {FRONTEND_URL}/oauth/callback?token=TOKEN&expires_in=86400&state=STATE&flow=login&provider=google - - Error redirect: - {FRONTEND_URL}/oauth/callback?error=MESSAGE&error_type=TYPE&state=STATE - - Args: - provider: Provider type (google, github, microsoft) - - Query parameters from provider: - code: Authorization code - state: State parameter (CSRF token from OAuth flow) - error: Error code if auth failed at provider - error_description: Human-readable error description - """ - from urllib.parse import urlencode - from flask import current_app, redirect as flask_redirect - - provider_type = get_provider_type(provider) - - state = request.args.get("state") - authorization_code = request.args.get("code") - error = request.args.get("error") - error_description = request.args.get("error_description") - - frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080") - frontend_callback = f"{frontend_url}/oauth/callback" - - # Check if this is a CLI /token_please flow — retrieve stored redirect_url - cli_redirect_url = _pop_cli_redirect(state) if state else None - - def redirect_error(message: str, error_type: str = "OAUTH_ERROR"): - """Redirect to frontend (or CLI) with error params.""" - if cli_redirect_url: - # CLI flow: return a plain error page instead of redirecting back - from flask import make_response - return make_response( - f"

Authentication Error

{message}

" - f"

You may close this window.

", - 400, - ) - params = {"error": message, "error_type": error_type} - if state: - params["state"] = state - return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302) - - # Handle errors returned by the provider (e.g. user denied) - if error: - msg = error_description or f"Authorization failed: {error}" - return redirect_error(msg, error.upper()) - - if not authorization_code or not state: - return redirect_error("Missing authorization code or state parameter.") - - try: - result = OAuthFlowService.handle_callback( - provider_type=provider_type, - authorization_code=authorization_code, - state=state, - redirect_uri=None, # backend handles the full flow - error=None, - error_description=None, - ) - - if not result.get("success"): - return redirect_error("Authentication failed.", "AUTH_FAILED") - - flow_type = result.get("flow_type", "login") - - # ── Link flow: redirect to linked-accounts page ────────────────────── - if flow_type == "link": - params = {"flow": "link", "provider": provider, "linked": "1"} - return flask_redirect(f"{frontend_url}/linked-accounts?{urlencode(params)}", code=302) - - # ── Login / Register flow ───────────────────────────────────────────── - - # Recover oidc_session_id if this was triggered from an OIDC bridge flow - oidc_session_id = _pop_oidc_bridge(state) - - # Organization selection / creation flows are not supported in CLI mode - # (fall through to token redirect with whatever session we have) - - # Organization selection needed (user belongs to multiple orgs) - if result.get("requires_org_selection") and not cli_redirect_url: - import json - orgs = json.dumps(result.get("available_organizations", [])) - params = { - "requires_org_selection": "1", - "state": result["state"], - "provider": provider, - "flow": flow_type, - "orgs": orgs, - } - if oidc_session_id: - params["oidc_session_id"] = oidc_session_id - return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302) - - # Organization creation needed (new user via OAuth with no org) - if result.get("requires_org_creation") and not cli_redirect_url: - import json as _json - session_data = result.get("session", {}) - token = session_data.get("token", "") - expires_in = session_data.get("expires_in", 86400) - pending_invites = result.get("pending_invites", []) - params = { - "requires_org_creation": "1", - "state": result["state"], - "provider": provider, - "flow": flow_type, - "token": token, - "expires_in": str(expires_in), - "pending_invites": _json.dumps(pending_invites), - } - if oidc_session_id: - params["oidc_session_id"] = oidc_session_id - return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302) - - # Normal success — carry token to frontend via URL - session_data = result.get("session", {}) - token = session_data.get("token") - expires_in = session_data.get("expires_in", 86400) - - if not token: - return redirect_error("No session token returned by server.", "NO_TOKEN") - - params = { - "token": token, - "expires_in": str(expires_in), - "flow": flow_type, - "provider": provider, - "state": state, - } - user_info = result.get("user", {}) - if user_info.get("email"): - params["email"] = user_info["email"] - - # ── CLI /token_please flow: redirect to the CLI's local callback ───── - if cli_redirect_url: - # The CLI expects: http://127.0.0.1:8250/?token= - # cli_redirect_url already ends with "token=" so just append the value - cli_final_url = cli_redirect_url + token - logger.info( - f"CLI token_please success: provider={provider}, user={user_info.get('email')}, " - f"redirecting to CLI callback" - ) - return flask_redirect(cli_final_url, code=302) - - # ── Frontend flow ───────────────────────────────────────────────────── - # Pass oidc_session_id through so the frontend can complete the OIDC flow - if oidc_session_id: - params["oidc_session_id"] = oidc_session_id - - logger.info( - f"OAuth callback success: provider={provider}, flow={flow_type}, " - f"user={user_info.get('email')}, redirecting to frontend" - ) - return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302) - - except OAuthFlowError as e: - logger.warning(f"OAuth callback OAuthFlowError: {e.message}") - return redirect_error(e.message, e.error_type) - except Exception as e: - logger.error(f"OAuth callback unexpected error: {str(e)}", exc_info=True) - return redirect_error("An unexpected error occurred. Please try again.", "INTERNAL_ERROR") - - -@api_v1_bp.route("/auth/external/select-organization", methods=["POST"]) -def select_organization(): - """ - Complete OAuth flow by selecting an organization. - - This endpoint is called after OAuth callback when the user needs to select - which organization to log in to (when user belongs to multiple orgs). - - Request body: - state: The state token from the OAuth callback - organization_id: The selected organization ID - - Returns: - 200: Session created successfully - 400: Invalid state or organization - 404: Organization not found or user not a member - - Response: - { - "token": "session_token", - "expires_in": 86400, - "token_type": "Bearer", - "user": { - "id": "...", - "email": "...", - "full_name": "...", - "organization_id": "..." - } - } - """ - data = request.json or {} - state_token = data.get("state") - organization_id = data.get("organization_id") - - if not state_token: - return api_response( - success=False, - message="state is required", - status=400, - error_type="VALIDATION_ERROR", - ) - - if not organization_id: - return api_response( - success=False, - message="organization_id is required", - status=400, - error_type="VALIDATION_ERROR", - ) - - try: - # Validate state and get OAuth state record - state_record = OAuthFlowService.validate_state(state_token) - if not state_record or state_record.used: - return api_response( - success=False, - message="Invalid or expired state token", - status=400, - error_type="INVALID_STATE", - ) - - # The state should have user information from the OAuth callback - # We need to find the user that was authenticated - from gatehouse_app.models import User, AuthenticationMethod, Organization, OrganizationMember - - # Find user by provider authentication - # The state record should have provider info in extra_data if set by callback - # Otherwise, we need to find the most recently created auth method - auth_method = AuthenticationMethod.query.filter_by( - method_type=state_record.provider_type, - ).order_by(AuthenticationMethod.created_at.desc()).first() - - if not auth_method: - return api_response( - success=False, - message="Authentication session not found", - status=400, - error_type="SESSION_NOT_FOUND", - ) - - user = auth_method.user - - # Verify user is member of selected organization - org = Organization.query.get(organization_id) - if not org: - return api_response( - success=False, - message="Organization not found", - status=404, - error_type="NOT_FOUND", - ) - - member = OrganizationMember.query.filter_by( - user_id=user.id, - organization_id=organization_id, - ).first() - - if not member: - return api_response( - success=False, - message="You are not a member of this organization", - status=403, - error_type="FORBIDDEN", - ) - - # Create session for the selected organization - from gatehouse_app.services.session_service import SessionService - session = SessionService.create_session( - user=user, - organization_id=organization_id, - ) - - # Mark state as used - state_record.mark_used() - - # Audit log - login success with org selection - AuditService.log_external_auth_login( - user_id=user.id, - organization_id=organization_id, - provider_type=state_record.provider_type.value if isinstance(state_record.provider_type, AuthMethodType) else state_record.provider_type, - provider_user_id=auth_method.provider_user_id, - auth_method_id=auth_method.id, - session_id=session.id, - ) - - return api_response( - data={ - "token": session.token, - "expires_in": session.lifetime_seconds, - "token_type": "Bearer", - "user": { - "id": user.id, - "email": user.email, - "full_name": user.full_name, - "organization_id": organization_id, - }, - }, - message="Organization selected and session created successfully", - ) - - except Exception as e: - logger.error(f"Error in select_organization: {str(e)}", exc_info=True) - return api_response( - success=False, - message="An error occurred while selecting organization", - status=500, - error_type="INTERNAL_ERROR", - ) - - -# ============================================================================= -# Authorization Code Exchange Endpoint -# ============================================================================= - -@api_v1_bp.route("/auth/external/token", methods=["POST"]) -def exchange_authorization_code(): - """ - Exchange an authorization code for a session token. - - This endpoint is used by external applications (like oauth2-proxy, BookStack) - to exchange the authorization code received from the OAuth callback for a - session token. - - Request body (form-encoded or JSON): - grant_type: Must be "authorization_code" - code: The authorization code from the callback - redirect_uri: The redirect URI used in the original request - client_id: The client ID (optional, defaults to "external-app") - - Returns: - 200: Session token exchanged successfully - 400: Invalid or expired authorization code - 404: User not found - - Response: - { - "token": "session_token", - "expires_in": 86400, - "token_type": "Bearer", - "user": { - "id": "...", - "email": "...", - "full_name": "...", - "organization_id": "..." - } - } - """ - # Support both JSON and form-encoded requests - if request.is_json: - data = request.json or {} - else: - data = request.form or {} - - grant_type = data.get("grant_type") - code = data.get("code") - redirect_uri = data.get("redirect_uri") - client_id = data.get("client_id", "external-app") - - # Validate required parameters - if grant_type and grant_type != "authorization_code": - return api_response( - success=False, - message="Invalid grant_type. Must be 'authorization_code'", - status=400, - error_type="INVALID_GRANT_TYPE", - ) - - if not code: - return api_response( - success=False, - message="code is required", - status=400, - error_type="VALIDATION_ERROR", - ) - - if not redirect_uri: - return api_response( - success=False, - message="redirect_uri is required", - status=400, - error_type="VALIDATION_ERROR", - ) - - try: - result = OAuthFlowService.exchange_authorization_code( - code=code, - client_id=client_id, - redirect_uri=redirect_uri, - ip_address=request.remote_addr, - ) - - return api_response( - data={ - "token": result["token"], - "expires_in": result["expires_in"], - "token_type": result["token_type"], - "user": result["user"], - }, - message="Token exchanged successfully", - ) - - except OAuthFlowError as e: - return api_response( - success=False, - message=e.message, - status=e.status_code, - error_type=e.error_type, - ) - - -# ============================================================================= -# Helper Functions -# ============================================================================= - -def _get_provider_endpoints(provider_type: AuthMethodType): - """Get OAuth endpoints for a provider.""" - if provider_type == AuthMethodType.GOOGLE: - return ( - "https://accounts.google.com/o/oauth2/v2/auth", - "https://oauth2.googleapis.com/token", - "https://www.googleapis.com/oauth2/v3/userinfo", - ) - elif provider_type == AuthMethodType.GITHUB: - return ( - "https://github.com/login/oauth/authorize", - "https://github.com/login/oauth/access_token", - "https://api.github.com/user", - ) - elif provider_type == AuthMethodType.MICROSOFT: - return ( - "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", - "https://login.microsoftonline.com/common/oauth2/v2.0/token", - "https://graph.microsoft.com/oidc/userinfo", - ) - else: - raise ExternalAuthError( - f"Unsupported provider: {provider_type}", - "UNSUPPORTED_PROVIDER", - 400, - ) - - -# ============================================================================= -# Admin: Application-level OAuth Provider Management -# ============================================================================= - -@api_v1_bp.route("/admin/oauth/providers", methods=["GET"]) -@login_required -def admin_list_app_providers(): - """List all application-level OAuth provider configurations (admin only). - - Returns: - 200: List of providers with client_id and enabled status - 401: Not authenticated - 403: Not an admin - """ - from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig - from gatehouse_app.models import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - - # Verify caller is admin in any org - admin_memberships = OrganizationMember.query.filter( - OrganizationMember.user_id == g.current_user.id, - OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]), - ).all() - - if not admin_memberships: - return api_response( - success=False, - message="Admin access required", - status=403, - error_type="FORBIDDEN", - ) - - PROVIDERS = [ - {"id": "google", "name": "Google"}, - {"id": "github", "name": "GitHub"}, - {"id": "microsoft", "name": "Microsoft"}, - ] - - db_configs = { - c.provider_type: c - for c in ApplicationProviderConfig.query.all() - } - - result = [] - for p in PROVIDERS: - cfg = db_configs.get(p["id"]) - result.append({ - "id": p["id"], - "name": p["name"], - "is_configured": cfg is not None, - "is_enabled": cfg.is_enabled if cfg else False, - "client_id": cfg.client_id if cfg else None, - }) - - return api_response( - data={"providers": result}, - message="OAuth providers retrieved successfully", - ) - - -@api_v1_bp.route("/admin/oauth/providers/", methods=["PUT"]) -@login_required -def admin_configure_app_provider(provider: str): - """Create or update an application-level OAuth provider config (admin only). - - Args: - provider: Provider type (google, github, microsoft) - - Request body: - client_id: OAuth client ID - client_secret: OAuth client secret (optional — omit to keep existing) - is_enabled: Whether the provider is enabled (default: true) - - Returns: - 200: Provider configuration updated - 400: Validation error - 401: Not authenticated - 403: Not an admin - """ - from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig - from gatehouse_app.models import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - from gatehouse_app.extensions import db - - SUPPORTED = ["google", "github", "microsoft"] - if provider not in SUPPORTED: - return api_response( - success=False, - message=f"Unsupported provider. Must be one of: {', '.join(SUPPORTED)}", - status=400, - error_type="VALIDATION_ERROR", - ) - - # Verify caller is admin in any org - admin_memberships = OrganizationMember.query.filter( - OrganizationMember.user_id == g.current_user.id, - OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]), - ).all() - - if not admin_memberships: - return api_response( - success=False, - message="Admin access required", - status=403, - error_type="FORBIDDEN", - ) - - data = request.json or {} - client_id = (data.get("client_id") or "").strip() - client_secret = (data.get("client_secret") or "").strip() - is_enabled = data.get("is_enabled", True) - - if not client_id: - return api_response( - success=False, - message="client_id is required", - status=400, - error_type="VALIDATION_ERROR", - ) - - cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first() - if cfg: - cfg.client_id = client_id - if client_secret: - cfg.set_client_secret(client_secret) - cfg.is_enabled = bool(is_enabled) - db.session.commit() - else: - cfg = ApplicationProviderConfig( - provider_type=provider, - client_id=client_id, - is_enabled=bool(is_enabled), - ) - if client_secret: - cfg.set_client_secret(client_secret) - db.session.add(cfg) - db.session.commit() - - return api_response( - data={ - "provider": { - "id": provider, - "client_id": cfg.client_id, - "is_enabled": cfg.is_enabled, - } - }, - message=f"{provider.capitalize()} OAuth provider configured successfully", - ) - - -@api_v1_bp.route("/admin/oauth/providers/", methods=["DELETE"]) -@login_required -def admin_delete_app_provider(provider: str): - """Delete an application-level OAuth provider config (admin only). - - Args: - provider: Provider type (google, github, microsoft) - - Returns: - 200: Provider configuration deleted - 404: Provider not found - 401: Not authenticated - 403: Not an admin - """ - from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig - from gatehouse_app.models import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - from gatehouse_app.extensions import db - - # Verify caller is admin in any org - admin_memberships = OrganizationMember.query.filter( - OrganizationMember.user_id == g.current_user.id, - OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]), - ).all() - - if not admin_memberships: - return api_response( - success=False, - message="Admin access required", - status=403, - error_type="FORBIDDEN", - ) - - cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first() - if not cfg: - return api_response( - success=False, - message=f"Provider '{provider}' is not configured", - status=404, - error_type="NOT_FOUND", - ) - - db.session.delete(cfg) - db.session.commit() - - return api_response( - message=f"{provider.capitalize()} OAuth provider configuration removed", - ) diff --git a/gatehouse_app/api/v1/external_auth/__init__.py b/gatehouse_app/api/v1/external_auth/__init__.py new file mode 100644 index 0000000..24261f6 --- /dev/null +++ b/gatehouse_app/api/v1/external_auth/__init__.py @@ -0,0 +1,2 @@ +"""External auth blueprint subpackage.""" +from gatehouse_app.api.v1.external_auth import cli, providers, oauth, admin diff --git a/gatehouse_app/api/v1/external_auth/_helpers.py b/gatehouse_app/api/v1/external_auth/_helpers.py new file mode 100644 index 0000000..f85be2a --- /dev/null +++ b/gatehouse_app/api/v1/external_auth/_helpers.py @@ -0,0 +1,94 @@ +"""Shared helpers for external_auth subpackage.""" +import logging +from gatehouse_app.utils.constants import AuthMethodType +from gatehouse_app.services.external_auth.models import ExternalAuthError + +_OAUTH_BRIDGE_TTL = 600 # 10 minutes + +logger = logging.getLogger(__name__) + +PROVIDER_TYPE_MAP = { + "google": AuthMethodType.GOOGLE, + "github": AuthMethodType.GITHUB, + "microsoft": AuthMethodType.MICROSOFT, +} + + +def get_provider_type(provider: str) -> AuthMethodType: + provider_lower = provider.lower() + if provider_lower not in PROVIDER_TYPE_MAP: + raise ExternalAuthError(f"Unsupported provider: {provider}", "UNSUPPORTED_PROVIDER", 400) + return PROVIDER_TYPE_MAP[provider_lower] + + +def _get_provider_endpoints(provider_type: AuthMethodType): + if provider_type == AuthMethodType.GOOGLE: + return ( + "https://accounts.google.com/o/oauth2/v2/auth", + "https://oauth2.googleapis.com/token", + "https://www.googleapis.com/oauth2/v3/userinfo", + ) + elif provider_type == AuthMethodType.GITHUB: + return ( + "https://github.com/login/oauth/authorize", + "https://github.com/login/oauth/access_token", + "https://api.github.com/user", + ) + elif provider_type == AuthMethodType.MICROSOFT: + return ( + "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", + "https://login.microsoftonline.com/common/oauth2/v2.0/token", + "https://graph.microsoft.com/oidc/userinfo", + ) + else: + raise ExternalAuthError(f"Unsupported provider: {provider_type}", "UNSUPPORTED_PROVIDER", 400) + + +def _store_oidc_bridge(oauth_state: str, oidc_session_id: str) -> None: + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + rc.setex(f"oauth_oidc_bridge:{oauth_state}", _OAUTH_BRIDGE_TTL, oidc_session_id) + except Exception: + pass + + +def _pop_oidc_bridge(oauth_state: str) -> str | None: + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + key = f"oauth_oidc_bridge:{oauth_state}" + val = rc.get(key) + if val: + rc.delete(key) + return val.decode() if isinstance(val, bytes) else val + except Exception: + pass + return None + + +def _store_cli_redirect(oauth_state: str, redirect_url: str) -> None: + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + rc.setex(f"oauth_cli_redirect:{oauth_state}", _OAUTH_BRIDGE_TTL, redirect_url) + except Exception: + pass + + +def _pop_cli_redirect(oauth_state: str) -> str | None: + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + key = f"oauth_cli_redirect:{oauth_state}" + val = rc.get(key) + if val: + rc.delete(key) + return val.decode() if isinstance(val, bytes) else val + except Exception: + pass + return None diff --git a/gatehouse_app/api/v1/external_auth/admin.py b/gatehouse_app/api/v1/external_auth/admin.py new file mode 100644 index 0000000..437fd37 --- /dev/null +++ b/gatehouse_app/api/v1/external_auth/admin.py @@ -0,0 +1,109 @@ +"""Admin application-level OAuth provider management.""" +from flask import g, request +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required + + +@api_v1_bp.route("/admin/oauth/providers", methods=["GET"]) +@login_required +def admin_list_app_providers(): + from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig + from gatehouse_app.models import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + + admin_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == g.current_user.id, + OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]), + ).all() + + if not admin_memberships: + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + PROVIDERS = [{"id": "google", "name": "Google"}, {"id": "github", "name": "GitHub"}, {"id": "microsoft", "name": "Microsoft"}] + db_configs = {c.provider_type: c for c in ApplicationProviderConfig.query.all()} + + result = [] + for p in PROVIDERS: + cfg = db_configs.get(p["id"]) + result.append({ + "id": p["id"], "name": p["name"], + "is_configured": cfg is not None, + "is_enabled": cfg.is_enabled if cfg else False, + "client_id": cfg.client_id if cfg else None, + }) + + return api_response(data={"providers": result}, message="OAuth providers retrieved successfully") + + +@api_v1_bp.route("/admin/oauth/providers/", methods=["PUT"]) +@login_required +def admin_configure_app_provider(provider: str): + from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig + from gatehouse_app.models import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + from gatehouse_app.extensions import db + + SUPPORTED = ["google", "github", "microsoft"] + if provider not in SUPPORTED: + return api_response(success=False, message=f"Unsupported provider. Must be one of: {', '.join(SUPPORTED)}", status=400, error_type="VALIDATION_ERROR") + + admin_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == g.current_user.id, + OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]), + ).all() + + if not admin_memberships: + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + data = request.json or {} + client_id = (data.get("client_id") or "").strip() + client_secret = (data.get("client_secret") or "").strip() + is_enabled = data.get("is_enabled", True) + + if not client_id: + return api_response(success=False, message="client_id is required", status=400, error_type="VALIDATION_ERROR") + + cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first() + if cfg: + cfg.client_id = client_id + if client_secret: + cfg.set_client_secret(client_secret) + cfg.is_enabled = bool(is_enabled) + db.session.commit() + else: + cfg = ApplicationProviderConfig(provider_type=provider, client_id=client_id, is_enabled=bool(is_enabled)) + if client_secret: + cfg.set_client_secret(client_secret) + db.session.add(cfg) + db.session.commit() + + return api_response( + data={"provider": {"id": provider, "client_id": cfg.client_id, "is_enabled": cfg.is_enabled}}, + message=f"{provider.capitalize()} OAuth provider configured successfully", + ) + + +@api_v1_bp.route("/admin/oauth/providers/", methods=["DELETE"]) +@login_required +def admin_delete_app_provider(provider: str): + from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig + from gatehouse_app.models import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + from gatehouse_app.extensions import db + + admin_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == g.current_user.id, + OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]), + ).all() + + if not admin_memberships: + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first() + if not cfg: + return api_response(success=False, message=f"Provider '{provider}' is not configured", status=404, error_type="NOT_FOUND") + + db.session.delete(cfg) + db.session.commit() + return api_response(message=f"{provider.capitalize()} OAuth provider configuration removed") diff --git a/gatehouse_app/api/v1/external_auth/cli.py b/gatehouse_app/api/v1/external_auth/cli.py new file mode 100644 index 0000000..e012756 --- /dev/null +++ b/gatehouse_app/api/v1/external_auth/cli.py @@ -0,0 +1,68 @@ +"""CLI token acquisition endpoints.""" +import secrets +import logging +from urllib.parse import quote +from flask import request, current_app, redirect as flask_redirect +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.api.v1.external_auth._helpers import _OAUTH_BRIDGE_TTL + +logger = logging.getLogger(__name__) + + +@api_v1_bp.route("/token_please", methods=["GET"]) +def token_please(): + redirect_url = request.args.get("redirect_url", "").strip() + + if not redirect_url: + return api_response(success=False, message="redirect_url query parameter is required", status=400, error_type="MISSING_REDIRECT_URL") + + from urllib.parse import urlparse as _urlparse + parsed = _urlparse(redirect_url) + if parsed.hostname not in ("localhost", "127.0.0.1"): + return api_response(success=False, message="redirect_url must point to localhost", status=400, error_type="INVALID_REDIRECT_URL") + + cli_token = secrets.token_urlsafe(32) + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + rc.setex(f"cli_redirect:{cli_token}", _OAUTH_BRIDGE_TTL, redirect_url) + else: + logger.warning("Redis not available; passing cli_redirect directly in URL") + cli_token = None + except Exception: + cli_token = None + + frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080") + + if cli_token: + login_url = f"{frontend_url}/login?cli_token={cli_token}" + else: + login_url = f"{frontend_url}/login?cli_redirect={quote(redirect_url, safe='')}" + + logger.info("CLI token_please: redirecting browser to Gatehouse login page") + return flask_redirect(login_url, code=302) + + +@api_v1_bp.route("/cli/redirect-url", methods=["GET"]) +def cli_redirect_url_lookup(): + cli_token = request.args.get("token", "").strip() + if not cli_token: + return api_response(success=False, message="token query parameter is required", status=400, error_type="MISSING_TOKEN") + + try: + import gatehouse_app.extensions as _ext + rc = _ext.redis_client + if rc is not None: + key = f"cli_redirect:{cli_token}" + val = rc.get(key) + if val is None: + return api_response(success=False, message="CLI token not found or expired", status=404, error_type="TOKEN_NOT_FOUND") + redirect_url = val.decode() if isinstance(val, bytes) else val + return api_response(data={"redirect_url": redirect_url}) + except Exception as e: + logger.error(f"cli_redirect_url_lookup error: {e}") + return api_response(success=False, message="Internal error looking up CLI token", status=500, error_type="INTERNAL_ERROR") + + return api_response(success=False, message="Redis not available", status=503, error_type="SERVICE_UNAVAILABLE") diff --git a/gatehouse_app/api/v1/external_auth/oauth.py b/gatehouse_app/api/v1/external_auth/oauth.py new file mode 100644 index 0000000..a5fea12 --- /dev/null +++ b/gatehouse_app/api/v1/external_auth/oauth.py @@ -0,0 +1,244 @@ +"""OAuth authorization and callback endpoints.""" +import json +import logging +from urllib.parse import urlencode +from flask import request, current_app, redirect as flask_redirect +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.services.external_auth.models import ExternalAuthError +from gatehouse_app.services.oauth_flow import OAuthFlowService, OAuthFlowError +from gatehouse_app.services.audit_service import AuditService +from gatehouse_app.api.v1.external_auth._helpers import ( + get_provider_type, _store_oidc_bridge, _pop_oidc_bridge, _pop_cli_redirect, +) + +logger = logging.getLogger(__name__) + + +@api_v1_bp.route("/auth/external//authorize", methods=["GET"]) +def initiate_oauth_authorize(provider: str): + flow = request.args.get("flow", "login") + redirect_uri = request.args.get("redirect_uri") + organization_id = request.args.get("organization_id") + oidc_session_id = request.args.get("oidc_session_id") + + if flow not in ["login", "register"]: + return api_response(success=False, message="Invalid flow type. Must be 'login' or 'register'", status=400, error_type="VALIDATION_ERROR") + + try: + provider_type = get_provider_type(provider) + if flow == "login": + auth_url, state = OAuthFlowService.initiate_login_flow( + provider_type=provider_type, organization_id=organization_id, redirect_uri=redirect_uri, + ) + else: + auth_url, state = OAuthFlowService.initiate_register_flow( + provider_type=provider_type, organization_id=organization_id, redirect_uri=redirect_uri, + ) + + if oidc_session_id: + _store_oidc_bridge(state, oidc_session_id) + + return api_response(data={"authorization_url": auth_url, "state": state}, message=f"OAuth {flow} flow initiated") + + except OAuthFlowError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) + except ExternalAuthError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) + + +@api_v1_bp.route("/auth/external//callback", methods=["GET"]) +def handle_oauth_callback(provider: str): + provider_type = get_provider_type(provider) + + state = request.args.get("state") + authorization_code = request.args.get("code") + error = request.args.get("error") + error_description = request.args.get("error_description") + + frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080") + frontend_callback = f"{frontend_url}/oauth/callback" + + cli_redirect_url = _pop_cli_redirect(state) if state else None + + def redirect_error(message: str, error_type: str = "OAUTH_ERROR"): + if cli_redirect_url: + from flask import make_response + return make_response( + f"

Authentication Error

{message}

" + f"

You may close this window.

", 400, + ) + params = {"error": message, "error_type": error_type} + if state: + params["state"] = state + return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302) + + if error: + msg = error_description or f"Authorization failed: {error}" + return redirect_error(msg, error.upper()) + + if not authorization_code or not state: + return redirect_error("Missing authorization code or state parameter.") + + try: + result = OAuthFlowService.handle_callback( + provider_type=provider_type, + authorization_code=authorization_code, + state=state, + redirect_uri=None, + error=None, + error_description=None, + ) + + if not result.get("success"): + return redirect_error("Authentication failed.", "AUTH_FAILED") + + flow_type = result.get("flow_type", "login") + + if flow_type == "link": + params = {"flow": "link", "provider": provider, "linked": "1"} + return flask_redirect(f"{frontend_url}/linked-accounts?{urlencode(params)}", code=302) + + oidc_session_id = _pop_oidc_bridge(state) + + if result.get("requires_org_selection") and not cli_redirect_url: + orgs = json.dumps(result.get("available_organizations", [])) + params = {"requires_org_selection": "1", "state": result["state"], "provider": provider, "flow": flow_type, "orgs": orgs} + if oidc_session_id: + params["oidc_session_id"] = oidc_session_id + return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302) + + if result.get("requires_org_creation") and not cli_redirect_url: + import json as _json + session_data = result.get("session", {}) + token = session_data.get("token", "") + expires_in = session_data.get("expires_in", 86400) + pending_invites = result.get("pending_invites", []) + params = { + "requires_org_creation": "1", "state": result["state"], "provider": provider, + "flow": flow_type, "token": token, "expires_in": str(expires_in), + "pending_invites": _json.dumps(pending_invites), + } + if oidc_session_id: + params["oidc_session_id"] = oidc_session_id + return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302) + + session_data = result.get("session", {}) + token = session_data.get("token") + expires_in = session_data.get("expires_in", 86400) + + if not token: + return redirect_error("No session token returned by server.", "NO_TOKEN") + + params = {"token": token, "expires_in": str(expires_in), "flow": flow_type, "provider": provider, "state": state} + user_info = result.get("user", {}) + if user_info.get("email"): + params["email"] = user_info["email"] + + if cli_redirect_url: + cli_final_url = cli_redirect_url + token + logger.info(f"CLI token_please success: provider={provider}, user={user_info.get('email')}, redirecting to CLI callback") + return flask_redirect(cli_final_url, code=302) + + if oidc_session_id: + params["oidc_session_id"] = oidc_session_id + + logger.info(f"OAuth callback success: provider={provider}, flow={flow_type}, user={user_info.get('email')}, redirecting to frontend") + return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302) + + except OAuthFlowError as e: + logger.warning(f"OAuth callback OAuthFlowError: {e.message}") + return redirect_error(e.message, e.error_type) + except Exception as e: + logger.error(f"OAuth callback unexpected error: {str(e)}", exc_info=True) + return redirect_error("An unexpected error occurred. Please try again.", "INTERNAL_ERROR") + + +@api_v1_bp.route("/auth/external/select-organization", methods=["POST"]) +def select_organization(): + from gatehouse_app.utils.constants import AuthMethodType as _AuthMethodType + from gatehouse_app.models import User, AuthenticationMethod, Organization, OrganizationMember + + data = request.json or {} + state_token = data.get("state") + organization_id = data.get("organization_id") + + if not state_token: + return api_response(success=False, message="state is required", status=400, error_type="VALIDATION_ERROR") + if not organization_id: + return api_response(success=False, message="organization_id is required", status=400, error_type="VALIDATION_ERROR") + + try: + state_record = OAuthFlowService.validate_state(state_token) + if not state_record or state_record.used: + return api_response(success=False, message="Invalid or expired state token", status=400, error_type="INVALID_STATE") + + auth_method = AuthenticationMethod.query.filter_by( + method_type=state_record.provider_type, + ).order_by(AuthenticationMethod.created_at.desc()).first() + + if not auth_method: + return api_response(success=False, message="Authentication session not found", status=400, error_type="SESSION_NOT_FOUND") + + user = auth_method.user + + org = Organization.query.get(organization_id) + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + member = OrganizationMember.query.filter_by(user_id=user.id, organization_id=organization_id).first() + if not member: + return api_response(success=False, message="You are not a member of this organization", status=403, error_type="FORBIDDEN") + + from gatehouse_app.services.session_service import SessionService + session = SessionService.create_session(user=user, organization_id=organization_id) + state_record.mark_used() + + provider_type_val = state_record.provider_type.value if isinstance(state_record.provider_type, _AuthMethodType) else state_record.provider_type + AuditService.log_external_auth_login( + user_id=user.id, organization_id=organization_id, provider_type=provider_type_val, + provider_user_id=auth_method.provider_user_id, + auth_method_id=auth_method.id, session_id=session.id, + ) + + return api_response( + data={ + "token": session.token, "expires_in": session.lifetime_seconds, "token_type": "Bearer", + "user": {"id": user.id, "email": user.email, "full_name": user.full_name, "organization_id": organization_id}, + }, + message="Organization selected and session created successfully", + ) + except Exception as e: + logger.error(f"Error in select_organization: {str(e)}", exc_info=True) + return api_response(success=False, message="An error occurred while selecting organization", status=500, error_type="INTERNAL_ERROR") + + +@api_v1_bp.route("/auth/external/token", methods=["POST"]) +def exchange_authorization_code(): + if request.is_json: + data = request.json or {} + else: + data = request.form or {} + + grant_type = data.get("grant_type") + code = data.get("code") + redirect_uri = data.get("redirect_uri") + client_id = data.get("client_id", "external-app") + + if grant_type and grant_type != "authorization_code": + return api_response(success=False, message="Invalid grant_type. Must be 'authorization_code'", status=400, error_type="INVALID_GRANT_TYPE") + if not code: + return api_response(success=False, message="code is required", status=400, error_type="VALIDATION_ERROR") + if not redirect_uri: + return api_response(success=False, message="redirect_uri is required", status=400, error_type="VALIDATION_ERROR") + + try: + result = OAuthFlowService.exchange_authorization_code( + code=code, client_id=client_id, redirect_uri=redirect_uri, ip_address=request.remote_addr, + ) + return api_response( + data={"token": result["token"], "expires_in": result["expires_in"], "token_type": result["token_type"], "user": result["user"]}, + message="Token exchanged successfully", + ) + except OAuthFlowError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) diff --git a/gatehouse_app/api/v1/external_auth/providers.py b/gatehouse_app/api/v1/external_auth/providers.py new file mode 100644 index 0000000..cd078c0 --- /dev/null +++ b/gatehouse_app/api/v1/external_auth/providers.py @@ -0,0 +1,201 @@ +"""External auth provider config endpoints (admin and user).""" +from flask import g, request +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required +from gatehouse_app.services.external_auth import ExternalAuthService +from gatehouse_app.services.external_auth.models import ExternalAuthError, ExternalProviderConfig +from gatehouse_app.services.audit_service import AuditService +from gatehouse_app.api.v1.external_auth._helpers import get_provider_type, _get_provider_endpoints + + +@api_v1_bp.route("/auth/external/providers", methods=["GET"]) +@login_required +def list_providers(): + from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig + + app_configs = {c.provider_type.lower(): c for c in ApplicationProviderConfig.query.filter_by(is_enabled=True).all()} + + user_orgs = g.current_user.get_organizations() + org_configs = {} + if user_orgs: + organization_id = user_orgs[0].id + org_level = ExternalProviderConfig.query.filter_by(organization_id=organization_id).all() + org_configs = {c.provider_type.lower(): c for c in org_level} + + def provider_info(provider_id, name): + app_cfg = app_configs.get(provider_id) + org_cfg = org_configs.get(provider_id) + is_configured = app_cfg is not None or org_cfg is not None + is_active = bool(app_cfg.is_enabled) if app_cfg else False + if org_cfg and hasattr(org_cfg, "is_active"): + is_active = bool(org_cfg.is_active) + return {"id": provider_id, "name": name, "type": provider_id, "is_configured": is_configured, "is_active": is_active, + "settings": {"requires_domain": False, "supports_refresh_tokens": True}} + + providers = [provider_info("google", "Google"), provider_info("github", "GitHub"), provider_info("microsoft", "Microsoft")] + return api_response(data={"providers": providers}, message="Providers retrieved successfully") + + +@api_v1_bp.route("/auth/external/providers//config", methods=["GET"]) +@login_required +def get_provider_config(provider: str): + from gatehouse_app.models import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + + provider_type = get_provider_type(provider) + + user_orgs = g.current_user.get_organizations() + if not user_orgs: + return api_response(success=False, message="No organizations found for user", status=400, error_type="BAD_REQUEST") + + organization_id = user_orgs[0].id + member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id).first() + if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value).first() + if not config: + return api_response(success=False, message=f"{provider.title()} OAuth is not configured", status=404, error_type="NOT_FOUND") + + return api_response(data=config.to_dict(include_secrets=False), message="Provider configuration retrieved successfully") + + +@api_v1_bp.route("/auth/external/providers//config", methods=["POST"]) +@login_required +def create_or_update_provider_config(provider: str): + from gatehouse_app.models import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + + provider_type = get_provider_type(provider) + + user_orgs = g.current_user.get_organizations() + if not user_orgs: + return api_response(success=False, message="No organizations found for user", status=400, error_type="BAD_REQUEST") + + organization_id = user_orgs[0].id + member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id).first() + if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + data = request.json or {} + client_id = data.get("client_id") + client_secret = data.get("client_secret") + + if not client_id: + return api_response(success=False, message="client_id is required", status=400, error_type="VALIDATION_ERROR") + + config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value).first() + is_new = config is None + + if config: + config.client_id = client_id + if client_secret: + config.set_client_secret(client_secret) + config.scopes = data.get("scopes", ["openid", "profile", "email"]) + config.redirect_uris = data.get("redirect_uris", []) + config.settings = data.get("settings", {}) + config.is_active = data.get("is_active", True) + config.save() + AuditService.log_external_auth_config_update( + user_id=g.current_user.id, organization_id=organization_id, provider_type=provider_type.value, + config_id=config.id, + changes={"client_id": "updated", "client_secret": "updated" if client_secret else None, + "scopes": data.get("scopes"), "redirect_uris": data.get("redirect_uris"), "is_active": config.is_active}, + ) + else: + auth_url, token_url, userinfo_url = _get_provider_endpoints(provider_type) + config = ExternalProviderConfig( + organization_id=organization_id, provider_type=provider_type.value, + client_id=client_id, client_secret_encrypted=None, + auth_url=auth_url, token_url=token_url, userinfo_url=userinfo_url, + scopes=data.get("scopes", ["openid", "profile", "email"]), + redirect_uris=data.get("redirect_uris", []), settings=data.get("settings", {}), + is_active=data.get("is_active", True), + ) + if client_secret: + config.set_client_secret(client_secret) + config.save() + AuditService.log_external_auth_config_create( + user_id=g.current_user.id, organization_id=organization_id, + provider_type=provider_type.value, config_id=config.id, + ) + + return api_response(data=config.to_dict(include_secrets=False), message="Provider configuration saved successfully", status=201 if is_new else 200) + + +@api_v1_bp.route("/auth/external/providers//config", methods=["DELETE"]) +@login_required +def delete_provider_config(provider: str): + from gatehouse_app.models import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + + provider_type = get_provider_type(provider) + + user_orgs = g.current_user.get_organizations() + if not user_orgs: + return api_response(success=False, message="No organizations found for user", status=400, error_type="BAD_REQUEST") + + organization_id = user_orgs[0].id + member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id).first() + if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value).first() + if not config: + return api_response(success=False, message=f"{provider.title()} OAuth is not configured", status=404, error_type="NOT_FOUND") + + config_id = config.id + config.delete() + AuditService.log_external_auth_config_delete( + user_id=g.current_user.id, organization_id=organization_id, + provider_type=provider_type.value, config_id=config_id, + ) + return api_response(message=f"{provider.title()} provider configuration deleted successfully") + + +@api_v1_bp.route("/auth/external/linked-accounts", methods=["GET"]) +@login_required +def list_linked_accounts(): + from gatehouse_app.models import AuthenticationMethod + + linked_accounts = ExternalAuthService.get_linked_accounts(g.current_user.id) + other_methods = AuthenticationMethod.query.filter_by(user_id=g.current_user.id, deleted_at=None).count() + return api_response(data={"linked_accounts": linked_accounts, "unlink_available": other_methods > 1}, message="Linked accounts retrieved successfully") + + +@api_v1_bp.route("/auth/external//link", methods=["POST"]) +@login_required +def initiate_link_account(provider: str): + provider_type = get_provider_type(provider) + + user_orgs = g.current_user.get_organizations() + organization_id = user_orgs[0].id if user_orgs else None + data = request.json or {} + redirect_uri = data.get("redirect_uri") + + try: + auth_url, state = ExternalAuthService.initiate_link_flow( + user_id=g.current_user.id, provider_type=provider_type, + organization_id=organization_id, redirect_uri=redirect_uri, + ) + return api_response(data={"authorization_url": auth_url, "state": state}, message="Link flow initiated. Redirect to authorization URL.") + except ExternalAuthError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) + + +@api_v1_bp.route("/auth/external//unlink", methods=["DELETE"]) +@login_required +def unlink_account(provider: str): + provider_type = get_provider_type(provider) + + user_orgs = g.current_user.get_organizations() + organization_id = user_orgs[0].id if user_orgs else None + + try: + ExternalAuthService.unlink_provider( + user_id=g.current_user.id, provider_type=provider_type, organization_id=organization_id, + ) + return api_response(message=f"{provider.title()} account unlinked successfully") + except ExternalAuthError as e: + return api_response(success=False, message=e.message, status=e.status_code, error_type=e.error_type) diff --git a/gatehouse_app/api/v1/organizations.py b/gatehouse_app/api/v1/organizations.py deleted file mode 100644 index 447ac0d..0000000 --- a/gatehouse_app/api/v1/organizations.py +++ /dev/null @@ -1,1888 +0,0 @@ -"""Organization endpoints.""" -from flask import g, request, current_app -from marshmallow import ValidationError -from gatehouse_app.api.v1 import api_v1_bp -from gatehouse_app.utils.response import api_response -from gatehouse_app.utils.decorators import login_required, require_admin, require_owner, full_access_required -from gatehouse_app.schemas.organization_schema import ( - OrganizationCreateSchema, - OrganizationUpdateSchema, - InviteMemberSchema, - UpdateMemberRoleSchema, -) -from gatehouse_app.services.organization_service import OrganizationService -from gatehouse_app.services.user_service import UserService -from gatehouse_app.utils.constants import OrganizationRole -from gatehouse_app.extensions import db - - - -def _get_system_ca_dict(): - """Return a synthetic read-only CA dict for the config-file CA, or None. - - This is injected into the org CA list when no DB CA exists for a given - ca_type so that the admin UI correctly shows "configured" rather than - "Not configured" when a system-level CA key is present. - - The returned dict has ``is_system=True`` so the frontend can render it - as read-only (no delete / edit / generate buttons). - """ - import os - try: - from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config - from gatehouse_app.utils.crypto import compute_ssh_fingerprint - - # Check env var first (takes priority over file path) - priv_key = os.environ.get("SSH_CA_PRIVATE_KEY", "").strip() - pub_key = "" - - if not priv_key: - cfg = get_ssh_ca_config() - key_path = cfg.get_str("ca_key_path", "").strip() - if not key_path: - return None - pub_path = key_path + ".pub" - if not os.path.exists(pub_path): - return None - with open(pub_path) as f: - pub_key = f.read().strip() - else: - # Derive the public key from the private key - from sshkey_tools.keys import PrivateKey - pk = PrivateKey.from_string(priv_key) - pub_key = pk.public_key.to_string() - - fingerprint = compute_ssh_fingerprint(pub_key) - return { - "id": f"system-ca-{fingerprint[:16]}", - "organization_id": None, - "name": "System CA (config file)", - "description": ( - "Read-only — this CA is loaded from the server's SSH_CA_PRIVATE_KEY " - "environment variable or etc/ssh_ca.conf. Manage it on the server." - ), - # ca_type is set by the caller - "ca_type": "user", - "key_type": "unknown", - "public_key": pub_key, - "fingerprint": fingerprint, - "is_active": True, - "is_system": True, - "default_cert_validity_hours": 0, - "max_cert_validity_hours": 0, - "total_certs": 0, - "active_certs": 0, - "revoked_certs": 0, - "created_at": None, - "updated_at": None, - } - except Exception: - return None - - - -@api_v1_bp.route("/organizations", methods=["POST"]) -@login_required -@full_access_required -def create_organization(): - """ - Create a new organization. - - Request body: - name: Organization name - slug: Organization slug (unique) - description: Optional description - logo_url: Optional logo URL - - Returns: - 201: Organization created successfully - 400: Validation error - 401: Not authenticated - 409: Slug already exists - """ - try: - # Validate request data - schema = OrganizationCreateSchema() - data = schema.load(request.json) - - # Create organization - org = OrganizationService.create_organization( - name=data["name"], - slug=data["slug"], - owner_user_id=g.current_user.id, - description=data.get("description"), - logo_url=data.get("logo_url"), - ) - - return api_response( - data={"organization": org.to_dict()}, - message="Organization created successfully", - status=201, - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - -@api_v1_bp.route("/organizations/", methods=["GET"]) -@login_required -@full_access_required -def get_organization(org_id): - """ - Get organization by ID. - - Args: - org_id: Organization ID - - Returns: - 200: Organization data - 401: Not authenticated - 403: Not a member - 404: Organization not found - """ - org = OrganizationService.get_organization_by_id(org_id) - - # Check if user is a member - if not org.is_member(g.current_user.id): - return api_response( - success=False, - message="You are not a member of this organization", - status=403, - error_type="AUTHORIZATION_ERROR", - ) - - return api_response( - data={ - "organization": org.to_dict(), - "member_count": org.get_member_count(), - }, - message="Organization retrieved successfully", - ) - - -@api_v1_bp.route("/organizations/", methods=["PATCH"]) -@login_required -@require_admin -@full_access_required -def update_organization(org_id): - """ - Update organization. - - Args: - org_id: Organization ID - - Request body: - name: Optional organization name - description: Optional description - logo_url: Optional logo URL - - Returns: - 200: Organization updated successfully - 400: Validation error - 401: Not authenticated - 403: Not an admin - 404: Organization not found - """ - try: - # Validate request data - schema = OrganizationUpdateSchema() - data = schema.load(request.json) - - org = OrganizationService.get_organization_by_id(org_id) - - # Update organization - org = OrganizationService.update_organization( - org=org, - user_id=g.current_user.id, - **data - ) - - return api_response( - data={"organization": org.to_dict()}, - message="Organization updated successfully", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - -@api_v1_bp.route("/organizations/", methods=["DELETE"]) -@login_required -@full_access_required -def delete_organization(org_id): - """ - Delete organization (soft delete). - - Only the OWNER of the organization may call this endpoint. - - When the organization has other active members the caller must explicitly - confirm the deletion by sending ``{"confirm": true}`` in the request body. - All members (and their memberships) are soft-deleted together with the org - in a single atomic transaction so no orphaned data is left behind. - - Args: - org_id: Organization ID - - Request body (JSON, optional): - confirm (bool): Required when the org has other active members. - - Returns: - 200: Organization deleted successfully - 400: Organization has other members but confirm was not true - 401: Not authenticated - 403: Not the owner - 404: Organization not found - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember as _OrgMember - from gatehouse_app.utils.constants import OrganizationRole as _OrgRole - - caller = g.current_user - - org = OrganizationService.get_organization_by_id(org_id) - - # Only the owner may delete the organization. - caller_membership = _OrgMember.query.filter_by( - user_id=caller.id, - organization_id=org.id, - deleted_at=None, - ).first() - - if not caller_membership or caller_membership.role != _OrgRole.OWNER: - return api_response( - success=False, - message="Only the organization owner can delete the organization.", - status=403, - error_type="AUTHORIZATION_ERROR", - ) - - # If other members exist, require explicit confirmation to avoid accidents. - active_member_count = org.get_member_count() - if active_member_count > 1: - data = request.get_json(silent=True) or {} - if not data.get("confirm"): - return api_response( - success=False, - message=( - f"This organization has {active_member_count} active members. " - "Deleting it will remove all members and their data. " - 'Send {"confirm": true} to confirm.' - ), - status=400, - error_type="CONFIRMATION_REQUIRED", - error_details={"member_count": active_member_count}, - ) - - OrganizationService.force_delete_organization( - org=org, - user_id=caller.id, - ) - - return api_response( - message="Organization deleted successfully", - ) - - -@api_v1_bp.route("/organizations//members", methods=["GET"]) -@login_required -@full_access_required -def get_organization_members(org_id): - """ - Get all members of an organization. - - Args: - org_id: Organization ID - - Returns: - 200: List of members - 401: Not authenticated - 403: Not a member - 404: Organization not found - """ - org = OrganizationService.get_organization_by_id(org_id) - - # Check if user is a member - if not org.is_member(g.current_user.id): - return api_response( - success=False, - message="You are not a member of this organization", - status=403, - error_type="AUTHORIZATION_ERROR", - ) - - members_data = [] - for member in org.members: - if member.deleted_at is None: - member_dict = member.to_dict() - member_dict["user"] = member.user.to_dict() - members_data.append(member_dict) - - return api_response( - data={ - "members": members_data, - "count": len(members_data), - }, - message="Members retrieved successfully", - ) - - -@api_v1_bp.route("/organizations//members", methods=["POST"]) -@login_required -@require_admin -@full_access_required -def add_organization_member(org_id): - """ - Add a member to the organization. - - Args: - org_id: Organization ID - - Request body: - email: User email to invite - role: Member role (owner, admin, member, guest) - - Returns: - 201: Member added successfully - 400: Validation error - 401: Not authenticated - 403: Not an admin - 404: Organization or user not found - 409: User already a member - """ - try: - # Validate request data - schema = InviteMemberSchema() - data = schema.load(request.json) - - org = OrganizationService.get_organization_by_id(org_id) - - # Find user by email - user = UserService.get_user_by_email(data["email"]) - if not user: - return api_response( - success=False, - message="User not found", - status=404, - error_type="NOT_FOUND", - ) - - # Add member - role = OrganizationRole(data["role"]) - member = OrganizationService.add_member( - org=org, - user_id=user.id, - role=role, - inviter_id=g.current_user.id, - ) - - member_dict = member.to_dict() - member_dict["user"] = user.to_dict() - - return api_response( - data={"member": member_dict}, - message="Member added successfully", - status=201, - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - -@api_v1_bp.route("/organizations//members/", methods=["DELETE"]) -@login_required -@require_admin -@full_access_required -def remove_organization_member(org_id, user_id): - """ - Remove a member from the organization. - - Args: - org_id: Organization ID - user_id: User ID to remove - - Returns: - 200: Member removed successfully - 401: Not authenticated - 403: Not an admin - 404: Organization or member not found - """ - org = OrganizationService.get_organization_by_id(org_id) - - OrganizationService.remove_member( - org=org, - user_id=user_id, - remover_id=g.current_user.id, - ) - - return api_response( - message="Member removed successfully", - ) - - -@api_v1_bp.route("/organizations//members//role", methods=["PATCH"]) -@login_required -@require_admin -@full_access_required -def update_member_role(org_id, user_id): - """ - Update a member's role. - - Args: - org_id: Organization ID - user_id: User ID - - Request body: - role: New role (owner, admin, member, guest) - - Returns: - 200: Role updated successfully - 400: Validation error - 401: Not authenticated - 403: Not an admin - 404: Organization or member not found - """ - try: - # Validate request data - schema = UpdateMemberRoleSchema() - data = schema.load(request.json) - - org = OrganizationService.get_organization_by_id(org_id) - - # Update role - new_role = OrganizationRole(data["role"]) - member = OrganizationService.update_member_role( - org=org, - user_id=user_id, - new_role=new_role, - updater_id=g.current_user.id, - ) - - member_dict = member.to_dict() - member_dict["user"] = member.user.to_dict() - - return api_response( - data={"member": member_dict}, - message="Member role updated successfully", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - -@api_v1_bp.route("/organizations//transfer-ownership", methods=["POST"]) -@login_required -@full_access_required -def transfer_organization_ownership(org_id): - """Transfer organization ownership from the current user to another member. - - Only the current OWNER of the organization may call this endpoint. - The caller will be demoted to ADMIN and the target user will be promoted to OWNER. - - Request body: - new_owner_user_id (str): UUID of the member to promote to OWNER. - - Returns: - 200: Ownership transferred successfully - 400: Validation error / missing fields - 403: Caller is not the OWNER of this org - 404: Organization or target member not found - 409: Target is already the OWNER - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole, AuditAction - from gatehouse_app.services.audit_service import AuditService - - caller = g.current_user - - data = request.get_json() or {} - new_owner_user_id = data.get("new_owner_user_id") - if not new_owner_user_id: - return api_response( - success=False, - message="new_owner_user_id is required", - status=400, - error_type="VALIDATION_ERROR", - ) - - if str(new_owner_user_id) == str(caller.id): - return api_response( - success=False, - message="You are already the owner of this organization.", - status=409, - error_type="CONFLICT", - ) - - # Fetch org (raises NotFound internally) - org = OrganizationService.get_organization_by_id(org_id) - - # Confirm caller is the current OWNER - caller_membership = OrganizationMember.query.filter_by( - organization_id=org.id, - user_id=caller.id, - deleted_at=None, - ).first() - if not caller_membership or caller_membership.role != OrganizationRole.OWNER: - return api_response( - success=False, - message="Only the organization owner can transfer ownership.", - status=403, - error_type="AUTHORIZATION_ERROR", - ) - - # Verify the target is an active member - target_membership = OrganizationMember.query.filter_by( - organization_id=org.id, - user_id=new_owner_user_id, - deleted_at=None, - ).first() - if not target_membership: - return api_response( - success=False, - message="Target user is not a member of this organization.", - status=404, - error_type="NOT_FOUND", - ) - - if target_membership.role == OrganizationRole.OWNER: - return api_response( - success=False, - message="Target user is already the owner.", - status=409, - error_type="CONFLICT", - ) - - # ── Atomic role swap ───────────────────────────────────────────────────── - # Demote caller → ADMIN, promote target → OWNER. - # Both updates go through OrganizationService so all hooks/auditing fire. - try: - demoted = OrganizationService.update_member_role( - org=org, - user_id=str(caller.id), - new_role=OrganizationRole.ADMIN, - updater_id=str(caller.id), - ) - promoted = OrganizationService.update_member_role( - org=org, - user_id=str(new_owner_user_id), - new_role=OrganizationRole.OWNER, - updater_id=str(caller.id), - ) - except Exception as exc: - from gatehouse_app.extensions import db as _db - _db.session.rollback() - return api_response( - success=False, - message=f"Failed to transfer ownership: {exc}", - status=500, - error_type="SERVER_ERROR", - ) - - AuditService.log_action( - action=AuditAction.ORG_OWNERSHIP_TRANSFERRED, - user_id=caller.id, - organization_id=org.id, - resource_type="organization", - resource_id=str(org.id), - description=( - f"Ownership of '{org.name}' transferred from {caller.email} " - f"to {target_membership.user.email if target_membership.user else new_owner_user_id}" - ), - metadata={ - "previous_owner_id": str(caller.id), - "previous_owner_email": caller.email, - "new_owner_id": str(new_owner_user_id), - "new_owner_email": ( - target_membership.user.email if target_membership.user else None - ), - }, - ) - - def _member_dict(m): - d = m.to_dict() - if m.user: - d["user"] = m.user.to_dict() - return d - - return api_response( - data={ - "previous_owner": _member_dict(demoted), - "new_owner": _member_dict(promoted), - }, - message=( - f"Ownership of '{org.name}' successfully transferred to " - f"{target_membership.user.email if target_membership.user else new_owner_user_id}." - ), - ) - - -@api_v1_bp.route("/organizations//audit-logs", methods=["GET"]) -@login_required -@require_admin -@full_access_required -def get_organization_audit_logs(org_id): - """ - Get audit logs for an organization. - - Query params: - page: Page number (default 1) - per_page: Results per page (default 50, max 200) - action: Filter by action type - - Returns: - 200: List of audit log entries - 401: Not authenticated - 403: Not a member / insufficient permissions - 404: Organization not found - """ - from gatehouse_app.models.auth.audit_log import AuditLog - - # Ensure org exists and user is a member (full_access_required handles this) - OrganizationService.get_organization_by_id(org_id) - - page = int(request.args.get("page", 1)) - per_page = min(int(request.args.get("per_page", 50)), 200) - action_filter = request.args.get("action") - - query = AuditLog.query.filter_by(organization_id=org_id) - if action_filter: - query = query.filter_by(action=action_filter) - - query = query.order_by(AuditLog.created_at.desc()) - total = query.count() - logs = query.offset((page - 1) * per_page).limit(per_page).all() - - def log_to_dict(log): - return { - "id": log.id, - "action": log.action.value if log.action else None, - "user_id": log.user_id, - "user_email": log.user.email if log.user else None, - "user": {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} if log.user else None, - "organization_id": log.organization_id, - "resource_type": log.resource_type, - "resource_id": log.resource_id, - "ip_address": log.ip_address, - "user_agent": log.user_agent, - "request_id": log.request_id, - "description": log.description, - "success": log.success, - "error_message": log.error_message, - "metadata": log.extra_data, - "created_at": log.created_at.isoformat() if log.created_at else None, - "updated_at": log.updated_at.isoformat() if log.updated_at else None, - } - - return api_response( - data={ - "audit_logs": [log_to_dict(log) for log in logs], - "count": total, - "page": page, - "per_page": per_page, - "pages": (total + per_page - 1) // per_page, - }, - message="Audit logs retrieved successfully", - ) - - -# ============================================================================ -# Organization Invite Tokens -# ============================================================================ - -@api_v1_bp.route("/organizations//invites", methods=["POST"]) -@login_required -@require_admin -def create_org_invite(org_id): - """Create an invite token for an organization. - - Request body: - email: Email address to invite - role: Role to assign (default: member) - - Returns: - 201: Invite created - 400: Validation error - 403: Not an admin - 404: Organization not found - """ - from gatehouse_app.models import OrgInviteToken, Organization - from gatehouse_app.services.notification_service import NotificationService - from flask import current_app - - org = Organization.query.filter_by(id=org_id, deleted_at=None).first() - if not org: - return api_response(success=False, message="Organization not found", status=404) - - data = request.get_json() or {} - email = (data.get("email") or "").strip().lower() - role = (data.get("role") or "member").strip() - - if not email: - return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR") - - invite = OrgInviteToken.generate( - organization_id=org_id, - email=email, - role=role, - invited_by_id=g.current_user.id, - ) - - app_url = current_app.config.get("APP_URL", "http://localhost:8080") - invite_link = f"{app_url}/invite?token={invite.token}" - - email_sent = NotificationService._send_email( - to_address=email, - subject=f"You're invited to join {org.name} on Gatehouse", - body=( - f"You've been invited to join {org.name} on Gatehouse.\n\n" - f"Click the link below to accept the invitation (valid for 7 days):\n" - f"{invite_link}\n\n" - f"Gatehouse Security Team" - ), - ) - - # In dev mode email may not be configured — always log the link so it's findable - import logging - if not email_sent: - logging.getLogger(__name__).warning( - f"[INVITE LINK] Email not sent (EMAIL_ENABLED=False or SMTP down). " - f"Invite for {email} → {invite_link}" - ) - else: - logging.getLogger(__name__).info( - f"[INVITE] Email sent successfully to {email}" - ) - - response_data = { - "invite": { - "id": invite.id, - "email": invite.email, - "role": invite.role, - "expires_at": invite.expires_at.isoformat() + "Z", - # Only include invite_link when email delivery failed — signals frontend to show copy dialog - **({"invite_link": invite_link} if not email_sent else {}), - } - } - - return api_response( - data=response_data, - message="Invite sent successfully", - status=201, - ) - - -@api_v1_bp.route("/organizations//invites", methods=["GET"]) -@login_required -@require_admin -def list_org_invites(org_id): - """List pending invite tokens for an organization. - - Returns: - 200: List of invites - 403: Not an admin - 404: Organization not found - """ - from gatehouse_app.models import OrgInviteToken, Organization - - org = Organization.query.filter_by(id=org_id, deleted_at=None).first() - if not org: - return api_response(success=False, message="Organization not found", status=404) - - invites = ( - OrgInviteToken.query.filter_by(organization_id=org_id) - .filter(OrgInviteToken.accepted_at == None) - .filter(OrgInviteToken.deleted_at == None) - .all() - ) - - def invite_to_dict(inv): - return { - "id": inv.id, - "email": inv.email, - "role": inv.role, - "invited_by_id": inv.invited_by_id, - "created_at": inv.created_at.isoformat() + "Z", - "expires_at": inv.expires_at.isoformat() + "Z", - } - - return api_response( - data={"invites": [invite_to_dict(i) for i in invites]}, - message="Invites retrieved", - ) - - -@api_v1_bp.route("/organizations//invites/", methods=["DELETE"]) -@login_required -@require_admin -def cancel_org_invite(org_id, invite_id): - """Cancel (soft-delete) an organization invite. - - Returns: - 200: Invite cancelled - 403: Not an admin - 404: Invite not found - """ - from gatehouse_app.models import OrgInviteToken, Organization - - org = Organization.query.filter_by(id=org_id, deleted_at=None).first() - if not org: - return api_response(success=False, message="Organization not found", status=404) - - invite = OrgInviteToken.query.filter_by(id=invite_id, organization_id=org_id, deleted_at=None).first() - if not invite: - return api_response(success=False, message="Invite not found", status=404) - - # Soft delete the invite so it's no longer usable - invite.delete(soft=True) - - return api_response(data={}, message="Invite cancelled") - - -@api_v1_bp.route("/invites/", methods=["GET"]) -def get_invite(token): - """Get invite details by token. - - Returns: - 200: Invite details (org name, email) - 400: Invalid or expired token - """ - from gatehouse_app.models import OrgInviteToken, User - - invite = OrgInviteToken.query.filter_by(token=token).first() - if not invite or not invite.is_valid: - return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN") - - user_exists = User.query.filter_by(email=invite.email, deleted_at=None).first() is not None - - return api_response( - data={ - "email": invite.email, - "organization": {"id": invite.organization_id, "name": invite.organization.name}, - "role": invite.role, - "user_exists": user_exists, - }, - message="Invite found", - ) - - -@api_v1_bp.route("/invites//accept", methods=["POST"]) -def accept_invite(token): - """Accept an organization invite. - - Creates the user account (if not already registered) and adds them - to the organization. - - Request body: - full_name: User's display name - password: Password for new account (if not already registered) - password_confirm: Password confirmation - - Returns: - 200: Invite accepted, returns user token - 400: Invalid/expired token or validation error - 409: Already a member - """ - from gatehouse_app.models import OrgInviteToken, User - from gatehouse_app.services.auth_service import AuthService - from gatehouse_app.services.organization_service import OrganizationService - from gatehouse_app.utils.constants import OrganizationRole - - invite = OrgInviteToken.query.filter_by(token=token).first() - if not invite or not invite.is_valid: - return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN") - - data = request.get_json() or {} - full_name = data.get("full_name") or "" - password = data.get("password") or "" - password_confirm = data.get("password_confirm") or "" - - user = User.query.filter_by(email=invite.email, deleted_at=None).first() - - if not user: - # Register a new user - if not password: - return api_response(success=False, message="Password is required for new accounts.", status=400, error_type="VALIDATION_ERROR") - if password != password_confirm: - return api_response(success=False, message="Passwords do not match.", status=400, error_type="VALIDATION_ERROR") - if len(password) < 8: - return api_response(success=False, message="Password must be at least 8 characters.", status=400, error_type="VALIDATION_ERROR") - try: - user = AuthService.register_user(email=invite.email, password=password, full_name=full_name or None) - except Exception as exc: - return api_response(success=False, message=str(exc), status=400, error_type="REGISTRATION_ERROR") - - # Add to org - role_value = invite.role - try: - org_role = OrganizationRole(role_value) - except ValueError: - org_role = OrganizationRole.MEMBER - - try: - OrganizationService.add_member( - org=invite.organization, - user_id=user.id, - role=org_role, - inviter_id=invite.invited_by_id, - ) - except Exception: - from gatehouse_app.extensions import db - db.session.rollback() # Clear broken transaction so invite.accept() can commit - - invite.accept() - - has_webauthn = user.has_webauthn_enabled() - has_totp = user.has_totp_enabled() - - if has_webauthn: - from flask import session as flask_session - flask_session["webauthn_pending_user_id"] = user.id - return api_response( - data={"requires_webauthn": True}, - message="Passkey verification required. Please use your passkey to complete sign-in.", - ) - - if has_totp: - from flask import session as flask_session - flask_session["totp_pending_user_id"] = user.id - return api_response( - data={"requires_totp": True}, - message="TOTP code required. Please enter your 6-digit code from your authenticator app.", - ) - - user_session = AuthService.create_session(user) - - return api_response( - data={ - "user": user.to_dict(), - "token": user_session.token, - "expires_at": user_session.expires_at.isoformat() + "Z", - }, - message="Invitation accepted. Welcome!", - ) - - -# ============================================================================ -# Organization OIDC Clients -# ============================================================================ - -@api_v1_bp.route("/organizations//clients", methods=["GET"]) -@login_required -@require_admin -@full_access_required -def list_org_clients(org_id): - """List OIDC clients for an organization. - - Returns: - 200: List of OIDC clients - 403: Not an admin - 404: Organization not found - """ - from gatehouse_app.models import OIDCClient, Organization - - org = Organization.query.filter_by(id=org_id, deleted_at=None).first() - if not org: - return api_response(success=False, message="Organization not found", status=404) - - clients = OIDCClient.query.filter_by(organization_id=org_id, is_active=True).all() - - def client_to_dict(c): - return { - "id": c.id, - "name": c.name, - "client_id": c.client_id, - "redirect_uris": c.redirect_uris, - "scopes": c.scopes, - "grant_types": c.grant_types, - "is_active": c.is_active, - "created_at": c.created_at.isoformat() + "Z", - } - - return api_response( - data={"clients": [client_to_dict(c) for c in clients], "count": len(clients)}, - message="Clients retrieved successfully", - ) - - -@api_v1_bp.route("/organizations//clients", methods=["POST"]) -@login_required -@require_admin -def create_org_client(org_id): - """Create a new OIDC client for an organization. - - Request body: - name: Client name - redirect_uris: List of allowed redirect URIs (newline or comma separated string) - - Returns: - 201: Client created with client_id and client_secret - 403: Not an admin - 404: Organization not found - """ - import secrets as _secrets - from gatehouse_app.extensions import bcrypt - from gatehouse_app.models import OIDCClient, Organization - - org = Organization.query.filter_by(id=org_id, deleted_at=None).first() - if not org: - return api_response(success=False, message="Organization not found", status=404) - - data = request.get_json() or {} - name = (data.get("name") or "").strip() - redirect_uris_raw = data.get("redirect_uris") or [] - - if not name: - return api_response(success=False, message="Client name is required", status=400, error_type="VALIDATION_ERROR") - - if isinstance(redirect_uris_raw, str): - redirect_uris = [u.strip() for u in redirect_uris_raw.replace(",", "\n").splitlines() if u.strip()] - else: - redirect_uris = [u.strip() for u in redirect_uris_raw if isinstance(u, str) and u.strip()] - - if not redirect_uris: - return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR") - - client_id = _secrets.token_hex(16) - client_secret = _secrets.token_urlsafe(32) - - client = OIDCClient( - organization_id=org_id, - name=name, - client_id=client_id, - client_secret_hash=bcrypt.generate_password_hash(client_secret).decode("utf-8"), - redirect_uris=redirect_uris, - grant_types=["authorization_code", "refresh_token"], - response_types=["code"], - scopes=["openid", "profile", "email"], - is_active=True, - is_confidential=True, - ) - from gatehouse_app.extensions import db - db.session.add(client) - db.session.commit() - - return api_response( - data={ - "client": { - "id": client.id, - "name": client.name, - "client_id": client.client_id, - "client_secret": client_secret, # Only returned once - "redirect_uris": client.redirect_uris, - "scopes": client.scopes, - "created_at": client.created_at.isoformat() + "Z", - } - }, - message="OIDC client created successfully", - status=201, - ) - - -@api_v1_bp.route("/organizations//clients/", methods=["DELETE"]) -@login_required -@require_admin -def delete_org_client(org_id, client_id): - """Deactivate an OIDC client. - - Returns: - 200: Client deactivated - 403: Not an admin - 404: Client not found - """ - from gatehouse_app.models import OIDCClient - from gatehouse_app.extensions import db - - client = OIDCClient.query.filter_by(id=client_id, organization_id=org_id).first() - if not client: - return api_response(success=False, message="Client not found", status=404) - - client.is_active = False - db.session.commit() - - return api_response(data={}, message="Client deactivated successfully") - - -@api_v1_bp.route("/organizations//members//send-mfa-reminder", methods=["POST"]) -@login_required -@require_admin -def send_mfa_reminder(org_id, user_id): - """Send an MFA reminder email to a specific member. - - Returns: - 200: Reminder sent (or silently skipped if no deadline record) - 403: Not an admin - 404: Member not found - """ - from gatehouse_app.models import User, MfaPolicyCompliance, OrganizationSecurityPolicy - from gatehouse_app.services.notification_service import NotificationService - - user = User.query.filter_by(id=user_id, deleted_at=None).first() - if not user: - return api_response(success=False, message="User not found", status=404) - - compliance = MfaPolicyCompliance.query.filter_by( - user_id=user_id, organization_id=org_id - ).first() - policy = OrganizationSecurityPolicy.query.filter_by(organization_id=org_id).first() - - if compliance and policy and compliance.deadline_at: - NotificationService.send_mfa_deadline_reminder(user, compliance, policy) - else: - # No compliance deadline — send a generic nudge - NotificationService._send_email( - to_address=user.email, - subject="Reminder: Set up multi-factor authentication", - body=( - f"Hi {user.full_name or user.email},\n\n" - "Your organization administrator has asked you to set up " - "multi-factor authentication (MFA) on your Gatehouse account.\n\n" - "Please log in and configure MFA as soon as possible.\n\n" - "Gatehouse Security Team" - ), - ) - - return api_response(data={}, message="Reminder sent successfully") - - -# ============================================================================= -# System-wide Audit Log (admin view) + User self audit -# ============================================================================= - -def _audit_log_to_dict(log): - """Serialize an AuditLog record to a dict.""" - return { - "id": log.id, - "action": log.action.value if log.action else None, - "user_id": log.user_id, - "user": ( - {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} - if log.user else None - ), - "organization_id": log.organization_id, - "resource_type": log.resource_type, - "resource_id": log.resource_id, - "ip_address": log.ip_address, - "user_agent": log.user_agent, - "request_id": log.request_id, - "description": log.description, - "success": log.success, - "error_message": log.error_message, - "metadata": log.extra_data, - "created_at": log.created_at.isoformat() if log.created_at else None, - "updated_at": log.updated_at.isoformat() if log.updated_at else None, - } - - -@api_v1_bp.route("/audit-logs", methods=["GET"]) -@login_required -def get_system_audit_logs(): - """ - Get all audit logs (system-wide). Any authenticated user can query - their own logs; org owners/admins also see org-scoped logs; this - endpoint returns ALL logs for users who own at least one org - (acting as an admin view). - - Query params: - page – page number (default 1) - per_page – results per page (default 50, max 200) - action – filter by AuditAction value - user_id – filter by user id - resource_type – filter by resource type - success – "true"/"false" - q – free-text search on description - """ - from gatehouse_app.models.auth.audit_log import AuditLog - from gatehouse_app.models.organization.organization_member import OrganizationMember - - current_user = g.current_user - page = max(1, int(request.args.get("page", 1))) - per_page = min(int(request.args.get("per_page", 50)), 200) - - # Check if the user is an admin or owner of any org to grant admin-level access - is_admin = OrganizationMember.query.filter( - OrganizationMember.user_id == current_user.id, - OrganizationMember.role.in_(["OWNER", "ADMIN"]), - OrganizationMember.deleted_at == None, - ).first() is not None - - query = AuditLog.query - - if not is_admin: - # Non-admins can only see their own logs - query = query.filter(AuditLog.user_id == current_user.id) - - # Optional filters - action_filter = request.args.get("action") - if action_filter: - query = query.filter(AuditLog.action == action_filter) - - user_id_filter = request.args.get("user_id") - if user_id_filter: - query = query.filter(AuditLog.user_id == user_id_filter) - - resource_type_filter = request.args.get("resource_type") - if resource_type_filter: - query = query.filter(AuditLog.resource_type == resource_type_filter) - - success_filter = request.args.get("success") - if success_filter is not None: - query = query.filter(AuditLog.success == (success_filter.lower() == "true")) - - q = request.args.get("q", "").strip() - if q: - query = query.filter(AuditLog.description.ilike(f"%{q}%")) - - query = query.order_by(AuditLog.created_at.desc()) - total = query.count() - logs = query.offset((page - 1) * per_page).limit(per_page).all() - - return api_response( - data={ - "audit_logs": [_audit_log_to_dict(log) for log in logs], - "count": total, - "page": page, - "per_page": per_page, - "pages": (total + per_page - 1) // per_page, - "is_admin_view": is_admin, - }, - message="Audit logs retrieved", - ) - - -@api_v1_bp.route("/auth/audit-logs", methods=["GET"]) -@login_required -def get_my_audit_logs(): - """ - Get audit logs for the currently authenticated user only. - - Query params: - page – page number (default 1) - per_page – results per page (default 50, max 200) - action – filter by AuditAction value - """ - from gatehouse_app.models.auth.audit_log import AuditLog - - current_user = g.current_user - page = max(1, int(request.args.get("page", 1))) - per_page = min(int(request.args.get("per_page", 50)), 200) - - query = AuditLog.query.filter(AuditLog.user_id == current_user.id) - - action_filter = request.args.get("action") - if action_filter: - query = query.filter(AuditLog.action == action_filter) - - query = query.order_by(AuditLog.created_at.desc()) - total = query.count() - logs = query.offset((page - 1) * per_page).limit(per_page).all() - - return api_response( - data={ - "audit_logs": [_audit_log_to_dict(log) for log in logs], - "count": total, - "page": page, - "per_page": per_page, - "pages": (total + per_page - 1) // per_page, - }, - message="Activity retrieved", - ) - - - -@api_v1_bp.route("/organizations//roles", methods=["GET"]) -@login_required -def list_organization_roles(org_id): - """List the available roles for an organization. - - Returns the canonical set of OrganizationRole values together with every - current member assigned to each role. - - Returns: - 200: roles list with member counts - 401: Not authenticated - 404: Organization not found - """ - from gatehouse_app.models.organization.organization import Organization - from gatehouse_app.models.organization.organization_member import OrganizationMember - - org = Organization.query.filter_by(id=org_id, deleted_at=None).first() - if not org: - return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") - - # Load all active members grouped by role - members = OrganizationMember.query.filter_by(organization_id=org_id, deleted_at=None).all() - by_role: dict = {r.value: [] for r in OrganizationRole} - for m in members: - role_key = m.role.value if hasattr(m.role, "value") else str(m.role) - if role_key in by_role: - by_role[role_key].append({ - "user_id": m.user_id, - "email": m.user.email if m.user else None, - "full_name": m.user.full_name if m.user else None, - "joined_at": m.created_at.isoformat() if m.created_at else None, - }) - - roles = [ - { - "role": r.value, - "member_count": len(by_role[r.value]), - "members": by_role[r.value], - } - for r in OrganizationRole - ] - return api_response(data={"roles": roles, "organization_id": org_id}, message="Roles retrieved") - - -@api_v1_bp.route("/organizations//roles//members", methods=["POST"]) -@login_required -@require_admin -def assign_role_to_member(org_id, role_name): - """Assign a role to a user in the organization (admin/owner only). - - This is a convenience endpoint equivalent to PATCH - /organizations//members//role but driven by role name. - - Request body: - user_id – UUID of the member to assign - - Returns: - 200: Role assigned - 400: Invalid role / missing user_id - 403: Not an admin/owner - 404: Org or member not found - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.extensions import db - - try: - new_role = OrganizationRole(role_name.lower()) - except ValueError: - valid = [r.value for r in OrganizationRole] - return api_response(success=False, message=f"Invalid role. Must be one of: {valid}", status=400, error_type="VALIDATION_ERROR") - - data = request.get_json() or {} - target_user_id = data.get("user_id") - if not target_user_id: - return api_response(success=False, message="user_id is required", status=400, error_type="VALIDATION_ERROR") - - membership = OrganizationMember.query.filter_by( - organization_id=org_id, user_id=target_user_id, deleted_at=None - ).first() - if not membership: - return api_response(success=False, message="Member not found in this organization", status=404, error_type="NOT_FOUND") - - membership.role = new_role - db.session.commit() - return api_response( - data={"user_id": target_user_id, "role": new_role.value}, - message=f"Role updated to {new_role.value}", - ) - - -@api_v1_bp.route("/organizations//roles//members/", methods=["DELETE"]) -@login_required -@require_admin -def remove_role_from_member(org_id, role_name, user_id): - """Demote a member to GUEST (effectively removing a named role). - - Removing a role downgrades the member to GUEST rather than removing them - from the organization entirely. Use the existing DELETE - /organizations//members/ endpoint to fully remove. - - Returns: - 200: Role removed (member demoted to GUEST) - 400: Invalid role name - 403: Not an admin/owner - 404: Org or member not found - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.extensions import db - - try: - OrganizationRole(role_name.lower()) # validate the name - except ValueError: - valid = [r.value for r in OrganizationRole] - return api_response(success=False, message=f"Invalid role. Must be one of: {valid}", status=400, error_type="VALIDATION_ERROR") - - membership = OrganizationMember.query.filter_by( - organization_id=org_id, user_id=user_id, deleted_at=None - ).first() - if not membership: - return api_response(success=False, message="Member not found in this organization", status=404, error_type="NOT_FOUND") - - membership.role = OrganizationRole.GUEST - db.session.commit() - return api_response( - data={"user_id": user_id, "role": OrganizationRole.GUEST.value}, - message="Role removed; member demoted to GUEST", - ) - - -@api_v1_bp.route("/organizations//cas", methods=["GET"]) -@login_required -@require_admin -def list_org_cas(org_id): - """List all Certificate Authorities for an organization. - - If the system config-file CA is configured (via SSH_CA_PRIVATE_KEY env var - or ca_key_path in etc/ssh_ca.conf) and no DB CA exists for a given ca_type, - a synthetic read-only entry is injected so the UI correctly shows the - system CA as configured rather than "Not configured". - - Returns: - 200: List of CAs (private_key excluded) - 403: Not admin/owner - 404: Org not found - """ - from gatehouse_app.models.ssh_ca.ca import CA, CaType - from gatehouse_app.models.organization.organization import Organization - - org = Organization.query.filter_by(id=org_id, deleted_at=None).first() - if not org: - return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") - - cas = CA.query.filter_by(organization_id=org_id, deleted_at=None).all() - ca_list = [ca.to_dict() for ca in cas] - - # Determine which ca_types are already covered by a DB CA - covered_types = {ca.ca_type for ca in cas} - - # Check whether a system config-file CA is available - system_ca_dict = _get_system_ca_dict() - if system_ca_dict: - # Inject a synthetic entry for each ca_type NOT covered by a real DB CA. - # The system CA only signs user certs (cert_type="user"), so we only - # inject it for the user slot. Host signing always needs a DB CA. - if CaType.USER not in covered_types: - ca_list.append({**system_ca_dict, "ca_type": "user"}) - - return api_response( - data={"cas": ca_list, "count": len(ca_list)}, - message="CAs retrieved", - ) - - -@api_v1_bp.route("/organizations//cas/", methods=["PATCH"]) -@login_required -@require_admin -def update_org_ca(org_id, ca_id): - """Update CA configuration (validity hours). - - Request body: - default_cert_validity_hours: Default validity in hours (optional) - max_cert_validity_hours: Maximum validity in hours (optional) - - Returns: - 200: CA updated successfully - 400: Validation error - 403: Not admin/owner - 404: Org or CA not found - """ - from gatehouse_app.models.ssh_ca.ca import CA - from gatehouse_app.models.organization.organization import Organization - from marshmallow import Schema, fields, validate, ValidationError - - org = Organization.query.filter_by(id=org_id, deleted_at=None).first() - if not org: - return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") - - ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first() - if not ca: - return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") - - try: - class CAUpdateSchema(Schema): - default_cert_validity_hours = fields.Int( - validate=validate.Range(min=1), - required=False - ) - max_cert_validity_hours = fields.Int( - validate=validate.Range(min=1), - required=False - ) - - schema = CAUpdateSchema() - data = schema.load(request.json or {}) - - # Validate that max >= default if both are provided - default_hours = data.get('default_cert_validity_hours', ca.default_cert_validity_hours) - max_hours = data.get('max_cert_validity_hours', ca.max_cert_validity_hours) - - if default_hours > max_hours: - return api_response( - success=False, - message="Default validity must be less than or equal to maximum validity", - status=400, - error_type="VALIDATION_ERROR", - ) - - # Update fields - if 'default_cert_validity_hours' in data: - ca.default_cert_validity_hours = data['default_cert_validity_hours'] - if 'max_cert_validity_hours' in data: - ca.max_cert_validity_hours = data['max_cert_validity_hours'] - - db.session.commit() - - return api_response( - data={"ca": ca.to_dict()}, - message="CA updated successfully", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - except Exception as e: - db.session.rollback() - return api_response( - success=False, - message="Failed to update CA", - status=500, - error_type="SERVER_ERROR", - ) - - -@api_v1_bp.route("/organizations//cas", methods=["POST"]) -@login_required -@require_admin -def create_org_ca(org_id): - """Create a new Certificate Authority for an organization. - - Request body: - name: CA display name (required) - description: Optional description - key_type: "ed25519" (default), "rsa", or "ecdsa" - default_cert_validity_hours: Default cert validity in hours (optional) - max_cert_validity_hours: Max cert validity in hours (optional) - - Returns: - 201: CA created successfully - 400: Validation error or name already taken - 403: Not admin/owner - 404: Org not found - """ - from gatehouse_app.models.ssh_ca.ca import CA, KeyType - from gatehouse_app.models.organization.organization import Organization - from gatehouse_app.utils.crypto import compute_ssh_fingerprint - from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key - from marshmallow import Schema, fields as ma_fields, validate, ValidationError as MaValidationError - from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey - - org = Organization.query.filter_by(id=org_id, deleted_at=None).first() - if not org: - return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") - - class CreateCASchema(Schema): - name = ma_fields.Str(required=True, validate=validate.Length(min=1, max=255)) - description = ma_fields.Str(load_default=None, allow_none=True) - ca_type = ma_fields.Str(load_default="user", validate=validate.OneOf(["user", "host"])) - key_type = ma_fields.Str(load_default="ed25519", validate=validate.OneOf(["ed25519", "rsa", "ecdsa"])) - default_cert_validity_hours = ma_fields.Int(load_default=8, validate=validate.Range(min=1)) - max_cert_validity_hours = ma_fields.Int(load_default=720, validate=validate.Range(min=1)) - - try: - schema = CreateCASchema() - data = schema.load(request.get_json() or {}) - - # Check name uniqueness within org - existing = CA.query.filter_by( - organization_id=org_id, name=data["name"], deleted_at=None - ).first() - if existing: - return api_response( - success=False, - message="A CA with that name already exists in this organization", - status=400, - error_type="DUPLICATE_NAME", - ) - - # Enforce one CA per type per org - from gatehouse_app.models.ssh_ca.ca import CaType - ca_type_val = data["ca_type"] - existing_type = CA.query.filter_by( - organization_id=org_id, deleted_at=None - ).filter(CA.ca_type == CaType(ca_type_val)).first() - if existing_type: - type_label = "User" if ca_type_val == "user" else "Host" - return api_response( - success=False, - message=f"A {type_label} CA already exists for this organization. " - f"You can only have one {type_label} CA per organization.", - status=400, - error_type="DUPLICATE_CA_TYPE", - ) - - # Validate cross-field - if data["default_cert_validity_hours"] > data["max_cert_validity_hours"]: - return api_response( - success=False, - message="Default validity must be less than or equal to maximum validity", - status=400, - error_type="VALIDATION_ERROR", - ) - - # Generate key pair - key_type = data["key_type"] - if key_type == "ed25519": - private_key_obj = Ed25519PrivateKey.generate() - elif key_type == "rsa": - private_key_obj = RsaPrivateKey.generate(4096) - else: # ecdsa - private_key_obj = EcdsaPrivateKey.generate() - - private_key_pem = private_key_obj.to_string() - public_key_str = private_key_obj.public_key.to_string() - fingerprint = compute_ssh_fingerprint(public_key_str) - - # Encrypt the private key before storing in the database - encrypted_private_key = encrypt_ca_key(private_key_pem) - - ca = CA( - organization_id=org_id, - name=data["name"], - description=data["description"], - ca_type=CaType(ca_type_val), - key_type=KeyType(key_type), - private_key=encrypted_private_key, - public_key=public_key_str, - fingerprint=fingerprint, - default_cert_validity_hours=data["default_cert_validity_hours"], - max_cert_validity_hours=data["max_cert_validity_hours"], - is_active=True, - ) - db.session.add(ca) - try: - db.session.commit() - except Exception as commit_exc: - db.session.rollback() - # Surface unique-constraint violations (soft-deleted record with same name) as a - # user-friendly 400 instead of a 500. - exc_str = str(commit_exc).lower() - if "uix_org_ca_name" in exc_str or "unique" in exc_str: - return api_response( - success=False, - message=( - "A CA with that name already exists in this organization " - "(it may have been recently deleted — choose a different name)." - ), - status=400, - error_type="DUPLICATE_NAME", - ) - raise - - return api_response( - data={"ca": ca.to_dict()}, - message="CA created successfully", - status=201, - ) - - except MaValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - except Exception as e: - db.session.rollback() - current_app.logger.exception("Failed to create CA") - return api_response( - success=False, - message="Failed to create CA", - status=500, - error_type="SERVER_ERROR", - ) - - -@api_v1_bp.route("/organizations//cas/", methods=["DELETE"]) -@login_required -@require_admin -def delete_org_ca(org_id, ca_id): - """Soft-delete a Certificate Authority. - - Deactivates the CA so no new certificates can be signed with it. - Existing certificates remain valid until they expire. - - Returns: - 200: CA deleted successfully - 403: Not admin/owner - 404: Org or CA not found - """ - from gatehouse_app.models.ssh_ca.ca import CA - from gatehouse_app.models.organization.organization import Organization - from gatehouse_app.utils.constants import AuditAction - from gatehouse_app.models import AuditLog - - org = Organization.query.filter_by(id=org_id, deleted_at=None).first() - if not org: - return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") - - ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first() - if not ca: - return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") - - try: - ca_name = ca.name - ca_type = ca.ca_type.value if hasattr(ca.ca_type, "value") else str(ca.ca_type) - ca.is_active = False - ca.delete(soft=True) - - AuditLog.log( - action=AuditAction.CA_DELETED, - user_id=g.current_user.id, - resource_type="CA", - resource_id=ca_id, - organization_id=org_id, - ip_address=request.remote_addr, - description=f"CA '{ca_name}' ({ca_type}) deleted", - ) - - return api_response( - data={"ca_id": ca_id}, - message="CA deleted successfully", - ) - except Exception as e: - db.session.rollback() - current_app.logger.exception("Failed to delete CA") - return api_response( - success=False, - message="Failed to delete CA", - status=500, - error_type="SERVER_ERROR", - ) - - -@api_v1_bp.route("/organizations//cas//rotate", methods=["POST"]) -@login_required -@require_admin -def rotate_org_ca(org_id, ca_id): - """Rotate (replace) a CA's key pair. - - Generates a new key pair of the same or different type. The old public key - fingerprint is returned so admins can update TrustedUserCAKeys / known_hosts - on their servers. All previously-issued certificates remain valid until they - expire but no new certificates will be signed with the old key. - - Request body (all optional): - key_type: "ed25519" (default keeps current), "rsa", or "ecdsa" - reason: Human-readable reason for the rotation - - Returns: - 200: CA rotated — { ca, old_fingerprint } - 403: Not admin/owner - 404: Org or CA not found - """ - from gatehouse_app.models.ssh_ca.ca import CA, KeyType - from gatehouse_app.models.organization.organization import Organization - from gatehouse_app.utils.crypto import compute_ssh_fingerprint - from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key - from gatehouse_app.utils.constants import AuditAction - from gatehouse_app.models import AuditLog - from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey - - org = Organization.query.filter_by(id=org_id, deleted_at=None).first() - if not org: - return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") - - ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first() - if not ca: - return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") - - data = request.get_json() or {} - new_key_type = data.get("key_type") or (ca.key_type.value if hasattr(ca.key_type, "value") else str(ca.key_type)) - reason = data.get("reason", "Admin-initiated key rotation") - - if new_key_type not in ("ed25519", "rsa", "ecdsa"): - return api_response( - success=False, - message="Invalid key_type. Must be one of: ed25519, rsa, ecdsa", - status=400, - error_type="VALIDATION_ERROR", - ) - - try: - old_fingerprint = ca.fingerprint - - # Generate new key pair - if new_key_type == "ed25519": - private_key_obj = Ed25519PrivateKey.generate() - elif new_key_type == "rsa": - private_key_obj = RsaPrivateKey.generate(4096) - else: # ecdsa - private_key_obj = EcdsaPrivateKey.generate() - - new_private_key = private_key_obj.to_string() - new_public_key = private_key_obj.public_key.to_string() - new_fingerprint = compute_ssh_fingerprint(new_public_key) - - # Encrypt the new private key before storing - encrypted_new_private_key = encrypt_ca_key(new_private_key) - - ca.rotate_key( - new_private_key=encrypted_new_private_key, - new_public_key=new_public_key, - new_fingerprint=new_fingerprint, - reason=reason, - ) - ca.key_type = KeyType(new_key_type) - db.session.commit() - - AuditLog.log( - action=AuditAction.CA_KEY_ROTATED, - user_id=g.current_user.id, - resource_type="CA", - resource_id=ca_id, - organization_id=org_id, - ip_address=request.remote_addr, - description=( - f"CA '{ca.name}' key rotated. " - f"Old fingerprint: {old_fingerprint}, New fingerprint: {new_fingerprint}. " - f"Reason: {reason}" - ), - ) - - return api_response( - data={ - "ca": ca.to_dict(), - "old_fingerprint": old_fingerprint, - }, - message="CA key rotated successfully. Update TrustedUserCAKeys / known_hosts on your servers.", - ) - except Exception as e: - db.session.rollback() - current_app.logger.exception("Failed to rotate CA key") - return api_response( - success=False, - message="Failed to rotate CA key", - status=500, - error_type="SERVER_ERROR", - ) - diff --git a/gatehouse_app/api/v1/organizations/__init__.py b/gatehouse_app/api/v1/organizations/__init__.py new file mode 100644 index 0000000..76f6fdd --- /dev/null +++ b/gatehouse_app/api/v1/organizations/__init__.py @@ -0,0 +1,4 @@ +"""Organization routes package.""" +from gatehouse_app.api.v1.organizations import core, members, invites, clients, cas, audit, roles + +__all__ = ["core", "members", "invites", "clients", "cas", "audit", "roles"] diff --git a/gatehouse_app/api/v1/organizations/_helpers.py b/gatehouse_app/api/v1/organizations/_helpers.py new file mode 100644 index 0000000..3463023 --- /dev/null +++ b/gatehouse_app/api/v1/organizations/_helpers.py @@ -0,0 +1,52 @@ +"""Shared helpers for organization endpoints.""" +import os + + +def _get_system_ca_dict(): + try: + from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config + from gatehouse_app.utils.crypto import compute_ssh_fingerprint + + priv_key = os.environ.get("SSH_CA_PRIVATE_KEY", "").strip() + pub_key = "" + + if not priv_key: + cfg = get_ssh_ca_config() + key_path = cfg.get_str("ca_key_path", "").strip() + if not key_path: + return None + pub_path = key_path + ".pub" + if not os.path.exists(pub_path): + return None + with open(pub_path) as f: + pub_key = f.read().strip() + else: + from sshkey_tools.keys import PrivateKey + pk = PrivateKey.from_string(priv_key) + pub_key = pk.public_key.to_string() + + fingerprint = compute_ssh_fingerprint(pub_key) + return { + "id": f"system-ca-{fingerprint[:16]}", + "organization_id": None, + "name": "System CA (config file)", + "description": ( + "Read-only — this CA is loaded from the server's SSH_CA_PRIVATE_KEY " + "environment variable or etc/ssh_ca.conf. Manage it on the server." + ), + "ca_type": "user", + "key_type": "unknown", + "public_key": pub_key, + "fingerprint": fingerprint, + "is_active": True, + "is_system": True, + "default_cert_validity_hours": 0, + "max_cert_validity_hours": 0, + "total_certs": 0, + "active_certs": 0, + "revoked_certs": 0, + "created_at": None, + "updated_at": None, + } + except Exception: + return None diff --git a/gatehouse_app/api/v1/organizations/audit.py b/gatehouse_app/api/v1/organizations/audit.py new file mode 100644 index 0000000..0ddd315 --- /dev/null +++ b/gatehouse_app/api/v1/organizations/audit.py @@ -0,0 +1,175 @@ +"""Organization audit log endpoints.""" +from flask import g, request +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required +from gatehouse_app.services.organization_service import OrganizationService + + +def _audit_log_to_dict(log): + action = log.action + return { + "id": log.id, + "action": action.value if hasattr(action, "value") else action, + "user_id": log.user_id, + "user": ( + {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} + if log.user else None + ), + "organization_id": log.organization_id, + "resource_type": log.resource_type, + "resource_id": log.resource_id, + "ip_address": log.ip_address, + "user_agent": log.user_agent, + "request_id": log.request_id, + "description": log.description, + "success": log.success, + "error_message": log.error_message, + "metadata": log.extra_data, + "created_at": log.created_at.isoformat() if log.created_at else None, + "updated_at": log.updated_at.isoformat() if log.updated_at else None, + } + + +@api_v1_bp.route("/organizations//audit-logs", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def get_organization_audit_logs(org_id): + from gatehouse_app.models.auth.audit_log import AuditLog + + OrganizationService.get_organization_by_id(org_id) + + page = int(request.args.get("page", 1)) + per_page = min(int(request.args.get("per_page", 50)), 200) + action_filter = request.args.get("action") + + query = AuditLog.query.filter_by(organization_id=org_id) + if action_filter: + query = query.filter_by(action=action_filter) + + query = query.order_by(AuditLog.created_at.desc()) + total = query.count() + logs = query.offset((page - 1) * per_page).limit(per_page).all() + + def log_to_dict(log): + action = log.action + return { + "id": log.id, + "action": action.value if hasattr(action, "value") else action, + "user_id": log.user_id, + "user_email": log.user.email if log.user else None, + "user": {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} if log.user else None, + "organization_id": log.organization_id, + "resource_type": log.resource_type, + "resource_id": log.resource_id, + "ip_address": log.ip_address, + "user_agent": log.user_agent, + "request_id": log.request_id, + "description": log.description, + "success": log.success, + "error_message": log.error_message, + "metadata": log.extra_data, + "created_at": log.created_at.isoformat() if log.created_at else None, + "updated_at": log.updated_at.isoformat() if log.updated_at else None, + } + + return api_response( + data={ + "audit_logs": [log_to_dict(log) for log in logs], + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + }, + message="Audit logs retrieved successfully", + ) + + +@api_v1_bp.route("/audit-logs", methods=["GET"]) +@login_required +def get_system_audit_logs(): + from gatehouse_app.models.auth.audit_log import AuditLog + from gatehouse_app.models.organization.organization_member import OrganizationMember + + current_user = g.current_user + page = max(1, int(request.args.get("page", 1))) + per_page = min(int(request.args.get("per_page", 50)), 200) + + is_admin = OrganizationMember.query.filter( + OrganizationMember.user_id == current_user.id, + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).first() is not None + + query = AuditLog.query + + if not is_admin: + query = query.filter(AuditLog.user_id == current_user.id) + + action_filter = request.args.get("action") + if action_filter: + query = query.filter(AuditLog.action == action_filter) + + user_id_filter = request.args.get("user_id") + if user_id_filter: + query = query.filter(AuditLog.user_id == user_id_filter) + + resource_type_filter = request.args.get("resource_type") + if resource_type_filter: + query = query.filter(AuditLog.resource_type == resource_type_filter) + + success_filter = request.args.get("success") + if success_filter is not None: + query = query.filter(AuditLog.success == (success_filter.lower() == "true")) + + q = request.args.get("q", "").strip() + if q: + query = query.filter(AuditLog.description.ilike(f"%{q}%")) + + query = query.order_by(AuditLog.created_at.desc()) + total = query.count() + logs = query.offset((page - 1) * per_page).limit(per_page).all() + + return api_response( + data={ + "audit_logs": [_audit_log_to_dict(log) for log in logs], + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + "is_admin_view": is_admin, + }, + message="Audit logs retrieved", + ) + + +@api_v1_bp.route("/auth/audit-logs", methods=["GET"]) +@login_required +def get_my_audit_logs(): + from gatehouse_app.models.auth.audit_log import AuditLog + + current_user = g.current_user + page = max(1, int(request.args.get("page", 1))) + per_page = min(int(request.args.get("per_page", 50)), 200) + + query = AuditLog.query.filter(AuditLog.user_id == current_user.id) + + action_filter = request.args.get("action") + if action_filter: + query = query.filter(AuditLog.action == action_filter) + + query = query.order_by(AuditLog.created_at.desc()) + total = query.count() + logs = query.offset((page - 1) * per_page).limit(per_page).all() + + return api_response( + data={ + "audit_logs": [_audit_log_to_dict(log) for log in logs], + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + }, + message="Activity retrieved", + ) diff --git a/gatehouse_app/api/v1/organizations/cas.py b/gatehouse_app/api/v1/organizations/cas.py new file mode 100644 index 0000000..ad1ac71 --- /dev/null +++ b/gatehouse_app/api/v1/organizations/cas.py @@ -0,0 +1,261 @@ +"""Organization Certificate Authority endpoints.""" +from flask import g, request, current_app +from marshmallow import ValidationError +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, require_admin +from gatehouse_app.extensions import db +from gatehouse_app.api.v1.organizations._helpers import _get_system_ca_dict + + +@api_v1_bp.route("/organizations//cas", methods=["GET"]) +@login_required +@require_admin +def list_org_cas(org_id): + from gatehouse_app.models.ssh_ca.ca import CA, CaType + from gatehouse_app.models.organization.organization import Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + cas = CA.query.filter_by(organization_id=org_id, deleted_at=None).all() + ca_list = [ca.to_dict() for ca in cas] + covered_types = {ca.ca_type for ca in cas} + + system_ca_dict = _get_system_ca_dict() + if system_ca_dict and CaType.USER not in covered_types: + ca_list.append({**system_ca_dict, "ca_type": "user"}) + + return api_response(data={"cas": ca_list, "count": len(ca_list)}, message="CAs retrieved") + + +@api_v1_bp.route("/organizations//cas/", methods=["PATCH"]) +@login_required +@require_admin +def update_org_ca(org_id, ca_id): + from gatehouse_app.models.ssh_ca.ca import CA + from gatehouse_app.models.organization.organization import Organization + from marshmallow import Schema, fields, validate + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + try: + class CAUpdateSchema(Schema): + default_cert_validity_hours = fields.Int(validate=validate.Range(min=1), required=False) + max_cert_validity_hours = fields.Int(validate=validate.Range(min=1), required=False) + + schema = CAUpdateSchema() + data = schema.load(request.json or {}) + + default_hours = data.get("default_cert_validity_hours", ca.default_cert_validity_hours) + max_hours = data.get("max_cert_validity_hours", ca.max_cert_validity_hours) + + if default_hours > max_hours: + return api_response(success=False, message="Default validity must be less than or equal to maximum validity", status=400, error_type="VALIDATION_ERROR") + + if "default_cert_validity_hours" in data: + ca.default_cert_validity_hours = data["default_cert_validity_hours"] + if "max_cert_validity_hours" in data: + ca.max_cert_validity_hours = data["max_cert_validity_hours"] + + db.session.commit() + return api_response(data={"ca": ca.to_dict()}, message="CA updated successfully") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + except Exception: + db.session.rollback() + return api_response(success=False, message="Failed to update CA", status=500, error_type="SERVER_ERROR") + + +@api_v1_bp.route("/organizations//cas", methods=["POST"]) +@login_required +@require_admin +def create_org_ca(org_id): + from gatehouse_app.models.ssh_ca.ca import CA, KeyType, CaType + from gatehouse_app.models.organization.organization import Organization + from gatehouse_app.utils.crypto import compute_ssh_fingerprint + from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key + from marshmallow import Schema, fields as ma_fields, validate, ValidationError as MaValidationError + from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + class CreateCASchema(Schema): + name = ma_fields.Str(required=True, validate=validate.Length(min=1, max=255)) + description = ma_fields.Str(load_default=None, allow_none=True) + ca_type = ma_fields.Str(load_default="user", validate=validate.OneOf(["user", "host"])) + key_type = ma_fields.Str(load_default="ed25519", validate=validate.OneOf(["ed25519", "rsa", "ecdsa"])) + default_cert_validity_hours = ma_fields.Int(load_default=8, validate=validate.Range(min=1)) + max_cert_validity_hours = ma_fields.Int(load_default=720, validate=validate.Range(min=1)) + + try: + schema = CreateCASchema() + data = schema.load(request.get_json() or {}) + + existing = CA.query.filter_by(organization_id=org_id, name=data["name"], deleted_at=None).first() + if existing: + return api_response(success=False, message="A CA with that name already exists in this organization", status=400, error_type="DUPLICATE_NAME") + + ca_type_val = data["ca_type"] + existing_type = CA.query.filter_by(organization_id=org_id, deleted_at=None).filter(CA.ca_type == CaType(ca_type_val)).first() + if existing_type: + type_label = "User" if ca_type_val == "user" else "Host" + return api_response(success=False, message=f"A {type_label} CA already exists for this organization. You can only have one {type_label} CA per organization.", status=400, error_type="DUPLICATE_CA_TYPE") + + if data["default_cert_validity_hours"] > data["max_cert_validity_hours"]: + return api_response(success=False, message="Default validity must be less than or equal to maximum validity", status=400, error_type="VALIDATION_ERROR") + + key_type = data["key_type"] + if key_type == "ed25519": + private_key_obj = Ed25519PrivateKey.generate() + elif key_type == "rsa": + private_key_obj = RsaPrivateKey.generate(4096) + else: + private_key_obj = EcdsaPrivateKey.generate() + + private_key_pem = private_key_obj.to_string() + public_key_str = private_key_obj.public_key.to_string() + fingerprint = compute_ssh_fingerprint(public_key_str) + encrypted_private_key = encrypt_ca_key(private_key_pem) + + ca = CA( + organization_id=org_id, + name=data["name"], + description=data["description"], + ca_type=CaType(ca_type_val), + key_type=KeyType(key_type), + private_key=encrypted_private_key, + public_key=public_key_str, + fingerprint=fingerprint, + default_cert_validity_hours=data["default_cert_validity_hours"], + max_cert_validity_hours=data["max_cert_validity_hours"], + is_active=True, + ) + db.session.add(ca) + try: + db.session.commit() + except Exception as commit_exc: + db.session.rollback() + exc_str = str(commit_exc).lower() + if "uix_org_ca_name" in exc_str or "unique" in exc_str: + return api_response(success=False, message="A CA with that name already exists in this organization (it may have been recently deleted — choose a different name).", status=400, error_type="DUPLICATE_NAME") + raise + + return api_response(data={"ca": ca.to_dict()}, message="CA created successfully", status=201) + except MaValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + except Exception: + db.session.rollback() + current_app.logger.exception("Failed to create CA") + return api_response(success=False, message="Failed to create CA", status=500, error_type="SERVER_ERROR") + + +@api_v1_bp.route("/organizations//cas/", methods=["DELETE"]) +@login_required +@require_admin +def delete_org_ca(org_id, ca_id): + from gatehouse_app.models.ssh_ca.ca import CA + from gatehouse_app.models.organization.organization import Organization + from gatehouse_app.utils.constants import AuditAction + from gatehouse_app.models import AuditLog + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + try: + ca_name = ca.name + ca_type = ca.ca_type.value if hasattr(ca.ca_type, "value") else str(ca.ca_type) + ca.is_active = False + ca.delete(soft=True) + + AuditLog.log( + action=AuditAction.CA_DELETED, + user_id=g.current_user.id, + resource_type="CA", + resource_id=ca_id, + organization_id=org_id, + ip_address=request.remote_addr, + description=f"CA '{ca_name}' ({ca_type}) deleted", + ) + return api_response(data={"ca_id": ca_id}, message="CA deleted successfully") + except Exception: + db.session.rollback() + current_app.logger.exception("Failed to delete CA") + return api_response(success=False, message="Failed to delete CA", status=500, error_type="SERVER_ERROR") + + +@api_v1_bp.route("/organizations//cas//rotate", methods=["POST"]) +@login_required +@require_admin +def rotate_org_ca(org_id, ca_id): + from gatehouse_app.models.ssh_ca.ca import CA, KeyType + from gatehouse_app.models.organization.organization import Organization + from gatehouse_app.utils.crypto import compute_ssh_fingerprint + from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key + from gatehouse_app.utils.constants import AuditAction + from gatehouse_app.models import AuditLog + from sshkey_tools.keys import Ed25519PrivateKey, RsaPrivateKey, EcdsaPrivateKey + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + data = request.get_json() or {} + new_key_type = data.get("key_type") or (ca.key_type.value if hasattr(ca.key_type, "value") else str(ca.key_type)) + reason = data.get("reason", "Admin-initiated key rotation") + + if new_key_type not in ("ed25519", "rsa", "ecdsa"): + return api_response(success=False, message="Invalid key_type. Must be one of: ed25519, rsa, ecdsa", status=400, error_type="VALIDATION_ERROR") + + try: + old_fingerprint = ca.fingerprint + + if new_key_type == "ed25519": + private_key_obj = Ed25519PrivateKey.generate() + elif new_key_type == "rsa": + private_key_obj = RsaPrivateKey.generate(4096) + else: + private_key_obj = EcdsaPrivateKey.generate() + + new_private_key = private_key_obj.to_string() + new_public_key = private_key_obj.public_key.to_string() + new_fingerprint = compute_ssh_fingerprint(new_public_key) + encrypted_new_private_key = encrypt_ca_key(new_private_key) + + ca.rotate_key(new_private_key=encrypted_new_private_key, new_public_key=new_public_key, new_fingerprint=new_fingerprint, reason=reason) + ca.key_type = KeyType(new_key_type) + db.session.commit() + + AuditLog.log( + action=AuditAction.CA_KEY_ROTATED, + user_id=g.current_user.id, + resource_type="CA", + resource_id=ca_id, + organization_id=org_id, + ip_address=request.remote_addr, + description=(f"CA '{ca.name}' key rotated. Old fingerprint: {old_fingerprint}, New fingerprint: {new_fingerprint}. Reason: {reason}"), + ) + + return api_response(data={"ca": ca.to_dict(), "old_fingerprint": old_fingerprint}, message="CA key rotated successfully. Update TrustedUserCAKeys / known_hosts on your servers.") + except Exception: + db.session.rollback() + current_app.logger.exception("Failed to rotate CA key") + return api_response(success=False, message="Failed to rotate CA key", status=500, error_type="SERVER_ERROR") diff --git a/gatehouse_app/api/v1/organizations/clients.py b/gatehouse_app/api/v1/organizations/clients.py new file mode 100644 index 0000000..553837a --- /dev/null +++ b/gatehouse_app/api/v1/organizations/clients.py @@ -0,0 +1,110 @@ +"""Organization OIDC client endpoints.""" +import secrets as _secrets +from flask import g, request +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required +from gatehouse_app.extensions import db, bcrypt + + +@api_v1_bp.route("/organizations//clients", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def list_org_clients(org_id): + from gatehouse_app.models import OIDCClient, Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + clients = OIDCClient.query.filter_by(organization_id=org_id, is_active=True).all() + + def client_to_dict(c): + return { + "id": c.id, + "name": c.name, + "client_id": c.client_id, + "redirect_uris": c.redirect_uris, + "scopes": c.scopes, + "grant_types": c.grant_types, + "is_active": c.is_active, + "created_at": c.created_at.isoformat() + "Z", + } + + return api_response(data={"clients": [client_to_dict(c) for c in clients], "count": len(clients)}, message="Clients retrieved successfully") + + +@api_v1_bp.route("/organizations//clients", methods=["POST"]) +@login_required +@require_admin +def create_org_client(org_id): + from gatehouse_app.models import OIDCClient, Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + data = request.get_json() or {} + name = (data.get("name") or "").strip() + redirect_uris_raw = data.get("redirect_uris") or [] + + if not name: + return api_response(success=False, message="Client name is required", status=400, error_type="VALIDATION_ERROR") + + if isinstance(redirect_uris_raw, str): + redirect_uris = [u.strip() for u in redirect_uris_raw.replace(",", "\n").splitlines() if u.strip()] + else: + redirect_uris = [u.strip() for u in redirect_uris_raw if isinstance(u, str) and u.strip()] + + if not redirect_uris: + return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR") + + client_id = _secrets.token_hex(16) + client_secret = _secrets.token_urlsafe(32) + + client = OIDCClient( + organization_id=org_id, + name=name, + client_id=client_id, + client_secret_hash=bcrypt.generate_password_hash(client_secret).decode("utf-8"), + redirect_uris=redirect_uris, + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scopes=["openid", "profile", "email"], + is_active=True, + is_confidential=True, + ) + db.session.add(client) + db.session.commit() + + return api_response( + data={ + "client": { + "id": client.id, + "name": client.name, + "client_id": client.client_id, + "client_secret": client_secret, + "redirect_uris": client.redirect_uris, + "scopes": client.scopes, + "created_at": client.created_at.isoformat() + "Z", + } + }, + message="OIDC client created successfully", + status=201, + ) + + +@api_v1_bp.route("/organizations//clients/", methods=["DELETE"]) +@login_required +@require_admin +def delete_org_client(org_id, client_id): + from gatehouse_app.models import OIDCClient + + client = OIDCClient.query.filter_by(id=client_id, organization_id=org_id).first() + if not client: + return api_response(success=False, message="Client not found", status=404) + + client.is_active = False + db.session.commit() + return api_response(data={}, message="Client deactivated successfully") diff --git a/gatehouse_app/api/v1/organizations/core.py b/gatehouse_app/api/v1/organizations/core.py new file mode 100644 index 0000000..e56bd14 --- /dev/null +++ b/gatehouse_app/api/v1/organizations/core.py @@ -0,0 +1,85 @@ +"""Organization core CRUD endpoints.""" +from flask import g, request +from marshmallow import ValidationError +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required +from gatehouse_app.schemas.organization_schema import OrganizationCreateSchema, OrganizationUpdateSchema +from gatehouse_app.services.organization_service import OrganizationService + + +@api_v1_bp.route("/organizations", methods=["POST"]) +@login_required +@full_access_required +def create_organization(): + try: + schema = OrganizationCreateSchema() + data = schema.load(request.json) + org = OrganizationService.create_organization( + name=data["name"], + slug=data["slug"], + owner_user_id=g.current_user.id, + description=data.get("description"), + logo_url=data.get("logo_url"), + ) + return api_response(data={"organization": org.to_dict()}, message="Organization created successfully", status=201) + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + + +@api_v1_bp.route("/organizations/", methods=["GET"]) +@login_required +@full_access_required +def get_organization(org_id): + org = OrganizationService.get_organization_by_id(org_id) + if not org.is_member(g.current_user.id): + return api_response(success=False, message="You are not a member of this organization", status=403, error_type="AUTHORIZATION_ERROR") + return api_response( + data={"organization": org.to_dict(), "member_count": org.get_member_count()}, + message="Organization retrieved successfully", + ) + + +@api_v1_bp.route("/organizations/", methods=["PATCH"]) +@login_required +@require_admin +@full_access_required +def update_organization(org_id): + try: + schema = OrganizationUpdateSchema() + data = schema.load(request.json) + org = OrganizationService.get_organization_by_id(org_id) + org = OrganizationService.update_organization(org=org, user_id=g.current_user.id, **data) + return api_response(data={"organization": org.to_dict()}, message="Organization updated successfully") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + + +@api_v1_bp.route("/organizations/", methods=["DELETE"]) +@login_required +@full_access_required +def delete_organization(org_id): + from gatehouse_app.models.organization.organization_member import OrganizationMember as _OrgMember + from gatehouse_app.utils.constants import OrganizationRole as _OrgRole + + caller = g.current_user + org = OrganizationService.get_organization_by_id(org_id) + + caller_membership = _OrgMember.query.filter_by(user_id=caller.id, organization_id=org.id, deleted_at=None).first() + if not caller_membership or caller_membership.role != _OrgRole.OWNER: + return api_response(success=False, message="Only the organization owner can delete the organization.", status=403, error_type="AUTHORIZATION_ERROR") + + active_member_count = org.get_member_count() + if active_member_count > 1: + data = request.get_json(silent=True) or {} + if not data.get("confirm"): + return api_response( + success=False, + message=(f"This organization has {active_member_count} active members. Deleting it will remove all members and their data. Send {{\"confirm\": true}} to confirm."), + status=400, + error_type="CONFIRMATION_REQUIRED", + error_details={"member_count": active_member_count}, + ) + + OrganizationService.force_delete_organization(org=org, user_id=caller.id) + return api_response(message="Organization deleted successfully") diff --git a/gatehouse_app/api/v1/organizations/invites.py b/gatehouse_app/api/v1/organizations/invites.py new file mode 100644 index 0000000..2920bc4 --- /dev/null +++ b/gatehouse_app/api/v1/organizations/invites.py @@ -0,0 +1,256 @@ +"""Organization invite token endpoints.""" +import logging +from flask import g, request, current_app +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, require_admin +from gatehouse_app.services.notification_service import NotificationService +from gatehouse_app.services.auth_service import AuthService +from gatehouse_app.services.organization_service import OrganizationService +from gatehouse_app.utils.constants import OrganizationRole + + +@api_v1_bp.route("/organizations//invites", methods=["POST"]) +@login_required +@require_admin +def create_org_invite(org_id): + from gatehouse_app.models import OrgInviteToken, Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + data = request.get_json() or {} + email = (data.get("email") or "").strip().lower() + role = (data.get("role") or "member").strip() + + if not email: + return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR") + + invite = OrgInviteToken.generate( + organization_id=org_id, + email=email, + role=role, + invited_by_id=g.current_user.id, + ) + + app_url = current_app.config.get("APP_URL", "http://localhost:8080") + invite_link = f"{app_url}/invite?token={invite.token}" + + email_sent = NotificationService._send_email( + to_address=email, + subject=f"You're invited to join {org.name} on Gatehouse", + body=( + f"You've been invited to join {org.name} on Gatehouse.\n\n" + f"Click the link below to accept the invitation (valid for 7 days):\n" + f"{invite_link}\n\n" + f"Gatehouse Security Team" + ), + ) + + if not email_sent: + logging.getLogger(__name__).warning( + f"[INVITE LINK] Email not sent (EMAIL_ENABLED=False or SMTP down). " + f"Invite for {email} → {invite_link}" + ) + else: + logging.getLogger(__name__).info( + f"[INVITE] Email sent successfully to {email}" + ) + + response_data = { + "invite": { + "id": invite.id, + "email": invite.email, + "role": invite.role, + "expires_at": invite.expires_at.isoformat() + "Z", + # Only include invite_link when email delivery failed — signals frontend to show copy dialog + **({"invite_link": invite_link} if not email_sent else {}), + } + } + + return api_response( + data=response_data, + message="Invite sent successfully", + status=201, + ) + + +@api_v1_bp.route("/organizations//invites", methods=["GET"]) +@login_required +@require_admin +def list_org_invites(org_id): + from gatehouse_app.models import OrgInviteToken, Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + invites = ( + OrgInviteToken.query.filter_by(organization_id=org_id) + .filter(OrgInviteToken.accepted_at == None) + .filter(OrgInviteToken.deleted_at == None) + .all() + ) + + def invite_to_dict(inv): + return { + "id": inv.id, + "email": inv.email, + "role": inv.role, + "invited_by_id": inv.invited_by_id, + "expires_at": inv.expires_at.isoformat() + "Z", + "token": inv.token, + } + + return api_response( + data={"invites": [invite_to_dict(i) for i in invites]}, + message="Invites retrieved", + ) + + +@api_v1_bp.route("/organizations//invites/", methods=["DELETE"]) +@login_required +@require_admin +def cancel_org_invite(org_id, invite_id): + from gatehouse_app.models import OrgInviteToken, Organization + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404) + + invite = OrgInviteToken.query.filter_by(id=invite_id, organization_id=org_id, deleted_at=None).first() + if not invite: + return api_response(success=False, message="Invite not found", status=404) + + invite.delete(soft=True) + return api_response(data={}, message="Invite cancelled") + + +@api_v1_bp.route("/invites/", methods=["GET"]) +def get_invite(token): + from gatehouse_app.models import OrgInviteToken, User + + invite = OrgInviteToken.query.filter_by(token=token).first() + if not invite or not invite.is_valid: + return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN") + + user_exists = User.query.filter_by(email=invite.email, deleted_at=None).first() is not None + + return api_response( + data={ + "email": invite.email, + "organization": {"id": invite.organization_id, "name": invite.organization.name}, + "role": invite.role, + "user_exists": user_exists, + }, + message="Invite found", + ) + + +@api_v1_bp.route("/invites//accept", methods=["POST"]) +def accept_invite(token): + """Accept an organization invite. + + """ + from gatehouse_app.models import OrgInviteToken, User + from gatehouse_app.services.session_service import SessionService + + invite = OrgInviteToken.query.filter_by(token=token).first() + if not invite or not invite.is_valid: + return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN") + + # --- Resolve the user ----------------------------------------------- + # If the request carries a valid session token the user is already + # authenticated (e.g. via Google OAuth). Use that identity and skip + # any password / registration logic entirely. + user = None + auth_header = request.headers.get("Authorization", "") + if auth_header.lower().startswith("bearer "): + bearer_token = auth_header.split(None, 1)[1].strip() + session = SessionService.get_active_session_by_token(bearer_token) + if session and session.is_active(): + session_user = session.user + # Verify the authenticated user's email matches the invite + if session_user.email.lower() != invite.email.lower(): + return api_response( + success=False, + message="This invite was sent to a different email address.", + status=403, + error_type="EMAIL_MISMATCH", + ) + user = session_user + + data = request.get_json() or {} + full_name = data.get("full_name") or "" + password = data.get("password") or "" + password_confirm = data.get("password_confirm") or "" + + if user is None: + # Fall back to email lookup (existing account created by any method) + user = User.query.filter( + User.email.ilike(invite.email), + User.deleted_at.is_(None), + ).first() + + if not user: + # Brand-new account — password registration required + if not password: + return api_response(success=False, message="Password is required for new accounts.", status=400, error_type="VALIDATION_ERROR") + if password != password_confirm: + return api_response(success=False, message="Passwords do not match.", status=400, error_type="VALIDATION_ERROR") + if len(password) < 8: + return api_response(success=False, message="Password must be at least 8 characters.", status=400, error_type="VALIDATION_ERROR") + try: + user = AuthService.register_user(email=invite.email, password=password, full_name=full_name or None) + except Exception as exc: + return api_response(success=False, message=str(exc), status=400, error_type="REGISTRATION_ERROR") + + # Add to org + try: + org_role = OrganizationRole(invite.role) + except ValueError: + org_role = OrganizationRole.MEMBER + + try: + OrganizationService.add_member( + org=invite.organization, + user_id=user.id, + role=org_role, + inviter_id=invite.invited_by_id, + ) + except Exception: + from gatehouse_app.extensions import db + db.session.rollback() + return api_response( + success=False, + message="Failed to add you to the organization. You may already be a member.", + status=409, + error_type="CONFLICT", + ) + + invite.accept() + + has_webauthn = user.has_webauthn_enabled() + has_totp = user.has_totp_enabled() + + if has_webauthn: + from flask import session as flask_session + flask_session["webauthn_pending_user_id"] = user.id + return api_response(data={"requires_webauthn": True}, message="Passkey verification required. Please use your passkey to complete sign-in.") + + if has_totp: + from flask import session as flask_session + flask_session["totp_pending_user_id"] = user.id + return api_response(data={"requires_totp": True}, message="TOTP code required. Please enter your 6-digit code from your authenticator app.") + + user_session = AuthService.create_session(user) + + return api_response( + data={ + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z", + }, + message="Invitation accepted. Welcome!", + ) diff --git a/gatehouse_app/api/v1/organizations/members.py b/gatehouse_app/api/v1/organizations/members.py new file mode 100644 index 0000000..df594f5 --- /dev/null +++ b/gatehouse_app/api/v1/organizations/members.py @@ -0,0 +1,176 @@ +"""Organization member management endpoints.""" +from flask import g, request +from marshmallow import ValidationError +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required +from gatehouse_app.schemas.organization_schema import InviteMemberSchema, UpdateMemberRoleSchema +from gatehouse_app.services.organization_service import OrganizationService +from gatehouse_app.services.user_service import UserService +from gatehouse_app.utils.constants import OrganizationRole + + +@api_v1_bp.route("/organizations//members", methods=["GET"]) +@login_required +@full_access_required +def get_organization_members(org_id): + org = OrganizationService.get_organization_by_id(org_id) + if not org.is_member(g.current_user.id): + return api_response(success=False, message="You are not a member of this organization", status=403, error_type="AUTHORIZATION_ERROR") + + members_data = [] + for member in org.members: + if member.deleted_at is None: + member_dict = member.to_dict() + member_dict["user"] = member.user.to_dict() + members_data.append(member_dict) + + return api_response(data={"members": members_data, "count": len(members_data)}, message="Members retrieved successfully") + + +@api_v1_bp.route("/organizations//members", methods=["POST"]) +@login_required +@require_admin +@full_access_required +def add_organization_member(org_id): + try: + schema = InviteMemberSchema() + data = schema.load(request.json) + org = OrganizationService.get_organization_by_id(org_id) + user = UserService.get_user_by_email(data["email"]) + if not user: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + role = OrganizationRole(data["role"]) + member = OrganizationService.add_member(org=org, user_id=user.id, role=role, inviter_id=g.current_user.id) + member_dict = member.to_dict() + member_dict["user"] = user.to_dict() + return api_response(data={"member": member_dict}, message="Member added successfully", status=201) + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + + +@api_v1_bp.route("/organizations//members/", methods=["DELETE"]) +@login_required +@require_admin +@full_access_required +def remove_organization_member(org_id, user_id): + org = OrganizationService.get_organization_by_id(org_id) + OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id) + return api_response(message="Member removed successfully") + + +@api_v1_bp.route("/organizations//members//role", methods=["PATCH"]) +@login_required +@require_admin +@full_access_required +def update_member_role(org_id, user_id): + try: + schema = UpdateMemberRoleSchema() + data = schema.load(request.json) + org = OrganizationService.get_organization_by_id(org_id) + new_role = OrganizationRole(data["role"]) + member = OrganizationService.update_member_role(org=org, user_id=user_id, new_role=new_role, updater_id=g.current_user.id) + member_dict = member.to_dict() + member_dict["user"] = member.user.to_dict() + return api_response(data={"member": member_dict}, message="Member role updated successfully") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + + +@api_v1_bp.route("/organizations//transfer-ownership", methods=["POST"]) +@login_required +@full_access_required +def transfer_organization_ownership(org_id): + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.utils.constants import AuditAction + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + data = request.get_json() or {} + new_owner_user_id = data.get("new_owner_user_id") + + if not new_owner_user_id: + return api_response(success=False, message="new_owner_user_id is required", status=400, error_type="VALIDATION_ERROR") + + if str(new_owner_user_id) == str(caller.id): + return api_response(success=False, message="You are already the owner of this organization.", status=409, error_type="CONFLICT") + + org = OrganizationService.get_organization_by_id(org_id) + + caller_membership = OrganizationMember.query.filter_by(organization_id=org.id, user_id=caller.id, deleted_at=None).first() + if not caller_membership or caller_membership.role != OrganizationRole.OWNER: + return api_response(success=False, message="Only the organization owner can transfer ownership.", status=403, error_type="AUTHORIZATION_ERROR") + + target_membership = OrganizationMember.query.filter_by(organization_id=org.id, user_id=new_owner_user_id, deleted_at=None).first() + if not target_membership: + return api_response(success=False, message="Target user is not a member of this organization.", status=404, error_type="NOT_FOUND") + + if target_membership.role == OrganizationRole.OWNER: + return api_response(success=False, message="Target user is already the owner.", status=409, error_type="CONFLICT") + + try: + demoted = OrganizationService.update_member_role(org=org, user_id=str(caller.id), new_role=OrganizationRole.ADMIN, updater_id=str(caller.id)) + promoted = OrganizationService.update_member_role(org=org, user_id=str(new_owner_user_id), new_role=OrganizationRole.OWNER, updater_id=str(caller.id)) + except Exception as exc: + from gatehouse_app.extensions import db as _db + _db.session.rollback() + return api_response(success=False, message=f"Failed to transfer ownership: {exc}", status=500, error_type="SERVER_ERROR") + + AuditService.log_action( + action=AuditAction.ORG_OWNERSHIP_TRANSFERRED, + user_id=caller.id, + organization_id=org.id, + resource_type="organization", + resource_id=str(org.id), + description=(f"Ownership of '{org.name}' transferred from {caller.email} to {target_membership.user.email if target_membership.user else new_owner_user_id}"), + metadata={ + "previous_owner_id": str(caller.id), + "previous_owner_email": caller.email, + "new_owner_id": str(new_owner_user_id), + "new_owner_email": target_membership.user.email if target_membership.user else None, + }, + ) + + def _member_dict(m): + d = m.to_dict() + if m.user: + d["user"] = m.user.to_dict() + return d + + return api_response( + data={"previous_owner": _member_dict(demoted), "new_owner": _member_dict(promoted)}, + message=(f"Ownership of '{org.name}' successfully transferred to {target_membership.user.email if target_membership.user else new_owner_user_id}."), + ) + + +@api_v1_bp.route("/organizations//members//send-mfa-reminder", methods=["POST"]) +@login_required +@require_admin +def send_mfa_reminder(org_id, user_id): + from gatehouse_app.models import User, MfaPolicyCompliance, OrganizationSecurityPolicy + from gatehouse_app.services.notification_service import NotificationService + + user = User.query.filter_by(id=user_id, deleted_at=None).first() + if not user: + return api_response(success=False, message="User not found", status=404) + + compliance = MfaPolicyCompliance.query.filter_by(user_id=user_id, organization_id=org_id).first() + policy = OrganizationSecurityPolicy.query.filter_by(organization_id=org_id).first() + + if compliance and policy and compliance.deadline_at: + NotificationService.send_mfa_deadline_reminder(user, compliance, policy) + else: + NotificationService._send_email( + to_address=user.email, + subject="Reminder: Set up multi-factor authentication", + body=( + f"Hi {user.full_name or user.email},\n\n" + "Your organization administrator has asked you to set up " + "multi-factor authentication (MFA) on your Gatehouse account.\n\n" + "Please log in and configure MFA as soon as possible.\n\n" + "Gatehouse Security Team" + ), + ) + + return api_response(data={}, message="Reminder sent successfully") diff --git a/gatehouse_app/api/v1/organizations/roles.py b/gatehouse_app/api/v1/organizations/roles.py new file mode 100644 index 0000000..3c982e9 --- /dev/null +++ b/gatehouse_app/api/v1/organizations/roles.py @@ -0,0 +1,85 @@ +"""Organization role management endpoints.""" +from flask import g, request +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required +from gatehouse_app.utils.constants import OrganizationRole +from gatehouse_app.extensions import db + + +@api_v1_bp.route("/organizations//roles", methods=["GET"]) +@login_required +def list_organization_roles(org_id): + from gatehouse_app.models.organization.organization import Organization + from gatehouse_app.models.organization.organization_member import OrganizationMember + + org = Organization.query.filter_by(id=org_id, deleted_at=None).first() + if not org: + return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND") + + members = OrganizationMember.query.filter_by(organization_id=org_id, deleted_at=None).all() + by_role: dict = {r.value: [] for r in OrganizationRole} + for m in members: + role_key = m.role.value if hasattr(m.role, "value") else str(m.role) + if role_key in by_role: + by_role[role_key].append({ + "user_id": m.user_id, + "email": m.user.email if m.user else None, + "full_name": m.user.full_name if m.user else None, + "joined_at": m.created_at.isoformat() if m.created_at else None, + }) + + roles = [ + {"role": r.value, "member_count": len(by_role[r.value]), "members": by_role[r.value]} + for r in OrganizationRole + ] + return api_response(data={"roles": roles, "organization_id": org_id}, message="Roles retrieved") + + +@api_v1_bp.route("/organizations//roles//members", methods=["POST"]) +@login_required +@require_admin +def assign_role_to_member(org_id, role_name): + from gatehouse_app.models.organization.organization_member import OrganizationMember + + try: + new_role = OrganizationRole(role_name.lower()) + except ValueError: + valid = [r.value for r in OrganizationRole] + return api_response(success=False, message=f"Invalid role. Must be one of: {valid}", status=400, error_type="VALIDATION_ERROR") + + data = request.get_json() or {} + target_user_id = data.get("user_id") + if not target_user_id: + return api_response(success=False, message="user_id is required", status=400, error_type="VALIDATION_ERROR") + + membership = OrganizationMember.query.filter_by(organization_id=org_id, user_id=target_user_id, deleted_at=None).first() + if not membership: + return api_response(success=False, message="Member not found in this organization", status=404, error_type="NOT_FOUND") + + membership.role = new_role + db.session.commit() + return api_response(data={"user_id": target_user_id, "role": new_role.value}, message=f"Role updated to {new_role.value}") + + +@api_v1_bp.route("/organizations//roles//members/", methods=["DELETE"]) +@login_required +@require_admin +@full_access_required +def remove_role_from_member(org_id, role_name, user_id): + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.services.organization_service import OrganizationService + + try: + OrganizationRole(role_name.lower()) + except ValueError: + valid = [r.value for r in OrganizationRole] + return api_response(success=False, message=f"Invalid role. Must be one of: {valid}", status=400, error_type="VALIDATION_ERROR") + + membership = OrganizationMember.query.filter_by(organization_id=org_id, user_id=user_id, deleted_at=None).first() + if not membership: + return api_response(success=False, message="Member not found in this organization", status=404, error_type="NOT_FOUND") + + org = OrganizationService.get_organization_by_id(org_id) + OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id) + return api_response(data={"user_id": user_id}, message="Member removed from organization") diff --git a/gatehouse_app/api/v1/ssh.py b/gatehouse_app/api/v1/ssh.py deleted file mode 100644 index 35d2ac6..0000000 --- a/gatehouse_app/api/v1/ssh.py +++ /dev/null @@ -1,1418 +0,0 @@ -"""SSH Key and Certificate API routes.""" -from flask import Blueprint, request, jsonify, g -from sqlalchemy.exc import IntegrityError -from gatehouse_app.services.ssh_key_service import SSHKeyService -from gatehouse_app.services.ssh_ca_signing_service import ( - SSHCASigningService, - SSHCertificateSigningRequest, -) -from gatehouse_app.exceptions import ( - SSHKeyError, - SSHKeyNotFoundError, - SSHCertificateError, - ValidationError, - SSHKeyAlreadyExistsError, -) -from gatehouse_app.utils.constants import AuditAction -from gatehouse_app.models import AuditLog -from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog -from gatehouse_app.utils.decorators import login_required -from gatehouse_app.utils.response import api_response - -ssh_bp = Blueprint('ssh', __name__, url_prefix='/ssh') -ssh_key_service = SSHKeyService() -ssh_ca_service = SSHCASigningService() - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _get_org_ca_for_user(user, ca_type: str = "user"): - """Return the active DB CA of the given type for the user's first org, or None. - - Args: - user: The current user object. - ca_type: ``"user"`` (default) or ``"host"`` — selects the CA that signs - the corresponding certificate type. - """ - try: - from gatehouse_app.models.ssh_ca.ca import CA, CaType - org_ids = [m.organization_id for m in user.organization_memberships] - if not org_ids: - return None - return CA.query.filter( - CA.organization_id.in_(org_ids), - CA.ca_type == CaType(ca_type), - CA.is_active == True, # noqa: E712 - ).first() - except Exception: - return None - - -def _get_or_create_system_ca(): - """ - Return a CA DB record representing the config-file CA. - - This is used as the ``ca_id`` FK when persisting certificates that were - signed by the globally-configured CA key (not an org-specific DB CA). - The record is created on first use and has no ``organization_id``. - """ - from gatehouse_app.extensions import db - from gatehouse_app.models.ssh_ca.ca import CA, KeyType - from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config - from gatehouse_app.utils.crypto import compute_ssh_fingerprint - import os - - try: - existing = CA.query.filter_by(name="system-config-ca").first() - if existing: - return existing - - cfg = get_ssh_ca_config() - key_path = cfg.get_str("ca_key_path", "").strip() - pub_key_path = key_path + ".pub" - - if not os.path.exists(pub_key_path): - return None - - with open(pub_key_path) as f: - pub_key = f.read().strip() - - # Load private key for the record (encrypt before storing in DB) - priv_key = "" - if os.path.exists(key_path): - with open(key_path) as f: - raw_priv_key = f.read() - try: - from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key - priv_key = encrypt_ca_key(raw_priv_key) - except Exception: - priv_key = raw_priv_key # fallback: store as-is if encryption unavailable - - fingerprint = compute_ssh_fingerprint(pub_key) - - # Check by fingerprint in case it was created under a different name - existing_by_fp = CA.query.filter_by(fingerprint=fingerprint).first() - if existing_by_fp: - return existing_by_fp - - system_ca = CA( - name="system-config-ca", - description="Global CA loaded from etc/ssh_ca.conf (ca_key_path)", - key_type=KeyType.ED25519, - private_key=priv_key, - public_key=pub_key, - fingerprint=fingerprint, - is_active=True, - default_cert_validity_hours=24, - max_cert_validity_hours=720, - ) - # organization_id is nullable=False in schema — we need a dummy org or - # need to allow NULL. Use None; the DB constraint will tell us quickly. - # If the migration enforces NOT NULL we'll catch the error gracefully. - db.session.add(system_ca) - db.session.commit() - return system_ca - except Exception as exc: - import logging - logging.getLogger(__name__).warning( - f"Could not upsert system-config-ca: {exc}" - ) - try: - db.session.rollback() - except Exception: - pass - return None - - -def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user', cert_identity=None): - """Save a signed certificate to the ssh_certificates table. - - Args: - user_id: UUID of the user - ssh_key_id: UUID of the SSH key that was signed. May be None for host - certificates issued against a raw public key (no pre-registered - SSHKey DB record). When None the record is still persisted - but ``ssh_key_id`` is left NULL (requires nullable FK migration). - ca: CA model instance (may be None — cert still returned but not persisted) - signing_response: SSHCertificateSigningResponse - request_ip: Client IP address - cert_type_str: 'user' or 'host' (from the sign request) - cert_identity: Rich OpenSSH key_id string (e.g. "user@host (Name) [org:slug]"). - Falls back to str(ssh_key_id) when not provided. - - Returns: - SSHCertificate instance or None if persistence failed - """ - if ca is None: - return None - - try: - from gatehouse_app.extensions import db - from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus - from gatehouse_app.models.ssh_ca.ca import CertType - - try: - resolved_cert_type = CertType(cert_type_str) - except ValueError: - resolved_cert_type = CertType.USER - - cert_record = SSHCertificate( - ca_id=ca.id, - user_id=user_id, - ssh_key_id=ssh_key_id, # None is OK for host certs (nullable FK) - certificate=signing_response.certificate, - serial=signing_response.serial, - key_id=cert_identity or (str(ssh_key_id) if ssh_key_id else "host-cert"), - cert_type=resolved_cert_type, - principals=signing_response.principals, - valid_after=signing_response.valid_after, - valid_before=signing_response.valid_before, - revoked=False, - status=CertificateStatus.ISSUED, - request_ip=request_ip, - ) - db.session.add(cert_record) - db.session.commit() - return cert_record - except Exception as exc: - import logging - logging.getLogger(__name__).warning( - f"Failed to persist certificate to DB: {exc}" - ) - try: - from gatehouse_app.extensions import db as _db - _db.session.rollback() - except Exception: - pass - return None - - - -def _get_merged_dept_cert_policy(user_id): - """Return a merged cert policy view for the given user across all their departments. - - Rules for merging when a user belongs to multiple departments: - - ``allow_user_expiry``: True only if ALL departments allow it. - - ``default_expiry_hours``: minimum across departments (most restrictive). - - ``max_expiry_hours``: minimum across departments (most restrictive). - - ``extensions``: intersection — only extensions allowed by ALL departments. - - Returns a plain dict with keys: - allow_user_expiry, default_expiry_hours, max_expiry_hours, extensions - Or None if the user has no department memberships or no policies are configured. - """ - from gatehouse_app.models.organization.department import DepartmentMembership - from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy, STANDARD_EXTENSIONS - - memberships = DepartmentMembership.query.filter_by(user_id=user_id, deleted_at=None).all() - dept_ids = [m.department_id for m in memberships if m.department and m.department.deleted_at is None] - if not dept_ids: - return None - - policies = DepartmentCertPolicy.query.filter( - DepartmentCertPolicy.department_id.in_(dept_ids), - DepartmentCertPolicy.deleted_at.is_(None), - ).all() - if not policies: - return None - - allow_user_expiry = all(p.allow_user_expiry for p in policies) - default_expiry_hours = min(p.default_expiry_hours for p in policies) - max_expiry_hours = min(p.max_expiry_hours for p in policies) - - # Intersection of all_extensions() across policies - ext_sets = [set(p.all_extensions()) for p in policies] - extensions = list(ext_sets[0].intersection(*ext_sets[1:])) - - return { - "allow_user_expiry": allow_user_expiry, - "default_expiry_hours": default_expiry_hours, - "max_expiry_hours": max_expiry_hours, - "extensions": extensions, - } - - -@ssh_bp.route('/dept-cert-policy', methods=['GET']) -@login_required -def get_my_dept_cert_policy(): - """Return the merged department certificate policy for the current user. - - Admins always get allow_user_expiry=True so the frontend shows the expiry - picker for them regardless of the member-facing toggle setting. - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.models.organization.department_cert_policy import STANDARD_EXTENSIONS - from gatehouse_app.utils.constants import OrganizationRole - - user = g.current_user - user_id = user.id - - # Check if caller is an org admin/owner - is_org_admin = OrganizationMember.query.filter( - OrganizationMember.user_id == user_id, - OrganizationMember.role.in_(["OWNER", "ADMIN"]), - OrganizationMember.deleted_at == None, - ).first() is not None - - policy = _get_merged_dept_cert_policy(user_id) - if policy is None: - policy = { - "allow_user_expiry": is_org_admin, # admins default to True even without a dept policy - "default_expiry_hours": 1, - "max_expiry_hours": 24, - "extensions": list(STANDARD_EXTENSIONS), - } - elif is_org_admin: - # Override allow_user_expiry for admins — they can always pick - policy = {**policy, "allow_user_expiry": True} - - return api_response(data={"policy": policy}, message="Certificate policy retrieved") - - -@ssh_bp.route('/keys', methods=['GET']) -@login_required -def list_ssh_keys(): - """Get all SSH keys for current user.""" - user_id = g.current_user.id - - keys = ssh_key_service.get_user_ssh_keys(user_id) - return api_response( - data={ - 'keys': [k.to_dict() for k in keys], - 'count': len(keys), - }, - message="SSH keys retrieved successfully" - ) - - -@ssh_bp.route('/keys', methods=['POST']) -@login_required -def add_ssh_key(): - """Add a new SSH public key for current user.""" - user_id = g.current_user.id - - data = request.get_json() - if not data: - return api_response(success=False, message='No JSON data provided', status=400, error_type='BAD_REQUEST') - - public_key = data.get('public_key') or data.get('key') - description = data.get('description') - - if not public_key: - return api_response(success=False, message='public_key is required', status=400, error_type='BAD_REQUEST') - - try: - ssh_key = ssh_key_service.add_ssh_key( - user_id=user_id, - public_key=public_key, - description=description, - ) - - AuditLog.log( - action=AuditAction.SSH_KEY_ADDED, - user_id=user_id, - resource_type='SSHKey', - resource_id=ssh_key.id, - ip_address=request.remote_addr, - ) - - return api_response(success=True, message='SSH key added', data=ssh_key.to_dict(), status=201) - - except SSHKeyAlreadyExistsError as e: - return api_response(success=False, message=e.message, status=409, error_type='SSH_KEY_ALREADY_EXISTS') - except IntegrityError: - return api_response(success=False, message='SSH key already exists', status=409, error_type='SSH_KEY_ALREADY_EXISTS') - except SSHKeyError as e: - return api_response(success=False, message=str(e), status=400, error_type='SSH_KEY_ERROR') - except ValidationError as e: - return api_response(success=False, message=str(e), status=400, error_type='VALIDATION_ERROR') - - -@ssh_bp.route('/keys/', methods=['GET']) -@login_required -def get_ssh_key(key_id): - """Get a specific SSH key.""" - user_id = g.current_user.id - - try: - ssh_key = ssh_key_service.get_ssh_key(key_id) - - if ssh_key.user_id != user_id: - return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') - - return api_response(success=True, message='SSH key retrieved', data=ssh_key.to_dict(), status=200) - - except SSHKeyNotFoundError: - return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') - - -@ssh_bp.route('/keys/', methods=['DELETE']) -@login_required -def delete_ssh_key(key_id): - """Delete an SSH key.""" - user_id = g.current_user.id - - try: - ssh_key = ssh_key_service.get_ssh_key(key_id) - - if ssh_key.user_id != user_id: - return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') - - ssh_key_service.delete_ssh_key(key_id) - - AuditLog.log( - action=AuditAction.SSH_KEY_DELETED, - user_id=user_id, - resource_type='SSHKey', - resource_id=key_id, - ip_address=request.remote_addr, - ) - - return api_response(success=True, message='SSH key deleted', data={'status': 'deleted'}, status=200) - - except SSHKeyNotFoundError: - return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') - - -@ssh_bp.route('/keys//verify', methods=['GET', 'POST']) -@login_required -def verify_ssh_key(key_id): - """Generate or verify SSH key ownership challenge.""" - user_id = g.current_user.id - - try: - ssh_key = ssh_key_service.get_ssh_key(key_id) - - if ssh_key.user_id != user_id: - return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') - - # GET — return a fresh challenge - if request.method == 'GET': - challenge = ssh_key_service.generate_verification_challenge(key_id) - return api_response(success=True, message='Challenge generated', data={ - 'challenge_text': challenge, - 'validationText': challenge, - 'key_id': key_id, - }, status=200) - - # POST — verify signature or generate challenge - data = request.get_json() or {} - action = data.get('action', 'verify_signature') - - if action == 'verify_signature': - signature = data.get('signature') - if not signature: - return api_response(success=False, message='signature is required', status=400, error_type='BAD_REQUEST') - - try: - verified = ssh_key_service.verify_ssh_key_ownership(key_id, signature) - - AuditLog.log( - action=AuditAction.SSH_KEY_VERIFIED, - user_id=user_id, - resource_type='SSHKey', - resource_id=key_id, - ip_address=request.remote_addr, - success=verified, - ) - - return api_response(success=True, message='Verification complete', data={'verified': verified}, status=200) - - except Exception as e: - AuditLog.log( - action=AuditAction.SSH_KEY_VALIDATION_FAILED, - user_id=user_id, - resource_type='SSHKey', - resource_id=key_id, - ip_address=request.remote_addr, - success=False, - error_message=str(e), - ) - return api_response(success=False, message=str(e), status=400, error_type='VERIFICATION_FAILED') - - else: # generate_challenge - challenge = ssh_key_service.generate_verification_challenge(key_id) - return api_response(success=True, message='Challenge generated', data={ - 'challenge_text': challenge, - 'challenge': challenge, - }, status=200) - - except SSHKeyNotFoundError: - return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') - - -@ssh_bp.route('/keys//update-description', methods=['PATCH']) -@login_required -def update_ssh_key_description(key_id): - """Update SSH key description.""" - user_id = g.current_user.id - - data = request.get_json() - if not data or 'description' not in data: - return api_response(success=False, message='description is required', status=400, error_type='BAD_REQUEST') - - try: - ssh_key = ssh_key_service.get_ssh_key(key_id) - - if ssh_key.user_id != user_id: - return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') - - updated_key = ssh_key_service.update_ssh_key_description(key_id, data['description']) - - return api_response(success=True, message='Description updated', data=updated_key.to_dict(), status=200) - - except SSHKeyNotFoundError: - return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') - - -@ssh_bp.route('/sign', methods=['POST']) -@login_required -def sign_certificate(): - """Sign an SSH certificate for the current user.""" - user = g.current_user - user_id = user.id - - # ── Check account suspension ────────────────────────────────────────────── - from gatehouse_app.utils.constants import UserStatus - if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): - return api_response( - success=False, - message="Your account is suspended. Contact an administrator.", - status=403, - error_type="ACCOUNT_SUSPENDED", - ) - - data = request.get_json() - if not data: - return api_response(success=False, message="No JSON data provided", status=400, error_type="BAD_REQUEST") - - requested_principals = data.get('principals') or [] - cert_type = data.get('cert_type', 'user') - key_id = data.get('key_id') or data.get('cert_id') - expiry_hours = data.get('expiry_hours') - - # ── Log the request ─────────────────────────────────────────────────────── - AuditLog.log( - action=AuditAction.SSH_CERT_REQUESTED, - user_id=user_id, - resource_type='SSHCertificate', - ip_address=request.remote_addr, - description=( - f'{user.email} requested a certificate' - + (f' for principals: {", ".join(requested_principals)}' if requested_principals else '') - ), - ) - - # ── Resolve which principals the user is allowed to use ────────────────── - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.models.organization.principal import Principal, PrincipalMembership - from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal - from gatehouse_app.utils.constants import OrganizationRole - - allowed_principal_names = set() - - memberships = OrganizationMember.query.filter_by(user_id=user_id).all() - for om in memberships: - org = om.organization - if not org or org.deleted_at is not None: - continue - role = om.role - if role in (OrganizationRole.ADMIN, OrganizationRole.OWNER): - # Admin/owner can use any principal in the org - for p in Principal.query.filter_by(organization_id=org.id, deleted_at=None).all(): - allowed_principal_names.add(p.name) - else: - # Direct memberships - for pm in PrincipalMembership.query.filter_by(user_id=user_id, deleted_at=None).all(): - if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None: - allowed_principal_names.add(pm.principal.name) - # Via department - for dm in DepartmentMembership.query.filter_by(user_id=user_id, deleted_at=None).all(): - if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None: - for dp in DepartmentPrincipal.query.filter_by(department_id=dm.department_id, deleted_at=None).all(): - if dp.principal and dp.principal.deleted_at is None: - allowed_principal_names.add(dp.principal.name) - - # ── Determine final principals list ───────────────────────────────────── - if not requested_principals: - # Auto-resolve: use all principals the user is assigned to - principals = list(allowed_principal_names) - if not principals: - return api_response( - success=False, - message="You have no principals assigned. Ask an admin to add you to a principal.", - status=400, - error_type="NO_PRINCIPALS", - ) - else: - # Validate each requested principal is within the user's allowed set - invalid = [p for p in requested_principals if p not in allowed_principal_names] - if invalid: - return api_response( - success=False, - message=f"You are not authorised to request principals: {', '.join(invalid)}", - status=403, - error_type="UNAUTHORIZED_PRINCIPALS", - ) - principals = requested_principals - - # ── Key resolution ──────────────────────────────────────────────────────── - if not key_id: - verified_keys = ssh_key_service.get_user_verified_ssh_keys(user_id) - if not verified_keys: - return api_response( - success=False, - message="No verified SSH keys found. Verify a key before requesting a certificate.", - status=400, - error_type="NO_VERIFIED_KEYS", - ) - key_id = verified_keys[0].id - - try: - ssh_key = ssh_key_service.get_ssh_key(key_id) - except SSHKeyNotFoundError: - return api_response(success=False, message="SSH key not found", status=404, error_type="NOT_FOUND") - - if ssh_key.user_id != user_id: - return api_response(success=False, message="Forbidden", status=403, error_type="FORBIDDEN") - - if not ssh_key.verified: - return api_response( - success=False, - message="SSH key is not verified. Verify it before requesting a certificate.", - status=400, - error_type="KEY_NOT_VERIFIED", - ) - - db_ca = _get_org_ca_for_user(user, ca_type=cert_type) - if db_ca is None: - return api_response( - success=False, - message=( - "No active Certificate Authority is configured for your organization. " - "An admin must generate a CA on the Certificate Authorities page before " - "certificates can be issued." - ), - status=503, - error_type="CA_NOT_CONFIGURED", - ) - - # Determine if the caller is an org admin/owner (admins can always choose expiry) - is_org_admin = any( - om.role in (OrganizationRole.ADMIN, OrganizationRole.OWNER) - for om in memberships - if om.organization and om.organization.deleted_at is None - ) - - # ── Apply department certificate policy ─────────────────────────────────── - dept_policy = _get_merged_dept_cert_policy(user_id) - if dept_policy: - if is_org_admin: - # Admins can always choose their own expiry, but still capped at dept max - if expiry_hours is not None: - expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"]) - elif not dept_policy["allow_user_expiry"]: - # Regular members: ignore user-requested expiry; use dept default - expiry_hours = dept_policy["default_expiry_hours"] - else: - # Regular members allowed to pick, cap at dept maximum - if expiry_hours is not None: - expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"]) - policy_extensions = dept_policy["extensions"] - else: - policy_extensions = None # let signing service use its own defaults - - # ── Build rich key_id identity for the OpenSSH cert ───────────────────── - # This appears in `ssh-keygen -L -f cert.pub` as the Key ID field and - # is stored in the DB cert record so it's auditable. - org_slugs = sorted({ - om.organization.slug - for om in memberships - if om.organization and om.organization.deleted_at is None - and getattr(om.organization, 'slug', None) - }) - org_slug = org_slugs[0] if org_slugs else "unknown" - full_name = getattr(user, 'full_name', None) or getattr(user, 'name', None) or "unknown" - cert_identity = f"{user.email} ({full_name}) [org:{org_slug}]" - - signing_request = SSHCertificateSigningRequest( - ssh_public_key=ssh_key.payload, - principals=principals, - cert_type=cert_type, - key_id=cert_identity, - expiry_hours=int(expiry_hours) if expiry_hours else None, - extensions=policy_extensions, - ) - validation_errors = signing_request.validate() - if validation_errors: - return api_response( - success=False, - message="Invalid signing request", - status=400, - error_type="VALIDATION_ERROR", - error_details={"errors": validation_errors}, - ) - - try: - from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key - ca_private_key_pem = decrypt_ca_key(db_ca.private_key) - response = ssh_ca_service.sign_certificate( - signing_request, ca_private_key=ca_private_key_pem, ca_obj=db_ca - ) - except SSHCertificateError as e: - AuditLog.log( - action=AuditAction.SSH_CERT_FAILED, - user_id=user_id, - resource_type='SSHCertificate', - ip_address=request.remote_addr, - success=False, - error_message=str(e), - ) - return api_response(success=False, message=str(e), status=400, error_type="SIGNING_FAILED") - except Exception as e: - AuditLog.log( - action=AuditAction.SSH_CERT_FAILED, - user_id=user_id, - resource_type='SSHCertificate', - ip_address=request.remote_addr, - success=False, - error_message=str(e), - ) - return api_response(success=False, message="Certificate signing failed", status=500, error_type="SERVER_ERROR") - - cert_record = _persist_certificate( - user_id=user_id, - ssh_key_id=key_id, - ca=db_ca, - signing_response=response, - request_ip=request.remote_addr, - cert_type_str=cert_type, - cert_identity=cert_identity, - ) - - AuditLog.log( - action=AuditAction.SSH_CERT_ISSUED, - user_id=user_id, - resource_type='SSHCertificate', - resource_id=cert_record.id if cert_record else key_id, - ip_address=request.remote_addr, - description=( - f'Certificate serial={response.serial} issued for {user.email}; ' - f'principals: {", ".join(principals)}' - ), - extra_data={ - 'serial': response.serial, - 'key_id': cert_identity, - 'principals': principals, - 'ca_id': str(db_ca.id), - 'ssh_key_id': str(key_id), - }, - ) - - if cert_record: - CertificateAuditLog.log( - certificate_id=cert_record.id, - action='issued', - user_id=user_id, - ip_address=request.remote_addr, - user_agent=request.headers.get('User-Agent'), - message=( - f'Certificate serial={response.serial} issued for {user.email}; ' - f'principals: {", ".join(principals)}' - ), - extra_data={ - 'serial': response.serial, - 'key_id': cert_identity, - 'principals': principals, - 'ca_id': str(db_ca.id), - 'ssh_key_id': str(key_id), - 'valid_after': response.valid_after.isoformat() if response.valid_after else None, - 'valid_before': response.valid_before.isoformat() if response.valid_before else None, - }, - success=True, - ) - - result = { - 'certificate': response.certificate, - 'serial': response.serial, - 'principals': response.principals, - 'valid_after': response.valid_after.isoformat() if response.valid_after else None, - 'valid_before': response.valid_before.isoformat() if response.valid_before else None, - } - if cert_record: - result['cert_id'] = str(cert_record.id) - - return api_response(data=result, message="Certificate signed successfully", status=201) - - -# --------------------------------------------------------------------------- -# Host certificate issuance (admin-only) -# --------------------------------------------------------------------------- - -def _classify_ssh_key_material(raw: str) -> str: - """Classify a raw SSH key string. - - Returns one of: 'certificate', 'public_key', 'private_key', 'unknown'. - This mirrors the frontend ``classifySshKeyMaterial`` helper so that the - API produces the same guardrails even when called directly (e.g. via CLI). - """ - import re - line = raw.strip().split()[0] if raw.strip() else "" - if re.search(r"-cert-v01@openssh\.com$", line): - return "certificate" - if re.match( - r"^(ssh-ed25519|ssh-rsa|ssh-dss|ecdsa-sha2-nistp\d+|sk-ssh-ed25519@openssh\.com)$", - line, - ): - return "public_key" - if "BEGIN OPENSSH PRIVATE KEY" in raw or "BEGIN RSA PRIVATE KEY" in raw: - return "private_key" - return "unknown" - - -@ssh_bp.route('/sign/host', methods=['POST']) -@login_required -def sign_host_certificate(): - """Issue a host certificate for a server's host public key. - - This endpoint is admin-only. It accepts a raw OpenSSH host public key - (the kind found in ``/etc/ssh/ssh_host_ed25519_key.pub``), signs it with - the organisation's Host CA, and returns the signed host certificate. - - The certificate should be saved on the server as - ``/etc/ssh/ssh_host_ed25519_key-cert.pub`` and referenced in - ``sshd_config`` as ``HostCertificate``. - - Clients trust the host because they have the Host CA *public key* in their - ``known_hosts`` (via ``@cert-authority``). That key is different from — - and must never be confused with — the certificate returned here. - - Request body (JSON): - host_public_key (str, required): - Raw OpenSSH host public key, e.g. - "ssh-ed25519 AAAA... root@server". - Must NOT be a certificate (ssh-*-cert-v01@openssh.com) or a - private key. - principals (list[str], required): - Hostnames / FQDNs the server is known by, e.g. - ["prod.example.com", "web01.internal"]. - These must match what SSH clients use in their connection target. - validity_hours (int, optional, default=720): - Certificate validity in hours. Host certs are typically - 30 days (720 h) to 1 year (8760 h). - ca_id (str, required): - UUID of the Host CA to sign with. Must be a ``ca_type=host`` CA - belonging to the caller's organisation. - - Returns (201): - certificate, serial, principals, valid_after, valid_before - - Errors: - 400 BAD_REQUEST — pasted material is a cert / private key / unknown - 403 FORBIDDEN — caller is not an org admin/owner - 404 CA_NOT_FOUND — ca_id does not exist or is not a host CA - 422 VALIDATION_ERROR — invalid principals, validity, or public key - 503 CA_NOT_CONFIGURED - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.models.ssh_ca.ca import CA, CaType - from gatehouse_app.utils.constants import OrganizationRole - from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key - - user = g.current_user - user_id = user.id - - # ── Admin-only gate ─────────────────────────────────────────────────────── - is_admin = OrganizationMember.query.filter( - OrganizationMember.user_id == user_id, - OrganizationMember.role.in_([OrganizationRole.ADMIN, OrganizationRole.OWNER]), - OrganizationMember.deleted_at.is_(None), - ).first() is not None - - if not is_admin: - return api_response( - success=False, - message="Issuing host certificates requires org admin or owner role.", - status=403, - error_type="FORBIDDEN", - ) - - data = request.get_json() - if not data: - return api_response(success=False, message="No JSON data provided", status=400, error_type="BAD_REQUEST") - - host_public_key = (data.get("host_public_key") or "").strip() - principals = data.get("principals") or [] - validity_hours = data.get("validity_hours", 720) - ca_id = (data.get("ca_id") or "").strip() - - # ── Validate host public key material ───────────────────────────────────── - if not host_public_key: - return api_response( - success=False, - message="host_public_key is required.", - status=400, - error_type="BAD_REQUEST", - ) - - key_kind = _classify_ssh_key_material(host_public_key) - if key_kind == "certificate": - return api_response( - success=False, - message=( - "You submitted a certificate (ssh-…-cert-v01@openssh.com), not a host public key. " - "Retrieve the server's host public key with: " - "cat /etc/ssh/ssh_host_ed25519_key.pub" - ), - status=400, - error_type="WRONG_KEY_MATERIAL", - ) - if key_kind == "private_key": - return api_response( - success=False, - message="Private keys must never be submitted here. Use the .pub file.", - status=400, - error_type="WRONG_KEY_MATERIAL", - ) - if key_kind == "unknown": - return api_response( - success=False, - message=( - "Unrecognised key format. " - "Expected an OpenSSH public key starting with ssh-ed25519, ssh-rsa, or ecdsa-sha2-*." - ), - status=400, - error_type="WRONG_KEY_MATERIAL", - ) - - # ── Validate principals ─────────────────────────────────────────────────── - if not principals or not isinstance(principals, list): - return api_response( - success=False, - message="principals must be a non-empty list of hostnames.", - status=422, - error_type="VALIDATION_ERROR", - ) - principals = [str(p).strip() for p in principals if str(p).strip()] - if not principals: - return api_response( - success=False, - message="At least one principal (hostname/FQDN) is required.", - status=422, - error_type="VALIDATION_ERROR", - ) - - # ── Validate validity ───────────────────────────────────────────────────── - try: - validity_hours = int(validity_hours) - if validity_hours < 1: - raise ValueError - except (TypeError, ValueError): - return api_response( - success=False, - message="validity_hours must be a positive integer.", - status=422, - error_type="VALIDATION_ERROR", - ) - - # ── Resolve CA ──────────────────────────────────────────────────────────── - if not ca_id: - return api_response( - success=False, - message="ca_id is required.", - status=400, - error_type="BAD_REQUEST", - ) - - org_ids = [ - m.organization_id - for m in OrganizationMember.query.filter_by(user_id=user_id, deleted_at=None).all() - ] - - # First: find the CA by id (ignoring type) so we can give a specific error - # if it exists but is the wrong type. - any_ca = CA.query.filter( - CA.id == ca_id, - CA.is_active.is_(True), - CA.organization_id.in_(org_ids), - CA.deleted_at.is_(None), - ).first() - - if any_ca and any_ca.ca_type != CaType.HOST: - return api_response( - success=False, - message=( - f"The CA '{any_ca.name}' is a {any_ca.ca_type.value} CA. " - "Host certificates must be signed by a ca_type='host' CA." - ), - status=400, - error_type="WRONG_CA_TYPE", - ) - - host_ca = any_ca # already filtered for org + active + not-deleted above - - if not host_ca: - return api_response( - success=False, - message=( - "Host CA not found, inactive, or you do not have permission to use it. " - "Ensure the CA exists and ca_type is 'host'." - ), - status=404, - error_type="CA_NOT_FOUND", - ) - - # ── Build key_id for the OpenSSH cert Key ID field ──────────────────────── - # Format: "host: [signed-by:]" - primary_principal = principals[0] - cert_identity = f"host:{primary_principal} [signed-by:{user.email}]" - - signing_request = SSHCertificateSigningRequest( - ssh_public_key=host_public_key, - principals=principals, - cert_type="host", - key_id=cert_identity, - expiry_hours=validity_hours, - extensions=[], # Host certs carry no extensions (OpenSSH spec) - critical_options={}, - ) - - validation_errors = signing_request.validate() - if validation_errors: - return api_response( - success=False, - message="Invalid signing request: " + "; ".join(validation_errors), - status=422, - error_type="VALIDATION_ERROR", - ) - - try: - ca_private_key_pem = decrypt_ca_key(host_ca.private_key) - response = ssh_ca_service.sign_certificate( - signing_request, ca_private_key=ca_private_key_pem, ca_obj=host_ca - ) - except Exception as exc: - AuditLog.log( - action=AuditAction.SSH_CERT_FAILED, - user_id=user_id, - resource_type="SSHCertificate", - ip_address=request.remote_addr, - success=False, - error_message=str(exc), - ) - return api_response( - success=False, - message=f"Host certificate signing failed: {exc}", - status=500, - error_type="SIGNING_FAILED", - ) - - # Persist a cert record linked to the issuing admin (no ssh_key_id FK - # because this was a raw key, not a registered user key). - # We reuse _persist_certificate with ssh_key_id=ca_id as a stable sentinel. - cert_record = _persist_certificate( - user_id=user_id, - ssh_key_id=None, # host certs are not tied to a user SSH key record - ca=host_ca, - signing_response=response, - request_ip=request.remote_addr, - cert_type_str="host", - cert_identity=cert_identity, - ) - - AuditLog.log( - action=AuditAction.SSH_CERT_ISSUED, - user_id=user_id, - resource_type="SSHCertificate", - resource_id=cert_record.id if cert_record else None, - ip_address=request.remote_addr, - description=( - f"Host certificate serial={response.serial} issued for " - f"{primary_principal} by {user.email}" - ), - extra_data={ - "serial": response.serial, - "principals": principals, - "ca_id": str(host_ca.id), - "cert_type": "host", - }, - ) - - result = { - "certificate": response.certificate, - "serial": response.serial, - "principals": response.principals, - "valid_after": response.valid_after.isoformat() if response.valid_after else None, - "valid_before": response.valid_before.isoformat() if response.valid_before else None, - } - if cert_record: - result["cert_id"] = str(cert_record.id) - - return api_response(data=result, message="Host certificate issued successfully", status=201) - - -@ssh_bp.route('/certificates', methods=['GET']) -@login_required -def list_certificates(): - """List all SSH certificates issued for the current user.""" - user_id = g.current_user.id - - try: - from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate - certs = ( - SSHCertificate.query - .filter_by(user_id=user_id, deleted_at=None) - .order_by(SSHCertificate.created_at.desc()) - .all() - ) - - return api_response( - data={ - 'certificates': [c.to_dict() for c in certs], - 'count': len(certs), - }, - message="Certificates retrieved successfully" - ) - except Exception as e: - return api_response( - success=False, - message=str(e), - status=500, - error_type='INTERNAL_ERROR' - ) - - -@ssh_bp.route('/certificates/', methods=['GET']) -@login_required -def get_certificate(cert_id): - """Get a specific issued certificate (metadata only).""" - user_id = g.current_user.id - - try: - from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate - cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first() - if not cert: - return api_response(success=False, message='Certificate not found', status=404, error_type='NOT_FOUND') - if cert.user_id != user_id: - return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') - data = cert.to_dict() - data['certificate'] = cert.certificate - return api_response(success=True, message='Certificate retrieved', data=data, status=200) - except Exception as e: - return api_response(success=False, message=str(e), status=500, error_type='INTERNAL_ERROR') - - -@ssh_bp.route('/certificates//revoke', methods=['POST']) -@login_required -def revoke_certificate(cert_id): - """Revoke an issued certificate.""" - user_id = g.current_user.id - - data = request.get_json() or {} - reason = data.get('reason', 'User requested revocation') - - try: - from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate - cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first() - if not cert: - return api_response(success=False, message='Certificate not found', status=404, error_type='NOT_FOUND') - if cert.user_id != user_id: - return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') - if cert.revoked: - return api_response(success=False, message='Certificate is already revoked', status=409, error_type='ALREADY_REVOKED') - - cert.revoke(reason=reason) - - AuditLog.log( - action=AuditAction.SSH_CERT_REVOKED, - user_id=user_id, - resource_type='SSHCertificate', - resource_id=cert_id, - ip_address=request.remote_addr, - description=f'Revoked: {reason}', - ) - - CertificateAuditLog.log( - certificate_id=cert_id, - action='revoked', - user_id=user_id, - ip_address=request.remote_addr, - user_agent=request.headers.get('User-Agent'), - message=f'Certificate revoked: {reason}', - success=True, - ) - - return api_response( - success=True, - message='Certificate revoked successfully', - data={'status': 'revoked', 'cert_id': cert_id, 'reason': reason}, - status=200, - ) - except Exception as e: - return api_response(success=False, message=str(e), status=500, error_type='INTERNAL_ERROR') - - -@ssh_bp.route('/ca/public-key', methods=['GET']) -@login_required -def get_ca_public_key(): - """ - Return the CA public key for this user's organization. - - Server admins should add this key to their host's ``TrustedUserCAKeys`` - directive so that certificates issued by gatehouse are trusted. - - Query parameters: - ca_type: 'user' (default) or 'host' — which CA's public key to return - format: 'openssh' (default) or 'text' — affects Content-Type only - - Returns: - { "public_key": "ssh-ed25519 AAAA...", - "fingerprint": "SHA256:...", - "ca_name": "..." } - """ - user = g.current_user - ca_type = request.args.get("ca_type", "user") - if ca_type not in ("user", "host"): - return api_response( - success=False, - message="ca_type must be 'user' or 'host'", - status=400, - error_type="BAD_REQUEST", - ) - - db_ca = _get_org_ca_for_user(user, ca_type=ca_type) - if db_ca: - return api_response( - data={ - 'public_key': db_ca.public_key, - 'fingerprint': db_ca.fingerprint, - 'ca_name': db_ca.name, - 'ca_type': ca_type, - 'source': 'db', - }, - message="CA public key retrieved successfully" - ) - - return api_response( - success=False, - message=( - f"No {ca_type} CA is configured for your organization. " - "An admin must generate one on the Certificate Authorities page." - ), - status=404, - error_type="CA_NOT_CONFIGURED", - ) - - -# --------------------------------------------------------------------------- -# CA Permissions -# --------------------------------------------------------------------------- - -@ssh_bp.route('/ca//permissions', methods=['GET']) -@login_required -def list_ca_permissions(ca_id): - """List permissions for a Certificate Authority. - - Returns: - 200: { ca_id, permissions: [...], open_to_all: bool } - 403: Not admin/owner - 404: CA not found - """ - from gatehouse_app.models.ssh_ca.ca import CA, CAPermission - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - - user = g.current_user - - ca = CA.query.filter_by(id=ca_id, deleted_at=None).first() - if not ca: - return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") - - # Verify user is admin/owner of the CA's org - if ca.organization_id: - membership = OrganizationMember.query.filter_by( - organization_id=ca.organization_id, - user_id=user.id, - deleted_at=None, - ).first() - if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER): - return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") - - perms = CAPermission.query.filter_by(ca_id=ca_id, deleted_at=None).all() - perm_list = [] - for p in perms: - d = p.to_dict() - d["user_email"] = p.user.email if p.user else None - perm_list.append(d) - - return api_response( - data={ - "ca_id": ca_id, - "permissions": perm_list, - "open_to_all": len(perms) == 0, - }, - message="CA permissions retrieved", - ) - - -@ssh_bp.route('/ca//permissions', methods=['POST']) -@login_required -def add_ca_permission(ca_id): - """Grant a user permission on a Certificate Authority. - - Request body: - user_id: UUID of the user to grant access - permission: "sign" or "admin" (default: "sign") - - Returns: - 201: Permission granted - 400: Validation error - 403: Not admin/owner - 404: CA or user not found - 409: Permission already exists - """ - from gatehouse_app.models.ssh_ca.ca import CA, CAPermission - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.models.user import User - from gatehouse_app.utils.constants import OrganizationRole, AuditAction - from gatehouse_app.models import AuditLog - from gatehouse_app.extensions import db - - user = g.current_user - - ca = CA.query.filter_by(id=ca_id, deleted_at=None).first() - if not ca: - return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") - - # Verify user is admin/owner of the CA's org - if ca.organization_id: - membership = OrganizationMember.query.filter_by( - organization_id=ca.organization_id, - user_id=user.id, - deleted_at=None, - ).first() - if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER): - return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") - - data = request.get_json() or {} - target_user_id = (data.get("user_id") or "").strip() - permission = data.get("permission", "sign") - - if not target_user_id: - return api_response(success=False, message="user_id is required", status=400, error_type="VALIDATION_ERROR") - if permission not in ("sign", "admin"): - return api_response( - success=False, - message="permission must be 'sign' or 'admin'", - status=400, - error_type="VALIDATION_ERROR", - ) - - target_user = User.query.filter_by(id=target_user_id, deleted_at=None).first() - if not target_user: - return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") - - # Check for duplicate - existing = CAPermission.query.filter_by( - ca_id=ca_id, user_id=target_user_id, deleted_at=None - ).first() - if existing: - # Update permission level if different - if existing.permission != permission: - existing.permission = permission - db.session.commit() - d = existing.to_dict() - d["user_email"] = target_user.email - return api_response( - data={"message": "Permission updated", "permission": d}, - message="Permission updated", - ) - return api_response( - success=False, - message="User already has this permission on the CA", - status=409, - error_type="DUPLICATE", - ) - - perm = CAPermission( - ca_id=ca_id, - user_id=target_user_id, - permission=permission, - ) - db.session.add(perm) - db.session.commit() - - AuditLog.log( - action=AuditAction.CA_UPDATED, - user_id=user.id, - resource_type="CAPermission", - resource_id=perm.id, - ip_address=request.remote_addr, - description=f"Granted '{permission}' on CA '{ca.name}' to user {target_user.email}", - ) - - d = perm.to_dict() - d["user_email"] = target_user.email - return api_response( - data={"message": "Permission granted", "permission": d}, - message="Permission granted", - status=201, - ) - - -@ssh_bp.route('/ca//permissions/', methods=['DELETE']) -@login_required -def remove_ca_permission(ca_id, target_user_id): - """Revoke a user's permission on a Certificate Authority. - - Returns: - 200: Permission revoked - 403: Not admin/owner - 404: CA or permission not found - """ - from gatehouse_app.models.ssh_ca.ca import CA, CAPermission - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole, AuditAction - from gatehouse_app.models import AuditLog - from gatehouse_app.extensions import db - - user = g.current_user - - ca = CA.query.filter_by(id=ca_id, deleted_at=None).first() - if not ca: - return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") - - # Verify user is admin/owner of the CA's org - if ca.organization_id: - membership = OrganizationMember.query.filter_by( - organization_id=ca.organization_id, - user_id=user.id, - deleted_at=None, - ).first() - if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER): - return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") - - perm = CAPermission.query.filter_by( - ca_id=ca_id, user_id=target_user_id, deleted_at=None - ).first() - if not perm: - return api_response(success=False, message="Permission not found", status=404, error_type="NOT_FOUND") - - perm.delete(soft=True) - - AuditLog.log( - action=AuditAction.CA_UPDATED, - user_id=user.id, - resource_type="CAPermission", - resource_id=perm.id, - ip_address=request.remote_addr, - description=f"Revoked permission on CA '{ca.name}' from user {target_user_id}", - ) - - return api_response( - data={}, - message="Permission revoked", - ) - diff --git a/gatehouse_app/api/v1/ssh/__init__.py b/gatehouse_app/api/v1/ssh/__init__.py new file mode 100644 index 0000000..5368075 --- /dev/null +++ b/gatehouse_app/api/v1/ssh/__init__.py @@ -0,0 +1,3 @@ +"""SSH blueprint subpackage. Exports ssh_bp for registration.""" +from gatehouse_app.api.v1.ssh._helpers import ssh_bp +from gatehouse_app.api.v1.ssh import keys, certs, admin diff --git a/gatehouse_app/api/v1/ssh/_helpers.py b/gatehouse_app/api/v1/ssh/_helpers.py new file mode 100644 index 0000000..4e244b1 --- /dev/null +++ b/gatehouse_app/api/v1/ssh/_helpers.py @@ -0,0 +1,174 @@ +"""Shared helpers for the SSH subpackage.""" +import logging +from flask import Blueprint, request, g +from gatehouse_app.services.ssh_key_service import SSHKeyService +from gatehouse_app.services.ssh_ca_signing_service import SSHCASigningService + +ssh_bp = Blueprint('ssh', __name__, url_prefix='/ssh') +ssh_key_service = SSHKeyService() +ssh_ca_service = SSHCASigningService() + +_logger = logging.getLogger(__name__) + + +def _get_org_ca_for_user(user, ca_type: str = "user"): + try: + from gatehouse_app.models.ssh_ca.ca import CA, CaType + org_ids = [m.organization_id for m in user.organization_memberships] + if not org_ids: + return None + return CA.query.filter( + CA.organization_id.in_(org_ids), + CA.ca_type == CaType(ca_type), + CA.is_active == True, # noqa: E712 + ).first() + except Exception: + return None + + +def _get_or_create_system_ca(): + from gatehouse_app.extensions import db + from gatehouse_app.models.ssh_ca.ca import CA, KeyType + from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config + from gatehouse_app.utils.crypto import compute_ssh_fingerprint + import os + + try: + existing = CA.query.filter_by(name="system-config-ca").first() + if existing: + return existing + + cfg = get_ssh_ca_config() + key_path = cfg.get_str("ca_key_path", "").strip() + pub_key_path = key_path + ".pub" + + if not os.path.exists(pub_key_path): + return None + + with open(pub_key_path) as f: + pub_key = f.read().strip() + + priv_key = "" + if os.path.exists(key_path): + with open(key_path) as f: + raw_priv_key = f.read() + try: + from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key + priv_key = encrypt_ca_key(raw_priv_key) + except Exception: + priv_key = raw_priv_key + + fingerprint = compute_ssh_fingerprint(pub_key) + + existing_by_fp = CA.query.filter_by(fingerprint=fingerprint).first() + if existing_by_fp: + return existing_by_fp + + system_ca = CA( + name="system-config-ca", + description="Global CA loaded from etc/ssh_ca.conf (ca_key_path)", + key_type=KeyType.ED25519, + private_key=priv_key, + public_key=pub_key, + fingerprint=fingerprint, + is_active=True, + default_cert_validity_hours=24, + max_cert_validity_hours=720, + ) + db.session.add(system_ca) + db.session.commit() + return system_ca + except Exception as exc: + _logger.warning(f"Could not upsert system-config-ca: {exc}") + try: + db.session.rollback() + except Exception: + pass + return None + + +def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None, cert_type_str='user', cert_identity=None): + if ca is None: + return None + + try: + from gatehouse_app.extensions import db + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus + from gatehouse_app.models.ssh_ca.ca import CertType + + try: + resolved_cert_type = CertType(cert_type_str) + except ValueError: + resolved_cert_type = CertType.USER + + cert_record = SSHCertificate( + ca_id=ca.id, + user_id=user_id, + ssh_key_id=ssh_key_id, + certificate=signing_response.certificate, + serial=signing_response.serial, + key_id=cert_identity or (str(ssh_key_id) if ssh_key_id else "host-cert"), + cert_type=resolved_cert_type, + principals=signing_response.principals, + valid_after=signing_response.valid_after, + valid_before=signing_response.valid_before, + revoked=False, + status=CertificateStatus.ISSUED, + request_ip=request_ip, + ) + db.session.add(cert_record) + db.session.commit() + return cert_record + except Exception as exc: + _logger.warning(f"Failed to persist certificate to DB: {exc}") + try: + from gatehouse_app.extensions import db as _db + _db.session.rollback() + except Exception: + pass + return None + + +def _get_merged_dept_cert_policy(user_id): + from gatehouse_app.models.organization.department import DepartmentMembership + from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy + + memberships = DepartmentMembership.query.filter_by(user_id=user_id, deleted_at=None).all() + dept_ids = [m.department_id for m in memberships if m.department and m.department.deleted_at is None] + if not dept_ids: + return None + + policies = DepartmentCertPolicy.query.filter( + DepartmentCertPolicy.department_id.in_(dept_ids), + DepartmentCertPolicy.deleted_at.is_(None), + ).all() + if not policies: + return None + + allow_user_expiry = all(p.allow_user_expiry for p in policies) + default_expiry_hours = min(p.default_expiry_hours for p in policies) + max_expiry_hours = min(p.max_expiry_hours for p in policies) + ext_sets = [set(p.all_extensions()) for p in policies] + extensions = list(ext_sets[0].intersection(*ext_sets[1:])) + + return { + "allow_user_expiry": allow_user_expiry, + "default_expiry_hours": default_expiry_hours, + "max_expiry_hours": max_expiry_hours, + "extensions": extensions, + } + + +def _classify_ssh_key_material(raw: str) -> str: + import re + line = raw.strip().split()[0] if raw.strip() else "" + if re.search(r"-cert-v01@openssh\.com$", line): + return "certificate" + if re.match( + r"^(ssh-ed25519|ssh-rsa|ssh-dss|ecdsa-sha2-nistp\d+|sk-ssh-ed25519@openssh\.com)$", + line, + ): + return "public_key" + if "BEGIN OPENSSH PRIVATE KEY" in raw or "BEGIN RSA PRIVATE KEY" in raw: + return "private_key" + return "unknown" diff --git a/gatehouse_app/api/v1/ssh/admin.py b/gatehouse_app/api/v1/ssh/admin.py new file mode 100644 index 0000000..1676339 --- /dev/null +++ b/gatehouse_app/api/v1/ssh/admin.py @@ -0,0 +1,111 @@ +"""SSH CA permissions admin endpoints.""" +from flask import request, g +from gatehouse_app.api.v1.ssh._helpers import ssh_bp +from gatehouse_app.utils.constants import AuditAction, OrganizationRole +from gatehouse_app.models import AuditLog +from gatehouse_app.utils.decorators import login_required +from gatehouse_app.utils.response import api_response + + +@ssh_bp.route('/ca//permissions', methods=['GET']) +@login_required +def list_ca_permissions(ca_id): + from gatehouse_app.models.ssh_ca.ca import CA, CAPermission + from gatehouse_app.models.organization.organization_member import OrganizationMember + + user = g.current_user + ca = CA.query.filter_by(id=ca_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + if ca.organization_id: + membership = OrganizationMember.query.filter_by(organization_id=ca.organization_id, user_id=user.id, deleted_at=None).first() + if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER): + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + perms = CAPermission.query.filter_by(ca_id=ca_id, deleted_at=None).all() + perm_list = [] + for p in perms: + d = p.to_dict() + d["user_email"] = p.user.email if p.user else None + perm_list.append(d) + + return api_response(data={"ca_id": ca_id, "permissions": perm_list, "open_to_all": len(perms) == 0}, message="CA permissions retrieved") + + +@ssh_bp.route('/ca//permissions', methods=['POST']) +@login_required +def add_ca_permission(ca_id): + from gatehouse_app.models.ssh_ca.ca import CA, CAPermission + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.user import User + from gatehouse_app.extensions import db + + user = g.current_user + ca = CA.query.filter_by(id=ca_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + if ca.organization_id: + membership = OrganizationMember.query.filter_by(organization_id=ca.organization_id, user_id=user.id, deleted_at=None).first() + if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER): + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + data = request.get_json() or {} + target_user_id = (data.get("user_id") or "").strip() + permission = data.get("permission", "sign") + + if not target_user_id: + return api_response(success=False, message="user_id is required", status=400, error_type="VALIDATION_ERROR") + if permission not in ("sign", "admin"): + return api_response(success=False, message="permission must be 'sign' or 'admin'", status=400, error_type="VALIDATION_ERROR") + + target_user = User.query.filter_by(id=target_user_id, deleted_at=None).first() + if not target_user: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + existing = CAPermission.query.filter_by(ca_id=ca_id, user_id=target_user_id, deleted_at=None).first() + if existing: + if existing.permission != permission: + existing.permission = permission + db.session.commit() + d = existing.to_dict() + d["user_email"] = target_user.email + return api_response(data={"message": "Permission updated", "permission": d}, message="Permission updated") + return api_response(success=False, message="User already has this permission on the CA", status=409, error_type="DUPLICATE") + + perm = CAPermission(ca_id=ca_id, user_id=target_user_id, permission=permission) + db.session.add(perm) + db.session.commit() + + AuditLog.log(action=AuditAction.CA_UPDATED, user_id=user.id, resource_type="CAPermission", resource_id=perm.id, ip_address=request.remote_addr, description=f"Granted '{permission}' on CA '{ca.name}' to user {target_user.email}") + + d = perm.to_dict() + d["user_email"] = target_user.email + return api_response(data={"message": "Permission granted", "permission": d}, message="Permission granted", status=201) + + +@ssh_bp.route('/ca//permissions/', methods=['DELETE']) +@login_required +def remove_ca_permission(ca_id, target_user_id): + from gatehouse_app.models.ssh_ca.ca import CA, CAPermission + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.extensions import db + + user = g.current_user + ca = CA.query.filter_by(id=ca_id, deleted_at=None).first() + if not ca: + return api_response(success=False, message="CA not found", status=404, error_type="NOT_FOUND") + + if ca.organization_id: + membership = OrganizationMember.query.filter_by(organization_id=ca.organization_id, user_id=user.id, deleted_at=None).first() + if not membership or membership.role not in (OrganizationRole.ADMIN, OrganizationRole.OWNER): + return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN") + + perm = CAPermission.query.filter_by(ca_id=ca_id, user_id=target_user_id, deleted_at=None).first() + if not perm: + return api_response(success=False, message="Permission not found", status=404, error_type="NOT_FOUND") + + perm.delete(soft=True) + AuditLog.log(action=AuditAction.CA_UPDATED, user_id=user.id, resource_type="CAPermission", resource_id=perm.id, ip_address=request.remote_addr, description=f"Revoked permission on CA '{ca.name}' from user {target_user_id}") + return api_response(data={}, message="Permission revoked") diff --git a/gatehouse_app/api/v1/ssh/certs.py b/gatehouse_app/api/v1/ssh/certs.py new file mode 100644 index 0000000..429f6f0 --- /dev/null +++ b/gatehouse_app/api/v1/ssh/certs.py @@ -0,0 +1,391 @@ +"""SSH certificate signing and listing endpoints.""" +from flask import request, g +from gatehouse_app.api.v1.ssh._helpers import ( + ssh_bp, ssh_key_service, ssh_ca_service, + _get_org_ca_for_user, _persist_certificate, + _get_merged_dept_cert_policy, _classify_ssh_key_material, +) +from gatehouse_app.services.ssh_ca_signing_service import SSHCertificateSigningRequest +from gatehouse_app.exceptions import SSHKeyNotFoundError, SSHCertificateError +from gatehouse_app.utils.constants import AuditAction, OrganizationRole +from gatehouse_app.models import AuditLog +from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog +from gatehouse_app.utils.decorators import login_required +from gatehouse_app.utils.response import api_response + + +@ssh_bp.route('/dept-cert-policy', methods=['GET']) +@login_required +def get_my_dept_cert_policy(): + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.organization.department_cert_policy import STANDARD_EXTENSIONS + + user = g.current_user + user_id = user.id + + is_org_admin = OrganizationMember.query.filter( + OrganizationMember.user_id == user_id, + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).first() is not None + + policy = _get_merged_dept_cert_policy(user_id) + if policy is None: + policy = {"allow_user_expiry": is_org_admin, "default_expiry_hours": 1, "max_expiry_hours": 24, "extensions": list(STANDARD_EXTENSIONS)} + elif is_org_admin: + policy = {**policy, "allow_user_expiry": True} + + return api_response(data={"policy": policy}, message="Certificate policy retrieved") + + +@ssh_bp.route('/sign', methods=['POST']) +@login_required +def sign_certificate(): + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.organization.principal import Principal, PrincipalMembership + from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal + from gatehouse_app.utils.constants import UserStatus + + user = g.current_user + user_id = user.id + + if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + return api_response(success=False, message="Your account is suspended. Contact an administrator.", status=403, error_type="ACCOUNT_SUSPENDED") + + data = request.get_json() + if not data: + return api_response(success=False, message="No JSON data provided", status=400, error_type="BAD_REQUEST") + + requested_principals = data.get('principals') or [] + cert_type = data.get('cert_type', 'user') + key_id = data.get('key_id') or data.get('cert_id') + expiry_hours = data.get('expiry_hours') + + AuditLog.log( + action=AuditAction.SSH_CERT_REQUESTED, + user_id=user_id, resource_type='SSHCertificate', ip_address=request.remote_addr, + description=(f'{user.email} requested a certificate' + (f' for principals: {", ".join(requested_principals)}' if requested_principals else '')), + ) + + allowed_principal_names = set() + memberships = OrganizationMember.query.filter_by(user_id=user_id).all() + for om in memberships: + org = om.organization + if not org or org.deleted_at is not None: + continue + role = om.role + if role in (OrganizationRole.ADMIN, OrganizationRole.OWNER): + for p in Principal.query.filter_by(organization_id=org.id, deleted_at=None).all(): + allowed_principal_names.add(p.name) + else: + for pm in PrincipalMembership.query.filter_by(user_id=user_id, deleted_at=None).all(): + if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None: + allowed_principal_names.add(pm.principal.name) + for dm in DepartmentMembership.query.filter_by(user_id=user_id, deleted_at=None).all(): + if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None: + for dp in DepartmentPrincipal.query.filter_by(department_id=dm.department_id, deleted_at=None).all(): + if dp.principal and dp.principal.deleted_at is None: + allowed_principal_names.add(dp.principal.name) + + if not requested_principals: + principals = list(allowed_principal_names) + if not principals: + return api_response(success=False, message="You have no principals assigned. Ask an admin to add you to a principal.", status=400, error_type="NO_PRINCIPALS") + else: + invalid = [p for p in requested_principals if p not in allowed_principal_names] + if invalid: + return api_response(success=False, message=f"You are not authorised to request principals: {', '.join(invalid)}", status=403, error_type="UNAUTHORIZED_PRINCIPALS") + principals = requested_principals + + if not key_id: + verified_keys = ssh_key_service.get_user_verified_ssh_keys(user_id) + if not verified_keys: + return api_response(success=False, message="No verified SSH keys found. Verify a key before requesting a certificate.", status=400, error_type="NO_VERIFIED_KEYS") + key_id = verified_keys[0].id + + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + except SSHKeyNotFoundError: + return api_response(success=False, message="SSH key not found", status=404, error_type="NOT_FOUND") + + if ssh_key.user_id != user_id: + return api_response(success=False, message="Forbidden", status=403, error_type="FORBIDDEN") + + if not ssh_key.verified: + return api_response(success=False, message="SSH key is not verified. Verify it before requesting a certificate.", status=400, error_type="KEY_NOT_VERIFIED") + + db_ca = _get_org_ca_for_user(user, ca_type=cert_type) + if db_ca is None: + return api_response( + success=False, + message="No active Certificate Authority is configured for your organization. An admin must generate a CA on the Certificate Authorities page before certificates can be issued.", + status=503, error_type="CA_NOT_CONFIGURED", + ) + + is_org_admin = any( + om.role in (OrganizationRole.ADMIN, OrganizationRole.OWNER) + for om in memberships + if om.organization and om.organization.deleted_at is None + ) + + dept_policy = _get_merged_dept_cert_policy(user_id) + if dept_policy: + if is_org_admin: + if expiry_hours is not None: + expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"]) + elif not dept_policy["allow_user_expiry"]: + expiry_hours = dept_policy["default_expiry_hours"] + else: + if expiry_hours is not None: + expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"]) + policy_extensions = dept_policy["extensions"] + else: + policy_extensions = None + + org_slugs = sorted({ + om.organization.slug for om in memberships + if om.organization and om.organization.deleted_at is None and getattr(om.organization, 'slug', None) + }) + org_slug = org_slugs[0] if org_slugs else "unknown" + full_name = getattr(user, 'full_name', None) or getattr(user, 'name', None) or "unknown" + cert_identity = f"{user.email} ({full_name}) [org:{org_slug}]" + + signing_request = SSHCertificateSigningRequest( + ssh_public_key=ssh_key.payload, principals=principals, cert_type=cert_type, + key_id=cert_identity, expiry_hours=int(expiry_hours) if expiry_hours else None, + extensions=policy_extensions, + ) + validation_errors = signing_request.validate() + if validation_errors: + return api_response(success=False, message="Invalid signing request", status=400, error_type="VALIDATION_ERROR", error_details={"errors": validation_errors}) + + try: + from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key + ca_private_key_pem = decrypt_ca_key(db_ca.private_key) + response = ssh_ca_service.sign_certificate(signing_request, ca_private_key=ca_private_key_pem, ca_obj=db_ca) + except SSHCertificateError as e: + AuditLog.log(action=AuditAction.SSH_CERT_FAILED, user_id=user_id, resource_type='SSHCertificate', ip_address=request.remote_addr, success=False, error_message=str(e)) + return api_response(success=False, message=str(e), status=400, error_type="SIGNING_FAILED") + except Exception as e: + AuditLog.log(action=AuditAction.SSH_CERT_FAILED, user_id=user_id, resource_type='SSHCertificate', ip_address=request.remote_addr, success=False, error_message=str(e)) + return api_response(success=False, message="Certificate signing failed", status=500, error_type="SERVER_ERROR") + + cert_record = _persist_certificate( + user_id=user_id, ssh_key_id=key_id, ca=db_ca, + signing_response=response, request_ip=request.remote_addr, + cert_type_str=cert_type, cert_identity=cert_identity, + ) + + AuditLog.log( + action=AuditAction.SSH_CERT_ISSUED, user_id=user_id, + resource_type='SSHCertificate', resource_id=cert_record.id if cert_record else key_id, + ip_address=request.remote_addr, + description=f'Certificate serial={response.serial} issued for {user.email}; principals: {", ".join(principals)}', + extra_data={'serial': response.serial, 'key_id': cert_identity, 'principals': principals, 'ca_id': str(db_ca.id), 'ssh_key_id': str(key_id)}, + ) + + if cert_record: + CertificateAuditLog.log( + certificate_id=cert_record.id, action='issued', user_id=user_id, + ip_address=request.remote_addr, user_agent=request.headers.get('User-Agent'), + message=f'Certificate serial={response.serial} issued for {user.email}; principals: {", ".join(principals)}', + extra_data={ + 'serial': response.serial, 'key_id': cert_identity, 'principals': principals, + 'ca_id': str(db_ca.id), 'ssh_key_id': str(key_id), + 'valid_after': response.valid_after.isoformat() if response.valid_after else None, + 'valid_before': response.valid_before.isoformat() if response.valid_before else None, + }, + success=True, + ) + + result = { + 'certificate': response.certificate, 'serial': response.serial, + 'principals': response.principals, + 'valid_after': response.valid_after.isoformat() if response.valid_after else None, + 'valid_before': response.valid_before.isoformat() if response.valid_before else None, + } + if cert_record: + result['cert_id'] = str(cert_record.id) + + return api_response(data=result, message="Certificate signed successfully", status=201) + + +@ssh_bp.route('/sign/host', methods=['POST']) +@login_required +def sign_host_certificate(): + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.ssh_ca.ca import CA, CaType + from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key + + user = g.current_user + user_id = user.id + + is_admin = OrganizationMember.query.filter( + OrganizationMember.user_id == user_id, + OrganizationMember.role.in_([OrganizationRole.ADMIN, OrganizationRole.OWNER]), + OrganizationMember.deleted_at.is_(None), + ).first() is not None + + if not is_admin: + return api_response(success=False, message="Issuing host certificates requires org admin or owner role.", status=403, error_type="FORBIDDEN") + + data = request.get_json() + if not data: + return api_response(success=False, message="No JSON data provided", status=400, error_type="BAD_REQUEST") + + host_public_key = (data.get("host_public_key") or "").strip() + principals = data.get("principals") or [] + validity_hours = data.get("validity_hours", 720) + ca_id = (data.get("ca_id") or "").strip() + + if not host_public_key: + return api_response(success=False, message="host_public_key is required.", status=400, error_type="BAD_REQUEST") + + key_kind = _classify_ssh_key_material(host_public_key) + if key_kind == "certificate": + return api_response(success=False, message="You submitted a certificate (ssh-…-cert-v01@openssh.com), not a host public key. Retrieve the server's host public key with: cat /etc/ssh/ssh_host_ed25519_key.pub", status=400, error_type="WRONG_KEY_MATERIAL") + if key_kind == "private_key": + return api_response(success=False, message="Private keys must never be submitted here. Use the .pub file.", status=400, error_type="WRONG_KEY_MATERIAL") + if key_kind == "unknown": + return api_response(success=False, message="Unrecognised key format. Expected an OpenSSH public key starting with ssh-ed25519, ssh-rsa, or ecdsa-sha2-*.", status=400, error_type="WRONG_KEY_MATERIAL") + + if not principals or not isinstance(principals, list): + return api_response(success=False, message="principals must be a non-empty list of hostnames.", status=422, error_type="VALIDATION_ERROR") + principals = [str(p).strip() for p in principals if str(p).strip()] + if not principals: + return api_response(success=False, message="At least one principal (hostname/FQDN) is required.", status=422, error_type="VALIDATION_ERROR") + + try: + validity_hours = int(validity_hours) + if validity_hours < 1: + raise ValueError + except (TypeError, ValueError): + return api_response(success=False, message="validity_hours must be a positive integer.", status=422, error_type="VALIDATION_ERROR") + + if not ca_id: + return api_response(success=False, message="ca_id is required.", status=400, error_type="BAD_REQUEST") + + org_ids = [m.organization_id for m in OrganizationMember.query.filter_by(user_id=user_id, deleted_at=None).all()] + + any_ca = CA.query.filter(CA.id == ca_id, CA.is_active.is_(True), CA.organization_id.in_(org_ids), CA.deleted_at.is_(None)).first() + + if any_ca and any_ca.ca_type != CaType.HOST: + return api_response(success=False, message=f"The CA '{any_ca.name}' is a {any_ca.ca_type.value} CA. Host certificates must be signed by a ca_type='host' CA.", status=400, error_type="WRONG_CA_TYPE") + + host_ca = any_ca + if not host_ca: + return api_response(success=False, message="Host CA not found, inactive, or you do not have permission to use it. Ensure the CA exists and ca_type is 'host'.", status=404, error_type="CA_NOT_FOUND") + + primary_principal = principals[0] + cert_identity = f"host:{primary_principal} [signed-by:{user.email}]" + + signing_request = SSHCertificateSigningRequest( + ssh_public_key=host_public_key, principals=principals, cert_type="host", + key_id=cert_identity, expiry_hours=validity_hours, extensions=[], critical_options={}, + ) + validation_errors = signing_request.validate() + if validation_errors: + return api_response(success=False, message="Invalid signing request: " + "; ".join(validation_errors), status=422, error_type="VALIDATION_ERROR") + + try: + ca_private_key_pem = decrypt_ca_key(host_ca.private_key) + response = ssh_ca_service.sign_certificate(signing_request, ca_private_key=ca_private_key_pem, ca_obj=host_ca) + except Exception as exc: + AuditLog.log(action=AuditAction.SSH_CERT_FAILED, user_id=user_id, resource_type="SSHCertificate", ip_address=request.remote_addr, success=False, error_message=str(exc)) + return api_response(success=False, message=f"Host certificate signing failed: {exc}", status=500, error_type="SIGNING_FAILED") + + cert_record = _persist_certificate( + user_id=user_id, ssh_key_id=None, ca=host_ca, + signing_response=response, request_ip=request.remote_addr, + cert_type_str="host", cert_identity=cert_identity, + ) + + AuditLog.log( + action=AuditAction.SSH_CERT_ISSUED, user_id=user_id, + resource_type="SSHCertificate", resource_id=cert_record.id if cert_record else None, + ip_address=request.remote_addr, + description=f"Host certificate serial={response.serial} issued for {primary_principal} by {user.email}", + extra_data={"serial": response.serial, "principals": principals, "ca_id": str(host_ca.id), "cert_type": "host"}, + ) + + result = { + "certificate": response.certificate, "serial": response.serial, "principals": response.principals, + "valid_after": response.valid_after.isoformat() if response.valid_after else None, + "valid_before": response.valid_before.isoformat() if response.valid_before else None, + } + if cert_record: + result["cert_id"] = str(cert_record.id) + + return api_response(data=result, message="Host certificate issued successfully", status=201) + + +@ssh_bp.route('/certificates', methods=['GET']) +@login_required +def list_certificates(): + user_id = g.current_user.id + try: + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate + certs = SSHCertificate.query.filter_by(user_id=user_id, deleted_at=None).order_by(SSHCertificate.created_at.desc()).all() + return api_response(data={'certificates': [c.to_dict() for c in certs], 'count': len(certs)}, message="Certificates retrieved successfully") + except Exception as e: + return api_response(success=False, message=str(e), status=500, error_type='INTERNAL_ERROR') + + +@ssh_bp.route('/certificates/', methods=['GET']) +@login_required +def get_certificate(cert_id): + user_id = g.current_user.id + try: + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate + cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first() + if not cert: + return api_response(success=False, message='Certificate not found', status=404, error_type='NOT_FOUND') + if cert.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + data = cert.to_dict() + data['certificate'] = cert.certificate + return api_response(success=True, message='Certificate retrieved', data=data, status=200) + except Exception as e: + return api_response(success=False, message=str(e), status=500, error_type='INTERNAL_ERROR') + + +@ssh_bp.route('/certificates//revoke', methods=['POST']) +@login_required +def revoke_certificate(cert_id): + user_id = g.current_user.id + data = request.get_json() or {} + reason = data.get('reason', 'User requested revocation') + try: + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate + cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first() + if not cert: + return api_response(success=False, message='Certificate not found', status=404, error_type='NOT_FOUND') + if cert.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + if cert.revoked: + return api_response(success=False, message='Certificate is already revoked', status=409, error_type='ALREADY_REVOKED') + + cert.revoke(reason=reason) + AuditLog.log(action=AuditAction.SSH_CERT_REVOKED, user_id=user_id, resource_type='SSHCertificate', resource_id=cert_id, ip_address=request.remote_addr, description=f'Revoked: {reason}') + CertificateAuditLog.log(certificate_id=cert_id, action='revoked', user_id=user_id, ip_address=request.remote_addr, user_agent=request.headers.get('User-Agent'), message=f'Certificate revoked: {reason}', success=True) + + return api_response(success=True, message='Certificate revoked successfully', data={'status': 'revoked', 'cert_id': cert_id, 'reason': reason}, status=200) + except Exception as e: + return api_response(success=False, message=str(e), status=500, error_type='INTERNAL_ERROR') + + +@ssh_bp.route('/ca/public-key', methods=['GET']) +@login_required +def get_ca_public_key(): + user = g.current_user + ca_type = request.args.get("ca_type", "user") + if ca_type not in ("user", "host"): + return api_response(success=False, message="ca_type must be 'user' or 'host'", status=400, error_type="BAD_REQUEST") + + db_ca = _get_org_ca_for_user(user, ca_type=ca_type) + if db_ca: + return api_response( + data={'public_key': db_ca.public_key, 'fingerprint': db_ca.fingerprint, 'ca_name': db_ca.name, 'ca_type': ca_type, 'source': 'db'}, + message="CA public key retrieved successfully", + ) + return api_response(success=False, message=f"No {ca_type} CA is configured for your organization. An admin must generate one on the Certificate Authorities page.", status=404, error_type="CA_NOT_CONFIGURED") diff --git a/gatehouse_app/api/v1/ssh/keys.py b/gatehouse_app/api/v1/ssh/keys.py new file mode 100644 index 0000000..e074586 --- /dev/null +++ b/gatehouse_app/api/v1/ssh/keys.py @@ -0,0 +1,125 @@ +"""SSH key management endpoints.""" +from sqlalchemy.exc import IntegrityError +from flask import request, g +from gatehouse_app.api.v1.ssh._helpers import ssh_bp, ssh_key_service +from gatehouse_app.exceptions import SSHKeyError, SSHKeyNotFoundError, ValidationError, SSHKeyAlreadyExistsError +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.models import AuditLog +from gatehouse_app.utils.decorators import login_required +from gatehouse_app.utils.response import api_response + + +@ssh_bp.route('/keys', methods=['GET']) +@login_required +def list_ssh_keys(): + user_id = g.current_user.id + keys = ssh_key_service.get_user_ssh_keys(user_id) + return api_response(data={'keys': [k.to_dict() for k in keys], 'count': len(keys)}, message="SSH keys retrieved successfully") + + +@ssh_bp.route('/keys', methods=['POST']) +@login_required +def add_ssh_key(): + user_id = g.current_user.id + data = request.get_json() + if not data: + return api_response(success=False, message='No JSON data provided', status=400, error_type='BAD_REQUEST') + + public_key = data.get('public_key') or data.get('key') + description = data.get('description') + + if not public_key: + return api_response(success=False, message='public_key is required', status=400, error_type='BAD_REQUEST') + + try: + ssh_key = ssh_key_service.add_ssh_key(user_id=user_id, public_key=public_key, description=description) + AuditLog.log(action=AuditAction.SSH_KEY_ADDED, user_id=user_id, resource_type='SSHKey', resource_id=ssh_key.id, ip_address=request.remote_addr) + return api_response(success=True, message='SSH key added', data=ssh_key.to_dict(), status=201) + except SSHKeyAlreadyExistsError as e: + return api_response(success=False, message=e.message, status=409, error_type='SSH_KEY_ALREADY_EXISTS') + except IntegrityError: + return api_response(success=False, message='SSH key already exists', status=409, error_type='SSH_KEY_ALREADY_EXISTS') + except SSHKeyError as e: + return api_response(success=False, message=str(e), status=400, error_type='SSH_KEY_ERROR') + except ValidationError as e: + return api_response(success=False, message=str(e), status=400, error_type='VALIDATION_ERROR') + + +@ssh_bp.route('/keys/', methods=['GET']) +@login_required +def get_ssh_key(key_id): + user_id = g.current_user.id + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + if ssh_key.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + return api_response(success=True, message='SSH key retrieved', data=ssh_key.to_dict(), status=200) + except SSHKeyNotFoundError: + return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') + + +@ssh_bp.route('/keys/', methods=['DELETE']) +@login_required +def delete_ssh_key(key_id): + user_id = g.current_user.id + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + if ssh_key.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + ssh_key_service.delete_ssh_key(key_id) + AuditLog.log(action=AuditAction.SSH_KEY_DELETED, user_id=user_id, resource_type='SSHKey', resource_id=key_id, ip_address=request.remote_addr) + return api_response(success=True, message='SSH key deleted', data={'status': 'deleted'}, status=200) + except SSHKeyNotFoundError: + return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') + + +@ssh_bp.route('/keys//verify', methods=['GET', 'POST']) +@login_required +def verify_ssh_key(key_id): + user_id = g.current_user.id + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + if ssh_key.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + + if request.method == 'GET': + challenge = ssh_key_service.generate_verification_challenge(key_id) + return api_response(success=True, message='Challenge generated', data={'challenge_text': challenge, 'validationText': challenge, 'key_id': key_id}, status=200) + + data = request.get_json() or {} + action = data.get('action', 'verify_signature') + + if action == 'verify_signature': + signature = data.get('signature') + if not signature: + return api_response(success=False, message='signature is required', status=400, error_type='BAD_REQUEST') + try: + verified = ssh_key_service.verify_ssh_key_ownership(key_id, signature) + AuditLog.log(action=AuditAction.SSH_KEY_VERIFIED, user_id=user_id, resource_type='SSHKey', resource_id=key_id, ip_address=request.remote_addr, success=verified) + return api_response(success=True, message='Verification complete', data={'verified': verified}, status=200) + except Exception as e: + AuditLog.log(action=AuditAction.SSH_KEY_VALIDATION_FAILED, user_id=user_id, resource_type='SSHKey', resource_id=key_id, ip_address=request.remote_addr, success=False, error_message=str(e)) + return api_response(success=False, message=str(e), status=400, error_type='VERIFICATION_FAILED') + else: + challenge = ssh_key_service.generate_verification_challenge(key_id) + return api_response(success=True, message='Challenge generated', data={'challenge_text': challenge, 'challenge': challenge}, status=200) + + except SSHKeyNotFoundError: + return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') + + +@ssh_bp.route('/keys//update-description', methods=['PATCH']) +@login_required +def update_ssh_key_description(key_id): + user_id = g.current_user.id + data = request.get_json() + if not data or 'description' not in data: + return api_response(success=False, message='description is required', status=400, error_type='BAD_REQUEST') + try: + ssh_key = ssh_key_service.get_ssh_key(key_id) + if ssh_key.user_id != user_id: + return api_response(success=False, message='Forbidden', status=403, error_type='FORBIDDEN') + updated_key = ssh_key_service.update_ssh_key_description(key_id, data['description']) + return api_response(success=True, message='Description updated', data=updated_key.to_dict(), status=200) + except SSHKeyNotFoundError: + return api_response(success=False, message='SSH key not found', status=404, error_type='NOT_FOUND') diff --git a/gatehouse_app/api/v1/users.py b/gatehouse_app/api/v1/users.py deleted file mode 100644 index 112ea1f..0000000 --- a/gatehouse_app/api/v1/users.py +++ /dev/null @@ -1,879 +0,0 @@ -"""User endpoints.""" -from flask import g, request -from marshmallow import ValidationError -from gatehouse_app.api.v1 import api_v1_bp -from gatehouse_app.utils.response import api_response -from gatehouse_app.utils.decorators import login_required, full_access_required -from gatehouse_app.schemas.user_schema import UserUpdateSchema, ChangePasswordSchema -from gatehouse_app.services.user_service import UserService -from gatehouse_app.services.auth_service import AuthService - - -@api_v1_bp.route("/users/me", methods=["GET"]) -@login_required -def get_me(): - """ - Get current user profile. - - Returns: - 200: User profile data - 401: Not authenticated - """ - user = g.current_user - - return api_response( - data={"user": user.to_dict()}, - message="User profile retrieved successfully", - ) - - -@api_v1_bp.route("/users/me", methods=["PATCH"]) -@login_required -@full_access_required -def update_me(): - """ - Update current user profile. - - Request body: - full_name: Optional full name - avatar_url: Optional avatar URL - - Returns: - 200: User updated successfully - 400: Validation error - 401: Not authenticated - """ - try: - # Validate request data - schema = UserUpdateSchema() - data = schema.load(request.json) - - # Update user - user = UserService.update_user(g.current_user, **data) - - return api_response( - data={"user": user.to_dict()}, - message="Profile updated successfully", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - -@api_v1_bp.route("/users/me", methods=["DELETE"]) -@login_required -@full_access_required -def delete_me(): - """ - Delete current user account (soft delete). - - Behaviour for owned organizations: - - If the org has other active members → blocked; user must transfer ownership first. - - If they are the sole member → org is automatically cascade-deleted (no orphan risk). - - Returns: - 200: Account deleted successfully (sole-member orgs auto-deleted) - 401: Not authenticated - 409: USER_IS_SOLE_OWNER — user owns orgs that still have other members - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - from gatehouse_app.services.organization_service import OrganizationService - - user = g.current_user - - # Find all orgs where this user is the owner. - owned_memberships = OrganizationMember.query.filter_by( - user_id=user.id, - role=OrganizationRole.OWNER, - deleted_at=None, - ).all() - - # Separate into two buckets depending on whether other members exist. - transfer_needed = [] # org has other members → must transfer ownership first - auto_delete = [] # user is sole member → safe to cascade-delete automatically - - for membership in owned_memberships: - org = membership.organization - if org.deleted_at is not None: - continue - member_count = org.get_member_count() - if member_count > 1: - transfer_needed.append(org.name) - else: - auto_delete.append(org) - - # Hard block: user owns orgs with other members — must transfer first. - if transfer_needed: - names = ", ".join(f'"{n}"' for n in transfer_needed) - return api_response( - success=False, - message=( - f"You are the owner of {len(transfer_needed)} organization" - f"{'s' if len(transfer_needed) > 1 else ''} that still " - f"{'have' if len(transfer_needed) > 1 else 'has'} other members " - f"({names}). Transfer ownership to another member first." - ), - status=409, - error_type="USER_IS_SOLE_OWNER", - error_details={"transfer_ownership": transfer_needed}, - ) - - # Auto-delete any sole-member orgs so no orphaned org rows can ever be left behind. - for org in auto_delete: - OrganizationService.force_delete_organization(org, user_id=user.id) - - UserService.delete_user(user, soft=True) - - return api_response( - message="Account deleted successfully", - ) - - -@api_v1_bp.route("/users/me/password", methods=["POST"]) -@login_required -@full_access_required -def change_password(): - """ - Change current user password. - - Request body: - current_password: Current password - new_password: New password - new_password_confirm: New password confirmation - - Returns: - 200: Password changed successfully - 400: Validation error - 401: Not authenticated or invalid current password - """ - try: - # Validate request data - schema = ChangePasswordSchema() - data = schema.load(request.json) - - # Verify passwords match - if data["new_password"] != data["new_password_confirm"]: - return api_response( - success=False, - message="New passwords do not match", - status=400, - error_type="VALIDATION_ERROR", - error_details={"new_password_confirm": ["Passwords do not match"]}, - ) - - # Change password - AuthService.change_password( - user=g.current_user, - current_password=data["current_password"], - new_password=data["new_password"], - ) - - return api_response( - message="Password changed successfully", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - -@api_v1_bp.route("/users/me/organizations", methods=["GET"]) -@login_required -@full_access_required -def get_my_organizations(): - """ - Get all organizations current user is a member of, including the user's role. - - Returns: - 200: List of organizations with role - 401: Not authenticated - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - - user = g.current_user - memberships = OrganizationMember.query.filter_by( - user_id=user.id, - deleted_at=None, - ).all() - - orgs = [] - for membership in memberships: - org = membership.organization - if not org or org.deleted_at is not None: - continue - org_dict = org.to_dict() - org_dict["role"] = membership.role.value if hasattr(membership.role, "value") else str(membership.role) - orgs.append(org_dict) - - return api_response( - data={ - "organizations": orgs, - "count": len(orgs), - }, - message="Organizations retrieved successfully", - ) - - -@api_v1_bp.route("/users/me/principals", methods=["GET"]) -@login_required -@full_access_required -def get_my_principals(): - """Return all principals the current user can sign certificates for. - - For each organization the user belongs to, returns: - - Their effective principals (direct membership + via department) - - Their role in that org (so the frontend can offer admin-mode selection) - - All principals in the org (admin/owner only — so they can pick any) - - Returns: - 200: { - orgs: [{ - org_id, org_name, role, - my_principals: [{id, name, description}], - all_principals: [{id, name, description}] # populated for admin/owner only - }] - } - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.models.organization.principal import Principal, PrincipalMembership - from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal - from gatehouse_app.utils.constants import OrganizationRole - - user = g.current_user - user_id = user.id - - # Get all org memberships - memberships = OrganizationMember.query.filter_by( - user_id=user_id, - ).all() - - orgs_result = [] - for membership in memberships: - org = membership.organization - if not org or org.deleted_at is not None: - continue - - role = membership.role - is_admin = role in (OrganizationRole.ADMIN, OrganizationRole.OWNER) - - # Collect the user's effective principals for this org - # Track direct vs via-department separately - direct_principal_ids = set() - via_dept_principal_ids = set() - - # Direct memberships - direct = PrincipalMembership.query.filter_by( - user_id=user_id, - deleted_at=None, - ).all() - for pm in direct: - if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None: - direct_principal_ids.add(pm.principal_id) - - # Via department - dept_memberships = DepartmentMembership.query.filter_by( - user_id=user_id, - deleted_at=None, - ).all() - for dm in dept_memberships: - if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None: - dept_principals = DepartmentPrincipal.query.filter_by( - department_id=dm.department_id, - deleted_at=None, - ).all() - for dp in dept_principals: - if dp.principal and dp.principal.deleted_at is None: - via_dept_principal_ids.add(dp.principal_id) - - effective_principal_ids = direct_principal_ids | via_dept_principal_ids - - # Fetch principal objects - my_principals = [] - if effective_principal_ids: - my_p = Principal.query.filter( - Principal.id.in_(list(effective_principal_ids)), - Principal.deleted_at == None, - ).all() - my_principals = [ - { - "id": p.id, - "name": p.name, - "description": p.description, - # direct=True means removable via API; False=inherited via department - "direct": p.id in direct_principal_ids, - } - for p in my_p - ] - - # For admins/owners: also return all principals in the org - all_principals = [] - if is_admin: - all_p = Principal.query.filter_by( - organization_id=org.id, - deleted_at=None, - ).all() - all_principals = [{"id": p.id, "name": p.name, "description": p.description} for p in all_p] - - orgs_result.append({ - "org_id": org.id, - "org_name": org.name, - "role": role.value if hasattr(role, "value") else role, - "is_admin": is_admin, - "my_principals": my_principals, - "all_principals": all_principals, - }) - - return api_response( - data={"orgs": orgs_result}, - message="Principals retrieved successfully", - ) - - -@api_v1_bp.route("/admin/users", methods=["GET"]) -@login_required -@full_access_required -def admin_list_users(): - """List all users the caller has admin rights to see. - - The caller must be an OWNER or ADMIN of at least one organization. - Returns users that share an organization with the caller and where the - caller holds admin/owner role in that organization. - - Query params: - q – optional search string (matched against name/email) - page – page number (default 1) - per_page – page size (default 50, max 200) - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.models.user.user import User as _User - from gatehouse_app.extensions import db as _db - from sqlalchemy import or_ - - caller = g.current_user - - # Find orgs where caller is admin/owner - admin_memberships = OrganizationMember.query.filter( - OrganizationMember.user_id == caller.id, - OrganizationMember.role.in_(["OWNER", "ADMIN"]), - OrganizationMember.deleted_at == None, - ).all() - - if not admin_memberships: - return api_response( - success=False, - message="Admin or owner role required", - status=403, - error_type="AUTHORIZATION_ERROR", - ) - - admin_org_ids = [m.organization_id for m in admin_memberships] - - # Collect user IDs in those orgs - member_rows = OrganizationMember.query.filter( - OrganizationMember.organization_id.in_(admin_org_ids), - OrganizationMember.deleted_at == None, - ).all() - visible_user_ids = list({row.user_id for row in member_rows}) - - # Optional search - q = request.args.get("q", "").strip() - try: - page = max(1, int(request.args.get("page", 1))) - per_page = min(200, max(1, int(request.args.get("per_page", 50)))) - except ValueError: - page, per_page = 1, 50 - - query = _User.query.filter( - _User.id.in_(visible_user_ids), - _User.deleted_at == None, - ) - if q: - like = f"%{q}%" - query = query.filter(or_(_User.email.ilike(like), _User.full_name.ilike(like))) - - total = query.count() - users = query.order_by(_User.email).offset((page - 1) * per_page).limit(per_page).all() - - member_lookup: dict = {} - for row in member_rows: - if row.user_id not in member_lookup: - member_lookup[row.user_id] = { - "organization_id": row.organization_id, - "role": row.role.value if hasattr(row.role, "value") else row.role, - } - - users_data = [] - for u in users: - d = u.to_dict() - m = member_lookup.get(u.id, {}) - d["org_role"] = m.get("role", "member") - d["org_id"] = m.get("organization_id") - users_data.append(d) - - return api_response( - data={ - "users": users_data, - "count": total, - "page": page, - "per_page": per_page, - "pages": (total + per_page - 1) // per_page, - }, - message="Users retrieved successfully", - ) - - -@api_v1_bp.route("/admin/users/", methods=["GET"]) -@login_required -@full_access_required -def admin_get_user(user_id): - """Get a single user's profile (admin view with SSH keys).""" - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.models.user.user import User as _User - from gatehouse_app.models.ssh_ca.ssh_key import SSHKey - - caller = g.current_user - - target = _User.query.filter_by(id=user_id, deleted_at=None).first() - if not target: - return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") - - # Verify caller has admin access to a shared org - target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None} - has_access = OrganizationMember.query.filter( - OrganizationMember.user_id == caller.id, - OrganizationMember.organization_id.in_(target_org_ids), - OrganizationMember.role.in_(["OWNER", "ADMIN"]), - OrganizationMember.deleted_at == None, - ).first() is not None - - if not has_access: - return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") - - ssh_keys = SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all() - - return api_response( - data={ - "user": target.to_dict(), - "ssh_keys": [k.to_dict() for k in ssh_keys], - }, - message="User retrieved", - ) - - -@api_v1_bp.route("/admin/users//suspend", methods=["POST"]) -@login_required -@full_access_required -def admin_suspend_user(user_id): - """Suspend a user account (blocks CA issuance and login). - - The caller must be an OWNER or ADMIN of an organization the target user belongs to. - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.models.user.user import User as _User - from gatehouse_app.extensions import db as _db - from gatehouse_app.utils.constants import UserStatus, AuditAction - from gatehouse_app.services.audit_service import AuditService - - caller = g.current_user - target = _User.query.filter_by(id=user_id, deleted_at=None).first() - if not target: - return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") - - if target.id == caller.id: - return api_response(success=False, message="Cannot suspend yourself", status=400, error_type="BAD_REQUEST") - - # Verify caller has admin access to a shared org - target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None} - admin_in_shared_org = OrganizationMember.query.filter( - OrganizationMember.user_id == caller.id, - OrganizationMember.organization_id.in_(target_org_ids), - OrganizationMember.role.in_(["OWNER", "ADMIN"]), - OrganizationMember.deleted_at == None, - ).first() - - if not admin_in_shared_org: - return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") - - # ── Owner protection ────────────────────────────────────────────────────── - # An org owner cannot be suspended until they transfer ownership. - from gatehouse_app.utils.constants import OrganizationRole - owner_memberships = OrganizationMember.query.filter( - OrganizationMember.user_id == target.id, - OrganizationMember.role == OrganizationRole.OWNER, - OrganizationMember.deleted_at == None, - ).all() - if owner_memberships: - org_names = [ - m.organization.name - for m in owner_memberships - if m.organization and not m.organization.deleted_at - ] - return api_response( - success=False, - message=( - f"Cannot suspend an organization owner. " - f"{target.email} is the owner of: {', '.join(org_names)}. " - "Transfer ownership to another member first." - ), - status=403, - error_type="OWNER_PROTECTION", - ) - - if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): - return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT") - - target.status = UserStatus.SUSPENDED - _db.session.commit() - - AuditService.log_action( - action=AuditAction.USER_SUSPEND, - user_id=caller.id, - organization_id=admin_in_shared_org.organization_id, - resource_type="user", - resource_id=str(target.id), - description=f"Admin suspended user {target.email}", - metadata={"target_user_id": str(target.id), "target_email": target.email}, - ) - - return api_response(data={"user": target.to_dict()}, message="User suspended successfully") - - -@api_v1_bp.route("/admin/users//unsuspend", methods=["POST"]) -@login_required -@full_access_required -def admin_unsuspend_user(user_id): - """Restore a suspended user account to active status.""" - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.models.user.user import User as _User - from gatehouse_app.extensions import db as _db - from gatehouse_app.utils.constants import UserStatus, AuditAction - from gatehouse_app.services.audit_service import AuditService - - caller = g.current_user - target = _User.query.filter_by(id=user_id, deleted_at=None).first() - if not target: - return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") - - # Verify caller has admin access to a shared org - target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None} - admin_in_shared_org = OrganizationMember.query.filter( - OrganizationMember.user_id == caller.id, - OrganizationMember.organization_id.in_(target_org_ids), - OrganizationMember.role.in_(["OWNER", "ADMIN"]), - OrganizationMember.deleted_at == None, - ).first() - - if not admin_in_shared_org: - return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") - - if target.status not in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): - return api_response(success=False, message="User is not suspended", status=409, error_type="CONFLICT") - - target.status = UserStatus.ACTIVE - _db.session.commit() - - AuditService.log_action( - action=AuditAction.USER_UNSUSPEND, - user_id=caller.id, - organization_id=admin_in_shared_org.organization_id, - resource_type="user", - resource_id=str(target.id), - description=f"Admin unsuspended user {target.email}", - metadata={"target_user_id": str(target.id), "target_email": target.email}, - ) - - return api_response(data={"user": target.to_dict()}, message="User unsuspended successfully") - - -@api_v1_bp.route("/users/me/invites", methods=["GET"]) -@login_required -def get_my_pending_invites(): - """Return pending (unaccepted, non-expired) invitations for the current user's email.""" - from gatehouse_app.models.organization.org_invite_token import OrgInviteToken - from datetime import datetime, timezone - - user = g.current_user - now = datetime.now(timezone.utc) - - invites = OrgInviteToken.query.filter( - OrgInviteToken.email == user.email, - OrgInviteToken.accepted_at.is_(None), - OrgInviteToken.expires_at > now, - OrgInviteToken.deleted_at.is_(None), - ).all() - - return api_response( - data={ - "invites": [ - { - "token": i.token, - "organization": {"id": str(i.organization_id), "name": i.organization.name}, - "role": i.role, - "expires_at": i.expires_at.isoformat(), - } - for i in invites - ] - }, - message="Pending invitations retrieved", - ) - - -@api_v1_bp.route("/users/me/memberships", methods=["GET"]) -@login_required -def get_my_memberships(): - """Return the current user's department and principal memberships across all orgs. - - Returns: - 200: { - orgs: [{ - org_id, org_name, role, - departments: [{id, name, description}], - principals: [{id, name, description, via_department: bool}] - }] - } - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal, Department - from gatehouse_app.models.organization.principal import Principal, PrincipalMembership - - user = g.current_user - - memberships = OrganizationMember.query.filter_by( - user_id=user.id, - deleted_at=None, - ).all() - - orgs_result = [] - for membership in memberships: - org = membership.organization - if not org or org.deleted_at is not None: - continue - - # Departments the user belongs to - dept_memberships = DepartmentMembership.query.filter_by( - user_id=user.id, - deleted_at=None, - ).all() - user_depts = [ - dm.department for dm in dept_memberships - if dm.department - and dm.department.organization_id == org.id - and dm.department.deleted_at is None - ] - - # Principals: direct - direct_pm = PrincipalMembership.query.filter_by( - user_id=user.id, - deleted_at=None, - ).all() - direct_principal_ids = { - pm.principal_id for pm in direct_pm - if pm.principal - and pm.principal.organization_id == org.id - and pm.principal.deleted_at is None - } - - # Principals: via department - via_dept_principal_ids = set() - for dept in user_depts: - for dp in DepartmentPrincipal.query.filter_by(department_id=dept.id, deleted_at=None).all(): - if dp.principal and dp.principal.deleted_at is None: - via_dept_principal_ids.add(dp.principal_id) - - all_principal_ids = direct_principal_ids | via_dept_principal_ids - principals_list = [] - if all_principal_ids: - for p in Principal.query.filter( - Principal.id.in_(list(all_principal_ids)), - Principal.deleted_at == None, - ).all(): - principals_list.append({ - "id": str(p.id), - "name": p.name, - "description": p.description, - "via_department": p.id not in direct_principal_ids, - }) - - role = membership.role - orgs_result.append({ - "org_id": str(org.id), - "org_name": org.name, - "role": role.value if hasattr(role, "value") else role, - "departments": [ - {"id": str(d.id), "name": d.name, "description": d.description} - for d in user_depts - ], - "principals": principals_list, - }) - - return api_response( - data={"orgs": orgs_result}, - message="Memberships retrieved", - ) - - -@api_v1_bp.route("/admin/users//delete", methods=["POST"]) -@login_required -@full_access_required -def admin_hard_delete_user(user_id): - """Permanently delete a user and ALL associated data (hard delete, irreversible). - - Required body: {"confirm": true} - - Pre-conditions: - - Caller is OWNER or ADMIN of a shared org with the target. - - Cannot delete yourself. - - Target must not be the OWNER of any active organization (transfer first). - - Side-effects: - - All active SSH certificates are revoked before deletion. - - The user row and all cascaded rows are hard-deleted from the database. - - An audit log entry is written by the *caller* (so it is not lost with the user). - """ - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.models.user.user import User as _User - from gatehouse_app.extensions import db as _db - from gatehouse_app.utils.constants import UserStatus, AuditAction, OrganizationRole - from gatehouse_app.services.audit_service import AuditService - - caller = g.current_user - data = request.get_json() or {} - - if not data.get("confirm"): - return api_response( - success=False, - message="Deletion requires explicit confirmation. Send {\"confirm\": true} to proceed.", - status=400, - error_type="CONFIRMATION_REQUIRED", - ) - - target = _User.query.filter_by(id=user_id).first() - if not target: - return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") - - if target.id == caller.id: - return api_response( - success=False, - message="Cannot delete your own account via this endpoint.", - status=400, - error_type="BAD_REQUEST", - ) - - # Caller must be OWNER/ADMIN of a shared org. - # Include soft-deleted memberships so that already-soft-deleted users can - # still be hard-deleted by an admin who shared an org with them. - target_org_ids = {m.organization_id for m in target.organization_memberships} - admin_in_shared_org = OrganizationMember.query.filter( - OrganizationMember.user_id == caller.id, - OrganizationMember.organization_id.in_(target_org_ids), - OrganizationMember.role.in_(["OWNER", "ADMIN"]), - OrganizationMember.deleted_at == None, - ).first() - if not admin_in_shared_org: - return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") - - # Block deletion if target is an org owner — they must transfer first - owner_memberships = OrganizationMember.query.filter( - OrganizationMember.user_id == target.id, - OrganizationMember.role == OrganizationRole.OWNER, - OrganizationMember.deleted_at == None, - ).all() - if owner_memberships: - org_names = [ - m.organization.name - for m in owner_memberships - if m.organization and not m.organization.deleted_at - ] - return api_response( - success=False, - message=( - f"Cannot delete an organization owner. " - f"{target.email} is the owner of: {', '.join(org_names)}. " - "Transfer ownership to another member first." - ), - status=403, - error_type="OWNER_PROTECTION", - ) - - # ── Collect counts for audit metadata ──────────────────────────────────── - from gatehouse_app.models.ssh_ca.ssh_key import SSHKey - from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus - - ssh_key_count = SSHKey.query.filter_by(user_id=target.id, deleted_at=None).count() - active_cert_count = SSHCertificate.query.filter_by( - user_id=target.id, revoked=False - ).filter(SSHCertificate.deleted_at == None).count() - - # ── Revoke all active SSH certificates before deletion ─────────────────── - active_certs = SSHCertificate.query.filter_by( - user_id=target.id, revoked=False - ).filter(SSHCertificate.deleted_at == None).all() - for cert in active_certs: - try: - cert.revoke("account_deleted") - except Exception: - pass - - if active_certs: - try: - _db.session.flush() - except Exception: - pass - - # ── Hard delete ─────────────────────────────────────────────────────────── - target_email = target.email # capture before deletion - target_id_str = str(target.id) - - try: - _db.session.delete(target) # cascades to all child tables - _db.session.flush() - except Exception as exc: - _db.session.rollback() - import logging - logging.getLogger(__name__).error(f"Hard delete failed for {target_id_str}: {exc}") - return api_response( - success=False, - message="Failed to delete user account. Please try again.", - status=500, - error_type="SERVER_ERROR", - ) - - # ── Audit log (written as the caller so it survives the deletion) ───────── - AuditService.log_action( - action=AuditAction.USER_HARD_DELETE, - user_id=caller.id, - organization_id=admin_in_shared_org.organization_id, - resource_type="user", - resource_id=target_id_str, - description=f"Admin permanently deleted user account: {target_email}", - metadata={ - "deleted_user_id": target_id_str, - "deleted_user_email": target_email, - "ssh_keys_deleted": ssh_key_count, - "certs_revoked": active_cert_count, - }, - ) - - _db.session.commit() - - return api_response( - message=f"User account {target_email} has been permanently deleted.", - data={ - "deleted_user_id": target_id_str, - "deleted_user_email": target_email, - "ssh_keys_deleted": ssh_key_count, - "certs_revoked": active_cert_count, - }, - ) diff --git a/gatehouse_app/api/v1/users/__init__.py b/gatehouse_app/api/v1/users/__init__.py new file mode 100644 index 0000000..35beff9 --- /dev/null +++ b/gatehouse_app/api/v1/users/__init__.py @@ -0,0 +1,2 @@ +"""Users blueprint subpackage.""" +from gatehouse_app.api.v1.users import me, admin diff --git a/gatehouse_app/api/v1/users/admin.py b/gatehouse_app/api/v1/users/admin.py new file mode 100644 index 0000000..4efcf00 --- /dev/null +++ b/gatehouse_app/api/v1/users/admin.py @@ -0,0 +1,842 @@ +"""Admin user management endpoints.""" +import logging +from datetime import datetime, timezone +from flask import g, request +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, full_access_required + +_logger = logging.getLogger(__name__) + + +def _get_admin_access(caller, target): + """Return the first OrganizationMember row where caller is OWNER/ADMIN in a shared org with target, or None. + + Works even when the target user has been soft-deleted, as long as the + OrganizationMember row is still active (deleted_at IS NULL). + """ + from gatehouse_app.models.organization.organization_member import OrganizationMember + + # Query directly — don't rely on the ORM relationship which may be stale + # when the user row is soft-deleted. + target_memberships = OrganizationMember.query.filter_by( + user_id=target.id, deleted_at=None + ).all() + target_org_ids = {m.organization_id for m in target_memberships} + if not target_org_ids: + return None + return OrganizationMember.query.filter( + OrganizationMember.user_id == caller.id, + OrganizationMember.organization_id.in_(target_org_ids), + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).first() + + +def _find_user_for_admin(user_id): + """Look up a user by ID for admin use. + + Returns the User row whether or not it has been soft-deleted, so that + admins can manage accounts that the user themselves deleted but that still + have an active org membership. + """ + from gatehouse_app.models.user.user import User as _User + return _User.query.filter_by(id=user_id).first() + + +@api_v1_bp.route("/admin/users", methods=["GET"]) +@login_required +@full_access_required +def admin_list_users(): + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.user.user import User as _User + from sqlalchemy import or_ + + caller = g.current_user + + admin_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == caller.id, + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).all() + + if not admin_memberships: + return api_response(success=False, message="Admin or owner role required", status=403, error_type="AUTHORIZATION_ERROR") + + admin_org_ids = [m.organization_id for m in admin_memberships] + + member_rows = OrganizationMember.query.filter( + OrganizationMember.organization_id.in_(admin_org_ids), + OrganizationMember.deleted_at == None, + ).all() + visible_user_ids = list({row.user_id for row in member_rows}) + + q = request.args.get("q", "").strip() + try: + page = max(1, int(request.args.get("page", 1))) + per_page = min(200, max(1, int(request.args.get("per_page", 50)))) + except ValueError: + page, per_page = 1, 50 + + query = _User.query.filter(_User.id.in_(visible_user_ids)) + if q: + like = f"%{q}%" + query = query.filter(or_(_User.email.ilike(like), _User.full_name.ilike(like))) + + total = query.count() + users = query.order_by(_User.email).offset((page - 1) * per_page).limit(per_page).all() + + member_lookup = {} + for row in member_rows: + if row.user_id not in member_lookup: + member_lookup[row.user_id] = { + "organization_id": row.organization_id, + "role": row.role.value if hasattr(row.role, "value") else row.role, + } + + users_data = [] + for u in users: + d = u.to_dict() + m = member_lookup.get(u.id, {}) + d["org_role"] = m.get("role", "member") + d["org_id"] = m.get("organization_id") + d["is_deleted"] = u.deleted_at is not None + users_data.append(d) + + return api_response( + data={ + "users": users_data, "count": total, + "page": page, "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + }, + message="Users retrieved successfully", + ) + + +@api_v1_bp.route("/admin/users/", methods=["GET"]) +@login_required +@full_access_required +def admin_get_user(user_id): + from gatehouse_app.models.ssh_ca.ssh_key import SSHKey + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + caller = g.current_user + target = _find_user_for_admin(user_id) + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if not _get_admin_access(caller, target): + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + OAUTH_TYPES = { + AuthMethodType.GOOGLE, AuthMethodType.GITHUB, + AuthMethodType.MICROSOFT, AuthMethodType.OIDC, + } + auth_methods = AuthenticationMethod.query.filter_by(user_id=user_id, deleted_at=None).all() + + has_password = any( + m.method_type == AuthMethodType.PASSWORD and m.password_hash + for m in auth_methods + ) + totp_method = next( + (m for m in auth_methods if m.method_type == AuthMethodType.TOTP and m.verified), + None, + ) + totp_enabled = totp_method is not None + linked_providers = [ + { + "provider": m.method_type.value, + "email": (m.provider_data or {}).get("email"), + "name": (m.provider_data or {}).get("name"), + "connected_since": m.created_at.isoformat() if m.created_at else None, + } + for m in auth_methods if m.method_type in OAUTH_TYPES + ] + + user_dict = target.to_dict() + user_dict["has_password"] = has_password + user_dict["totp_enabled"] = totp_enabled + user_dict["totp_enabled_at"] = ( + totp_method.totp_verified_at.isoformat() + if totp_method and totp_method.totp_verified_at + else (totp_method.created_at.isoformat() if totp_method and totp_method.created_at else None) + ) + user_dict["linked_providers"] = linked_providers + user_dict["is_deleted"] = target.deleted_at is not None + + ssh_keys = SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all() + return api_response( + data={"user": user_dict, "ssh_keys": [k.to_dict() for k in ssh_keys]}, + message="User retrieved", + ) + + +@api_v1_bp.route("/admin/users//suspend", methods=["POST"]) +@login_required +@full_access_required +def admin_suspend_user(user_id): + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.extensions import db as _db + from gatehouse_app.utils.constants import UserStatus, AuditAction, OrganizationRole + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + target = _find_user_for_admin(user_id) + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if target.id == caller.id: + return api_response(success=False, message="Cannot suspend yourself", status=400, error_type="BAD_REQUEST") + + admin_in_shared_org = _get_admin_access(caller, target) + if not admin_in_shared_org: + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + owner_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == target.id, + OrganizationMember.role == OrganizationRole.OWNER, + OrganizationMember.deleted_at == None, + ).all() + if owner_memberships: + org_names = [m.organization.name for m in owner_memberships if m.organization and not m.organization.deleted_at] + return api_response( + success=False, + message=( + f"Cannot suspend an organization owner. {target.email} is the owner of: {', '.join(org_names)}. " + "Transfer ownership to another member first." + ), + status=403, error_type="OWNER_PROTECTION", + ) + + if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT") + + target.status = UserStatus.SUSPENDED + _db.session.commit() + + AuditService.log_action( + action=AuditAction.USER_SUSPEND, + user_id=caller.id, + organization_id=admin_in_shared_org.organization_id, + resource_type="user", resource_id=str(target.id), + description=f"Admin suspended user {target.email}", + metadata={"target_user_id": str(target.id), "target_email": target.email}, + ) + return api_response(data={"user": target.to_dict()}, message="User suspended successfully") + + +@api_v1_bp.route("/admin/users//unsuspend", methods=["POST"]) +@login_required +@full_access_required +def admin_unsuspend_user(user_id): + from gatehouse_app.extensions import db as _db + from gatehouse_app.utils.constants import UserStatus, AuditAction + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + target = _find_user_for_admin(user_id) + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + admin_in_shared_org = _get_admin_access(caller, target) + if not admin_in_shared_org: + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + if target.status not in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED): + return api_response(success=False, message="User is not suspended", status=409, error_type="CONFLICT") + + target.status = UserStatus.ACTIVE + _db.session.commit() + + AuditService.log_action( + action=AuditAction.USER_UNSUSPEND, + user_id=caller.id, + organization_id=admin_in_shared_org.organization_id, + resource_type="user", resource_id=str(target.id), + description=f"Admin unsuspended user {target.email}", + metadata={"target_user_id": str(target.id), "target_email": target.email}, + ) + return api_response(data={"user": target.to_dict()}, message="User unsuspended successfully") + + +@api_v1_bp.route("/admin/users//verify-email", methods=["POST"]) +@login_required +@full_access_required +def admin_verify_user_email(user_id): + from gatehouse_app.models.auth.email_verification_token import EmailVerificationToken + from gatehouse_app.extensions import db as _db + from gatehouse_app.utils.constants import UserStatus, AuditAction + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + target = _find_user_for_admin(user_id) + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + admin_in_shared_org = _get_admin_access(caller, target) + if not admin_in_shared_org: + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + target.email_verified = True + was_inactive = target.status == UserStatus.INACTIVE + if was_inactive: + target.status = UserStatus.ACTIVE + + EmailVerificationToken.query.filter_by(user_id=target.id, used_at=None).delete() + _db.session.commit() + + AuditService.log_action( + action=AuditAction.ADMIN_EMAIL_VERIFY, + user_id=caller.id, + organization_id=admin_in_shared_org.organization_id, + resource_type="user", resource_id=str(target.id), + description=f"Admin force-verified email for {target.email}", + metadata={"target_user_id": str(target.id), "target_email": target.email, "was_inactive": was_inactive}, + ) + return api_response(data={"user": target.to_dict()}, message="Email verified and account activated successfully") + + +@api_v1_bp.route("/admin/users//delete", methods=["POST"]) +@login_required +@full_access_required +def admin_hard_delete_user(user_id): + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.user.user import User as _User + from gatehouse_app.models.ssh_ca.ssh_key import SSHKey + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate + from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog + from gatehouse_app.models.auth.authentication_method import OAuthState + from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy + from gatehouse_app.extensions import db as _db + from gatehouse_app.utils.constants import AuditAction, OrganizationRole + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + data = request.get_json() or {} + + if not data.get("confirm"): + return api_response( + success=False, + message='Deletion requires explicit confirmation. Send {"confirm": true} to proceed.', + status=400, error_type="CONFIRMATION_REQUIRED", + ) + + target = _User.query.filter_by(id=user_id).first() + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if target.id == caller.id: + return api_response(success=False, message="Cannot delete your own account via this endpoint.", status=400, error_type="BAD_REQUEST") + + target_org_ids = {m.organization_id for m in target.organization_memberships} + admin_in_shared_org = OrganizationMember.query.filter( + OrganizationMember.user_id == caller.id, + OrganizationMember.organization_id.in_(target_org_ids), + OrganizationMember.role.in_(["OWNER", "ADMIN"]), + OrganizationMember.deleted_at == None, + ).first() + if not admin_in_shared_org: + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + owner_memberships = OrganizationMember.query.filter( + OrganizationMember.user_id == target.id, + OrganizationMember.role == OrganizationRole.OWNER, + OrganizationMember.deleted_at == None, + ).all() + if owner_memberships: + org_names = [m.organization.name for m in owner_memberships if m.organization and not m.organization.deleted_at] + return api_response( + success=False, + message=( + f"Cannot delete an organization owner. {target.email} is the owner of: {', '.join(org_names)}. " + "Transfer ownership to another member first." + ), + status=403, error_type="OWNER_PROTECTION", + ) + + ssh_key_count = SSHKey.query.filter_by(user_id=target.id, deleted_at=None).count() + active_certs = SSHCertificate.query.filter_by(user_id=target.id, revoked=False).filter(SSHCertificate.deleted_at == None).all() + active_cert_count = len(active_certs) + + for cert in active_certs: + try: + cert.revoke("account_deleted") + except Exception: + pass + + if active_certs: + try: + _db.session.flush() + except Exception: + pass + + target_email = target.email + target_id_str = str(target.id) + + try: + # NULL out FK references that don't cascade on delete so the + # session.delete() below doesn't hit FK constraint violations. + + # org_invite_tokens.invited_by_id — SET NULL is already on the FK column, + # but OrganizationMember.invited_by_id has no ondelete clause. + _db.session.execute( + _db.text("UPDATE organization_members SET invited_by_id = NULL WHERE invited_by_id = :uid"), + {"uid": target_id_str}, + ) + + # certificate_audit_logs.user_id — nullable, no ondelete clause. + CertificateAuditLog.query.filter_by(user_id=target_id_str).update( + {"user_id": None}, synchronize_session=False + ) + + # organization_security_policies.updated_by_user_id — nullable, no ondelete. + OrganizationSecurityPolicy.query.filter_by(updated_by_user_id=target_id_str).update( + {"updated_by_user_id": None}, synchronize_session=False + ) + + # oauth_states.user_id — nullable, no ondelete. + OAuthState.query.filter_by(user_id=target_id_str).delete(synchronize_session=False) + + _db.session.delete(target) + _db.session.flush() + except Exception as exc: + _db.session.rollback() + _logger.error(f"Hard delete failed for {target_id_str}: {exc}") + return api_response(success=False, message="Failed to delete user account. Please try again.", status=500, error_type="SERVER_ERROR") + + AuditService.log_action( + action=AuditAction.USER_HARD_DELETE, + user_id=caller.id, + organization_id=admin_in_shared_org.organization_id, + resource_type="user", resource_id=target_id_str, + description=f"Admin permanently deleted user account: {target_email}", + metadata={ + "deleted_user_id": target_id_str, "deleted_user_email": target_email, + "ssh_keys_deleted": ssh_key_count, "certs_revoked": active_cert_count, + }, + ) + + _db.session.commit() + return api_response( + message=f"User account {target_email} has been permanently deleted.", + data={"deleted_user_id": target_id_str, "deleted_user_email": target_email, + "ssh_keys_deleted": ssh_key_count, "certs_revoked": active_cert_count}, + ) + + +@api_v1_bp.route("/admin/users//restore", methods=["POST"]) +@login_required +@full_access_required +def admin_restore_user(user_id): + """Restore a soft-deleted user account. + + A user who self-deleted but still has an active org membership (and active + auth methods) can be restored by an admin. Clearing ``deleted_at`` makes + the account usable again without touching any auth methods. + """ + from gatehouse_app.extensions import db as _db + from gatehouse_app.utils.constants import UserStatus, AuditAction + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + target = _find_user_for_admin(user_id) + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if not _get_admin_access(caller, target): + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + if target.deleted_at is None: + return api_response( + success=False, message="User account is not deleted — nothing to restore.", + status=409, error_type="CONFLICT", + ) + + target.deleted_at = None + if target.status not in (UserStatus.ACTIVE, UserStatus.INACTIVE): + target.status = UserStatus.ACTIVE + _db.session.commit() + + AuditService.log_action( + action=AuditAction.USER_UNSUSPEND, # closest existing action + user_id=caller.id, + organization_id=_get_admin_access(caller, target).organization_id, + resource_type="user", resource_id=str(target.id), + description=f"Admin restored soft-deleted user account {target.email}", + metadata={"target_user_id": str(target.id), "target_email": target.email, "admin_email": caller.email}, + ) + return api_response( + data={"user": target.to_dict()}, + message=f"User account {target.email} has been restored successfully.", + ) + + +@api_v1_bp.route("/admin/users//mfa", methods=["GET"]) +@login_required +@full_access_required +def admin_get_user_mfa(user_id): + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + caller = g.current_user + target = _find_user_for_admin(user_id) + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if not _get_admin_access(caller, target): + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + mfa_methods = [] + + totp_method = AuthenticationMethod.query.filter_by( + user_id=user_id, method_type=AuthMethodType.TOTP, verified=True, deleted_at=None, + ).first() + if totp_method: + enabled_at = ( + totp_method.totp_verified_at.isoformat() + if totp_method.totp_verified_at + else (totp_method.created_at.isoformat() if totp_method.created_at else None) + ) + mfa_methods.append({ + "id": str(totp_method.id), + "type": "totp", + "name": "Authenticator app (TOTP)", + "verified": totp_method.verified, + "enabled_at": enabled_at, + "created_at": totp_method.created_at.isoformat() if totp_method.created_at else None, + "last_used_at": totp_method.last_used_at.isoformat() if totp_method.last_used_at else None, + }) + + webauthn_method = AuthenticationMethod.query.filter_by( + user_id=user_id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None, + ).first() + if webauthn_method and webauthn_method.provider_data: + for cred in webauthn_method.provider_data.get("credentials", []): + if not cred.get("deleted_at"): + mfa_methods.append({ + "id": cred.get("id") or cred.get("credential_id"), + "type": "webauthn", + "name": cred.get("name") or cred.get("device_type") or "Passkey", + "device_type": cred.get("device_type", ""), + "transports": cred.get("transports", []), + "verified": True, + "created_at": cred.get("created_at"), + "last_used_at": cred.get("last_used_at"), + }) + + return api_response( + data={"user": {"id": str(target.id), "email": target.email, "full_name": target.full_name}, "mfa_methods": mfa_methods}, + message="MFA methods retrieved", + ) + + +@api_v1_bp.route("/admin/users//mfa/", methods=["DELETE"]) +@login_required +@full_access_required +def admin_remove_user_mfa(user_id, method_type): + from sqlalchemy.orm.attributes import flag_modified + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance + from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy + from gatehouse_app.extensions import db as _db + from gatehouse_app.utils.constants import AuthMethodType, AuditAction, MfaComplianceStatus, UserStatus as _UserStatus + from gatehouse_app.services.audit_service import AuditService + from datetime import timedelta + + caller = g.current_user + now = datetime.now(timezone.utc) + + VALID_TYPES = {"totp", "webauthn", "all"} + method_type = method_type.lower().strip() + if method_type not in VALID_TYPES: + return api_response( + success=False, + message=f"Invalid method_type '{method_type}'. Must be one of: {', '.join(sorted(VALID_TYPES))}", + status=400, error_type="VALIDATION_ERROR", + ) + + target = _find_user_for_admin(user_id) + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if target.id == caller.id: + return api_response(success=False, message="Use the regular MFA management endpoints to modify your own MFA methods.", status=400, error_type="BAD_REQUEST") + + admin_in_shared_org = _get_admin_access(caller, target) + if not admin_in_shared_org: + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + removed = [] + + if method_type in ("totp", "all"): + totp_methods = AuthenticationMethod.query.filter_by(user_id=user_id, method_type=AuthMethodType.TOTP, deleted_at=None).all() + if totp_methods: + for totp_method in totp_methods: + totp_method.deleted_at = now + totp_method.totp_secret = None + totp_method.totp_backup_codes = None + totp_method.totp_verified_at = None + _db.session.add(totp_method) + removed.append("totp") + elif method_type == "totp": + return api_response(success=False, message="User does not have TOTP configured", status=404, error_type="NOT_FOUND") + + if method_type in ("webauthn", "all"): + webauthn_method = AuthenticationMethod.query.filter_by(user_id=user_id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None).first() + if webauthn_method: + credential_id = request.args.get("credential_id") + if credential_id: + credentials = (webauthn_method.provider_data or {}).get("credentials", []) + found = False + new_credentials = [] + for cred in credentials: + cid = cred.get("id") or cred.get("credential_id") + if cid == credential_id and not cred.get("deleted_at"): + cred["deleted_at"] = now.isoformat() + found = True + removed.append(f"webauthn:{credential_id[:16]}") + new_credentials.append(cred) + if not found: + return api_response(success=False, message=f"WebAuthn credential '{credential_id}' not found", status=404, error_type="NOT_FOUND") + active_remaining = sum(1 for c in new_credentials if not c.get("deleted_at")) + if active_remaining == 0: + webauthn_method.deleted_at = now + else: + if webauthn_method.provider_data is None: + webauthn_method.provider_data = {} + webauthn_method.provider_data["credentials"] = new_credentials + flag_modified(webauthn_method, "provider_data") + _db.session.add(webauthn_method) + else: + webauthn_method.deleted_at = now + if webauthn_method.provider_data: + for cred in webauthn_method.provider_data.get("credentials", []): + cred["deleted_at"] = now.isoformat() + flag_modified(webauthn_method, "provider_data") + _db.session.add(webauthn_method) + removed.append("webauthn") + elif method_type == "webauthn": + return api_response(success=False, message="User does not have any WebAuthn passkeys configured", status=404, error_type="NOT_FOUND") + + if not removed: + return api_response(success=False, message="No MFA methods found to remove", status=404, error_type="NOT_FOUND") + + compliance_records = MfaPolicyCompliance.query.filter_by(user_id=user_id).filter(MfaPolicyCompliance.deleted_at == None).all() + for record in compliance_records: + if record.status in (MfaComplianceStatus.COMPLIANT, MfaComplianceStatus.PAST_DUE, MfaComplianceStatus.SUSPENDED): + record.status = MfaComplianceStatus.IN_GRACE + record.compliant_at = None + record.suspended_at = None + org_policy = OrganizationSecurityPolicy.query.filter_by(organization_id=record.organization_id, deleted_at=None).first() + grace_days = org_policy.mfa_grace_period_days if org_policy else 14 + record.deadline_at = now + timedelta(days=grace_days) + record.applied_at = now + record.notification_count = 0 + record.last_notified_at = None + + if target.status == _UserStatus.COMPLIANCE_SUSPENDED: + target.status = _UserStatus.ACTIVE + _db.session.add(target) + + _db.session.commit() + + AuditService.log_action( + action=AuditAction.ADMIN_MFA_REMOVE, + user_id=caller.id, + organization_id=admin_in_shared_org.organization_id, + resource_type="user", resource_id=str(target.id), + description=f"Admin removed MFA method(s) [{', '.join(removed)}] for user {target.email}", + metadata={"target_user_id": str(target.id), "target_user_email": target.email, "removed_methods": removed, "admin_email": caller.email}, + ) + + return api_response( + data={"removed_methods": removed, "removed_count": len(removed), "user": {"id": str(target.id), "email": target.email}}, + message=f"Removed {len(removed)} MFA method(s) for {target.email}", + ) + + +@api_v1_bp.route("/admin/users//password", methods=["POST"]) +@login_required +@full_access_required +def admin_set_user_password(user_id): + from flask_bcrypt import Bcrypt + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.extensions import db as _db + from gatehouse_app.utils.constants import AuthMethodType, AuditAction + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + data = request.get_json() or {} + new_password = data.get("password", "").strip() + + if len(new_password) < 8: + return api_response(success=False, message="Password must be at least 8 characters", status=400, error_type="VALIDATION_ERROR") + + target = _find_user_for_admin(user_id) + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if target.id == caller.id: + return api_response(success=False, message="Use the regular password change endpoint to update your own password.", status=400, error_type="BAD_REQUEST") + + admin_in_shared_org = _get_admin_access(caller, target) + if not admin_in_shared_org: + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + bcrypt = Bcrypt() + password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8") + now = datetime.now(timezone.utc) + + pw_method = AuthenticationMethod.query.filter_by(user_id=user_id, method_type=AuthMethodType.PASSWORD, deleted_at=None).first() + method_was_created = False + if pw_method: + pw_method.password_hash = password_hash + pw_method.updated_at = now + _db.session.add(pw_method) + action_description = f"Admin reset password for user {target.email}" + else: + method_was_created = True + pw_method = AuthenticationMethod( + user_id=user_id, method_type=AuthMethodType.PASSWORD, + password_hash=password_hash, verified=True, created_at=now, + ) + _db.session.add(pw_method) + action_description = f"Admin set password for user {target.email} (new method created)" + + _db.session.commit() + + AuditService.log_action( + action=AuditAction.ADMIN_PASSWORD_SET, + user_id=caller.id, + organization_id=admin_in_shared_org.organization_id, + resource_type="user", resource_id=str(target.id), + description=action_description, + metadata={"target_user_id": str(target.id), "target_user_email": target.email, "admin_email": caller.email, "method_created": method_was_created}, + ) + return api_response(data={"user": {"id": str(target.id), "email": target.email}}, message=f"Password updated for {target.email}") + + +@api_v1_bp.route("/admin/users//linked-accounts", methods=["GET"]) +@login_required +@full_access_required +def admin_get_user_linked_accounts(user_id): + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + caller = g.current_user + target = _find_user_for_admin(user_id) + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if not _get_admin_access(caller, target): + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + OAUTH_TYPES = {AuthMethodType.GOOGLE, AuthMethodType.GITHUB, AuthMethodType.MICROSOFT, AuthMethodType.OIDC} + + oauth_methods = AuthenticationMethod.query.filter( + AuthenticationMethod.user_id == user_id, + AuthenticationMethod.method_type.in_(OAUTH_TYPES), + AuthenticationMethod.deleted_at == None, + ).all() + + linked_accounts = [] + for method in oauth_methods: + pd = method.provider_data or {} + connected_since = method.created_at.isoformat() if method.created_at else None + linked_accounts.append({ + "id": str(method.id), + "provider_type": method.method_type.value, + "email": pd.get("email"), + "name": pd.get("name"), + "provider_user_id": method.provider_user_id, + # both names so old and new clients both work + "linked_at": connected_since, + "connected_since": connected_since, + "verified": method.verified, + }) + + all_active_methods = AuthenticationMethod.query.filter_by(user_id=user_id, deleted_at=None).count() + + return api_response( + data={ + "user": {"id": str(target.id), "email": target.email, "full_name": target.full_name}, + "linked_accounts": linked_accounts, + "total_auth_methods": all_active_methods, + }, + message="Linked accounts retrieved", + ) + + +@api_v1_bp.route("/admin/users//linked-accounts/", methods=["DELETE"]) +@login_required +@full_access_required +def admin_unlink_user_provider(user_id, provider): + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.extensions import db as _db + from gatehouse_app.utils.constants import AuthMethodType, AuditAction + from gatehouse_app.services.audit_service import AuditService + + caller = g.current_user + + OAUTH_TYPES = {AuthMethodType.GOOGLE, AuthMethodType.GITHUB, AuthMethodType.MICROSOFT, AuthMethodType.OIDC} + PROVIDER_MAP = {t.value: t for t in OAUTH_TYPES} + + target = _find_user_for_admin(user_id) + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if target.id == caller.id: + return api_response(success=False, message="Use the regular account settings to unlink your own providers.", status=400, error_type="BAD_REQUEST") + + admin_in_shared_org = _get_admin_access(caller, target) + if not admin_in_shared_org: + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + provider_lower = provider.lower().strip() + method_to_unlink = None + if provider_lower in PROVIDER_MAP: + method_to_unlink = AuthenticationMethod.query.filter_by( + user_id=user_id, method_type=PROVIDER_MAP[provider_lower], deleted_at=None, + ).first() + else: + method_to_unlink = AuthenticationMethod.query.filter( + AuthenticationMethod.id == provider, + AuthenticationMethod.user_id == user_id, + AuthenticationMethod.method_type.in_(OAUTH_TYPES), + AuthenticationMethod.deleted_at == None, + ).first() + + if not method_to_unlink: + return api_response(success=False, message=f"Provider '{provider}' is not linked to this user's account", status=404, error_type="NOT_FOUND") + + all_active = AuthenticationMethod.query.filter_by(user_id=user_id, deleted_at=None).all() + remaining = [m for m in all_active if m.id != method_to_unlink.id] + has_password_remaining = any(m.method_type == AuthMethodType.PASSWORD and m.password_hash for m in remaining) + has_other_oauth_remaining = any(m.method_type in OAUTH_TYPES for m in remaining) + + if not has_password_remaining and not has_other_oauth_remaining: + return api_response( + success=False, + message="Cannot unlink this provider — it is the user's only sign-in method. Ensure the user has a password or another linked provider before unlinking.", + status=400, error_type="VALIDATION_ERROR", + ) + + now = datetime.now(timezone.utc) + provider_name = method_to_unlink.method_type.value + method_to_unlink.deleted_at = now + _db.session.add(method_to_unlink) + _db.session.commit() + + AuditService.log_action( + action=AuditAction.ADMIN_OAUTH_UNLINK, + user_id=caller.id, + organization_id=admin_in_shared_org.organization_id, + resource_type="user", resource_id=str(target.id), + description=f"Admin unlinked {provider_name} OAuth provider from user {target.email}", + metadata={"target_user_id": str(target.id), "target_user_email": target.email, "provider": provider_name, "admin_email": caller.email}, + ) + return api_response( + data={"provider": provider_name, "user": {"id": str(target.id), "email": target.email}}, + message=f"Successfully unlinked {provider_name} from {target.email}", + ) diff --git a/gatehouse_app/api/v1/users/me.py b/gatehouse_app/api/v1/users/me.py new file mode 100644 index 0000000..8b9983c --- /dev/null +++ b/gatehouse_app/api/v1/users/me.py @@ -0,0 +1,299 @@ +"""Current user (self-service) endpoints.""" +from flask import g, request +from marshmallow import ValidationError +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, full_access_required +from gatehouse_app.schemas.user_schema import UserUpdateSchema, ChangePasswordSchema +from gatehouse_app.services.user_service import UserService +from gatehouse_app.services.auth_service import AuthService + + +@api_v1_bp.route("/users/me", methods=["GET"]) +@login_required +def get_me(): + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + user = g.current_user + user_dict = user.to_dict() + + OAUTH_TYPES = { + AuthMethodType.GOOGLE, AuthMethodType.GITHUB, + AuthMethodType.MICROSOFT, AuthMethodType.OIDC, + } + auth_methods = AuthenticationMethod.query.filter_by(user_id=user.id, deleted_at=None).all() + + has_password = any(m.method_type == AuthMethodType.PASSWORD and m.password_hash for m in auth_methods) + totp_enabled = any(m.method_type == AuthMethodType.TOTP and m.verified for m in auth_methods) + linked_providers = [m.method_type.value for m in auth_methods if m.method_type in OAUTH_TYPES] + + user_dict["has_password"] = has_password + user_dict["totp_enabled"] = totp_enabled + user_dict["linked_providers"] = linked_providers + + return api_response(data={"user": user_dict}, message="User profile retrieved successfully") + + +@api_v1_bp.route("/users/me", methods=["PATCH"]) +@login_required +@full_access_required +def update_me(): + try: + schema = UserUpdateSchema() + data = schema.load(request.json) + user = UserService.update_user(g.current_user, **data) + return api_response(data={"user": user.to_dict()}, message="Profile updated successfully") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + + +@api_v1_bp.route("/users/me", methods=["DELETE"]) +@login_required +@full_access_required +def delete_me(): + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + from gatehouse_app.services.organization_service import OrganizationService + + user = g.current_user + + owned_memberships = OrganizationMember.query.filter_by( + user_id=user.id, role=OrganizationRole.OWNER, deleted_at=None, + ).all() + + transfer_needed = [] + auto_delete = [] + + for membership in owned_memberships: + org = membership.organization + if org.deleted_at is not None: + continue + if org.get_member_count() > 1: + transfer_needed.append(org.name) + else: + auto_delete.append(org) + + if transfer_needed: + names = ", ".join(f'"{n}"' for n in transfer_needed) + return api_response( + success=False, + message=( + f"You are the owner of {len(transfer_needed)} organization" + f"{'s' if len(transfer_needed) > 1 else ''} that still " + f"{'have' if len(transfer_needed) > 1 else 'has'} other members " + f"({names}). Transfer ownership to another member first." + ), + status=409, + error_type="USER_IS_SOLE_OWNER", + error_details={"transfer_ownership": transfer_needed}, + ) + + for org in auto_delete: + OrganizationService.force_delete_organization(org, user_id=user.id) + + UserService.delete_user(user, soft=True) + return api_response(message="Account deleted successfully") + + +@api_v1_bp.route("/users/me/password", methods=["POST"]) +@login_required +@full_access_required +def change_password(): + try: + schema = ChangePasswordSchema() + data = schema.load(request.json) + + if data["new_password"] != data["new_password_confirm"]: + return api_response( + success=False, message="New passwords do not match", status=400, + error_type="VALIDATION_ERROR", + error_details={"new_password_confirm": ["Passwords do not match"]}, + ) + + AuthService.change_password( + user=g.current_user, + current_password=data["current_password"], + new_password=data["new_password"], + ) + return api_response(message="Password changed successfully") + except ValidationError as e: + return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + + +@api_v1_bp.route("/users/me/organizations", methods=["GET"]) +@login_required +@full_access_required +def get_my_organizations(): + from gatehouse_app.models.organization.organization_member import OrganizationMember + + user = g.current_user + memberships = OrganizationMember.query.filter_by(user_id=user.id, deleted_at=None).all() + + orgs = [] + for membership in memberships: + org = membership.organization + if not org or org.deleted_at is not None: + continue + org_dict = org.to_dict() + org_dict["role"] = membership.role.value if hasattr(membership.role, "value") else str(membership.role) + orgs.append(org_dict) + + return api_response(data={"organizations": orgs, "count": len(orgs)}, message="Organizations retrieved successfully") + + +@api_v1_bp.route("/users/me/principals", methods=["GET"]) +@login_required +@full_access_required +def get_my_principals(): + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.organization.principal import Principal, PrincipalMembership + from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal + from gatehouse_app.utils.constants import OrganizationRole + + user = g.current_user + user_id = user.id + + memberships = OrganizationMember.query.filter_by(user_id=user_id, deleted_at=None).all() + + orgs_result = [] + for membership in memberships: + org = membership.organization + if not org or org.deleted_at is not None: + continue + + role = membership.role + is_admin = role in (OrganizationRole.ADMIN, OrganizationRole.OWNER) + + direct_principal_ids = set() + via_dept_principal_ids = set() + + for pm in PrincipalMembership.query.filter_by(user_id=user_id, deleted_at=None).all(): + if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None: + direct_principal_ids.add(pm.principal_id) + + for dm in DepartmentMembership.query.filter_by(user_id=user_id, deleted_at=None).all(): + if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None: + for dp in DepartmentPrincipal.query.filter_by(department_id=dm.department_id, deleted_at=None).all(): + if dp.principal and dp.principal.deleted_at is None: + via_dept_principal_ids.add(dp.principal_id) + + effective_principal_ids = direct_principal_ids | via_dept_principal_ids + + my_principals = [] + if effective_principal_ids: + for p in Principal.query.filter( + Principal.id.in_(list(effective_principal_ids)), + Principal.deleted_at == None, + ).all(): + my_principals.append({ + "id": p.id, "name": p.name, "description": p.description, + "direct": p.id in direct_principal_ids, + }) + + all_principals = [] + if is_admin: + for p in Principal.query.filter_by(organization_id=org.id, deleted_at=None).all(): + all_principals.append({"id": p.id, "name": p.name, "description": p.description}) + + orgs_result.append({ + "org_id": org.id, "org_name": org.name, + "role": role.value if hasattr(role, "value") else role, + "is_admin": is_admin, + "my_principals": my_principals, + "all_principals": all_principals, + }) + + return api_response(data={"orgs": orgs_result}, message="Principals retrieved successfully") + + +@api_v1_bp.route("/users/me/invites", methods=["GET"]) +@login_required +def get_my_pending_invites(): + from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + from datetime import datetime, timezone + + user = g.current_user + now = datetime.now(timezone.utc) + + invites = OrgInviteToken.query.filter( + OrgInviteToken.email == user.email, + OrgInviteToken.accepted_at.is_(None), + OrgInviteToken.expires_at > now, + OrgInviteToken.deleted_at.is_(None), + ).all() + + return api_response( + data={ + "invites": [ + { + "token": i.token, + "organization": {"id": str(i.organization_id), "name": i.organization.name}, + "role": i.role, + "expires_at": i.expires_at.isoformat(), + } + for i in invites + ] + }, + message="Pending invitations retrieved", + ) + + +@api_v1_bp.route("/users/me/memberships", methods=["GET"]) +@login_required +def get_my_memberships(): + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal, Department + from gatehouse_app.models.organization.principal import Principal, PrincipalMembership + + user = g.current_user + + memberships = OrganizationMember.query.filter_by(user_id=user.id, deleted_at=None).all() + + orgs_result = [] + for membership in memberships: + org = membership.organization + if not org or org.deleted_at is not None: + continue + + dept_memberships = DepartmentMembership.query.filter_by(user_id=user.id, deleted_at=None).all() + user_depts = [ + dm.department for dm in dept_memberships + if dm.department + and dm.department.organization_id == org.id + and dm.department.deleted_at is None + ] + + direct_pm = PrincipalMembership.query.filter_by(user_id=user.id, deleted_at=None).all() + direct_principal_ids = { + pm.principal_id for pm in direct_pm + if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None + } + + via_dept_principal_ids = set() + for dept in user_depts: + for dp in DepartmentPrincipal.query.filter_by(department_id=dept.id, deleted_at=None).all(): + if dp.principal and dp.principal.deleted_at is None: + via_dept_principal_ids.add(dp.principal_id) + + all_principal_ids = direct_principal_ids | via_dept_principal_ids + principals_list = [] + if all_principal_ids: + for p in Principal.query.filter( + Principal.id.in_(list(all_principal_ids)), + Principal.deleted_at == None, + ).all(): + principals_list.append({ + "id": str(p.id), "name": p.name, "description": p.description, + "via_department": p.id not in direct_principal_ids, + }) + + role = membership.role + orgs_result.append({ + "org_id": str(org.id), "org_name": org.name, + "role": role.value if hasattr(role, "value") else role, + "departments": [{"id": str(d.id), "name": d.name, "description": d.description} for d in user_depts], + "principals": principals_list, + }) + + return api_response(data={"orgs": orgs_result}, message="Memberships retrieved") diff --git a/gatehouse_app/jobs/__init__.py b/gatehouse_app/jobs/__init__.py index 68eef1a..155944d 100644 --- a/gatehouse_app/jobs/__init__.py +++ b/gatehouse_app/jobs/__init__.py @@ -1 +1 @@ -Jobs module for scheduled tasks. \ No newline at end of file +"""Jobs module for scheduled tasks.""" \ No newline at end of file diff --git a/gatehouse_app/jobs/mfa_compliance_job.py b/gatehouse_app/jobs/mfa_compliance_job.py index a8548c8..9c6a882 100644 --- a/gatehouse_app/jobs/mfa_compliance_job.py +++ b/gatehouse_app/jobs/mfa_compliance_job.py @@ -20,9 +20,9 @@ from typing import Optional, Dict, Any, List import logging from gatehouse_app.extensions import db -from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance -from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy -from gatehouse_app.models.user import User +from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance +from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy +from gatehouse_app.models.user.user import User from gatehouse_app.services.mfa_policy_service import MfaPolicyService from gatehouse_app.services.notification_service import NotificationService from gatehouse_app.utils.constants import MfaComplianceStatus @@ -203,6 +203,19 @@ def _evaluate_pending_compliance(now: datetime) -> int: if not user: continue + # Skip records for deleted organizations + from gatehouse_app.models.organization.organization import Organization + org = Organization.query.get(record.organization_id) + if not org or org.deleted_at is not None: + # Soft-delete orphaned compliance record + record.deleted_at = now or datetime.now(timezone.utc) + db.session.commit() + logger.info( + f"Cleaned up orphaned compliance record {record.id} " + f"for deleted org {record.organization_id}" + ) + continue + # Re-evaluate compliance status # This handles cases where policy changed or user enrolled in MFA from gatehouse_app.services.mfa_policy_service import MfaPolicyService diff --git a/gatehouse_app/models/auth/audit_log.py b/gatehouse_app/models/auth/audit_log.py index 849f915..7ccf106 100644 --- a/gatehouse_app/models/auth/audit_log.py +++ b/gatehouse_app/models/auth/audit_log.py @@ -10,7 +10,7 @@ class AuditLog(BaseModel): __tablename__ = "audit_logs" user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True) - action = db.Column(db.Enum(AuditAction), nullable=False, index=True) + action = db.Column(db.String(100), nullable=False, index=True) # Context resource_type = db.Column(db.String(50), nullable=True, index=True) diff --git a/gatehouse_app/schemas/auth_schema.py b/gatehouse_app/schemas/auth_schema.py index dff1758..51ecd90 100644 --- a/gatehouse_app/schemas/auth_schema.py +++ b/gatehouse_app/schemas/auth_schema.py @@ -113,7 +113,8 @@ class TOTPVerifySchema(Schema): class TOTPDisableSchema(Schema): """Schema for disabling TOTP.""" - password = fields.Str(required=True, validate=validate.Length(min=1)) + # Password is optional: OAuth-only users have no password and skip verification. + password = fields.Str(load_default=None, allow_none=True) class TOTPRegenerateBackupCodesSchema(Schema): diff --git a/gatehouse_app/services/__init__.py b/gatehouse_app/services/__init__.py index 46213f4..8a27413 100644 --- a/gatehouse_app/services/__init__.py +++ b/gatehouse_app/services/__init__.py @@ -4,7 +4,7 @@ from gatehouse_app.services.user_service import UserService from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.services.session_service import SessionService from gatehouse_app.services.audit_service import AuditService -from gatehouse_app.services.oidc_service import OIDCService, OIDCError +from gatehouse_app.services.oidc import OIDCService, OIDCError from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService from gatehouse_app.services.oidc_token_service import OIDCTokenService from gatehouse_app.services.oidc_session_service import OIDCSessionService diff --git a/gatehouse_app/services/auth_service.py b/gatehouse_app/services/auth_service.py index 9061f33..fdf09b7 100644 --- a/gatehouse_app/services/auth_service.py +++ b/gatehouse_app/services/auth_service.py @@ -388,7 +388,7 @@ class AuthService: Args: user: User instance - password: User's current password for verification + password: User's current password for verification (ignored for OAuth-only users) Returns: True if TOTP disabled successfully @@ -396,18 +396,21 @@ class AuthService: Raises: InvalidCredentialsError: If password is invalid or TOTP method not found """ - # Verify user's password + # Verify user's password — only required when the user actually has one. + # OAuth-only users have no PASSWORD auth method; they authenticate via their + # identity provider so there is nothing to check here. auth_method = AuthenticationMethod.query.filter_by( user_id=user.id, method_type=AuthMethodType.PASSWORD, deleted_at=None, ).first() - if not auth_method or not auth_method.password_hash: - raise InvalidCredentialsError("No password authentication method found") - - if not bcrypt.check_password_hash(auth_method.password_hash, password): - raise InvalidCredentialsError("Invalid password") + if auth_method and auth_method.password_hash: + # Password-based account: a password must be supplied and must match. + if not password: + raise InvalidCredentialsError("Password is required") + if not bcrypt.check_password_hash(auth_method.password_hash, password): + raise InvalidCredentialsError("Invalid password") # Get user's TOTP authentication method totp_method = user.get_totp_method() diff --git a/gatehouse_app/services/external_auth/__init__.py b/gatehouse_app/services/external_auth/__init__.py new file mode 100644 index 0000000..a09e00e --- /dev/null +++ b/gatehouse_app/services/external_auth/__init__.py @@ -0,0 +1,168 @@ +"""ExternalAuthService — public facade re-exporting the full API.""" +import logging +from typing import Optional, Tuple + +from gatehouse_app.models import AuthenticationMethod, User +from gatehouse_app.models.auth.authentication_method import ( + ApplicationProviderConfig, + OrganizationProviderOverride, + OAuthState, +) +from gatehouse_app.utils.constants import AuthMethodType + +from gatehouse_app.services.external_auth.models import ( + ExternalAuthError, + ExternalProviderConfig, + ProviderConfigAdapter, +) +from gatehouse_app.services.external_auth import app_provider, org_override, linking +from gatehouse_app.services.external_auth._helpers import ( + _compute_s256_challenge, + _build_authorization_url, + _exchange_code, + _get_user_info, + _encrypt_provider_data, + _decrypt_provider_data, +) + +logger = logging.getLogger(__name__) + + +class ExternalAuthService: + """Service for external authentication operations.""" + + # ── Provider config lookup ────────────────────────────────────────────── + + @classmethod + def get_provider_config( + cls, + provider_type: AuthMethodType, + organization_id: Optional[str] = None, + ) -> ProviderConfigAdapter: + provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type + + app_config = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type_str + ).first() + + if not app_config: + raise ExternalAuthError( + f"{provider_type_str.title()} OAuth is not configured for this application", + "PROVIDER_NOT_CONFIGURED", + 400, + ) + + if not app_config.is_enabled: + raise ExternalAuthError( + f"{provider_type_str.title()} OAuth is currently disabled", + "PROVIDER_DISABLED", + 400, + ) + + org_override_obj = None + if organization_id: + org_override_obj = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id, + provider_type=provider_type_str, + ).first() + + if org_override_obj and not org_override_obj.is_enabled: + raise ExternalAuthError( + f"{provider_type_str.title()} OAuth is disabled for this organization", + "PROVIDER_DISABLED_FOR_ORG", + 400, + ) + + return ProviderConfigAdapter(app_config, org_override_obj) + + # ── App-wide provider config ──────────────────────────────────────────── + + @classmethod + def create_app_provider_config(cls, provider_type, client_id, client_secret, **kwargs): + return app_provider.create_app_provider_config(provider_type, client_id, client_secret, **kwargs) + + @classmethod + def update_app_provider_config(cls, provider_type, **updates): + return app_provider.update_app_provider_config(provider_type, **updates) + + @classmethod + def get_app_provider_config(cls, provider_type): + return app_provider.get_app_provider_config(provider_type) + + @classmethod + def list_app_provider_configs(cls): + return app_provider.list_app_provider_configs() + + @classmethod + def delete_app_provider_config(cls, provider_type): + return app_provider.delete_app_provider_config(provider_type) + + # ── Org override management ───────────────────────────────────────────── + + @classmethod + def create_org_provider_override(cls, organization_id, provider_type, **kwargs): + return org_override.create_org_provider_override(organization_id, provider_type, **kwargs) + + @classmethod + def update_org_provider_override(cls, organization_id, provider_type, **updates): + return org_override.update_org_provider_override(organization_id, provider_type, **updates) + + @classmethod + def get_org_provider_override(cls, organization_id, provider_type): + return org_override.get_org_provider_override(organization_id, provider_type) + + @classmethod + def list_org_provider_overrides(cls, organization_id): + return org_override.list_org_provider_overrides(organization_id) + + @classmethod + def delete_org_provider_override(cls, organization_id, provider_type): + return org_override.delete_org_provider_override(organization_id, provider_type) + + # ── OAuth link / auth flows ───────────────────────────────────────────── + + @classmethod + def initiate_link_flow(cls, user_id, provider_type, organization_id, redirect_uri=None): + return linking.initiate_link_flow(cls.get_provider_config, user_id, provider_type, organization_id, redirect_uri) + + @classmethod + def complete_link_flow(cls, provider_type, authorization_code, state, redirect_uri): + return linking.complete_link_flow(cls.get_provider_config, provider_type, authorization_code, state, redirect_uri) + + @classmethod + def authenticate_with_provider(cls, provider_type, organization_id, authorization_code, state, redirect_uri): + return linking.authenticate_with_provider(cls.get_provider_config, provider_type, organization_id, authorization_code, state, redirect_uri) + + @classmethod + def unlink_provider(cls, user_id, provider_type, organization_id=None): + return linking.unlink_provider(user_id, provider_type, organization_id) + + @classmethod + def get_linked_accounts(cls, user_id): + return linking.get_linked_accounts(user_id) + + # ── Static helpers (kept as class methods for backward compatibility) ─── + + @staticmethod + def _compute_s256_challenge(verifier: str) -> str: + return _compute_s256_challenge(verifier) + + @staticmethod + def _build_authorization_url(config, state) -> str: + return _build_authorization_url(config, state) + + @staticmethod + def _exchange_code(config, code, redirect_uri, code_verifier=None) -> dict: + return _exchange_code(config, code, redirect_uri, code_verifier) + + @staticmethod + def _get_user_info(config, access_token) -> dict: + return _get_user_info(config, access_token) + + @staticmethod + def _encrypt_provider_data(tokens, user_info) -> dict: + return _encrypt_provider_data(tokens, user_info) + + @staticmethod + def _decrypt_provider_data(provider_data) -> dict: + return _decrypt_provider_data(provider_data) diff --git a/gatehouse_app/services/external_auth/_helpers.py b/gatehouse_app/services/external_auth/_helpers.py new file mode 100644 index 0000000..20ad3cd --- /dev/null +++ b/gatehouse_app/services/external_auth/_helpers.py @@ -0,0 +1,183 @@ +"""Static helper methods for OAuth flows.""" +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +def _compute_s256_challenge(verifier: str) -> str: + import hashlib + import base64 + digest = hashlib.sha256(verifier.encode()).digest() + return base64.urlsafe_b64encode(digest).decode().rstrip("=") + + +def _build_authorization_url(config, state) -> str: + from urllib.parse import urlencode + provider = (config.provider_type or "").lower() + + params = { + "client_id": config.client_id, + "redirect_uri": state.redirect_uri, + "response_type": "code", + "scope": " ".join(config.scopes or ["openid", "profile", "email"]), + "state": state.state, + } + + if provider == "google": + params["access_type"] = ( + config.settings.get("access_type", "offline") if config.settings else "offline" + ) + params["prompt"] = ( + config.settings.get("prompt", "consent") if config.settings else "consent" + ) + elif provider == "microsoft": + params["prompt"] = ( + config.settings.get("prompt", "select_account") if config.settings else "select_account" + ) + else: + if config.settings: + if "prompt" in config.settings: + params["prompt"] = config.settings["prompt"] + if "access_type" in config.settings: + params["access_type"] = config.settings["access_type"] + + if state.nonce: + params["nonce"] = state.nonce + + if state.code_challenge: + params["code_challenge"] = state.code_challenge + params["code_challenge_method"] = "S256" + + full_url = f"{config.auth_url}?{urlencode(params)}" + + logger.info( + f"[PKCE DEBUG] Building authorization URL:\n" + f" provider_type: {config.provider_type}\n" + f" state.code_challenge: {state.code_challenge[:20] if state.code_challenge else 'None'}...\n" + f" params has code_challenge: {'code_challenge' in params}\n" + f" Full URL: {full_url}" + ) + + return full_url + + +def _exchange_code(config, code: str, redirect_uri: str, code_verifier: str = None) -> dict: + import requests + + data = { + "client_id": config.client_id, + "client_secret": config.get_client_secret(), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + } + + if code_verifier: + data["code_verifier"] = code_verifier + + logger.debug( + f"Token exchange request: url={config.token_url}, " + f"client_id={config.client_id}, redirect_uri={redirect_uri}, " + f"has_code_verifier={bool(code_verifier)}" + ) + + response = requests.post(config.token_url, data=data) + + if response.status_code != 200: + logger.error( + f"Token exchange failed: status={response.status_code}, " + f"response={response.text}" + ) + + response.raise_for_status() + return response.json() + + +def _get_user_info(config, access_token: str) -> dict: + import re + import requests + + provider = (config.provider_type or "").lower() + headers = {"Authorization": f"Bearer {access_token}"} + response = requests.get(config.userinfo_url, headers=headers) + response.raise_for_status() + + data = response.json() + + if provider == "microsoft": + email_verified = data.get("email_verified", True) + else: + email_verified = data.get("email_verified", False) + + sub = data.get("sub") + + raw_email = data.get("email") + if not raw_email and sub: + if re.match(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", sub): + raw_email = sub + email_verified = True + else: + raw_email = f"{sub}@{provider or 'oauth'}.local" + email_verified = False + + raw_name = data.get("name") or data.get("display_name") + if not raw_name and raw_email: + raw_name = raw_email.split("@")[0] + + return { + "provider_user_id": sub, + "email": raw_email, + "email_verified": email_verified, + "name": raw_name, + "first_name": data.get("given_name"), + "last_name": data.get("family_name"), + "picture": data.get("picture"), + "raw_data": data, + } + + +def _encrypt_provider_data(tokens: dict, user_info: dict) -> dict: + from gatehouse_app.utils.encryption import encrypt + + return { + "access_token": encrypt(tokens.get("access_token")) if tokens.get("access_token") else None, + "token_type": tokens.get("token_type", "Bearer"), + "expires_in": tokens.get("expires_in"), + "refresh_token": encrypt(tokens.get("refresh_token")) if tokens.get("refresh_token") else None, + "scope": tokens.get("scope", []), + "id_token": encrypt(tokens.get("id_token")) if tokens.get("id_token") else None, + "email": user_info.get("email"), + "name": user_info.get("name"), + "picture": user_info.get("picture"), + "raw_data": user_info.get("raw_data", {}), + } + + +def _decrypt_provider_data(provider_data: dict) -> dict: + from gatehouse_app.utils.encryption import decrypt + + if not provider_data: + return {} + + result = { + "token_type": provider_data.get("token_type", "Bearer"), + "expires_in": provider_data.get("expires_in"), + "scope": provider_data.get("scope", []), + "email": provider_data.get("email"), + "name": provider_data.get("name"), + "picture": provider_data.get("picture"), + "raw_data": provider_data.get("raw_data", {}), + } + + for field in ("access_token", "refresh_token", "id_token"): + value = provider_data.get(field) + if value: + try: + result[field] = decrypt(value) + except Exception: + result[field] = value + else: + result[field] = None + + return result diff --git a/gatehouse_app/services/external_auth/app_provider.py b/gatehouse_app/services/external_auth/app_provider.py new file mode 100644 index 0000000..97c0e97 --- /dev/null +++ b/gatehouse_app/services/external_auth/app_provider.py @@ -0,0 +1,125 @@ +"""Application-wide provider configuration management.""" +import logging + +from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig +from gatehouse_app.services.external_auth.models import ExternalAuthError + +logger = logging.getLogger(__name__) + + +def create_app_provider_config( + provider_type: str, + client_id: str, + client_secret: str, + **kwargs, +) -> ApplicationProviderConfig: + existing = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type + ).first() + + if existing: + raise ExternalAuthError( + f"Provider {provider_type} already exists", + "PROVIDER_EXISTS", + 400, + ) + + additional_config = {} + for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']: + if key in kwargs: + additional_config[key] = kwargs.pop(key) + + if 'settings' in kwargs: + additional_config.update(kwargs.pop('settings')) + + config = ApplicationProviderConfig( + provider_type=provider_type, + client_id=client_id, + is_enabled=kwargs.get('is_enabled', True), + default_redirect_url=kwargs.get('default_redirect_url'), + additional_config=additional_config, + ) + config.set_client_secret(client_secret) + config.save() + + logger.info(f"Created application provider config for {provider_type}") + return config + + +def update_app_provider_config( + provider_type: str, + **updates, +) -> ApplicationProviderConfig: + config = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type + ).first() + + if not config: + raise ExternalAuthError( + f"Provider {provider_type} not found", + "PROVIDER_NOT_FOUND", + 404, + ) + + if 'client_id' in updates: + config.client_id = updates['client_id'] + + if 'client_secret' in updates: + config.set_client_secret(updates['client_secret']) + + if 'is_enabled' in updates: + config.is_enabled = updates['is_enabled'] + + if 'default_redirect_url' in updates: + config.default_redirect_url = updates['default_redirect_url'] + + if config.additional_config is None: + config.additional_config = {} + + for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']: + if key in updates: + config.additional_config[key] = updates[key] + + if 'settings' in updates: + config.additional_config.update(updates['settings']) + + config.save() + logger.info(f"Updated application provider config for {provider_type}") + return config + + +def get_app_provider_config(provider_type: str) -> ApplicationProviderConfig: + config = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type + ).first() + + if not config: + raise ExternalAuthError( + f"Provider {provider_type} not found", + "PROVIDER_NOT_FOUND", + 404, + ) + + return config + + +def list_app_provider_configs() -> list: + configs = ApplicationProviderConfig.query.all() + return [config.to_dict() for config in configs] + + +def delete_app_provider_config(provider_type: str) -> bool: + config = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type + ).first() + + if not config: + raise ExternalAuthError( + f"Provider {provider_type} not found", + "PROVIDER_NOT_FOUND", + 404, + ) + + config.delete() + logger.info(f"Deleted application provider config for {provider_type}") + return True diff --git a/gatehouse_app/services/external_auth/linking.py b/gatehouse_app/services/external_auth/linking.py new file mode 100644 index 0000000..e9fc5f4 --- /dev/null +++ b/gatehouse_app/services/external_auth/linking.py @@ -0,0 +1,339 @@ +"""Account linking, authentication, and unlinking flows.""" +import logging +import secrets +from datetime import datetime +from typing import Optional, Tuple + +from gatehouse_app.models import User, AuthenticationMethod +from gatehouse_app.models.auth.authentication_method import OAuthState +from gatehouse_app.utils.constants import AuthMethodType +from gatehouse_app.services.audit_service import AuditService +from gatehouse_app.services.external_auth.models import ExternalAuthError + +logger = logging.getLogger(__name__) + + +def initiate_link_flow( + get_provider_config, + user_id: str, + provider_type: AuthMethodType, + organization_id: str, + redirect_uri: str = None, +) -> Tuple[str, str]: + from gatehouse_app.services.external_auth._helpers import ( + _compute_s256_challenge, + _build_authorization_url, + ) + + provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type + config = get_provider_config(provider_type, organization_id) + + if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri): + raise ExternalAuthError("Invalid redirect URI", "INVALID_REDIRECT_URI", 400) + + code_verifier = None + code_challenge = None + if provider_type_str not in ('google', 'microsoft'): + code_verifier = secrets.token_urlsafe(32) + code_challenge = _compute_s256_challenge(code_verifier) + + state = OAuthState.create_state( + flow_type="link", + provider_type=provider_type, + user_id=user_id, + organization_id=organization_id, + redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None), + code_verifier=code_verifier, + code_challenge=code_challenge, + lifetime_seconds=600, + ) + + auth_url = _build_authorization_url(config=config, state=state) + + AuditService.log_external_auth_link_initiated( + user_id=user_id, + organization_id=organization_id, + provider_type=provider_type_str, + state_id=state.id, + ) + + return auth_url, state.state + + +def complete_link_flow( + get_provider_config, + provider_type: AuthMethodType, + authorization_code: str, + state: str, + redirect_uri: str, +) -> AuthenticationMethod: + from gatehouse_app.services.external_auth._helpers import ( + _exchange_code, + _get_user_info, + _encrypt_provider_data, + ) + + provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type + + state_record = OAuthState.query.filter_by(state=state).first() + if not state_record or not state_record.is_valid(): + AuditService.log_external_auth_link_failed( + user_id=None, + organization_id=None, + provider_type=provider_type_str, + error_message="Invalid or expired OAuth state", + failure_reason="invalid_state", + ) + raise ExternalAuthError("Invalid or expired OAuth state", "INVALID_STATE", 400) + + if state_record.flow_type != "link": + AuditService.log_external_auth_link_failed( + user_id=state_record.user_id, + organization_id=state_record.organization_id, + provider_type=provider_type_str, + error_message="Invalid flow type for this operation", + failure_reason="invalid_flow_type", + ) + raise ExternalAuthError("Invalid flow type for this operation", "INVALID_FLOW_TYPE", 400) + + if state_record.provider_type != provider_type_str: + AuditService.log_external_auth_link_failed( + user_id=state_record.user_id, + organization_id=state_record.organization_id, + provider_type=provider_type_str, + error_message="Provider mismatch", + failure_reason="provider_mismatch", + ) + raise ExternalAuthError("Provider mismatch", "PROVIDER_MISMATCH", 400) + + config = get_provider_config(provider_type, state_record.organization_id) + + tokens = _exchange_code( + config=config, + code=authorization_code, + redirect_uri=redirect_uri, + code_verifier=state_record.code_verifier, + ) + + user_info = _get_user_info(config=config, access_token=tokens["access_token"]) + + user = User.query.get(state_record.user_id) + if not user: + AuditService.log_external_auth_link_failed( + user_id=None, + organization_id=state_record.organization_id, + provider_type=provider_type_str, + error_message="User not found", + failure_reason="user_not_found", + ) + raise ExternalAuthError("User not found", "USER_NOT_FOUND", 400) + + conflicting = AuthenticationMethod.query.filter( + AuthenticationMethod.method_type == provider_type, + AuthenticationMethod.provider_user_id == user_info["provider_user_id"], + AuthenticationMethod.user_id != user.id, + AuthenticationMethod.deleted_at == None, + ).first() + if conflicting: + raise ExternalAuthError( + f"This {provider_type_str} account is already linked to a different Gatehouse user.", + "PROVIDER_ALREADY_LINKED", + 409, + ) + + auth_method = AuthenticationMethod.query.filter_by( + user_id=user.id, + method_type=provider_type, + provider_user_id=user_info["provider_user_id"], + ).first() + + if auth_method: + # Restore the row if it was previously soft-deleted (re-linking after admin unlink) + auth_method.deleted_at = None + auth_method.provider_data = _encrypt_provider_data(tokens, user_info) + auth_method.verified = user_info.get("email_verified", False) + auth_method.last_used_at = datetime.utcnow() + auth_method.save() + else: + auth_method = AuthenticationMethod( + user_id=user.id, + method_type=provider_type, + provider_user_id=user_info["provider_user_id"], + provider_data=_encrypt_provider_data(tokens, user_info), + verified=user_info.get("email_verified", False), + is_primary=False, + last_used_at=datetime.utcnow(), + ) + auth_method.save() + + state_record.mark_used() + + AuditService.log_external_auth_link_completed( + user_id=user.id, + organization_id=state_record.organization_id, + provider_type=provider_type_str, + provider_user_id=user_info["provider_user_id"], + auth_method_id=auth_method.id, + ) + + return auth_method + + +def authenticate_with_provider( + get_provider_config, + provider_type: AuthMethodType, + organization_id: str, + authorization_code: str, + state: str, + redirect_uri: str, +) -> Tuple[User, dict]: + from gatehouse_app.services.external_auth._helpers import ( + _exchange_code, + _get_user_info, + _encrypt_provider_data, + ) + + provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type + + state_record = OAuthState.query.filter_by(state=state).first() + if not state_record or not state_record.is_valid(): + AuditService.log_external_auth_login_failed( + organization_id=organization_id, + provider_type=provider_type_str, + failure_reason="invalid_state", + error_message="Invalid or expired OAuth state", + ) + raise ExternalAuthError("Invalid or expired OAuth state", "INVALID_STATE", 400) + + config = get_provider_config(provider_type, organization_id) + + tokens = _exchange_code( + config=config, + code=authorization_code, + redirect_uri=redirect_uri, + code_verifier=state_record.code_verifier, + ) + + user_info = _get_user_info(config=config, access_token=tokens["access_token"]) + + auth_method = AuthenticationMethod.query.filter_by( + method_type=provider_type, + provider_user_id=user_info["provider_user_id"], + ).first() + + if not auth_method: + existing_user = User.query.filter_by(email=user_info["email"]).first() + + if existing_user: + AuditService.log_external_auth_login_failed( + organization_id=organization_id, + provider_type=provider_type_str, + provider_user_id=user_info["provider_user_id"], + email=user_info["email"], + failure_reason="email_exists", + error_message=f"An account with email {user_info['email']} already exists", + ) + raise ExternalAuthError( + f"An account with email {user_info['email']} already exists. " + "Please log in with your password and link your Google account from settings.", + "EMAIL_EXISTS", + 400, + ) + + AuditService.log_external_auth_login_failed( + organization_id=organization_id, + provider_type=provider_type_str, + provider_user_id=user_info["provider_user_id"], + email=user_info["email"], + failure_reason="account_not_found", + error_message="No Gatehouse account matches this external account", + ) + raise ExternalAuthError( + "No Gatehouse account matches this external account. Please register first.", + "ACCOUNT_NOT_FOUND", + 400, + ) + + user = auth_method.user + auth_method.provider_data = _encrypt_provider_data(tokens, user_info) + auth_method.last_used_at = datetime.utcnow() + auth_method.save() + + state_record.mark_used() + + from gatehouse_app.services.auth_service import AuthService + session = AuthService.create_session(user=user, organization_id=organization_id) + + AuditService.log_external_auth_login( + user_id=user.id, + organization_id=organization_id, + provider_type=provider_type_str, + provider_user_id=user_info["provider_user_id"], + auth_method_id=auth_method.id, + session_id=session.id, + ) + + return user, session.to_dict() + + +def unlink_provider( + user_id: str, + provider_type: AuthMethodType, + organization_id: str = None, +) -> bool: + provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type + + auth_method = AuthenticationMethod.query.filter_by( + user_id=user_id, + method_type=provider_type, + ).first() + + if not auth_method: + raise ExternalAuthError("Provider not linked", "PROVIDER_NOT_LINKED", 400) + + other_methods = AuthenticationMethod.query.filter_by(user_id=user_id).count() + if other_methods <= 1: + raise ExternalAuthError( + "Cannot unlink the last authentication method", + "CANNOT_UNLINK_LAST", + 400, + ) + + provider_user_id = auth_method.provider_user_id + auth_method_id = auth_method.id + auth_method.delete() + + AuditService.log_external_auth_unlink( + user_id=user_id, + organization_id=organization_id, + provider_type=provider_type_str, + provider_user_id=provider_user_id, + auth_method_id=auth_method_id, + ) + + return True + + +def get_linked_accounts(user_id: str) -> list: + from gatehouse_app.utils.constants import AuthMethodType as AMT + + methods = AuthenticationMethod.query.filter_by(user_id=user_id, deleted_at=None).all() + + external_providers = [AMT.GOOGLE, AMT.GITHUB, AMT.MICROSOFT] + + return [ + { + "id": m.id, + "provider_type": m.method_type.value if hasattr(m.method_type, 'value') else str(m.method_type), + "provider_user_id": m.provider_user_id, + "email": m.provider_data.get("email") if m.provider_data else None, + "name": m.provider_data.get("name") if m.provider_data else None, + "picture": m.provider_data.get("picture") if m.provider_data else None, + "verified": m.verified, + "linked_at": m.created_at.isoformat() if m.created_at else None, + "last_used_at": m.last_used_at.isoformat() if m.last_used_at else None, + } + for m in methods + if m.method_type in external_providers + or str(m.method_type) in [p.value for p in external_providers] + ] diff --git a/gatehouse_app/services/external_auth/models.py b/gatehouse_app/services/external_auth/models.py new file mode 100644 index 0000000..5482b67 --- /dev/null +++ b/gatehouse_app/services/external_auth/models.py @@ -0,0 +1,173 @@ +"""External auth models and adapter classes.""" +from typing import Optional + +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel +from gatehouse_app.models.auth.authentication_method import ( + ApplicationProviderConfig, + OrganizationProviderOverride, +) + + +class ExternalAuthError(Exception): + """Base exception for external auth errors.""" + + def __init__(self, message: str, error_type: str, status_code: int = 400): + self.message = message + self.error_type = error_type + self.status_code = status_code + super().__init__(message) + + +class ExternalProviderConfig(BaseModel): + """OAuth provider configuration per organization. + + DEPRECATED: This model is maintained for backward compatibility only. + Use ApplicationProviderConfig and OrganizationProviderOverride instead. + """ + + __tablename__ = "external_provider_configs" + + organization_id = db.Column( + db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True + ) + provider_type = db.Column(db.String(50), nullable=False, index=True) + client_id = db.Column(db.String(255), nullable=False) + client_secret_encrypted = db.Column(db.String(512), nullable=True) + auth_url = db.Column(db.String(2048), nullable=False) + token_url = db.Column(db.String(2048), nullable=False) + userinfo_url = db.Column(db.String(2048), nullable=True) + jwks_url = db.Column(db.String(2048), nullable=True) + scopes = db.Column(db.JSON, nullable=False, default=list) + redirect_uris = db.Column(db.JSON, nullable=False, default=list) + settings = db.Column(db.JSON, nullable=True) + is_active = db.Column(db.Boolean, default=True, nullable=False) + + organization = db.relationship( + "Organization", back_populates="external_provider_configs" + ) + + __table_args__ = ( + db.Index("idx_provider_config_org", "organization_id", "provider_type"), + db.UniqueConstraint( + "organization_id", + "provider_type", + name="uix_org_provider_type", + ), + ) + + def get_client_secret(self) -> str: + from gatehouse_app.utils.encryption import decrypt + if self.client_secret_encrypted: + return decrypt(self.client_secret_encrypted) + return None + + def set_client_secret(self, secret: str): + from gatehouse_app.utils.encryption import encrypt + self.client_secret_encrypted = encrypt(secret) + + def is_redirect_uri_allowed(self, uri: str) -> bool: + return uri in (self.redirect_uris or []) + + def to_dict(self, include_secrets: bool = False) -> dict: + data = { + "id": self.id, + "organization_id": self.organization_id, + "provider_type": self.provider_type, + "client_id": self.client_id, + "auth_url": self.auth_url, + "token_url": self.token_url, + "userinfo_url": self.userinfo_url, + "scopes": self.scopes, + "redirect_uris": self.redirect_uris, + "is_active": self.is_active, + "settings": self.settings, + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None, + } + if include_secrets and self.client_secret_encrypted: + data["client_secret"] = self.get_client_secret() + return data + + +class ProviderConfigAdapter: + """Unified interface for provider configuration. + + Merges application-level config with optional organization overrides. + """ + + def __init__( + self, + app_config: ApplicationProviderConfig, + org_override: Optional[OrganizationProviderOverride] = None, + ): + self.app_config = app_config + self.org_override = org_override + self.provider_type = app_config.provider_type + + @property + def client_id(self) -> str: + if self.org_override and self.org_override.client_id: + return self.org_override.client_id + return self.app_config.client_id + + def get_client_secret(self) -> str: + if self.org_override and self.org_override.client_secret_encrypted: + return self.org_override.get_client_secret() + return self.app_config.get_client_secret() + + @property + def auth_url(self) -> str: + return self._get_provider_endpoint('auth_url') + + @property + def token_url(self) -> str: + return self._get_provider_endpoint('token_url') + + @property + def userinfo_url(self) -> str: + return self._get_provider_endpoint('userinfo_url') + + @property + def jwks_url(self) -> str: + return self._get_provider_endpoint('jwks_url') + + @property + def scopes(self) -> list: + base_scopes = self.app_config.additional_config.get('scopes', []) if self.app_config.additional_config else [] + if self.org_override and self.org_override.additional_config: + override_scopes = self.org_override.additional_config.get('scopes') + if override_scopes is not None: + return override_scopes + return base_scopes or ['openid', 'profile', 'email'] + + @property + def redirect_uris(self) -> list: + if self.org_override and self.org_override.redirect_url_override: + return [self.org_override.redirect_url_override] + if self.app_config.default_redirect_url: + return [self.app_config.default_redirect_url] + return [] + + @property + def settings(self) -> dict: + settings = {} + if self.app_config.additional_config: + settings.update(self.app_config.additional_config) + if self.org_override and self.org_override.additional_config: + settings.update(self.org_override.additional_config) + return settings + + @property + def is_active(self) -> bool: + app_enabled = self.app_config.is_enabled + org_enabled = True if not self.org_override else self.org_override.is_enabled + return app_enabled and org_enabled + + def is_redirect_uri_allowed(self, uri: str) -> bool: + return uri in self.redirect_uris + + def _get_provider_endpoint(self, endpoint_name: str) -> Optional[str]: + if not self.app_config.additional_config: + return None + return self.app_config.additional_config.get(endpoint_name) diff --git a/gatehouse_app/services/external_auth/org_override.py b/gatehouse_app/services/external_auth/org_override.py new file mode 100644 index 0000000..e302b07 --- /dev/null +++ b/gatehouse_app/services/external_auth/org_override.py @@ -0,0 +1,147 @@ +"""Organization-specific provider override management.""" +import logging + +from gatehouse_app.models.auth.authentication_method import ( + ApplicationProviderConfig, + OrganizationProviderOverride, +) +from gatehouse_app.services.external_auth.models import ExternalAuthError + +logger = logging.getLogger(__name__) + + +def create_org_provider_override( + organization_id: str, + provider_type: str, + **kwargs, +) -> OrganizationProviderOverride: + app_config = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type + ).first() + + if not app_config: + raise ExternalAuthError( + f"Application provider {provider_type} must be configured first", + "PROVIDER_NOT_CONFIGURED", + 400, + ) + + existing = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id, + provider_type=provider_type, + ).first() + + if existing: + raise ExternalAuthError( + f"Override for {provider_type} already exists for this organization", + "OVERRIDE_EXISTS", + 400, + ) + + additional_config = {} + if 'settings' in kwargs: + additional_config.update(kwargs.pop('settings')) + if 'scopes' in kwargs: + additional_config['scopes'] = kwargs.pop('scopes') + + override = OrganizationProviderOverride( + organization_id=organization_id, + provider_type=provider_type, + client_id=kwargs.get('client_id'), + is_enabled=kwargs.get('is_enabled', True), + redirect_url_override=kwargs.get('redirect_url_override'), + additional_config=additional_config if additional_config else None, + ) + + if 'client_secret' in kwargs: + override.set_client_secret(kwargs['client_secret']) + + override.save() + logger.info(f"Created org override for {provider_type} in org {organization_id}") + return override + + +def update_org_provider_override( + organization_id: str, + provider_type: str, + **updates, +) -> OrganizationProviderOverride: + override = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id, + provider_type=provider_type, + ).first() + + if not override: + raise ExternalAuthError( + f"Override for {provider_type} not found for this organization", + "OVERRIDE_NOT_FOUND", + 404, + ) + + if 'client_id' in updates: + override.client_id = updates['client_id'] + + if 'client_secret' in updates: + override.set_client_secret(updates['client_secret']) + + if 'is_enabled' in updates: + override.is_enabled = updates['is_enabled'] + + if 'redirect_url_override' in updates: + override.redirect_url_override = updates['redirect_url_override'] + + if 'settings' in updates or 'scopes' in updates: + if override.additional_config is None: + override.additional_config = {} + if 'settings' in updates: + override.additional_config.update(updates['settings']) + if 'scopes' in updates: + override.additional_config['scopes'] = updates['scopes'] + + override.save() + logger.info(f"Updated org override for {provider_type} in org {organization_id}") + return override + + +def get_org_provider_override( + organization_id: str, + provider_type: str, +) -> OrganizationProviderOverride: + override = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id, + provider_type=provider_type, + ).first() + + if not override: + raise ExternalAuthError( + f"Override for {provider_type} not found for this organization", + "OVERRIDE_NOT_FOUND", + 404, + ) + + return override + + +def list_org_provider_overrides(organization_id: str) -> list: + overrides = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id + ).all() + return [override.to_dict() for override in overrides] + + +def delete_org_provider_override(organization_id: str, provider_type: str) -> bool: + override = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id, + provider_type=provider_type, + ).first() + + if not override: + raise ExternalAuthError( + f"Override for {provider_type} not found for this organization", + "OVERRIDE_NOT_FOUND", + 404, + ) + + override.delete() + logger.info(f"Deleted org override for {provider_type} in org {organization_id}") + return True diff --git a/gatehouse_app/services/external_auth_service.py b/gatehouse_app/services/external_auth_service.py deleted file mode 100644 index 57cad76..0000000 --- a/gatehouse_app/services/external_auth_service.py +++ /dev/null @@ -1,1328 +0,0 @@ -"""External authentication provider service.""" -import logging -import secrets -from datetime import datetime, timedelta, timezone -from typing import Optional, Tuple, Dict, Any - -from flask import current_app - -from gatehouse_app.extensions import db -from gatehouse_app.models import User, AuthenticationMethod -from gatehouse_app.models.auth.authentication_method import ( - OAuthState, - ApplicationProviderConfig, - OrganizationProviderOverride -) -from gatehouse_app.models.base import BaseModel -from gatehouse_app.utils.constants import AuthMethodType -from gatehouse_app.services.audit_service import AuditService - -logger = logging.getLogger(__name__) - - -class ExternalAuthError(Exception): - """Base exception for external auth errors.""" - - def __init__(self, message: str, error_type: str, status_code: int = 400): - self.message = message - self.error_type = error_type - self.status_code = status_code - super().__init__(message) - - -class ExternalProviderConfig(BaseModel): - """OAuth provider configuration per organization. - - DEPRECATED: This model is maintained for backward compatibility only. - Use ApplicationProviderConfig and OrganizationProviderOverride instead. - """ - - __tablename__ = "external_provider_configs" - - # Organization reference - organization_id = db.Column( - db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True - ) - - # Provider type - provider_type = db.Column(db.String(50), nullable=False, index=True) - - # OAuth credentials (client_secret is encrypted) - client_id = db.Column(db.String(255), nullable=False) - client_secret_encrypted = db.Column(db.String(512), nullable=True) - - # Provider endpoints - auth_url = db.Column(db.String(2048), nullable=False) - token_url = db.Column(db.String(2048), nullable=False) - userinfo_url = db.Column(db.String(2048), nullable=True) - jwks_url = db.Column(db.String(2048), nullable=True) - - # Configuration - scopes = db.Column(db.JSON, nullable=False, default=list) - redirect_uris = db.Column(db.JSON, nullable=False, default=list) - - # Provider-specific settings - settings = db.Column(db.JSON, nullable=True) - - # Status - is_active = db.Column(db.Boolean, default=True, nullable=False) - - # Relationships - organization = db.relationship( - "Organization", back_populates="external_provider_configs" - ) - - # Indexes - __table_args__ = ( - db.Index("idx_provider_config_org", "organization_id", "provider_type"), - db.UniqueConstraint( - "organization_id", - "provider_type", - name="uix_org_provider_type", - ), - ) - - def get_client_secret(self) -> str: - """Decrypt and return client secret.""" - from gatehouse_app.utils.encryption import decrypt - if self.client_secret_encrypted: - return decrypt(self.client_secret_encrypted) - return None - - def set_client_secret(self, secret: str): - """Encrypt and store client secret.""" - from gatehouse_app.utils.encryption import encrypt - self.client_secret_encrypted = encrypt(secret) - - def is_redirect_uri_allowed(self, uri: str) -> bool: - """Check if redirect URI is allowed.""" - return uri in (self.redirect_uris or []) - - def to_dict(self, include_secrets: bool = False) -> dict: - """Convert to dictionary.""" - data = { - "id": self.id, - "organization_id": self.organization_id, - "provider_type": self.provider_type, - "client_id": self.client_id, - "auth_url": self.auth_url, - "token_url": self.token_url, - "userinfo_url": self.userinfo_url, - "scopes": self.scopes, - "redirect_uris": self.redirect_uris, - "is_active": self.is_active, - "settings": self.settings, - "created_at": self.created_at.isoformat() if self.created_at else None, - "updated_at": self.updated_at.isoformat() if self.updated_at else None, - } - if include_secrets and self.client_secret_encrypted: - data["client_secret"] = self.get_client_secret() - return data - - -class ProviderConfigAdapter: - """ - Adapter to provide a unified interface for provider configuration. - - This merges application-level config with optional organization overrides, - presenting a single config object that works with existing OAuth flow code. - """ - - def __init__( - self, - app_config: ApplicationProviderConfig, - org_override: Optional[OrganizationProviderOverride] = None - ): - """ - Initialize adapter with app config and optional org override. - - Args: - app_config: Application-level provider configuration - org_override: Optional organization-specific override - """ - self.app_config = app_config - self.org_override = org_override - self.provider_type = app_config.provider_type - - @property - def client_id(self) -> str: - """Get effective client ID (override takes precedence).""" - if self.org_override and self.org_override.client_id: - return self.org_override.client_id - return self.app_config.client_id - - def get_client_secret(self) -> str: - """Get effective client secret (override takes precedence).""" - if self.org_override and self.org_override.client_secret_encrypted: - return self.org_override.get_client_secret() - return self.app_config.get_client_secret() - - @property - def auth_url(self) -> str: - """Get authorization URL from app config.""" - # Provider endpoints are not overridable - return self._get_provider_endpoint('auth_url') - - @property - def token_url(self) -> str: - """Get token URL from app config.""" - return self._get_provider_endpoint('token_url') - - @property - def userinfo_url(self) -> str: - """Get userinfo URL from app config.""" - return self._get_provider_endpoint('userinfo_url') - - @property - def jwks_url(self) -> str: - """Get JWKS URL from app config.""" - return self._get_provider_endpoint('jwks_url') - - @property - def scopes(self) -> list: - """Get effective scopes (merged from app config and override).""" - base_scopes = self.app_config.additional_config.get('scopes', []) if self.app_config.additional_config else [] - if self.org_override and self.org_override.additional_config: - override_scopes = self.org_override.additional_config.get('scopes') - if override_scopes is not None: - return override_scopes - return base_scopes or ['openid', 'profile', 'email'] - - @property - def redirect_uris(self) -> list: - """Get effective redirect URIs.""" - # Use override redirect URL if present, otherwise app default - if self.org_override and self.org_override.redirect_url_override: - return [self.org_override.redirect_url_override] - if self.app_config.default_redirect_url: - return [self.app_config.default_redirect_url] - return [] - - @property - def settings(self) -> dict: - """Get merged settings (app config + org override).""" - settings = {} - if self.app_config.additional_config: - settings.update(self.app_config.additional_config) - if self.org_override and self.org_override.additional_config: - settings.update(self.org_override.additional_config) - return settings - - @property - def is_active(self) -> bool: - """Check if provider is active (both app and org must be enabled).""" - app_enabled = self.app_config.is_enabled - org_enabled = True if not self.org_override else self.org_override.is_enabled - return app_enabled and org_enabled - - def is_redirect_uri_allowed(self, uri: str) -> bool: - """Check if redirect URI is allowed.""" - return uri in self.redirect_uris - - def _get_provider_endpoint(self, endpoint_name: str) -> Optional[str]: - """ - Get provider endpoint from app config additional_config. - - For application-wide configs, endpoints are stored in additional_config JSON. - """ - if not self.app_config.additional_config: - return None - return self.app_config.additional_config.get(endpoint_name) - - -class ExternalAuthService: - """Service for external authentication operations.""" - - @classmethod - def get_provider_config( - cls, - provider_type: AuthMethodType, - organization_id: Optional[str] = None, - ) -> ProviderConfigAdapter: - """ - Get provider configuration for authentication. - - This method retrieves application-wide provider configuration and merges - it with organization-specific overrides if present. Both the application - config and organization override (if present) must be enabled for the - provider to be considered active. - - Configuration Precedence: - 1. Application-level config provides the baseline configuration - 2. Organization override can override client_id and client_secret (for SSO) - 3. Both must be enabled for the provider to work - - Args: - provider_type: The OAuth provider type (google, github, etc.) - organization_id: Optional organization ID for override lookup - - Returns: - ProviderConfigAdapter: Unified config object with merged settings - - Raises: - ExternalAuthError: If provider is not configured or disabled - """ - provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - - # Get application-wide config - app_config = ApplicationProviderConfig.query.filter_by( - provider_type=provider_type_str - ).first() - - if not app_config: - raise ExternalAuthError( - f"{provider_type_str.title()} OAuth is not configured for this application", - "PROVIDER_NOT_CONFIGURED", - 400, - ) - - if not app_config.is_enabled: - raise ExternalAuthError( - f"{provider_type_str.title()} OAuth is currently disabled", - "PROVIDER_DISABLED", - 400, - ) - - # Check for organization-specific override - org_override = None - if organization_id: - org_override = OrganizationProviderOverride.query.filter_by( - organization_id=organization_id, - provider_type=provider_type_str - ).first() - - # If override exists but is disabled, provider is not available for this org - if org_override and not org_override.is_enabled: - raise ExternalAuthError( - f"{provider_type_str.title()} OAuth is disabled for this organization", - "PROVIDER_DISABLED_FOR_ORG", - 400, - ) - - # Return adapter with merged configuration - return ProviderConfigAdapter(app_config, org_override) - - # ==================== Application-Wide Provider Management ==================== - - @classmethod - def create_app_provider_config( - cls, - provider_type: str, - client_id: str, - client_secret: str, - **kwargs - ) -> ApplicationProviderConfig: - """ - Create application-wide provider configuration. - - Args: - provider_type: Provider type (google, github, etc.) - client_id: OAuth client ID - client_secret: OAuth client secret - **kwargs: Additional config (auth_url, token_url, userinfo_url, scopes, etc.) - - Returns: - ApplicationProviderConfig: Created configuration - - Raises: - ExternalAuthError: If provider already exists - """ - # Check if provider already exists - existing = ApplicationProviderConfig.query.filter_by( - provider_type=provider_type - ).first() - - if existing: - raise ExternalAuthError( - f"Provider {provider_type} already exists", - "PROVIDER_EXISTS", - 400 - ) - - # Build additional_config with endpoints and settings - additional_config = {} - for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']: - if key in kwargs: - additional_config[key] = kwargs.pop(key) - - # Add any extra settings - if 'settings' in kwargs: - additional_config.update(kwargs.pop('settings')) - - # Create new config - config = ApplicationProviderConfig( - provider_type=provider_type, - client_id=client_id, - is_enabled=kwargs.get('is_enabled', True), - default_redirect_url=kwargs.get('default_redirect_url'), - additional_config=additional_config - ) - - # Set encrypted secret - config.set_client_secret(client_secret) - config.save() - - logger.info(f"Created application provider config for {provider_type}") - return config - - @classmethod - def update_app_provider_config( - cls, - provider_type: str, - **updates - ) -> ApplicationProviderConfig: - """ - Update application-wide provider configuration. - - Args: - provider_type: Provider type to update - **updates: Fields to update (client_id, client_secret, is_enabled, etc.) - - Returns: - ApplicationProviderConfig: Updated configuration - - Raises: - ExternalAuthError: If provider not found - """ - config = ApplicationProviderConfig.query.filter_by( - provider_type=provider_type - ).first() - - if not config: - raise ExternalAuthError( - f"Provider {provider_type} not found", - "PROVIDER_NOT_FOUND", - 404 - ) - - # Update simple fields - if 'client_id' in updates: - config.client_id = updates['client_id'] - - if 'client_secret' in updates: - config.set_client_secret(updates['client_secret']) - - if 'is_enabled' in updates: - config.is_enabled = updates['is_enabled'] - - if 'default_redirect_url' in updates: - config.default_redirect_url = updates['default_redirect_url'] - - # Update additional_config JSON fields - if config.additional_config is None: - config.additional_config = {} - - for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']: - if key in updates: - config.additional_config[key] = updates[key] - - if 'settings' in updates: - config.additional_config.update(updates['settings']) - - config.save() - logger.info(f"Updated application provider config for {provider_type}") - return config - - @classmethod - def get_app_provider_config(cls, provider_type: str) -> ApplicationProviderConfig: - """ - Get application-wide provider configuration. - - Args: - provider_type: Provider type to retrieve - - Returns: - ApplicationProviderConfig: Provider configuration - - Raises: - ExternalAuthError: If provider not found - """ - config = ApplicationProviderConfig.query.filter_by( - provider_type=provider_type - ).first() - - if not config: - raise ExternalAuthError( - f"Provider {provider_type} not found", - "PROVIDER_NOT_FOUND", - 404 - ) - - return config - - @classmethod - def list_app_provider_configs(cls) -> list: - """ - List all application-wide provider configurations. - - Returns: - list: List of provider configuration dictionaries - """ - configs = ApplicationProviderConfig.query.all() - return [config.to_dict() for config in configs] - - @classmethod - def delete_app_provider_config(cls, provider_type: str) -> bool: - """ - Delete application-wide provider configuration. - - Args: - provider_type: Provider type to delete - - Returns: - bool: True if deleted successfully - - Raises: - ExternalAuthError: If provider not found - """ - config = ApplicationProviderConfig.query.filter_by( - provider_type=provider_type - ).first() - - if not config: - raise ExternalAuthError( - f"Provider {provider_type} not found", - "PROVIDER_NOT_FOUND", - 404 - ) - - config.delete() - logger.info(f"Deleted application provider config for {provider_type}") - return True - - # ==================== Organization Provider Override Management ==================== - - @classmethod - def create_org_provider_override( - cls, - organization_id: str, - provider_type: str, - **kwargs - ) -> OrganizationProviderOverride: - """ - Create organization-specific provider override (for SSO scenarios). - - Args: - organization_id: Organization ID - provider_type: Provider type to override - **kwargs: Override fields (client_id, client_secret, redirect_url, etc.) - - Returns: - OrganizationProviderOverride: Created override - - Raises: - ExternalAuthError: If provider doesn't exist or override already exists - """ - # Verify app-level provider exists - app_config = ApplicationProviderConfig.query.filter_by( - provider_type=provider_type - ).first() - - if not app_config: - raise ExternalAuthError( - f"Application provider {provider_type} must be configured first", - "PROVIDER_NOT_CONFIGURED", - 400 - ) - - # Check if override already exists - existing = OrganizationProviderOverride.query.filter_by( - organization_id=organization_id, - provider_type=provider_type - ).first() - - if existing: - raise ExternalAuthError( - f"Override for {provider_type} already exists for this organization", - "OVERRIDE_EXISTS", - 400 - ) - - # Build additional_config from kwargs - additional_config = {} - if 'settings' in kwargs: - additional_config.update(kwargs.pop('settings')) - if 'scopes' in kwargs: - additional_config['scopes'] = kwargs.pop('scopes') - - # Create override - override = OrganizationProviderOverride( - organization_id=organization_id, - provider_type=provider_type, - client_id=kwargs.get('client_id'), - is_enabled=kwargs.get('is_enabled', True), - redirect_url_override=kwargs.get('redirect_url_override'), - additional_config=additional_config if additional_config else None - ) - - # Set encrypted secret if provided - if 'client_secret' in kwargs: - override.set_client_secret(kwargs['client_secret']) - - override.save() - logger.info(f"Created org override for {provider_type} in org {organization_id}") - return override - - @classmethod - def update_org_provider_override( - cls, - organization_id: str, - provider_type: str, - **updates - ) -> OrganizationProviderOverride: - """ - Update organization-specific provider override. - - Args: - organization_id: Organization ID - provider_type: Provider type - **updates: Fields to update - - Returns: - OrganizationProviderOverride: Updated override - - Raises: - ExternalAuthError: If override not found - """ - override = OrganizationProviderOverride.query.filter_by( - organization_id=organization_id, - provider_type=provider_type - ).first() - - if not override: - raise ExternalAuthError( - f"Override for {provider_type} not found for this organization", - "OVERRIDE_NOT_FOUND", - 404 - ) - - # Update simple fields - if 'client_id' in updates: - override.client_id = updates['client_id'] - - if 'client_secret' in updates: - override.set_client_secret(updates['client_secret']) - - if 'is_enabled' in updates: - override.is_enabled = updates['is_enabled'] - - if 'redirect_url_override' in updates: - override.redirect_url_override = updates['redirect_url_override'] - - # Update additional_config - if 'settings' in updates or 'scopes' in updates: - if override.additional_config is None: - override.additional_config = {} - - if 'settings' in updates: - override.additional_config.update(updates['settings']) - if 'scopes' in updates: - override.additional_config['scopes'] = updates['scopes'] - - override.save() - logger.info(f"Updated org override for {provider_type} in org {organization_id}") - return override - - @classmethod - def get_org_provider_override( - cls, - organization_id: str, - provider_type: str - ) -> OrganizationProviderOverride: - """ - Get organization-specific provider override. - - Args: - organization_id: Organization ID - provider_type: Provider type - - Returns: - OrganizationProviderOverride: Provider override - - Raises: - ExternalAuthError: If override not found - """ - override = OrganizationProviderOverride.query.filter_by( - organization_id=organization_id, - provider_type=provider_type - ).first() - - if not override: - raise ExternalAuthError( - f"Override for {provider_type} not found for this organization", - "OVERRIDE_NOT_FOUND", - 404 - ) - - return override - - @classmethod - def list_org_provider_overrides(cls, organization_id: str) -> list: - """ - List all provider overrides for an organization. - - Args: - organization_id: Organization ID - - Returns: - list: List of override configuration dictionaries - """ - overrides = OrganizationProviderOverride.query.filter_by( - organization_id=organization_id - ).all() - return [override.to_dict() for override in overrides] - - @classmethod - def delete_org_provider_override( - cls, - organization_id: str, - provider_type: str - ) -> bool: - """ - Delete organization-specific provider override. - - Args: - organization_id: Organization ID - provider_type: Provider type - - Returns: - bool: True if deleted successfully - - Raises: - ExternalAuthError: If override not found - """ - override = OrganizationProviderOverride.query.filter_by( - organization_id=organization_id, - provider_type=provider_type - ).first() - - if not override: - raise ExternalAuthError( - f"Override for {provider_type} not found for this organization", - "OVERRIDE_NOT_FOUND", - 404 - ) - - override.delete() - logger.info(f"Deleted org override for {provider_type} in org {organization_id}") - return True - - # ==================== OAuth Flow Methods (Updated for New Architecture) ==================== - - @classmethod - def initiate_link_flow( - cls, - user_id: str, - provider_type: AuthMethodType, - organization_id: str, - redirect_uri: str = None, - ) -> Tuple[str, str]: - """ - Initiate account linking flow. - - Returns: - Tuple of (redirect_url, state) - """ - provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - - # Get provider config (with org override if applicable) - config = cls.get_provider_config(provider_type, organization_id) - - # Validate redirect URI - if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri): - raise ExternalAuthError( - "Invalid redirect URI", - "INVALID_REDIRECT_URI", - 400, - ) - - # Generate PKCE — skip for confidential clients (Google, Microsoft) that use a - # client_secret. Sending code_challenge to Microsoft causes it to enforce PKCE on - # the token exchange, which then fails. Matches the behaviour of initiate_login_flow. - code_verifier = None - code_challenge = None - if provider_type_str not in ('google', 'microsoft'): - code_verifier = secrets.token_urlsafe(32) - code_challenge = cls._compute_s256_challenge(code_verifier) - - # Create OAuth state - state = OAuthState.create_state( - flow_type="link", - provider_type=provider_type, - user_id=user_id, - organization_id=organization_id, - redirect_uri=redirect_uri or config.redirect_uris[0] if config.redirect_uris else None, - code_verifier=code_verifier, - code_challenge=code_challenge, - lifetime_seconds=600, - ) - - # Build authorization URL - auth_url = cls._build_authorization_url( - config=config, - state=state, - ) - - # Audit log - link initiated - AuditService.log_external_auth_link_initiated( - user_id=user_id, - organization_id=organization_id, - provider_type=provider_type_str, - state_id=state.id, - ) - - return auth_url, state.state - - @classmethod - def complete_link_flow( - cls, - provider_type: AuthMethodType, - authorization_code: str, - state: str, - redirect_uri: str, - ) -> AuthenticationMethod: - """Complete account linking flow.""" - provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - - # Validate state - state_record = OAuthState.query.filter_by(state=state).first() - if not state_record or not state_record.is_valid(): - AuditService.log_external_auth_link_failed( - user_id=None, - organization_id=None, - provider_type=provider_type_str, - error_message="Invalid or expired OAuth state", - failure_reason="invalid_state", - ) - raise ExternalAuthError( - "Invalid or expired OAuth state", - "INVALID_STATE", - 400, - ) - - if state_record.flow_type != "link": - AuditService.log_external_auth_link_failed( - user_id=state_record.user_id, - organization_id=state_record.organization_id, - provider_type=provider_type_str, - error_message="Invalid flow type for this operation", - failure_reason="invalid_flow_type", - ) - raise ExternalAuthError( - "Invalid flow type for this operation", - "INVALID_FLOW_TYPE", - 400, - ) - - if state_record.provider_type != provider_type_str: - AuditService.log_external_auth_link_failed( - user_id=state_record.user_id, - organization_id=state_record.organization_id, - provider_type=provider_type_str, - error_message="Provider mismatch", - failure_reason="provider_mismatch", - ) - raise ExternalAuthError( - "Provider mismatch", - "PROVIDER_MISMATCH", - 400, - ) - - # Get provider config (with org override if applicable) - config = cls.get_provider_config( - provider_type, state_record.organization_id - ) - - # Exchange code for tokens - tokens = cls._exchange_code( - config=config, - code=authorization_code, - redirect_uri=redirect_uri, - code_verifier=state_record.code_verifier, - ) - - # Get user info - user_info = cls._get_user_info( - config=config, - access_token=tokens["access_token"], - ) - - # Get user - user = User.query.get(state_record.user_id) - if not user: - AuditService.log_external_auth_link_failed( - user_id=None, - organization_id=state_record.organization_id, - provider_type=provider_type_str, - error_message="User not found", - failure_reason="user_not_found", - ) - raise ExternalAuthError( - "User not found", - "USER_NOT_FOUND", - 400, - ) - - # Create or update authentication method - auth_method = AuthenticationMethod.query.filter_by( - user_id=user.id, - method_type=provider_type, - provider_user_id=user_info["provider_user_id"], - ).first() - - if auth_method: - # Update existing - auth_method.provider_data = cls._encrypt_provider_data(tokens, user_info) - auth_method.verified = user_info.get("email_verified", False) - auth_method.last_used_at = datetime.utcnow() - auth_method.save() - else: - # Create new - auth_method = AuthenticationMethod( - user_id=user.id, - method_type=provider_type, - provider_user_id=user_info["provider_user_id"], - provider_data=cls._encrypt_provider_data(tokens, user_info), - verified=user_info.get("email_verified", False), - is_primary=False, - last_used_at=datetime.utcnow(), - ) - auth_method.save() - - # Mark state as used - state_record.mark_used() - - # Audit log - link completed - AuditService.log_external_auth_link_completed( - user_id=user.id, - organization_id=state_record.organization_id, - provider_type=provider_type_str, - provider_user_id=user_info["provider_user_id"], - auth_method_id=auth_method.id, - ) - - return auth_method - - @classmethod - def authenticate_with_provider( - cls, - provider_type: AuthMethodType, - organization_id: str, - authorization_code: str, - state: str, - redirect_uri: str, - ) -> Tuple[User, dict]: - """Authenticate user with external provider and return tokens.""" - provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - - # Validate state - state_record = OAuthState.query.filter_by(state=state).first() - if not state_record or not state_record.is_valid(): - AuditService.log_external_auth_login_failed( - organization_id=organization_id, - provider_type=provider_type_str, - failure_reason="invalid_state", - error_message="Invalid or expired OAuth state", - ) - raise ExternalAuthError( - "Invalid or expired OAuth state", - "INVALID_STATE", - 400, - ) - - # Get provider config (with org override if applicable) - config = cls.get_provider_config(provider_type, organization_id) - - # Exchange code for tokens - tokens = cls._exchange_code( - config=config, - code=authorization_code, - redirect_uri=redirect_uri, - code_verifier=state_record.code_verifier, - ) - - # Get user info - user_info = cls._get_user_info( - config=config, - access_token=tokens["access_token"], - ) - - # Look up user by provider_user_id - auth_method = AuthenticationMethod.query.filter_by( - method_type=provider_type, - provider_user_id=user_info["provider_user_id"], - ).first() - - if not auth_method: - # Check if email matches existing user - existing_user = User.query.filter_by( - email=user_info["email"] - ).first() - - if existing_user: - AuditService.log_external_auth_login_failed( - organization_id=organization_id, - provider_type=provider_type_str, - provider_user_id=user_info["provider_user_id"], - email=user_info["email"], - failure_reason="email_exists", - error_message=f"An account with email {user_info['email']} already exists", - ) - raise ExternalAuthError( - f"An account with email {user_info['email']} already exists. " - "Please log in with your password and link your Google account from settings.", - "EMAIL_EXISTS", - 400, - ) - - AuditService.log_external_auth_login_failed( - organization_id=organization_id, - provider_type=provider_type_str, - provider_user_id=user_info["provider_user_id"], - email=user_info["email"], - failure_reason="account_not_found", - error_message="No Gatehouse account matches this external account", - ) - raise ExternalAuthError( - "No Gatehouse account matches this external account. Please register first.", - "ACCOUNT_NOT_FOUND", - 400, - ) - - user = auth_method.user - - # Update tokens - auth_method.provider_data = cls._encrypt_provider_data(tokens, user_info) - auth_method.last_used_at = datetime.utcnow() - auth_method.save() - - # Mark state as used - state_record.mark_used() - - # Create session - from gatehouse_app.services.auth_service import AuthService - session = AuthService.create_session( - user=user, - organization_id=organization_id, - ) - - # Audit log - login success - AuditService.log_external_auth_login( - user_id=user.id, - organization_id=organization_id, - provider_type=provider_type_str, - provider_user_id=user_info["provider_user_id"], - auth_method_id=auth_method.id, - session_id=session.id, - ) - - return user, session.to_dict() - - @classmethod - def unlink_provider( - cls, - user_id: str, - provider_type: AuthMethodType, - organization_id: str = None, - ) -> bool: - """Unlink external provider from user account.""" - provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - - auth_method = AuthenticationMethod.query.filter_by( - user_id=user_id, - method_type=provider_type, - ).first() - - if not auth_method: - raise ExternalAuthError( - f"Provider not linked", - "PROVIDER_NOT_LINKED", - 400, - ) - - # Check if this is the last auth method - other_methods = AuthenticationMethod.query.filter_by( - user_id=user_id, - ).count() - - if other_methods <= 1: - raise ExternalAuthError( - "Cannot unlink the last authentication method", - "CANNOT_UNLINK_LAST", - 400, - ) - - provider_user_id = auth_method.provider_user_id - auth_method_id = auth_method.id - auth_method.delete() - - # Audit log - unlink - AuditService.log_external_auth_unlink( - user_id=user_id, - organization_id=organization_id, - provider_type=provider_type_str, - provider_user_id=provider_user_id, - auth_method_id=auth_method_id, - ) - - return True - - @classmethod - def get_linked_accounts(cls, user_id: str) -> list: - """Get all linked external accounts for user.""" - methods = AuthenticationMethod.query.filter_by( - user_id=user_id, - ).all() - - external_providers = [ - AuthMethodType.GOOGLE, - AuthMethodType.GITHUB, - AuthMethodType.MICROSOFT, - ] - - return [ - { - "id": m.id, - "provider_type": m.method_type.value if hasattr(m.method_type, 'value') else str(m.method_type), - "provider_user_id": m.provider_user_id, - "email": m.provider_data.get("email") if m.provider_data else None, - "name": m.provider_data.get("name") if m.provider_data else None, - "picture": m.provider_data.get("picture") if m.provider_data else None, - "verified": m.verified, - "linked_at": m.created_at.isoformat() if m.created_at else None, - "last_used_at": m.last_used_at.isoformat() if m.last_used_at else None, - } - for m in methods - if m.method_type in external_providers or str(m.method_type) in [p.value for p in external_providers] - ] - - # ==================== Helper Methods ==================== - - @staticmethod - def _compute_s256_challenge(verifier: str) -> str: - """Compute S256 code challenge from verifier.""" - import hashlib - import base64 - - digest = hashlib.sha256(verifier.encode()).digest() - return base64.urlsafe_b64encode(digest).decode().rstrip("=") - - @staticmethod - def _build_authorization_url(config: ProviderConfigAdapter, state: OAuthState) -> str: - """Build authorization URL using the provider config adapter.""" - from urllib.parse import urlencode - provider = (config.provider_type or "").lower() - - params = { - "client_id": config.client_id, - "redirect_uri": state.redirect_uri, - "response_type": "code", - "scope": " ".join(config.scopes or ["openid", "profile", "email"]), - "state": state.state, - } - - if provider == "google": - params["access_type"] = ( - config.settings.get("access_type", "offline") if config.settings else "offline" - ) - params["prompt"] = ( - config.settings.get("prompt", "consent") if config.settings else "consent" - ) - elif provider == "microsoft": - params["prompt"] = ( - config.settings.get("prompt", "select_account") if config.settings else "select_account" - ) - else: - if config.settings: - if "prompt" in config.settings: - params["prompt"] = config.settings["prompt"] - if "access_type" in config.settings: - params["access_type"] = config.settings["access_type"] - - if state.nonce: - params["nonce"] = state.nonce - - if state.code_challenge: - params["code_challenge"] = state.code_challenge - params["code_challenge_method"] = "S256" - - full_url = f"{config.auth_url}?{urlencode(params)}" - - # DIAGNOSTIC LOGGING: Show exact URL being built - logger.info( - f"[PKCE DEBUG] Building authorization URL:\n" - f" provider_type: {config.provider_type}\n" - f" state.code_challenge: {state.code_challenge[:20] if state.code_challenge else 'None'}...\n" - f" params has code_challenge: {'code_challenge' in params}\n" - f" Full URL: {full_url}" - ) - - return full_url - - @staticmethod - def _exchange_code(config: ProviderConfigAdapter, code: str, redirect_uri: str, code_verifier: str = None) -> dict: - """Exchange authorization code for tokens using the provider config adapter.""" - import requests - - data = { - "client_id": config.client_id, - "client_secret": config.get_client_secret(), - "code": code, - "grant_type": "authorization_code", - "redirect_uri": redirect_uri, - } - - if code_verifier: - data["code_verifier"] = code_verifier - - # Log token exchange request (without secrets) - logger.debug( - f"Token exchange request: url={config.token_url}, " - f"client_id={config.client_id}, redirect_uri={redirect_uri}, " - f"has_code_verifier={bool(code_verifier)}" - ) - - response = requests.post(config.token_url, data=data) - - # Log response details for debugging - if response.status_code != 200: - logger.error( - f"Token exchange failed: status={response.status_code}, " - f"response={response.text}" - ) - - response.raise_for_status() - - return response.json() - - @staticmethod - def _get_user_info(config: ProviderConfigAdapter, access_token: str) -> dict: - """Get user info from provider using the provider config adapter.""" - import requests - - provider = (config.provider_type or "").lower() - headers = {"Authorization": f"Bearer {access_token}"} - response = requests.get(config.userinfo_url, headers=headers) - response.raise_for_status() - - data = response.json() - - # Microsoft's /oidc/userinfo endpoint returns verified email addresses - # (all AAD accounts are verified) but may omit the email_verified claim. - # Default to True for Microsoft so users aren't stuck with unverified state. - if provider == "microsoft": - email_verified = data.get("email_verified", True) - else: - email_verified = data.get("email_verified", False) - - sub = data.get("sub") - - # Derive email from sub when the provider omits the email claim. - # This happens with some OIDC servers (including the nav-security mock) - # that only return the minimal {sub, iss, iat, exp} set. - # Rule: if sub looks like an email address, use it directly. - # Otherwise, construct a deterministic fallback so we never get NULL. - raw_email = data.get("email") - if not raw_email and sub: - import re as _re - if _re.match(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", sub): - raw_email = sub - email_verified = True # if sub IS the email it's already verified - else: - # e.g. "12345" → "12345@google.local" so we can store it - raw_email = f"{sub}@{provider or 'oauth'}.local" - email_verified = False - - # Derive display name when omitted - raw_name = data.get("name") or data.get("display_name") - if not raw_name and raw_email: - raw_name = raw_email.split("@")[0] - - # Standardize user info - return { - "provider_user_id": sub, - "email": raw_email, - "email_verified": email_verified, - "name": raw_name, - "first_name": data.get("given_name"), - "last_name": data.get("family_name"), - "picture": data.get("picture"), - "raw_data": data, - } - - @staticmethod - def _encrypt_provider_data(tokens: dict, user_info: dict) -> dict: - """Encrypt and store provider tokens and user info.""" - from gatehouse_app.utils.encryption import encrypt - - result = { - "access_token": encrypt(tokens.get("access_token")) if tokens.get("access_token") else None, - "token_type": tokens.get("token_type", "Bearer"), - "expires_in": tokens.get("expires_in"), - "refresh_token": encrypt(tokens.get("refresh_token")) if tokens.get("refresh_token") else None, - "scope": tokens.get("scope", []), - "id_token": encrypt(tokens.get("id_token")) if tokens.get("id_token") else None, - "email": user_info.get("email"), - "name": user_info.get("name"), - "picture": user_info.get("picture"), - "raw_data": user_info.get("raw_data", {}), - } - - return result - - @staticmethod - def _decrypt_provider_data(provider_data: dict) -> dict: - """ - Decrypt provider tokens from stored data. - - This method handles backward compatibility with existing data where - access_token may be stored in plain text (unencrypted). - """ - from gatehouse_app.utils.encryption import decrypt - - if not provider_data: - return {} - - result = { - "token_type": provider_data.get("token_type", "Bearer"), - "expires_in": provider_data.get("expires_in"), - "scope": provider_data.get("scope", []), - "email": provider_data.get("email"), - "name": provider_data.get("name"), - "picture": provider_data.get("picture"), - "raw_data": provider_data.get("raw_data", {}), - } - - # Decrypt access_token with backward compatibility - access_token = provider_data.get("access_token") - if access_token: - # Try to decrypt - if it fails, assume it's plain text (old data) - try: - result["access_token"] = decrypt(access_token) - except Exception: - # Access token is plain text (pre-encryption data) - result["access_token"] = access_token - else: - result["access_token"] = None - - # Decrypt refresh_token - refresh_token = provider_data.get("refresh_token") - if refresh_token: - try: - result["refresh_token"] = decrypt(refresh_token) - except Exception: - result["refresh_token"] = refresh_token - else: - result["refresh_token"] = None - - # Decrypt id_token - id_token = provider_data.get("id_token") - if id_token: - try: - result["id_token"] = decrypt(id_token) - except Exception: - result["id_token"] = id_token - else: - result["id_token"] = None - - return result diff --git a/gatehouse_app/services/notification_service.py b/gatehouse_app/services/notification_service.py index fc9bbe0..fa942c0 100644 --- a/gatehouse_app/services/notification_service.py +++ b/gatehouse_app/services/notification_service.py @@ -295,6 +295,7 @@ Gatehouse Security Team Returns True if the email was sent successfully, False otherwise. If EMAIL_ENABLED is False, logs the email body instead (simulation mode). + All SMTP exceptions are caught and logged — this method never raises. """ import smtplib from email.mime.multipart import MIMEMultipart @@ -310,17 +311,37 @@ Gatehouse Security Team ) return False - smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY, "localhost") - smtp_port = int(current_app.config.get(NotificationService.SMTP_PORT_KEY, 587)) + smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY, "") + smtp_port_raw = current_app.config.get(NotificationService.SMTP_PORT_KEY, 587) smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY) smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY) + from_address = current_app.config.get( + NotificationService.FROM_ADDRESS_KEY, "" + ) + + # Guard: refuse to attempt a connection when critical config is missing. + # This surfaces a clear log message instead of a confusing socket error. + missing = [k for k, v in [ + ("SMTP_HOST", smtp_host), + ("FROM_ADDRESS", from_address), + ] if not v] + if missing: + logger.error( + f"[EMAIL] Cannot send — missing config: {', '.join(missing)}. " + f"Would have sent to: {to_address} | Subject: {subject}" + ) + return False + + try: + smtp_port = int(smtp_port_raw) + except (TypeError, ValueError): + logger.error(f"[EMAIL] Invalid SMTP_PORT value: {smtp_port_raw!r}") + return False + smtp_use_tls = current_app.config.get( NotificationService.SMTP_USE_TLS_KEY, smtp_port not in (25, 1025), ) - from_address = current_app.config.get( - NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local" - ) try: msg = MIMEMultipart("alternative") diff --git a/gatehouse_app/services/oauth_flow/__init__.py b/gatehouse_app/services/oauth_flow/__init__.py new file mode 100644 index 0000000..71ea135 --- /dev/null +++ b/gatehouse_app/services/oauth_flow/__init__.py @@ -0,0 +1,209 @@ +"""OAuthFlowService — public facade and handle_callback dispatcher.""" +import logging +from typing import Optional, Tuple + +from gatehouse_app.models.auth.authentication_method import OAuthState +from gatehouse_app.utils.constants import AuthMethodType +from gatehouse_app.services.audit_service import AuditService +from gatehouse_app.services.external_auth import ExternalAuthService +from gatehouse_app.services.external_auth.models import ExternalAuthError + +from gatehouse_app.services.oauth_flow.login import OAuthFlowError, initiate_login_flow, handle_login_callback +from gatehouse_app.services.oauth_flow.register import initiate_register_flow, handle_register_callback +from gatehouse_app.services.oauth_flow.code import ( + generate_authorization_code, + exchange_authorization_code, + create_redirect_response, +) + +logger = logging.getLogger(__name__) + + +class OAuthFlowService: + """Service for managing OAuth authentication flows.""" + + @classmethod + def initiate_login_flow( + cls, + provider_type: AuthMethodType, + organization_id: str = None, + redirect_uri: str = None, + state_data: dict = None, + ) -> Tuple[str, str]: + return initiate_login_flow(provider_type, organization_id, redirect_uri, state_data) + + @classmethod + def initiate_register_flow( + cls, + provider_type: AuthMethodType, + organization_id: str = None, + redirect_uri: str = None, + ) -> Tuple[str, str]: + return initiate_register_flow(provider_type, organization_id, redirect_uri) + + @classmethod + def handle_callback( + cls, + provider_type: AuthMethodType, + authorization_code: str, + state: str, + redirect_uri: str = None, + error: str = None, + error_description: str = None, + ) -> dict: + provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type + + try: + from flask import request + ip_address = request.remote_addr if request else None + user_agent = request.headers.get("User-Agent") if request else None + except RuntimeError: + ip_address = None + user_agent = None + + if error: + AuditService.log_external_auth_login_failed( + organization_id=None, + provider_type=provider_type_str, + failure_reason=error, + error_message=error_description or error, + ) + raise OAuthFlowError( + error_description or f"OAuth error: {error}", + error.upper() if error else "OAUTH_ERROR", + 400, + ) + + state_record = OAuthState.query.filter_by(state=state).first() + + if state_record: + logger.debug( + f"State validation: found=True, used={state_record.used}, " + f"expires_at={state_record.expires_at}, is_valid={state_record.is_valid()}" + ) + else: + logger.warning(f"State validation: state token not found in database: {state}") + + if not state_record or not state_record.is_valid(): + AuditService.log_external_auth_login_failed( + organization_id=state_record.organization_id if state_record else None, + provider_type=provider_type_str, + failure_reason="invalid_state", + error_message="Invalid or expired OAuth state", + ) + raise OAuthFlowError("Invalid or expired OAuth state", "INVALID_STATE", 400) + + effective_redirect = redirect_uri or state_record.redirect_uri + + if state_record.flow_type == "login": + return handle_login_callback( + provider_type=provider_type, + state_record=state_record, + authorization_code=authorization_code, + redirect_uri=effective_redirect, + ip_address=ip_address, + user_agent=user_agent, + ) + elif state_record.flow_type == "link": + return cls._handle_link_callback( + provider_type=provider_type, + state_record=state_record, + authorization_code=authorization_code, + redirect_uri=effective_redirect, + ) + elif state_record.flow_type == "register": + return handle_register_callback( + provider_type=provider_type, + state_record=state_record, + authorization_code=authorization_code, + redirect_uri=effective_redirect, + ) + else: + raise OAuthFlowError( + f"Unknown flow type: {state_record.flow_type}", + "INVALID_FLOW_TYPE", + 400, + ) + + @classmethod + def _handle_link_callback( + cls, + provider_type: AuthMethodType, + state_record: OAuthState, + authorization_code: str, + redirect_uri: str, + ) -> dict: + provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type + + try: + auth_method = ExternalAuthService.complete_link_flow( + provider_type=provider_type, + authorization_code=authorization_code, + state=state_record.state, + redirect_uri=redirect_uri, + ) + + logger.info( + f"OAuth link successful for user={state_record.user_id}, " + f"provider={provider_type_str}, auth_method_id={auth_method.id}" + ) + + return { + "success": True, + "flow_type": "link", + "linked_account": { + "id": auth_method.id, + "provider_type": provider_type_str, + "provider_user_id": auth_method.provider_user_id, + "verified": auth_method.verified, + }, + } + + except ExternalAuthError as e: + logger.warning( + f"OAuth link failed for state={state_record.id}, " + f"provider={provider_type_str}, error={e.message}" + ) + raise + + @classmethod + def validate_state(cls, state: str) -> Optional[OAuthState]: + state_record = OAuthState.query.filter_by(state=state).first() + if state_record and state_record.is_valid(): + return state_record + return None + + @classmethod + def cleanup_expired_states(cls): + OAuthState.cleanup_expired() + logger.info("Expired OAuth states cleaned up") + + @classmethod + def generate_authorization_code( + cls, + user_id: str, + client_id: str, + redirect_uri: str, + scope: list = None, + nonce: str = None, + ip_address: str = None, + user_agent: str = None, + lifetime_seconds: int = 600, + ) -> str: + return generate_authorization_code( + user_id, client_id, redirect_uri, scope, nonce, ip_address, user_agent, lifetime_seconds + ) + + @classmethod + def exchange_authorization_code( + cls, + code: str, + client_id: str, + redirect_uri: str, + ip_address: str = None, + ) -> dict: + return exchange_authorization_code(code, client_id, redirect_uri, ip_address) + + @classmethod + def create_redirect_response(cls, redirect_uri: str, authorization_code: str, state: str = None): + return create_redirect_response(redirect_uri, authorization_code, state) diff --git a/gatehouse_app/services/oauth_flow/code.py b/gatehouse_app/services/oauth_flow/code.py new file mode 100644 index 0000000..d298045 --- /dev/null +++ b/gatehouse_app/services/oauth_flow/code.py @@ -0,0 +1,141 @@ +"""Authorization code generation, exchange, and redirect helpers.""" +import hashlib +import logging +import secrets +from datetime import datetime, timezone +from typing import Optional + +from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode +from gatehouse_app.services.oauth_flow.login import OAuthFlowError + +logger = logging.getLogger(__name__) + + +def generate_authorization_code( + user_id: str, + client_id: str, + redirect_uri: str, + scope: list = None, + nonce: str = None, + ip_address: str = None, + user_agent: str = None, + lifetime_seconds: int = 600, +) -> str: + code = secrets.token_urlsafe(32) + code_hash = hashlib.sha256(code.encode()).hexdigest() + + OIDCAuthCode.create_code( + client_id=client_id, + user_id=user_id, + code_hash=code_hash, + redirect_uri=redirect_uri, + scope=scope, + nonce=nonce, + ip_address=ip_address, + user_agent=user_agent, + lifetime_seconds=lifetime_seconds, + ) + + logger.info(f"Generated authorization code for user={user_id}, client={client_id}") + return code + + +def exchange_authorization_code( + code: str, + client_id: str, + redirect_uri: str, + ip_address: str = None, +) -> dict: + code_hash = hashlib.sha256(code.encode()).hexdigest() + + auth_code = OIDCAuthCode.query.filter_by( + client_id=client_id, + code_hash=code_hash, + ).first() + + if not auth_code: + raise OAuthFlowError("Invalid authorization code", "INVALID_CODE", 400) + + if not auth_code.is_valid(): + if auth_code.is_used: + raise OAuthFlowError( + "Authorization code has already been used", "CODE_USED", 400 + ) + else: + raise OAuthFlowError("Authorization code has expired", "CODE_EXPIRED", 400) + + if auth_code.redirect_uri != redirect_uri: + raise OAuthFlowError("Redirect URI mismatch", "INVALID_REDIRECT_URI", 400) + + from gatehouse_app.models import User + user = User.query.get(auth_code.user_id) + if not user: + raise OAuthFlowError("User not found", "USER_NOT_FOUND", 404) + + user_orgs = user.get_organizations() + target_org = None + if len(user_orgs) == 1: + target_org = user_orgs[0] + + if not target_org: + raise OAuthFlowError( + "User does not have a default organization. Organization selection required.", + "ORG_SELECTION_REQUIRED", + 400, + ) + + from gatehouse_app.services.auth_service import AuthService + session = AuthService.create_session(user=user, is_compliance_only=False) + auth_code.mark_as_used() + + session_dict = session.to_dict() + session_dict["token"] = session.token + expires_at = session.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + session_dict["expires_in"] = int((expires_at - now).total_seconds()) + + logger.info( + f"Authorization code exchanged for session: user={user.id}, " + f"org_id={target_org.id}, client={client_id}" + ) + + return { + "success": True, + "token": session_dict["token"], + "expires_in": session_dict["expires_in"], + "token_type": "Bearer", + "user": { + "id": user.id, + "email": user.email, + "full_name": user.full_name, + "organization_id": target_org.id, + }, + } + + +def create_redirect_response( + redirect_uri: str, + authorization_code: str, + state: str = None, +): + from urllib.parse import urlencode, urlparse, urlunparse + from flask import redirect + + parsed = urlparse(redirect_uri) + params = {"code": authorization_code} + if state: + params["state"] = state + + redirect_url = urlunparse(( + parsed.scheme, + parsed.netloc, + parsed.path, + parsed.params, + urlencode(params), + parsed.fragment, + )) + + logger.info(f"Redirecting to {parsed.scheme}://{parsed.netloc} with authorization code") + return redirect(redirect_url) diff --git a/gatehouse_app/services/oauth_flow/login.py b/gatehouse_app/services/oauth_flow/login.py new file mode 100644 index 0000000..c5fd4c0 --- /dev/null +++ b/gatehouse_app/services/oauth_flow/login.py @@ -0,0 +1,410 @@ +"""Login flow: initiate and handle OAuth login callback.""" +import logging +import secrets +from datetime import datetime, timezone +from typing import Optional, Tuple + +from gatehouse_app.models import User, AuthenticationMethod +from gatehouse_app.models.auth.authentication_method import OAuthState +from gatehouse_app.utils.constants import AuthMethodType, AuditAction +from gatehouse_app.services.audit_service import AuditService +from gatehouse_app.services.external_auth import ExternalAuthService +from gatehouse_app.services.external_auth.models import ExternalAuthError + +logger = logging.getLogger(__name__) + + +class OAuthFlowError(Exception): + def __init__(self, message: str, error_type: str, status_code: int = 400): + self.message = message + self.error_type = error_type + self.status_code = status_code + super().__init__(message) + + +def initiate_login_flow( + provider_type: AuthMethodType, + organization_id: str = None, + redirect_uri: str = None, + state_data: dict = None, +) -> Tuple[str, str]: + try: + from flask import request + except Exception: + request = None + + try: + ip_address = request.remote_addr if request else None + user_agent = request.headers.get("User-Agent") if request else None + except RuntimeError: + ip_address = None + user_agent = None + + provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type + + try: + config = ExternalAuthService.get_provider_config(provider_type, organization_id) + + if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri): + raise OAuthFlowError("Invalid redirect URI", "INVALID_REDIRECT_URI", 400) + + code_verifier = None + code_challenge = None + if provider_type_str not in ['google', 'microsoft']: + code_verifier = secrets.token_urlsafe(32) + code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier) + + logger.info( + f"[PKCE DEBUG] Provider type check: provider_type_str='{provider_type_str}', " + f"is_google={provider_type_str in ['google']}, " + f"will_skip_pkce={provider_type_str in ['google', 'microsoft']}" + ) + + state = OAuthState.create_state( + flow_type="login", + provider_type=provider_type, + organization_id=organization_id, + redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None), + code_verifier=code_verifier, + code_challenge=code_challenge, + extra_data=state_data, + lifetime_seconds=600, + ) + + logger.info( + f"[PKCE DEBUG] Created OAuthState object:\n" + f" state.id: {state.id}\n" + f" state.provider_type: {state.provider_type}\n" + f" state.code_challenge: {state.code_challenge}\n" + f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..." + ) + + auth_url = ExternalAuthService._build_authorization_url(config=config, state=state) + + logger.info( + f"OAuth login flow initiated for provider={provider_type_str}, " + f"org_id={organization_id}, state_token={state.state}, state_record_id={state.id}" + ) + logger.info( + f"[PKCE DEBUG] FINAL CHECK: code_challenge={code_challenge}, " + f"code_verifier={code_verifier[:20] if code_verifier else None}..., " + f"auth_url_has_challenge={'code_challenge=' in auth_url}, " + f"returned_auth_url={auth_url}" + ) + + return auth_url, state.state + + except ExternalAuthError as e: + AuditService.log_action( + action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED, + organization_id=organization_id, + metadata={ + "provider_type": provider_type_str, + "failure_reason": e.error_type, + "ip_address": ip_address, + }, + description=f"OAuth login initiation failed: {e.message}", + success=False, + error_message=e.message, + ) + raise + + +def handle_login_callback( + provider_type: AuthMethodType, + state_record: OAuthState, + authorization_code: str, + redirect_uri: str, + ip_address: str = None, + user_agent: str = None, +) -> dict: + from gatehouse_app.services.external_auth._helpers import _encrypt_provider_data + + provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type + + try: + config = ExternalAuthService.get_provider_config( + provider_type, state_record.organization_id + ) + + logger.debug( + f"Exchanging code with PKCE: state_record.code_verifier=" + f"{state_record.code_verifier[:20] if state_record.code_verifier else None}..." + ) + + tokens = ExternalAuthService._exchange_code( + config=config, + code=authorization_code, + redirect_uri=redirect_uri, + code_verifier=state_record.code_verifier, + ) + + user_info = ExternalAuthService._get_user_info( + config=config, + access_token=tokens["access_token"], + ) + + if not user_info.get("provider_user_id"): + raise OAuthFlowError( + "Provider did not return a user identifier (sub claim). " + "Cannot complete authentication.", + "MISSING_PROVIDER_USER_ID", + 400, + ) + + if not user_info.get("email"): + raise OAuthFlowError( + "Provider did not return an email address. " + "Cannot complete authentication.", + "MISSING_EMAIL", + 400, + ) + + logger.debug( + f"Got user_info from provider: sub={user_info['provider_user_id']}, " + f"email={user_info['email']}, email_verified={user_info.get('email_verified')}" + ) + + # Find the active auth method for this provider identity. + # Order by created_at DESC so that an explicitly linked (newer) row wins + # over an older auto-created primary row when the same Google identity + # was linked to a second profile. + auth_method = ( + AuthenticationMethod.query + .filter_by( + method_type=provider_type, + provider_user_id=user_info["provider_user_id"], + deleted_at=None, + ) + .order_by(AuthenticationMethod.created_at.desc()) + .first() + ) + + if not auth_method: + deleted_method = ( + AuthenticationMethod.query + .filter_by( + method_type=provider_type, + provider_user_id=user_info["provider_user_id"], + ) + .order_by(AuthenticationMethod.created_at.desc()) + .first() + ) + + if deleted_method: + logger.info( + f"OAuth login: restoring previously unlinked {provider_type_str} " + f"auth method for user {deleted_method.user_id}" + ) + deleted_method.deleted_at = None + deleted_method.provider_data = _encrypt_provider_data(tokens, user_info) + deleted_method.last_used_at = datetime.utcnow() + deleted_method.save() + auth_method = deleted_method + + else: + existing_user = User.query.filter_by(email=user_info["email"], deleted_at=None).first() + + if existing_user: + logger.info( + f"OAuth login: email {user_info['email']} matches existing user " + f"{existing_user.id}, auto-linking {provider_type_str} account" + ) + auth_method = AuthenticationMethod( + user_id=existing_user.id, + method_type=provider_type, + provider_user_id=user_info["provider_user_id"], + provider_data=_encrypt_provider_data(tokens, user_info), + verified=user_info.get("email_verified", False), + is_primary=False, + last_used_at=datetime.utcnow(), + ) + auth_method.save() + user = existing_user + else: + logger.info( + f"OAuth login: no account for {user_info['email']}, " + f"auto-creating user via {provider_type_str}" + ) + user = User( + email=user_info["email"], + full_name=user_info.get("name", ""), + status="active", + email_verified=user_info.get("email_verified", False), + ) + user.save() + + auth_method = AuthenticationMethod( + user_id=user.id, + method_type=provider_type, + provider_user_id=user_info["provider_user_id"], + provider_data=_encrypt_provider_data(tokens, user_info), + verified=user_info.get("email_verified", False), + is_primary=True, + last_used_at=datetime.utcnow(), + ) + auth_method.save() + + AuditService.log_action( + action="user.register", + user_id=user.id, + organization_id=state_record.organization_id, + resource_type="user", + resource_id=user.id, + metadata={ + "provider_type": provider_type_str, + "provider_user_id": user_info["provider_user_id"], + "auto_registered": True, + }, + description=f"User auto-registered via {provider_type_str} OAuth", + success=True, + ) + else: + auth_method.provider_data = _encrypt_provider_data(tokens, user_info) + auth_method.last_used_at = datetime.utcnow() + auth_method.save() + + user = auth_method.user + + user_orgs = user.get_organizations() + target_org = None + + if state_record.organization_id: + target_org = next( + (org for org in user_orgs if org.id == state_record.organization_id), + None, + ) + + if not target_org and len(user_orgs) == 1: + target_org = user_orgs[0] + + if not target_org and len(user_orgs) > 1: + # Multiple orgs and none specified in the OAuth state — pick the one the + # user joined most recently (highest created_at on their membership row). + # Users can switch organisations inside the app after logging in. + from gatehouse_app.models.organization.organization_member import OrganizationMember as _OM + latest_membership = ( + _OM.query + .filter_by(user_id=user.id, deleted_at=None) + .order_by(_OM.created_at.desc()) + .first() + ) + if latest_membership: + target_org = latest_membership.organization + else: + target_org = user_orgs[0] + + if not target_org and len(user_orgs) == 0: + from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + from gatehouse_app.services.auth_service import AuthService as _AS + _now = datetime.now(timezone.utc) + _session = _AS.create_session(user=user, is_compliance_only=False) + _session_dict = _session.to_dict() + _session_dict["token"] = _session.token + _expires_at = _session.expires_at + if _expires_at.tzinfo is None: + _expires_at = _expires_at.replace(tzinfo=timezone.utc) + _session_dict["expires_in"] = int((_expires_at - _now).total_seconds()) + + _pending = OrgInviteToken.query.filter( + OrgInviteToken.email == user.email, + OrgInviteToken.accepted_at.is_(None), + OrgInviteToken.expires_at > _now, + OrgInviteToken.deleted_at.is_(None), + ).all() + _pending_list = [ + { + "token": inv.token, + "organization": {"id": str(inv.organization_id), "name": inv.organization.name}, + "role": inv.role, + "expires_at": inv.expires_at.isoformat(), + } + for inv in _pending + ] + + state_record.mark_used() + logger.info( + f"OAuth login: user {user.id} has no org, redirecting to org-setup " + f"(pending_invites={len(_pending_list)})" + ) + return { + "success": True, + "flow_type": "login", + "requires_org_creation": True, + "user": {"id": user.id, "email": user.email, "full_name": user.full_name}, + "session": _session_dict, + "pending_invites": _pending_list, + "state": state_record.state, + } + + if not target_org: + state_record.mark_used() + logger.info( + f"OAuth login requires org selection for user={user.id}, " + f"provider={provider_type_str}, org_count={len(user_orgs)}" + ) + return { + "success": True, + "flow_type": "login", + "requires_org_selection": True, + "user": {"id": user.id, "email": user.email, "full_name": user.full_name}, + "available_organizations": [ + { + "id": org.id, + "name": org.name, + "slug": org.slug if hasattr(org, "slug") else None, + } + for org in user_orgs + ], + "state": state_record.state, + } + + from gatehouse_app.services.auth_service import AuthService + session = AuthService.create_session(user=user, is_compliance_only=False) + state_record.mark_used() + + AuditService.log_external_auth_login( + user_id=user.id, + organization_id=target_org.id, + provider_type=provider_type_str, + provider_user_id=user_info["provider_user_id"], + auth_method_id=auth_method.id, + session_id=session.id, + ) + + logger.info( + f"OAuth login successful for user={user.id}, " + f"provider={provider_type_str}, org_id={target_org.id}" + ) + + session_dict = session.to_dict() + session_dict["token"] = session.token + expires_at = session.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + session_dict["expires_in"] = int((expires_at - now).total_seconds()) + + return { + "success": True, + "flow_type": "login", + "user": { + "id": user.id, + "email": user.email, + "full_name": user.full_name, + "organization_id": target_org.id, + }, + "session": session_dict, + } + + except ExternalAuthError as e: + logger.warning( + f"OAuth login failed for state={state_record.id}, " + f"provider={provider_type_str}, error={e.message}" + ) + raise + except OAuthFlowError: + raise + except Exception as e: + logger.error(f"Unexpected error in OAuth login callback: {str(e)}", exc_info=True) + raise OAuthFlowError("An unexpected error occurred during login", "INTERNAL_ERROR", 500) diff --git a/gatehouse_app/services/oauth_flow/register.py b/gatehouse_app/services/oauth_flow/register.py new file mode 100644 index 0000000..069c247 --- /dev/null +++ b/gatehouse_app/services/oauth_flow/register.py @@ -0,0 +1,248 @@ +"""Registration flow: initiate and handle OAuth register callback.""" +import logging +import secrets +from datetime import datetime, timezone +from typing import Optional, Tuple + +from gatehouse_app.models import User, AuthenticationMethod +from gatehouse_app.models.auth.authentication_method import OAuthState +from gatehouse_app.utils.constants import AuthMethodType, AuditAction +from gatehouse_app.services.audit_service import AuditService +from gatehouse_app.services.external_auth import ExternalAuthService +from gatehouse_app.services.external_auth.models import ExternalAuthError +from gatehouse_app.services.oauth_flow.login import OAuthFlowError + +logger = logging.getLogger(__name__) + + +def initiate_register_flow( + provider_type: AuthMethodType, + organization_id: str = None, + redirect_uri: str = None, +) -> Tuple[str, str]: + provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type + + try: + config = ExternalAuthService.get_provider_config(provider_type, organization_id) + + if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri): + raise OAuthFlowError("Invalid redirect URI", "INVALID_REDIRECT_URI", 400) + + code_verifier = None + code_challenge = None + if provider_type_str not in ['google', 'microsoft']: + code_verifier = secrets.token_urlsafe(32) + code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier) + + logger.info( + f"[PKCE DEBUG] Register flow - Provider type check: provider_type_str='{provider_type_str}', " + f"is_google={provider_type_str in ['google']}, " + f"will_skip_pkce={provider_type_str in ['google', 'microsoft']}" + ) + + state = OAuthState.create_state( + flow_type="register", + provider_type=provider_type, + organization_id=organization_id, + redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None), + code_verifier=code_verifier, + code_challenge=code_challenge, + lifetime_seconds=600, + ) + + logger.info( + f"[PKCE DEBUG] Register flow - Created OAuthState:\n" + f" state.id: {state.id}\n" + f" state.code_challenge: {state.code_challenge}\n" + f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..." + ) + + auth_url = ExternalAuthService._build_authorization_url(config=config, state=state) + + logger.info( + f"OAuth register flow initiated for provider={provider_type_str}, " + f"org_id={organization_id}, state_id={state.id}" + ) + logger.info( + f"[PKCE DEBUG] Register flow - FINAL: auth_url_has_challenge={'code_challenge=' in auth_url}" + ) + + return auth_url, state.state + + except ExternalAuthError as e: + AuditService.log_action( + action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED, + organization_id=organization_id, + metadata={ + "provider_type": provider_type_str, + "failure_reason": e.error_type, + }, + description=f"OAuth registration initiation failed: {e.message}", + success=False, + error_message=e.message, + ) + raise + + +def handle_register_callback( + provider_type: AuthMethodType, + state_record: OAuthState, + authorization_code: str, + redirect_uri: str, +) -> dict: + from gatehouse_app.services.external_auth._helpers import _encrypt_provider_data + + provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type + + try: + config = ExternalAuthService.get_provider_config( + provider_type, state_record.organization_id + ) + + tokens = ExternalAuthService._exchange_code( + config=config, + code=authorization_code, + redirect_uri=redirect_uri, + code_verifier=state_record.code_verifier, + ) + + user_info = ExternalAuthService._get_user_info( + config=config, + access_token=tokens["access_token"], + ) + + existing_user = User.query.filter_by(email=user_info["email"]).first() + if existing_user: + raise OAuthFlowError( + f"An account with email {user_info['email']} already exists. " + "Please log in with your password and link your Google account from settings.", + "EMAIL_EXISTS", + 400, + ) + + user = User( + email=user_info["email"], + full_name=user_info.get("name", ""), + status="active", + email_verified=user_info.get("email_verified", False), + ) + user.save() + + auth_method = AuthenticationMethod( + user_id=user.id, + method_type=provider_type, + provider_user_id=user_info["provider_user_id"], + provider_data=_encrypt_provider_data(tokens, user_info), + verified=user_info.get("email_verified", False), + is_primary=True, + last_used_at=datetime.utcnow(), + ) + auth_method.save() + + state_record.mark_used() + + AuditService.log_action( + action="user.register", + user_id=user.id, + organization_id=state_record.organization_id, + resource_type="user", + resource_id=user.id, + metadata={ + "provider_type": provider_type_str, + "provider_user_id": user_info["provider_user_id"], + "auth_method_id": auth_method.id, + }, + description=f"User registered via {provider_type_str}", + success=True, + ) + + AuditService.log_external_auth_link_completed( + user_id=user.id, + organization_id=state_record.organization_id, + provider_type=provider_type_str, + provider_user_id=user_info["provider_user_id"], + auth_method_id=auth_method.id, + ) + + logger.info( + f"OAuth registration successful for email={user_info['email']}, " + f"provider={provider_type_str}, user_id={user.id}" + ) + + if state_record.organization_id: + from gatehouse_app.models.organization.organization import Organization + org = Organization.query.get(state_record.organization_id) + if org: + from gatehouse_app.services.auth_service import AuthService + session = AuthService.create_session(user=user, is_compliance_only=False) + session_dict = session.to_dict() + session_dict["token"] = session.token + expires_at = session.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + session_dict["expires_in"] = int((expires_at - now).total_seconds()) + return { + "success": True, + "flow_type": "register", + "user": { + "id": user.id, + "email": user.email, + "full_name": user.full_name, + "organization_id": org.id, + }, + "session": session_dict, + } + + from gatehouse_app.services.auth_service import AuthService as _AS + from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + _session = _AS.create_session(user=user, is_compliance_only=False) + _session_dict = _session.to_dict() + _session_dict["token"] = _session.token + _expires_at = _session.expires_at + if _expires_at.tzinfo is None: + _expires_at = _expires_at.replace(tzinfo=timezone.utc) + _now = datetime.now(timezone.utc) + _session_dict["expires_in"] = int((_expires_at - _now).total_seconds()) + + _pending = OrgInviteToken.query.filter( + OrgInviteToken.email == user.email, + OrgInviteToken.accepted_at.is_(None), + OrgInviteToken.expires_at > _now, + OrgInviteToken.deleted_at.is_(None), + ).all() + _pending_list = [ + { + "token": inv.token, + "organization": {"id": str(inv.organization_id), "name": inv.organization.name}, + "role": inv.role, + "expires_at": inv.expires_at.isoformat(), + } + for inv in _pending + ] + + return { + "success": True, + "flow_type": "register", + "requires_org_creation": True, + "user": {"id": user.id, "email": user.email, "full_name": user.full_name}, + "session": _session_dict, + "pending_invites": _pending_list, + "state": state_record.state, + } + + except ExternalAuthError as e: + logger.warning( + f"OAuth registration failed for state={state_record.id}, " + f"provider={provider_type_str}, error={e.message}" + ) + raise + except OAuthFlowError: + raise + except Exception as e: + logger.error(f"Unexpected error in OAuth registration callback: {str(e)}", exc_info=True) + raise OAuthFlowError( + "An unexpected error occurred during registration", + "INTERNAL_ERROR", + 500, + ) diff --git a/gatehouse_app/services/oauth_flow_service.py b/gatehouse_app/services/oauth_flow_service.py deleted file mode 100644 index 2889a7c..0000000 --- a/gatehouse_app/services/oauth_flow_service.py +++ /dev/null @@ -1,1152 +0,0 @@ -"""OAuth flow service for handling external authentication flows.""" -import hashlib -import logging -import secrets -from datetime import datetime, timedelta, timezone -from typing import Optional, Tuple - -from flask import current_app, request, g, redirect - -from gatehouse_app.extensions import db -from gatehouse_app.models import User, AuthenticationMethod -from gatehouse_app.models.auth.authentication_method import OAuthState -from gatehouse_app.models.base import BaseModel -from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode -from gatehouse_app.utils.constants import AuthMethodType, AuditAction -from gatehouse_app.services.audit_service import AuditService -from gatehouse_app.services.external_auth_service import ( - ExternalAuthService, - ExternalAuthError, - ExternalProviderConfig, -) - -logger = logging.getLogger(__name__) - - -class OAuthFlowError(Exception): - """Exception for OAuth flow errors.""" - - def __init__(self, message: str, error_type: str, status_code: int = 400): - self.message = message - self.error_type = error_type - self.status_code = status_code - super().__init__(message) - - -class OAuthFlowService: - """Service for managing OAuth authentication flows.""" - - @classmethod - def initiate_login_flow( - cls, - provider_type: AuthMethodType, - organization_id: str = None, - redirect_uri: str = None, - state_data: dict = None, - ) -> Tuple[str, str]: - """ - Initiate OAuth login flow without requiring organization_id upfront. - - This method initiates the OAuth flow using application-wide provider configuration. - The organization context is determined after successful authentication. - - Args: - provider_type: The authentication provider type - organization_id: Optional organization hint for SSO discovery - redirect_uri: Optional custom redirect URI - state_data: Additional state data to include - - Returns: - Tuple of (authorization_url, state) - """ - # Get request context for audit logging - try: - ip_address = request.remote_addr if request else None - user_agent = request.headers.get("User-Agent") if request else None - except RuntimeError: - ip_address = None - user_agent = None - - provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - - try: - # Get provider config (application-wide, no organization required) - config = ExternalAuthService.get_provider_config(provider_type, organization_id) - - # Validate redirect URI - if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri): - raise OAuthFlowError( - "Invalid redirect URI", - "INVALID_REDIRECT_URI", - 400, - ) - - # Generate PKCE parameters (Google and Microsoft web applications don't use PKCE - # when a client_secret is present — they are confidential clients) - code_verifier = None - code_challenge = None - if provider_type_str not in ['google', 'microsoft']: - code_verifier = secrets.token_urlsafe(32) - code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier) - - # DIAGNOSTIC LOGGING: Show PKCE decision - logger.info( - f"[PKCE DEBUG] Provider type check: provider_type_str='{provider_type_str}', " - f"is_google={provider_type_str in ['google']}, " - f"will_skip_pkce={provider_type_str in ['google', 'microsoft']}" - ) - - # Create OAuth state for login flow - state = OAuthState.create_state( - flow_type="login", - provider_type=provider_type, - organization_id=organization_id, - redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None), - code_verifier=code_verifier, - code_challenge=code_challenge, - extra_data=state_data, - lifetime_seconds=600, - ) - - # DIAGNOSTIC LOGGING: Verify state object - logger.info( - f"[PKCE DEBUG] Created OAuthState object:\n" - f" state.id: {state.id}\n" - f" state.provider_type: {state.provider_type}\n" - f" state.code_challenge: {state.code_challenge}\n" - f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..." - ) - - # Build authorization URL - auth_url = ExternalAuthService._build_authorization_url( - config=config, - state=state, - ) - - logger.info( - f"OAuth login flow initiated for provider={provider_type_str}, " - f"org_id={organization_id}, state_token={state.state}, state_record_id={state.id}" - ) - logger.info( - f"[PKCE DEBUG] FINAL CHECK: code_challenge={code_challenge}, " - f"code_verifier={code_verifier[:20] if code_verifier else None}..., " - f"auth_url_has_challenge={'code_challenge=' in auth_url}, " - f"returned_auth_url={auth_url}" - ) - - return auth_url, state.state - - except ExternalAuthError as e: - # Log failed initiation - AuditService.log_action( - action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED, - organization_id=organization_id, - metadata={ - "provider_type": provider_type_str, - "failure_reason": e.error_type, - "ip_address": ip_address, - }, - description=f"OAuth login initiation failed: {e.message}", - success=False, - error_message=e.message, - ) - raise - - @classmethod - def initiate_register_flow( - cls, - provider_type: AuthMethodType, - organization_id: str = None, - redirect_uri: str = None, - ) -> Tuple[str, str]: - """ - Initiate OAuth registration flow without requiring organization_id upfront. - - Args: - provider_type: The authentication provider type - organization_id: Optional organization hint - redirect_uri: Optional custom redirect URI - - Returns: - Tuple of (authorization_url, state) - """ - provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - - try: - # Get provider config (application-wide, no organization required) - config = ExternalAuthService.get_provider_config(provider_type, organization_id) - - # Validate redirect URI - if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri): - raise OAuthFlowError( - "Invalid redirect URI", - "INVALID_REDIRECT_URI", - 400, - ) - - # Generate PKCE parameters (Google and Microsoft web applications don't use PKCE - # when a client_secret is present — they are confidential clients) - code_verifier = None - code_challenge = None - if provider_type_str not in ['google', 'microsoft']: - code_verifier = secrets.token_urlsafe(32) - code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier) - - # DIAGNOSTIC LOGGING: Show PKCE decision for register flow - logger.info( - f"[PKCE DEBUG] Register flow - Provider type check: provider_type_str='{provider_type_str}', " - f"is_google={provider_type_str in ['google']}, " - f"will_skip_pkce={provider_type_str in ['google', 'microsoft']}" - ) - - # Create OAuth state for register flow - state = OAuthState.create_state( - flow_type="register", - provider_type=provider_type, - organization_id=organization_id, - redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None), - code_verifier=code_verifier, - code_challenge=code_challenge, - lifetime_seconds=600, - ) - - # DIAGNOSTIC LOGGING: Verify state object for register flow - logger.info( - f"[PKCE DEBUG] Register flow - Created OAuthState:\n" - f" state.id: {state.id}\n" - f" state.code_challenge: {state.code_challenge}\n" - f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..." - ) - - # Build authorization URL - auth_url = ExternalAuthService._build_authorization_url( - config=config, - state=state, - ) - - logger.info( - f"OAuth register flow initiated for provider={provider_type_str}, " - f"org_id={organization_id}, state_id={state.id}" - ) - logger.info( - f"[PKCE DEBUG] Register flow - FINAL: auth_url_has_challenge={'code_challenge=' in auth_url}" - ) - - return auth_url, state.state - - except ExternalAuthError as e: - AuditService.log_action( - action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED, - organization_id=organization_id, - metadata={ - "provider_type": provider_type_str, - "failure_reason": e.error_type, - }, - description=f"OAuth registration initiation failed: {e.message}", - success=False, - error_message=e.message, - ) - raise - - @classmethod - def handle_callback( - cls, - provider_type: AuthMethodType, - authorization_code: str, - state: str, - redirect_uri: str = None, - error: str = None, - error_description: str = None, - ) -> dict: - """ - Handle OAuth callback from provider. - - Args: - provider_type: The authentication provider type - authorization_code: Authorization code from provider - state: State parameter from provider - redirect_uri: Redirect URI used in the flow - error: Error code if auth failed - error_description: Human-readable error description - - Returns: - Dict with flow result - """ - provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - - # Get request context for audit logging - try: - ip_address = request.remote_addr if request else None - user_agent = request.headers.get("User-Agent") if request else None - except RuntimeError: - ip_address = None - user_agent = None - - # Handle error response from provider - if error: - AuditService.log_external_auth_login_failed( - organization_id=None, - provider_type=provider_type_str, - failure_reason=error, - error_message=error_description or error, - ) - raise OAuthFlowError( - error_description or f"OAuth error: {error}", - error.upper() if error else "OAUTH_ERROR", - 400, - ) - - # Validate state - state_record = OAuthState.query.filter_by(state=state).first() - - # Log validation details for debugging - if state_record: - logger.debug( - f"State validation: found=True, used={state_record.used}, " - f"expires_at={state_record.expires_at}, now={datetime.now(timezone.utc)}, " - f"is_valid={state_record.is_valid()}" - ) - else: - logger.warning(f"State validation: state token not found in database: {state}") - - if not state_record or not state_record.is_valid(): - AuditService.log_external_auth_login_failed( - organization_id=state_record.organization_id if state_record else None, - provider_type=provider_type_str, - failure_reason="invalid_state", - error_message="Invalid or expired OAuth state", - ) - raise OAuthFlowError( - "Invalid or expired OAuth state", - "INVALID_STATE", - 400, - ) - - # Route to appropriate handler based on flow type - if state_record.flow_type == "login": - return cls._handle_login_callback( - provider_type=provider_type, - state_record=state_record, - authorization_code=authorization_code, - redirect_uri=redirect_uri or state_record.redirect_uri, - ip_address=ip_address, - user_agent=user_agent, - ) - elif state_record.flow_type == "link": - return cls._handle_link_callback( - provider_type=provider_type, - state_record=state_record, - authorization_code=authorization_code, - redirect_uri=redirect_uri or state_record.redirect_uri, - ) - elif state_record.flow_type == "register": - return cls._handle_register_callback( - provider_type=provider_type, - state_record=state_record, - authorization_code=authorization_code, - redirect_uri=redirect_uri or state_record.redirect_uri, - ) - else: - raise OAuthFlowError( - f"Unknown flow type: {state_record.flow_type}", - "INVALID_FLOW_TYPE", - 400, - ) - - @classmethod - def _handle_login_callback( - cls, - provider_type: AuthMethodType, - state_record: OAuthState, - authorization_code: str, - redirect_uri: str, - ip_address: str = None, - user_agent: str = None, - ) -> dict: - """ - Handle login flow callback with organization discovery. - - This method: - 1. Exchanges the authorization code for tokens - 2. Gets user info from the OAuth provider - 3. Looks up the user by provider_user_id - 4. Determines which organization(s) the user belongs to - 5. Creates a session or returns org selection needed - """ - provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - - try: - # Get provider config (application-wide) - config = ExternalAuthService.get_provider_config( - provider_type, state_record.organization_id - ) - - logger.debug( - f"Exchanging code with PKCE: state_record.code_verifier={state_record.code_verifier[:20] if state_record.code_verifier else None}..." - ) - - # Exchange code for tokens - tokens = ExternalAuthService._exchange_code( - config=config, - code=authorization_code, - redirect_uri=redirect_uri, - code_verifier=state_record.code_verifier, - ) - - # Get user info from provider - user_info = ExternalAuthService._get_user_info( - config=config, - access_token=tokens["access_token"], - ) - - if not user_info.get("provider_user_id"): - raise OAuthFlowError( - "Provider did not return a user identifier (sub claim). " - "Cannot complete authentication.", - "MISSING_PROVIDER_USER_ID", - 400, - ) - - if not user_info.get("email"): - raise OAuthFlowError( - "Provider did not return an email address. " - "Cannot complete authentication.", - "MISSING_EMAIL", - 400, - ) - - logger.debug( - f"Got user_info from provider: sub={user_info['provider_user_id']}, " - f"email={user_info['email']}, email_verified={user_info.get('email_verified')}" - ) - - # Look up user by provider_user_id - auth_method = AuthenticationMethod.query.filter_by( - method_type=provider_type, - provider_user_id=user_info["provider_user_id"], - ).first() - - if not auth_method: - # No linked account found — check if email matches an existing user - existing_user = User.query.filter_by( - email=user_info["email"] - ).first() - - if existing_user: - # Email exists but no OAuth link — auto-link and log in - logger.info( - f"OAuth login: email {user_info['email']} matches existing user " - f"{existing_user.id}, auto-linking {provider_type_str} account" - ) - auth_method = AuthenticationMethod( - user_id=existing_user.id, - method_type=provider_type, - provider_user_id=user_info["provider_user_id"], - provider_data=ExternalAuthService._encrypt_provider_data(tokens, user_info), - verified=user_info.get("email_verified", False), - is_primary=False, - last_used_at=datetime.utcnow(), - ) - auth_method.save() - user = existing_user - else: - # Brand-new user — auto-register via OAuth (standard behaviour) - logger.info( - f"OAuth login: no account for {user_info['email']}, " - f"auto-creating user via {provider_type_str}" - ) - user = User( - email=user_info["email"], - full_name=user_info.get("name", ""), - status="active", - email_verified=user_info.get("email_verified", False), - ) - user.save() - - auth_method = AuthenticationMethod( - user_id=user.id, - method_type=provider_type, - provider_user_id=user_info["provider_user_id"], - provider_data=ExternalAuthService._encrypt_provider_data(tokens, user_info), - verified=user_info.get("email_verified", False), - is_primary=True, - last_used_at=datetime.utcnow(), - ) - auth_method.save() - - AuditService.log_action( - action="user.register", - user_id=user.id, - organization_id=state_record.organization_id, - resource_type="user", - resource_id=user.id, - metadata={ - "provider_type": provider_type_str, - "provider_user_id": user_info["provider_user_id"], - "auto_registered": True, - }, - description=f"User auto-registered via {provider_type_str} OAuth", - success=True, - ) - else: - # Existing linked account — update provider data - auth_method.provider_data = ExternalAuthService._encrypt_provider_data( - tokens, user_info - ) - auth_method.last_used_at = datetime.utcnow() - auth_method.save() - - user = auth_method.user - - # Get user's organizations - user_orgs = user.get_organizations() - - # Determine target organization - target_org = None - - # Priority 1: Use organization_id from state if provided (org hint) - if state_record.organization_id: - target_org = next( - (org for org in user_orgs if org.id == state_record.organization_id), - None - ) - - # Priority 2: If user has exactly one organization, use it - if not target_org and len(user_orgs) == 1: - target_org = user_orgs[0] - - # Priority 3: No orgs at all — send to org-setup instead of auto-creating - if not target_org and len(user_orgs) == 0: - from gatehouse_app.models.organization.org_invite_token import OrgInviteToken - from gatehouse_app.services.auth_service import AuthService as _AS - _now = datetime.now(timezone.utc) - _session = _AS.create_session(user=user, is_compliance_only=False) - _session_dict = _session.to_dict() - _session_dict["token"] = _session.token - _expires_at = _session.expires_at - if _expires_at.tzinfo is None: - _expires_at = _expires_at.replace(tzinfo=timezone.utc) - _session_dict["expires_in"] = int((_expires_at - _now).total_seconds()) - - _pending = OrgInviteToken.query.filter( - OrgInviteToken.email == user.email, - OrgInviteToken.accepted_at.is_(None), - OrgInviteToken.expires_at > _now, - OrgInviteToken.deleted_at.is_(None), - ).all() - _pending_list = [ - { - "token": inv.token, - "organization": { - "id": str(inv.organization_id), - "name": inv.organization.name, - }, - "role": inv.role, - "expires_at": inv.expires_at.isoformat(), - } - for inv in _pending - ] - - state_record.mark_used() - logger.info( - f"OAuth login: user {user.id} has no org, redirecting to org-setup " - f"(pending_invites={len(_pending_list)})" - ) - return { - "success": True, - "flow_type": "login", - "requires_org_creation": True, - "user": {"id": user.id, "email": user.email, "full_name": user.full_name}, - "session": _session_dict, - "pending_invites": _pending_list, - "state": state_record.state, - } - - # Priority 4: Multiple orgs — need user to pick one - if not target_org: - state_record.mark_used() - logger.info( - f"OAuth login requires org selection for user={user.id}, " - f"provider={provider_type_str}, org_count={len(user_orgs)}" - ) - return { - "success": True, - "flow_type": "login", - "requires_org_selection": True, - "user": { - "id": user.id, - "email": user.email, - "full_name": user.full_name, - }, - "available_organizations": [ - { - "id": org.id, - "name": org.name, - "slug": org.slug if hasattr(org, "slug") else None, - } - for org in user_orgs - ], - "state": state_record.state, - } - - # Create session for the target org - from gatehouse_app.services.auth_service import AuthService - session = AuthService.create_session( - user=user, - is_compliance_only=False, - ) - - # Mark state as used - state_record.mark_used() - - # Audit log - login success - AuditService.log_external_auth_login( - user_id=user.id, - organization_id=target_org.id, - provider_type=provider_type_str, - provider_user_id=user_info["provider_user_id"], - auth_method_id=auth_method.id, - session_id=session.id, - ) - - logger.info( - f"OAuth login successful for user={user.id}, " - f"provider={provider_type_str}, org_id={target_org.id}" - ) - - # Build session dict with token (to_dict() excludes token for security) - session_dict = session.to_dict() - session_dict["token"] = session.token - # Calculate expires_in handling naive datetime from database - expires_at = session.expires_at - if expires_at.tzinfo is None: - expires_at = expires_at.replace(tzinfo=timezone.utc) - now = datetime.now(timezone.utc) - session_dict["expires_in"] = int((expires_at - now).total_seconds()) - - return { - "success": True, - "flow_type": "login", - "user": { - "id": user.id, - "email": user.email, - "full_name": user.full_name, - "organization_id": target_org.id, - }, - "session": session_dict, - } - - except ExternalAuthError as e: - logger.warning( - f"OAuth login failed for state={state_record.id}, " - f"provider={provider_type_str}, error={e.message}" - ) - raise - except OAuthFlowError: - # Re-raise OAuthFlowError as-is - raise - except Exception as e: - logger.error( - f"Unexpected error in OAuth login callback: {str(e)}", - exc_info=True - ) - raise OAuthFlowError( - "An unexpected error occurred during login", - "INTERNAL_ERROR", - 500, - ) - - @classmethod - def _handle_link_callback( - cls, - provider_type: AuthMethodType, - state_record: OAuthState, - authorization_code: str, - redirect_uri: str, - ) -> dict: - """Handle account linking flow callback.""" - provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - - try: - # Complete link flow - auth_method = ExternalAuthService.complete_link_flow( - provider_type=provider_type, - authorization_code=authorization_code, - state=state_record.state, - redirect_uri=redirect_uri, - ) - - logger.info( - f"OAuth link successful for user={state_record.user_id}, " - f"provider={provider_type_str}, auth_method_id={auth_method.id}" - ) - - return { - "success": True, - "flow_type": "link", - "linked_account": { - "id": auth_method.id, - "provider_type": provider_type_str, - "provider_user_id": auth_method.provider_user_id, - "verified": auth_method.verified, - }, - } - - except ExternalAuthError as e: - logger.warning( - f"OAuth link failed for state={state_record.id}, " - f"provider={provider_type_str}, error={e.message}" - ) - raise - - @classmethod - def _handle_register_callback( - cls, - provider_type: AuthMethodType, - state_record: OAuthState, - authorization_code: str, - redirect_uri: str, - ) -> dict: - """ - Handle registration flow callback. - - Creates a new user account and prompts for organization creation/selection. - """ - provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - - try: - # Get provider config (application-wide) - config = ExternalAuthService.get_provider_config( - provider_type, state_record.organization_id - ) - - # Exchange code for tokens - tokens = ExternalAuthService._exchange_code( - config=config, - code=authorization_code, - redirect_uri=redirect_uri, - code_verifier=state_record.code_verifier, - ) - - # Get user info - user_info = ExternalAuthService._get_user_info( - config=config, - access_token=tokens["access_token"], - ) - - # Check if user already exists by email - existing_user = User.query.filter_by( - email=user_info["email"] - ).first() - - if existing_user: - # User exists - suggest linking - raise OAuthFlowError( - f"An account with email {user_info['email']} already exists. " - "Please log in with your password and link your Google account from settings.", - "EMAIL_EXISTS", - 400, - ) - - # Create new user - user = User( - email=user_info["email"], - full_name=user_info.get("name", ""), - status="active", - email_verified=user_info.get("email_verified", False), - ) - user.save() - - # Create authentication method - auth_method = AuthenticationMethod( - user_id=user.id, - method_type=provider_type, - provider_user_id=user_info["provider_user_id"], - provider_data=ExternalAuthService._encrypt_provider_data(tokens, user_info), - verified=user_info.get("email_verified", False), - is_primary=True, - last_used_at=datetime.utcnow(), - ) - auth_method.save() - - # Mark state as used - state_record.mark_used() - - # Audit log - registration success - AuditService.log_action( - action="user.register", - user_id=user.id, - organization_id=state_record.organization_id, - resource_type="user", - resource_id=user.id, - metadata={ - "provider_type": provider_type_str, - "provider_user_id": user_info["provider_user_id"], - "auth_method_id": auth_method.id, - }, - description=f"User registered via {provider_type_str}", - success=True, - ) - - AuditService.log_external_auth_link_completed( - user_id=user.id, - organization_id=state_record.organization_id, - provider_type=provider_type_str, - provider_user_id=user_info["provider_user_id"], - auth_method_id=auth_method.id, - ) - - logger.info( - f"OAuth registration successful for email={user_info['email']}, " - f"provider={provider_type_str}, user_id={user.id}" - ) - - # If organization_id hint was provided and valid, create session for that org - if state_record.organization_id: - from gatehouse_app.models.organization.organization import Organization - org = Organization.query.get(state_record.organization_id) - if org: - from gatehouse_app.services.auth_service import AuthService - session = AuthService.create_session( - user=user, - is_compliance_only=False, - ) - # Build session dict with token (to_dict() excludes token for security) - session_dict = session.to_dict() - session_dict["token"] = session.token - # Calculate expires_in handling naive datetime from database - expires_at = session.expires_at - if expires_at.tzinfo is None: - expires_at = expires_at.replace(tzinfo=timezone.utc) - now = datetime.now(timezone.utc) - session_dict["expires_in"] = int((expires_at - now).total_seconds()) - return { - "success": True, - "flow_type": "register", - "user": { - "id": user.id, - "email": user.email, - "full_name": user.full_name, - "organization_id": org.id, - }, - "session": session_dict, - } - - # No organization hint or invalid - need to create/select org. - # Still create a session so the frontend can call /organizations - # and /invites after redirecting to /org-setup. - from gatehouse_app.services.auth_service import AuthService as _AS - from gatehouse_app.models.organization.org_invite_token import OrgInviteToken - _session = _AS.create_session(user=user, is_compliance_only=False) - _session_dict = _session.to_dict() - _session_dict["token"] = _session.token - _expires_at = _session.expires_at - if _expires_at.tzinfo is None: - _expires_at = _expires_at.replace(tzinfo=timezone.utc) - _now = datetime.now(timezone.utc) - _session_dict["expires_in"] = int((_expires_at - _now).total_seconds()) - - # Surface pending invitations so the UI can offer "join vs create" - _pending = OrgInviteToken.query.filter( - OrgInviteToken.email == user.email, - OrgInviteToken.accepted_at.is_(None), - OrgInviteToken.expires_at > _now, - OrgInviteToken.deleted_at.is_(None), - ).all() - _pending_list = [ - { - "token": inv.token, - "organization": { - "id": str(inv.organization_id), - "name": inv.organization.name, - }, - "role": inv.role, - "expires_at": inv.expires_at.isoformat(), - } - for inv in _pending - ] - - return { - "success": True, - "flow_type": "register", - "requires_org_creation": True, - "user": { - "id": user.id, - "email": user.email, - "full_name": user.full_name, - }, - "session": _session_dict, - "pending_invites": _pending_list, - "state": state_record.state, - } - - except ExternalAuthError as e: - logger.warning( - f"OAuth registration failed for state={state_record.id}, " - f"provider={provider_type_str}, error={e.message}" - ) - raise - except OAuthFlowError: - # Re-raise OAuthFlowError as-is - raise - except Exception as e: - logger.error( - f"Unexpected error in OAuth registration callback: {str(e)}", - exc_info=True - ) - raise OAuthFlowError( - "An unexpected error occurred during registration", - "INTERNAL_ERROR", - 500, - ) - - @classmethod - def validate_state(cls, state: str) -> Optional[OAuthState]: - """ - Validate and return OAuth state. - - Args: - state: The state parameter to validate - - Returns: - OAuthState if valid, None otherwise - """ - state_record = OAuthState.query.filter_by(state=state).first() - if state_record and state_record.is_valid(): - return state_record - return None - - @classmethod - def cleanup_expired_states(cls): - """Remove expired OAuth states.""" - OAuthState.cleanup_expired() - logger.info("Expired OAuth states cleaned up") - - @classmethod - def generate_authorization_code( - cls, - user_id: str, - client_id: str, - redirect_uri: str, - scope: list = None, - nonce: str = None, - ip_address: str = None, - user_agent: str = None, - lifetime_seconds: int = 600, - ) -> str: - """ - Generate an authorization code for external OAuth applications. - - This method creates a short-lived, single-use authorization code that can be - exchanged for a session token by external applications like oauth2-proxy. - - Args: - user_id: The user ID - client_id: The client ID (e.g., 'oauth2-proxy', 'bookstack') - redirect_uri: The redirect URI - scope: Requested scopes - nonce: OIDC nonce for validation - ip_address: Client IP address - user_agent: Client user agent - lifetime_seconds: Code lifetime in seconds (default 10 minutes) - - Returns: - The authorization code (plain text, not hashed) - """ - # Generate a secure random code - code = secrets.token_urlsafe(32) - code_hash = hashlib.sha256(code.encode()).hexdigest() - - # Create the authorization code record - OIDCAuthCode.create_code( - client_id=client_id, - user_id=user_id, - code_hash=code_hash, - redirect_uri=redirect_uri, - scope=scope, - nonce=nonce, - ip_address=ip_address, - user_agent=user_agent, - lifetime_seconds=lifetime_seconds, - ) - - logger.info( - f"Generated authorization code for user={user_id}, client={client_id}" - ) - - return code - - @classmethod - def exchange_authorization_code( - cls, - code: str, - client_id: str, - redirect_uri: str, - ip_address: str = None, - ) -> dict: - """ - Exchange an authorization code for a session token. - - This method validates and consumes the authorization code, then creates - a session for the user. - - Args: - code: The authorization code - client_id: The client ID - redirect_uri: The redirect URI (must match original request) - ip_address: Client IP address - - Returns: - Dict with session token and user info - """ - # Hash the provided code for lookup - code_hash = hashlib.sha256(code.encode()).hexdigest() - - # Find the authorization code record - auth_code = OIDCAuthCode.query.filter_by( - client_id=client_id, - code_hash=code_hash, - ).first() - - if not auth_code: - raise OAuthFlowError( - "Invalid authorization code", - "INVALID_CODE", - 400, - ) - - # Validate the code - if not auth_code.is_valid(): - if auth_code.is_used: - raise OAuthFlowError( - "Authorization code has already been used", - "CODE_USED", - 400, - ) - else: - raise OAuthFlowError( - "Authorization code has expired", - "CODE_EXPIRED", - 400, - ) - - # Validate redirect URI - if auth_code.redirect_uri != redirect_uri: - raise OAuthFlowError( - "Redirect URI mismatch", - "INVALID_REDIRECT_URI", - 400, - ) - - # Get the user - from gatehouse_app.models import User - user = User.query.get(auth_code.user_id) - if not user: - raise OAuthFlowError( - "User not found", - "USER_NOT_FOUND", - 404, - ) - - # Determine organization - from gatehouse_app.models.organization.organization import Organization - from gatehouse_app.models.organization.organization_member import OrganizationMember - - # Get user's organizations - user_orgs = user.get_organizations() - - # Determine target organization - target_org = None - - # Priority 1: Use organization_id from auth code if available - # Priority 2: If user has exactly one organization, use it - if not target_org and len(user_orgs) == 1: - target_org = user_orgs[0] - - if not target_org: - raise OAuthFlowError( - "User does not have a default organization. Organization selection required.", - "ORG_SELECTION_REQUIRED", - 400, - ) - - # Create session - from gatehouse_app.services.auth_service import AuthService - session = AuthService.create_session( - user=user, - is_compliance_only=False, - ) - - # Mark the code as used - auth_code.mark_as_used() - - # Build session dict - session_dict = session.to_dict() - session_dict["token"] = session.token - expires_at = session.expires_at - if expires_at.tzinfo is None: - expires_at = expires_at.replace(tzinfo=timezone.utc) - now = datetime.now(timezone.utc) - session_dict["expires_in"] = int((expires_at - now).total_seconds()) - - logger.info( - f"Authorization code exchanged for session: user={user.id}, " - f"org_id={target_org.id}, client={client_id}" - ) - - return { - "success": True, - "token": session_dict["token"], - "expires_in": session_dict["expires_in"], - "token_type": "Bearer", - "user": { - "id": user.id, - "email": user.email, - "full_name": user.full_name, - "organization_id": target_org.id, - }, - } - - @classmethod - def create_redirect_response( - cls, - redirect_uri: str, - authorization_code: str, - state: str = None, - ): - """ - Create a redirect response with authorization code. - - Args: - redirect_uri: The redirect URI - authorization_code: The authorization code - state: Optional state parameter - - Returns: - Flask redirect response - """ - from urllib.parse import urlencode, urlparse, urlunparse - - # Parse the redirect URI - parsed = urlparse(redirect_uri) - - # Build query parameters - params = {"code": authorization_code} - if state: - params["state"] = state - - # Reconstruct URL with query parameters - redirect_url = urlunparse(( - parsed.scheme, - parsed.netloc, - parsed.path, - parsed.params, - urlencode(params), - parsed.fragment, - )) - - logger.info( - f"Redirecting to {parsed.scheme}://{parsed.netloc} with authorization code" - ) - - return redirect(redirect_url) diff --git a/gatehouse_app/services/oidc/__init__.py b/gatehouse_app/services/oidc/__init__.py new file mode 100644 index 0000000..7aae242 --- /dev/null +++ b/gatehouse_app/services/oidc/__init__.py @@ -0,0 +1,150 @@ +"""OIDCService — public facade over the oidc sub-package.""" +import logging +from typing import Dict, List, Optional, Tuple + +from gatehouse_app.exceptions.auth_exceptions import InvalidTokenError + +logger = logging.getLogger(__name__) + + +class OIDCError(Exception): + def __init__(self, error: str, error_description: str = None, status_code: int = 400): + self.error = error + self.error_description = error_description + self.status_code = status_code + + +class InvalidClientError(OIDCError): + def __init__(self, error_description: str = "Invalid client"): + super().__init__("invalid_client", error_description, 401) + + +class InvalidGrantError(OIDCError): + def __init__(self, error_description: str = "Invalid grant"): + super().__init__("invalid_grant", error_description, 400) + + +class InvalidRequestError(OIDCError): + def __init__(self, error_description: str = "Invalid request"): + super().__init__("invalid_request", error_description, 400) + + +from gatehouse_app.services.oidc import auth_code as _auth_code +from gatehouse_app.services.oidc import tokens as _tokens +from gatehouse_app.services.oidc import userinfo as _userinfo + + +class OIDCService: + """Main OIDC service handling all OpenID Connect operations.""" + + @staticmethod + def _generate_code() -> str: + import secrets + return secrets.token_urlsafe(32) + + @staticmethod + def _hash_value(value: str) -> str: + import hashlib + return hashlib.sha256(value.encode()).hexdigest() + + @classmethod + def generate_authorization_code( + cls, + client_id: str, + user_id: str, + redirect_uri: str, + scope: list, + state: str, + nonce: str, + code_challenge: str = None, + code_challenge_method: str = None, + ip_address: str = None, + user_agent: str = None, + ) -> str: + return _auth_code.generate_authorization_code( + client_id, user_id, redirect_uri, scope, state, nonce, + code_challenge, code_challenge_method, ip_address, user_agent, + ) + + @classmethod + def validate_authorization_code( + cls, + code: str, + client_id: str, + redirect_uri: str, + code_verifier: str = None, + ip_address: str = None, + user_agent: str = None, + ) -> Tuple[Dict, object]: + return _auth_code.validate_authorization_code( + code, client_id, redirect_uri, code_verifier, ip_address, user_agent + ) + + @classmethod + def _compute_code_challenge(cls, verifier: str, method: str = "S256") -> str: + return _auth_code._compute_code_challenge(verifier, method) + + @classmethod + def generate_tokens( + cls, + client_id: str, + user_id: str, + scope: list, + nonce: str = None, + refresh_token: str = None, + ip_address: str = None, + user_agent: str = None, + auth_time: int = None, + ) -> Dict: + return _tokens.generate_tokens( + client_id, user_id, scope, nonce, refresh_token, ip_address, user_agent, auth_time + ) + + @classmethod + def refresh_access_token( + cls, + refresh_token: str, + client_id: str, + scope: list = None, + ip_address: str = None, + user_agent: str = None, + ) -> Dict: + return _tokens.refresh_access_token(refresh_token, client_id, scope, ip_address, user_agent) + + @classmethod + def validate_access_token(cls, token: str, client_id: str = None) -> Dict: + return _tokens.validate_access_token(token, client_id) + + @classmethod + def revoke_token( + cls, + token: str, + client_id: str, + token_type_hint: str = None, + ip_address: str = None, + user_agent: str = None, + ) -> bool: + return _tokens.revoke_token(token, client_id, token_type_hint, ip_address, user_agent) + + @classmethod + def introspect_token( + cls, + token: str, + client_id: str = None, + ip_address: str = None, + user_agent: str = None, + ) -> Dict: + return _tokens.introspect_token(token, client_id, ip_address, user_agent) + + @classmethod + def get_jwks(cls) -> Dict: + from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService + return OIDCJWKSService().get_jwks() + + @classmethod + def get_userinfo(cls, access_token: str) -> Dict: + return _userinfo.get_userinfo(access_token, cls.validate_access_token) + + @staticmethod + def _get_user_roles(user) -> list: + return _userinfo._get_user_roles(user) diff --git a/gatehouse_app/services/oidc/auth_code.py b/gatehouse_app/services/oidc/auth_code.py new file mode 100644 index 0000000..1009b56 --- /dev/null +++ b/gatehouse_app/services/oidc/auth_code.py @@ -0,0 +1,196 @@ +"""OIDC authorization code generation and validation.""" +import logging +from datetime import datetime, timezone +from typing import Dict, Tuple + +from flask import current_app + +from gatehouse_app.models import User, OIDCAuthCode +from gatehouse_app.exceptions.validation_exceptions import ValidationError, NotFoundError +from gatehouse_app.services.oidc_audit_service import OIDCAuditService + +logger = logging.getLogger(__name__) + + +def _hash_value(value: str) -> str: + import hashlib + return hashlib.sha256(value.encode()).hexdigest() + + +def _compute_code_challenge(verifier: str, method: str = "S256") -> str: + import hashlib + import base64 + if method == "S256": + digest = hashlib.sha256(verifier.encode()).digest() + return base64.urlsafe_b64encode(digest).decode().rstrip("=") + return verifier + + +def generate_authorization_code( + client_id: str, + user_id: str, + redirect_uri: str, + scope: list, + state: str, + nonce: str, + code_challenge: str = None, + code_challenge_method: str = None, + ip_address: str = None, + user_agent: str = None, +) -> str: + import secrets + + from gatehouse_app.models import OIDCClient + + logger.debug("[OIDC SERVICE] generate_authorization_code called") + logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s", client_id, user_id) + + client = OIDCClient.query.filter_by(client_id=client_id).first() + + if current_app.config.get('ENV') == 'development': + logger.debug(f"[OIDC] Generate auth code - Client validation: client_id={client_id}, exists={client is not None}") + + if not client: + raise NotFoundError("Client not found") + + if not client.is_active: + raise ValidationError("Client is not active") + + if not client.is_redirect_uri_allowed(redirect_uri): + raise ValidationError("Invalid redirect_uri") + + allowed_scopes = client.scopes or [] + valid_scopes = [s for s in scope if s in allowed_scopes] + + if not valid_scopes: + raise ValidationError("Invalid scopes") + + code = secrets.token_urlsafe(32) + code_hash = _hash_value(code) + + auth_code = OIDCAuthCode.create_code( + client_id=client.id, + user_id=user_id, + code_hash=code_hash, + redirect_uri=redirect_uri, + scope=valid_scopes, + nonce=nonce, + code_verifier=code_challenge, + ip_address=ip_address, + user_agent=user_agent, + lifetime_seconds=600, + ) + logger.debug("[OIDC SERVICE] Auth code created, expires_at=%s", auth_code.expires_at.isoformat()) + + OIDCAuditService.log_authorization_event( + client_id=client.id, + user_id=user_id, + success=True, + redirect_uri=redirect_uri, + scope=valid_scopes, + ) + + return code + + +def validate_authorization_code( + code: str, + client_id: str, + redirect_uri: str, + code_verifier: str = None, + ip_address: str = None, + user_agent: str = None, +) -> Tuple[Dict, User]: + from gatehouse_app.models import OIDCClient + from gatehouse_app.exceptions.auth_exceptions import InvalidTokenError + + logger.debug("[OIDC SERVICE] validate_authorization_code called, client_id=%s", client_id) + + client = OIDCClient.query.filter_by(client_id=client_id).first() + if not client: + logger.error(f"[OIDC] Validate auth code - Client not found: client_id={client_id}") + from gatehouse_app.services.oidc import InvalidGrantError + raise InvalidGrantError("Invalid client") + + code_hash = _hash_value(code) + auth_code = OIDCAuthCode.query.filter_by( + code_hash=code_hash, + client_id=client.id, + deleted_at=None, + ).first() + + if not auth_code: + OIDCAuditService.log_authorization_event( + client_id=client.id, + success=False, + error_code="invalid_grant", + error_description="Invalid or expired authorization code", + ) + from gatehouse_app.services.oidc import InvalidGrantError + raise InvalidGrantError("Invalid or expired authorization code") + + if auth_code.is_used: + OIDCAuditService.log_authorization_event( + client_id=client.id, + user_id=auth_code.user_id, + success=False, + error_code="invalid_grant", + error_description="Authorization code already used", + ) + from gatehouse_app.services.oidc import InvalidGrantError + raise InvalidGrantError("Authorization code already used") + + expires_at = auth_code.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + logger.debug( + "[OIDC SERVICE] Time until expiration (seconds): %s", + (expires_at - datetime.now(timezone.utc)).total_seconds(), + ) + + if auth_code.is_expired(): + OIDCAuditService.log_authorization_event( + client_id=client.id, + user_id=auth_code.user_id, + success=False, + error_code="invalid_grant", + error_description="Authorization code expired", + ) + from gatehouse_app.services.oidc import InvalidGrantError + raise InvalidGrantError("Authorization code expired") + + if auth_code.redirect_uri != redirect_uri: + from gatehouse_app.services.oidc import InvalidGrantError + raise InvalidGrantError("Invalid redirect_uri") + + if client.require_pkce and auth_code.code_verifier: + if not code_verifier: + raise ValidationError("code_verifier is required") + expected_challenge = _compute_code_challenge(code_verifier, "S256") + if expected_challenge != auth_code.code_verifier: + OIDCAuditService.log_authorization_event( + client_id=client.id, + user_id=auth_code.user_id, + success=False, + error_code="invalid_grant", + error_description="Invalid code_verifier", + ) + from gatehouse_app.services.oidc import InvalidGrantError + raise InvalidGrantError("Invalid code_verifier") + + auth_code.mark_as_used() + + user = User.query.get(auth_code.user_id) + if not user: + from gatehouse_app.services.oidc import InvalidGrantError + raise InvalidGrantError("User not found") + + claims = { + "user_id": auth_code.user_id, + "client_id": client_id, + "redirect_uri": redirect_uri, + "scope": auth_code.scope, + "nonce": auth_code.nonce, + } + + return claims, user diff --git a/gatehouse_app/services/oidc/tokens.py b/gatehouse_app/services/oidc/tokens.py new file mode 100644 index 0000000..b540ad5 --- /dev/null +++ b/gatehouse_app/services/oidc/tokens.py @@ -0,0 +1,321 @@ +"""OIDC token generation, refresh, validation, revocation, and introspection.""" +import hashlib +import logging +from datetime import datetime, timedelta, timezone +from typing import Dict, Optional + +from flask import current_app + +from gatehouse_app.models import OIDCClient, OIDCRefreshToken, OIDCTokenMetadata +from gatehouse_app.services.oidc_token_service import OIDCTokenService +from gatehouse_app.services.oidc_audit_service import OIDCAuditService +from gatehouse_app.exceptions.auth_exceptions import InvalidTokenError + +logger = logging.getLogger(__name__) + + +def generate_tokens( + client_id: str, + user_id: str, + scope: list, + nonce: str = None, + refresh_token: str = None, + ip_address: str = None, + user_agent: str = None, + auth_time: int = None, +) -> Dict: + logger.debug("[OIDC SERVICE] generate_tokens called: client_id=%s, user_id=%s", client_id, user_id) + + client = OIDCClient.query.filter_by(client_id=client_id).first() + if not client: + from gatehouse_app.services.oidc import InvalidClientError + raise InvalidClientError() + + access_token_jti = OIDCTokenService._generate_jti() + access_token = OIDCTokenService.create_access_token( + client_id=client_id, + user_id=user_id, + scope=scope, + jti=access_token_jti, + ) + + id_token = OIDCTokenService.create_id_token( + client_id=client_id, + user_id=user_id, + nonce=nonce, + scope=scope, + access_token=access_token, + auth_time=auth_time, + ) + + final_refresh_token = None + if "refresh_token" in (client.grant_types or []): + if refresh_token: + refresh_token_obj = OIDCRefreshToken.query.filter_by( + token_hash=hashlib.sha256(refresh_token.encode()).hexdigest(), + deleted_at=None, + ).first() + if refresh_token_obj and refresh_token_obj.is_valid(): + new_refresh, new_hash = OIDCTokenService.create_refresh_token( + client_id=client_id, + user_id=user_id, + scope=scope, + access_token_id=access_token_jti, + ) + refresh_token_obj.rotate(new_hash) + final_refresh_token = new_refresh + else: + final_refresh_token, refresh_hash = OIDCTokenService.create_refresh_token( + client_id=client_id, + user_id=user_id, + scope=scope, + access_token_id=access_token_jti, + ) + OIDCRefreshToken.create_token( + client_id=client.id, + user_id=user_id, + token_hash=refresh_hash, + scope=scope, + access_token_id=access_token_jti, + ip_address=ip_address, + user_agent=user_agent, + lifetime_seconds=client.refresh_token_lifetime or 2592000, + ) + + access_token_expires_at = datetime.now(timezone.utc) + timedelta( + seconds=client.access_token_lifetime or 3600 + ) + OIDCTokenMetadata.create_metadata( + client_id=client.id, + user_id=user_id, + token_type="access_token", + token_jti=access_token_jti, + expires_at=access_token_expires_at, + ) + + id_token_jti = OIDCTokenService._generate_jti() + id_token_expires_at = datetime.now(timezone.utc) + timedelta( + seconds=client.id_token_lifetime or 3600 + ) + OIDCTokenMetadata.create_metadata( + client_id=client.id, + user_id=user_id, + token_type="id_token", + token_jti=id_token_jti, + expires_at=id_token_expires_at, + ) + + OIDCAuditService.log_token_event( + client_id=client.id, + user_id=user_id, + token_type="access_token", + success=True, + grant_type="authorization_code", + scopes=scope, + ) + + result = { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": client.access_token_lifetime or 3600, + "id_token": id_token, + } + if final_refresh_token: + result["refresh_token"] = final_refresh_token + + return result + + +def refresh_access_token( + refresh_token: str, + client_id: str, + scope: list = None, + ip_address: str = None, + user_agent: str = None, +) -> Dict: + logger.debug("[OIDC SERVICE] refresh_access_token called, client_id=%s", client_id) + + client = OIDCClient.query.filter_by(client_id=client_id).first() + if not client: + from gatehouse_app.services.oidc import InvalidClientError + raise InvalidClientError() + + token_hash = hashlib.sha256(refresh_token.encode()).hexdigest() + refresh_token_obj = OIDCRefreshToken.query.filter_by( + token_hash=token_hash, + deleted_at=None, + ).first() + + if not refresh_token_obj: + OIDCAuditService.log_token_event( + client_id=client.id, + success=False, + error_code="invalid_grant", + error_description="Invalid refresh token", + ) + from gatehouse_app.services.oidc import InvalidGrantError + raise InvalidGrantError("Invalid refresh token") + + if not refresh_token_obj.is_valid(): + OIDCAuditService.log_token_event( + client_id=client.id, + user_id=refresh_token_obj.user_id, + success=False, + error_code="invalid_grant", + error_description="Refresh token expired or revoked", + ) + from gatehouse_app.services.oidc import InvalidGrantError + raise InvalidGrantError("Refresh token expired or revoked") + + if refresh_token_obj.client_id != client.id: + from gatehouse_app.services.oidc import InvalidGrantError + raise InvalidGrantError("Client mismatch") + + granted_scope = scope or (refresh_token_obj.scope or []) + + access_token_jti = OIDCTokenService._generate_jti() + access_token = OIDCTokenService.create_access_token( + client_id=client_id, + user_id=refresh_token_obj.user_id, + scope=granted_scope, + jti=access_token_jti, + ) + + id_token = OIDCTokenService.create_id_token( + client_id=client_id, + user_id=refresh_token_obj.user_id, + scope=granted_scope, + access_token=access_token, + ) + + new_refresh, new_hash = OIDCTokenService.create_refresh_token( + client_id=client_id, + user_id=refresh_token_obj.user_id, + scope=granted_scope, + access_token_id=access_token_jti, + ) + refresh_token_obj.rotate(new_hash) + + access_token_expires_at = datetime.now(timezone.utc) + timedelta( + seconds=client.access_token_lifetime or 3600 + ) + OIDCTokenMetadata.create_metadata( + client_id=client.id, + user_id=refresh_token_obj.user_id, + token_type="access_token", + token_jti=access_token_jti, + expires_at=access_token_expires_at, + ) + + OIDCAuditService.log_token_event( + client_id=client.id, + user_id=refresh_token_obj.user_id, + token_type="access_token", + success=True, + grant_type="refresh_token", + scopes=granted_scope, + ) + + return { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": client.access_token_lifetime or 3600, + "id_token": id_token, + "refresh_token": new_refresh, + } + + +def validate_access_token(token: str, client_id: str = None) -> Dict: + logger.debug("[OIDC SERVICE] validate_access_token() called") + + try: + claims = OIDCTokenService.validate_access_token(token, client_id) + logger.debug("[OIDC SERVICE] Token validation successful") + return claims + except Exception as e: + logger.error("[OIDC SERVICE] Token validation failed: %s: %s", type(e).__name__, str(e)) + _client_db_id = None + if client_id: + _c = OIDCClient.query.filter_by(client_id=client_id).first() + _client_db_id = _c.id if _c else None + OIDCAuditService.log_event( + event_type="token_validation", + client_id=_client_db_id, + success=False, + error_code="invalid_token", + error_description=str(e), + ) + raise InvalidTokenError(str(e)) + + +def revoke_token( + token: str, + client_id: str, + token_type_hint: str = None, + ip_address: str = None, + user_agent: str = None, +) -> bool: + client = OIDCClient.query.filter_by(client_id=client_id).first() + if not client: + from gatehouse_app.services.oidc import InvalidClientError + raise InvalidClientError() + + revoked = False + token_hash = hashlib.sha256(token.encode()).hexdigest() + + if token_type_hint in (None, "refresh_token"): + refresh_token_obj = OIDCRefreshToken.query.filter_by( + token_hash=token_hash, + deleted_at=None, + ).first() + if refresh_token_obj: + refresh_token_obj.revoke(reason="revoked_by_client") + revoked = True + OIDCAuditService.log_token_revocation_event( + client_id=client.id, + user_id=refresh_token_obj.user_id, + token_type="refresh_token", + reason="revoked_by_client", + ) + + if not revoked or token_type_hint in (None, "access_token"): + try: + claims = OIDCTokenService.decode_token(token) + jti = claims.get("jti") + if jti: + revoked_at = OIDCTokenMetadata.revoke_by_jti(jti, reason="revoked_by_client") + if revoked_at: + revoked = True + OIDCAuditService.log_token_revocation_event( + client_id=client.id, + user_id=claims.get("sub"), + token_type="access_token", + reason="revoked_by_client", + ) + except Exception: + pass + + return revoked + + +def introspect_token( + token: str, + client_id: str = None, + ip_address: str = None, + user_agent: str = None, +) -> Dict: + result = OIDCTokenService.introspect_token(token, client_id) + + _client_db_id = None + if client_id: + _ic = OIDCClient.query.filter_by(client_id=client_id).first() + _client_db_id = _ic.id if _ic else None + OIDCAuditService.log_event( + event_type="token_introspection", + client_id=_client_db_id, + user_id=result.get("sub"), + success=result.get("active", False), + metadata={"active": result.get("active")}, + ) + + return result diff --git a/gatehouse_app/services/oidc/userinfo.py b/gatehouse_app/services/oidc/userinfo.py new file mode 100644 index 0000000..2e46705 --- /dev/null +++ b/gatehouse_app/services/oidc/userinfo.py @@ -0,0 +1,65 @@ +"""OIDC userinfo endpoint logic.""" +import logging +from typing import Dict + +from gatehouse_app.models import User +from gatehouse_app.exceptions.validation_exceptions import NotFoundError +from gatehouse_app.services.oidc_audit_service import OIDCAuditService + +logger = logging.getLogger(__name__) + + +def get_userinfo(access_token: str, validate_access_token_fn) -> Dict: + logger.debug("[OIDC SERVICE] get_userinfo() called") + + claims = validate_access_token_fn(access_token) + user_id = claims.get("sub") + + user = User.query.get(user_id) + if not user: + logger.error("[OIDC SERVICE] User not found in database: user_id=%s", user_id) + raise NotFoundError("User not found") + + scope_str = claims.get("scope", "") + scopes = scope_str.split() if scope_str else [] + + userinfo = {"sub": user_id} + + if "profile" in scopes and user.full_name: + userinfo["name"] = user.full_name + + if "email" in scopes: + userinfo["email"] = user.email + userinfo["email_verified"] = user.email_verified + + if "roles" in scopes: + userinfo["roles"] = _get_user_roles(user) + + _userinfo_client_id_str = claims.get("client_id") + _userinfo_client_db_id = None + if _userinfo_client_id_str: + from gatehouse_app.models import OIDCClient + _uc = OIDCClient.query.filter_by(client_id=_userinfo_client_id_str).first() + _userinfo_client_db_id = _uc.id if _uc else None + + OIDCAuditService.log_userinfo_event( + access_token=access_token, + user_id=user_id, + client_id=_userinfo_client_db_id, + success=True, + scopes_claimed=scopes, + ) + + return userinfo + + +def _get_user_roles(user: User) -> list: + roles = [] + if not user or not user.organization_memberships: + return roles + for member in user.organization_memberships: + roles.append({ + "organization_id": str(member.organization_id), + "role": member.role.value, + }) + return roles diff --git a/gatehouse_app/services/oidc_service.py b/gatehouse_app/services/oidc_service.py deleted file mode 100644 index fc6e317..0000000 --- a/gatehouse_app/services/oidc_service.py +++ /dev/null @@ -1,1025 +0,0 @@ -"""OIDC Service - Main OIDC service layer.""" -import logging -import secrets -import hashlib -from datetime import datetime, timedelta, timezone -from typing import Dict, List, Optional, Tuple - -from flask import current_app, g - -logger = logging.getLogger(__name__) - -from gatehouse_app.extensions import db -from gatehouse_app.models import ( - User, OIDCClient, OIDCAuthCode, OIDCRefreshToken, - OIDCSession, OIDCTokenMetadata -) -from gatehouse_app.models.organization.organization_member import OrganizationMember -from gatehouse_app.exceptions.validation_exceptions import ( - ValidationError, NotFoundError, BadRequestError -) -from gatehouse_app.exceptions.auth_exceptions import UnauthorizedError, InvalidTokenError -from gatehouse_app.services.oidc_token_service import OIDCTokenService -from gatehouse_app.services.oidc_session_service import OIDCSessionService -from gatehouse_app.services.oidc_audit_service import OIDCAuditService -from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService - - -class OIDCError(Exception): - """Base exception for OIDC errors.""" - - def __init__(self, error: str, error_description: str = None, status_code: int = 400): - self.error = error - self.error_description = error_description - self.status_code = status_code - - -class InvalidClientError(OIDCError): - """Raised when client authentication fails.""" - - def __init__(self, error_description: str = "Invalid client"): - super().__init__("invalid_client", error_description, 401) - - -class InvalidGrantError(OIDCError): - """Raised when grant is invalid.""" - - def __init__(self, error_description: str = "Invalid grant"): - super().__init__("invalid_grant", error_description, 400) - - -class InvalidRequestError(OIDCError): - """Raised when request is malformed.""" - - def __init__(self, error_description: str = "Invalid request"): - super().__init__("invalid_request", error_description, 400) - - -class OIDCService: - """Main OIDC service handling all OpenID Connect operations. - - This service provides: - - Authorization code generation and validation - - Token generation (access, refresh, ID tokens) - - Token refresh with rotation - - Token validation and introspection - - Token revocation - """ - - @staticmethod - def _generate_code() -> str: - """Generate a secure authorization code. - - Returns: - URL-safe base64 encoded code - """ - return secrets.token_urlsafe(32) - - @staticmethod - def _hash_value(value: str) -> str: - """Hash a value for secure storage. - - Args: - value: Value to hash - - Returns: - SHA256 hash - """ - return hashlib.sha256(value.encode()).hexdigest() - - @classmethod - def generate_authorization_code( - cls, - client_id: str, - user_id: str, - redirect_uri: str, - scope: list, - state: str, - nonce: str, - code_challenge: str = None, - code_challenge_method: str = None, - ip_address: str = None, - user_agent: str = None - ) -> str: - """Generate an authorization code for the auth code flow. - - Args: - client_id: OIDC client ID - user_id: User ID - redirect_uri: Redirect URI - scope: Requested scopes - state: State parameter - nonce: Nonce for ID token - code_challenge: PKCE code challenge - code_challenge_method: PKCE method ("S256" or "plain") - ip_address: Client IP address - user_agent: Client user agent - - Returns: - Authorization code string - - Raises: - ValidationError: If parameters are invalid - NotFoundError: If client not found - """ - logger.debug("[OIDC SERVICE] ===========================================") - logger.debug("[OIDC SERVICE] generate_authorization_code called") - logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s", client_id, user_id) - logger.debug("[OIDC SERVICE] redirect_uri=%s", redirect_uri) - logger.debug("[OIDC SERVICE] scope=%s", scope) - logger.debug("[OIDC SERVICE] state=%s, nonce=%s", state, nonce) - - # Validate client exists and is active - client = OIDCClient.query.filter_by(client_id=client_id).first() - - # Development-only debug logging for client validation - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Generate auth code - Client validation: client_id={client_id}, exists={client is not None}") - - if not client: - raise NotFoundError("Client not found") - - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Generate auth code - Client active validation: client_id={client_id}, is_active={client.is_active}") - - if not client.is_active: - raise ValidationError("Client is not active") - - # Validate redirect URI - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Generate auth code - Redirect URI validation: client_id={client_id}, redirect_uri={redirect_uri}") - - if not client.is_redirect_uri_allowed(redirect_uri): - raise ValidationError("Invalid redirect_uri") - - # Validate scopes - allowed_scopes = client.scopes or [] - valid_scopes = [s for s in scope if s in allowed_scopes] - - if not valid_scopes: - raise ValidationError("Invalid scopes") - - # Generate authorization code - logger.debug("[OIDC SERVICE] Generating authorization code...") - logger.debug("[OIDC SERVICE] Current UTC time before code generation: %s", datetime.now(timezone.utc).isoformat() + "Z") - code = cls._generate_code() - code_hash = cls._hash_value(code) - logger.debug("[OIDC SERVICE] Code generated: %s...", code[:20] if code else None) - - # Development-only debug logging for PKCE in code creation - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Generate auth code - PKCE: code_challenge={code_challenge is not None}, code_challenge_method={code_challenge_method}") - - # Create auth code record - logger.debug("[OIDC SERVICE] Creating auth code record with lifetime_seconds=600 (10 minutes)") - logger.debug("[OIDC SERVICE] Current UTC time before creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z") - auth_code = OIDCAuthCode.create_code( - client_id=client.id, - user_id=user_id, - code_hash=code_hash, - redirect_uri=redirect_uri, - scope=valid_scopes, - nonce=nonce, - code_verifier=code_challenge, # Store for validation - ip_address=ip_address, - user_agent=user_agent, - lifetime_seconds=600, # 10 minutes - ) - logger.debug("[OIDC SERVICE] Auth code created successfully") - logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z") - logger.debug("[OIDC SERVICE] Current UTC time after creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z") - - # Log authorization event — use client.id (UUID) not client_id (string) for FK - OIDCAuditService.log_authorization_event( - client_id=client.id, - user_id=user_id, - success=True, - redirect_uri=redirect_uri, - scope=valid_scopes, - ) - - logger.debug("[OIDC SERVICE] generate_authorization_code completed successfully") - logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] ===========================================") - return code - - @classmethod - def validate_authorization_code( - cls, - code: str, - client_id: str, - redirect_uri: str, - code_verifier: str = None, - ip_address: str = None, - user_agent: str = None - ) -> Tuple[Dict, User]: - """Validate and consume an authorization code. - - Args: - code: Authorization code - client_id: OIDC client ID - redirect_uri: Redirect URI - code_verifier: PKCE code verifier (required if PKCE was used) - ip_address: Client IP address - user_agent: Client user agent - - Returns: - Tuple of (claims dict, User instance) - - Raises: - InvalidGrantError: If code is invalid - ValidationError: If PKCE validation fails - """ - logger.debug("[OIDC SERVICE] ===========================================") - logger.debug("[OIDC SERVICE] validate_authorization_code called") - logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] client_id=%s, redirect_uri=%s", client_id, redirect_uri) - logger.debug("[OIDC SERVICE] code_verifier provided: %s", bool(code_verifier)) - - # Get client - client = OIDCClient.query.filter_by(client_id=client_id).first() - - # Development-only debug logging for client validation in code validation - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Validate auth code - Client validation: client_id={client_id}, exists={client is not None}") - - if not client: - logger.error(f"[OIDC] Validate auth code - Client not found: client_id={client_id}") - raise InvalidGrantError("Invalid client") - - # Hash the provided code and find matching auth code - logger.debug("[OIDC SERVICE] Looking up authorization code...") - logger.debug("[OIDC SERVICE] Current UTC time before code lookup: %s", datetime.now(timezone.utc).isoformat() + "Z") - code_hash = cls._hash_value(code) - auth_code = OIDCAuthCode.query.filter_by( - code_hash=code_hash, - client_id=client.id, - deleted_at=None - ).first() - - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Validate auth code - Code lookup: code_hash={code_hash[:20]}..., found={auth_code is not None}") - - if not auth_code: - logger.error(f"[OIDC] Validate auth code - Code not found or deleted: code_hash={code_hash[:20]}...") - OIDCAuditService.log_authorization_event( - client_id=client.id, - success=False, - error_code="invalid_grant", - error_description="Invalid or expired authorization code", - ) - raise InvalidGrantError("Invalid or expired authorization code") - - # Check if already used - if auth_code.is_used: - logger.error(f"[OIDC] Validate auth code - Code already used: code_hash={code_hash[:20]}..., user_id={auth_code.user_id}") - OIDCAuditService.log_authorization_event( - client_id=client.id, - user_id=auth_code.user_id, - success=False, - error_code="invalid_grant", - error_description="Authorization code already used", - ) - raise InvalidGrantError("Authorization code already used") - - # Check expiration - logger.debug("[OIDC SERVICE] Checking if authorization code is expired...") - logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z") - # Handle timezone-naive expires_at from database - expires_at = auth_code.expires_at - if expires_at.tzinfo is None: - expires_at = expires_at.replace(tzinfo=timezone.utc) - logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (expires_at - datetime.now(timezone.utc)).total_seconds()) - - if auth_code.is_expired(): - logger.error("[OIDC] Validate auth code - Code expired: code_hash=%s..., expires_at (UTC)=%s, current UTC time=%s", - code_hash[:20], auth_code.expires_at.isoformat() + "Z", datetime.now(timezone.utc).isoformat() + "Z") - OIDCAuditService.log_authorization_event( - client_id=client.id, - user_id=auth_code.user_id, - success=False, - error_code="invalid_grant", - error_description="Authorization code expired", - ) - raise InvalidGrantError("Authorization code expired") - - # Validate redirect URI - if auth_code.redirect_uri != redirect_uri: - logger.error(f"[OIDC] Validate auth code - Redirect URI mismatch: expected={auth_code.redirect_uri}, got={redirect_uri}") - raise InvalidGrantError("Invalid redirect_uri") - - # Validate PKCE if required - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Validate auth code - PKCE: require_pkce={client.require_pkce}, has_verifier={bool(auth_code.code_verifier)}, provided_verifier={bool(code_verifier)}") - - if client.require_pkce and auth_code.code_verifier: - if not code_verifier: - logger.error(f"[OIDC] Validate auth code - PKCE required but no code_verifier provided") - raise ValidationError("code_verifier is required") - - # Verify code verifier - expected_challenge = cls._compute_code_challenge(code_verifier, "S256") - if expected_challenge != auth_code.code_verifier: - logger.error(f"[OIDC] Validate auth code - Invalid code_verifier: expected={expected_challenge[:20]}..., got={auth_code.code_verifier[:20]}...") - OIDCAuditService.log_authorization_event( - client_id=client.id, - user_id=auth_code.user_id, - success=False, - error_code="invalid_grant", - error_description="Invalid code_verifier", - ) - raise InvalidGrantError("Invalid code_verifier") - - # Mark code as used - auth_code.mark_as_used() - - # Get user - user = User.query.get(auth_code.user_id) - - # Development-only debug logging for user validation - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Validate auth code - User validation: user_id={auth_code.user_id}, exists={user is not None}") - - if not user: - logger.error(f"[OIDC] Validate auth code - User not found: user_id={auth_code.user_id}") - raise InvalidGrantError("User not found") - - claims = { - "user_id": auth_code.user_id, - "client_id": client_id, - "redirect_uri": redirect_uri, - "scope": auth_code.scope, - "nonce": auth_code.nonce, - } - - logger.debug("[OIDC SERVICE] validate_authorization_code completed successfully") - logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] ===========================================") - return claims, user - - @classmethod - def _compute_code_challenge(cls, verifier: str, method: str = "S256") -> str: - """Compute PKCE code challenge from verifier. - - Args: - verifier: Code verifier - method: Challenge method - - Returns: - Code challenge - """ - import hashlib - import base64 - - if method == "S256": - digest = hashlib.sha256(verifier.encode()).digest() - return base64.urlsafe_b64encode(digest).decode().rstrip("=") - return verifier - - @classmethod - def generate_tokens( - cls, - client_id: str, - user_id: str, - scope: list, - nonce: str = None, - refresh_token: str = None, - ip_address: str = None, - user_agent: str = None, - auth_time: int = None - ) -> Dict: - """Generate access token, ID token, and refresh token. - - Args: - client_id: OIDC client ID - user_id: User ID - scope: Granted scopes - nonce: Nonce for ID token - refresh_token: Existing refresh token (for rotation) - ip_address: Client IP address - user_agent: Client user agent - auth_time: Authentication time - - Returns: - Dictionary with tokens - """ - import hashlib - - logger.debug("[OIDC SERVICE] ===========================================") - logger.debug("[OIDC SERVICE] generate_tokens called") - logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s, scope=%s", client_id, user_id, scope) - logger.debug("[OIDC SERVICE] nonce=%s, auth_time=%s", nonce, auth_time) - - # Get client - client = OIDCClient.query.filter_by(client_id=client_id).first() - - # Development-only debug logging for token generation client validation - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Generate tokens - Client validation: client_id={client_id}, exists={client is not None}") - - if not client: - raise InvalidClientError() - - # Generate access token - logger.debug("[OIDC SERVICE] Generating access token...") - logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600) - access_token_jti = OIDCTokenService._generate_jti() - access_token = OIDCTokenService.create_access_token( - client_id=client_id, - user_id=user_id, - scope=scope, - jti=access_token_jti, - ) - logger.debug("[OIDC SERVICE] Access token generated successfully") - logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") - - # Generate ID token - logger.debug("[OIDC SERVICE] Generating ID token...") - logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600) - id_token = OIDCTokenService.create_id_token( - client_id=client_id, - user_id=user_id, - nonce=nonce, - scope=scope, - access_token=access_token, - auth_time=auth_time, - ) - logger.debug("[OIDC SERVICE] ID token generated successfully") - logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") - - # Generate or rotate refresh token - if "refresh_token" in (client.grant_types or []): - if refresh_token: - # Rotate existing refresh token - refresh_token_obj = OIDCRefreshToken.query.filter_by( - token_hash=hashlib.sha256(refresh_token.encode()).hexdigest(), - deleted_at=None - ).first() - - if refresh_token_obj and refresh_token_obj.is_valid(): - # Create new refresh token - new_refresh, new_hash = OIDCTokenService.create_refresh_token( - client_id=client_id, - user_id=user_id, - scope=scope, - access_token_id=access_token_jti, - ) - - # Rotate in database - refresh_token_obj.rotate(new_hash) - final_refresh_token = new_refresh - else: - final_refresh_token = None - else: - # Create new refresh token - final_refresh_token, refresh_hash = OIDCTokenService.create_refresh_token( - client_id=client_id, - user_id=user_id, - scope=scope, - access_token_id=access_token_jti, - ) - - # Store refresh token - OIDCRefreshToken.create_token( - client_id=client.id, - user_id=user_id, - token_hash=refresh_hash, - scope=scope, - access_token_id=access_token_jti, - ip_address=ip_address, - user_agent=user_agent, - lifetime_seconds=client.refresh_token_lifetime or 2592000, - ) - else: - final_refresh_token = None - - # Store token metadata - client_db_id = client.id - - # Access token metadata - logger.debug("[OIDC SERVICE] Creating access token metadata...") - access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600) - logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z") - OIDCTokenMetadata.create_metadata( - client_id=client_db_id, - user_id=user_id, - token_type="access_token", - token_jti=access_token_jti, - expires_at=access_token_expires_at, - ) - - # ID token metadata (using access token JTI as reference) - logger.debug("[OIDC SERVICE] Creating ID token metadata...") - id_token_jti = OIDCTokenService._generate_jti() - id_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.id_token_lifetime or 3600) - logger.debug("[OIDC SERVICE] ID token expires_at (UTC): %s", id_token_expires_at.isoformat() + "Z") - OIDCTokenMetadata.create_metadata( - client_id=client_db_id, - user_id=user_id, - token_type="id_token", - token_jti=id_token_jti, - expires_at=id_token_expires_at, - ) - - # Log token event — use client.id (UUID) not client_id (string) for FK - OIDCAuditService.log_token_event( - client_id=client.id, - user_id=user_id, - token_type="access_token", - success=True, - grant_type="authorization_code", - scopes=scope, - ) - - result = { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": client.access_token_lifetime or 3600, - "id_token": id_token, - } - - if final_refresh_token: - result["refresh_token"] = final_refresh_token - - logger.debug("[OIDC SERVICE] generate_tokens completed successfully") - logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] ===========================================") - return result - - @classmethod - def refresh_access_token( - cls, - refresh_token: str, - client_id: str, - scope: list = None, - ip_address: str = None, - user_agent: str = None - ) -> Dict: - """Refresh an access token with token rotation. - - Args: - refresh_token: The refresh token - client_id: OIDC client ID - scope: Optional scope override - ip_address: Client IP address - user_agent: Client user agent - - Returns: - Dictionary with new tokens - - Raises: - InvalidGrantError: If refresh token is invalid - """ - import hashlib - - logger.debug("[OIDC SERVICE] ===========================================") - logger.debug("[OIDC SERVICE] refresh_access_token called") - logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] client_id=%s, scope=%s", client_id, scope) - - # Get client - client = OIDCClient.query.filter_by(client_id=client_id).first() - - # Development-only debug logging for refresh token client validation - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Refresh token - Client validation: client_id={client_id}, exists={client is not None}") - - if not client: - raise InvalidClientError() - - # Find refresh token - logger.debug("[OIDC SERVICE] Looking up refresh token...") - logger.debug("[OIDC SERVICE] Current UTC time before refresh token lookup: %s", datetime.now(timezone.utc).isoformat() + "Z") - token_hash = hashlib.sha256(refresh_token.encode()).hexdigest() - refresh_token_obj = OIDCRefreshToken.query.filter_by( - token_hash=token_hash, - deleted_at=None - ).first() - - # Development-only debug logging for refresh token validation - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Refresh token - Token validation: user_id={refresh_token_obj.user_id if refresh_token_obj else None}, found={refresh_token_obj is not None}") - - if not refresh_token_obj: - OIDCAuditService.log_token_event( - client_id=client.id, - success=False, - error_code="invalid_grant", - error_description="Invalid refresh token", - ) - raise InvalidGrantError("Invalid refresh token") - - # Check if valid - logger.debug("[OIDC SERVICE] Checking if refresh token is valid...") - logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") - if refresh_token_obj: - logger.debug("[OIDC SERVICE] Refresh token expires_at (UTC): %s", refresh_token_obj.expires_at.isoformat() + "Z") - # Handle timezone-naive expires_at from database - rt_expires_at = refresh_token_obj.expires_at - if rt_expires_at.tzinfo is None: - rt_expires_at = rt_expires_at.replace(tzinfo=timezone.utc) - logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (rt_expires_at - datetime.now(timezone.utc)).total_seconds()) - - if not refresh_token_obj.is_valid(): - OIDCAuditService.log_token_event( - client_id=client.id, - user_id=refresh_token_obj.user_id, - success=False, - error_code="invalid_grant", - error_description="Refresh token expired or revoked", - ) - raise InvalidGrantError("Refresh token expired or revoked") - - # Validate client matches - if current_app.config.get('ENV') == 'development': - logger.debug(f"[OIDC] Refresh token - Client match validation: expected={client.id}, actual={refresh_token_obj.client_id}, match={refresh_token_obj.client_id == client.id}") - - if refresh_token_obj.client_id != client.id: - raise InvalidGrantError("Client mismatch") - - # Get original scope or use provided - granted_scope = scope or (refresh_token_obj.scope or []) - - # Generate new access token - logger.debug("[OIDC SERVICE] Generating new access token...") - logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600) - access_token_jti = OIDCTokenService._generate_jti() - access_token = OIDCTokenService.create_access_token( - client_id=client_id, - user_id=refresh_token_obj.user_id, - scope=granted_scope, - jti=access_token_jti, - ) - logger.debug("[OIDC SERVICE] Access token generated successfully") - logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") - - # Generate new ID token - logger.debug("[OIDC SERVICE] Generating new ID token...") - logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600) - id_token = OIDCTokenService.create_id_token( - client_id=client_id, - user_id=refresh_token_obj.user_id, - scope=granted_scope, - access_token=access_token, - ) - logger.debug("[OIDC SERVICE] ID token generated successfully") - logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") - - # Rotate refresh token - new_refresh, new_hash = OIDCTokenService.create_refresh_token( - client_id=client_id, - user_id=refresh_token_obj.user_id, - scope=granted_scope, - access_token_id=access_token_jti, - ) - - refresh_token_obj.rotate(new_hash) - - # Store new token metadata - logger.debug("[OIDC SERVICE] Creating access token metadata...") - access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600) - logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z") - OIDCTokenMetadata.create_metadata( - client_id=client.id, - user_id=refresh_token_obj.user_id, - token_type="access_token", - token_jti=access_token_jti, - expires_at=access_token_expires_at, - ) - - # Log refresh event — use client.id (UUID) not client_id (string) for FK - OIDCAuditService.log_token_event( - client_id=client.id, - user_id=refresh_token_obj.user_id, - token_type="access_token", - success=True, - grant_type="refresh_token", - scopes=granted_scope, - ) - - return { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": client.access_token_lifetime or 3600, - "id_token": id_token, - "refresh_token": new_refresh, - } - - logger.debug("[OIDC SERVICE] refresh_access_token completed successfully") - logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") - logger.debug("[OIDC SERVICE] ===========================================") - return { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": client.access_token_lifetime or 3600, - "id_token": id_token, - "refresh_token": new_refresh, - } - - @classmethod - def validate_access_token(cls, token: str, client_id: str = None) -> Dict: - """Validate an access token and return its claims. - - Args: - token: JWT access token - client_id: Optional client ID to validate audience - - Returns: - Token claims - - Raises: - InvalidTokenError: If token is invalid - """ - logger.debug("[OIDC SERVICE] ===========================================") - logger.debug("[OIDC SERVICE] validate_access_token() called") - logger.debug("[OIDC SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token) - logger.debug("[OIDC SERVICE] Token length: %d", len(token)) - logger.debug("[OIDC SERVICE] Client ID: %s", client_id) - - try: - logger.debug("[OIDC SERVICE] Calling OIDCTokenService.validate_access_token()...") - claims = OIDCTokenService.validate_access_token(token, client_id) - logger.debug("[OIDC SERVICE] Token validation successful") - logger.debug("[OIDC SERVICE] Token claims: %s", claims) - logger.debug("[OIDC SERVICE] ===========================================") - return claims - except Exception as e: - logger.error("[OIDC SERVICE] Token validation failed: %s: %s", type(e).__name__, str(e)) - import traceback - logger.error("[OIDC SERVICE] Traceback: %s", traceback.format_exc()) - # Resolve internal client UUID for FK if possible - _client_db_id = None - if client_id: - _c = OIDCClient.query.filter_by(client_id=client_id).first() - _client_db_id = _c.id if _c else None - OIDCAuditService.log_event( - event_type="token_validation", - client_id=_client_db_id, - success=False, - error_code="invalid_token", - error_description=str(e), - ) - raise InvalidTokenError(str(e)) - - @classmethod - def revoke_token( - cls, - token: str, - client_id: str, - token_type_hint: str = None, - ip_address: str = None, - user_agent: str = None - ) -> bool: - """Revoke a token. - - Args: - token: Token to revoke - client_id: OIDC client ID - token_type_hint: Hint about token type - ip_address: Client IP address - user_agent: Client user agent - - Returns: - True if token was revoked - """ - import hashlib - - # Get client - client = OIDCClient.query.filter_by(client_id=client_id).first() - if not client: - raise InvalidClientError() - - revoked = False - token_hash = hashlib.sha256(token.encode()).hexdigest() - - # Try to revoke as refresh token - if token_type_hint in (None, "refresh_token"): - refresh_token = OIDCRefreshToken.query.filter_by( - token_hash=token_hash, - deleted_at=None - ).first() - - if refresh_token: - refresh_token.revoke(reason="revoked_by_client") - revoked = True - - OIDCAuditService.log_token_revocation_event( - client_id=client.id, - user_id=refresh_token.user_id, - token_type="refresh_token", - reason="revoked_by_client", - ) - - # Try to revoke as access token (JTI lookup) - if not revoked or token_type_hint in (None, "access_token"): - try: - # Decode token to get JTI - claims = OIDCTokenService.decode_token(token) - jti = claims.get("jti") - - if jti: - revoked_at = OIDCTokenMetadata.revoke_by_jti( - jti, - reason="revoked_by_client" - ) - if revoked_at: - revoked = True - - OIDCAuditService.log_token_revocation_event( - client_id=client.id, - user_id=claims.get("sub"), - token_type="access_token", - reason="revoked_by_client", - ) - except Exception: - pass - - return revoked - - @classmethod - def introspect_token( - cls, - token: str, - client_id: str = None, - ip_address: str = None, - user_agent: str = None - ) -> Dict: - """Introspect a token and return its status and claims. - - Args: - token: Token to introspect - client_id: Client ID for validation - ip_address: Client IP address - user_agent: Client user agent - - Returns: - Introspection response - """ - result = OIDCTokenService.introspect_token(token, client_id) - - # Log introspection — resolve internal client UUID for FK - _introspect_client_db_id = None - if client_id: - _ic = OIDCClient.query.filter_by(client_id=client_id).first() - _introspect_client_db_id = _ic.id if _ic else None - OIDCAuditService.log_event( - event_type="token_introspection", - client_id=_introspect_client_db_id, - user_id=result.get("sub"), - success=result.get("active", False), - metadata={"active": result.get("active")}, - ) - - return result - - @classmethod - def get_jwks(cls) -> Dict: - """Get the JWKS document. - - Returns: - JWKS document - """ - jwks_service = OIDCJWKSService() - return jwks_service.get_jwks() - - @classmethod - def get_userinfo(cls, access_token: str) -> Dict: - """Get user information using access token. - - Args: - access_token: Access token - - Returns: - User information dictionary - """ - logger.debug("[OIDC SERVICE] ===========================================") - logger.debug("[OIDC SERVICE] get_userinfo() called") - logger.debug("[OIDC SERVICE] Access token (first 50 chars): %s...", access_token[:50] if len(access_token) > 50 else access_token) - logger.debug("[OIDC SERVICE] Access token length: %d", len(access_token)) - - # Validate access token - logger.debug("[OIDC SERVICE] Validating access token...") - claims = cls.validate_access_token(access_token) - logger.debug("[OIDC SERVICE] Access token validated successfully") - logger.debug("[OIDC SERVICE] Token claims: %s", claims) - - user_id = claims.get("sub") - logger.debug("[OIDC SERVICE] User ID from token: %s", user_id) - - logger.debug("[OIDC SERVICE] Querying user from database...") - user = User.query.get(user_id) - logger.debug("[OIDC SERVICE] User query result: %s", user) - - if not user: - logger.error("[OIDC SERVICE] User not found in database: user_id=%s", user_id) - raise NotFoundError("User not found") - - logger.debug("[OIDC SERVICE] User found: user_id=%s, email=%s, full_name=%s", user.id, user.email, user.full_name) - - # Get scopes from token - scope_str = claims.get("scope", "") - scopes = scope_str.split() if scope_str else [] - logger.debug("[OIDC SERVICE] Scope string from token: '%s'", scope_str) - logger.debug("[OIDC SERVICE] Parsed scopes: %s", scopes) - - userinfo = {"sub": user_id} - logger.debug("[OIDC SERVICE] Initial userinfo: %s", userinfo) - - # Add claims based on scope - if "profile" in scopes and user.full_name: - logger.debug("[OIDC SERVICE] Found 'profile' in scope, adding name claim") - userinfo["name"] = user.full_name - logger.debug("[OIDC SERVICE] Added name: %s", user.full_name) - else: - logger.debug("[OIDC SERVICE] 'profile' not in scope or user.full_name is None: profile_in_scope=%s, full_name=%s", "profile" in scopes, user.full_name) - - if "email" in scopes: - logger.debug("[OIDC SERVICE] Found 'email' in scope, adding email claims") - userinfo["email"] = user.email - userinfo["email_verified"] = user.email_verified - logger.debug("[OIDC SERVICE] Added email: %s, email_verified: %s", user.email, user.email_verified) - else: - logger.debug("[OIDC SERVICE] 'email' not in scope") - - if "roles" in scopes: - logger.debug("[OIDC SERVICE] Found 'roles' in scope, adding roles claim") - user_roles = cls._get_user_roles(user) - userinfo["roles"] = user_roles - logger.debug("[OIDC SERVICE] Added roles: %s", user_roles) - else: - logger.debug("[OIDC SERVICE] 'roles' not in scope") - - logger.debug("[OIDC SERVICE] Final userinfo: %s", userinfo) - - # Log userinfo access — resolve internal client UUID for FK - logger.debug("[OIDC SERVICE] Logging userinfo access event...") - _userinfo_client_id_str = claims.get("client_id") - _userinfo_client_db_id = None - if _userinfo_client_id_str: - _uc = OIDCClient.query.filter_by(client_id=_userinfo_client_id_str).first() - _userinfo_client_db_id = _uc.id if _uc else None - OIDCAuditService.log_userinfo_event( - access_token=access_token, - user_id=user_id, - client_id=_userinfo_client_db_id, - success=True, - scopes_claimed=scopes, - ) - logger.debug("[OIDC SERVICE] Userinfo access event logged") - - logger.debug("[OIDC SERVICE] get_userinfo() completed successfully") - logger.debug("[OIDC SERVICE] ===========================================") - - return userinfo - - @staticmethod - def _get_user_roles(user: User) -> list: - """Get user's organization roles. - - Args: - user: User instance - - Returns: - List of role objects with organization_id and role - """ - logger.debug("[OIDC SERVICE] _get_user_roles() called") - logger.debug("[OIDC SERVICE] User: %s", user) - - roles = [] - - if not user: - logger.debug("[OIDC SERVICE] User is None, returning empty roles list") - return roles - - logger.debug("[OIDC SERVICE] User ID: %s", user.id) - logger.debug("[OIDC SERVICE] User email: %s", user.email) - logger.debug("[OIDC SERVICE] User organization_memberships: %s", user.organization_memberships) - - if user.organization_memberships: - logger.debug("[OIDC SERVICE] User has %d organization memberships", len(user.organization_memberships)) - for idx, member in enumerate(user.organization_memberships): - logger.debug("[OIDC SERVICE] Processing membership %d: member=%s", idx, member) - logger.debug("[OIDC SERVICE] organization_id: %s", member.organization_id) - logger.debug("[OIDC SERVICE] role: %s", member.role) - logger.debug("[OIDC SERVICE] role.value: %s", member.role.value) - - role_entry = { - "organization_id": str(member.organization_id), - "role": member.role.value - } - roles.append(role_entry) - logger.debug("[OIDC SERVICE] Added role entry: %s", role_entry) - else: - logger.debug("[OIDC SERVICE] User has no organization memberships") - - logger.debug("[OIDC SERVICE] Final roles list: %s", roles) - logger.debug("[OIDC SERVICE] _get_user_roles() completed") - - return roles diff --git a/gatehouse_app/services/organization_service.py b/gatehouse_app/services/organization_service.py index c071a9d..3df2db7 100644 --- a/gatehouse_app/services/organization_service.py +++ b/gatehouse_app/services/organization_service.py @@ -1,5 +1,6 @@ """Organization service.""" import logging +import uuid from datetime import datetime, timezone from flask import current_app from gatehouse_app.extensions import db @@ -157,6 +158,12 @@ class OrganizationService: Returns: Deleted Organization instance """ + if soft: + # Mangle slug so it can be reused + original_slug = org.slug + org.slug = f"{original_slug}__deleted_{uuid.uuid4().hex[:8]}" + org.is_active = False + org.delete(soft=soft) # Log organization deletion @@ -174,11 +181,16 @@ class OrganizationService: @staticmethod def force_delete_organization(org, user_id): """ - Force-delete an organization and all its members in a single atomic operation. + Force-delete an organization and ALL associated data in a single atomic + operation. - All active memberships are soft-deleted before the organization itself - is soft-deleted, preventing orphaned membership rows and avoiding any - cascade deadlocks. + Cleans up: + - All active memberships (soft-deleted) + - MFA policy compliance records for this org + - User security policy overrides for this org + - Pending invite tokens for this org + - OIDC clients for this org + - The organization slug is mangled so the same slug can be reused Args: org: Organization instance @@ -187,31 +199,71 @@ class OrganizationService: Returns: Deleted Organization instance """ - from datetime import datetime, timezone + from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance + from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy + from gatehouse_app.models.organization.org_invite_token import OrgInviteToken now = datetime.now(timezone.utc) member_count = 0 + cleanup_counts = {} - # Soft-delete all active memberships first. + # 1. Soft-delete all active memberships first. for member in org.members: if member.deleted_at is None: member.deleted_at = now member_count += 1 - # Now soft-delete the organization itself. - org.delete(soft=True) + # 2. Remove MFA compliance records for this org so the compliance job + # doesn't accidentally process stale records for a deleted org. + compliance_records = MfaPolicyCompliance.query.filter_by( + organization_id=org.id, + ).filter(MfaPolicyCompliance.deleted_at == None).all() + for record in compliance_records: + record.deleted_at = now + cleanup_counts["compliance_records"] = len(compliance_records) - # Log with member count for audit trail. + # 3. Remove user security policy overrides for this org. + user_policies = UserSecurityPolicy.query.filter_by( + organization_id=org.id, + ).filter(UserSecurityPolicy.deleted_at == None).all() + for policy in user_policies: + policy.deleted_at = now + cleanup_counts["user_security_policies"] = len(user_policies) + + # 4. Remove pending invite tokens for this org. + pending_invites = OrgInviteToken.query.filter_by( + organization_id=org.id, + ).filter(OrgInviteToken.accepted_at == None, OrgInviteToken.deleted_at == None).all() + for invite in pending_invites: + invite.deleted_at = now + cleanup_counts["pending_invites"] = len(pending_invites) + + # 5. Mangle the slug so the same slug can be reused for a new org. + # Format: "original-slug__deleted_" + original_slug = org.slug + org.slug = f"{original_slug}__deleted_{uuid.uuid4().hex[:8]}" + + # 6. Now soft-delete the organization itself. + org.deleted_at = now + org.is_active = False + db.session.commit() + + # Log with member count and cleanup summary for audit trail. AuditService.log_action( action=AuditAction.ORG_DELETE, user_id=user_id, organization_id=org.id, resource_type="organization", resource_id=org.id, - metadata={"members_removed": member_count}, + metadata={ + "members_removed": member_count, + "original_slug": original_slug, + **cleanup_counts, + }, description=( - f"Organization deleted by owner; " - f"{member_count} membership(s) removed." + f"Organization '{original_slug}' deleted by owner; " + f"{member_count} membership(s) removed, " + f"{cleanup_counts.get('compliance_records', 0)} compliance record(s) cleaned." ), ) diff --git a/gatehouse_app/utils/constants.py b/gatehouse_app/utils/constants.py index 2a99825..201803b 100644 --- a/gatehouse_app/utils/constants.py +++ b/gatehouse_app/utils/constants.py @@ -90,6 +90,10 @@ class AuditAction(str, Enum): TOTP_DISABLED = "totp.disabled" TOTP_BACKUP_CODE_USED = "totp.backup_code.used" TOTP_BACKUP_CODES_REGENERATED = "totp.backup_codes.regenerated" + ADMIN_MFA_REMOVE = "admin.mfa.remove" + ADMIN_OAUTH_UNLINK = "admin.oauth.unlink" + ADMIN_PASSWORD_SET = "admin.password.set" + ADMIN_EMAIL_VERIFY = "admin.email.verify" # WebAuthn actions WEBAUTHN_REGISTER_INITIATED = "webauthn.register.initiated" diff --git a/migrations/versions/019_convert_auditaction_enum_to_varchar.py b/migrations/versions/019_convert_auditaction_enum_to_varchar.py new file mode 100644 index 0000000..1b419ac --- /dev/null +++ b/migrations/versions/019_convert_auditaction_enum_to_varchar.py @@ -0,0 +1,143 @@ +"""Convert audit_logs.action from auditaction enum to VARCHAR(100). + +Revision ID: 019_audit_varchar +Revises: 018_audit_enum_values, db15faee1fb8 +Create Date: 2026-03-04 + +WHY +--- +The PostgreSQL `auditaction` ENUM type must be explicitly altered every time a +new AuditAction is added to the Python enum, otherwise the INSERT fails with: + + psycopg2.errors.InvalidTextRepresentation: + invalid input value for enum auditaction: "admin.mfa.remove" + +The Python enum was refactored from UPPER_SNAKE_CASE to lower.dot.case string +values, but only the UPPER_SNAKE_CASE values exist in the DB type. Rather +than add every new value forever, we convert the column to VARCHAR(100) which +accepts any string — the Python layer already validates the value via the Enum. + +DATA MIGRATION +-------------- +All existing rows store UPPER_SNAKE_CASE values. We map each one to the +corresponding new lower.dot.case string so historical audit logs remain +queryable with the current enum. +""" +from alembic import op +import sqlalchemy as sa + +revision = "019_audit_varchar" +down_revision = ("018_audit_enum_values", "db15faee1fb8") +branch_labels = None +depends_on = None + +# Map every UPPER_SNAKE_CASE DB value → its new lower.dot.case Python value. +VALUE_MAP = { + "USER_LOGIN": "user.login", + "USER_LOGOUT": "user.logout", + "USER_REGISTER": "user.register", + "USER_UPDATE": "user.update", + "USER_DELETE": "user.delete", + "USER_HARD_DELETE": "user.hard_delete", + "USER_SUSPEND": "user.suspend", + "USER_UNSUSPEND": "user.unsuspend", + "PASSWORD_CHANGE": "user.password_change", + "PASSWORD_RESET": "user.password_reset", + "ORG_CREATE": "org.create", + "ORG_UPDATE": "org.update", + "ORG_DELETE": "org.delete", + "ORG_MEMBER_ADD": "org.member.add", + "ORG_MEMBER_REMOVE": "org.member.remove", + "ORG_MEMBER_ROLE_CHANGE": "org.member.role_change", + "ORG_OWNERSHIP_TRANSFERRED": "org.ownership.transferred", + "SESSION_CREATE": "session.create", + "SESSION_REVOKE": "session.revoke", + "AUTH_METHOD_ADD": "auth.method.add", + "AUTH_METHOD_REMOVE": "auth.method.remove", + "TOTP_ENROLL_INITIATED": "totp.enroll.initiated", + "TOTP_ENROLL_COMPLETED": "totp.enroll.completed", + "TOTP_VERIFY_SUCCESS": "totp.verify.success", + "TOTP_VERIFY_FAILED": "totp.verify.failed", + "TOTP_DISABLED": "totp.disabled", + "TOTP_BACKUP_CODE_USED": "totp.backup_code.used", + "TOTP_BACKUP_CODES_REGENERATED": "totp.backup_codes.regenerated", + "WEBAUTHN_REGISTER_INITIATED": "webauthn.register.initiated", + "WEBAUTHN_REGISTER_COMPLETED": "webauthn.register.completed", + "WEBAUTHN_REGISTER_FAILED": "webauthn.register.failed", + "WEBAUTHN_LOGIN_INITIATED": "webauthn.login.initiated", + "WEBAUTHN_LOGIN_SUCCESS": "webauthn.login.success", + "WEBAUTHN_LOGIN_FAILED": "webauthn.login.failed", + "WEBAUTHN_CREDENTIAL_DELETED": "webauthn.credential.deleted", + "WEBAUTHN_CREDENTIAL_RENAMED": "webauthn.credential.renamed", + "ORG_SECURITY_POLICY_UPDATE": "org.security_policy.update", + "USER_SECURITY_POLICY_OVERRIDE_UPDATE":"user.security_policy.override_update", + "MFA_POLICY_USER_SUSPENDED": "mfa.policy.user_suspended", + "MFA_POLICY_USER_COMPLIANT": "mfa.policy.user_compliant", + "EXTERNAL_AUTH_LINK_INITIATED": "external_auth.link.initiated", + "EXTERNAL_AUTH_LINK_COMPLETED": "external_auth.link.completed", + "EXTERNAL_AUTH_LINK_FAILED": "external_auth.link.failed", + "EXTERNAL_AUTH_UNLINK": "external_auth.unlink", + "EXTERNAL_AUTH_LOGIN": "external_auth.login", + "EXTERNAL_AUTH_LOGIN_FAILED": "external_auth.login.failed", + "EXTERNAL_AUTH_TOKEN_REFRESH": "external_auth.token_refresh", + "EXTERNAL_AUTH_CONFIG_CREATE": "external_auth.config.create", + "EXTERNAL_AUTH_CONFIG_UPDATE": "external_auth.config.update", + "EXTERNAL_AUTH_CONFIG_DELETE": "external_auth.config.delete", + "SSH_KEY_ADDED": "ssh.key.added", + "SSH_KEY_VERIFIED": "ssh.key.verified", + "SSH_KEY_DELETED": "ssh.key.deleted", + "SSH_KEY_VALIDATION_FAILED": "ssh.key.validation.failed", + "SSH_CERT_REQUESTED": "ssh.cert.requested", + "SSH_CERT_ISSUED": "ssh.cert.issued", + "SSH_CERT_FAILED": "ssh.cert.failed", + "SSH_CERT_REVOKED": "ssh.cert.revoked", + "SSH_CERT_EXPIRED": "ssh.cert.expired", + "CA_CREATED": "ca.created", + "CA_UPDATED": "ca.updated", + "CA_DELETED": "ca.deleted", + "CA_KEY_ROTATED": "ca.key.rotated", + "PRINCIPAL_CREATED": "principal.created", + "PRINCIPAL_UPDATED": "principal.updated", + "PRINCIPAL_DELETED": "principal.deleted", + "PRINCIPAL_MEMBER_ADDED": "principal.member.added", + "PRINCIPAL_MEMBER_REMOVED": "principal.member.removed", + "DEPARTMENT_CREATED": "department.created", + "DEPARTMENT_UPDATED": "department.updated", + "DEPARTMENT_DELETED": "department.deleted", + "DEPARTMENT_MEMBER_ADDED": "department.member.added", + "DEPARTMENT_MEMBER_REMOVED": "department.member.removed", +} + + +def upgrade(): + conn = op.get_bind() + + # 1. Add a temporary VARCHAR column + op.add_column("audit_logs", sa.Column("action_new", sa.String(100), nullable=True)) + + # 2. Populate it: map old UPPER_SNAKE_CASE to new lower.dot.case + for old_val, new_val in VALUE_MAP.items(): + conn.execute( + sa.text("UPDATE audit_logs SET action_new = :new WHERE action::text = :old"), + {"new": new_val, "old": old_val}, + ) + + # 3. Any unmapped rows (shouldn't exist, but be safe): copy as-is + conn.execute(sa.text("UPDATE audit_logs SET action_new = action::text WHERE action_new IS NULL")) + + # 4. Drop the old enum column, rename the new one + op.drop_column("audit_logs", "action") + op.alter_column("audit_logs", "action_new", new_column_name="action", nullable=False) + + # 5. Recreate the index (was on the old column) + op.create_index("ix_audit_logs_action", "audit_logs", ["action"]) + op.create_index("idx_audit_user_action", "audit_logs", ["user_id", "action"]) + + # 6. Drop the now-unused auditaction enum type + op.execute("DROP TYPE IF EXISTS auditaction") + + +def downgrade(): + # Converting VARCHAR back to a custom enum is complex and lossy for new + # values — provide a no-op downgrade. Run a previous backup to revert. + pass