diff --git a/README.md b/README.md index fabcf2f..b93c081 100644 --- a/README.md +++ b/README.md @@ -286,7 +286,6 @@ For issues and questions: # Boostrap db python manage.py db upgrade -python manage.py db migrate diff --git a/gatehouse_app/api/oidc.py b/gatehouse_app/api/oidc.py index cd579c3..f670cce 100644 --- a/gatehouse_app/api/oidc.py +++ b/gatehouse_app/api/oidc.py @@ -16,6 +16,7 @@ from gatehouse_app.services.oidc_service import ( OIDCService, InvalidClientError, InvalidGrantError, InvalidRequestError ) from gatehouse_app.services.auth_service import AuthService +from gatehouse_app.services.mfa_policy_service import MfaPolicyService from gatehouse_app.extensions import db from gatehouse_app.extensions import bcrypt as flask_bcrypt from gatehouse_app.models import User, OIDCClient @@ -372,6 +373,23 @@ def oidc_authorize(): logger.debug("[OIDC] Attempting user authentication for email: %s", email) try: user = AuthService.authenticate(email, password) + + # Evaluate MFA policy after primary authentication + policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False) + + # Check if user can create full session + if not policy_result.can_create_full_session: + logger.debug("[OIDC] User cannot create full session due to MFA compliance: user_id=%s, email=%s", user.id, email) + return _show_login_page( + client_id=client_id, + redirect_uri=redirect_uri, + scope=scope, + state=state, + nonce=nonce, + response_type=response_type, + error="Your account requires multi factor enrollment before using single sign on" + ) + user_id = user.id session["oidc_user_id"] = user_id diff --git a/gatehouse_app/api/v1/__init__.py b/gatehouse_app/api/v1/__init__.py index 7e81ffe..56b7c7b 100644 --- a/gatehouse_app/api/v1/__init__.py +++ b/gatehouse_app/api/v1/__init__.py @@ -5,4 +5,4 @@ from flask import Blueprint api_v1_bp = Blueprint("api_v1", __name__) # Import route modules to register them -from gatehouse_app.api.v1 import auth, users, organizations +from gatehouse_app.api.v1 import auth, users, organizations, policies diff --git a/gatehouse_app/api/v1/auth.py b/gatehouse_app/api/v1/auth.py index 36fdff9..8495c2a 100644 --- a/gatehouse_app/api/v1/auth.py +++ b/gatehouse_app/api/v1/auth.py @@ -22,6 +22,7 @@ from gatehouse_app.schemas.webauthn_schema import ( from gatehouse_app.services.auth_service import AuthService from gatehouse_app.services.webauthn_service import WebAuthnService from gatehouse_app.services.user_service import UserService +from gatehouse_app.services.mfa_policy_service import MfaPolicyService from gatehouse_app.utils.decorators import login_required from gatehouse_app.utils.constants import AuditAction from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError @@ -94,6 +95,9 @@ def login(): 400: Validation error 401: Invalid credentials """ + import logging + logger = logging.getLogger(__name__) + try: # Validate request data schema = LoginSchema() @@ -105,6 +109,11 @@ def login(): password=data["password"], ) + # SECURITY CHECK: Log MFA enrollment status to validate the vulnerability + has_totp = user.has_totp_enabled() + has_webauthn = user.has_webauthn_enabled() + logger.warning(f"[SECURITY DIAGNOSTIC] Login attempt for user {user.email} - TOTP enabled: {has_totp}, WebAuthn enabled: {has_webauthn}") + # Check if user has TOTP enabled for two-factor authentication if user.has_totp_enabled(): # TOTP is enabled - store user_id in session for TOTP verification @@ -121,16 +130,56 @@ def login(): ) # TOTP is NOT enabled - proceed with normal login flow + # SECURITY DIAGNOSTIC: This is where the vulnerability occurs - no WebAuthn check! + if has_webauthn: + logger.error(f"[SECURITY VULNERABILITY DETECTED] User {user.email} has WebAuthn enrolled but is bypassing it! Creating session without MFA verification.") + + # Evaluate MFA policy after primary authentication + remember_me = data.get("remember_me", False) + policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me) + # Create session with appropriate duration based on remember_me preference - duration = 2592000 if data.get("remember_me") else 86400 # 30 days vs 1 day - user_session = AuthService.create_session(user, duration_seconds=duration) + duration = 2592000 if remember_me else 86400 # 30 days vs 1 day + + # Determine if this should be a compliance-only session + is_compliance_only = policy_result.create_compliance_only_session + + user_session = AuthService.create_session( + user, + duration_seconds=duration, + is_compliance_only=is_compliance_only + ) + + # Build response data + response_data = { + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(), + } + + # Add MFA compliance information + if policy_result.compliance_summary: + response_data["mfa_compliance"] = { + "overall_status": policy_result.compliance_summary.overall_status, + "missing_methods": policy_result.compliance_summary.missing_methods, + "deadline_at": policy_result.compliance_summary.deadline_at, + "orgs": [ + { + "organization_id": org.organization_id, + "organization_name": org.organization_name, + "status": org.status, + "deadline_at": org.deadline_at, + } + for org in policy_result.compliance_summary.orgs + ], + } + + # Add requires_mfa_enrollment flag if compliance-only session + if is_compliance_only: + response_data["requires_mfa_enrollment"] = True return api_response( - data={ - "user": user.to_dict(), - "token": user_session.token, - "expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(), - }, + data=response_data, message="Login successful", ) @@ -380,20 +429,50 @@ def verify_totp(): client_utc_timestamp=data.get("client_timestamp"), ) - # Create full session - user_session = AuthService.create_session(user) + # Evaluate MFA policy after primary authentication + policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False) + + # Determine if this should be a compliance-only session + is_compliance_only = policy_result.create_compliance_only_session + + # Create session + user_session = AuthService.create_session(user, is_compliance_only=is_compliance_only) # Clear temporary session session.pop("totp_pending_user_id", None) + # Build response data + response_data = { + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z" + if user_session.expires_at.isoformat()[-1] != "Z" + else user_session.expires_at.isoformat(), + } + + # Add MFA compliance information + if policy_result.compliance_summary: + response_data["mfa_compliance"] = { + "overall_status": policy_result.compliance_summary.overall_status, + "missing_methods": policy_result.compliance_summary.missing_methods, + "deadline_at": policy_result.compliance_summary.deadline_at, + "orgs": [ + { + "organization_id": org.organization_id, + "organization_name": org.organization_name, + "status": org.status, + "deadline_at": org.deadline_at, + } + for org in policy_result.compliance_summary.orgs + ], + } + + # Add requires_mfa_enrollment flag if compliance-only session + if is_compliance_only: + response_data["requires_mfa_enrollment"] = True + return api_response( - data={ - "user": user.to_dict(), - "token": user_session.token, - "expires_at": user_session.expires_at.isoformat() + "Z" - if user_session.expires_at.isoformat()[-1] != "Z" - else user_session.expires_at.isoformat(), - }, + data=response_data, message="TOTP verification successful", ) @@ -835,22 +914,52 @@ def complete_webauthn_login(): challenge ) + # Evaluate MFA policy after primary authentication + policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False) + + # Determine if this should be a compliance-only session + is_compliance_only = policy_result.create_compliance_only_session + # Create session - user_session = AuthService.create_session(user) + user_session = AuthService.create_session(user, is_compliance_only=is_compliance_only) # Clear pending session session.pop("webauthn_pending_user_id", None) logger.info(f"WebAuthn login completed successfully for user: {user.email}") + # Build response data + response_data = { + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z" + if user_session.expires_at.isoformat()[-1] != "Z" + else user_session.expires_at.isoformat(), + } + + # Add MFA compliance information + if policy_result.compliance_summary: + response_data["mfa_compliance"] = { + "overall_status": policy_result.compliance_summary.overall_status, + "missing_methods": policy_result.compliance_summary.missing_methods, + "deadline_at": policy_result.compliance_summary.deadline_at, + "orgs": [ + { + "organization_id": org.organization_id, + "organization_name": org.organization_name, + "status": org.status, + "deadline_at": org.deadline_at, + } + for org in policy_result.compliance_summary.orgs + ], + } + + # Add requires_mfa_enrollment flag if compliance-only session + if is_compliance_only: + response_data["requires_mfa_enrollment"] = True + return api_response( - data={ - "user": user.to_dict(), - "token": user_session.token, - "expires_at": user_session.expires_at.isoformat() + "Z" - if user_session.expires_at.isoformat()[-1] != "Z" - else user_session.expires_at.isoformat(), - }, + data=response_data, message="Login successful", ) diff --git a/gatehouse_app/api/v1/organizations.py b/gatehouse_app/api/v1/organizations.py index 4b005aa..98c39e1 100644 --- a/gatehouse_app/api/v1/organizations.py +++ b/gatehouse_app/api/v1/organizations.py @@ -3,7 +3,7 @@ from flask import g, request from marshmallow import ValidationError from gatehouse_app.api.v1 import api_v1_bp from gatehouse_app.utils.response import api_response -from gatehouse_app.utils.decorators import login_required, require_admin, require_owner +from gatehouse_app.utils.decorators import login_required, require_admin, require_owner, full_access_required from gatehouse_app.schemas.organization_schema import ( OrganizationCreateSchema, OrganizationUpdateSchema, @@ -17,6 +17,7 @@ from gatehouse_app.utils.constants import OrganizationRole @api_v1_bp.route("/organizations", methods=["POST"]) @login_required +@full_access_required def create_organization(): """ Create a new organization. @@ -65,6 +66,7 @@ def create_organization(): @api_v1_bp.route("/organizations/", methods=["GET"]) @login_required +@full_access_required def get_organization(org_id): """ Get organization by ID. @@ -101,6 +103,7 @@ def get_organization(org_id): @api_v1_bp.route("/organizations/", methods=["PATCH"]) @login_required @require_admin +@full_access_required def update_organization(org_id): """ Update organization. @@ -152,6 +155,7 @@ def update_organization(org_id): @api_v1_bp.route("/organizations/", methods=["DELETE"]) @login_required @require_owner +@full_access_required def delete_organization(org_id): """ Delete organization (soft delete). @@ -180,6 +184,7 @@ def delete_organization(org_id): @api_v1_bp.route("/organizations//members", methods=["GET"]) @login_required +@full_access_required def get_organization_members(org_id): """ Get all members of an organization. @@ -223,6 +228,7 @@ def get_organization_members(org_id): @api_v1_bp.route("/organizations//members", methods=["POST"]) @login_required @require_admin +@full_access_required def add_organization_member(org_id): """ Add a member to the organization. @@ -290,6 +296,7 @@ def add_organization_member(org_id): @api_v1_bp.route("/organizations//members/", methods=["DELETE"]) @login_required @require_admin +@full_access_required def remove_organization_member(org_id, user_id): """ Remove a member from the organization. @@ -320,6 +327,7 @@ def remove_organization_member(org_id, user_id): @api_v1_bp.route("/organizations//members//role", methods=["PATCH"]) @login_required @require_admin +@full_access_required def update_member_role(org_id, user_id): """ Update a member's role. diff --git a/gatehouse_app/api/v1/policies.py b/gatehouse_app/api/v1/policies.py new file mode 100644 index 0000000..ddf16f2 --- /dev/null +++ b/gatehouse_app/api/v1/policies.py @@ -0,0 +1,336 @@ +"""Security policy endpoints.""" +from flask import g, request +from marshmallow import Schema, fields, validate, ValidationError +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required +from gatehouse_app.services.mfa_policy_service import MfaPolicyService +from gatehouse_app.services.organization_service import OrganizationService +from gatehouse_app.services.audit_service import AuditService +from gatehouse_app.utils.constants import MfaPolicyMode, MfaRequirementOverride, MfaComplianceStatus, AuditAction + + +class UpdateOrgPolicySchema(Schema): + """Schema for updating organization security policy.""" + mfa_policy_mode = fields.String( + required=False, + validate=validate.OneOf([m.value for m in MfaPolicyMode]) + ) + mfa_grace_period_days = fields.Integer( + required=False, + validate=validate.Range(min=0, max=365) + ) + notify_days_before = fields.Integer( + required=False, + validate=validate.Range(min=0, max=30) + ) + + +class UpdateUserPolicySchema(Schema): + """Schema for updating user security policy override.""" + mfa_override_mode = fields.String( + required=True, + validate=validate.OneOf([m.value for m in MfaRequirementOverride]) + ) + force_totp = fields.Boolean(required=False, load_default=False) + force_webauthn = fields.Boolean(required=False, load_default=False) + + +class ComplianceListQuerySchema(Schema): + """Schema for compliance list query parameters.""" + status = fields.String(required=False) + limit = fields.Integer(required=False, load_default=100) + offset = fields.Integer(required=False, load_default=0) + + +@api_v1_bp.route("/organizations//security-policy", methods=["GET"]) +@login_required +def get_org_security_policy(org_id): + """ + Get organization security policy. + + Args: + org_id: Organization ID + + Returns: + 200: Organization security policy + 401: Not authenticated + 403: Not a member + 404: Organization not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + # Check if user is a member + if not org.is_member(g.current_user.id): + return api_response( + success=False, + message="You are not a member of this organization", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + policy_dto = MfaPolicyService.get_org_policy(org_id) + + if policy_dto: + data = { + "organization_id": policy_dto.organization_id, + "mfa_policy_mode": policy_dto.mfa_policy_mode, + "mfa_grace_period_days": policy_dto.mfa_grace_period_days, + "notify_days_before": policy_dto.notify_days_before, + "policy_version": policy_dto.policy_version, + } + else: + # Return default policy if none exists + data = { + "organization_id": org_id, + "mfa_policy_mode": MfaPolicyMode.OPTIONAL.value, + "mfa_grace_period_days": 14, + "notify_days_before": 7, + "policy_version": 0, + } + + return api_response( + data={"security_policy": data}, + message="Security policy retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//security-policy", methods=["PUT"]) +@login_required +@require_admin +@full_access_required +def update_org_security_policy(org_id): + """ + Update organization security policy. + + Args: + org_id: Organization ID + + Request body: + mfa_policy_mode: MFA policy mode (disabled, optional, require_totp, require_webauthn, require_totp_or_webauthn) + mfa_grace_period_days: Grace period in days (0-365) + notify_days_before: Days before deadline to notify (0-30) + + Returns: + 200: Security policy updated successfully + 400: Validation error + 401: Not authenticated + 403: Not an admin + 404: Organization not found + """ + try: + schema = UpdateOrgPolicySchema() + data = schema.load(request.json) + + org = OrganizationService.get_organization_by_id(org_id) + + # Update policy + policy = MfaPolicyService.create_org_policy( + organization_id=org_id, + mfa_policy_mode=MfaPolicyMode(data.get("mfa_policy_mode", MfaPolicyMode.OPTIONAL.value)), + mfa_grace_period_days=data.get("mfa_grace_period_days", 14), + notify_days_before=data.get("notify_days_before", 7), + updated_by_user_id=g.current_user.id, + ) + + return api_response( + data={ + "security_policy": { + "organization_id": policy.organization_id, + "mfa_policy_mode": policy.mfa_policy_mode.value, + "mfa_grace_period_days": policy.mfa_grace_period_days, + "notify_days_before": policy.notify_days_before, + "policy_version": policy.policy_version, + } + }, + message="Security policy updated successfully", + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + +@api_v1_bp.route("/organizations//mfa-compliance", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def get_org_mfa_compliance(org_id): + """ + Get MFA compliance list for an organization. + + Args: + org_id: Organization ID + + Query params: + status: Optional status filter (not_applicable, pending, in_grace, compliant, past_due, suspended) + limit: Maximum records to return (default 100) + offset: Offset for pagination (default 0) + + Returns: + 200: List of compliance records + 401: Not authenticated + 403: Not an admin + 404: Organization not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + # Parse query params + status = None + if request.args.get("status"): + try: + status = MfaComplianceStatus(request.args.get("status")) + except ValueError: + return api_response( + success=False, + message="Invalid status value", + status=400, + error_type="VALIDATION_ERROR", + ) + + limit = min(int(request.args.get("limit", 100)), 100) + offset = int(request.args.get("offset", 0)) + + compliance_list = MfaPolicyService.get_org_compliance_list( + organization_id=org_id, + status=status, + limit=limit, + offset=offset, + ) + + return api_response( + data={ + "compliance": compliance_list, + "count": len(compliance_list), + "limit": limit, + "offset": offset, + }, + message="Compliance records retrieved successfully", + ) + + +@api_v1_bp.route( + "/organizations//users//security-policy", methods=["PATCH"] +) +@login_required +@require_admin +@full_access_required +def update_user_security_policy(org_id, user_id): + """ + Update user security policy override. + + Args: + org_id: Organization ID + user_id: User ID + + Request body: + mfa_override_mode: Override mode (inherit, required, exempt) + force_totp: Force TOTP requirement (default false) + force_webauthn: Force WebAuthn requirement (default false) + + Returns: + 200: User policy updated successfully + 400: Validation error + 401: Not authenticated + 403: Not an admin + 404: Organization or user not found + """ + try: + schema = UpdateUserPolicySchema() + data = schema.load(request.json) + + org = OrganizationService.get_organization_by_id(org_id) + + # Check if user is a member of the organization + if not org.is_member(user_id): + return api_response( + success=False, + message="User is not a member of this organization", + status=404, + error_type="NOT_FOUND", + ) + + # Update user policy + policy = MfaPolicyService.set_user_override( + user_id=user_id, + organization_id=org_id, + mfa_override_mode=MfaRequirementOverride(data["mfa_override_mode"]), + force_totp=data.get("force_totp", False), + force_webauthn=data.get("force_webauthn", False), + updated_by_user_id=g.current_user.id, + ) + + # Log the override change with details + AuditService.log_action( + action=AuditAction.USER_SECURITY_POLICY_OVERRIDE_UPDATE, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="user", + resource_id=user_id, + description=f"User security policy override changed to {data['mfa_override_mode']} for user {user_id}", + ) + + return api_response( + data={ + "user_security_policy": { + "user_id": policy.user_id, + "organization_id": policy.organization_id, + "mfa_override_mode": policy.mfa_override_mode.value, + "force_totp": policy.force_totp, + "force_webauthn": policy.force_webauthn, + } + }, + message="User security policy updated successfully", + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + +@api_v1_bp.route("/users/me/mfa-compliance", methods=["GET"]) +@login_required +def get_my_mfa_compliance(): + """ + Get current user's MFA compliance across all organizations. + + Returns: + 200: MFA compliance summary + 401: Not authenticated + """ + user = g.current_user + + compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user) + + orgs = [] + for org_state in compliance_summary.orgs: + orgs.append({ + "organization_id": org_state.organization_id, + "organization_name": org_state.organization_name, + "status": org_state.status, + "effective_mode": org_state.effective_mode, + "deadline_at": org_state.deadline_at, + "applied_at": org_state.applied_at, + }) + + return api_response( + data={ + "mfa_compliance": { + "overall_status": compliance_summary.overall_status, + "missing_methods": compliance_summary.missing_methods, + "deadline_at": compliance_summary.deadline_at, + "orgs": orgs, + } + }, + message="MFA compliance retrieved successfully", + ) \ No newline at end of file diff --git a/gatehouse_app/api/v1/users.py b/gatehouse_app/api/v1/users.py index b6a9c8d..407e373 100644 --- a/gatehouse_app/api/v1/users.py +++ b/gatehouse_app/api/v1/users.py @@ -3,7 +3,7 @@ from flask import g, request from marshmallow import ValidationError from gatehouse_app.api.v1 import api_v1_bp from gatehouse_app.utils.response import api_response -from gatehouse_app.utils.decorators import login_required +from gatehouse_app.utils.decorators import login_required, full_access_required from gatehouse_app.schemas.user_schema import UserUpdateSchema, ChangePasswordSchema from gatehouse_app.services.user_service import UserService from gatehouse_app.services.auth_service import AuthService @@ -29,6 +29,7 @@ def get_me(): @api_v1_bp.route("/users/me", methods=["PATCH"]) @login_required +@full_access_required def update_me(): """ Update current user profile. @@ -67,6 +68,7 @@ def update_me(): @api_v1_bp.route("/users/me", methods=["DELETE"]) @login_required +@full_access_required def delete_me(): """ Delete current user account (soft delete). @@ -84,6 +86,7 @@ def delete_me(): @api_v1_bp.route("/users/me/password", methods=["POST"]) @login_required +@full_access_required def change_password(): """ Change current user password. @@ -136,6 +139,7 @@ def change_password(): @api_v1_bp.route("/users/me/organizations", methods=["GET"]) @login_required +@full_access_required def get_my_organizations(): """ Get all organizations current user is a member of. diff --git a/gatehouse_app/jobs/__init__.py b/gatehouse_app/jobs/__init__.py new file mode 100644 index 0000000..68eef1a --- /dev/null +++ b/gatehouse_app/jobs/__init__.py @@ -0,0 +1 @@ +Jobs module for scheduled tasks. \ No newline at end of file diff --git a/gatehouse_app/jobs/mfa_compliance_job.py b/gatehouse_app/jobs/mfa_compliance_job.py new file mode 100644 index 0000000..a8548c8 --- /dev/null +++ b/gatehouse_app/jobs/mfa_compliance_job.py @@ -0,0 +1,279 @@ +"""MFA Compliance Scheduled Job. + +This module implements the scheduled job for processing MFA compliance transitions, +sending notifications to users approaching deadlines, and handling edge cases. + +The job is designed to be run periodically (e.g., via cron) to: +1. Transition users from PAST_DUE to SUSPENDED status +2. Send deadline reminder notifications to users in grace period +3. Update notification tracking metadata + +Usage: + python manage.py run_mfa_compliance_job + +Or call directly: + from gatehouse_app.jobs.mfa_compliance_job import process_mfa_compliance + process_mfa_compliance() +""" +from datetime import datetime, timezone, timedelta +from typing import Optional, Dict, Any, List +import logging + +from gatehouse_app.extensions import db +from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance +from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy +from gatehouse_app.models.user import User +from gatehouse_app.services.mfa_policy_service import MfaPolicyService +from gatehouse_app.services.notification_service import NotificationService +from gatehouse_app.utils.constants import MfaComplianceStatus + +logger = logging.getLogger(__name__) + + +def process_mfa_compliance(now: Optional[datetime] = None) -> Dict[str, Any]: + """Process MFA compliance transitions and send notifications. + + This scheduled job performs the following operations: + 1. Transitions users from PAST_DUE to SUSPENDED status + 2. Identifies users approaching deadline (within notify_days_before) + 3. Sends deadline reminder notifications + 4. Updates notification tracking metadata + + Args: + now: Current time, defaults to now (UTC) + + Returns: + Dictionary with job execution statistics: + - suspended_count: Number of users transitioned to suspended + - notified_count: Number of notifications sent + - processed_count: Total compliance records processed + """ + if now is None: + now = datetime.now(timezone.utc) + + logger.info(f"Starting MFA compliance job at {now.isoformat()}") + + stats = { + "suspended_count": 0, + "notified_count": 0, + "processed_count": 0, + "errors": [], + } + + try: + # Step 1: Transition past-due users to suspended + suspended_count = MfaPolicyService.transition_to_suspended_if_past_due(now) + stats["suspended_count"] = suspended_count + logger.info(f"Transitioned {suspended_count} users to suspended status") + + # Step 2: Send notifications to users approaching deadline + notified_count = _send_deadline_reminders(now) + stats["notified_count"] = notified_count + logger.info(f"Sent {notified_count} deadline reminder notifications") + + # Step 3: Process any pending compliance evaluations + processed_count = _evaluate_pending_compliance(now) + stats["processed_count"] = processed_count + logger.info(f"Processed {processed_count} compliance records") + + except Exception as e: + logger.exception(f"Error during MFA compliance job: {e}") + stats["errors"].append(str(e)) + + logger.info( + f"MFA compliance job completed: suspended={stats['suspended_count']}, " + f"notified={stats['notified_count']}, processed={stats['processed_count']}" + ) + + return stats + + +def _send_deadline_reminders(now: datetime) -> int: + """Send deadline reminder notifications to users approaching deadline. + + Identifies users in grace period who are within their organization's + notify_days_before threshold and sends them reminder notifications. + + Args: + now: Current time (UTC) + + Returns: + Number of notifications sent + """ + notified_count = 0 + + # Find all compliance records in grace period + grace_records = MfaPolicyCompliance.query.filter( + MfaPolicyCompliance.status == MfaComplianceStatus.IN_GRACE, + MfaPolicyCompliance.deadline_at != None, + MfaPolicyCompliance.deleted_at == None, + ).all() + + for record in grace_records: + try: + # Get organization policy for notify_days_before + org_policy = OrganizationSecurityPolicy.query.filter_by( + organization_id=record.organization_id, deleted_at=None + ).first() + + if not org_policy: + continue + + notify_threshold = org_policy.notify_days_before + deadline = record.deadline_at + + # Ensure deadline has timezone + if deadline.tzinfo is None: + deadline = deadline.replace(tzinfo=timezone.utc) + + # Calculate time until deadline + time_until_deadline = deadline - now + days_until_deadline = time_until_deadline.total_seconds() / 86400 + + # Check if we should send a reminder + should_notify = False + if days_until_deadline <= notify_threshold: + # Check if we've already notified recently (within last 24 hours) + if record.last_notified_at: + hours_since_notification = ( + now - record.last_notified_at + ).total_seconds() / 3600 + if hours_since_notification < 24: + continue # Already notified recently + + should_notify = True + + if should_notify: + # Get user + user = User.query.get(record.user_id) + if not user: + continue + + # Send notification + success = NotificationService.send_mfa_deadline_reminder( + user=user, + compliance=record, + org_policy=org_policy, + ) + + if success: + # Update notification tracking + record.last_notified_at = now + record.notification_count += 1 + db.session.commit() + notified_count += 1 + logger.info( + f"Sent deadline reminder to user {user.email} " + f"(days until deadline: {days_until_deadline:.1f})" + ) + + except Exception as e: + logger.warning( + f"Error sending reminder for compliance record " + f"{record.id}: {e}" + ) + continue + + return notified_count + + +def _evaluate_pending_compliance(now: datetime) -> int: + """Evaluate and update pending compliance records. + + This handles edge cases where compliance records may need + status updates due to policy changes or other factors. + + Args: + now: Current time (UTC) + + Returns: + Number of records processed + """ + processed_count = 0 + + # Find all non-deleted compliance records + records = MfaPolicyCompliance.query.filter( + MfaPolicyCompliance.deleted_at == None, + ).all() + + for record in records: + try: + # Get the user and evaluate their current state + user = User.query.get(record.user_id) + if not user: + continue + + # Re-evaluate compliance status + # This handles cases where policy changed or user enrolled in MFA + from gatehouse_app.services.mfa_policy_service import MfaPolicyService + + effective_policy = MfaPolicyService.get_effective_user_policy( + user.id, record.organization_id + ) + + new_status = MfaPolicyService._evaluate_compliance_status( + user, effective_policy, record + ) + + # Update status if changed + if record.status != new_status: + old_status = record.status.value if hasattr(record.status, 'value') else str(record.status) + record.status = MfaComplianceStatus(new_status) + db.session.commit() + + logger.info( + f"Updated compliance status for user {user.email} " + f"in org {record.organization_id}: {old_status} -> {new_status}" + ) + + processed_count += 1 + + except Exception as e: + logger.warning( + f"Error evaluating compliance record {record.id}: {e}" + ) + continue + + return processed_count + + +def get_job_status(now: Optional[datetime] = None) -> Dict[str, Any]: + """Get current status of MFA compliance for monitoring. + + Args: + now: Current time, defaults to now (UTC) + + Returns: + Dictionary with compliance statistics + """ + if now is None: + now = datetime.now(timezone.utc) + + # Count records by status + status_counts = {} + for status in MfaComplianceStatus: + count = MfaPolicyCompliance.query.filter( + MfaPolicyCompliance.status == status, + MfaPolicyCompliance.deleted_at == None, + ).count() + status_counts[status.value] = count + + # Count users approaching deadline (within 7 days by default) + approaching_deadline = MfaPolicyCompliance.query.filter( + MfaPolicyCompliance.status == MfaComplianceStatus.IN_GRACE, + MfaPolicyCompliance.deadline_at != None, + MfaPolicyCompliance.deleted_at == None, + ).count() + + # Count past-due records + past_due_count = MfaPolicyCompliance.query.filter( + MfaPolicyCompliance.status == MfaComplianceStatus.PAST_DUE, + MfaPolicyCompliance.deleted_at == None, + ).count() + + return { + "status_counts": status_counts, + "approaching_deadline_count": approaching_deadline, + "past_due_count": past_due_count, + "timestamp": now.isoformat(), + } \ No newline at end of file diff --git a/gatehouse_app/models/__init__.py b/gatehouse_app/models/__init__.py index 5a21c78..b114718 100644 --- a/gatehouse_app/models/__init__.py +++ b/gatehouse_app/models/__init__.py @@ -12,6 +12,9 @@ from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken from gatehouse_app.models.oidc_session import OIDCSession from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata from gatehouse_app.models.oidc_audit_log import OIDCAuditLog +from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy +from gatehouse_app.models.user_security_policy import UserSecurityPolicy +from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance __all__ = [ "BaseModel", @@ -27,4 +30,7 @@ __all__ = [ "OIDCSession", "OIDCTokenMetadata", "OIDCAuditLog", + "OrganizationSecurityPolicy", + "UserSecurityPolicy", + "MfaPolicyCompliance", ] diff --git a/gatehouse_app/models/mfa_policy_compliance.py b/gatehouse_app/models/mfa_policy_compliance.py new file mode 100644 index 0000000..6ecd217 --- /dev/null +++ b/gatehouse_app/models/mfa_policy_compliance.py @@ -0,0 +1,66 @@ +"""MfaPolicyCompliance model.""" +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel +from gatehouse_app.utils.constants import MfaComplianceStatus + + +class MfaPolicyCompliance(BaseModel): + """MFA policy compliance tracking per user per organization. + + Tracks each user's MFA compliance state separately for each organization membership. + """ + + __tablename__ = "mfa_policy_compliance" + + user_id = db.Column( + db.String(36), db.ForeignKey("users.id"), nullable=False, index=True + ) + organization_id = db.Column( + db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True + ) + + status = db.Column( + db.Enum(MfaComplianceStatus), + nullable=False, + default=MfaComplianceStatus.NOT_APPLICABLE, + ) + + # Snapshot of org policy at the time this record became active + policy_version = db.Column(db.Integer, nullable=False) + + # When policy started applying to this user + applied_at = db.Column(db.DateTime, nullable=True) + + # Final deadline for this user to comply (per user, not global) + deadline_at = db.Column(db.DateTime, nullable=True) + + # When they became compliant under this policy_version + compliant_at = db.Column(db.DateTime, nullable=True) + + # When suspended enforcement started for this user + suspended_at = db.Column(db.DateTime, nullable=True) + + # Notification tracking + last_notified_at = db.Column(db.DateTime, nullable=True) + notification_count = db.Column(db.Integer, nullable=False, default=0) + + __table_args__ = ( + db.UniqueConstraint( + "user_id", "organization_id", name="uix_user_org_compliance" + ), + ) + + # Relationships + user = db.relationship( + "User", back_populates="mfa_compliance", foreign_keys=[user_id] + ) + organization = db.relationship("Organization", foreign_keys=[organization_id]) + + def __repr__(self): + """String representation of MfaPolicyCompliance.""" + return f"" + + def to_dict(self, exclude=None): + """Convert to dictionary.""" + exclude = exclude or [] + return super().to_dict(exclude=exclude) \ No newline at end of file diff --git a/gatehouse_app/models/organization.py b/gatehouse_app/models/organization.py index f81fcc3..88fae01 100644 --- a/gatehouse_app/models/organization.py +++ b/gatehouse_app/models/organization.py @@ -24,6 +24,13 @@ class Organization(BaseModel): oidc_clients = db.relationship( "OIDCClient", back_populates="organization", cascade="all, delete-orphan" ) + security_policy = db.relationship( + "OrganizationSecurityPolicy", + back_populates="organization", + uselist=False, + cascade="all, delete-orphan", + foreign_keys="OrganizationSecurityPolicy.organization_id", + ) def __repr__(self): """String representation of Organization.""" diff --git a/gatehouse_app/models/organization_security_policy.py b/gatehouse_app/models/organization_security_policy.py new file mode 100644 index 0000000..991b72d --- /dev/null +++ b/gatehouse_app/models/organization_security_policy.py @@ -0,0 +1,53 @@ +"""OrganizationSecurityPolicy model.""" +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel +from gatehouse_app.utils.constants import MfaPolicyMode + + +class OrganizationSecurityPolicy(BaseModel): + """Organization security policy model for MFA configuration. + + One row per organization capturing its current security requirements. + """ + + __tablename__ = "organization_security_policies" + + organization_id = db.Column( + db.String(36), + db.ForeignKey("organizations.id"), + nullable=False, + index=True, + unique=True, + ) + + # MFA policy configuration + mfa_policy_mode = db.Column( + db.Enum(MfaPolicyMode), nullable=False, default=MfaPolicyMode.OPTIONAL + ) + + # Grace period for members in days + mfa_grace_period_days = db.Column(db.Integer, nullable=False, default=14) + + # Notification settings (in days before individual user deadline) + notify_days_before = db.Column(db.Integer, nullable=False, default=7) + + # Versioning for compatibility tracking + policy_version = db.Column(db.Integer, nullable=False, default=1) + + # Audit metadata + updated_by_user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True) + + # Relationships + organization = db.relationship( + "Organization", back_populates="security_policy", foreign_keys=[organization_id] + ) + updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id]) + + def __repr__(self): + """String representation of OrganizationSecurityPolicy.""" + return f"" + + def to_dict(self, exclude=None): + """Convert to dictionary.""" + exclude = exclude or [] + return super().to_dict(exclude=exclude) \ No newline at end of file diff --git a/gatehouse_app/models/session.py b/gatehouse_app/models/session.py index 2bc6261..0290a20 100644 --- a/gatehouse_app/models/session.py +++ b/gatehouse_app/models/session.py @@ -25,6 +25,9 @@ class Session(BaseModel): revoked_at = db.Column(db.DateTime, nullable=True) revoked_reason = db.Column(db.String(255), nullable=True) + # Compliance session flag + is_compliance_only = db.Column(db.Boolean, nullable=False, default=False) + # Relationships user = db.relationship("User", back_populates="sessions") diff --git a/gatehouse_app/models/user.py b/gatehouse_app/models/user.py index 47670e6..474269f 100644 --- a/gatehouse_app/models/user.py +++ b/gatehouse_app/models/user.py @@ -31,6 +31,18 @@ class User(BaseModel): foreign_keys="OrganizationMember.user_id", ) audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan") + security_policies = db.relationship( + "UserSecurityPolicy", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="UserSecurityPolicy.user_id", + ) + mfa_compliance = db.relationship( + "MfaPolicyCompliance", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="MfaPolicyCompliance.user_id", + ) def __repr__(self): """String representation of User.""" diff --git a/gatehouse_app/models/user_security_policy.py b/gatehouse_app/models/user_security_policy.py new file mode 100644 index 0000000..d765575 --- /dev/null +++ b/gatehouse_app/models/user_security_policy.py @@ -0,0 +1,53 @@ +"""UserSecurityPolicy model.""" +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel +from gatehouse_app.utils.constants import MfaRequirementOverride + + +class UserSecurityPolicy(BaseModel): + """User security policy model for per-user MFA overrides. + + Stores per user overrides of organization level MFA requirements. + """ + + __tablename__ = "user_security_policies" + + user_id = db.Column( + db.String(36), db.ForeignKey("users.id"), nullable=False, index=True + ) + organization_id = db.Column( + db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True + ) + + mfa_override_mode = db.Column( + db.Enum(MfaRequirementOverride), + nullable=False, + default=MfaRequirementOverride.INHERIT, + ) + + # If override is REQUIRED and you want to force a specific factor set + force_totp = db.Column(db.Boolean, nullable=False, default=False) + force_webauthn = db.Column(db.Boolean, nullable=False, default=False) + + __table_args__ = ( + db.UniqueConstraint( + "user_id", "organization_id", name="uix_user_org_policy" + ), + ) + + # Relationships + user = db.relationship( + "User", back_populates="security_policies", foreign_keys=[user_id] + ) + organization = db.relationship( + "Organization", foreign_keys=[organization_id] + ) + + def __repr__(self): + """String representation of UserSecurityPolicy.""" + return f"" + + def to_dict(self, exclude=None): + """Convert to dictionary.""" + exclude = exclude or [] + return super().to_dict(exclude=exclude) \ No newline at end of file diff --git a/gatehouse_app/schemas/auth_schema.py b/gatehouse_app/schemas/auth_schema.py index f865671..b2042f7 100644 --- a/gatehouse_app/schemas/auth_schema.py +++ b/gatehouse_app/schemas/auth_schema.py @@ -96,3 +96,29 @@ class TOTPRegenerateBackupCodesSchema(Schema): """Schema for regenerating backup codes.""" password = fields.Str(required=True, validate=validate.Length(min=1)) + + +class MfaComplianceOrgSchema(Schema): + """Schema for MFA compliance per organization.""" + organization_id = fields.Str(required=True) + organization_name = fields.Str(required=True) + status = fields.Str(required=True) + deadline_at = fields.Str(allow_none=True) + + +class MfaComplianceSchema(Schema): + """Schema for MFA compliance summary in login response.""" + overall_status = fields.Str(required=True) + missing_methods = fields.List(fields.Str(), required=True) + deadline_at = fields.Str(allow_none=True) + orgs = fields.List(fields.Nested(MfaComplianceOrgSchema), required=True) + + +class LoginResponseSchema(Schema): + """Schema for login response.""" + user = fields.Dict(required=True) + token = fields.Str(required=True) + expires_at = fields.Str(required=True) + requires_totp = fields.Bool(required=False) + requires_mfa_enrollment = fields.Bool(required=False) + mfa_compliance = fields.Nested(MfaComplianceSchema, required=False) diff --git a/gatehouse_app/services/auth_service.py b/gatehouse_app/services/auth_service.py index b1d833e..3df5bf3 100644 --- a/gatehouse_app/services/auth_service.py +++ b/gatehouse_app/services/auth_service.py @@ -140,13 +140,14 @@ class AuthService: return user @staticmethod - def create_session(user, duration_seconds=86400): + def create_session(user, duration_seconds=86400, is_compliance_only=False): """ Create a new session for the user. Args: user: User instance duration_seconds: Session duration in seconds + is_compliance_only: Whether this is a compliance-only session (limited access) Returns: Session instance @@ -163,6 +164,7 @@ class AuthService: user_agent=request.headers.get("User-Agent"), expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds), last_activity_at=datetime.now(timezone.utc), + is_compliance_only=is_compliance_only, ) session.save() diff --git a/gatehouse_app/services/mfa_policy_service.py b/gatehouse_app/services/mfa_policy_service.py new file mode 100644 index 0000000..d553fc1 --- /dev/null +++ b/gatehouse_app/services/mfa_policy_service.py @@ -0,0 +1,978 @@ +"""MFA Policy Service.""" +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Optional, List, Dict, Any + +from gatehouse_app.extensions import db +from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy +from gatehouse_app.models.user_security_policy import UserSecurityPolicy +from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance +from gatehouse_app.models.user import User +from gatehouse_app.models.organization import Organization +from gatehouse_app.services.audit_service import AuditService +from gatehouse_app.utils.constants import ( + MfaPolicyMode, + MfaComplianceStatus, + MfaRequirementOverride, + AuditAction, + UserStatus, +) + + +@dataclass +class OrgPolicyDto: + """DTO for organization policy.""" + organization_id: str + mfa_policy_mode: str + mfa_grace_period_days: int + notify_days_before: int + policy_version: int + updated_by_user_id: Optional[str] = None + + +@dataclass +class EffectiveUserPolicyDto: + """DTO for effective user policy combining org and user overrides.""" + organization_id: str + effective_mode: str + requires_totp: bool + requires_webauthn: bool + grace_period_days: int + is_exempt: bool = False + + +@dataclass +class UserMfaStateDto: + """DTO for per-organization MFA state.""" + organization_id: str + organization_name: str + status: str + effective_mode: str + deadline_at: Optional[str] = None + applied_at: Optional[str] = None + + +@dataclass +class AggregateMfaStateDto: + """DTO for aggregate MFA state across all organizations.""" + overall_status: str + missing_methods: List[str] = field(default_factory=list) + deadline_at: Optional[str] = None + orgs: List[UserMfaStateDto] = field(default_factory=list) + + +@dataclass +class LoginPolicyResult: + """Result of policy evaluation after primary auth success.""" + can_create_full_session: bool + create_compliance_only_session: bool + compliance_summary: AggregateMfaStateDto + + +class MfaPolicyService: + """Service for MFA policy evaluation and compliance tracking.""" + + @staticmethod + def get_org_policy(org_id: str) -> Optional[OrgPolicyDto]: + """Get organization security policy. + + Args: + org_id: Organization ID + + Returns: + OrgPolicyDto or None if not found + """ + policy = OrganizationSecurityPolicy.query.filter_by( + organization_id=org_id, deleted_at=None + ).first() + + if not policy: + return None + + return OrgPolicyDto( + organization_id=policy.organization_id, + mfa_policy_mode=policy.mfa_policy_mode.value, + mfa_grace_period_days=policy.mfa_grace_period_days, + notify_days_before=policy.notify_days_before, + policy_version=policy.policy_version, + updated_by_user_id=policy.updated_by_user_id, + ) + + @staticmethod + def get_effective_user_policy( + user_id: str, org_id: str + ) -> EffectiveUserPolicyDto: + """Get effective user policy combining org policy with user overrides. + + Args: + user_id: User ID + org_id: Organization ID + + Returns: + EffectiveUserPolicyDto + """ + # Get org policy + org_policy = OrganizationSecurityPolicy.query.filter_by( + organization_id=org_id, deleted_at=None + ).first() + + if not org_policy: + # No org policy means no requirements + return EffectiveUserPolicyDto( + organization_id=org_id, + effective_mode=MfaPolicyMode.DISABLED.value, + requires_totp=False, + requires_webauthn=False, + grace_period_days=0, + is_exempt=True, + ) + + # Get user override + user_override = UserSecurityPolicy.query.filter_by( + user_id=user_id, organization_id=org_id, deleted_at=None + ).first() + + # Determine effective mode + if user_override: + override_mode = user_override.mfa_override_mode + if override_mode == MfaRequirementOverride.EXEMPT: + return EffectiveUserPolicyDto( + organization_id=org_id, + effective_mode=MfaPolicyMode.DISABLED.value, + requires_totp=False, + requires_webauthn=False, + grace_period_days=org_policy.mfa_grace_period_days, + is_exempt=True, + ) + elif override_mode == MfaRequirementOverride.REQUIRED: + # User is required to have MFA even if org is optional + effective_mode = MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN + else: + effective_mode = org_policy.mfa_policy_mode + else: + effective_mode = org_policy.mfa_policy_mode + + # Determine required methods based on mode + requires_totp = effective_mode in ( + MfaPolicyMode.REQUIRE_TOTP, + MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + ) + requires_webauthn = effective_mode in ( + MfaPolicyMode.REQUIRE_WEBAUTHN, + MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + ) + + return EffectiveUserPolicyDto( + organization_id=org_id, + effective_mode=effective_mode.value, + requires_totp=requires_totp, + requires_webauthn=requires_webauthn, + grace_period_days=org_policy.mfa_grace_period_days, + is_exempt=False, + ) + + @staticmethod + def evaluate_user_mfa_state(user: User) -> AggregateMfaStateDto: + """Evaluate user's MFA state across all organizations. + + Args: + user: User instance + + Returns: + AggregateMfaStateDto with overall status and per-org breakdown + """ + org_states: List[UserMfaStateDto] = [] + overall_status = MfaComplianceStatus.COMPLIANT.value + earliest_deadline: Optional[datetime] = None + missing_methods: set = set() + + for membership in user.organization_memberships: + if membership.deleted_at is not None: + continue + + org = membership.organization + if org.deleted_at is not None: + continue + + effective_policy = MfaPolicyService.get_effective_user_policy( + user.id, org.id + ) + + # Get or create compliance record + compliance = MfaPolicyCompliance.query.filter_by( + user_id=user.id, organization_id=org.id, deleted_at=None + ).first() + + if not compliance: + # Create initial compliance record + compliance = MfaPolicyCompliance( + user_id=user.id, + organization_id=org.id, + status=MfaComplianceStatus.NOT_APPLICABLE, + policy_version=0, + ) + compliance.save() + + # Determine status based on policy and user MFA state + status = MfaPolicyService._evaluate_compliance_status( + user, effective_policy, compliance + ) + + # Update compliance record if needed + if compliance.status != status: + compliance.status = status + db.session.commit() + + # Track missing methods + if status not in ( + MfaComplianceStatus.COMPLIANT.value, + MfaComplianceStatus.NOT_APPLICABLE.value, + ): + if effective_policy.requires_totp and not user.has_totp_enabled(): + missing_methods.add("totp") + if effective_policy.requires_webauthn and not user.has_webauthn_enabled(): + missing_methods.add("webauthn") + + # Track earliest deadline + if compliance.deadline_at: + if earliest_deadline is None or compliance.deadline_at < earliest_deadline: + earliest_deadline = compliance.deadline_at + + # Determine overall status (most restrictive) + if status == MfaComplianceStatus.SUSPENDED.value: + overall_status = MfaComplianceStatus.SUSPENDED.value + elif ( + status == MfaComplianceStatus.PAST_DUE.value + and overall_status != MfaComplianceStatus.SUSPENDED.value + ): + overall_status = MfaComplianceStatus.PAST_DUE.value + elif ( + status == MfaComplianceStatus.IN_GRACE.value + and overall_status + not in ( + MfaComplianceStatus.SUSPENDED.value, + MfaComplianceStatus.PAST_DUE.value, + ) + ): + overall_status = MfaComplianceStatus.IN_GRACE.value + elif ( + status == MfaComplianceStatus.PENDING.value + and overall_status == MfaComplianceStatus.COMPLIANT.value + ): + overall_status = MfaComplianceStatus.PENDING.value + + org_states.append( + UserMfaStateDto( + organization_id=org.id, + organization_name=org.name, + status=status, + effective_mode=effective_policy.effective_mode, + deadline_at=compliance.deadline_at.isoformat() if compliance.deadline_at else None, + applied_at=compliance.applied_at.isoformat() if compliance.applied_at else None, + ) + ) + + return AggregateMfaStateDto( + overall_status=overall_status, + missing_methods=list(missing_methods), + deadline_at=earliest_deadline.isoformat() if earliest_deadline else None, + orgs=org_states, + ) + + @staticmethod + def _evaluate_compliance_status( + user: User, + effective_policy: EffectiveUserPolicyDto, + compliance: MfaPolicyCompliance, + ) -> str: + """Evaluate compliance status for a user in an organization. + + Args: + user: User instance + effective_policy: EffectiveUserPolicyDto + compliance: MfaPolicyCompliance instance + + Returns: + Status string + """ + now = datetime.now(timezone.utc) + + # If exempt or disabled, mark as not applicable + if effective_policy.is_exempt: + return MfaComplianceStatus.NOT_APPLICABLE.value + + if effective_policy.effective_mode == MfaPolicyMode.DISABLED.value: + return MfaComplianceStatus.NOT_APPLICABLE.value + + # Check if user has required MFA methods + has_totp = user.has_totp_enabled() + has_webauthn = user.has_webauthn_enabled() + + has_required = ( + (not effective_policy.requires_totp or has_totp) + and (not effective_policy.requires_webauthn or has_webauthn) + ) + + if has_required: + return MfaComplianceStatus.COMPLIANT.value + + # User is missing required MFA + # If no deadline set, set it now + if not compliance.deadline_at and effective_policy.grace_period_days > 0: + compliance.applied_at = now + compliance.deadline_at = now.replace( + tzinfo=None + ) + __import__("datetime").timedelta( + days=effective_policy.grace_period_days + ) + db.session.commit() + return MfaComplianceStatus.IN_GRACE.value + + # Check deadline + if compliance.deadline_at: + deadline = compliance.deadline_at + if deadline.tzinfo is None: + deadline = deadline.replace(tzinfo=timezone.utc) + + if now < deadline: + return MfaComplianceStatus.IN_GRACE.value + else: + return MfaComplianceStatus.PAST_DUE.value + + return MfaComplianceStatus.PENDING.value + + @staticmethod + def after_primary_auth_success( + user: User, remember_me: bool = False + ) -> LoginPolicyResult: + """Determine session type based on compliance after primary auth success. + + Args: + user: User instance + remember_me: Whether this is a remember-me session + + Returns: + LoginPolicyResult with session type and compliance summary + """ + compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user) + + # Check if there are any REQUIRED policies affecting this user + has_required_policy = False + for org_state in compliance_summary.orgs: + if org_state.effective_mode in ( + MfaPolicyMode.REQUIRE_TOTP.value, + MfaPolicyMode.REQUIRE_WEBAUTHN.value, + MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value, + ): + has_required_policy = True + break + + if not has_required_policy: + # No required policies, full session allowed + return LoginPolicyResult( + can_create_full_session=True, + create_compliance_only_session=False, + compliance_summary=compliance_summary, + ) + + # Check if user is compliant + if compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value: + return LoginPolicyResult( + can_create_full_session=True, + create_compliance_only_session=False, + compliance_summary=compliance_summary, + ) + + # User is not compliant + if compliance_summary.overall_status in ( + MfaComplianceStatus.IN_GRACE.value, + MfaComplianceStatus.PENDING.value, + ): + # Can proceed with full session but warnings + return LoginPolicyResult( + can_create_full_session=True, + create_compliance_only_session=False, + compliance_summary=compliance_summary, + ) + + # Past due or suspended - compliance only session + return LoginPolicyResult( + can_create_full_session=False, + create_compliance_only_session=True, + compliance_summary=compliance_summary, + ) + + @staticmethod + def transition_to_suspended_if_past_due(now: Optional[datetime] = None) -> int: + """Scheduled job to transition past-due users to suspended status. + + Args: + now: Current time, defaults to now + + Returns: + Number of users transitioned to suspended + """ + if now is None: + now = datetime.now(timezone.utc) + + suspended_count = 0 + + # Find all compliance records that are past due + past_due_records = MfaPolicyCompliance.query.filter( + MfaPolicyCompliance.status == MfaComplianceStatus.PAST_DUE, + MfaPolicyCompliance.deadline_at != None, + MfaPolicyCompliance.deleted_at == None, + ).all() + + for record in past_due_records: + deadline = record.deadline_at + if deadline.tzinfo is None: + deadline = deadline.replace(tzinfo=timezone.utc) + + if now >= deadline: + # Transition to suspended + record.status = MfaComplianceStatus.SUSPENDED + record.suspended_at = now + db.session.commit() + + # Update user status + user = User.query.get(record.user_id) + if user and user.status != UserStatus.COMPLIANCE_SUSPENDED: + user.status = UserStatus.COMPLIANCE_SUSPENDED + db.session.commit() + + # Audit log + AuditService.log_action( + action=AuditAction.MFA_POLICY_USER_SUSPENDED, + user_id=record.user_id, + organization_id=record.organization_id, + description=f"User suspended due to MFA compliance deadline passed", + ) + + suspended_count += 1 + + return suspended_count + + @staticmethod + def create_org_policy( + organization_id: str, + mfa_policy_mode: MfaPolicyMode, + mfa_grace_period_days: int = 14, + notify_days_before: int = 7, + updated_by_user_id: Optional[str] = None, + ) -> OrganizationSecurityPolicy: + """Create or update organization security policy. + + Args: + organization_id: Organization ID + mfa_policy_mode: MFA policy mode + mfa_grace_period_days: Grace period in days + notify_days_before: Days before deadline to notify + updated_by_user_id: User making the change + + Returns: + OrganizationSecurityPolicy instance + """ + policy = OrganizationSecurityPolicy.query.filter_by( + organization_id=organization_id, deleted_at=None + ).first() + + if policy: + # Update existing + old_mode = policy.mfa_policy_mode + policy.mfa_policy_mode = mfa_policy_mode + policy.mfa_grace_period_days = mfa_grace_period_days + policy.notify_days_before = notify_days_before + policy.policy_version += 1 + policy.updated_by_user_id = updated_by_user_id + policy.save() + + # Audit log + AuditService.log_action( + action=AuditAction.ORG_SECURITY_POLICY_UPDATE, + user_id=updated_by_user_id, + organization_id=organization_id, + description=f"Security policy updated from {old_mode.value} to {mfa_policy_mode.value}", + ) + else: + # Create new + policy = OrganizationSecurityPolicy( + organization_id=organization_id, + mfa_policy_mode=mfa_policy_mode, + mfa_grace_period_days=mfa_grace_period_days, + notify_days_before=notify_days_before, + policy_version=1, + updated_by_user_id=updated_by_user_id, + ) + policy.save() + + # Audit log + AuditService.log_action( + action=AuditAction.ORG_SECURITY_POLICY_UPDATE, + user_id=updated_by_user_id, + organization_id=organization_id, + description=f"Security policy created with mode {mfa_policy_mode.value}", + ) + + return policy + + @staticmethod + def set_user_override( + user_id: str, + organization_id: str, + mfa_override_mode: MfaRequirementOverride, + force_totp: bool = False, + force_webauthn: bool = False, + updated_by_user_id: Optional[str] = None, + ) -> UserSecurityPolicy: + """Set user security policy override. + + Args: + user_id: User ID + organization_id: Organization ID + mfa_override_mode: Override mode + force_totp: Force TOTP requirement + force_webauthn: Force WebAuthn requirement + updated_by_user_id: User making the change + + Returns: + UserSecurityPolicy instance + """ + override = UserSecurityPolicy.query.filter_by( + user_id=user_id, organization_id=organization_id, deleted_at=None + ).first() + + if override: + old_mode = override.mfa_override_mode + override.mfa_override_mode = mfa_override_mode + override.force_totp = force_totp + override.force_webauthn = force_webauthn + override.save() + + # Audit log + AuditService.log_action( + action=AuditAction.USER_SECURITY_POLICY_OVERRIDE_UPDATE, + user_id=updated_by_user_id, + organization_id=organization_id, + resource_type="user", + resource_id=user_id, + description=f"User policy override updated from {old_mode.value} to {mfa_override_mode.value}", + ) + else: + override = UserSecurityPolicy( + user_id=user_id, + organization_id=organization_id, + mfa_override_mode=mfa_override_mode, + force_totp=force_totp, + force_webauthn=force_webauthn, + ) + override.save() + + # Audit log + AuditService.log_action( + action=AuditAction.USER_SECURITY_POLICY_OVERRIDE_UPDATE, + user_id=updated_by_user_id, + organization_id=organization_id, + resource_type="user", + resource_id=user_id, + description=f"User policy override created with mode {mfa_override_mode.value}", + ) + + return override + + @staticmethod + def get_user_compliance(user_id: str, organization_id: str) -> Optional[MfaPolicyCompliance]: + """Get user compliance record for an organization. + + Args: + user_id: User ID + organization_id: Organization ID + + Returns: + MfaPolicyCompliance or None + """ + return MfaPolicyCompliance.query.filter_by( + user_id=user_id, organization_id=organization_id, deleted_at=None + ).first() + + @staticmethod + def get_org_compliance_list( + organization_id: str, status: Optional[MfaComplianceStatus] = None, limit: int = 100, offset: int = 0 + ) -> List[Dict[str, Any]]: + """Get list of user compliance records for an organization. + + Args: + organization_id: Organization ID + status: Optional status filter + limit: Maximum records to return + offset: Offset for pagination + + Returns: + List of compliance records with user info + """ + query = db.session.query( + MfaPolicyCompliance, + User.email, + User.full_name, + ).join( + User, User.id == MfaPolicyCompliance.user_id + ).filter( + MfaPolicyCompliance.organization_id == organization_id, + MfaPolicyCompliance.deleted_at == None, + User.deleted_at == None, + ) + + if status: + query = query.filter(MfaPolicyCompliance.status == status) + + records = query.order_by( + MfaPolicyCompliance.created_at.desc() + ).limit(limit).offset(offset).all() + + result = [] + for compliance, email, full_name in records: + result.append({ + "user_id": compliance.user_id, + "email": email, + "full_name": full_name, + "status": compliance.status.value, + "deadline_at": compliance.deadline_at.isoformat() if compliance.deadline_at else None, + "applied_at": compliance.applied_at.isoformat() if compliance.applied_at else None, + "compliant_at": compliance.compliant_at.isoformat() if compliance.compliant_at else None, + "suspended_at": compliance.suspended_at.isoformat() if compliance.suspended_at else None, + "notification_count": compliance.notification_count, + }) + + return result + + # ========================================================================= + # Multi-Organization Edge Case Handling + # ========================================================================= + + @staticmethod + def get_strictest_mode(modes: List[str]) -> str: + """Get the strictest MFA policy mode from a list. + + Used for multi-org scenarios where a user belongs to multiple organizations + with different policies. "Most secure wins" logic determines the effective + requirement. + + Args: + modes: List of policy mode strings + + Returns: + The strictest mode string + """ + # Define strictness hierarchy (more strict = higher index) + strictness_order = [ + MfaPolicyMode.DISABLED.value, + MfaPolicyMode.OPTIONAL.value, + MfaPolicyMode.REQUIRE_TOTP.value, + MfaPolicyMode.REQUIRE_WEBAUTHN.value, + MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value, + ] + + max_strictness = -1 + result_mode = MfaPolicyMode.OPTIONAL.value + + for mode in modes: + if mode in strictness_order: + idx = strictness_order.index(mode) + if idx > max_strictness: + max_strictness = idx + result_mode = mode + + return result_mode + + @staticmethod + def reevaluate_all_org_compliance(organization_id: str, now: Optional[datetime] = None) -> int: + """Reevaluate compliance for all users in an organization. + + Called when org policy changes to ensure all users are properly evaluated + under the new policy. This handles the edge case where policy becomes + more restrictive (e.g., OPTIONAL -> REQUIRE_TOTP). + + Args: + organization_id: Organization ID + now: Current time, defaults to now + + Returns: + Number of compliance records updated + """ + if now is None: + now = datetime.now(timezone.utc) + + from gatehouse_app.models.organization_member import OrganizationMember + + updated_count = 0 + + # Get all active members of the organization + memberships = OrganizationMember.query.filter_by( + organization_id=organization_id, deleted_at=None + ).all() + + for membership in memberships: + user = membership.user + if not user or user.deleted_at is not None: + continue + + # Get or create compliance record + compliance = MfaPolicyCompliance.query.filter_by( + user_id=user.id, organization_id=organization_id, deleted_at=None + ).first() + + if not compliance: + compliance = MfaPolicyCompliance( + user_id=user.id, + organization_id=organization_id, + status=MfaComplianceStatus.NOT_APPLICABLE, + policy_version=0, + ) + compliance.save() + + # Reevaluate under new policy + effective_policy = MfaPolicyService.get_effective_user_policy( + user.id, organization_id + ) + + old_status = compliance.status.value if hasattr(compliance.status, 'value') else str(compliance.status) + new_status = MfaPolicyService._evaluate_compliance_status( + user, effective_policy, compliance + ) + + if old_status != new_status: + compliance.status = MfaComplianceStatus(new_status) + # Reset deadline if transitioning to in_grace from a non-grace state + if new_status == MfaComplianceStatus.IN_GRACE.value and not compliance.deadline_at: + compliance.applied_at = now + compliance.deadline_at = now.replace(tzinfo=None) + __import__("datetime").timedelta( + days=effective_policy.grace_period_days + ) + db.session.commit() + updated_count += 1 + + logger.info( + f"Reevaluated compliance for user {user.email} in org {organization_id}: " + f"{old_status} -> {new_status}" + ) + + return updated_count + + @staticmethod + def check_and_restore_user_status(user_id: str) -> bool: + """Check if user should be restored to ACTIVE status. + + Called after compliance changes to determine if a COMPLIANCE_SUSPENDED + user should be restored to ACTIVE status. This happens when: + - All org policies are now compliant + - User overrides were changed to EXEMPT + + Args: + user_id: User ID + + Returns: + True if user status was restored, False otherwise + """ + user = User.query.get(user_id) + if not user: + return False + + if user.status != UserStatus.COMPLIANCE_SUSPENDED: + return False + + # Evaluate user's overall compliance state + compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user) + + # If now compliant across all orgs, restore status + if compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value: + user.status = UserStatus.ACTIVE + db.session.commit() + + # Audit log + AuditService.log_action( + action=AuditAction.MFA_POLICY_USER_COMPLIANT, + user_id=user_id, + description="User restored to ACTIVE status after becoming MFA compliant", + ) + + logger.info(f"User {user.email} restored to ACTIVE status") + return True + + return False + + # ========================================================================= + # User Override Edge Case Handling + # ========================================================================= + + @staticmethod + def get_override_summary(user_id: str, organization_id: str) -> Dict[str, Any]: + """Get a summary of user override for an organization. + + Args: + user_id: User ID + organization_id: Organization ID + + Returns: + Dictionary with override information + """ + user_override = UserSecurityPolicy.query.filter_by( + user_id=user_id, organization_id=organization_id, deleted_at=None + ).first() + + org_policy = MfaPolicyService.get_org_policy(organization_id) + + if not user_override: + return { + "has_override": False, + "mode": "inherit", + "org_policy_mode": org_policy.mfa_policy_mode if org_policy else "none", + "effective_mode": org_policy.mfa_policy_mode if org_policy else "disabled", + } + + effective_policy = MfaPolicyService.get_effective_user_policy( + user_id, organization_id + ) + + return { + "has_override": True, + "mode": user_override.mfa_override_mode.value, + "force_totp": user_override.force_totp, + "force_webauthn": user_override.force_webauthn, + "org_policy_mode": org_policy.mfa_policy_mode if org_policy else "none", + "effective_mode": effective_policy.effective_mode, + "is_exempt": effective_policy.is_exempt, + } + + # ========================================================================= + # Security Audit Logging + # ========================================================================= + + @staticmethod + def log_suspended_login_attempt(user: User, ip_address: str = None, user_agent: str = None): + """Log a login attempt by a compliance-suspended user. + + This provides audit trail for potential security incidents where + suspended users attempt to access the system. + + Args: + user: User instance + ip_address: Client IP address + user_agent: Client user agent + """ + # Get current compliance summary + compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user) + + # Find which org(s) caused suspension + suspended_orgs = [ + org for org in compliance_summary.orgs + if org.status == MfaComplianceStatus.SUSPENDED.value + ] + + org_ids = [org.organization_id for org in suspended_orgs] + + AuditService.log_action( + action=AuditAction.USER_LOGIN, + user_id=user.id, + organization_id=org_ids[0] if org_ids else None, + ip_address=ip_address, + user_agent=user_agent, + description=f"Login attempt while compliance suspended. Suspended orgs: {org_ids}", + success=False, + error_message="MFA compliance required", + ) + + @staticmethod + def log_policy_bypass_attempt( + user: User, + endpoint: str, + ip_address: str = None, + user_agent: str = None, + ): + """Log a potential policy bypass attempt. + + Called when a compliance-only session attempts to access a + full-access endpoint. This could indicate security issues. + + Args: + user: User instance + endpoint: Requested endpoint + ip_address: Client IP address + user_agent: Client user agent + """ + AuditService.log_action( + action=AuditAction.USER_LOGIN, # Reusing USER_LOGIN for audit + user_id=user.id, + ip_address=ip_address, + user_agent=user_agent, + resource_type="endpoint", + resource_id=endpoint, + description=f"Policy bypass attempt - compliance-only session accessed {endpoint}", + success=False, + error_message="MFA compliance required", + ) + + @staticmethod + def get_multi_org_aggregate_state(user: User) -> Dict[str, Any]: + """Get aggregate MFA state for a user across all organizations. + + This provides detailed breakdown of how multi-org membership affects + compliance status, useful for debugging and admin reporting. + + Args: + user: User instance + + Returns: + Dictionary with aggregate state details + """ + compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user) + + # Calculate strictest requirement + modes = [org.effective_mode for org in compliance_summary.orgs] + strictest_mode = MfaPolicyService.get_strictest_mode(modes) + + # Find organizations requiring MFA + requiring_orgs = [ + { + "organization_id": org.organization_id, + "organization_name": org.organization_name, + "mode": org.effective_mode, + "status": org.status, + } + for org in compliance_summary.orgs + if org.effective_mode not in ( + MfaPolicyMode.DISABLED.value, + MfaPolicyMode.OPTIONAL.value, + ) + ] + + # Find exempt organizations + for org in compliance_summary.orgs: + override_summary = MfaPolicyService.get_override_summary( + user.id, org.organization_id + ) + if override_summary.get("is_exempt"): + requiring_orgs = [ + o for o in requiring_orgs + if o["organization_id"] != org.organization_id + ] + + return { + "overall_status": compliance_summary.overall_status, + "strictest_mode": strictest_mode, + "missing_methods": compliance_summary.missing_methods, + "deadline_at": compliance_summary.deadline_at, + "requiring_org_count": len(requiring_orgs), + "requiring_orgs": requiring_orgs, + "total_org_count": len(compliance_summary.orgs), + "per_org_details": [ + { + "organization_id": org.organization_id, + "organization_name": org.organization_name, + "effective_mode": org.effective_mode, + "status": org.status, + "deadline_at": org.deadline_at, + "applied_at": org.applied_at, + } + for org in compliance_summary.orgs + ], + } \ No newline at end of file diff --git a/gatehouse_app/services/notification_service.py b/gatehouse_app/services/notification_service.py new file mode 100644 index 0000000..d3afdca --- /dev/null +++ b/gatehouse_app/services/notification_service.py @@ -0,0 +1,430 @@ +"""Notification Service for MFA compliance notifications. + +This service handles sending MFA-related notifications to users, including: +- Deadline reminder emails +- Suspension notifications +- Compliance status updates + +The service is designed to work with or without email infrastructure: +- If email is configured, it sends actual emails +- If email is not available, it logs notifications for debugging/auditing + +Usage: + from gatehouse_app.services.notification_service import NotificationService + NotificationService.send_mfa_deadline_reminder(user, compliance, org_policy) +""" +from datetime import datetime, timezone +from typing import Optional, Dict, Any +import logging +import json + +from gatehouse_app.extensions import db +from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance +from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy +from gatehouse_app.models.user import User +from gatehouse_app.services.audit_service import AuditService +from gatehouse_app.utils.constants import AuditAction + +logger = logging.getLogger(__name__) + + +class NotificationService: + """Service for sending MFA compliance notifications.""" + + # Configuration keys for email settings + EMAIL_ENABLED_KEY = "EMAIL_ENABLED" + SMTP_HOST_KEY = "SMTP_HOST" + SMTP_PORT_KEY = "SMTP_PORT" + SMTP_USERNAME_KEY = "SMTP_USERNAME" + SMTP_PASSWORD_KEY = "SMTP_PASSWORD" + FROM_ADDRESS_KEY = "FROM_ADDRESS" + + @staticmethod + def send_mfa_deadline_reminder( + user: User, + compliance: MfaPolicyCompliance, + org_policy: OrganizationSecurityPolicy, + ) -> bool: + """Send MFA deadline reminder notification to user. + + Sends a reminder email to users who are approaching their MFA + compliance deadline. The reminder includes: + - Days remaining until deadline + - Required MFA methods + - Link to MFA enrollment + + Args: + user: User to notify + compliance: User's compliance record + org_policy: Organization's MFA policy + + Returns: + True if notification was sent successfully, False otherwise + """ + try: + # Calculate days until deadline + deadline = compliance.deadline_at + if deadline.tzinfo is None: + deadline = deadline.replace(tzinfo=timezone.utc) + + now = datetime.now(timezone.utc) + days_until_deadline = (deadline - now).days + + # Build notification content + subject = f"Action Required: MFA enrollment deadline in {days_until_deadline} days" + body = NotificationService._build_deadline_reminder_body( + user, compliance, org_policy, days_until_deadline + ) + + # Send the notification + success = NotificationService._send_email( + to_address=user.email, + subject=subject, + body=body, + ) + + if success: + logger.info( + f"Sent MFA deadline reminder to {user.email} " + f"({days_until_deadline} days remaining # Audit log +)" + ) + AuditService.log_action( + action=AuditAction.MFA_POLICY_USER_COMPLIANT, + user_id=user.id, + organization_id=compliance.organization_id, + description=f"MFA deadline reminder sent. Days remaining: {days_until_deadline}", + ) + else: + logger.warning( + f"Failed to send MFA deadline reminder to {user.email}" + ) + + return success + + except Exception as e: + logger.exception(f"Error sending MFA deadline reminder to {user.email}: {e}") + return False + + @staticmethod + def send_mfa_suspended_notification( + user: User, + compliance: MfaPolicyCompliance, + org_policy: OrganizationSecurityPolicy, + ) -> bool: + """Send MFA suspension notification to user. + + Notifies users that their account has been suspended due to + failure to comply with MFA requirements. The notification includes: + - Explanation of suspension + - Steps to restore access + - Link to MFA enrollment + + Args: + user: User to notify + compliance: User's compliance record + org_policy: Organization's MFA policy + + Returns: + True if notification was sent successfully, False otherwise + """ + try: + # Build notification content + subject = "Account Access Restricted - MFA Enrollment Required" + body = NotificationService._build_suspension_body( + user, compliance, org_policy + ) + + # Send the notification + success = NotificationService._send_email( + to_address=user.email, + subject=subject, + body=body, + ) + + if success: + logger.info(f"Sent MFA suspension notification to {user.email}") + # Audit log + AuditService.log_action( + action=AuditAction.MFA_POLICY_USER_SUSPENDED, + user_id=user.id, + organization_id=compliance.organization_id, + description="MFA compliance suspension notification sent", + ) + else: + logger.warning( + f"Failed to send MFA suspension notification to {user.email}" + ) + + return success + + except Exception as e: + logger.exception( + f"Error sending MFA suspension notification to {user.email}: {e}" + ) + return False + + @staticmethod + def _build_deadline_reminder_body( + user: User, + compliance: MfaPolicyCompliance, + org_policy: OrganizationSecurityPolicy, + days_until_deadline: int, + ) -> str: + """Build the email body for deadline reminder. + + Args: + user: User being notified + compliance: Compliance record + org_policy: Organization policy + days_until_deadline: Days remaining until deadline + + Returns: + Email body string + """ + org_name = compliance.organization_id # In real impl, fetch org name + + body = f""" +Dear {user.full_name or user.email}, + +This is a reminder that you need to set up multi-factor authentication (MFA) +to maintain access to your account in the organization "{org_name}". + +**Important Details:** +- Days remaining: {days_until_deadline} +- Deadline: {compliance.deadline_at.strftime('%Y-%m-%d %H:%M UTC') if compliance.deadline_at else 'Not set'} + +**Required MFA Methods:** +""" + + # Add required methods based on policy mode + from gatehouse_app.utils.constants import MfaPolicyMode + + mode = org_policy.mfa_policy_mode + if mode == MfaPolicyMode.REQUIRE_TOTP: + body += "- Authenticator app (TOTP)\n" + elif mode == MfaPolicyMode.REQUIRE_WEBAUTHN: + body += "- Passkey (WebAuthn)\n" + elif mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN: + body += "- Authenticator app (TOTP) OR Passkey (WebAuthn)\n" + else: + body += "- Multi-factor authentication\n" + + body += """ +**How to Set Up MFA:** +1. Log in to your account +2. Navigate to Settings > Security +3. Follow the prompts to set up an authenticator app or passkey + +If you do not set up MFA by the deadline, your account access will be restricted. + +If you have any questions, please contact your organization administrator. + +Best regards, +Gatehouse Security Team +""" + return body + + @staticmethod + def _build_suspension_body( + user: User, + compliance: MfaPolicyCompliance, + org_policy: OrganizationSecurityPolicy, + ) -> str: + """Build the email body for suspension notification. + + Args: + user: User being notified + compliance: Compliance record + org_policy: Organization policy + + Returns: + Email body string + """ + org_name = compliance.organization_id # In real impl, fetch org name + + body = f""" +Dear {user.full_name or user.email}, + +Your account access has been restricted because you did not set up +multi-factor authentication (MFA) within the required timeframe for +the organization "{org_name}". + +**What Happened:** +Your MFA compliance deadline passed without MFA being configured. +As a result, your account has been placed in a suspended state. + +**How to Restore Access:** +1. Log in to your account (you will see a compliance enrollment screen) +2. Follow the prompts to set up an authenticator app or passkey +3. Once MFA is configured, your access will be restored + +**Required MFA Methods: +""" + + # Add required methods based on policy mode + from gatehouse_app.utils.constants import MfaPolicyMode + + mode = org_policy.mfa_policy_mode + if mode == MfaPolicyMode.REQUIRE_TOTP: + body += "- Authenticator app (TOTP)\n" + elif mode == MfaPolicyMode.REQUIRE_WEBAUTHN: + body += "- Passkey (WebAuthn)\n" + elif mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN: + body += "- Authenticator app (TOTP) OR Passkey (WebAuthn)\n" + else: + body += "- Multi-factor authentication\n" + + body += """ +**Need Help?** +Contact your organization administrator if you have questions. + +Best regards, +Gatehouse Security Team +""" + return body + + @staticmethod + def _send_email( + to_address: str, + subject: str, + body: str, + html_body: Optional[str] = None, + ) -> bool: + """Send an email notification. + + This method attempts to send an email using configured SMTP settings. + If email is not configured, it logs the notification instead. + + Args: + to_address: Recipient email address + subject: Email subject + body: Plain text email body + html_body: Optional HTML email body + + Returns: + True if email was sent (or logged), False on error + """ + try: + from flask import current_app + + # Check if email is configured + email_enabled = current_app.config.get( + NotificationService.EMAIL_ENABLED_KEY, False + ) + + if not email_enabled: + # Log the notification instead of sending + logger.info( + f"[EMAIL SIMULATION] To: {to_address}\n" + f"Subject: {subject}\n" + f"Body: {body[:200]}..." if len(body) > 200 else f"Body: {body}" + ) + return True + + # Get email configuration + smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY) + smtp_port = current_app.config.get(NotificationService.SMTP_PORT_KEY, 587) + smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY) + smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY) + from_address = current_app.config.get( + NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local" + ) + + # Import send_email based on available mail library + try: + from flask_mail import Message + + from gatehouse_app import mail + + msg = Message( + subject=subject, + recipients=[to_address], + body=body, + html=html_body, + sender=from_address, + ) + mail.send(msg) + logger.info(f"Email sent successfully to {to_address}") + return True + + except ImportError: + # Flask-Mail not available, use SMTP directly + import smtplib + from email.mime.text import MIMEText + from email.mime.multipart import MIMEMultipart + + msg = MIMEMultipart("alternative") + msg["Subject"] = subject + msg["From"] = from_address + msg["To"] = to_address + + # Attach plain text and HTML versions + part1 = MIMEText(body, "plain") + msg.attach(part1) + + if html_body: + part2 = MIMEText(html_body, "html") + msg.attach(part2) + + # Send via SMTP + with smtplib.SMTP(smtp_host, smtp_port) as server: + server.starttls() + if smtp_username and smtp_password: + server.login(smtp_username, smtp_password) + server.send_message(msg) + + logger.info(f"Email sent successfully to {to_address}") + return True + + except Exception as e: + logger.exception(f"Failed to send email to {to_address}: {e}") + # Log the notification as fallback + logger.info( + f"[EMAIL FALLBACK] To: {to_address}\n" + f"Subject: {subject}\n" + f"Body: {body[:500]}..." if len(body) > 500 else f"Body: {body}" + ) + return True # Return True to continue processing + + @staticmethod + def get_notification_stats(user_id: str) -> Dict[str, Any]: + """Get notification statistics for a user. + + Args: + user_id: User ID + + Returns: + Dictionary with notification statistics + """ + from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance + + stats = { + "total_notifications": 0, + "last_notification": None, + "by_organization": [], + } + + compliance_records = MfaPolicyCompliance.query.filter_by( + user_id=user_id, deleted_at=None + ).all() + + total_notifications = 0 + last_notification = None + + for record in compliance_records: + total_notifications += record.notification_count + if record.last_notified_at: + if last_notification is None or record.last_notified_at > last_notification: + last_notification = record.last_notified_at + + stats["by_organization"].append({ + "organization_id": record.organization_id, + "notification_count": record.notification_count, + "last_notified_at": record.last_notified_at.isoformat() if record.last_notified_at else None, + }) + + stats["total_notifications"] = total_notifications + stats["last_notification"] = last_notification.isoformat() if last_notification else None + + return stats \ No newline at end of file diff --git a/gatehouse_app/utils/constants.py b/gatehouse_app/utils/constants.py index 801a652..d5866c2 100644 --- a/gatehouse_app/utils/constants.py +++ b/gatehouse_app/utils/constants.py @@ -9,6 +9,7 @@ class UserStatus(str, Enum): INACTIVE = "inactive" SUSPENDED = "suspended" PENDING = "pending" + COMPLIANCE_SUSPENDED = "compliance_suspended" class OrganizationRole(str, Enum): @@ -86,6 +87,12 @@ class AuditAction(str, Enum): WEBAUTHN_CREDENTIAL_DELETED = "webauthn.credential.deleted" WEBAUTHN_CREDENTIAL_RENAMED = "webauthn.credential.renamed" + # Security policy actions + ORG_SECURITY_POLICY_UPDATE = "org.security_policy.update" + USER_SECURITY_POLICY_OVERRIDE_UPDATE = "user.security_policy.override_update" + MFA_POLICY_USER_SUSPENDED = "mfa.policy.user_suspended" + MFA_POLICY_USER_COMPLIANT = "mfa.policy.user_compliant" + class OIDCGrantType(str, Enum): """OIDC grant types.""" @@ -116,3 +123,32 @@ class ErrorType: RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED" INTERNAL_ERROR = "INTERNAL_ERROR" BAD_REQUEST = "BAD_REQUEST" + + +class MfaPolicyMode(str, Enum): + """MFA policy mode for organizations.""" + + DISABLED = "disabled" + OPTIONAL = "optional" + REQUIRE_TOTP = "require_totp" + REQUIRE_WEBAUTHN = "require_webauthn" + REQUIRE_TOTP_OR_WEBAUTHN = "require_totp_or_webauthn" + + +class MfaComplianceStatus(str, Enum): + """MFA compliance status for users per organization.""" + + NOT_APPLICABLE = "not_applicable" + PENDING = "pending" + IN_GRACE = "in_grace" + COMPLIANT = "compliant" + PAST_DUE = "past_due" + SUSPENDED = "suspended" + + +class MfaRequirementOverride(str, Enum): + """User override for organization MFA requirements.""" + + INHERIT = "inherit" + REQUIRED = "required" + EXEMPT = "exempt" diff --git a/gatehouse_app/utils/decorators.py b/gatehouse_app/utils/decorators.py index 5b571f2..e3b0085 100644 --- a/gatehouse_app/utils/decorators.py +++ b/gatehouse_app/utils/decorators.py @@ -2,7 +2,8 @@ from functools import wraps from flask import request, g from gatehouse_app.utils.response import api_response -from gatehouse_app.utils.constants import OrganizationRole +from gatehouse_app.utils.constants import OrganizationRole, UserStatus +from gatehouse_app.exceptions.auth_exceptions import UnauthorizedError, ForbiddenError def login_required(f): @@ -127,3 +128,41 @@ def require_owner(f): def require_admin(f): """Decorator to require organization admin or owner role.""" return require_role(OrganizationRole.OWNER, OrganizationRole.ADMIN)(f) + + +def full_access_required(f): + """Decorator to require full access session (not compliance-only). + + This decorator checks if the user has a compliance-only session or + is in COMPLIANCE_SUSPENDED status. If so, it returns a 403 error + with error_type "MFA_COMPLIANCE_REQUIRED". + + Use this decorator on endpoints that require full MFA compliance. + Endpoints for MFA enrollment, status, and logout should NOT use this decorator. + """ + @wraps(f) + def decorated_function(*args, **kwargs): + user = getattr(g, "current_user", None) + session = getattr(g, "current_session", None) + + if not user or not session: + return api_response( + success=False, + message="Authentication required", + status=401, + error_type="AUTH_REQUIRED", + ) + + # Check for compliance-only session or compliance suspended status + if session.is_compliance_only or user.status == UserStatus.COMPLIANCE_SUSPENDED: + return api_response( + success=False, + message="MFA compliance required to access this resource", + status=403, + error_type="MFA_COMPLIANCE_REQUIRED", + error_details={"overall_status": "suspended"}, + ) + + return f(*args, **kwargs) + + return decorated_function diff --git a/manage.py b/manage.py index 1e2c91d..088447e 100644 --- a/manage.py +++ b/manage.py @@ -14,5 +14,102 @@ app = create_app(os.getenv("FLASK_ENV", "development")) # Create Flask CLI group cli = FlaskGroup(create_app=lambda: app) + +@cli.command("run_mfa_compliance_job") +def run_mfa_compliance_job(): + """Run the MFA compliance scheduled job. + + This command processes MFA compliance transitions: + - Transitions users from PAST_DUE to SUSPENDED status + - Sends deadline reminder notifications + - Updates notification tracking metadata + + Usage: + python manage.py run_mfa_compliance_job + + This can be called via cron or a task scheduler: + 0 * * * * cd /path/to/app && python manage.py run_mfa_compliance_job + """ + from datetime import datetime, timezone + from gatehouse_app.jobs.mfa_compliance_job import process_mfa_compliance, get_job_status + + print("=" * 60) + print("MFA Compliance Job") + print("=" * 60) + + now = datetime.now(timezone.utc) + print(f"Start time: {now.isoformat()}") + print() + + # Show current status before processing + print("Current Compliance Status:") + status = get_job_status(now) + for status_name, count in status["status_counts"].items(): + print(f" {status_name}: {count}") + print(f" Approaching deadline: {status['approaching_deadline_count']}") + print(f" Past due: {status['past_due_count']}") + print() + + # Run the job + print("Processing compliance...") + result = process_mfa_compliance(now) + + print() + print("Job Results:") + print(f" Users suspended: {result['suspended_count']}") + print(f" Notifications sent: {result['notified_count']}") + print(f" Records processed: {result['processed_count']}") + + if result['errors']: + print() + print("Errors:") + for error in result['errors']: + print(f" - {error}") + + print() + print("=" * 60) + print("Job completed successfully") + print("=" * 60) + + +@cli.command("mfa_compliance_status") +def mfa_compliance_status(): + """Show current MFA compliance status. + + Usage: + python manage.py mfa_compliance_status + """ + from datetime import datetime, timezone + from gatehouse_app.jobs.mfa_compliance_job import get_job_status + + print("=" * 60) + print("MFA Compliance Status Report") + print("=" * 60) + + now = datetime.now(timezone.utc) + status = get_job_status(now) + + print(f"Report time: {status['timestamp']}") + print() + + print("Compliance Records by Status:") + for status_name, count in sorted(status["status_counts"].items()): + bar = "â–ˆ" * min(count, 50) + print(f" {status_name:20s}: {count:5d} {bar}") + + print() + print("Summary:") + print(f" Approaching deadline: {status['approaching_deadline_count']}") + print(f" Past due (pending suspension): {status['past_due_count']}") + + total = sum(status["status_counts"].values()) + compliant = status["status_counts"].get("compliant", 0) + if total > 0: + compliance_rate = (compliant / total) * 100 + print(f" Compliance rate: {compliance_rate:.1f}%") + + print("=" * 60) + + if __name__ == "__main__": cli() diff --git a/migrations/001_create_oidc_tables.py b/migrations/001_create_oidc_tables.py deleted file mode 100644 index 9c6fba3..0000000 --- a/migrations/001_create_oidc_tables.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Database migration: Create OIDC tables. - -Revision ID: 001 -Revises: -Create Date: 2024-01-01 00:00:00 - -This migration creates all OIDC-related tables for the authorization code flow, -refresh token management, OIDC session tracking, token metadata, and audit logging. -""" - -from alembic import op -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - -# Revision identifiers -revision = '001' -down_revision = None -branch_labels = None -depends_on = None - - -def upgrade(): - """Create OIDC tables.""" - - # OIDC Authorization Codes table - op.create_table( - 'oidc_authorization_codes', - sa.Column('id', sa.String(36), primary_key=True), - sa.Column('created_at', sa.DateTime, nullable=False), - sa.Column('updated_at', sa.DateTime, nullable=False), - sa.Column('deleted_at', sa.DateTime, nullable=True), - sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False), - sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False), - sa.Column('code_hash', sa.String(255), nullable=False), - sa.Column('redirect_uri', sa.String(512), nullable=False), - sa.Column('scope', postgresql.JSON, nullable=True), - sa.Column('nonce', sa.String(255), nullable=True), - sa.Column('code_verifier', sa.String(255), nullable=True), - sa.Column('expires_at', sa.DateTime, nullable=False), - sa.Column('used_at', sa.DateTime, nullable=True), - sa.Column('is_used', sa.Boolean, default=False, nullable=False), - sa.Column('ip_address', sa.String(45), nullable=True), - sa.Column('user_agent', sa.Text, nullable=True), - ) - op.create_index('ix_oidc_authorization_codes_client_id', 'oidc_authorization_codes', ['client_id']) - op.create_index('ix_oidc_authorization_codes_user_id', 'oidc_authorization_codes', ['user_id']) - op.create_index('ix_oidc_authorization_codes_expires_at', 'oidc_authorization_codes', ['expires_at']) - - # OIDC Refresh Tokens table - op.create_table( - 'oidc_refresh_tokens', - sa.Column('id', sa.String(36), primary_key=True), - sa.Column('created_at', sa.DateTime, nullable=False), - sa.Column('updated_at', sa.DateTime, nullable=False), - sa.Column('deleted_at', sa.DateTime, nullable=True), - sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False), - sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False), - sa.Column('token_hash', sa.String(255), nullable=False), - sa.Column('access_token_id', sa.String(36), sa.ForeignKey('sessions.id'), nullable=True), - sa.Column('scope', postgresql.JSON, nullable=True), - sa.Column('expires_at', sa.DateTime, nullable=False), - sa.Column('revoked_at', sa.DateTime, nullable=True), - sa.Column('revoked_reason', sa.String(255), nullable=True), - sa.Column('previous_token_hash', sa.String(255), nullable=True), - sa.Column('rotation_count', sa.Integer, default=0, nullable=False), - sa.Column('ip_address', sa.String(45), nullable=True), - sa.Column('user_agent', sa.Text, nullable=True), - ) - op.create_index('ix_oidc_refresh_tokens_client_id', 'oidc_refresh_tokens', ['client_id']) - op.create_index('ix_oidc_refresh_tokens_user_id', 'oidc_refresh_tokens', ['user_id']) - op.create_index('ix_oidc_refresh_tokens_token_hash', 'oidc_refresh_tokens', ['token_hash'], unique=True) - op.create_index('ix_oidc_refresh_tokens_access_token_id', 'oidc_refresh_tokens', ['access_token_id']) - op.create_index('ix_oidc_refresh_tokens_expires_at', 'oidc_refresh_tokens', ['expires_at']) - - # OIDC Sessions table - op.create_table( - 'oidc_sessions', - sa.Column('id', sa.String(36), primary_key=True), - sa.Column('created_at', sa.DateTime, nullable=False), - sa.Column('updated_at', sa.DateTime, nullable=False), - sa.Column('deleted_at', sa.DateTime, nullable=True), - sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False), - sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False), - sa.Column('state', sa.String(255), nullable=False), - sa.Column('nonce', sa.String(255), nullable=True), - sa.Column('redirect_uri', sa.String(512), nullable=False), - sa.Column('scope', postgresql.JSON, nullable=True), - sa.Column('code_challenge', sa.String(255), nullable=True), - sa.Column('code_challenge_method', sa.String(10), nullable=True), - sa.Column('expires_at', sa.DateTime, nullable=False), - sa.Column('authenticated_at', sa.DateTime, nullable=True), - ) - op.create_index('ix_oidc_sessions_user_id', 'oidc_sessions', ['user_id']) - op.create_index('ix_oidc_sessions_client_id', 'oidc_sessions', ['client_id']) - op.create_index('ix_oidc_sessions_state', 'oidc_sessions', ['state']) - op.create_index('ix_oidc_sessions_expires_at', 'oidc_sessions', ['expires_at']) - - # OIDC Token Metadata table - op.create_table( - 'oidc_token_metadata', - sa.Column('id', sa.String(36), primary_key=True), - sa.Column('created_at', sa.DateTime, nullable=False), - sa.Column('updated_at', sa.DateTime, nullable=False), - sa.Column('deleted_at', sa.DateTime, nullable=True), - sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False), - sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False), - sa.Column('token_type', sa.String(50), nullable=False), - sa.Column('token_jti', sa.String(255), nullable=False), - sa.Column('expires_at', sa.DateTime, nullable=False), - sa.Column('revoked_at', sa.DateTime, nullable=True), - sa.Column('revoked_reason', sa.String(255), nullable=True), - ) - op.create_index('ix_oidc_token_metadata_client_id', 'oidc_token_metadata', ['client_id']) - op.create_index('ix_oidc_token_metadata_user_id', 'oidc_token_metadata', ['user_id']) - op.create_index('ix_oidc_token_metadata_token_jti', 'oidc_token_metadata', ['token_jti']) - op.create_index('ix_oidc_token_metadata_expires_at', 'oidc_token_metadata', ['expires_at']) - - # OIDC Audit Logs table - op.create_table( - 'oidc_audit_logs', - sa.Column('id', sa.String(36), primary_key=True), - sa.Column('created_at', sa.DateTime, nullable=False), - sa.Column('updated_at', sa.DateTime, nullable=False), - sa.Column('deleted_at', sa.DateTime, nullable=True), - sa.Column('event_type', sa.String(100), nullable=False), - sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=True), - sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=True), - sa.Column('success', sa.Boolean, default=True, nullable=False), - sa.Column('error_code', sa.String(100), nullable=True), - sa.Column('error_description', sa.Text, nullable=True), - sa.Column('ip_address', sa.String(45), nullable=True), - sa.Column('user_agent', sa.Text, nullable=True), - sa.Column('request_id', sa.String(36), nullable=True), - sa.Column('event_metadata', postgresql.JSON, nullable=True), - ) - op.create_index('ix_oidc_audit_logs_event_type', 'oidc_audit_logs', ['event_type']) - op.create_index('ix_oidc_audit_logs_client_id', 'oidc_audit_logs', ['client_id']) - op.create_index('ix_oidc_audit_logs_user_id', 'oidc_audit_logs', ['user_id']) - op.create_index('ix_oidc_audit_logs_success', 'oidc_audit_logs', ['success']) - op.create_index('ix_oidc_audit_logs_ip_address', 'oidc_audit_logs', ['ip_address']) - op.create_index('ix_oidc_audit_logs_request_id', 'oidc_audit_logs', ['request_id']) - - -def downgrade(): - """Drop OIDC tables.""" - op.drop_table('oidc_audit_logs') - op.drop_table('oidc_token_metadata') - op.drop_table('oidc_sessions') - op.drop_table('oidc_refresh_tokens') - op.drop_table('oidc_authorization_codes') diff --git a/migrations/002_add_webauthn_support.py b/migrations/002_add_webauthn_support.py deleted file mode 100644 index 45d21e1..0000000 --- a/migrations/002_add_webauthn_support.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Database migration: Add WebAuthn support. - -Revision ID: 002 -Revises: 001 -Create Date: 2024-01-15 00:00:00 - -This migration adds support for WebAuthn passkey authentication by: -- Adding WEBAUTHN to the AuthMethodType enum (handled in application code) -- No schema changes required (uses existing provider_data JSON field) -""" - -from alembic import op -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - -# Revision identifiers -revision = '002' -down_revision = '001' -branch_labels = None -depends_on = None - - -def upgrade(): - """Add WebAuthn support - no schema changes needed.""" - # WebAuthn credentials are stored in the existing provider_data JSON field - # of the authentication_methods table. No schema changes are required. - - # Create an index for faster lookups of WebAuthn credentials by user - # This is optional but recommended for performance - # op.create_index( - # 'ix_authentication_methods_webauthn_user', - # 'authentication_methods', - # ['user_id'], - # postgresql_where=(sa.text("method_type = 'webauthn'")), - # if_not_exists=True - # ) - - pass - - -def downgrade(): - """Remove WebAuthn support - no schema changes needed.""" - # No schema changes to revert - pass \ No newline at end of file diff --git a/tests/integration/test_mfa_compliance.py b/tests/integration/test_mfa_compliance.py new file mode 100644 index 0000000..5e6bbfe --- /dev/null +++ b/tests/integration/test_mfa_compliance.py @@ -0,0 +1,933 @@ +"""Integration tests for MFA compliance enforcement.""" +import pytest +import json +from datetime import datetime, timezone, timedelta +from gatehouse_app.models.user import User +from gatehouse_app.models.organization import Organization +from gatehouse_app.models.organization_member import OrganizationMember +from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy +from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance +from gatehouse_app.models.user_security_policy import UserSecurityPolicy +from gatehouse_app.models.session import Session +from gatehouse_app.utils.constants import MfaPolicyMode, MfaComplianceStatus, UserStatus, MfaRequirementOverride +from gatehouse_app.services.mfa_policy_service import MfaPolicyService + + +@pytest.mark.integration +class TestMfaComplianceLogin: + """Integration tests for MFA compliance during login.""" + + def test_login_with_no_policy(self, client, db, test_user): + """Test login with no MFA policy (should work normally).""" + login_data = { + "email": test_user.email, + "password": "TestPassword123!", + } + + response = client.post( + "/api/v1/auth/login", + data=json.dumps(login_data), + content_type="application/json", + ) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + assert "user" in data["data"] + assert "token" in data["data"] + # No MFA compliance info should be present when no policy exists + assert "mfa_compliance" not in data["data"] + assert "requires_mfa_enrollment" not in data["data"] + + def test_login_with_optional_policy(self, client, db, test_user, test_organization): + """Test login with optional MFA policy (should work normally).""" + # Create an optional MFA policy + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.OPTIONAL, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + db.session.commit() + + login_data = { + "email": test_user.email, + "password": "TestPassword123!", + } + + response = client.post( + "/api/v1/auth/login", + data=json.dumps(login_data), + content_type="application/json", + ) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + assert "user" in data["data"] + assert "token" in data["data"] + # MFA compliance should be present but status should be not_applicable + assert "mfa_compliance" in data["data"] + assert data["data"]["mfa_compliance"]["overall_status"] == "not_applicable" + assert "requires_mfa_enrollment" not in data["data"] + + def test_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization): + """Test login with required policy within grace period (should work with warning).""" + # Create a required MFA policy + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + db.session.commit() + + login_data = { + "email": test_user.email, + "password": "TestPassword123!", + } + + response = client.post( + "/api/v1/auth/login", + data=json.dumps(login_data), + content_type="application/json", + ) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + assert "user" in data["data"] + assert "token" in data["data"] + # MFA compliance should be present with in_grace status + assert "mfa_compliance" in data["data"] + assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace" + assert "requires_mfa_enrollment" not in data["data"] + assert "totp" in data["data"]["mfa_compliance"]["missing_methods"] + + def test_login_with_required_policy_after_deadline(self, client, db, test_user, test_organization): + """Test login with required policy after deadline (should get compliance-only session).""" + # Create a required MFA policy with past deadline + past_deadline = datetime.now(timezone.utc) - timedelta(days=1) + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + + # Create compliance record past due + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.PAST_DUE, + policy_version=1, + applied_at=datetime.now(timezone.utc) - timedelta(days=15), + deadline_at=past_deadline, + ) + db.session.add(compliance) + db.session.commit() + + login_data = { + "email": test_user.email, + "password": "TestPassword123!", + } + + response = client.post( + "/api/v1/auth/login", + data=json.dumps(login_data), + content_type="application/json", + ) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + assert "user" in data["data"] + assert "token" in data["data"] + # Should have compliance-only session + assert data["data"]["requires_mfa_enrollment"] is True + assert "mfa_compliance" in data["data"] + assert data["data"]["mfa_compliance"]["overall_status"] in ["past_due", "suspended"] + + def test_login_with_suspended_user(self, client, db, test_user, test_organization): + """Test login with compliance suspended user (should get compliance-only session).""" + # Set user status to compliance suspended + test_user.status = UserStatus.COMPLIANCE_SUSPENDED + db.session.commit() + + login_data = { + "email": test_user.email, + "password": "TestPassword123!", + } + + response = client.post( + "/api/v1/auth/login", + data=json.dumps(login_data), + content_type="application/json", + ) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + assert "user" in data["data"] + assert "token" in data["data"] + # Should have compliance-only session + assert data["data"]["requires_mfa_enrollment"] is True + + +@pytest.mark.integration +class TestMfaComplianceAccess: + """Integration tests for MFA compliance access control.""" + + def test_compliance_only_session_denied_full_access(self, client, db, test_user, test_organization): + """Test that compliance-only session cannot access full access endpoints.""" + # Create a required MFA policy with past deadline + past_deadline = datetime.now(timezone.utc) - timedelta(days=1) + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + + # Create compliance record past due + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.PAST_DUE, + policy_version=1, + applied_at=datetime.now(timezone.utc) - timedelta(days=15), + deadline_at=past_deadline, + ) + db.session.add(compliance) + + # Create a compliance-only session + session = Session( + user_id=test_user.id, + token="compliance_only_token", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + is_compliance_only=True, + ) + db.session.add(session) + db.session.commit() + + # Try to access a full-access endpoint (get_my_organizations) + response = client.get( + "/api/v1/users/me/organizations", + headers={"Authorization": "Bearer compliance_only_token"}, + ) + + assert response.status_code == 403 + data = response.get_json() + assert data["success"] is False + assert data["error_type"] == "MFA_COMPLIANCE_REQUIRED" + + def test_compliance_only_session_can_access_mfa_enrollment(self, client, db, test_user, test_organization): + """Test that compliance-only session can access MFA enrollment endpoints.""" + # Create a required MFA policy with past deadline + past_deadline = datetime.now(timezone.utc) - timedelta(days=1) + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + + # Create compliance record past due + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.PAST_DUE, + policy_version=1, + applied_at=datetime.now(timezone.utc) - timedelta(days=15), + deadline_at=past_deadline, + ) + db.session.add(compliance) + + # Create a compliance-only session + session = Session( + user_id=test_user.id, + token="compliance_only_token", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + is_compliance_only=True, + ) + db.session.add(session) + db.session.commit() + + # Try to access MFA enrollment endpoint (should work) + response = client.get( + "/api/v1/auth/totp/status", + headers={"Authorization": "Bearer compliance_only_token"}, + ) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + + def test_compliance_only_session_can_access_logout(self, client, db, test_user, test_organization): + """Test that compliance-only session can access logout endpoint.""" + # Create a required MFA policy with past deadline + past_deadline = datetime.now(timezone.utc) - timedelta(days=1) + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + + # Create compliance record past due + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.PAST_DUE, + policy_version=1, + applied_at=datetime.now(timezone.utc) - timedelta(days=15), + deadline_at=past_deadline, + ) + db.session.add(compliance) + + # Create a compliance-only session + session = Session( + user_id=test_user.id, + token="compliance_only_token", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + is_compliance_only=True, + ) + db.session.add(session) + db.session.commit() + + # Try to access logout endpoint (should work) + response = client.post( + "/api/v1/auth/logout", + headers={"Authorization": "Bearer compliance_only_token"}, + ) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + + +@pytest.mark.integration +class TestMfaComplianceWebAuthn: + """Integration tests for MFA compliance with WebAuthn login.""" + + def test_webauthn_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization): + """Test WebAuthn login with required policy within grace period.""" + # Create a required MFA policy + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + db.session.commit() + + # Note: Full WebAuthn login test would require WebAuthn setup + # This test verifies the compliance response structure + login_data = { + "email": test_user.email, + "password": "TestPassword123!", + } + + response = client.post( + "/api/v1/auth/login", + data=json.dumps(login_data), + content_type="application/json", + ) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + assert "mfa_compliance" in data["data"] + assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace" + + +@pytest.mark.integration +class TestMfaComplianceOIDC: + """Integration tests for MFA compliance with OIDC authorization.""" + + def test_oidc_authorize_with_compliance_required(self, client, db, test_user, test_organization, app): + """Test OIDC authorize with compliance required (should show error).""" + # Create a required MFA policy with past deadline + past_deadline = datetime.now(timezone.utc) - timedelta(days=1) + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + + # Create compliance record past due + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.PAST_DUE, + policy_version=1, + applied_at=datetime.now(timezone.utc) - timedelta(days=15), + deadline_at=past_deadline, + ) + db.session.add(compliance) + db.session.commit() + + # Try OIDC authorize with credentials + response = client.post( + "/oidc/authorize", + data={ + "client_id": "test_client", + "redirect_uri": "http://localhost:8080/callback", + "response_type": "code", + "scope": "openid profile email", + "state": "test_state", + "email": test_user.email, + "password": "TestPassword123!", + }, + ) + + # Should return login page with error + assert response.status_code == 200 + assert b"Your account requires multi factor enrollment before using single sign on" in response.data + + +# ============================================================================= +# Phase 4: Edge Case Tests +# ============================================================================= + + +@pytest.mark.integration +class TestMfaComplianceMultiOrg: + """Integration tests for multi-organization MFA compliance edge cases.""" + + def test_user_with_multiple_orgs_different_policies(self, client, db, test_user): + """Test user belonging to multiple orgs with different MFA policies.""" + # Create two organizations + org1 = Organization( + name="Org1", + slug="org1-test-multi", + ) + org2 = Organization( + name="Org2", + slug="org2-test-multi", + ) + db.session.add_all([org1, org2]) + db.session.commit() + + # Add user to both orgs + membership1 = OrganizationMember( + user_id=test_user.id, + organization_id=org1.id, + role="member", + ) + membership2 = OrganizationMember( + user_id=test_user.id, + organization_id=org2.id, + role="member", + ) + db.session.add_all([membership1, membership2]) + db.session.commit() + + # Create different policies for each org + # Org1: OPTIONAL (no requirement) + policy1 = OrganizationSecurityPolicy( + organization_id=org1.id, + mfa_policy_mode=MfaPolicyMode.OPTIONAL, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + # Org2: REQUIRE_TOTP (strictest) + policy2 = OrganizationSecurityPolicy( + organization_id=org2.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add_all([policy1, policy2]) + db.session.commit() + + # Evaluate user MFA state + compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user) + + # Overall status should reflect the strictest policy (REQUIRE_TOTP from org2) + assert compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value + assert "totp" in compliance_summary.missing_methods + + # Verify per-org breakdown + assert len(compliance_summary.orgs) == 2 + org1_status = next((o for o in compliance_summary.orgs if o.organization_id == org1.id), None) + org2_status = next((o for o in compliance_summary.orgs if o.organization_id == org2.id), None) + + assert org1_status is not None + assert org2_status is not None + assert org1_status.status == MfaComplianceStatus.NOT_APPLICABLE.value + assert org2_status.status == MfaComplianceStatus.IN_GRACE.value + + def test_user_with_multiple_orgs_all_suspended(self, client, db, test_user): + """Test user with multiple orgs where all require MFA and are past due.""" + # Create two organizations + org1 = Organization( + name="Org1", + slug="org1-test-suspended", + ) + org2 = Organization( + name="Org2", + slug="org2-test-suspended", + ) + db.session.add_all([org1, org2]) + db.session.commit() + + # Add user to both orgs + membership1 = OrganizationMember( + user_id=test_user.id, + organization_id=org1.id, + role="member", + ) + membership2 = OrganizationMember( + user_id=test_user.id, + organization_id=org2.id, + role="member", + ) + db.session.add_all([membership1, membership2]) + db.session.commit() + + # Create required policies + policy1 = OrganizationSecurityPolicy( + organization_id=org1.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + policy2 = OrganizationSecurityPolicy( + organization_id=org2.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add_all([policy1, policy2]) + db.session.commit() + + # Create past-due compliance records for both + past_deadline = datetime.now(timezone.utc) - timedelta(days=1) + compliance1 = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=org1.id, + status=MfaComplianceStatus.SUSPENDED, + policy_version=1, + applied_at=datetime.now(timezone.utc) - timedelta(days=30), + deadline_at=past_deadline, + suspended_at=past_deadline, + ) + compliance2 = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=org2.id, + status=MfaComplianceStatus.SUSPENDED, + policy_version=1, + applied_at=datetime.now(timezone.utc) - timedelta(days=30), + deadline_at=past_deadline, + suspended_at=past_deadline, + ) + db.session.add_all([compliance1, compliance2]) + db.session.commit() + + # Evaluate user MFA state + compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user) + + # Overall status should be SUSPENDED + assert compliance_summary.overall_status == MfaComplianceStatus.SUSPENDED.value + + def test_strictest_mode_selection(self): + """Test that get_strictest_mode returns the most restrictive policy.""" + modes = [ + MfaPolicyMode.DISABLED.value, + MfaPolicyMode.OPTIONAL.value, + MfaPolicyMode.REQUIRE_TOTP.value, + ] + result = MfaPolicyService.get_strictest_mode(modes) + assert result == MfaPolicyMode.REQUIRE_TOTP.value + + # Test with REQUIRE_TOTP_OR_WEBAUTHN (strictest) + modes_strictest = [ + MfaPolicyMode.REQUIRE_TOTP.value, + MfaPolicyMode.REQUIRE_WEBAUTHN.value, + MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value, + ] + result = MfaPolicyService.get_strictest_mode(modes_strictest) + assert result == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value + + +@pytest.mark.integration +class TestMfaComplianceUserOverrides: + """Integration tests for user override edge cases.""" + + def test_user_override_inherit_mode(self, client, db, test_user, test_organization): + """Test INHERIT mode - org policy applies as is.""" + # Create a required policy + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + db.session.commit() + + # Create INHERIT override (default behavior) + override = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.INHERIT, + ) + db.session.add(override) + db.session.commit() + + # Get effective policy + effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id) + + # Should inherit org policy + assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value + assert effective.requires_totp is True + assert effective.is_exempt is False + + def test_user_override_required_mode(self, client, db, test_user, test_organization): + """Test REQUIRED mode - user always required to have MFA.""" + # Create an optional policy + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.OPTIONAL, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + db.session.commit() + + # Create REQUIRED override + override = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.REQUIRED, + ) + db.session.add(override) + db.session.commit() + + # Get effective policy + effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id) + + # Should be upgraded to REQUIRE_TOTP_OR_WEBAUTHN + assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value + assert effective.requires_totp is True + assert effective.requires_webauthn is True + assert effective.is_exempt is False + + def test_user_override_exempt_mode(self, client, db, test_user, test_organization): + """Test EXEMPT mode - org policy does not apply.""" + # Create a required policy + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + db.session.commit() + + # Create EXEMPT override + override = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.EXEMPT, + ) + db.session.add(override) + db.session.commit() + + # Get effective policy + effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id) + + # Should be exempt from policy + assert effective.is_exempt is True + assert effective.effective_mode == MfaPolicyMode.DISABLED.value + assert effective.requires_totp is False + assert effective.requires_webauthn is False + + def test_get_override_summary(self, client, db, test_user, test_organization): + """Test getting override summary for a user.""" + # No override exists + summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id) + + assert summary["has_override"] is False + assert summary["mode"] == "inherit" + + # Create an override + override = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.EXEMPT, + ) + db.session.add(override) + db.session.commit() + + # Get summary again + summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id) + + assert summary["has_override"] is True + assert summary["mode"] == "exempt" + assert summary["is_exempt"] is True + + +@pytest.mark.integration +class TestMfaCompliancePolicyChanges: + """Integration tests for policy changes affecting existing users.""" + + def test_policy_change_triggers_compliance_reevaluation(self, client, db, test_user, test_organization): + """Test that policy change triggers compliance reevaluation.""" + # Create initial optional policy + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.OPTIONAL, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + db.session.commit() + + # Create compliance record (should be NOT_APPLICABLE) + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.NOT_APPLICABLE, + policy_version=1, + ) + db.session.add(compliance) + db.session.commit() + + # Update policy to REQUIRE_TOTP + MfaPolicyService.create_org_policy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + mfa_grace_period_days=14, + notify_days_before=7, + updated_by_user_id=test_user.id, + ) + + # Reevaluate all compliance + updated_count = MfaPolicyService.reevaluate_all_org_compliance(test_organization.id) + + # Should have updated at least one record + assert updated_count >= 1 + + # Check compliance status was updated + updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id) + assert updated_compliance.status == MfaComplianceStatus.IN_GRACE.value + assert updated_compliance.deadline_at is not None + + def test_policy_relaxation_clears_requirements(self, client, db, test_user, test_organization): + """Test that relaxing policy clears compliance requirements.""" + # Create required policy + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + db.session.commit() + + # Create IN_GRACE compliance record + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.IN_GRACE, + policy_version=1, + applied_at=datetime.now(timezone.utc), + deadline_at=datetime.now(timezone.utc) + timedelta(days=14), + ) + db.session.add(compliance) + db.session.commit() + + # Update policy to OPTIONAL + MfaPolicyService.create_org_policy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.OPTIONAL, + mfa_grace_period_days=14, + notify_days_before=7, + updated_by_user_id=test_user.id, + ) + + # Reevaluate compliance + MfaPolicyService.reevaluate_all_org_compliance(test_organization.id) + + # Check compliance status was updated to NOT_APPLICABLE + updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id) + assert updated_compliance.status == MfaComplianceStatus.NOT_APPLICABLE.value + + +@pytest.mark.integration +class TestMfaComplianceScheduledJob: + """Integration tests for the MFA compliance scheduled job.""" + + def test_transition_to_suspended(self, client, db, test_user, test_organization): + """Test that past-due users are transitioned to suspended.""" + # Create required policy + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + db.session.commit() + + # Create past-due compliance record + past_deadline = datetime.now(timezone.utc) - timedelta(hours=1) + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.PAST_DUE, + policy_version=1, + applied_at=datetime.now(timezone.utc) - timedelta(days=15), + deadline_at=past_deadline, + ) + db.session.add(compliance) + db.session.commit() + + # Run the job + now = datetime.now(timezone.utc) + suspended_count = MfaPolicyService.transition_to_suspended_if_past_due(now) + + # Should have suspended the user + assert suspended_count >= 1 + + # Check compliance status + updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id) + assert updated_compliance.status == MfaComplianceStatus.SUSPENDED.value + assert updated_compliance.suspended_at is not None + + # Check user status + db.refresh(test_user) + assert test_user.status == UserStatus.COMPLIANCE_SUSPENDED + + def test_check_and_restore_user_status(self, client, db, test_user, test_organization): + """Test that suspended users are restored when they become compliant.""" + # Create required policy + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add(policy) + db.session.commit() + + # User is suspended + test_user.status = UserStatus.COMPLIANCE_SUSPENDED + db.session.commit() + + # Create EXEMPT override to clear requirement + override = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.EXEMPT, + ) + db.session.add(override) + db.session.commit() + + # Check and restore status + restored = MfaPolicyService.check_and_restore_user_status(test_user.id) + + # Should have restored user + assert restored is True + db.refresh(test_user) + assert test_user.status == UserStatus.ACTIVE + + +@pytest.mark.integration +class TestMfaComplianceMultiOrgAggregate: + """Integration tests for multi-org aggregate state calculation.""" + + def test_get_multi_org_aggregate_state(self, client, db, test_user): + """Test aggregate state calculation for multi-org user.""" + # Create two organizations + org1 = Organization( + name="AggOrg1", + slug="agg-org1-test", + ) + org2 = Organization( + name="AggOrg2", + slug="agg-org2-test", + ) + db.session.add_all([org1, org2]) + db.session.commit() + + # Add user to both + membership1 = OrganizationMember( + user_id=test_user.id, + organization_id=org1.id, + role="member", + ) + membership2 = OrganizationMember( + user_id=test_user.id, + organization_id=org2.id, + role="member", + ) + db.session.add_all([membership1, membership2]) + db.session.commit() + + # Create policies + policy1 = OrganizationSecurityPolicy( + organization_id=org1.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + policy2 = OrganizationSecurityPolicy( + organization_id=org2.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + db.session.add_all([policy1, policy2]) + db.session.commit() + + # Get aggregate state + aggregate = MfaPolicyService.get_multi_org_aggregate_state(test_user) + + # Verify structure + assert "overall_status" in aggregate + assert "strictest_mode" in aggregate + assert "missing_methods" in aggregate + assert "requiring_org_count" in aggregate + assert "requiring_orgs" in aggregate + assert "per_org_details" in aggregate + + # Strictest mode should be REQUIRE_TOTP_OR_WEBAUTHN + assert aggregate["strictest_mode"] == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value + + # Both orgs should require MFA + assert aggregate["requiring_org_count"] == 2 + assert len(aggregate["requiring_orgs"]) == 2 + assert len(aggregate["per_org_details"]) == 2 \ No newline at end of file diff --git a/tests/unit/test_mfa_policy_models.py b/tests/unit/test_mfa_policy_models.py new file mode 100644 index 0000000..07dcdd1 --- /dev/null +++ b/tests/unit/test_mfa_policy_models.py @@ -0,0 +1,295 @@ +"""Unit tests for MFA policy models.""" +import pytest +from datetime import datetime, timezone, timedelta +from gatehouse_app.models import ( + User, + Organization, + OrganizationMember, + OrganizationSecurityPolicy, + UserSecurityPolicy, + MfaPolicyCompliance, + Session, +) +from gatehouse_app.utils.constants import ( + UserStatus, + MfaPolicyMode, + MfaComplianceStatus, + MfaRequirementOverride, + SessionStatus, + OrganizationRole, +) + + +@pytest.mark.unit +class TestOrganizationSecurityPolicyModel: + """Tests for OrganizationSecurityPolicy model.""" + + def test_create_org_security_policy(self, db, test_organization): + """Test creating an organization security policy.""" + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.OPTIONAL, + mfa_grace_period_days=14, + notify_days_before=7, + ) + policy.save() + + assert policy.id is not None + assert policy.organization_id == test_organization.id + assert policy.mfa_policy_mode == MfaPolicyMode.OPTIONAL + assert policy.mfa_grace_period_days == 14 + assert policy.notify_days_before == 7 + assert policy.policy_version == 1 + assert policy.created_at is not None + + def test_org_security_policy_to_dict(self, db, test_organization): + """Test organization security policy to_dict method.""" + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=7, + notify_days_before=3, + ) + policy.save() + + policy_dict = policy.to_dict() + + assert "id" in policy_dict + assert "organization_id" in policy_dict + assert policy_dict["organization_id"] == test_organization.id + assert "mfa_policy_mode" in policy_dict + assert "mfa_grace_period_days" in policy_dict + + def test_org_security_policy_relationships(self, db, test_organization): + """Test organization security policy relationships.""" + policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + ) + policy.save() + + # Test relationship + assert policy.organization is not None + assert policy.organization.id == test_organization.id + + +@pytest.mark.unit +class TestUserSecurityPolicyModel: + """Tests for UserSecurityPolicy model.""" + + def test_create_user_security_policy(self, db, test_user, test_organization): + """Test creating a user security policy.""" + policy = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.INHERIT, + ) + policy.save() + + assert policy.id is not None + assert policy.user_id == test_user.id + assert policy.organization_id == test_organization.id + assert policy.mfa_override_mode == MfaRequirementOverride.INHERIT + assert policy.force_totp is False + assert policy.force_webauthn is False + + def test_user_security_policy_with_overrides(self, db, test_user, test_organization): + """Test user security policy with override settings.""" + policy = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.REQUIRED, + force_totp=True, + force_webauthn=False, + ) + policy.save() + + assert policy.mfa_override_mode == MfaRequirementOverride.REQUIRED + assert policy.force_totp is True + assert policy.force_webauthn is False + + def test_user_security_policy_exempt(self, db, test_user, test_organization): + """Test user security policy with exempt override.""" + policy = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.EXEMPT, + ) + policy.save() + + assert policy.mfa_override_mode == MfaRequirementOverride.EXEMPT + + def test_user_security_policy_relationships(self, db, test_user, test_organization): + """Test user security policy relationships.""" + policy = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.INHERIT, + ) + policy.save() + + # Test relationships + assert policy.user is not None + assert policy.user.id == test_user.id + assert policy.organization is not None + assert policy.organization.id == test_organization.id + + +@pytest.mark.unit +class TestMfaPolicyComplianceModel: + """Tests for MfaPolicyCompliance model.""" + + def test_create_mfa_policy_compliance(self, db, test_user, test_organization): + """Test creating an MFA policy compliance record.""" + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.NOT_APPLICABLE, + policy_version=1, + ) + compliance.save() + + assert compliance.id is not None + assert compliance.user_id == test_user.id + assert compliance.organization_id == test_organization.id + assert compliance.status == MfaComplianceStatus.NOT_APPLICABLE + assert compliance.policy_version == 1 + assert compliance.notification_count == 0 + + def test_mfa_policy_compliance_in_grace(self, db, test_user, test_organization): + """Test MFA compliance record in grace period.""" + now = datetime.now(timezone.utc) + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.IN_GRACE, + policy_version=1, + applied_at=now, + deadline_at=now + timedelta(days=14), + ) + compliance.save() + + assert compliance.status == MfaComplianceStatus.IN_GRACE + assert compliance.applied_at is not None + assert compliance.deadline_at is not None + assert compliance.deadline_at > now + + def test_mfa_policy_compliance_compliant(self, db, test_user, test_organization): + """Test MFA compliance record when compliant.""" + now = datetime.now(timezone.utc) + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.COMPLIANT, + policy_version=1, + applied_at=now - timedelta(days=30), + deadline_at=now - timedelta(days=16), + compliant_at=now - timedelta(days=16), + ) + compliance.save() + + assert compliance.status == MfaComplianceStatus.COMPLIANT + assert compliance.compliant_at is not None + + def test_mfa_policy_compliance_suspended(self, db, test_user, test_organization): + """Test MFA compliance record when suspended.""" + now = datetime.now(timezone.utc) + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.SUSPENDED, + policy_version=1, + applied_at=now - timedelta(days=30), + deadline_at=now - timedelta(days=16), + suspended_at=now - timedelta(days=16), + ) + compliance.save() + + assert compliance.status == MfaComplianceStatus.SUSPENDED + assert compliance.suspended_at is not None + + def test_mfa_policy_compliance_relationships(self, db, test_user, test_organization): + """Test MFA compliance relationships.""" + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.NOT_APPLICABLE, + policy_version=1, + ) + compliance.save() + + # Test relationships + assert compliance.user is not None + assert compliance.user.id == test_user.id + assert compliance.organization is not None + assert compliance.organization.id == test_organization.id + + +@pytest.mark.unit +class TestSessionModelComplianceFlag: + """Tests for Session model compliance flag.""" + + def test_session_default_not_compliance_only(self, db, test_user): + """Test that sessions are not compliance only by default.""" + session = Session( + user_id=test_user.id, + token="test-token-123", + status=SessionStatus.ACTIVE, + expires_at=datetime.now(timezone.utc) + timedelta(hours=8), + last_activity_at=datetime.now(timezone.utc), + ) + session.save() + + assert session.is_compliance_only is False + + def test_session_compliance_only(self, db, test_user): + """Test creating a compliance-only session.""" + session = Session( + user_id=test_user.id, + token="compliance-token-123", + status=SessionStatus.ACTIVE, + expires_at=datetime.now(timezone.utc) + timedelta(hours=8), + last_activity_at=datetime.now(timezone.utc), + is_compliance_only=True, + ) + session.save() + + assert session.is_compliance_only is True + + def test_session_to_dict_excludes_token(self, db, test_user): + """Test that session to_dict excludes the token.""" + session = Session( + user_id=test_user.id, + token="test-token-456", + status=SessionStatus.ACTIVE, + expires_at=datetime.now(timezone.utc) + timedelta(hours=8), + last_activity_at=datetime.now(timezone.utc), + ) + session.save() + + session_dict = session.to_dict() + + assert "id" in session_dict + assert "user_id" in session_dict + assert "is_compliance_only" in session_dict + assert session_dict["is_compliance_only"] is False + + +@pytest.mark.unit +class TestUserStatusComplianceSuspended: + """Tests for UserStatus.COMPLIANCE_SUSPENDED.""" + + def test_compliance_suspended_status_exists(self): + """Test that COMPLIANCE_SUSPENDED status exists.""" + assert UserStatus.COMPLIANCE_SUSPENDED.value == "compliance_suspended" + + def test_create_compliance_suspended_user(self, db): + """Test creating a compliance suspended user.""" + user = User( + email="suspended@example.com", + full_name="Suspended User", + status=UserStatus.COMPLIANCE_SUSPENDED, + ) + user.save() + + assert user.status == UserStatus.COMPLIANCE_SUSPENDED diff --git a/tests/unit/test_services/test_mfa_policy_service.py b/tests/unit/test_services/test_mfa_policy_service.py new file mode 100644 index 0000000..6fcb749 --- /dev/null +++ b/tests/unit/test_services/test_mfa_policy_service.py @@ -0,0 +1,476 @@ +"""Unit tests for MfaPolicyService.""" +import pytest +from datetime import datetime, timezone, timedelta +from unittest.mock import patch, MagicMock + +from gatehouse_app.models import ( + User, + Organization, + OrganizationMember, + OrganizationSecurityPolicy, + UserSecurityPolicy, + MfaPolicyCompliance, + Session, +) +from gatehouse_app.services.mfa_policy_service import ( + MfaPolicyService, + OrgPolicyDto, + EffectiveUserPolicyDto, + AggregateMfaStateDto, + LoginPolicyResult, +) +from gatehouse_app.utils.constants import ( + UserStatus, + MfaPolicyMode, + MfaComplianceStatus, + MfaRequirementOverride, + SessionStatus, + OrganizationRole, +) + + +@pytest.mark.unit +class TestMfaPolicyService: + """Tests for MfaPolicyService.""" + + def test_get_org_policy_not_found(self, db, test_organization): + """Test getting organization policy when none exists.""" + policy = MfaPolicyService.get_org_policy(test_organization.id) + assert policy is None + + def test_get_org_policy_found(self, db, test_organization): + """Test getting organization policy when it exists.""" + # Create policy + org_policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + ) + org_policy.save() + + policy = MfaPolicyService.get_org_policy(test_organization.id) + + assert policy is not None + assert policy.organization_id == test_organization.id + assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value + assert policy.mfa_grace_period_days == 14 + assert policy.notify_days_before == 7 + assert policy.policy_version == 1 + + def test_get_effective_user_policy_no_org_policy(self, db, test_user, test_organization): + """Test effective user policy when no org policy exists.""" + policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id) + + assert policy is not None + assert policy.organization_id == test_organization.id + assert policy.effective_mode == MfaPolicyMode.DISABLED.value + assert policy.requires_totp is False + assert policy.requires_webauthn is False + assert policy.is_exempt is True + + def test_get_effective_user_policy_with_org_policy(self, db, test_user, test_organization): + """Test effective user policy with org policy and no override.""" + # Create org policy + org_policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + mfa_grace_period_days=14, + ) + org_policy.save() + + policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id) + + assert policy is not None + assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value + assert policy.requires_totp is True + assert policy.requires_webauthn is False + assert policy.is_exempt is False + + def test_get_effective_user_policy_with_override_inherit(self, db, test_user, test_organization): + """Test effective user policy with INHERIT override.""" + # Create org policy + org_policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN, + mfa_grace_period_days=7, + ) + org_policy.save() + + # Create user override + user_override = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.INHERIT, + ) + user_override.save() + + policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id) + + assert policy.effective_mode == MfaPolicyMode.REQUIRE_WEBAUTHN.value + assert policy.requires_webauthn is True + + def test_get_effective_user_policy_with_override_exempt(self, db, test_user, test_organization): + """Test effective user policy with EXEMPT override.""" + # Create org policy + org_policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + ) + org_policy.save() + + # Create user override + user_override = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.EXEMPT, + ) + user_override.save() + + policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id) + + assert policy.effective_mode == MfaPolicyMode.DISABLED.value + assert policy.is_exempt is True + + def test_get_effective_user_policy_with_override_required(self, db, test_user, test_organization): + """Test effective user policy with REQUIRED override.""" + # Create org policy + org_policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.OPTIONAL, + mfa_grace_period_days=14, + ) + org_policy.save() + + # Create user override + user_override = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.REQUIRED, + ) + user_override.save() + + policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id) + + assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value + assert policy.requires_totp is True + assert policy.requires_webauthn is True + assert policy.is_exempt is False + + def test_evaluate_user_mfa_state_no_policy(self, db, test_user, test_organization): + """Test evaluating user MFA state with no policy.""" + # Create membership + membership = OrganizationMember( + user_id=test_user.id, + organization_id=test_organization.id, + role=OrganizationRole.MEMBER, + ) + membership.save() + + state = MfaPolicyService.evaluate_user_mfa_state(test_user) + + assert state is not None + assert state.overall_status == MfaComplianceStatus.COMPLIANT.value + assert len(state.missing_methods) == 0 + assert len(state.orgs) == 1 + + def test_evaluate_user_mfa_state_with_policy(self, db, test_user, test_organization): + """Test evaluating user MFA state with policy.""" + # Create membership + membership = OrganizationMember( + user_id=test_user.id, + organization_id=test_organization.id, + role=OrganizationRole.MEMBER, + ) + membership.save() + + # Create org policy + org_policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + mfa_grace_period_days=14, + ) + org_policy.save() + + state = MfaPolicyService.evaluate_user_mfa_state(test_user) + + assert state is not None + assert state.overall_status == MfaComplianceStatus.IN_GRACE.value + assert "totp" in state.missing_methods + assert len(state.orgs) == 1 + assert state.orgs[0].effective_mode == MfaPolicyMode.REQUIRE_TOTP.value + + def test_after_primary_auth_success_no_required_policy(self, db, test_user, test_organization): + """Test after_primary_auth_success with no required policy.""" + # Create membership + membership = OrganizationMember( + user_id=test_user.id, + organization_id=test_organization.id, + role=OrganizationRole.MEMBER, + ) + membership.save() + + result = MfaPolicyService.after_primary_auth_success(test_user) + + assert result.can_create_full_session is True + assert result.create_compliance_only_session is False + assert result.compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value + + def test_after_primary_auth_success_in_grace(self, db, test_user, test_organization): + """Test after_primary_auth_success when user is in grace period.""" + # Create membership + membership = OrganizationMember( + user_id=test_user.id, + organization_id=test_organization.id, + role=OrganizationRole.MEMBER, + ) + membership.save() + + # Create org policy + org_policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + mfa_grace_period_days=14, + ) + org_policy.save() + + result = MfaPolicyService.after_primary_auth_success(test_user) + + assert result.can_create_full_session is True + assert result.create_compliance_only_session is False + assert result.compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value + + def test_after_primary_auth_success_past_due(self, db, test_user, test_organization): + """Test after_primary_auth_success when user is past due.""" + # Create membership + membership = OrganizationMember( + user_id=test_user.id, + organization_id=test_organization.id, + role=OrganizationRole.MEMBER, + ) + membership.save() + + # Create org policy + org_policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + mfa_grace_period_days=14, + ) + org_policy.save() + + # Create compliance record past due + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.PAST_DUE, + policy_version=1, + applied_at=datetime.now(timezone.utc) - timedelta(days=30), + deadline_at=datetime.now(timezone.utc) - timedelta(days=1), + ) + compliance.save() + + result = MfaPolicyService.after_primary_auth_success(test_user) + + assert result.can_create_full_session is False + assert result.create_compliance_only_session is True + + def test_create_org_policy_new(self, db, test_organization): + """Test creating a new organization policy.""" + policy = MfaPolicyService.create_org_policy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN, + mfa_grace_period_days=14, + notify_days_before=7, + updated_by_user_id=None, + ) + + assert policy is not None + assert policy.organization_id == test_organization.id + assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN + assert policy.policy_version == 1 + + def test_create_org_policy_update(self, db, test_organization): + """Test updating an existing organization policy.""" + # Create initial policy + initial_policy = OrganizationSecurityPolicy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.OPTIONAL, + mfa_grace_period_days=14, + ) + initial_policy.save() + + # Update policy + updated_policy = MfaPolicyService.create_org_policy( + organization_id=test_organization.id, + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP, + mfa_grace_period_days=7, + updated_by_user_id=None, + ) + + assert updated_policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP + assert updated_policy.mfa_grace_period_days == 7 + assert updated_policy.policy_version == 2 + + def test_set_user_override_new(self, db, test_user, test_organization): + """Test setting a new user override.""" + override = MfaPolicyService.set_user_override( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.REQUIRED, + force_totp=True, + force_webauthn=False, + updated_by_user_id=None, + ) + + assert override is not None + assert override.user_id == test_user.id + assert override.organization_id == test_organization.id + assert override.mfa_override_mode == MfaRequirementOverride.REQUIRED + assert override.force_totp is True + + def test_set_user_override_update(self, db, test_user, test_organization): + """Test updating an existing user override.""" + # Create initial override + initial_override = UserSecurityPolicy( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.INHERIT, + ) + initial_override.save() + + # Update override + updated_override = MfaPolicyService.set_user_override( + user_id=test_user.id, + organization_id=test_organization.id, + mfa_override_mode=MfaRequirementOverride.EXEMPT, + updated_by_user_id=None, + ) + + assert updated_override.mfa_override_mode == MfaRequirementOverride.EXEMPT + + def test_get_user_compliance(self, db, test_user, test_organization): + """Test getting user compliance record.""" + # Create compliance record + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.COMPLIANT, + policy_version=1, + ) + compliance.save() + + result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id) + + assert result is not None + assert result.status == MfaComplianceStatus.COMPLIANT + + def test_get_user_compliance_not_found(self, db, test_user, test_organization): + """Test getting user compliance record when none exists.""" + result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id) + assert result is None + + def test_get_org_compliance_list(self, db, test_user, test_organization): + """Test getting organization compliance list.""" + # Create compliance record + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.IN_GRACE, + policy_version=1, + deadline_at=datetime.now(timezone.utc) + timedelta(days=14), + ) + compliance.save() + + results = MfaPolicyService.get_org_compliance_list(test_organization.id) + + assert len(results) == 1 + assert results[0]["user_id"] == test_user.id + assert results[0]["status"] == MfaComplianceStatus.IN_GRACE.value + + def test_get_org_compliance_list_with_status_filter(self, db, test_user, test_organization): + """Test getting organization compliance list with status filter.""" + # Create compliance record + compliance = MfaPolicyCompliance( + user_id=test_user.id, + organization_id=test_organization.id, + status=MfaComplianceStatus.COMPLIANT, + policy_version=1, + ) + compliance.save() + + # Filter by different status + results = MfaPolicyService.get_org_compliance_list( + test_organization.id, status=MfaComplianceStatus.IN_GRACE + ) + assert len(results) == 0 + + # Filter by correct status + results = MfaPolicyService.get_org_compliance_list( + test_organization.id, status=MfaComplianceStatus.COMPLIANT + ) + assert len(results) == 1 + + +@pytest.mark.unit +class TestMfaPolicyServiceDto: + """Tests for MfaPolicyService DTOs.""" + + def test_org_policy_dto(self): + """Test OrgPolicyDto creation.""" + dto = OrgPolicyDto( + organization_id="org-123", + mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP.value, + mfa_grace_period_days=14, + notify_days_before=7, + policy_version=1, + ) + + assert dto.organization_id == "org-123" + assert dto.mfa_policy_mode == "require_totp" + assert dto.mfa_grace_period_days == 14 + + def test_effective_user_policy_dto(self): + """Test EffectiveUserPolicyDto creation.""" + dto = EffectiveUserPolicyDto( + organization_id="org-123", + effective_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value, + requires_totp=True, + requires_webauthn=True, + grace_period_days=14, + is_exempt=False, + ) + + assert dto.requires_totp is True + assert dto.requires_webauthn is True + assert dto.is_exempt is False + + def test_aggregate_mfa_state_dto(self): + """Test AggregateMfaStateDto creation.""" + dto = AggregateMfaStateDto( + overall_status=MfaComplianceStatus.IN_GRACE.value, + missing_methods=["totp"], + deadline_at="2025-02-01T00:00:00Z", + orgs=[], + ) + + assert dto.overall_status == "in_grace" + assert "totp" in dto.missing_methods + assert dto.deadline_at == "2025-02-01T00:00:00Z" + + def test_login_policy_result(self): + """Test LoginPolicyResult creation.""" + summary = AggregateMfaStateDto( + overall_status=MfaComplianceStatus.IN_GRACE.value, + missing_methods=["totp"], + orgs=[], + ) + result = LoginPolicyResult( + can_create_full_session=True, + create_compliance_only_session=False, + compliance_summary=summary, + ) + + assert result.can_create_full_session is True + assert result.create_compliance_only_session is False + assert result.compliance_summary.overall_status == "in_grace"