enable policies

This commit is contained in:
2026-01-16 17:31:20 +10:30
parent b2e084db33
commit d063a0ca81
28 changed files with 4296 additions and 224 deletions
-1
View File
@@ -286,7 +286,6 @@ For issues and questions:
# Boostrap db # Boostrap db
python manage.py db upgrade python manage.py db upgrade
python manage.py db migrate
+18
View File
@@ -16,6 +16,7 @@ from gatehouse_app.services.oidc_service import (
OIDCService, InvalidClientError, InvalidGrantError, InvalidRequestError OIDCService, InvalidClientError, InvalidGrantError, InvalidRequestError
) )
from gatehouse_app.services.auth_service import AuthService from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.extensions import bcrypt as flask_bcrypt from gatehouse_app.extensions import bcrypt as flask_bcrypt
from gatehouse_app.models import User, OIDCClient from gatehouse_app.models import User, OIDCClient
@@ -372,6 +373,23 @@ def oidc_authorize():
logger.debug("[OIDC] Attempting user authentication for email: %s", email) logger.debug("[OIDC] Attempting user authentication for email: %s", email)
try: try:
user = AuthService.authenticate(email, password) user = AuthService.authenticate(email, password)
# Evaluate MFA policy after primary authentication
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False)
# Check if user can create full session
if not policy_result.can_create_full_session:
logger.debug("[OIDC] User cannot create full session due to MFA compliance: user_id=%s, email=%s", user.id, email)
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type,
error="Your account requires multi factor enrollment before using single sign on"
)
user_id = user.id user_id = user.id
session["oidc_user_id"] = user_id session["oidc_user_id"] = user_id
+1 -1
View File
@@ -5,4 +5,4 @@ from flask import Blueprint
api_v1_bp = Blueprint("api_v1", __name__) api_v1_bp = Blueprint("api_v1", __name__)
# Import route modules to register them # Import route modules to register them
from gatehouse_app.api.v1 import auth, users, organizations from gatehouse_app.api.v1 import auth, users, organizations, policies
+124 -15
View File
@@ -22,6 +22,7 @@ from gatehouse_app.schemas.webauthn_schema import (
from gatehouse_app.services.auth_service import AuthService from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.webauthn_service import WebAuthnService from gatehouse_app.services.webauthn_service import WebAuthnService
from gatehouse_app.services.user_service import UserService from gatehouse_app.services.user_service import UserService
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
from gatehouse_app.utils.decorators import login_required from gatehouse_app.utils.decorators import login_required
from gatehouse_app.utils.constants import AuditAction from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
@@ -94,6 +95,9 @@ def login():
400: Validation error 400: Validation error
401: Invalid credentials 401: Invalid credentials
""" """
import logging
logger = logging.getLogger(__name__)
try: try:
# Validate request data # Validate request data
schema = LoginSchema() schema = LoginSchema()
@@ -105,6 +109,11 @@ def login():
password=data["password"], password=data["password"],
) )
# SECURITY CHECK: Log MFA enrollment status to validate the vulnerability
has_totp = user.has_totp_enabled()
has_webauthn = user.has_webauthn_enabled()
logger.warning(f"[SECURITY DIAGNOSTIC] Login attempt for user {user.email} - TOTP enabled: {has_totp}, WebAuthn enabled: {has_webauthn}")
# Check if user has TOTP enabled for two-factor authentication # Check if user has TOTP enabled for two-factor authentication
if user.has_totp_enabled(): if user.has_totp_enabled():
# TOTP is enabled - store user_id in session for TOTP verification # TOTP is enabled - store user_id in session for TOTP verification
@@ -121,16 +130,56 @@ def login():
) )
# TOTP is NOT enabled - proceed with normal login flow # TOTP is NOT enabled - proceed with normal login flow
# Create session with appropriate duration based on remember_me preference # SECURITY DIAGNOSTIC: This is where the vulnerability occurs - no WebAuthn check!
duration = 2592000 if data.get("remember_me") else 86400 # 30 days vs 1 day if has_webauthn:
user_session = AuthService.create_session(user, duration_seconds=duration) logger.error(f"[SECURITY VULNERABILITY DETECTED] User {user.email} has WebAuthn enrolled but is bypassing it! Creating session without MFA verification.")
return api_response( # Evaluate MFA policy after primary authentication
data={ remember_me = data.get("remember_me", False)
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me)
# Create session with appropriate duration based on remember_me preference
duration = 2592000 if remember_me else 86400 # 30 days vs 1 day
# Determine if this should be a compliance-only session
is_compliance_only = policy_result.create_compliance_only_session
user_session = AuthService.create_session(
user,
duration_seconds=duration,
is_compliance_only=is_compliance_only
)
# Build response data
response_data = {
"user": user.to_dict(), "user": user.to_dict(),
"token": user_session.token, "token": user_session.token,
"expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(), "expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(),
}, }
# Add MFA compliance information
if policy_result.compliance_summary:
response_data["mfa_compliance"] = {
"overall_status": policy_result.compliance_summary.overall_status,
"missing_methods": policy_result.compliance_summary.missing_methods,
"deadline_at": policy_result.compliance_summary.deadline_at,
"orgs": [
{
"organization_id": org.organization_id,
"organization_name": org.organization_name,
"status": org.status,
"deadline_at": org.deadline_at,
}
for org in policy_result.compliance_summary.orgs
],
}
# Add requires_mfa_enrollment flag if compliance-only session
if is_compliance_only:
response_data["requires_mfa_enrollment"] = True
return api_response(
data=response_data,
message="Login successful", message="Login successful",
) )
@@ -380,20 +429,50 @@ def verify_totp():
client_utc_timestamp=data.get("client_timestamp"), client_utc_timestamp=data.get("client_timestamp"),
) )
# Create full session # Evaluate MFA policy after primary authentication
user_session = AuthService.create_session(user) policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False)
# Determine if this should be a compliance-only session
is_compliance_only = policy_result.create_compliance_only_session
# Create session
user_session = AuthService.create_session(user, is_compliance_only=is_compliance_only)
# Clear temporary session # Clear temporary session
session.pop("totp_pending_user_id", None) session.pop("totp_pending_user_id", None)
return api_response( # Build response data
data={ response_data = {
"user": user.to_dict(), "user": user.to_dict(),
"token": user_session.token, "token": user_session.token,
"expires_at": user_session.expires_at.isoformat() + "Z" "expires_at": user_session.expires_at.isoformat() + "Z"
if user_session.expires_at.isoformat()[-1] != "Z" if user_session.expires_at.isoformat()[-1] != "Z"
else user_session.expires_at.isoformat(), else user_session.expires_at.isoformat(),
}, }
# Add MFA compliance information
if policy_result.compliance_summary:
response_data["mfa_compliance"] = {
"overall_status": policy_result.compliance_summary.overall_status,
"missing_methods": policy_result.compliance_summary.missing_methods,
"deadline_at": policy_result.compliance_summary.deadline_at,
"orgs": [
{
"organization_id": org.organization_id,
"organization_name": org.organization_name,
"status": org.status,
"deadline_at": org.deadline_at,
}
for org in policy_result.compliance_summary.orgs
],
}
# Add requires_mfa_enrollment flag if compliance-only session
if is_compliance_only:
response_data["requires_mfa_enrollment"] = True
return api_response(
data=response_data,
message="TOTP verification successful", message="TOTP verification successful",
) )
@@ -835,22 +914,52 @@ def complete_webauthn_login():
challenge challenge
) )
# Evaluate MFA policy after primary authentication
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False)
# Determine if this should be a compliance-only session
is_compliance_only = policy_result.create_compliance_only_session
# Create session # Create session
user_session = AuthService.create_session(user) user_session = AuthService.create_session(user, is_compliance_only=is_compliance_only)
# Clear pending session # Clear pending session
session.pop("webauthn_pending_user_id", None) session.pop("webauthn_pending_user_id", None)
logger.info(f"WebAuthn login completed successfully for user: {user.email}") logger.info(f"WebAuthn login completed successfully for user: {user.email}")
return api_response( # Build response data
data={ response_data = {
"user": user.to_dict(), "user": user.to_dict(),
"token": user_session.token, "token": user_session.token,
"expires_at": user_session.expires_at.isoformat() + "Z" "expires_at": user_session.expires_at.isoformat() + "Z"
if user_session.expires_at.isoformat()[-1] != "Z" if user_session.expires_at.isoformat()[-1] != "Z"
else user_session.expires_at.isoformat(), else user_session.expires_at.isoformat(),
}, }
# Add MFA compliance information
if policy_result.compliance_summary:
response_data["mfa_compliance"] = {
"overall_status": policy_result.compliance_summary.overall_status,
"missing_methods": policy_result.compliance_summary.missing_methods,
"deadline_at": policy_result.compliance_summary.deadline_at,
"orgs": [
{
"organization_id": org.organization_id,
"organization_name": org.organization_name,
"status": org.status,
"deadline_at": org.deadline_at,
}
for org in policy_result.compliance_summary.orgs
],
}
# Add requires_mfa_enrollment flag if compliance-only session
if is_compliance_only:
response_data["requires_mfa_enrollment"] = True
return api_response(
data=response_data,
message="Login successful", message="Login successful",
) )
+9 -1
View File
@@ -3,7 +3,7 @@ from flask import g, request
from marshmallow import ValidationError from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, require_owner from gatehouse_app.utils.decorators import login_required, require_admin, require_owner, full_access_required
from gatehouse_app.schemas.organization_schema import ( from gatehouse_app.schemas.organization_schema import (
OrganizationCreateSchema, OrganizationCreateSchema,
OrganizationUpdateSchema, OrganizationUpdateSchema,
@@ -17,6 +17,7 @@ from gatehouse_app.utils.constants import OrganizationRole
@api_v1_bp.route("/organizations", methods=["POST"]) @api_v1_bp.route("/organizations", methods=["POST"])
@login_required @login_required
@full_access_required
def create_organization(): def create_organization():
""" """
Create a new organization. Create a new organization.
@@ -65,6 +66,7 @@ def create_organization():
@api_v1_bp.route("/organizations/<org_id>", methods=["GET"]) @api_v1_bp.route("/organizations/<org_id>", methods=["GET"])
@login_required @login_required
@full_access_required
def get_organization(org_id): def get_organization(org_id):
""" """
Get organization by ID. Get organization by ID.
@@ -101,6 +103,7 @@ def get_organization(org_id):
@api_v1_bp.route("/organizations/<org_id>", methods=["PATCH"]) @api_v1_bp.route("/organizations/<org_id>", methods=["PATCH"])
@login_required @login_required
@require_admin @require_admin
@full_access_required
def update_organization(org_id): def update_organization(org_id):
""" """
Update organization. Update organization.
@@ -152,6 +155,7 @@ def update_organization(org_id):
@api_v1_bp.route("/organizations/<org_id>", methods=["DELETE"]) @api_v1_bp.route("/organizations/<org_id>", methods=["DELETE"])
@login_required @login_required
@require_owner @require_owner
@full_access_required
def delete_organization(org_id): def delete_organization(org_id):
""" """
Delete organization (soft delete). Delete organization (soft delete).
@@ -180,6 +184,7 @@ def delete_organization(org_id):
@api_v1_bp.route("/organizations/<org_id>/members", methods=["GET"]) @api_v1_bp.route("/organizations/<org_id>/members", methods=["GET"])
@login_required @login_required
@full_access_required
def get_organization_members(org_id): def get_organization_members(org_id):
""" """
Get all members of an organization. Get all members of an organization.
@@ -223,6 +228,7 @@ def get_organization_members(org_id):
@api_v1_bp.route("/organizations/<org_id>/members", methods=["POST"]) @api_v1_bp.route("/organizations/<org_id>/members", methods=["POST"])
@login_required @login_required
@require_admin @require_admin
@full_access_required
def add_organization_member(org_id): def add_organization_member(org_id):
""" """
Add a member to the organization. Add a member to the organization.
@@ -290,6 +296,7 @@ def add_organization_member(org_id):
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>", methods=["DELETE"]) @api_v1_bp.route("/organizations/<org_id>/members/<user_id>", methods=["DELETE"])
@login_required @login_required
@require_admin @require_admin
@full_access_required
def remove_organization_member(org_id, user_id): def remove_organization_member(org_id, user_id):
""" """
Remove a member from the organization. Remove a member from the organization.
@@ -320,6 +327,7 @@ def remove_organization_member(org_id, user_id):
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>/role", methods=["PATCH"]) @api_v1_bp.route("/organizations/<org_id>/members/<user_id>/role", methods=["PATCH"])
@login_required @login_required
@require_admin @require_admin
@full_access_required
def update_member_role(org_id, user_id): def update_member_role(org_id, user_id):
""" """
Update a member's role. Update a member's role.
+336
View File
@@ -0,0 +1,336 @@
"""Security policy endpoints."""
from flask import g, request
from marshmallow import Schema, fields, validate, ValidationError
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.utils.constants import MfaPolicyMode, MfaRequirementOverride, MfaComplianceStatus, AuditAction
class UpdateOrgPolicySchema(Schema):
"""Schema for updating organization security policy."""
mfa_policy_mode = fields.String(
required=False,
validate=validate.OneOf([m.value for m in MfaPolicyMode])
)
mfa_grace_period_days = fields.Integer(
required=False,
validate=validate.Range(min=0, max=365)
)
notify_days_before = fields.Integer(
required=False,
validate=validate.Range(min=0, max=30)
)
class UpdateUserPolicySchema(Schema):
"""Schema for updating user security policy override."""
mfa_override_mode = fields.String(
required=True,
validate=validate.OneOf([m.value for m in MfaRequirementOverride])
)
force_totp = fields.Boolean(required=False, load_default=False)
force_webauthn = fields.Boolean(required=False, load_default=False)
class ComplianceListQuerySchema(Schema):
"""Schema for compliance list query parameters."""
status = fields.String(required=False)
limit = fields.Integer(required=False, load_default=100)
offset = fields.Integer(required=False, load_default=0)
@api_v1_bp.route("/organizations/<org_id>/security-policy", methods=["GET"])
@login_required
def get_org_security_policy(org_id):
"""
Get organization security policy.
Args:
org_id: Organization ID
Returns:
200: Organization security policy
401: Not authenticated
403: Not a member
404: Organization not found
"""
org = OrganizationService.get_organization_by_id(org_id)
# Check if user is a member
if not org.is_member(g.current_user.id):
return api_response(
success=False,
message="You are not a member of this organization",
status=403,
error_type="AUTHORIZATION_ERROR",
)
policy_dto = MfaPolicyService.get_org_policy(org_id)
if policy_dto:
data = {
"organization_id": policy_dto.organization_id,
"mfa_policy_mode": policy_dto.mfa_policy_mode,
"mfa_grace_period_days": policy_dto.mfa_grace_period_days,
"notify_days_before": policy_dto.notify_days_before,
"policy_version": policy_dto.policy_version,
}
else:
# Return default policy if none exists
data = {
"organization_id": org_id,
"mfa_policy_mode": MfaPolicyMode.OPTIONAL.value,
"mfa_grace_period_days": 14,
"notify_days_before": 7,
"policy_version": 0,
}
return api_response(
data={"security_policy": data},
message="Security policy retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/security-policy", methods=["PUT"])
@login_required
@require_admin
@full_access_required
def update_org_security_policy(org_id):
"""
Update organization security policy.
Args:
org_id: Organization ID
Request body:
mfa_policy_mode: MFA policy mode (disabled, optional, require_totp, require_webauthn, require_totp_or_webauthn)
mfa_grace_period_days: Grace period in days (0-365)
notify_days_before: Days before deadline to notify (0-30)
Returns:
200: Security policy updated successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization not found
"""
try:
schema = UpdateOrgPolicySchema()
data = schema.load(request.json)
org = OrganizationService.get_organization_by_id(org_id)
# Update policy
policy = MfaPolicyService.create_org_policy(
organization_id=org_id,
mfa_policy_mode=MfaPolicyMode(data.get("mfa_policy_mode", MfaPolicyMode.OPTIONAL.value)),
mfa_grace_period_days=data.get("mfa_grace_period_days", 14),
notify_days_before=data.get("notify_days_before", 7),
updated_by_user_id=g.current_user.id,
)
return api_response(
data={
"security_policy": {
"organization_id": policy.organization_id,
"mfa_policy_mode": policy.mfa_policy_mode.value,
"mfa_grace_period_days": policy.mfa_grace_period_days,
"notify_days_before": policy.notify_days_before,
"policy_version": policy.policy_version,
}
},
message="Security policy updated successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>/mfa-compliance", methods=["GET"])
@login_required
@require_admin
@full_access_required
def get_org_mfa_compliance(org_id):
"""
Get MFA compliance list for an organization.
Args:
org_id: Organization ID
Query params:
status: Optional status filter (not_applicable, pending, in_grace, compliant, past_due, suspended)
limit: Maximum records to return (default 100)
offset: Offset for pagination (default 0)
Returns:
200: List of compliance records
401: Not authenticated
403: Not an admin
404: Organization not found
"""
org = OrganizationService.get_organization_by_id(org_id)
# Parse query params
status = None
if request.args.get("status"):
try:
status = MfaComplianceStatus(request.args.get("status"))
except ValueError:
return api_response(
success=False,
message="Invalid status value",
status=400,
error_type="VALIDATION_ERROR",
)
limit = min(int(request.args.get("limit", 100)), 100)
offset = int(request.args.get("offset", 0))
compliance_list = MfaPolicyService.get_org_compliance_list(
organization_id=org_id,
status=status,
limit=limit,
offset=offset,
)
return api_response(
data={
"compliance": compliance_list,
"count": len(compliance_list),
"limit": limit,
"offset": offset,
},
message="Compliance records retrieved successfully",
)
@api_v1_bp.route(
"/organizations/<org_id>/users/<user_id>/security-policy", methods=["PATCH"]
)
@login_required
@require_admin
@full_access_required
def update_user_security_policy(org_id, user_id):
"""
Update user security policy override.
Args:
org_id: Organization ID
user_id: User ID
Request body:
mfa_override_mode: Override mode (inherit, required, exempt)
force_totp: Force TOTP requirement (default false)
force_webauthn: Force WebAuthn requirement (default false)
Returns:
200: User policy updated successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization or user not found
"""
try:
schema = UpdateUserPolicySchema()
data = schema.load(request.json)
org = OrganizationService.get_organization_by_id(org_id)
# Check if user is a member of the organization
if not org.is_member(user_id):
return api_response(
success=False,
message="User is not a member of this organization",
status=404,
error_type="NOT_FOUND",
)
# Update user policy
policy = MfaPolicyService.set_user_override(
user_id=user_id,
organization_id=org_id,
mfa_override_mode=MfaRequirementOverride(data["mfa_override_mode"]),
force_totp=data.get("force_totp", False),
force_webauthn=data.get("force_webauthn", False),
updated_by_user_id=g.current_user.id,
)
# Log the override change with details
AuditService.log_action(
action=AuditAction.USER_SECURITY_POLICY_OVERRIDE_UPDATE,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=user_id,
description=f"User security policy override changed to {data['mfa_override_mode']} for user {user_id}",
)
return api_response(
data={
"user_security_policy": {
"user_id": policy.user_id,
"organization_id": policy.organization_id,
"mfa_override_mode": policy.mfa_override_mode.value,
"force_totp": policy.force_totp,
"force_webauthn": policy.force_webauthn,
}
},
message="User security policy updated successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/users/me/mfa-compliance", methods=["GET"])
@login_required
def get_my_mfa_compliance():
"""
Get current user's MFA compliance across all organizations.
Returns:
200: MFA compliance summary
401: Not authenticated
"""
user = g.current_user
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user)
orgs = []
for org_state in compliance_summary.orgs:
orgs.append({
"organization_id": org_state.organization_id,
"organization_name": org_state.organization_name,
"status": org_state.status,
"effective_mode": org_state.effective_mode,
"deadline_at": org_state.deadline_at,
"applied_at": org_state.applied_at,
})
return api_response(
data={
"mfa_compliance": {
"overall_status": compliance_summary.overall_status,
"missing_methods": compliance_summary.missing_methods,
"deadline_at": compliance_summary.deadline_at,
"orgs": orgs,
}
},
message="MFA compliance retrieved successfully",
)
+5 -1
View File
@@ -3,7 +3,7 @@ from flask import g, request
from marshmallow import ValidationError from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required from gatehouse_app.utils.decorators import login_required, full_access_required
from gatehouse_app.schemas.user_schema import UserUpdateSchema, ChangePasswordSchema from gatehouse_app.schemas.user_schema import UserUpdateSchema, ChangePasswordSchema
from gatehouse_app.services.user_service import UserService from gatehouse_app.services.user_service import UserService
from gatehouse_app.services.auth_service import AuthService from gatehouse_app.services.auth_service import AuthService
@@ -29,6 +29,7 @@ def get_me():
@api_v1_bp.route("/users/me", methods=["PATCH"]) @api_v1_bp.route("/users/me", methods=["PATCH"])
@login_required @login_required
@full_access_required
def update_me(): def update_me():
""" """
Update current user profile. Update current user profile.
@@ -67,6 +68,7 @@ def update_me():
@api_v1_bp.route("/users/me", methods=["DELETE"]) @api_v1_bp.route("/users/me", methods=["DELETE"])
@login_required @login_required
@full_access_required
def delete_me(): def delete_me():
""" """
Delete current user account (soft delete). Delete current user account (soft delete).
@@ -84,6 +86,7 @@ def delete_me():
@api_v1_bp.route("/users/me/password", methods=["POST"]) @api_v1_bp.route("/users/me/password", methods=["POST"])
@login_required @login_required
@full_access_required
def change_password(): def change_password():
""" """
Change current user password. Change current user password.
@@ -136,6 +139,7 @@ def change_password():
@api_v1_bp.route("/users/me/organizations", methods=["GET"]) @api_v1_bp.route("/users/me/organizations", methods=["GET"])
@login_required @login_required
@full_access_required
def get_my_organizations(): def get_my_organizations():
""" """
Get all organizations current user is a member of. Get all organizations current user is a member of.
+1
View File
@@ -0,0 +1 @@
Jobs module for scheduled tasks.
+279
View File
@@ -0,0 +1,279 @@
"""MFA Compliance Scheduled Job.
This module implements the scheduled job for processing MFA compliance transitions,
sending notifications to users approaching deadlines, and handling edge cases.
The job is designed to be run periodically (e.g., via cron) to:
1. Transition users from PAST_DUE to SUSPENDED status
2. Send deadline reminder notifications to users in grace period
3. Update notification tracking metadata
Usage:
python manage.py run_mfa_compliance_job
Or call directly:
from gatehouse_app.jobs.mfa_compliance_job import process_mfa_compliance
process_mfa_compliance()
"""
from datetime import datetime, timezone, timedelta
from typing import Optional, Dict, Any, List
import logging
from gatehouse_app.extensions import db
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.user import User
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
from gatehouse_app.services.notification_service import NotificationService
from gatehouse_app.utils.constants import MfaComplianceStatus
logger = logging.getLogger(__name__)
def process_mfa_compliance(now: Optional[datetime] = None) -> Dict[str, Any]:
"""Process MFA compliance transitions and send notifications.
This scheduled job performs the following operations:
1. Transitions users from PAST_DUE to SUSPENDED status
2. Identifies users approaching deadline (within notify_days_before)
3. Sends deadline reminder notifications
4. Updates notification tracking metadata
Args:
now: Current time, defaults to now (UTC)
Returns:
Dictionary with job execution statistics:
- suspended_count: Number of users transitioned to suspended
- notified_count: Number of notifications sent
- processed_count: Total compliance records processed
"""
if now is None:
now = datetime.now(timezone.utc)
logger.info(f"Starting MFA compliance job at {now.isoformat()}")
stats = {
"suspended_count": 0,
"notified_count": 0,
"processed_count": 0,
"errors": [],
}
try:
# Step 1: Transition past-due users to suspended
suspended_count = MfaPolicyService.transition_to_suspended_if_past_due(now)
stats["suspended_count"] = suspended_count
logger.info(f"Transitioned {suspended_count} users to suspended status")
# Step 2: Send notifications to users approaching deadline
notified_count = _send_deadline_reminders(now)
stats["notified_count"] = notified_count
logger.info(f"Sent {notified_count} deadline reminder notifications")
# Step 3: Process any pending compliance evaluations
processed_count = _evaluate_pending_compliance(now)
stats["processed_count"] = processed_count
logger.info(f"Processed {processed_count} compliance records")
except Exception as e:
logger.exception(f"Error during MFA compliance job: {e}")
stats["errors"].append(str(e))
logger.info(
f"MFA compliance job completed: suspended={stats['suspended_count']}, "
f"notified={stats['notified_count']}, processed={stats['processed_count']}"
)
return stats
def _send_deadline_reminders(now: datetime) -> int:
"""Send deadline reminder notifications to users approaching deadline.
Identifies users in grace period who are within their organization's
notify_days_before threshold and sends them reminder notifications.
Args:
now: Current time (UTC)
Returns:
Number of notifications sent
"""
notified_count = 0
# Find all compliance records in grace period
grace_records = MfaPolicyCompliance.query.filter(
MfaPolicyCompliance.status == MfaComplianceStatus.IN_GRACE,
MfaPolicyCompliance.deadline_at != None,
MfaPolicyCompliance.deleted_at == None,
).all()
for record in grace_records:
try:
# Get organization policy for notify_days_before
org_policy = OrganizationSecurityPolicy.query.filter_by(
organization_id=record.organization_id, deleted_at=None
).first()
if not org_policy:
continue
notify_threshold = org_policy.notify_days_before
deadline = record.deadline_at
# Ensure deadline has timezone
if deadline.tzinfo is None:
deadline = deadline.replace(tzinfo=timezone.utc)
# Calculate time until deadline
time_until_deadline = deadline - now
days_until_deadline = time_until_deadline.total_seconds() / 86400
# Check if we should send a reminder
should_notify = False
if days_until_deadline <= notify_threshold:
# Check if we've already notified recently (within last 24 hours)
if record.last_notified_at:
hours_since_notification = (
now - record.last_notified_at
).total_seconds() / 3600
if hours_since_notification < 24:
continue # Already notified recently
should_notify = True
if should_notify:
# Get user
user = User.query.get(record.user_id)
if not user:
continue
# Send notification
success = NotificationService.send_mfa_deadline_reminder(
user=user,
compliance=record,
org_policy=org_policy,
)
if success:
# Update notification tracking
record.last_notified_at = now
record.notification_count += 1
db.session.commit()
notified_count += 1
logger.info(
f"Sent deadline reminder to user {user.email} "
f"(days until deadline: {days_until_deadline:.1f})"
)
except Exception as e:
logger.warning(
f"Error sending reminder for compliance record "
f"{record.id}: {e}"
)
continue
return notified_count
def _evaluate_pending_compliance(now: datetime) -> int:
"""Evaluate and update pending compliance records.
This handles edge cases where compliance records may need
status updates due to policy changes or other factors.
Args:
now: Current time (UTC)
Returns:
Number of records processed
"""
processed_count = 0
# Find all non-deleted compliance records
records = MfaPolicyCompliance.query.filter(
MfaPolicyCompliance.deleted_at == None,
).all()
for record in records:
try:
# Get the user and evaluate their current state
user = User.query.get(record.user_id)
if not user:
continue
# Re-evaluate compliance status
# This handles cases where policy changed or user enrolled in MFA
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
effective_policy = MfaPolicyService.get_effective_user_policy(
user.id, record.organization_id
)
new_status = MfaPolicyService._evaluate_compliance_status(
user, effective_policy, record
)
# Update status if changed
if record.status != new_status:
old_status = record.status.value if hasattr(record.status, 'value') else str(record.status)
record.status = MfaComplianceStatus(new_status)
db.session.commit()
logger.info(
f"Updated compliance status for user {user.email} "
f"in org {record.organization_id}: {old_status} -> {new_status}"
)
processed_count += 1
except Exception as e:
logger.warning(
f"Error evaluating compliance record {record.id}: {e}"
)
continue
return processed_count
def get_job_status(now: Optional[datetime] = None) -> Dict[str, Any]:
"""Get current status of MFA compliance for monitoring.
Args:
now: Current time, defaults to now (UTC)
Returns:
Dictionary with compliance statistics
"""
if now is None:
now = datetime.now(timezone.utc)
# Count records by status
status_counts = {}
for status in MfaComplianceStatus:
count = MfaPolicyCompliance.query.filter(
MfaPolicyCompliance.status == status,
MfaPolicyCompliance.deleted_at == None,
).count()
status_counts[status.value] = count
# Count users approaching deadline (within 7 days by default)
approaching_deadline = MfaPolicyCompliance.query.filter(
MfaPolicyCompliance.status == MfaComplianceStatus.IN_GRACE,
MfaPolicyCompliance.deadline_at != None,
MfaPolicyCompliance.deleted_at == None,
).count()
# Count past-due records
past_due_count = MfaPolicyCompliance.query.filter(
MfaPolicyCompliance.status == MfaComplianceStatus.PAST_DUE,
MfaPolicyCompliance.deleted_at == None,
).count()
return {
"status_counts": status_counts,
"approaching_deadline_count": approaching_deadline,
"past_due_count": past_due_count,
"timestamp": now.isoformat(),
}
+6
View File
@@ -12,6 +12,9 @@ from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken
from gatehouse_app.models.oidc_session import OIDCSession from gatehouse_app.models.oidc_session import OIDCSession
from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
from gatehouse_app.models.oidc_audit_log import OIDCAuditLog from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
__all__ = [ __all__ = [
"BaseModel", "BaseModel",
@@ -27,4 +30,7 @@ __all__ = [
"OIDCSession", "OIDCSession",
"OIDCTokenMetadata", "OIDCTokenMetadata",
"OIDCAuditLog", "OIDCAuditLog",
"OrganizationSecurityPolicy",
"UserSecurityPolicy",
"MfaPolicyCompliance",
] ]
@@ -0,0 +1,66 @@
"""MfaPolicyCompliance model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import MfaComplianceStatus
class MfaPolicyCompliance(BaseModel):
"""MFA policy compliance tracking per user per organization.
Tracks each user's MFA compliance state separately for each organization membership.
"""
__tablename__ = "mfa_policy_compliance"
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
)
status = db.Column(
db.Enum(MfaComplianceStatus),
nullable=False,
default=MfaComplianceStatus.NOT_APPLICABLE,
)
# Snapshot of org policy at the time this record became active
policy_version = db.Column(db.Integer, nullable=False)
# When policy started applying to this user
applied_at = db.Column(db.DateTime, nullable=True)
# Final deadline for this user to comply (per user, not global)
deadline_at = db.Column(db.DateTime, nullable=True)
# When they became compliant under this policy_version
compliant_at = db.Column(db.DateTime, nullable=True)
# When suspended enforcement started for this user
suspended_at = db.Column(db.DateTime, nullable=True)
# Notification tracking
last_notified_at = db.Column(db.DateTime, nullable=True)
notification_count = db.Column(db.Integer, nullable=False, default=0)
__table_args__ = (
db.UniqueConstraint(
"user_id", "organization_id", name="uix_user_org_compliance"
),
)
# Relationships
user = db.relationship(
"User", back_populates="mfa_compliance", foreign_keys=[user_id]
)
organization = db.relationship("Organization", foreign_keys=[organization_id])
def __repr__(self):
"""String representation of MfaPolicyCompliance."""
return f"<MfaPolicyCompliance user={self.user_id} org={self.organization_id} status={self.status}>"
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
+7
View File
@@ -24,6 +24,13 @@ class Organization(BaseModel):
oidc_clients = db.relationship( oidc_clients = db.relationship(
"OIDCClient", back_populates="organization", cascade="all, delete-orphan" "OIDCClient", back_populates="organization", cascade="all, delete-orphan"
) )
security_policy = db.relationship(
"OrganizationSecurityPolicy",
back_populates="organization",
uselist=False,
cascade="all, delete-orphan",
foreign_keys="OrganizationSecurityPolicy.organization_id",
)
def __repr__(self): def __repr__(self):
"""String representation of Organization.""" """String representation of Organization."""
@@ -0,0 +1,53 @@
"""OrganizationSecurityPolicy model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import MfaPolicyMode
class OrganizationSecurityPolicy(BaseModel):
"""Organization security policy model for MFA configuration.
One row per organization capturing its current security requirements.
"""
__tablename__ = "organization_security_policies"
organization_id = db.Column(
db.String(36),
db.ForeignKey("organizations.id"),
nullable=False,
index=True,
unique=True,
)
# MFA policy configuration
mfa_policy_mode = db.Column(
db.Enum(MfaPolicyMode), nullable=False, default=MfaPolicyMode.OPTIONAL
)
# Grace period for members in days
mfa_grace_period_days = db.Column(db.Integer, nullable=False, default=14)
# Notification settings (in days before individual user deadline)
notify_days_before = db.Column(db.Integer, nullable=False, default=7)
# Versioning for compatibility tracking
policy_version = db.Column(db.Integer, nullable=False, default=1)
# Audit metadata
updated_by_user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
# Relationships
organization = db.relationship(
"Organization", back_populates="security_policy", foreign_keys=[organization_id]
)
updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id])
def __repr__(self):
"""String representation of OrganizationSecurityPolicy."""
return f"<OrganizationSecurityPolicy org={self.organization_id} mode={self.mfa_policy_mode}>"
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
+3
View File
@@ -25,6 +25,9 @@ class Session(BaseModel):
revoked_at = db.Column(db.DateTime, nullable=True) revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True) revoked_reason = db.Column(db.String(255), nullable=True)
# Compliance session flag
is_compliance_only = db.Column(db.Boolean, nullable=False, default=False)
# Relationships # Relationships
user = db.relationship("User", back_populates="sessions") user = db.relationship("User", back_populates="sessions")
+12
View File
@@ -31,6 +31,18 @@ class User(BaseModel):
foreign_keys="OrganizationMember.user_id", foreign_keys="OrganizationMember.user_id",
) )
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan") audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
security_policies = db.relationship(
"UserSecurityPolicy",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="UserSecurityPolicy.user_id",
)
mfa_compliance = db.relationship(
"MfaPolicyCompliance",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="MfaPolicyCompliance.user_id",
)
def __repr__(self): def __repr__(self):
"""String representation of User.""" """String representation of User."""
@@ -0,0 +1,53 @@
"""UserSecurityPolicy model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import MfaRequirementOverride
class UserSecurityPolicy(BaseModel):
"""User security policy model for per-user MFA overrides.
Stores per user overrides of organization level MFA requirements.
"""
__tablename__ = "user_security_policies"
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
)
mfa_override_mode = db.Column(
db.Enum(MfaRequirementOverride),
nullable=False,
default=MfaRequirementOverride.INHERIT,
)
# If override is REQUIRED and you want to force a specific factor set
force_totp = db.Column(db.Boolean, nullable=False, default=False)
force_webauthn = db.Column(db.Boolean, nullable=False, default=False)
__table_args__ = (
db.UniqueConstraint(
"user_id", "organization_id", name="uix_user_org_policy"
),
)
# Relationships
user = db.relationship(
"User", back_populates="security_policies", foreign_keys=[user_id]
)
organization = db.relationship(
"Organization", foreign_keys=[organization_id]
)
def __repr__(self):
"""String representation of UserSecurityPolicy."""
return f"<UserSecurityPolicy user={self.user_id} org={self.organization_id} mode={self.mfa_override_mode}>"
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
+26
View File
@@ -96,3 +96,29 @@ class TOTPRegenerateBackupCodesSchema(Schema):
"""Schema for regenerating backup codes.""" """Schema for regenerating backup codes."""
password = fields.Str(required=True, validate=validate.Length(min=1)) password = fields.Str(required=True, validate=validate.Length(min=1))
class MfaComplianceOrgSchema(Schema):
"""Schema for MFA compliance per organization."""
organization_id = fields.Str(required=True)
organization_name = fields.Str(required=True)
status = fields.Str(required=True)
deadline_at = fields.Str(allow_none=True)
class MfaComplianceSchema(Schema):
"""Schema for MFA compliance summary in login response."""
overall_status = fields.Str(required=True)
missing_methods = fields.List(fields.Str(), required=True)
deadline_at = fields.Str(allow_none=True)
orgs = fields.List(fields.Nested(MfaComplianceOrgSchema), required=True)
class LoginResponseSchema(Schema):
"""Schema for login response."""
user = fields.Dict(required=True)
token = fields.Str(required=True)
expires_at = fields.Str(required=True)
requires_totp = fields.Bool(required=False)
requires_mfa_enrollment = fields.Bool(required=False)
mfa_compliance = fields.Nested(MfaComplianceSchema, required=False)
+3 -1
View File
@@ -140,13 +140,14 @@ class AuthService:
return user return user
@staticmethod @staticmethod
def create_session(user, duration_seconds=86400): def create_session(user, duration_seconds=86400, is_compliance_only=False):
""" """
Create a new session for the user. Create a new session for the user.
Args: Args:
user: User instance user: User instance
duration_seconds: Session duration in seconds duration_seconds: Session duration in seconds
is_compliance_only: Whether this is a compliance-only session (limited access)
Returns: Returns:
Session instance Session instance
@@ -163,6 +164,7 @@ class AuthService:
user_agent=request.headers.get("User-Agent"), user_agent=request.headers.get("User-Agent"),
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds), expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
last_activity_at=datetime.now(timezone.utc), last_activity_at=datetime.now(timezone.utc),
is_compliance_only=is_compliance_only,
) )
session.save() session.save()
@@ -0,0 +1,978 @@
"""MFA Policy Service."""
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Optional, List, Dict, Any
from gatehouse_app.extensions import db
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.user import User
from gatehouse_app.models.organization import Organization
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.utils.constants import (
MfaPolicyMode,
MfaComplianceStatus,
MfaRequirementOverride,
AuditAction,
UserStatus,
)
@dataclass
class OrgPolicyDto:
"""DTO for organization policy."""
organization_id: str
mfa_policy_mode: str
mfa_grace_period_days: int
notify_days_before: int
policy_version: int
updated_by_user_id: Optional[str] = None
@dataclass
class EffectiveUserPolicyDto:
"""DTO for effective user policy combining org and user overrides."""
organization_id: str
effective_mode: str
requires_totp: bool
requires_webauthn: bool
grace_period_days: int
is_exempt: bool = False
@dataclass
class UserMfaStateDto:
"""DTO for per-organization MFA state."""
organization_id: str
organization_name: str
status: str
effective_mode: str
deadline_at: Optional[str] = None
applied_at: Optional[str] = None
@dataclass
class AggregateMfaStateDto:
"""DTO for aggregate MFA state across all organizations."""
overall_status: str
missing_methods: List[str] = field(default_factory=list)
deadline_at: Optional[str] = None
orgs: List[UserMfaStateDto] = field(default_factory=list)
@dataclass
class LoginPolicyResult:
"""Result of policy evaluation after primary auth success."""
can_create_full_session: bool
create_compliance_only_session: bool
compliance_summary: AggregateMfaStateDto
class MfaPolicyService:
"""Service for MFA policy evaluation and compliance tracking."""
@staticmethod
def get_org_policy(org_id: str) -> Optional[OrgPolicyDto]:
"""Get organization security policy.
Args:
org_id: Organization ID
Returns:
OrgPolicyDto or None if not found
"""
policy = OrganizationSecurityPolicy.query.filter_by(
organization_id=org_id, deleted_at=None
).first()
if not policy:
return None
return OrgPolicyDto(
organization_id=policy.organization_id,
mfa_policy_mode=policy.mfa_policy_mode.value,
mfa_grace_period_days=policy.mfa_grace_period_days,
notify_days_before=policy.notify_days_before,
policy_version=policy.policy_version,
updated_by_user_id=policy.updated_by_user_id,
)
@staticmethod
def get_effective_user_policy(
user_id: str, org_id: str
) -> EffectiveUserPolicyDto:
"""Get effective user policy combining org policy with user overrides.
Args:
user_id: User ID
org_id: Organization ID
Returns:
EffectiveUserPolicyDto
"""
# Get org policy
org_policy = OrganizationSecurityPolicy.query.filter_by(
organization_id=org_id, deleted_at=None
).first()
if not org_policy:
# No org policy means no requirements
return EffectiveUserPolicyDto(
organization_id=org_id,
effective_mode=MfaPolicyMode.DISABLED.value,
requires_totp=False,
requires_webauthn=False,
grace_period_days=0,
is_exempt=True,
)
# Get user override
user_override = UserSecurityPolicy.query.filter_by(
user_id=user_id, organization_id=org_id, deleted_at=None
).first()
# Determine effective mode
if user_override:
override_mode = user_override.mfa_override_mode
if override_mode == MfaRequirementOverride.EXEMPT:
return EffectiveUserPolicyDto(
organization_id=org_id,
effective_mode=MfaPolicyMode.DISABLED.value,
requires_totp=False,
requires_webauthn=False,
grace_period_days=org_policy.mfa_grace_period_days,
is_exempt=True,
)
elif override_mode == MfaRequirementOverride.REQUIRED:
# User is required to have MFA even if org is optional
effective_mode = MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN
else:
effective_mode = org_policy.mfa_policy_mode
else:
effective_mode = org_policy.mfa_policy_mode
# Determine required methods based on mode
requires_totp = effective_mode in (
MfaPolicyMode.REQUIRE_TOTP,
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
)
requires_webauthn = effective_mode in (
MfaPolicyMode.REQUIRE_WEBAUTHN,
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
)
return EffectiveUserPolicyDto(
organization_id=org_id,
effective_mode=effective_mode.value,
requires_totp=requires_totp,
requires_webauthn=requires_webauthn,
grace_period_days=org_policy.mfa_grace_period_days,
is_exempt=False,
)
@staticmethod
def evaluate_user_mfa_state(user: User) -> AggregateMfaStateDto:
"""Evaluate user's MFA state across all organizations.
Args:
user: User instance
Returns:
AggregateMfaStateDto with overall status and per-org breakdown
"""
org_states: List[UserMfaStateDto] = []
overall_status = MfaComplianceStatus.COMPLIANT.value
earliest_deadline: Optional[datetime] = None
missing_methods: set = set()
for membership in user.organization_memberships:
if membership.deleted_at is not None:
continue
org = membership.organization
if org.deleted_at is not None:
continue
effective_policy = MfaPolicyService.get_effective_user_policy(
user.id, org.id
)
# Get or create compliance record
compliance = MfaPolicyCompliance.query.filter_by(
user_id=user.id, organization_id=org.id, deleted_at=None
).first()
if not compliance:
# Create initial compliance record
compliance = MfaPolicyCompliance(
user_id=user.id,
organization_id=org.id,
status=MfaComplianceStatus.NOT_APPLICABLE,
policy_version=0,
)
compliance.save()
# Determine status based on policy and user MFA state
status = MfaPolicyService._evaluate_compliance_status(
user, effective_policy, compliance
)
# Update compliance record if needed
if compliance.status != status:
compliance.status = status
db.session.commit()
# Track missing methods
if status not in (
MfaComplianceStatus.COMPLIANT.value,
MfaComplianceStatus.NOT_APPLICABLE.value,
):
if effective_policy.requires_totp and not user.has_totp_enabled():
missing_methods.add("totp")
if effective_policy.requires_webauthn and not user.has_webauthn_enabled():
missing_methods.add("webauthn")
# Track earliest deadline
if compliance.deadline_at:
if earliest_deadline is None or compliance.deadline_at < earliest_deadline:
earliest_deadline = compliance.deadline_at
# Determine overall status (most restrictive)
if status == MfaComplianceStatus.SUSPENDED.value:
overall_status = MfaComplianceStatus.SUSPENDED.value
elif (
status == MfaComplianceStatus.PAST_DUE.value
and overall_status != MfaComplianceStatus.SUSPENDED.value
):
overall_status = MfaComplianceStatus.PAST_DUE.value
elif (
status == MfaComplianceStatus.IN_GRACE.value
and overall_status
not in (
MfaComplianceStatus.SUSPENDED.value,
MfaComplianceStatus.PAST_DUE.value,
)
):
overall_status = MfaComplianceStatus.IN_GRACE.value
elif (
status == MfaComplianceStatus.PENDING.value
and overall_status == MfaComplianceStatus.COMPLIANT.value
):
overall_status = MfaComplianceStatus.PENDING.value
org_states.append(
UserMfaStateDto(
organization_id=org.id,
organization_name=org.name,
status=status,
effective_mode=effective_policy.effective_mode,
deadline_at=compliance.deadline_at.isoformat() if compliance.deadline_at else None,
applied_at=compliance.applied_at.isoformat() if compliance.applied_at else None,
)
)
return AggregateMfaStateDto(
overall_status=overall_status,
missing_methods=list(missing_methods),
deadline_at=earliest_deadline.isoformat() if earliest_deadline else None,
orgs=org_states,
)
@staticmethod
def _evaluate_compliance_status(
user: User,
effective_policy: EffectiveUserPolicyDto,
compliance: MfaPolicyCompliance,
) -> str:
"""Evaluate compliance status for a user in an organization.
Args:
user: User instance
effective_policy: EffectiveUserPolicyDto
compliance: MfaPolicyCompliance instance
Returns:
Status string
"""
now = datetime.now(timezone.utc)
# If exempt or disabled, mark as not applicable
if effective_policy.is_exempt:
return MfaComplianceStatus.NOT_APPLICABLE.value
if effective_policy.effective_mode == MfaPolicyMode.DISABLED.value:
return MfaComplianceStatus.NOT_APPLICABLE.value
# Check if user has required MFA methods
has_totp = user.has_totp_enabled()
has_webauthn = user.has_webauthn_enabled()
has_required = (
(not effective_policy.requires_totp or has_totp)
and (not effective_policy.requires_webauthn or has_webauthn)
)
if has_required:
return MfaComplianceStatus.COMPLIANT.value
# User is missing required MFA
# If no deadline set, set it now
if not compliance.deadline_at and effective_policy.grace_period_days > 0:
compliance.applied_at = now
compliance.deadline_at = now.replace(
tzinfo=None
) + __import__("datetime").timedelta(
days=effective_policy.grace_period_days
)
db.session.commit()
return MfaComplianceStatus.IN_GRACE.value
# Check deadline
if compliance.deadline_at:
deadline = compliance.deadline_at
if deadline.tzinfo is None:
deadline = deadline.replace(tzinfo=timezone.utc)
if now < deadline:
return MfaComplianceStatus.IN_GRACE.value
else:
return MfaComplianceStatus.PAST_DUE.value
return MfaComplianceStatus.PENDING.value
@staticmethod
def after_primary_auth_success(
user: User, remember_me: bool = False
) -> LoginPolicyResult:
"""Determine session type based on compliance after primary auth success.
Args:
user: User instance
remember_me: Whether this is a remember-me session
Returns:
LoginPolicyResult with session type and compliance summary
"""
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user)
# Check if there are any REQUIRED policies affecting this user
has_required_policy = False
for org_state in compliance_summary.orgs:
if org_state.effective_mode in (
MfaPolicyMode.REQUIRE_TOTP.value,
MfaPolicyMode.REQUIRE_WEBAUTHN.value,
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
):
has_required_policy = True
break
if not has_required_policy:
# No required policies, full session allowed
return LoginPolicyResult(
can_create_full_session=True,
create_compliance_only_session=False,
compliance_summary=compliance_summary,
)
# Check if user is compliant
if compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value:
return LoginPolicyResult(
can_create_full_session=True,
create_compliance_only_session=False,
compliance_summary=compliance_summary,
)
# User is not compliant
if compliance_summary.overall_status in (
MfaComplianceStatus.IN_GRACE.value,
MfaComplianceStatus.PENDING.value,
):
# Can proceed with full session but warnings
return LoginPolicyResult(
can_create_full_session=True,
create_compliance_only_session=False,
compliance_summary=compliance_summary,
)
# Past due or suspended - compliance only session
return LoginPolicyResult(
can_create_full_session=False,
create_compliance_only_session=True,
compliance_summary=compliance_summary,
)
@staticmethod
def transition_to_suspended_if_past_due(now: Optional[datetime] = None) -> int:
"""Scheduled job to transition past-due users to suspended status.
Args:
now: Current time, defaults to now
Returns:
Number of users transitioned to suspended
"""
if now is None:
now = datetime.now(timezone.utc)
suspended_count = 0
# Find all compliance records that are past due
past_due_records = MfaPolicyCompliance.query.filter(
MfaPolicyCompliance.status == MfaComplianceStatus.PAST_DUE,
MfaPolicyCompliance.deadline_at != None,
MfaPolicyCompliance.deleted_at == None,
).all()
for record in past_due_records:
deadline = record.deadline_at
if deadline.tzinfo is None:
deadline = deadline.replace(tzinfo=timezone.utc)
if now >= deadline:
# Transition to suspended
record.status = MfaComplianceStatus.SUSPENDED
record.suspended_at = now
db.session.commit()
# Update user status
user = User.query.get(record.user_id)
if user and user.status != UserStatus.COMPLIANCE_SUSPENDED:
user.status = UserStatus.COMPLIANCE_SUSPENDED
db.session.commit()
# Audit log
AuditService.log_action(
action=AuditAction.MFA_POLICY_USER_SUSPENDED,
user_id=record.user_id,
organization_id=record.organization_id,
description=f"User suspended due to MFA compliance deadline passed",
)
suspended_count += 1
return suspended_count
@staticmethod
def create_org_policy(
organization_id: str,
mfa_policy_mode: MfaPolicyMode,
mfa_grace_period_days: int = 14,
notify_days_before: int = 7,
updated_by_user_id: Optional[str] = None,
) -> OrganizationSecurityPolicy:
"""Create or update organization security policy.
Args:
organization_id: Organization ID
mfa_policy_mode: MFA policy mode
mfa_grace_period_days: Grace period in days
notify_days_before: Days before deadline to notify
updated_by_user_id: User making the change
Returns:
OrganizationSecurityPolicy instance
"""
policy = OrganizationSecurityPolicy.query.filter_by(
organization_id=organization_id, deleted_at=None
).first()
if policy:
# Update existing
old_mode = policy.mfa_policy_mode
policy.mfa_policy_mode = mfa_policy_mode
policy.mfa_grace_period_days = mfa_grace_period_days
policy.notify_days_before = notify_days_before
policy.policy_version += 1
policy.updated_by_user_id = updated_by_user_id
policy.save()
# Audit log
AuditService.log_action(
action=AuditAction.ORG_SECURITY_POLICY_UPDATE,
user_id=updated_by_user_id,
organization_id=organization_id,
description=f"Security policy updated from {old_mode.value} to {mfa_policy_mode.value}",
)
else:
# Create new
policy = OrganizationSecurityPolicy(
organization_id=organization_id,
mfa_policy_mode=mfa_policy_mode,
mfa_grace_period_days=mfa_grace_period_days,
notify_days_before=notify_days_before,
policy_version=1,
updated_by_user_id=updated_by_user_id,
)
policy.save()
# Audit log
AuditService.log_action(
action=AuditAction.ORG_SECURITY_POLICY_UPDATE,
user_id=updated_by_user_id,
organization_id=organization_id,
description=f"Security policy created with mode {mfa_policy_mode.value}",
)
return policy
@staticmethod
def set_user_override(
user_id: str,
organization_id: str,
mfa_override_mode: MfaRequirementOverride,
force_totp: bool = False,
force_webauthn: bool = False,
updated_by_user_id: Optional[str] = None,
) -> UserSecurityPolicy:
"""Set user security policy override.
Args:
user_id: User ID
organization_id: Organization ID
mfa_override_mode: Override mode
force_totp: Force TOTP requirement
force_webauthn: Force WebAuthn requirement
updated_by_user_id: User making the change
Returns:
UserSecurityPolicy instance
"""
override = UserSecurityPolicy.query.filter_by(
user_id=user_id, organization_id=organization_id, deleted_at=None
).first()
if override:
old_mode = override.mfa_override_mode
override.mfa_override_mode = mfa_override_mode
override.force_totp = force_totp
override.force_webauthn = force_webauthn
override.save()
# Audit log
AuditService.log_action(
action=AuditAction.USER_SECURITY_POLICY_OVERRIDE_UPDATE,
user_id=updated_by_user_id,
organization_id=organization_id,
resource_type="user",
resource_id=user_id,
description=f"User policy override updated from {old_mode.value} to {mfa_override_mode.value}",
)
else:
override = UserSecurityPolicy(
user_id=user_id,
organization_id=organization_id,
mfa_override_mode=mfa_override_mode,
force_totp=force_totp,
force_webauthn=force_webauthn,
)
override.save()
# Audit log
AuditService.log_action(
action=AuditAction.USER_SECURITY_POLICY_OVERRIDE_UPDATE,
user_id=updated_by_user_id,
organization_id=organization_id,
resource_type="user",
resource_id=user_id,
description=f"User policy override created with mode {mfa_override_mode.value}",
)
return override
@staticmethod
def get_user_compliance(user_id: str, organization_id: str) -> Optional[MfaPolicyCompliance]:
"""Get user compliance record for an organization.
Args:
user_id: User ID
organization_id: Organization ID
Returns:
MfaPolicyCompliance or None
"""
return MfaPolicyCompliance.query.filter_by(
user_id=user_id, organization_id=organization_id, deleted_at=None
).first()
@staticmethod
def get_org_compliance_list(
organization_id: str, status: Optional[MfaComplianceStatus] = None, limit: int = 100, offset: int = 0
) -> List[Dict[str, Any]]:
"""Get list of user compliance records for an organization.
Args:
organization_id: Organization ID
status: Optional status filter
limit: Maximum records to return
offset: Offset for pagination
Returns:
List of compliance records with user info
"""
query = db.session.query(
MfaPolicyCompliance,
User.email,
User.full_name,
).join(
User, User.id == MfaPolicyCompliance.user_id
).filter(
MfaPolicyCompliance.organization_id == organization_id,
MfaPolicyCompliance.deleted_at == None,
User.deleted_at == None,
)
if status:
query = query.filter(MfaPolicyCompliance.status == status)
records = query.order_by(
MfaPolicyCompliance.created_at.desc()
).limit(limit).offset(offset).all()
result = []
for compliance, email, full_name in records:
result.append({
"user_id": compliance.user_id,
"email": email,
"full_name": full_name,
"status": compliance.status.value,
"deadline_at": compliance.deadline_at.isoformat() if compliance.deadline_at else None,
"applied_at": compliance.applied_at.isoformat() if compliance.applied_at else None,
"compliant_at": compliance.compliant_at.isoformat() if compliance.compliant_at else None,
"suspended_at": compliance.suspended_at.isoformat() if compliance.suspended_at else None,
"notification_count": compliance.notification_count,
})
return result
# =========================================================================
# Multi-Organization Edge Case Handling
# =========================================================================
@staticmethod
def get_strictest_mode(modes: List[str]) -> str:
"""Get the strictest MFA policy mode from a list.
Used for multi-org scenarios where a user belongs to multiple organizations
with different policies. "Most secure wins" logic determines the effective
requirement.
Args:
modes: List of policy mode strings
Returns:
The strictest mode string
"""
# Define strictness hierarchy (more strict = higher index)
strictness_order = [
MfaPolicyMode.DISABLED.value,
MfaPolicyMode.OPTIONAL.value,
MfaPolicyMode.REQUIRE_TOTP.value,
MfaPolicyMode.REQUIRE_WEBAUTHN.value,
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
]
max_strictness = -1
result_mode = MfaPolicyMode.OPTIONAL.value
for mode in modes:
if mode in strictness_order:
idx = strictness_order.index(mode)
if idx > max_strictness:
max_strictness = idx
result_mode = mode
return result_mode
@staticmethod
def reevaluate_all_org_compliance(organization_id: str, now: Optional[datetime] = None) -> int:
"""Reevaluate compliance for all users in an organization.
Called when org policy changes to ensure all users are properly evaluated
under the new policy. This handles the edge case where policy becomes
more restrictive (e.g., OPTIONAL -> REQUIRE_TOTP).
Args:
organization_id: Organization ID
now: Current time, defaults to now
Returns:
Number of compliance records updated
"""
if now is None:
now = datetime.now(timezone.utc)
from gatehouse_app.models.organization_member import OrganizationMember
updated_count = 0
# Get all active members of the organization
memberships = OrganizationMember.query.filter_by(
organization_id=organization_id, deleted_at=None
).all()
for membership in memberships:
user = membership.user
if not user or user.deleted_at is not None:
continue
# Get or create compliance record
compliance = MfaPolicyCompliance.query.filter_by(
user_id=user.id, organization_id=organization_id, deleted_at=None
).first()
if not compliance:
compliance = MfaPolicyCompliance(
user_id=user.id,
organization_id=organization_id,
status=MfaComplianceStatus.NOT_APPLICABLE,
policy_version=0,
)
compliance.save()
# Reevaluate under new policy
effective_policy = MfaPolicyService.get_effective_user_policy(
user.id, organization_id
)
old_status = compliance.status.value if hasattr(compliance.status, 'value') else str(compliance.status)
new_status = MfaPolicyService._evaluate_compliance_status(
user, effective_policy, compliance
)
if old_status != new_status:
compliance.status = MfaComplianceStatus(new_status)
# Reset deadline if transitioning to in_grace from a non-grace state
if new_status == MfaComplianceStatus.IN_GRACE.value and not compliance.deadline_at:
compliance.applied_at = now
compliance.deadline_at = now.replace(tzinfo=None) + __import__("datetime").timedelta(
days=effective_policy.grace_period_days
)
db.session.commit()
updated_count += 1
logger.info(
f"Reevaluated compliance for user {user.email} in org {organization_id}: "
f"{old_status} -> {new_status}"
)
return updated_count
@staticmethod
def check_and_restore_user_status(user_id: str) -> bool:
"""Check if user should be restored to ACTIVE status.
Called after compliance changes to determine if a COMPLIANCE_SUSPENDED
user should be restored to ACTIVE status. This happens when:
- All org policies are now compliant
- User overrides were changed to EXEMPT
Args:
user_id: User ID
Returns:
True if user status was restored, False otherwise
"""
user = User.query.get(user_id)
if not user:
return False
if user.status != UserStatus.COMPLIANCE_SUSPENDED:
return False
# Evaluate user's overall compliance state
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user)
# If now compliant across all orgs, restore status
if compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value:
user.status = UserStatus.ACTIVE
db.session.commit()
# Audit log
AuditService.log_action(
action=AuditAction.MFA_POLICY_USER_COMPLIANT,
user_id=user_id,
description="User restored to ACTIVE status after becoming MFA compliant",
)
logger.info(f"User {user.email} restored to ACTIVE status")
return True
return False
# =========================================================================
# User Override Edge Case Handling
# =========================================================================
@staticmethod
def get_override_summary(user_id: str, organization_id: str) -> Dict[str, Any]:
"""Get a summary of user override for an organization.
Args:
user_id: User ID
organization_id: Organization ID
Returns:
Dictionary with override information
"""
user_override = UserSecurityPolicy.query.filter_by(
user_id=user_id, organization_id=organization_id, deleted_at=None
).first()
org_policy = MfaPolicyService.get_org_policy(organization_id)
if not user_override:
return {
"has_override": False,
"mode": "inherit",
"org_policy_mode": org_policy.mfa_policy_mode if org_policy else "none",
"effective_mode": org_policy.mfa_policy_mode if org_policy else "disabled",
}
effective_policy = MfaPolicyService.get_effective_user_policy(
user_id, organization_id
)
return {
"has_override": True,
"mode": user_override.mfa_override_mode.value,
"force_totp": user_override.force_totp,
"force_webauthn": user_override.force_webauthn,
"org_policy_mode": org_policy.mfa_policy_mode if org_policy else "none",
"effective_mode": effective_policy.effective_mode,
"is_exempt": effective_policy.is_exempt,
}
# =========================================================================
# Security Audit Logging
# =========================================================================
@staticmethod
def log_suspended_login_attempt(user: User, ip_address: str = None, user_agent: str = None):
"""Log a login attempt by a compliance-suspended user.
This provides audit trail for potential security incidents where
suspended users attempt to access the system.
Args:
user: User instance
ip_address: Client IP address
user_agent: Client user agent
"""
# Get current compliance summary
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user)
# Find which org(s) caused suspension
suspended_orgs = [
org for org in compliance_summary.orgs
if org.status == MfaComplianceStatus.SUSPENDED.value
]
org_ids = [org.organization_id for org in suspended_orgs]
AuditService.log_action(
action=AuditAction.USER_LOGIN,
user_id=user.id,
organization_id=org_ids[0] if org_ids else None,
ip_address=ip_address,
user_agent=user_agent,
description=f"Login attempt while compliance suspended. Suspended orgs: {org_ids}",
success=False,
error_message="MFA compliance required",
)
@staticmethod
def log_policy_bypass_attempt(
user: User,
endpoint: str,
ip_address: str = None,
user_agent: str = None,
):
"""Log a potential policy bypass attempt.
Called when a compliance-only session attempts to access a
full-access endpoint. This could indicate security issues.
Args:
user: User instance
endpoint: Requested endpoint
ip_address: Client IP address
user_agent: Client user agent
"""
AuditService.log_action(
action=AuditAction.USER_LOGIN, # Reusing USER_LOGIN for audit
user_id=user.id,
ip_address=ip_address,
user_agent=user_agent,
resource_type="endpoint",
resource_id=endpoint,
description=f"Policy bypass attempt - compliance-only session accessed {endpoint}",
success=False,
error_message="MFA compliance required",
)
@staticmethod
def get_multi_org_aggregate_state(user: User) -> Dict[str, Any]:
"""Get aggregate MFA state for a user across all organizations.
This provides detailed breakdown of how multi-org membership affects
compliance status, useful for debugging and admin reporting.
Args:
user: User instance
Returns:
Dictionary with aggregate state details
"""
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user)
# Calculate strictest requirement
modes = [org.effective_mode for org in compliance_summary.orgs]
strictest_mode = MfaPolicyService.get_strictest_mode(modes)
# Find organizations requiring MFA
requiring_orgs = [
{
"organization_id": org.organization_id,
"organization_name": org.organization_name,
"mode": org.effective_mode,
"status": org.status,
}
for org in compliance_summary.orgs
if org.effective_mode not in (
MfaPolicyMode.DISABLED.value,
MfaPolicyMode.OPTIONAL.value,
)
]
# Find exempt organizations
for org in compliance_summary.orgs:
override_summary = MfaPolicyService.get_override_summary(
user.id, org.organization_id
)
if override_summary.get("is_exempt"):
requiring_orgs = [
o for o in requiring_orgs
if o["organization_id"] != org.organization_id
]
return {
"overall_status": compliance_summary.overall_status,
"strictest_mode": strictest_mode,
"missing_methods": compliance_summary.missing_methods,
"deadline_at": compliance_summary.deadline_at,
"requiring_org_count": len(requiring_orgs),
"requiring_orgs": requiring_orgs,
"total_org_count": len(compliance_summary.orgs),
"per_org_details": [
{
"organization_id": org.organization_id,
"organization_name": org.organization_name,
"effective_mode": org.effective_mode,
"status": org.status,
"deadline_at": org.deadline_at,
"applied_at": org.applied_at,
}
for org in compliance_summary.orgs
],
}
@@ -0,0 +1,430 @@
"""Notification Service for MFA compliance notifications.
This service handles sending MFA-related notifications to users, including:
- Deadline reminder emails
- Suspension notifications
- Compliance status updates
The service is designed to work with or without email infrastructure:
- If email is configured, it sends actual emails
- If email is not available, it logs notifications for debugging/auditing
Usage:
from gatehouse_app.services.notification_service import NotificationService
NotificationService.send_mfa_deadline_reminder(user, compliance, org_policy)
"""
from datetime import datetime, timezone
from typing import Optional, Dict, Any
import logging
import json
from gatehouse_app.extensions import db
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.user import User
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.utils.constants import AuditAction
logger = logging.getLogger(__name__)
class NotificationService:
"""Service for sending MFA compliance notifications."""
# Configuration keys for email settings
EMAIL_ENABLED_KEY = "EMAIL_ENABLED"
SMTP_HOST_KEY = "SMTP_HOST"
SMTP_PORT_KEY = "SMTP_PORT"
SMTP_USERNAME_KEY = "SMTP_USERNAME"
SMTP_PASSWORD_KEY = "SMTP_PASSWORD"
FROM_ADDRESS_KEY = "FROM_ADDRESS"
@staticmethod
def send_mfa_deadline_reminder(
user: User,
compliance: MfaPolicyCompliance,
org_policy: OrganizationSecurityPolicy,
) -> bool:
"""Send MFA deadline reminder notification to user.
Sends a reminder email to users who are approaching their MFA
compliance deadline. The reminder includes:
- Days remaining until deadline
- Required MFA methods
- Link to MFA enrollment
Args:
user: User to notify
compliance: User's compliance record
org_policy: Organization's MFA policy
Returns:
True if notification was sent successfully, False otherwise
"""
try:
# Calculate days until deadline
deadline = compliance.deadline_at
if deadline.tzinfo is None:
deadline = deadline.replace(tzinfo=timezone.utc)
now = datetime.now(timezone.utc)
days_until_deadline = (deadline - now).days
# Build notification content
subject = f"Action Required: MFA enrollment deadline in {days_until_deadline} days"
body = NotificationService._build_deadline_reminder_body(
user, compliance, org_policy, days_until_deadline
)
# Send the notification
success = NotificationService._send_email(
to_address=user.email,
subject=subject,
body=body,
)
if success:
logger.info(
f"Sent MFA deadline reminder to {user.email} "
f"({days_until_deadline} days remaining # Audit log
)"
)
AuditService.log_action(
action=AuditAction.MFA_POLICY_USER_COMPLIANT,
user_id=user.id,
organization_id=compliance.organization_id,
description=f"MFA deadline reminder sent. Days remaining: {days_until_deadline}",
)
else:
logger.warning(
f"Failed to send MFA deadline reminder to {user.email}"
)
return success
except Exception as e:
logger.exception(f"Error sending MFA deadline reminder to {user.email}: {e}")
return False
@staticmethod
def send_mfa_suspended_notification(
user: User,
compliance: MfaPolicyCompliance,
org_policy: OrganizationSecurityPolicy,
) -> bool:
"""Send MFA suspension notification to user.
Notifies users that their account has been suspended due to
failure to comply with MFA requirements. The notification includes:
- Explanation of suspension
- Steps to restore access
- Link to MFA enrollment
Args:
user: User to notify
compliance: User's compliance record
org_policy: Organization's MFA policy
Returns:
True if notification was sent successfully, False otherwise
"""
try:
# Build notification content
subject = "Account Access Restricted - MFA Enrollment Required"
body = NotificationService._build_suspension_body(
user, compliance, org_policy
)
# Send the notification
success = NotificationService._send_email(
to_address=user.email,
subject=subject,
body=body,
)
if success:
logger.info(f"Sent MFA suspension notification to {user.email}")
# Audit log
AuditService.log_action(
action=AuditAction.MFA_POLICY_USER_SUSPENDED,
user_id=user.id,
organization_id=compliance.organization_id,
description="MFA compliance suspension notification sent",
)
else:
logger.warning(
f"Failed to send MFA suspension notification to {user.email}"
)
return success
except Exception as e:
logger.exception(
f"Error sending MFA suspension notification to {user.email}: {e}"
)
return False
@staticmethod
def _build_deadline_reminder_body(
user: User,
compliance: MfaPolicyCompliance,
org_policy: OrganizationSecurityPolicy,
days_until_deadline: int,
) -> str:
"""Build the email body for deadline reminder.
Args:
user: User being notified
compliance: Compliance record
org_policy: Organization policy
days_until_deadline: Days remaining until deadline
Returns:
Email body string
"""
org_name = compliance.organization_id # In real impl, fetch org name
body = f"""
Dear {user.full_name or user.email},
This is a reminder that you need to set up multi-factor authentication (MFA)
to maintain access to your account in the organization "{org_name}".
**Important Details:**
- Days remaining: {days_until_deadline}
- Deadline: {compliance.deadline_at.strftime('%Y-%m-%d %H:%M UTC') if compliance.deadline_at else 'Not set'}
**Required MFA Methods:**
"""
# Add required methods based on policy mode
from gatehouse_app.utils.constants import MfaPolicyMode
mode = org_policy.mfa_policy_mode
if mode == MfaPolicyMode.REQUIRE_TOTP:
body += "- Authenticator app (TOTP)\n"
elif mode == MfaPolicyMode.REQUIRE_WEBAUTHN:
body += "- Passkey (WebAuthn)\n"
elif mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN:
body += "- Authenticator app (TOTP) OR Passkey (WebAuthn)\n"
else:
body += "- Multi-factor authentication\n"
body += """
**How to Set Up MFA:**
1. Log in to your account
2. Navigate to Settings > Security
3. Follow the prompts to set up an authenticator app or passkey
If you do not set up MFA by the deadline, your account access will be restricted.
If you have any questions, please contact your organization administrator.
Best regards,
Gatehouse Security Team
"""
return body
@staticmethod
def _build_suspension_body(
user: User,
compliance: MfaPolicyCompliance,
org_policy: OrganizationSecurityPolicy,
) -> str:
"""Build the email body for suspension notification.
Args:
user: User being notified
compliance: Compliance record
org_policy: Organization policy
Returns:
Email body string
"""
org_name = compliance.organization_id # In real impl, fetch org name
body = f"""
Dear {user.full_name or user.email},
Your account access has been restricted because you did not set up
multi-factor authentication (MFA) within the required timeframe for
the organization "{org_name}".
**What Happened:**
Your MFA compliance deadline passed without MFA being configured.
As a result, your account has been placed in a suspended state.
**How to Restore Access:**
1. Log in to your account (you will see a compliance enrollment screen)
2. Follow the prompts to set up an authenticator app or passkey
3. Once MFA is configured, your access will be restored
**Required MFA Methods:
"""
# Add required methods based on policy mode
from gatehouse_app.utils.constants import MfaPolicyMode
mode = org_policy.mfa_policy_mode
if mode == MfaPolicyMode.REQUIRE_TOTP:
body += "- Authenticator app (TOTP)\n"
elif mode == MfaPolicyMode.REQUIRE_WEBAUTHN:
body += "- Passkey (WebAuthn)\n"
elif mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN:
body += "- Authenticator app (TOTP) OR Passkey (WebAuthn)\n"
else:
body += "- Multi-factor authentication\n"
body += """
**Need Help?**
Contact your organization administrator if you have questions.
Best regards,
Gatehouse Security Team
"""
return body
@staticmethod
def _send_email(
to_address: str,
subject: str,
body: str,
html_body: Optional[str] = None,
) -> bool:
"""Send an email notification.
This method attempts to send an email using configured SMTP settings.
If email is not configured, it logs the notification instead.
Args:
to_address: Recipient email address
subject: Email subject
body: Plain text email body
html_body: Optional HTML email body
Returns:
True if email was sent (or logged), False on error
"""
try:
from flask import current_app
# Check if email is configured
email_enabled = current_app.config.get(
NotificationService.EMAIL_ENABLED_KEY, False
)
if not email_enabled:
# Log the notification instead of sending
logger.info(
f"[EMAIL SIMULATION] To: {to_address}\n"
f"Subject: {subject}\n"
f"Body: {body[:200]}..." if len(body) > 200 else f"Body: {body}"
)
return True
# Get email configuration
smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY)
smtp_port = current_app.config.get(NotificationService.SMTP_PORT_KEY, 587)
smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY)
smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY)
from_address = current_app.config.get(
NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local"
)
# Import send_email based on available mail library
try:
from flask_mail import Message
from gatehouse_app import mail
msg = Message(
subject=subject,
recipients=[to_address],
body=body,
html=html_body,
sender=from_address,
)
mail.send(msg)
logger.info(f"Email sent successfully to {to_address}")
return True
except ImportError:
# Flask-Mail not available, use SMTP directly
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["From"] = from_address
msg["To"] = to_address
# Attach plain text and HTML versions
part1 = MIMEText(body, "plain")
msg.attach(part1)
if html_body:
part2 = MIMEText(html_body, "html")
msg.attach(part2)
# Send via SMTP
with smtplib.SMTP(smtp_host, smtp_port) as server:
server.starttls()
if smtp_username and smtp_password:
server.login(smtp_username, smtp_password)
server.send_message(msg)
logger.info(f"Email sent successfully to {to_address}")
return True
except Exception as e:
logger.exception(f"Failed to send email to {to_address}: {e}")
# Log the notification as fallback
logger.info(
f"[EMAIL FALLBACK] To: {to_address}\n"
f"Subject: {subject}\n"
f"Body: {body[:500]}..." if len(body) > 500 else f"Body: {body}"
)
return True # Return True to continue processing
@staticmethod
def get_notification_stats(user_id: str) -> Dict[str, Any]:
"""Get notification statistics for a user.
Args:
user_id: User ID
Returns:
Dictionary with notification statistics
"""
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
stats = {
"total_notifications": 0,
"last_notification": None,
"by_organization": [],
}
compliance_records = MfaPolicyCompliance.query.filter_by(
user_id=user_id, deleted_at=None
).all()
total_notifications = 0
last_notification = None
for record in compliance_records:
total_notifications += record.notification_count
if record.last_notified_at:
if last_notification is None or record.last_notified_at > last_notification:
last_notification = record.last_notified_at
stats["by_organization"].append({
"organization_id": record.organization_id,
"notification_count": record.notification_count,
"last_notified_at": record.last_notified_at.isoformat() if record.last_notified_at else None,
})
stats["total_notifications"] = total_notifications
stats["last_notification"] = last_notification.isoformat() if last_notification else None
return stats
+36
View File
@@ -9,6 +9,7 @@ class UserStatus(str, Enum):
INACTIVE = "inactive" INACTIVE = "inactive"
SUSPENDED = "suspended" SUSPENDED = "suspended"
PENDING = "pending" PENDING = "pending"
COMPLIANCE_SUSPENDED = "compliance_suspended"
class OrganizationRole(str, Enum): class OrganizationRole(str, Enum):
@@ -86,6 +87,12 @@ class AuditAction(str, Enum):
WEBAUTHN_CREDENTIAL_DELETED = "webauthn.credential.deleted" WEBAUTHN_CREDENTIAL_DELETED = "webauthn.credential.deleted"
WEBAUTHN_CREDENTIAL_RENAMED = "webauthn.credential.renamed" WEBAUTHN_CREDENTIAL_RENAMED = "webauthn.credential.renamed"
# Security policy actions
ORG_SECURITY_POLICY_UPDATE = "org.security_policy.update"
USER_SECURITY_POLICY_OVERRIDE_UPDATE = "user.security_policy.override_update"
MFA_POLICY_USER_SUSPENDED = "mfa.policy.user_suspended"
MFA_POLICY_USER_COMPLIANT = "mfa.policy.user_compliant"
class OIDCGrantType(str, Enum): class OIDCGrantType(str, Enum):
"""OIDC grant types.""" """OIDC grant types."""
@@ -116,3 +123,32 @@ class ErrorType:
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED" RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"
INTERNAL_ERROR = "INTERNAL_ERROR" INTERNAL_ERROR = "INTERNAL_ERROR"
BAD_REQUEST = "BAD_REQUEST" BAD_REQUEST = "BAD_REQUEST"
class MfaPolicyMode(str, Enum):
"""MFA policy mode for organizations."""
DISABLED = "disabled"
OPTIONAL = "optional"
REQUIRE_TOTP = "require_totp"
REQUIRE_WEBAUTHN = "require_webauthn"
REQUIRE_TOTP_OR_WEBAUTHN = "require_totp_or_webauthn"
class MfaComplianceStatus(str, Enum):
"""MFA compliance status for users per organization."""
NOT_APPLICABLE = "not_applicable"
PENDING = "pending"
IN_GRACE = "in_grace"
COMPLIANT = "compliant"
PAST_DUE = "past_due"
SUSPENDED = "suspended"
class MfaRequirementOverride(str, Enum):
"""User override for organization MFA requirements."""
INHERIT = "inherit"
REQUIRED = "required"
EXEMPT = "exempt"
+40 -1
View File
@@ -2,7 +2,8 @@
from functools import wraps from functools import wraps
from flask import request, g from flask import request, g
from gatehouse_app.utils.response import api_response from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.constants import OrganizationRole from gatehouse_app.utils.constants import OrganizationRole, UserStatus
from gatehouse_app.exceptions.auth_exceptions import UnauthorizedError, ForbiddenError
def login_required(f): def login_required(f):
@@ -127,3 +128,41 @@ def require_owner(f):
def require_admin(f): def require_admin(f):
"""Decorator to require organization admin or owner role.""" """Decorator to require organization admin or owner role."""
return require_role(OrganizationRole.OWNER, OrganizationRole.ADMIN)(f) return require_role(OrganizationRole.OWNER, OrganizationRole.ADMIN)(f)
def full_access_required(f):
"""Decorator to require full access session (not compliance-only).
This decorator checks if the user has a compliance-only session or
is in COMPLIANCE_SUSPENDED status. If so, it returns a 403 error
with error_type "MFA_COMPLIANCE_REQUIRED".
Use this decorator on endpoints that require full MFA compliance.
Endpoints for MFA enrollment, status, and logout should NOT use this decorator.
"""
@wraps(f)
def decorated_function(*args, **kwargs):
user = getattr(g, "current_user", None)
session = getattr(g, "current_session", None)
if not user or not session:
return api_response(
success=False,
message="Authentication required",
status=401,
error_type="AUTH_REQUIRED",
)
# Check for compliance-only session or compliance suspended status
if session.is_compliance_only or user.status == UserStatus.COMPLIANCE_SUSPENDED:
return api_response(
success=False,
message="MFA compliance required to access this resource",
status=403,
error_type="MFA_COMPLIANCE_REQUIRED",
error_details={"overall_status": "suspended"},
)
return f(*args, **kwargs)
return decorated_function
+97
View File
@@ -14,5 +14,102 @@ app = create_app(os.getenv("FLASK_ENV", "development"))
# Create Flask CLI group # Create Flask CLI group
cli = FlaskGroup(create_app=lambda: app) cli = FlaskGroup(create_app=lambda: app)
@cli.command("run_mfa_compliance_job")
def run_mfa_compliance_job():
"""Run the MFA compliance scheduled job.
This command processes MFA compliance transitions:
- Transitions users from PAST_DUE to SUSPENDED status
- Sends deadline reminder notifications
- Updates notification tracking metadata
Usage:
python manage.py run_mfa_compliance_job
This can be called via cron or a task scheduler:
0 * * * * cd /path/to/app && python manage.py run_mfa_compliance_job
"""
from datetime import datetime, timezone
from gatehouse_app.jobs.mfa_compliance_job import process_mfa_compliance, get_job_status
print("=" * 60)
print("MFA Compliance Job")
print("=" * 60)
now = datetime.now(timezone.utc)
print(f"Start time: {now.isoformat()}")
print()
# Show current status before processing
print("Current Compliance Status:")
status = get_job_status(now)
for status_name, count in status["status_counts"].items():
print(f" {status_name}: {count}")
print(f" Approaching deadline: {status['approaching_deadline_count']}")
print(f" Past due: {status['past_due_count']}")
print()
# Run the job
print("Processing compliance...")
result = process_mfa_compliance(now)
print()
print("Job Results:")
print(f" Users suspended: {result['suspended_count']}")
print(f" Notifications sent: {result['notified_count']}")
print(f" Records processed: {result['processed_count']}")
if result['errors']:
print()
print("Errors:")
for error in result['errors']:
print(f" - {error}")
print()
print("=" * 60)
print("Job completed successfully")
print("=" * 60)
@cli.command("mfa_compliance_status")
def mfa_compliance_status():
"""Show current MFA compliance status.
Usage:
python manage.py mfa_compliance_status
"""
from datetime import datetime, timezone
from gatehouse_app.jobs.mfa_compliance_job import get_job_status
print("=" * 60)
print("MFA Compliance Status Report")
print("=" * 60)
now = datetime.now(timezone.utc)
status = get_job_status(now)
print(f"Report time: {status['timestamp']}")
print()
print("Compliance Records by Status:")
for status_name, count in sorted(status["status_counts"].items()):
bar = "" * min(count, 50)
print(f" {status_name:20s}: {count:5d} {bar}")
print()
print("Summary:")
print(f" Approaching deadline: {status['approaching_deadline_count']}")
print(f" Past due (pending suspension): {status['past_due_count']}")
total = sum(status["status_counts"].values())
compliant = status["status_counts"].get("compliant", 0)
if total > 0:
compliance_rate = (compliant / total) * 100
print(f" Compliance rate: {compliance_rate:.1f}%")
print("=" * 60)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()
-150
View File
@@ -1,150 +0,0 @@
"""Database migration: Create OIDC tables.
Revision ID: 001
Revises:
Create Date: 2024-01-01 00:00:00
This migration creates all OIDC-related tables for the authorization code flow,
refresh token management, OIDC session tracking, token metadata, and audit logging.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# Revision identifiers
revision = '001'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
"""Create OIDC tables."""
# OIDC Authorization Codes table
op.create_table(
'oidc_authorization_codes',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
sa.Column('code_hash', sa.String(255), nullable=False),
sa.Column('redirect_uri', sa.String(512), nullable=False),
sa.Column('scope', postgresql.JSON, nullable=True),
sa.Column('nonce', sa.String(255), nullable=True),
sa.Column('code_verifier', sa.String(255), nullable=True),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('used_at', sa.DateTime, nullable=True),
sa.Column('is_used', sa.Boolean, default=False, nullable=False),
sa.Column('ip_address', sa.String(45), nullable=True),
sa.Column('user_agent', sa.Text, nullable=True),
)
op.create_index('ix_oidc_authorization_codes_client_id', 'oidc_authorization_codes', ['client_id'])
op.create_index('ix_oidc_authorization_codes_user_id', 'oidc_authorization_codes', ['user_id'])
op.create_index('ix_oidc_authorization_codes_expires_at', 'oidc_authorization_codes', ['expires_at'])
# OIDC Refresh Tokens table
op.create_table(
'oidc_refresh_tokens',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
sa.Column('token_hash', sa.String(255), nullable=False),
sa.Column('access_token_id', sa.String(36), sa.ForeignKey('sessions.id'), nullable=True),
sa.Column('scope', postgresql.JSON, nullable=True),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('revoked_at', sa.DateTime, nullable=True),
sa.Column('revoked_reason', sa.String(255), nullable=True),
sa.Column('previous_token_hash', sa.String(255), nullable=True),
sa.Column('rotation_count', sa.Integer, default=0, nullable=False),
sa.Column('ip_address', sa.String(45), nullable=True),
sa.Column('user_agent', sa.Text, nullable=True),
)
op.create_index('ix_oidc_refresh_tokens_client_id', 'oidc_refresh_tokens', ['client_id'])
op.create_index('ix_oidc_refresh_tokens_user_id', 'oidc_refresh_tokens', ['user_id'])
op.create_index('ix_oidc_refresh_tokens_token_hash', 'oidc_refresh_tokens', ['token_hash'], unique=True)
op.create_index('ix_oidc_refresh_tokens_access_token_id', 'oidc_refresh_tokens', ['access_token_id'])
op.create_index('ix_oidc_refresh_tokens_expires_at', 'oidc_refresh_tokens', ['expires_at'])
# OIDC Sessions table
op.create_table(
'oidc_sessions',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
sa.Column('state', sa.String(255), nullable=False),
sa.Column('nonce', sa.String(255), nullable=True),
sa.Column('redirect_uri', sa.String(512), nullable=False),
sa.Column('scope', postgresql.JSON, nullable=True),
sa.Column('code_challenge', sa.String(255), nullable=True),
sa.Column('code_challenge_method', sa.String(10), nullable=True),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('authenticated_at', sa.DateTime, nullable=True),
)
op.create_index('ix_oidc_sessions_user_id', 'oidc_sessions', ['user_id'])
op.create_index('ix_oidc_sessions_client_id', 'oidc_sessions', ['client_id'])
op.create_index('ix_oidc_sessions_state', 'oidc_sessions', ['state'])
op.create_index('ix_oidc_sessions_expires_at', 'oidc_sessions', ['expires_at'])
# OIDC Token Metadata table
op.create_table(
'oidc_token_metadata',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
sa.Column('token_type', sa.String(50), nullable=False),
sa.Column('token_jti', sa.String(255), nullable=False),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('revoked_at', sa.DateTime, nullable=True),
sa.Column('revoked_reason', sa.String(255), nullable=True),
)
op.create_index('ix_oidc_token_metadata_client_id', 'oidc_token_metadata', ['client_id'])
op.create_index('ix_oidc_token_metadata_user_id', 'oidc_token_metadata', ['user_id'])
op.create_index('ix_oidc_token_metadata_token_jti', 'oidc_token_metadata', ['token_jti'])
op.create_index('ix_oidc_token_metadata_expires_at', 'oidc_token_metadata', ['expires_at'])
# OIDC Audit Logs table
op.create_table(
'oidc_audit_logs',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
sa.Column('event_type', sa.String(100), nullable=False),
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=True),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=True),
sa.Column('success', sa.Boolean, default=True, nullable=False),
sa.Column('error_code', sa.String(100), nullable=True),
sa.Column('error_description', sa.Text, nullable=True),
sa.Column('ip_address', sa.String(45), nullable=True),
sa.Column('user_agent', sa.Text, nullable=True),
sa.Column('request_id', sa.String(36), nullable=True),
sa.Column('event_metadata', postgresql.JSON, nullable=True),
)
op.create_index('ix_oidc_audit_logs_event_type', 'oidc_audit_logs', ['event_type'])
op.create_index('ix_oidc_audit_logs_client_id', 'oidc_audit_logs', ['client_id'])
op.create_index('ix_oidc_audit_logs_user_id', 'oidc_audit_logs', ['user_id'])
op.create_index('ix_oidc_audit_logs_success', 'oidc_audit_logs', ['success'])
op.create_index('ix_oidc_audit_logs_ip_address', 'oidc_audit_logs', ['ip_address'])
op.create_index('ix_oidc_audit_logs_request_id', 'oidc_audit_logs', ['request_id'])
def downgrade():
"""Drop OIDC tables."""
op.drop_table('oidc_audit_logs')
op.drop_table('oidc_token_metadata')
op.drop_table('oidc_sessions')
op.drop_table('oidc_refresh_tokens')
op.drop_table('oidc_authorization_codes')
-44
View File
@@ -1,44 +0,0 @@
"""Database migration: Add WebAuthn support.
Revision ID: 002
Revises: 001
Create Date: 2024-01-15 00:00:00
This migration adds support for WebAuthn passkey authentication by:
- Adding WEBAUTHN to the AuthMethodType enum (handled in application code)
- No schema changes required (uses existing provider_data JSON field)
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# Revision identifiers
revision = '002'
down_revision = '001'
branch_labels = None
depends_on = None
def upgrade():
"""Add WebAuthn support - no schema changes needed."""
# WebAuthn credentials are stored in the existing provider_data JSON field
# of the authentication_methods table. No schema changes are required.
# Create an index for faster lookups of WebAuthn credentials by user
# This is optional but recommended for performance
# op.create_index(
# 'ix_authentication_methods_webauthn_user',
# 'authentication_methods',
# ['user_id'],
# postgresql_where=(sa.text("method_type = 'webauthn'")),
# if_not_exists=True
# )
pass
def downgrade():
"""Remove WebAuthn support - no schema changes needed."""
# No schema changes to revert
pass
+933
View File
@@ -0,0 +1,933 @@
"""Integration tests for MFA compliance enforcement."""
import pytest
import json
from datetime import datetime, timezone, timedelta
from gatehouse_app.models.user import User
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.session import Session
from gatehouse_app.utils.constants import MfaPolicyMode, MfaComplianceStatus, UserStatus, MfaRequirementOverride
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
@pytest.mark.integration
class TestMfaComplianceLogin:
"""Integration tests for MFA compliance during login."""
def test_login_with_no_policy(self, client, db, test_user):
"""Test login with no MFA policy (should work normally)."""
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# No MFA compliance info should be present when no policy exists
assert "mfa_compliance" not in data["data"]
assert "requires_mfa_enrollment" not in data["data"]
def test_login_with_optional_policy(self, client, db, test_user, test_organization):
"""Test login with optional MFA policy (should work normally)."""
# Create an optional MFA policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# MFA compliance should be present but status should be not_applicable
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] == "not_applicable"
assert "requires_mfa_enrollment" not in data["data"]
def test_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization):
"""Test login with required policy within grace period (should work with warning)."""
# Create a required MFA policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# MFA compliance should be present with in_grace status
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace"
assert "requires_mfa_enrollment" not in data["data"]
assert "totp" in data["data"]["mfa_compliance"]["missing_methods"]
def test_login_with_required_policy_after_deadline(self, client, db, test_user, test_organization):
"""Test login with required policy after deadline (should get compliance-only session)."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# Should have compliance-only session
assert data["data"]["requires_mfa_enrollment"] is True
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] in ["past_due", "suspended"]
def test_login_with_suspended_user(self, client, db, test_user, test_organization):
"""Test login with compliance suspended user (should get compliance-only session)."""
# Set user status to compliance suspended
test_user.status = UserStatus.COMPLIANCE_SUSPENDED
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# Should have compliance-only session
assert data["data"]["requires_mfa_enrollment"] is True
@pytest.mark.integration
class TestMfaComplianceAccess:
"""Integration tests for MFA compliance access control."""
def test_compliance_only_session_denied_full_access(self, client, db, test_user, test_organization):
"""Test that compliance-only session cannot access full access endpoints."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
# Create a compliance-only session
session = Session(
user_id=test_user.id,
token="compliance_only_token",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
is_compliance_only=True,
)
db.session.add(session)
db.session.commit()
# Try to access a full-access endpoint (get_my_organizations)
response = client.get(
"/api/v1/users/me/organizations",
headers={"Authorization": "Bearer compliance_only_token"},
)
assert response.status_code == 403
data = response.get_json()
assert data["success"] is False
assert data["error_type"] == "MFA_COMPLIANCE_REQUIRED"
def test_compliance_only_session_can_access_mfa_enrollment(self, client, db, test_user, test_organization):
"""Test that compliance-only session can access MFA enrollment endpoints."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
# Create a compliance-only session
session = Session(
user_id=test_user.id,
token="compliance_only_token",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
is_compliance_only=True,
)
db.session.add(session)
db.session.commit()
# Try to access MFA enrollment endpoint (should work)
response = client.get(
"/api/v1/auth/totp/status",
headers={"Authorization": "Bearer compliance_only_token"},
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
def test_compliance_only_session_can_access_logout(self, client, db, test_user, test_organization):
"""Test that compliance-only session can access logout endpoint."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
# Create a compliance-only session
session = Session(
user_id=test_user.id,
token="compliance_only_token",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
is_compliance_only=True,
)
db.session.add(session)
db.session.commit()
# Try to access logout endpoint (should work)
response = client.post(
"/api/v1/auth/logout",
headers={"Authorization": "Bearer compliance_only_token"},
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
@pytest.mark.integration
class TestMfaComplianceWebAuthn:
"""Integration tests for MFA compliance with WebAuthn login."""
def test_webauthn_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization):
"""Test WebAuthn login with required policy within grace period."""
# Create a required MFA policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Note: Full WebAuthn login test would require WebAuthn setup
# This test verifies the compliance response structure
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace"
@pytest.mark.integration
class TestMfaComplianceOIDC:
"""Integration tests for MFA compliance with OIDC authorization."""
def test_oidc_authorize_with_compliance_required(self, client, db, test_user, test_organization, app):
"""Test OIDC authorize with compliance required (should show error)."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
db.session.commit()
# Try OIDC authorize with credentials
response = client.post(
"/oidc/authorize",
data={
"client_id": "test_client",
"redirect_uri": "http://localhost:8080/callback",
"response_type": "code",
"scope": "openid profile email",
"state": "test_state",
"email": test_user.email,
"password": "TestPassword123!",
},
)
# Should return login page with error
assert response.status_code == 200
assert b"Your account requires multi factor enrollment before using single sign on" in response.data
# =============================================================================
# Phase 4: Edge Case Tests
# =============================================================================
@pytest.mark.integration
class TestMfaComplianceMultiOrg:
"""Integration tests for multi-organization MFA compliance edge cases."""
def test_user_with_multiple_orgs_different_policies(self, client, db, test_user):
"""Test user belonging to multiple orgs with different MFA policies."""
# Create two organizations
org1 = Organization(
name="Org1",
slug="org1-test-multi",
)
org2 = Organization(
name="Org2",
slug="org2-test-multi",
)
db.session.add_all([org1, org2])
db.session.commit()
# Add user to both orgs
membership1 = OrganizationMember(
user_id=test_user.id,
organization_id=org1.id,
role="member",
)
membership2 = OrganizationMember(
user_id=test_user.id,
organization_id=org2.id,
role="member",
)
db.session.add_all([membership1, membership2])
db.session.commit()
# Create different policies for each org
# Org1: OPTIONAL (no requirement)
policy1 = OrganizationSecurityPolicy(
organization_id=org1.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
# Org2: REQUIRE_TOTP (strictest)
policy2 = OrganizationSecurityPolicy(
organization_id=org2.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add_all([policy1, policy2])
db.session.commit()
# Evaluate user MFA state
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user)
# Overall status should reflect the strictest policy (REQUIRE_TOTP from org2)
assert compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value
assert "totp" in compliance_summary.missing_methods
# Verify per-org breakdown
assert len(compliance_summary.orgs) == 2
org1_status = next((o for o in compliance_summary.orgs if o.organization_id == org1.id), None)
org2_status = next((o for o in compliance_summary.orgs if o.organization_id == org2.id), None)
assert org1_status is not None
assert org2_status is not None
assert org1_status.status == MfaComplianceStatus.NOT_APPLICABLE.value
assert org2_status.status == MfaComplianceStatus.IN_GRACE.value
def test_user_with_multiple_orgs_all_suspended(self, client, db, test_user):
"""Test user with multiple orgs where all require MFA and are past due."""
# Create two organizations
org1 = Organization(
name="Org1",
slug="org1-test-suspended",
)
org2 = Organization(
name="Org2",
slug="org2-test-suspended",
)
db.session.add_all([org1, org2])
db.session.commit()
# Add user to both orgs
membership1 = OrganizationMember(
user_id=test_user.id,
organization_id=org1.id,
role="member",
)
membership2 = OrganizationMember(
user_id=test_user.id,
organization_id=org2.id,
role="member",
)
db.session.add_all([membership1, membership2])
db.session.commit()
# Create required policies
policy1 = OrganizationSecurityPolicy(
organization_id=org1.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
policy2 = OrganizationSecurityPolicy(
organization_id=org2.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add_all([policy1, policy2])
db.session.commit()
# Create past-due compliance records for both
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
compliance1 = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=org1.id,
status=MfaComplianceStatus.SUSPENDED,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
deadline_at=past_deadline,
suspended_at=past_deadline,
)
compliance2 = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=org2.id,
status=MfaComplianceStatus.SUSPENDED,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
deadline_at=past_deadline,
suspended_at=past_deadline,
)
db.session.add_all([compliance1, compliance2])
db.session.commit()
# Evaluate user MFA state
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user)
# Overall status should be SUSPENDED
assert compliance_summary.overall_status == MfaComplianceStatus.SUSPENDED.value
def test_strictest_mode_selection(self):
"""Test that get_strictest_mode returns the most restrictive policy."""
modes = [
MfaPolicyMode.DISABLED.value,
MfaPolicyMode.OPTIONAL.value,
MfaPolicyMode.REQUIRE_TOTP.value,
]
result = MfaPolicyService.get_strictest_mode(modes)
assert result == MfaPolicyMode.REQUIRE_TOTP.value
# Test with REQUIRE_TOTP_OR_WEBAUTHN (strictest)
modes_strictest = [
MfaPolicyMode.REQUIRE_TOTP.value,
MfaPolicyMode.REQUIRE_WEBAUTHN.value,
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
]
result = MfaPolicyService.get_strictest_mode(modes_strictest)
assert result == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
@pytest.mark.integration
class TestMfaComplianceUserOverrides:
"""Integration tests for user override edge cases."""
def test_user_override_inherit_mode(self, client, db, test_user, test_organization):
"""Test INHERIT mode - org policy applies as is."""
# Create a required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create INHERIT override (default behavior)
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
db.session.add(override)
db.session.commit()
# Get effective policy
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
# Should inherit org policy
assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
assert effective.requires_totp is True
assert effective.is_exempt is False
def test_user_override_required_mode(self, client, db, test_user, test_organization):
"""Test REQUIRED mode - user always required to have MFA."""
# Create an optional policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create REQUIRED override
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
)
db.session.add(override)
db.session.commit()
# Get effective policy
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
# Should be upgraded to REQUIRE_TOTP_OR_WEBAUTHN
assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
assert effective.requires_totp is True
assert effective.requires_webauthn is True
assert effective.is_exempt is False
def test_user_override_exempt_mode(self, client, db, test_user, test_organization):
"""Test EXEMPT mode - org policy does not apply."""
# Create a required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create EXEMPT override
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
db.session.add(override)
db.session.commit()
# Get effective policy
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
# Should be exempt from policy
assert effective.is_exempt is True
assert effective.effective_mode == MfaPolicyMode.DISABLED.value
assert effective.requires_totp is False
assert effective.requires_webauthn is False
def test_get_override_summary(self, client, db, test_user, test_organization):
"""Test getting override summary for a user."""
# No override exists
summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id)
assert summary["has_override"] is False
assert summary["mode"] == "inherit"
# Create an override
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
db.session.add(override)
db.session.commit()
# Get summary again
summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id)
assert summary["has_override"] is True
assert summary["mode"] == "exempt"
assert summary["is_exempt"] is True
@pytest.mark.integration
class TestMfaCompliancePolicyChanges:
"""Integration tests for policy changes affecting existing users."""
def test_policy_change_triggers_compliance_reevaluation(self, client, db, test_user, test_organization):
"""Test that policy change triggers compliance reevaluation."""
# Create initial optional policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create compliance record (should be NOT_APPLICABLE)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.NOT_APPLICABLE,
policy_version=1,
)
db.session.add(compliance)
db.session.commit()
# Update policy to REQUIRE_TOTP
MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
updated_by_user_id=test_user.id,
)
# Reevaluate all compliance
updated_count = MfaPolicyService.reevaluate_all_org_compliance(test_organization.id)
# Should have updated at least one record
assert updated_count >= 1
# Check compliance status was updated
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert updated_compliance.status == MfaComplianceStatus.IN_GRACE.value
assert updated_compliance.deadline_at is not None
def test_policy_relaxation_clears_requirements(self, client, db, test_user, test_organization):
"""Test that relaxing policy clears compliance requirements."""
# Create required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create IN_GRACE compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.IN_GRACE,
policy_version=1,
applied_at=datetime.now(timezone.utc),
deadline_at=datetime.now(timezone.utc) + timedelta(days=14),
)
db.session.add(compliance)
db.session.commit()
# Update policy to OPTIONAL
MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
updated_by_user_id=test_user.id,
)
# Reevaluate compliance
MfaPolicyService.reevaluate_all_org_compliance(test_organization.id)
# Check compliance status was updated to NOT_APPLICABLE
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert updated_compliance.status == MfaComplianceStatus.NOT_APPLICABLE.value
@pytest.mark.integration
class TestMfaComplianceScheduledJob:
"""Integration tests for the MFA compliance scheduled job."""
def test_transition_to_suspended(self, client, db, test_user, test_organization):
"""Test that past-due users are transitioned to suspended."""
# Create required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create past-due compliance record
past_deadline = datetime.now(timezone.utc) - timedelta(hours=1)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
db.session.commit()
# Run the job
now = datetime.now(timezone.utc)
suspended_count = MfaPolicyService.transition_to_suspended_if_past_due(now)
# Should have suspended the user
assert suspended_count >= 1
# Check compliance status
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert updated_compliance.status == MfaComplianceStatus.SUSPENDED.value
assert updated_compliance.suspended_at is not None
# Check user status
db.refresh(test_user)
assert test_user.status == UserStatus.COMPLIANCE_SUSPENDED
def test_check_and_restore_user_status(self, client, db, test_user, test_organization):
"""Test that suspended users are restored when they become compliant."""
# Create required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# User is suspended
test_user.status = UserStatus.COMPLIANCE_SUSPENDED
db.session.commit()
# Create EXEMPT override to clear requirement
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
db.session.add(override)
db.session.commit()
# Check and restore status
restored = MfaPolicyService.check_and_restore_user_status(test_user.id)
# Should have restored user
assert restored is True
db.refresh(test_user)
assert test_user.status == UserStatus.ACTIVE
@pytest.mark.integration
class TestMfaComplianceMultiOrgAggregate:
"""Integration tests for multi-org aggregate state calculation."""
def test_get_multi_org_aggregate_state(self, client, db, test_user):
"""Test aggregate state calculation for multi-org user."""
# Create two organizations
org1 = Organization(
name="AggOrg1",
slug="agg-org1-test",
)
org2 = Organization(
name="AggOrg2",
slug="agg-org2-test",
)
db.session.add_all([org1, org2])
db.session.commit()
# Add user to both
membership1 = OrganizationMember(
user_id=test_user.id,
organization_id=org1.id,
role="member",
)
membership2 = OrganizationMember(
user_id=test_user.id,
organization_id=org2.id,
role="member",
)
db.session.add_all([membership1, membership2])
db.session.commit()
# Create policies
policy1 = OrganizationSecurityPolicy(
organization_id=org1.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
policy2 = OrganizationSecurityPolicy(
organization_id=org2.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add_all([policy1, policy2])
db.session.commit()
# Get aggregate state
aggregate = MfaPolicyService.get_multi_org_aggregate_state(test_user)
# Verify structure
assert "overall_status" in aggregate
assert "strictest_mode" in aggregate
assert "missing_methods" in aggregate
assert "requiring_org_count" in aggregate
assert "requiring_orgs" in aggregate
assert "per_org_details" in aggregate
# Strictest mode should be REQUIRE_TOTP_OR_WEBAUTHN
assert aggregate["strictest_mode"] == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
# Both orgs should require MFA
assert aggregate["requiring_org_count"] == 2
assert len(aggregate["requiring_orgs"]) == 2
assert len(aggregate["per_org_details"]) == 2
+295
View File
@@ -0,0 +1,295 @@
"""Unit tests for MFA policy models."""
import pytest
from datetime import datetime, timezone, timedelta
from gatehouse_app.models import (
User,
Organization,
OrganizationMember,
OrganizationSecurityPolicy,
UserSecurityPolicy,
MfaPolicyCompliance,
Session,
)
from gatehouse_app.utils.constants import (
UserStatus,
MfaPolicyMode,
MfaComplianceStatus,
MfaRequirementOverride,
SessionStatus,
OrganizationRole,
)
@pytest.mark.unit
class TestOrganizationSecurityPolicyModel:
"""Tests for OrganizationSecurityPolicy model."""
def test_create_org_security_policy(self, db, test_organization):
"""Test creating an organization security policy."""
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
)
policy.save()
assert policy.id is not None
assert policy.organization_id == test_organization.id
assert policy.mfa_policy_mode == MfaPolicyMode.OPTIONAL
assert policy.mfa_grace_period_days == 14
assert policy.notify_days_before == 7
assert policy.policy_version == 1
assert policy.created_at is not None
def test_org_security_policy_to_dict(self, db, test_organization):
"""Test organization security policy to_dict method."""
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=7,
notify_days_before=3,
)
policy.save()
policy_dict = policy.to_dict()
assert "id" in policy_dict
assert "organization_id" in policy_dict
assert policy_dict["organization_id"] == test_organization.id
assert "mfa_policy_mode" in policy_dict
assert "mfa_grace_period_days" in policy_dict
def test_org_security_policy_relationships(self, db, test_organization):
"""Test organization security policy relationships."""
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
)
policy.save()
# Test relationship
assert policy.organization is not None
assert policy.organization.id == test_organization.id
@pytest.mark.unit
class TestUserSecurityPolicyModel:
"""Tests for UserSecurityPolicy model."""
def test_create_user_security_policy(self, db, test_user, test_organization):
"""Test creating a user security policy."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
policy.save()
assert policy.id is not None
assert policy.user_id == test_user.id
assert policy.organization_id == test_organization.id
assert policy.mfa_override_mode == MfaRequirementOverride.INHERIT
assert policy.force_totp is False
assert policy.force_webauthn is False
def test_user_security_policy_with_overrides(self, db, test_user, test_organization):
"""Test user security policy with override settings."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
force_totp=True,
force_webauthn=False,
)
policy.save()
assert policy.mfa_override_mode == MfaRequirementOverride.REQUIRED
assert policy.force_totp is True
assert policy.force_webauthn is False
def test_user_security_policy_exempt(self, db, test_user, test_organization):
"""Test user security policy with exempt override."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
policy.save()
assert policy.mfa_override_mode == MfaRequirementOverride.EXEMPT
def test_user_security_policy_relationships(self, db, test_user, test_organization):
"""Test user security policy relationships."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
policy.save()
# Test relationships
assert policy.user is not None
assert policy.user.id == test_user.id
assert policy.organization is not None
assert policy.organization.id == test_organization.id
@pytest.mark.unit
class TestMfaPolicyComplianceModel:
"""Tests for MfaPolicyCompliance model."""
def test_create_mfa_policy_compliance(self, db, test_user, test_organization):
"""Test creating an MFA policy compliance record."""
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.NOT_APPLICABLE,
policy_version=1,
)
compliance.save()
assert compliance.id is not None
assert compliance.user_id == test_user.id
assert compliance.organization_id == test_organization.id
assert compliance.status == MfaComplianceStatus.NOT_APPLICABLE
assert compliance.policy_version == 1
assert compliance.notification_count == 0
def test_mfa_policy_compliance_in_grace(self, db, test_user, test_organization):
"""Test MFA compliance record in grace period."""
now = datetime.now(timezone.utc)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.IN_GRACE,
policy_version=1,
applied_at=now,
deadline_at=now + timedelta(days=14),
)
compliance.save()
assert compliance.status == MfaComplianceStatus.IN_GRACE
assert compliance.applied_at is not None
assert compliance.deadline_at is not None
assert compliance.deadline_at > now
def test_mfa_policy_compliance_compliant(self, db, test_user, test_organization):
"""Test MFA compliance record when compliant."""
now = datetime.now(timezone.utc)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.COMPLIANT,
policy_version=1,
applied_at=now - timedelta(days=30),
deadline_at=now - timedelta(days=16),
compliant_at=now - timedelta(days=16),
)
compliance.save()
assert compliance.status == MfaComplianceStatus.COMPLIANT
assert compliance.compliant_at is not None
def test_mfa_policy_compliance_suspended(self, db, test_user, test_organization):
"""Test MFA compliance record when suspended."""
now = datetime.now(timezone.utc)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.SUSPENDED,
policy_version=1,
applied_at=now - timedelta(days=30),
deadline_at=now - timedelta(days=16),
suspended_at=now - timedelta(days=16),
)
compliance.save()
assert compliance.status == MfaComplianceStatus.SUSPENDED
assert compliance.suspended_at is not None
def test_mfa_policy_compliance_relationships(self, db, test_user, test_organization):
"""Test MFA compliance relationships."""
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.NOT_APPLICABLE,
policy_version=1,
)
compliance.save()
# Test relationships
assert compliance.user is not None
assert compliance.user.id == test_user.id
assert compliance.organization is not None
assert compliance.organization.id == test_organization.id
@pytest.mark.unit
class TestSessionModelComplianceFlag:
"""Tests for Session model compliance flag."""
def test_session_default_not_compliance_only(self, db, test_user):
"""Test that sessions are not compliance only by default."""
session = Session(
user_id=test_user.id,
token="test-token-123",
status=SessionStatus.ACTIVE,
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
last_activity_at=datetime.now(timezone.utc),
)
session.save()
assert session.is_compliance_only is False
def test_session_compliance_only(self, db, test_user):
"""Test creating a compliance-only session."""
session = Session(
user_id=test_user.id,
token="compliance-token-123",
status=SessionStatus.ACTIVE,
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
last_activity_at=datetime.now(timezone.utc),
is_compliance_only=True,
)
session.save()
assert session.is_compliance_only is True
def test_session_to_dict_excludes_token(self, db, test_user):
"""Test that session to_dict excludes the token."""
session = Session(
user_id=test_user.id,
token="test-token-456",
status=SessionStatus.ACTIVE,
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
last_activity_at=datetime.now(timezone.utc),
)
session.save()
session_dict = session.to_dict()
assert "id" in session_dict
assert "user_id" in session_dict
assert "is_compliance_only" in session_dict
assert session_dict["is_compliance_only"] is False
@pytest.mark.unit
class TestUserStatusComplianceSuspended:
"""Tests for UserStatus.COMPLIANCE_SUSPENDED."""
def test_compliance_suspended_status_exists(self):
"""Test that COMPLIANCE_SUSPENDED status exists."""
assert UserStatus.COMPLIANCE_SUSPENDED.value == "compliance_suspended"
def test_create_compliance_suspended_user(self, db):
"""Test creating a compliance suspended user."""
user = User(
email="suspended@example.com",
full_name="Suspended User",
status=UserStatus.COMPLIANCE_SUSPENDED,
)
user.save()
assert user.status == UserStatus.COMPLIANCE_SUSPENDED
@@ -0,0 +1,476 @@
"""Unit tests for MfaPolicyService."""
import pytest
from datetime import datetime, timezone, timedelta
from unittest.mock import patch, MagicMock
from gatehouse_app.models import (
User,
Organization,
OrganizationMember,
OrganizationSecurityPolicy,
UserSecurityPolicy,
MfaPolicyCompliance,
Session,
)
from gatehouse_app.services.mfa_policy_service import (
MfaPolicyService,
OrgPolicyDto,
EffectiveUserPolicyDto,
AggregateMfaStateDto,
LoginPolicyResult,
)
from gatehouse_app.utils.constants import (
UserStatus,
MfaPolicyMode,
MfaComplianceStatus,
MfaRequirementOverride,
SessionStatus,
OrganizationRole,
)
@pytest.mark.unit
class TestMfaPolicyService:
"""Tests for MfaPolicyService."""
def test_get_org_policy_not_found(self, db, test_organization):
"""Test getting organization policy when none exists."""
policy = MfaPolicyService.get_org_policy(test_organization.id)
assert policy is None
def test_get_org_policy_found(self, db, test_organization):
"""Test getting organization policy when it exists."""
# Create policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
)
org_policy.save()
policy = MfaPolicyService.get_org_policy(test_organization.id)
assert policy is not None
assert policy.organization_id == test_organization.id
assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
assert policy.mfa_grace_period_days == 14
assert policy.notify_days_before == 7
assert policy.policy_version == 1
def test_get_effective_user_policy_no_org_policy(self, db, test_user, test_organization):
"""Test effective user policy when no org policy exists."""
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy is not None
assert policy.organization_id == test_organization.id
assert policy.effective_mode == MfaPolicyMode.DISABLED.value
assert policy.requires_totp is False
assert policy.requires_webauthn is False
assert policy.is_exempt is True
def test_get_effective_user_policy_with_org_policy(self, db, test_user, test_organization):
"""Test effective user policy with org policy and no override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy is not None
assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
assert policy.requires_totp is True
assert policy.requires_webauthn is False
assert policy.is_exempt is False
def test_get_effective_user_policy_with_override_inherit(self, db, test_user, test_organization):
"""Test effective user policy with INHERIT override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
mfa_grace_period_days=7,
)
org_policy.save()
# Create user override
user_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
user_override.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy.effective_mode == MfaPolicyMode.REQUIRE_WEBAUTHN.value
assert policy.requires_webauthn is True
def test_get_effective_user_policy_with_override_exempt(self, db, test_user, test_organization):
"""Test effective user policy with EXEMPT override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
)
org_policy.save()
# Create user override
user_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
user_override.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy.effective_mode == MfaPolicyMode.DISABLED.value
assert policy.is_exempt is True
def test_get_effective_user_policy_with_override_required(self, db, test_user, test_organization):
"""Test effective user policy with REQUIRED override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
)
org_policy.save()
# Create user override
user_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
)
user_override.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
assert policy.requires_totp is True
assert policy.requires_webauthn is True
assert policy.is_exempt is False
def test_evaluate_user_mfa_state_no_policy(self, db, test_user, test_organization):
"""Test evaluating user MFA state with no policy."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
state = MfaPolicyService.evaluate_user_mfa_state(test_user)
assert state is not None
assert state.overall_status == MfaComplianceStatus.COMPLIANT.value
assert len(state.missing_methods) == 0
assert len(state.orgs) == 1
def test_evaluate_user_mfa_state_with_policy(self, db, test_user, test_organization):
"""Test evaluating user MFA state with policy."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
state = MfaPolicyService.evaluate_user_mfa_state(test_user)
assert state is not None
assert state.overall_status == MfaComplianceStatus.IN_GRACE.value
assert "totp" in state.missing_methods
assert len(state.orgs) == 1
assert state.orgs[0].effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
def test_after_primary_auth_success_no_required_policy(self, db, test_user, test_organization):
"""Test after_primary_auth_success with no required policy."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
result = MfaPolicyService.after_primary_auth_success(test_user)
assert result.can_create_full_session is True
assert result.create_compliance_only_session is False
assert result.compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value
def test_after_primary_auth_success_in_grace(self, db, test_user, test_organization):
"""Test after_primary_auth_success when user is in grace period."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
result = MfaPolicyService.after_primary_auth_success(test_user)
assert result.can_create_full_session is True
assert result.create_compliance_only_session is False
assert result.compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value
def test_after_primary_auth_success_past_due(self, db, test_user, test_organization):
"""Test after_primary_auth_success when user is past due."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
deadline_at=datetime.now(timezone.utc) - timedelta(days=1),
)
compliance.save()
result = MfaPolicyService.after_primary_auth_success(test_user)
assert result.can_create_full_session is False
assert result.create_compliance_only_session is True
def test_create_org_policy_new(self, db, test_organization):
"""Test creating a new organization policy."""
policy = MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
updated_by_user_id=None,
)
assert policy is not None
assert policy.organization_id == test_organization.id
assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN
assert policy.policy_version == 1
def test_create_org_policy_update(self, db, test_organization):
"""Test updating an existing organization policy."""
# Create initial policy
initial_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
)
initial_policy.save()
# Update policy
updated_policy = MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=7,
updated_by_user_id=None,
)
assert updated_policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP
assert updated_policy.mfa_grace_period_days == 7
assert updated_policy.policy_version == 2
def test_set_user_override_new(self, db, test_user, test_organization):
"""Test setting a new user override."""
override = MfaPolicyService.set_user_override(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
force_totp=True,
force_webauthn=False,
updated_by_user_id=None,
)
assert override is not None
assert override.user_id == test_user.id
assert override.organization_id == test_organization.id
assert override.mfa_override_mode == MfaRequirementOverride.REQUIRED
assert override.force_totp is True
def test_set_user_override_update(self, db, test_user, test_organization):
"""Test updating an existing user override."""
# Create initial override
initial_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
initial_override.save()
# Update override
updated_override = MfaPolicyService.set_user_override(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
updated_by_user_id=None,
)
assert updated_override.mfa_override_mode == MfaRequirementOverride.EXEMPT
def test_get_user_compliance(self, db, test_user, test_organization):
"""Test getting user compliance record."""
# Create compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.COMPLIANT,
policy_version=1,
)
compliance.save()
result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert result is not None
assert result.status == MfaComplianceStatus.COMPLIANT
def test_get_user_compliance_not_found(self, db, test_user, test_organization):
"""Test getting user compliance record when none exists."""
result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert result is None
def test_get_org_compliance_list(self, db, test_user, test_organization):
"""Test getting organization compliance list."""
# Create compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.IN_GRACE,
policy_version=1,
deadline_at=datetime.now(timezone.utc) + timedelta(days=14),
)
compliance.save()
results = MfaPolicyService.get_org_compliance_list(test_organization.id)
assert len(results) == 1
assert results[0]["user_id"] == test_user.id
assert results[0]["status"] == MfaComplianceStatus.IN_GRACE.value
def test_get_org_compliance_list_with_status_filter(self, db, test_user, test_organization):
"""Test getting organization compliance list with status filter."""
# Create compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.COMPLIANT,
policy_version=1,
)
compliance.save()
# Filter by different status
results = MfaPolicyService.get_org_compliance_list(
test_organization.id, status=MfaComplianceStatus.IN_GRACE
)
assert len(results) == 0
# Filter by correct status
results = MfaPolicyService.get_org_compliance_list(
test_organization.id, status=MfaComplianceStatus.COMPLIANT
)
assert len(results) == 1
@pytest.mark.unit
class TestMfaPolicyServiceDto:
"""Tests for MfaPolicyService DTOs."""
def test_org_policy_dto(self):
"""Test OrgPolicyDto creation."""
dto = OrgPolicyDto(
organization_id="org-123",
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP.value,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
assert dto.organization_id == "org-123"
assert dto.mfa_policy_mode == "require_totp"
assert dto.mfa_grace_period_days == 14
def test_effective_user_policy_dto(self):
"""Test EffectiveUserPolicyDto creation."""
dto = EffectiveUserPolicyDto(
organization_id="org-123",
effective_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
requires_totp=True,
requires_webauthn=True,
grace_period_days=14,
is_exempt=False,
)
assert dto.requires_totp is True
assert dto.requires_webauthn is True
assert dto.is_exempt is False
def test_aggregate_mfa_state_dto(self):
"""Test AggregateMfaStateDto creation."""
dto = AggregateMfaStateDto(
overall_status=MfaComplianceStatus.IN_GRACE.value,
missing_methods=["totp"],
deadline_at="2025-02-01T00:00:00Z",
orgs=[],
)
assert dto.overall_status == "in_grace"
assert "totp" in dto.missing_methods
assert dto.deadline_at == "2025-02-01T00:00:00Z"
def test_login_policy_result(self):
"""Test LoginPolicyResult creation."""
summary = AggregateMfaStateDto(
overall_status=MfaComplianceStatus.IN_GRACE.value,
missing_methods=["totp"],
orgs=[],
)
result = LoginPolicyResult(
can_create_full_session=True,
create_compliance_only_session=False,
compliance_summary=summary,
)
assert result.can_create_full_session is True
assert result.create_compliance_only_session is False
assert result.compliance_summary.overall_status == "in_grace"