enable policies
This commit is contained in:
@@ -286,7 +286,6 @@ For issues and questions:
|
|||||||
|
|
||||||
# Boostrap db
|
# Boostrap db
|
||||||
python manage.py db upgrade
|
python manage.py db upgrade
|
||||||
python manage.py db migrate
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from gatehouse_app.services.oidc_service import (
|
|||||||
OIDCService, InvalidClientError, InvalidGrantError, InvalidRequestError
|
OIDCService, InvalidClientError, InvalidGrantError, InvalidRequestError
|
||||||
)
|
)
|
||||||
from gatehouse_app.services.auth_service import AuthService
|
from gatehouse_app.services.auth_service import AuthService
|
||||||
|
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
|
||||||
from gatehouse_app.extensions import db
|
from gatehouse_app.extensions import db
|
||||||
from gatehouse_app.extensions import bcrypt as flask_bcrypt
|
from gatehouse_app.extensions import bcrypt as flask_bcrypt
|
||||||
from gatehouse_app.models import User, OIDCClient
|
from gatehouse_app.models import User, OIDCClient
|
||||||
@@ -372,6 +373,23 @@ def oidc_authorize():
|
|||||||
logger.debug("[OIDC] Attempting user authentication for email: %s", email)
|
logger.debug("[OIDC] Attempting user authentication for email: %s", email)
|
||||||
try:
|
try:
|
||||||
user = AuthService.authenticate(email, password)
|
user = AuthService.authenticate(email, password)
|
||||||
|
|
||||||
|
# Evaluate MFA policy after primary authentication
|
||||||
|
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False)
|
||||||
|
|
||||||
|
# Check if user can create full session
|
||||||
|
if not policy_result.can_create_full_session:
|
||||||
|
logger.debug("[OIDC] User cannot create full session due to MFA compliance: user_id=%s, email=%s", user.id, email)
|
||||||
|
return _show_login_page(
|
||||||
|
client_id=client_id,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=scope,
|
||||||
|
state=state,
|
||||||
|
nonce=nonce,
|
||||||
|
response_type=response_type,
|
||||||
|
error="Your account requires multi factor enrollment before using single sign on"
|
||||||
|
)
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
session["oidc_user_id"] = user_id
|
session["oidc_user_id"] = user_id
|
||||||
|
|
||||||
|
|||||||
@@ -5,4 +5,4 @@ from flask import Blueprint
|
|||||||
api_v1_bp = Blueprint("api_v1", __name__)
|
api_v1_bp = Blueprint("api_v1", __name__)
|
||||||
|
|
||||||
# Import route modules to register them
|
# Import route modules to register them
|
||||||
from gatehouse_app.api.v1 import auth, users, organizations
|
from gatehouse_app.api.v1 import auth, users, organizations, policies
|
||||||
|
|||||||
+133
-24
@@ -22,6 +22,7 @@ from gatehouse_app.schemas.webauthn_schema import (
|
|||||||
from gatehouse_app.services.auth_service import AuthService
|
from gatehouse_app.services.auth_service import AuthService
|
||||||
from gatehouse_app.services.webauthn_service import WebAuthnService
|
from gatehouse_app.services.webauthn_service import WebAuthnService
|
||||||
from gatehouse_app.services.user_service import UserService
|
from gatehouse_app.services.user_service import UserService
|
||||||
|
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
|
||||||
from gatehouse_app.utils.decorators import login_required
|
from gatehouse_app.utils.decorators import login_required
|
||||||
from gatehouse_app.utils.constants import AuditAction
|
from gatehouse_app.utils.constants import AuditAction
|
||||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||||
@@ -94,6 +95,9 @@ def login():
|
|||||||
400: Validation error
|
400: Validation error
|
||||||
401: Invalid credentials
|
401: Invalid credentials
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate request data
|
# Validate request data
|
||||||
schema = LoginSchema()
|
schema = LoginSchema()
|
||||||
@@ -105,6 +109,11 @@ def login():
|
|||||||
password=data["password"],
|
password=data["password"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# SECURITY CHECK: Log MFA enrollment status to validate the vulnerability
|
||||||
|
has_totp = user.has_totp_enabled()
|
||||||
|
has_webauthn = user.has_webauthn_enabled()
|
||||||
|
logger.warning(f"[SECURITY DIAGNOSTIC] Login attempt for user {user.email} - TOTP enabled: {has_totp}, WebAuthn enabled: {has_webauthn}")
|
||||||
|
|
||||||
# Check if user has TOTP enabled for two-factor authentication
|
# Check if user has TOTP enabled for two-factor authentication
|
||||||
if user.has_totp_enabled():
|
if user.has_totp_enabled():
|
||||||
# TOTP is enabled - store user_id in session for TOTP verification
|
# TOTP is enabled - store user_id in session for TOTP verification
|
||||||
@@ -121,16 +130,56 @@ def login():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TOTP is NOT enabled - proceed with normal login flow
|
# TOTP is NOT enabled - proceed with normal login flow
|
||||||
|
# SECURITY DIAGNOSTIC: This is where the vulnerability occurs - no WebAuthn check!
|
||||||
|
if has_webauthn:
|
||||||
|
logger.error(f"[SECURITY VULNERABILITY DETECTED] User {user.email} has WebAuthn enrolled but is bypassing it! Creating session without MFA verification.")
|
||||||
|
|
||||||
|
# Evaluate MFA policy after primary authentication
|
||||||
|
remember_me = data.get("remember_me", False)
|
||||||
|
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me)
|
||||||
|
|
||||||
# Create session with appropriate duration based on remember_me preference
|
# Create session with appropriate duration based on remember_me preference
|
||||||
duration = 2592000 if data.get("remember_me") else 86400 # 30 days vs 1 day
|
duration = 2592000 if remember_me else 86400 # 30 days vs 1 day
|
||||||
user_session = AuthService.create_session(user, duration_seconds=duration)
|
|
||||||
|
# Determine if this should be a compliance-only session
|
||||||
|
is_compliance_only = policy_result.create_compliance_only_session
|
||||||
|
|
||||||
|
user_session = AuthService.create_session(
|
||||||
|
user,
|
||||||
|
duration_seconds=duration,
|
||||||
|
is_compliance_only=is_compliance_only
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build response data
|
||||||
|
response_data = {
|
||||||
|
"user": user.to_dict(),
|
||||||
|
"token": user_session.token,
|
||||||
|
"expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add MFA compliance information
|
||||||
|
if policy_result.compliance_summary:
|
||||||
|
response_data["mfa_compliance"] = {
|
||||||
|
"overall_status": policy_result.compliance_summary.overall_status,
|
||||||
|
"missing_methods": policy_result.compliance_summary.missing_methods,
|
||||||
|
"deadline_at": policy_result.compliance_summary.deadline_at,
|
||||||
|
"orgs": [
|
||||||
|
{
|
||||||
|
"organization_id": org.organization_id,
|
||||||
|
"organization_name": org.organization_name,
|
||||||
|
"status": org.status,
|
||||||
|
"deadline_at": org.deadline_at,
|
||||||
|
}
|
||||||
|
for org in policy_result.compliance_summary.orgs
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add requires_mfa_enrollment flag if compliance-only session
|
||||||
|
if is_compliance_only:
|
||||||
|
response_data["requires_mfa_enrollment"] = True
|
||||||
|
|
||||||
return api_response(
|
return api_response(
|
||||||
data={
|
data=response_data,
|
||||||
"user": user.to_dict(),
|
|
||||||
"token": user_session.token,
|
|
||||||
"expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(),
|
|
||||||
},
|
|
||||||
message="Login successful",
|
message="Login successful",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -380,20 +429,50 @@ def verify_totp():
|
|||||||
client_utc_timestamp=data.get("client_timestamp"),
|
client_utc_timestamp=data.get("client_timestamp"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create full session
|
# Evaluate MFA policy after primary authentication
|
||||||
user_session = AuthService.create_session(user)
|
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False)
|
||||||
|
|
||||||
|
# Determine if this should be a compliance-only session
|
||||||
|
is_compliance_only = policy_result.create_compliance_only_session
|
||||||
|
|
||||||
|
# Create session
|
||||||
|
user_session = AuthService.create_session(user, is_compliance_only=is_compliance_only)
|
||||||
|
|
||||||
# Clear temporary session
|
# Clear temporary session
|
||||||
session.pop("totp_pending_user_id", None)
|
session.pop("totp_pending_user_id", None)
|
||||||
|
|
||||||
|
# Build response data
|
||||||
|
response_data = {
|
||||||
|
"user": user.to_dict(),
|
||||||
|
"token": user_session.token,
|
||||||
|
"expires_at": user_session.expires_at.isoformat() + "Z"
|
||||||
|
if user_session.expires_at.isoformat()[-1] != "Z"
|
||||||
|
else user_session.expires_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add MFA compliance information
|
||||||
|
if policy_result.compliance_summary:
|
||||||
|
response_data["mfa_compliance"] = {
|
||||||
|
"overall_status": policy_result.compliance_summary.overall_status,
|
||||||
|
"missing_methods": policy_result.compliance_summary.missing_methods,
|
||||||
|
"deadline_at": policy_result.compliance_summary.deadline_at,
|
||||||
|
"orgs": [
|
||||||
|
{
|
||||||
|
"organization_id": org.organization_id,
|
||||||
|
"organization_name": org.organization_name,
|
||||||
|
"status": org.status,
|
||||||
|
"deadline_at": org.deadline_at,
|
||||||
|
}
|
||||||
|
for org in policy_result.compliance_summary.orgs
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add requires_mfa_enrollment flag if compliance-only session
|
||||||
|
if is_compliance_only:
|
||||||
|
response_data["requires_mfa_enrollment"] = True
|
||||||
|
|
||||||
return api_response(
|
return api_response(
|
||||||
data={
|
data=response_data,
|
||||||
"user": user.to_dict(),
|
|
||||||
"token": user_session.token,
|
|
||||||
"expires_at": user_session.expires_at.isoformat() + "Z"
|
|
||||||
if user_session.expires_at.isoformat()[-1] != "Z"
|
|
||||||
else user_session.expires_at.isoformat(),
|
|
||||||
},
|
|
||||||
message="TOTP verification successful",
|
message="TOTP verification successful",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -835,22 +914,52 @@ def complete_webauthn_login():
|
|||||||
challenge
|
challenge
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Evaluate MFA policy after primary authentication
|
||||||
|
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me=False)
|
||||||
|
|
||||||
|
# Determine if this should be a compliance-only session
|
||||||
|
is_compliance_only = policy_result.create_compliance_only_session
|
||||||
|
|
||||||
# Create session
|
# Create session
|
||||||
user_session = AuthService.create_session(user)
|
user_session = AuthService.create_session(user, is_compliance_only=is_compliance_only)
|
||||||
|
|
||||||
# Clear pending session
|
# Clear pending session
|
||||||
session.pop("webauthn_pending_user_id", None)
|
session.pop("webauthn_pending_user_id", None)
|
||||||
|
|
||||||
logger.info(f"WebAuthn login completed successfully for user: {user.email}")
|
logger.info(f"WebAuthn login completed successfully for user: {user.email}")
|
||||||
|
|
||||||
|
# Build response data
|
||||||
|
response_data = {
|
||||||
|
"user": user.to_dict(),
|
||||||
|
"token": user_session.token,
|
||||||
|
"expires_at": user_session.expires_at.isoformat() + "Z"
|
||||||
|
if user_session.expires_at.isoformat()[-1] != "Z"
|
||||||
|
else user_session.expires_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add MFA compliance information
|
||||||
|
if policy_result.compliance_summary:
|
||||||
|
response_data["mfa_compliance"] = {
|
||||||
|
"overall_status": policy_result.compliance_summary.overall_status,
|
||||||
|
"missing_methods": policy_result.compliance_summary.missing_methods,
|
||||||
|
"deadline_at": policy_result.compliance_summary.deadline_at,
|
||||||
|
"orgs": [
|
||||||
|
{
|
||||||
|
"organization_id": org.organization_id,
|
||||||
|
"organization_name": org.organization_name,
|
||||||
|
"status": org.status,
|
||||||
|
"deadline_at": org.deadline_at,
|
||||||
|
}
|
||||||
|
for org in policy_result.compliance_summary.orgs
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add requires_mfa_enrollment flag if compliance-only session
|
||||||
|
if is_compliance_only:
|
||||||
|
response_data["requires_mfa_enrollment"] = True
|
||||||
|
|
||||||
return api_response(
|
return api_response(
|
||||||
data={
|
data=response_data,
|
||||||
"user": user.to_dict(),
|
|
||||||
"token": user_session.token,
|
|
||||||
"expires_at": user_session.expires_at.isoformat() + "Z"
|
|
||||||
if user_session.expires_at.isoformat()[-1] != "Z"
|
|
||||||
else user_session.expires_at.isoformat(),
|
|
||||||
},
|
|
||||||
message="Login successful",
|
message="Login successful",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from flask import g, request
|
|||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
from gatehouse_app.api.v1 import api_v1_bp
|
from gatehouse_app.api.v1 import api_v1_bp
|
||||||
from gatehouse_app.utils.response import api_response
|
from gatehouse_app.utils.response import api_response
|
||||||
from gatehouse_app.utils.decorators import login_required, require_admin, require_owner
|
from gatehouse_app.utils.decorators import login_required, require_admin, require_owner, full_access_required
|
||||||
from gatehouse_app.schemas.organization_schema import (
|
from gatehouse_app.schemas.organization_schema import (
|
||||||
OrganizationCreateSchema,
|
OrganizationCreateSchema,
|
||||||
OrganizationUpdateSchema,
|
OrganizationUpdateSchema,
|
||||||
@@ -17,6 +17,7 @@ from gatehouse_app.utils.constants import OrganizationRole
|
|||||||
|
|
||||||
@api_v1_bp.route("/organizations", methods=["POST"])
|
@api_v1_bp.route("/organizations", methods=["POST"])
|
||||||
@login_required
|
@login_required
|
||||||
|
@full_access_required
|
||||||
def create_organization():
|
def create_organization():
|
||||||
"""
|
"""
|
||||||
Create a new organization.
|
Create a new organization.
|
||||||
@@ -65,6 +66,7 @@ def create_organization():
|
|||||||
|
|
||||||
@api_v1_bp.route("/organizations/<org_id>", methods=["GET"])
|
@api_v1_bp.route("/organizations/<org_id>", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
|
@full_access_required
|
||||||
def get_organization(org_id):
|
def get_organization(org_id):
|
||||||
"""
|
"""
|
||||||
Get organization by ID.
|
Get organization by ID.
|
||||||
@@ -101,6 +103,7 @@ def get_organization(org_id):
|
|||||||
@api_v1_bp.route("/organizations/<org_id>", methods=["PATCH"])
|
@api_v1_bp.route("/organizations/<org_id>", methods=["PATCH"])
|
||||||
@login_required
|
@login_required
|
||||||
@require_admin
|
@require_admin
|
||||||
|
@full_access_required
|
||||||
def update_organization(org_id):
|
def update_organization(org_id):
|
||||||
"""
|
"""
|
||||||
Update organization.
|
Update organization.
|
||||||
@@ -152,6 +155,7 @@ def update_organization(org_id):
|
|||||||
@api_v1_bp.route("/organizations/<org_id>", methods=["DELETE"])
|
@api_v1_bp.route("/organizations/<org_id>", methods=["DELETE"])
|
||||||
@login_required
|
@login_required
|
||||||
@require_owner
|
@require_owner
|
||||||
|
@full_access_required
|
||||||
def delete_organization(org_id):
|
def delete_organization(org_id):
|
||||||
"""
|
"""
|
||||||
Delete organization (soft delete).
|
Delete organization (soft delete).
|
||||||
@@ -180,6 +184,7 @@ def delete_organization(org_id):
|
|||||||
|
|
||||||
@api_v1_bp.route("/organizations/<org_id>/members", methods=["GET"])
|
@api_v1_bp.route("/organizations/<org_id>/members", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
|
@full_access_required
|
||||||
def get_organization_members(org_id):
|
def get_organization_members(org_id):
|
||||||
"""
|
"""
|
||||||
Get all members of an organization.
|
Get all members of an organization.
|
||||||
@@ -223,6 +228,7 @@ def get_organization_members(org_id):
|
|||||||
@api_v1_bp.route("/organizations/<org_id>/members", methods=["POST"])
|
@api_v1_bp.route("/organizations/<org_id>/members", methods=["POST"])
|
||||||
@login_required
|
@login_required
|
||||||
@require_admin
|
@require_admin
|
||||||
|
@full_access_required
|
||||||
def add_organization_member(org_id):
|
def add_organization_member(org_id):
|
||||||
"""
|
"""
|
||||||
Add a member to the organization.
|
Add a member to the organization.
|
||||||
@@ -290,6 +296,7 @@ def add_organization_member(org_id):
|
|||||||
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>", methods=["DELETE"])
|
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>", methods=["DELETE"])
|
||||||
@login_required
|
@login_required
|
||||||
@require_admin
|
@require_admin
|
||||||
|
@full_access_required
|
||||||
def remove_organization_member(org_id, user_id):
|
def remove_organization_member(org_id, user_id):
|
||||||
"""
|
"""
|
||||||
Remove a member from the organization.
|
Remove a member from the organization.
|
||||||
@@ -320,6 +327,7 @@ def remove_organization_member(org_id, user_id):
|
|||||||
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>/role", methods=["PATCH"])
|
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>/role", methods=["PATCH"])
|
||||||
@login_required
|
@login_required
|
||||||
@require_admin
|
@require_admin
|
||||||
|
@full_access_required
|
||||||
def update_member_role(org_id, user_id):
|
def update_member_role(org_id, user_id):
|
||||||
"""
|
"""
|
||||||
Update a member's role.
|
Update a member's role.
|
||||||
|
|||||||
@@ -0,0 +1,336 @@
|
|||||||
|
"""Security policy endpoints."""
|
||||||
|
from flask import g, request
|
||||||
|
from marshmallow import Schema, fields, validate, ValidationError
|
||||||
|
from gatehouse_app.api.v1 import api_v1_bp
|
||||||
|
from gatehouse_app.utils.response import api_response
|
||||||
|
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
|
||||||
|
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
|
||||||
|
from gatehouse_app.services.organization_service import OrganizationService
|
||||||
|
from gatehouse_app.services.audit_service import AuditService
|
||||||
|
from gatehouse_app.utils.constants import MfaPolicyMode, MfaRequirementOverride, MfaComplianceStatus, AuditAction
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateOrgPolicySchema(Schema):
|
||||||
|
"""Schema for updating organization security policy."""
|
||||||
|
mfa_policy_mode = fields.String(
|
||||||
|
required=False,
|
||||||
|
validate=validate.OneOf([m.value for m in MfaPolicyMode])
|
||||||
|
)
|
||||||
|
mfa_grace_period_days = fields.Integer(
|
||||||
|
required=False,
|
||||||
|
validate=validate.Range(min=0, max=365)
|
||||||
|
)
|
||||||
|
notify_days_before = fields.Integer(
|
||||||
|
required=False,
|
||||||
|
validate=validate.Range(min=0, max=30)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateUserPolicySchema(Schema):
|
||||||
|
"""Schema for updating user security policy override."""
|
||||||
|
mfa_override_mode = fields.String(
|
||||||
|
required=True,
|
||||||
|
validate=validate.OneOf([m.value for m in MfaRequirementOverride])
|
||||||
|
)
|
||||||
|
force_totp = fields.Boolean(required=False, load_default=False)
|
||||||
|
force_webauthn = fields.Boolean(required=False, load_default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ComplianceListQuerySchema(Schema):
|
||||||
|
"""Schema for compliance list query parameters."""
|
||||||
|
status = fields.String(required=False)
|
||||||
|
limit = fields.Integer(required=False, load_default=100)
|
||||||
|
offset = fields.Integer(required=False, load_default=0)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/organizations/<org_id>/security-policy", methods=["GET"])
|
||||||
|
@login_required
|
||||||
|
def get_org_security_policy(org_id):
|
||||||
|
"""
|
||||||
|
Get organization security policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org_id: Organization ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: Organization security policy
|
||||||
|
401: Not authenticated
|
||||||
|
403: Not a member
|
||||||
|
404: Organization not found
|
||||||
|
"""
|
||||||
|
org = OrganizationService.get_organization_by_id(org_id)
|
||||||
|
|
||||||
|
# Check if user is a member
|
||||||
|
if not org.is_member(g.current_user.id):
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="You are not a member of this organization",
|
||||||
|
status=403,
|
||||||
|
error_type="AUTHORIZATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
policy_dto = MfaPolicyService.get_org_policy(org_id)
|
||||||
|
|
||||||
|
if policy_dto:
|
||||||
|
data = {
|
||||||
|
"organization_id": policy_dto.organization_id,
|
||||||
|
"mfa_policy_mode": policy_dto.mfa_policy_mode,
|
||||||
|
"mfa_grace_period_days": policy_dto.mfa_grace_period_days,
|
||||||
|
"notify_days_before": policy_dto.notify_days_before,
|
||||||
|
"policy_version": policy_dto.policy_version,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Return default policy if none exists
|
||||||
|
data = {
|
||||||
|
"organization_id": org_id,
|
||||||
|
"mfa_policy_mode": MfaPolicyMode.OPTIONAL.value,
|
||||||
|
"mfa_grace_period_days": 14,
|
||||||
|
"notify_days_before": 7,
|
||||||
|
"policy_version": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={"security_policy": data},
|
||||||
|
message="Security policy retrieved successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/organizations/<org_id>/security-policy", methods=["PUT"])
|
||||||
|
@login_required
|
||||||
|
@require_admin
|
||||||
|
@full_access_required
|
||||||
|
def update_org_security_policy(org_id):
|
||||||
|
"""
|
||||||
|
Update organization security policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org_id: Organization ID
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
mfa_policy_mode: MFA policy mode (disabled, optional, require_totp, require_webauthn, require_totp_or_webauthn)
|
||||||
|
mfa_grace_period_days: Grace period in days (0-365)
|
||||||
|
notify_days_before: Days before deadline to notify (0-30)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: Security policy updated successfully
|
||||||
|
400: Validation error
|
||||||
|
401: Not authenticated
|
||||||
|
403: Not an admin
|
||||||
|
404: Organization not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
schema = UpdateOrgPolicySchema()
|
||||||
|
data = schema.load(request.json)
|
||||||
|
|
||||||
|
org = OrganizationService.get_organization_by_id(org_id)
|
||||||
|
|
||||||
|
# Update policy
|
||||||
|
policy = MfaPolicyService.create_org_policy(
|
||||||
|
organization_id=org_id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode(data.get("mfa_policy_mode", MfaPolicyMode.OPTIONAL.value)),
|
||||||
|
mfa_grace_period_days=data.get("mfa_grace_period_days", 14),
|
||||||
|
notify_days_before=data.get("notify_days_before", 7),
|
||||||
|
updated_by_user_id=g.current_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"security_policy": {
|
||||||
|
"organization_id": policy.organization_id,
|
||||||
|
"mfa_policy_mode": policy.mfa_policy_mode.value,
|
||||||
|
"mfa_grace_period_days": policy.mfa_grace_period_days,
|
||||||
|
"notify_days_before": policy.notify_days_before,
|
||||||
|
"policy_version": policy.policy_version,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
message="Security policy updated successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Validation failed",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
error_details=e.messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/organizations/<org_id>/mfa-compliance", methods=["GET"])
|
||||||
|
@login_required
|
||||||
|
@require_admin
|
||||||
|
@full_access_required
|
||||||
|
def get_org_mfa_compliance(org_id):
|
||||||
|
"""
|
||||||
|
Get MFA compliance list for an organization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org_id: Organization ID
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
status: Optional status filter (not_applicable, pending, in_grace, compliant, past_due, suspended)
|
||||||
|
limit: Maximum records to return (default 100)
|
||||||
|
offset: Offset for pagination (default 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: List of compliance records
|
||||||
|
401: Not authenticated
|
||||||
|
403: Not an admin
|
||||||
|
404: Organization not found
|
||||||
|
"""
|
||||||
|
org = OrganizationService.get_organization_by_id(org_id)
|
||||||
|
|
||||||
|
# Parse query params
|
||||||
|
status = None
|
||||||
|
if request.args.get("status"):
|
||||||
|
try:
|
||||||
|
status = MfaComplianceStatus(request.args.get("status"))
|
||||||
|
except ValueError:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Invalid status value",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
limit = min(int(request.args.get("limit", 100)), 100)
|
||||||
|
offset = int(request.args.get("offset", 0))
|
||||||
|
|
||||||
|
compliance_list = MfaPolicyService.get_org_compliance_list(
|
||||||
|
organization_id=org_id,
|
||||||
|
status=status,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"compliance": compliance_list,
|
||||||
|
"count": len(compliance_list),
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
},
|
||||||
|
message="Compliance records retrieved successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route(
|
||||||
|
"/organizations/<org_id>/users/<user_id>/security-policy", methods=["PATCH"]
|
||||||
|
)
|
||||||
|
@login_required
|
||||||
|
@require_admin
|
||||||
|
@full_access_required
|
||||||
|
def update_user_security_policy(org_id, user_id):
|
||||||
|
"""
|
||||||
|
Update user security policy override.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org_id: Organization ID
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
mfa_override_mode: Override mode (inherit, required, exempt)
|
||||||
|
force_totp: Force TOTP requirement (default false)
|
||||||
|
force_webauthn: Force WebAuthn requirement (default false)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: User policy updated successfully
|
||||||
|
400: Validation error
|
||||||
|
401: Not authenticated
|
||||||
|
403: Not an admin
|
||||||
|
404: Organization or user not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
schema = UpdateUserPolicySchema()
|
||||||
|
data = schema.load(request.json)
|
||||||
|
|
||||||
|
org = OrganizationService.get_organization_by_id(org_id)
|
||||||
|
|
||||||
|
# Check if user is a member of the organization
|
||||||
|
if not org.is_member(user_id):
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="User is not a member of this organization",
|
||||||
|
status=404,
|
||||||
|
error_type="NOT_FOUND",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update user policy
|
||||||
|
policy = MfaPolicyService.set_user_override(
|
||||||
|
user_id=user_id,
|
||||||
|
organization_id=org_id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride(data["mfa_override_mode"]),
|
||||||
|
force_totp=data.get("force_totp", False),
|
||||||
|
force_webauthn=data.get("force_webauthn", False),
|
||||||
|
updated_by_user_id=g.current_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log the override change with details
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.USER_SECURITY_POLICY_OVERRIDE_UPDATE,
|
||||||
|
user_id=g.current_user.id,
|
||||||
|
organization_id=org_id,
|
||||||
|
resource_type="user",
|
||||||
|
resource_id=user_id,
|
||||||
|
description=f"User security policy override changed to {data['mfa_override_mode']} for user {user_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"user_security_policy": {
|
||||||
|
"user_id": policy.user_id,
|
||||||
|
"organization_id": policy.organization_id,
|
||||||
|
"mfa_override_mode": policy.mfa_override_mode.value,
|
||||||
|
"force_totp": policy.force_totp,
|
||||||
|
"force_webauthn": policy.force_webauthn,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
message="User security policy updated successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Validation failed",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
error_details=e.messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/users/me/mfa-compliance", methods=["GET"])
|
||||||
|
@login_required
|
||||||
|
def get_my_mfa_compliance():
|
||||||
|
"""
|
||||||
|
Get current user's MFA compliance across all organizations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: MFA compliance summary
|
||||||
|
401: Not authenticated
|
||||||
|
"""
|
||||||
|
user = g.current_user
|
||||||
|
|
||||||
|
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user)
|
||||||
|
|
||||||
|
orgs = []
|
||||||
|
for org_state in compliance_summary.orgs:
|
||||||
|
orgs.append({
|
||||||
|
"organization_id": org_state.organization_id,
|
||||||
|
"organization_name": org_state.organization_name,
|
||||||
|
"status": org_state.status,
|
||||||
|
"effective_mode": org_state.effective_mode,
|
||||||
|
"deadline_at": org_state.deadline_at,
|
||||||
|
"applied_at": org_state.applied_at,
|
||||||
|
})
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"mfa_compliance": {
|
||||||
|
"overall_status": compliance_summary.overall_status,
|
||||||
|
"missing_methods": compliance_summary.missing_methods,
|
||||||
|
"deadline_at": compliance_summary.deadline_at,
|
||||||
|
"orgs": orgs,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
message="MFA compliance retrieved successfully",
|
||||||
|
)
|
||||||
@@ -3,7 +3,7 @@ from flask import g, request
|
|||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
from gatehouse_app.api.v1 import api_v1_bp
|
from gatehouse_app.api.v1 import api_v1_bp
|
||||||
from gatehouse_app.utils.response import api_response
|
from gatehouse_app.utils.response import api_response
|
||||||
from gatehouse_app.utils.decorators import login_required
|
from gatehouse_app.utils.decorators import login_required, full_access_required
|
||||||
from gatehouse_app.schemas.user_schema import UserUpdateSchema, ChangePasswordSchema
|
from gatehouse_app.schemas.user_schema import UserUpdateSchema, ChangePasswordSchema
|
||||||
from gatehouse_app.services.user_service import UserService
|
from gatehouse_app.services.user_service import UserService
|
||||||
from gatehouse_app.services.auth_service import AuthService
|
from gatehouse_app.services.auth_service import AuthService
|
||||||
@@ -29,6 +29,7 @@ def get_me():
|
|||||||
|
|
||||||
@api_v1_bp.route("/users/me", methods=["PATCH"])
|
@api_v1_bp.route("/users/me", methods=["PATCH"])
|
||||||
@login_required
|
@login_required
|
||||||
|
@full_access_required
|
||||||
def update_me():
|
def update_me():
|
||||||
"""
|
"""
|
||||||
Update current user profile.
|
Update current user profile.
|
||||||
@@ -67,6 +68,7 @@ def update_me():
|
|||||||
|
|
||||||
@api_v1_bp.route("/users/me", methods=["DELETE"])
|
@api_v1_bp.route("/users/me", methods=["DELETE"])
|
||||||
@login_required
|
@login_required
|
||||||
|
@full_access_required
|
||||||
def delete_me():
|
def delete_me():
|
||||||
"""
|
"""
|
||||||
Delete current user account (soft delete).
|
Delete current user account (soft delete).
|
||||||
@@ -84,6 +86,7 @@ def delete_me():
|
|||||||
|
|
||||||
@api_v1_bp.route("/users/me/password", methods=["POST"])
|
@api_v1_bp.route("/users/me/password", methods=["POST"])
|
||||||
@login_required
|
@login_required
|
||||||
|
@full_access_required
|
||||||
def change_password():
|
def change_password():
|
||||||
"""
|
"""
|
||||||
Change current user password.
|
Change current user password.
|
||||||
@@ -136,6 +139,7 @@ def change_password():
|
|||||||
|
|
||||||
@api_v1_bp.route("/users/me/organizations", methods=["GET"])
|
@api_v1_bp.route("/users/me/organizations", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
|
@full_access_required
|
||||||
def get_my_organizations():
|
def get_my_organizations():
|
||||||
"""
|
"""
|
||||||
Get all organizations current user is a member of.
|
Get all organizations current user is a member of.
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
Jobs module for scheduled tasks.
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
"""MFA Compliance Scheduled Job.
|
||||||
|
|
||||||
|
This module implements the scheduled job for processing MFA compliance transitions,
|
||||||
|
sending notifications to users approaching deadlines, and handling edge cases.
|
||||||
|
|
||||||
|
The job is designed to be run periodically (e.g., via cron) to:
|
||||||
|
1. Transition users from PAST_DUE to SUSPENDED status
|
||||||
|
2. Send deadline reminder notifications to users in grace period
|
||||||
|
3. Update notification tracking metadata
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python manage.py run_mfa_compliance_job
|
||||||
|
|
||||||
|
Or call directly:
|
||||||
|
from gatehouse_app.jobs.mfa_compliance_job import process_mfa_compliance
|
||||||
|
process_mfa_compliance()
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from gatehouse_app.extensions import db
|
||||||
|
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
||||||
|
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
|
||||||
|
from gatehouse_app.models.user import User
|
||||||
|
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
|
||||||
|
from gatehouse_app.services.notification_service import NotificationService
|
||||||
|
from gatehouse_app.utils.constants import MfaComplianceStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def process_mfa_compliance(now: Optional[datetime] = None) -> Dict[str, Any]:
|
||||||
|
"""Process MFA compliance transitions and send notifications.
|
||||||
|
|
||||||
|
This scheduled job performs the following operations:
|
||||||
|
1. Transitions users from PAST_DUE to SUSPENDED status
|
||||||
|
2. Identifies users approaching deadline (within notify_days_before)
|
||||||
|
3. Sends deadline reminder notifications
|
||||||
|
4. Updates notification tracking metadata
|
||||||
|
|
||||||
|
Args:
|
||||||
|
now: Current time, defaults to now (UTC)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with job execution statistics:
|
||||||
|
- suspended_count: Number of users transitioned to suspended
|
||||||
|
- notified_count: Number of notifications sent
|
||||||
|
- processed_count: Total compliance records processed
|
||||||
|
"""
|
||||||
|
if now is None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
logger.info(f"Starting MFA compliance job at {now.isoformat()}")
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"suspended_count": 0,
|
||||||
|
"notified_count": 0,
|
||||||
|
"processed_count": 0,
|
||||||
|
"errors": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: Transition past-due users to suspended
|
||||||
|
suspended_count = MfaPolicyService.transition_to_suspended_if_past_due(now)
|
||||||
|
stats["suspended_count"] = suspended_count
|
||||||
|
logger.info(f"Transitioned {suspended_count} users to suspended status")
|
||||||
|
|
||||||
|
# Step 2: Send notifications to users approaching deadline
|
||||||
|
notified_count = _send_deadline_reminders(now)
|
||||||
|
stats["notified_count"] = notified_count
|
||||||
|
logger.info(f"Sent {notified_count} deadline reminder notifications")
|
||||||
|
|
||||||
|
# Step 3: Process any pending compliance evaluations
|
||||||
|
processed_count = _evaluate_pending_compliance(now)
|
||||||
|
stats["processed_count"] = processed_count
|
||||||
|
logger.info(f"Processed {processed_count} compliance records")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error during MFA compliance job: {e}")
|
||||||
|
stats["errors"].append(str(e))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"MFA compliance job completed: suspended={stats['suspended_count']}, "
|
||||||
|
f"notified={stats['notified_count']}, processed={stats['processed_count']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def _send_deadline_reminders(now: datetime) -> int:
|
||||||
|
"""Send deadline reminder notifications to users approaching deadline.
|
||||||
|
|
||||||
|
Identifies users in grace period who are within their organization's
|
||||||
|
notify_days_before threshold and sends them reminder notifications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
now: Current time (UTC)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of notifications sent
|
||||||
|
"""
|
||||||
|
notified_count = 0
|
||||||
|
|
||||||
|
# Find all compliance records in grace period
|
||||||
|
grace_records = MfaPolicyCompliance.query.filter(
|
||||||
|
MfaPolicyCompliance.status == MfaComplianceStatus.IN_GRACE,
|
||||||
|
MfaPolicyCompliance.deadline_at != None,
|
||||||
|
MfaPolicyCompliance.deleted_at == None,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
for record in grace_records:
|
||||||
|
try:
|
||||||
|
# Get organization policy for notify_days_before
|
||||||
|
org_policy = OrganizationSecurityPolicy.query.filter_by(
|
||||||
|
organization_id=record.organization_id, deleted_at=None
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not org_policy:
|
||||||
|
continue
|
||||||
|
|
||||||
|
notify_threshold = org_policy.notify_days_before
|
||||||
|
deadline = record.deadline_at
|
||||||
|
|
||||||
|
# Ensure deadline has timezone
|
||||||
|
if deadline.tzinfo is None:
|
||||||
|
deadline = deadline.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
# Calculate time until deadline
|
||||||
|
time_until_deadline = deadline - now
|
||||||
|
days_until_deadline = time_until_deadline.total_seconds() / 86400
|
||||||
|
|
||||||
|
# Check if we should send a reminder
|
||||||
|
should_notify = False
|
||||||
|
if days_until_deadline <= notify_threshold:
|
||||||
|
# Check if we've already notified recently (within last 24 hours)
|
||||||
|
if record.last_notified_at:
|
||||||
|
hours_since_notification = (
|
||||||
|
now - record.last_notified_at
|
||||||
|
).total_seconds() / 3600
|
||||||
|
if hours_since_notification < 24:
|
||||||
|
continue # Already notified recently
|
||||||
|
|
||||||
|
should_notify = True
|
||||||
|
|
||||||
|
if should_notify:
|
||||||
|
# Get user
|
||||||
|
user = User.query.get(record.user_id)
|
||||||
|
if not user:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Send notification
|
||||||
|
success = NotificationService.send_mfa_deadline_reminder(
|
||||||
|
user=user,
|
||||||
|
compliance=record,
|
||||||
|
org_policy=org_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
# Update notification tracking
|
||||||
|
record.last_notified_at = now
|
||||||
|
record.notification_count += 1
|
||||||
|
db.session.commit()
|
||||||
|
notified_count += 1
|
||||||
|
logger.info(
|
||||||
|
f"Sent deadline reminder to user {user.email} "
|
||||||
|
f"(days until deadline: {days_until_deadline:.1f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error sending reminder for compliance record "
|
||||||
|
f"{record.id}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return notified_count
|
||||||
|
|
||||||
|
|
||||||
|
def _evaluate_pending_compliance(now: datetime) -> int:
|
||||||
|
"""Evaluate and update pending compliance records.
|
||||||
|
|
||||||
|
This handles edge cases where compliance records may need
|
||||||
|
status updates due to policy changes or other factors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
now: Current time (UTC)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of records processed
|
||||||
|
"""
|
||||||
|
processed_count = 0
|
||||||
|
|
||||||
|
# Find all non-deleted compliance records
|
||||||
|
records = MfaPolicyCompliance.query.filter(
|
||||||
|
MfaPolicyCompliance.deleted_at == None,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
try:
|
||||||
|
# Get the user and evaluate their current state
|
||||||
|
user = User.query.get(record.user_id)
|
||||||
|
if not user:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Re-evaluate compliance status
|
||||||
|
# This handles cases where policy changed or user enrolled in MFA
|
||||||
|
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
|
||||||
|
|
||||||
|
effective_policy = MfaPolicyService.get_effective_user_policy(
|
||||||
|
user.id, record.organization_id
|
||||||
|
)
|
||||||
|
|
||||||
|
new_status = MfaPolicyService._evaluate_compliance_status(
|
||||||
|
user, effective_policy, record
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update status if changed
|
||||||
|
if record.status != new_status:
|
||||||
|
old_status = record.status.value if hasattr(record.status, 'value') else str(record.status)
|
||||||
|
record.status = MfaComplianceStatus(new_status)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Updated compliance status for user {user.email} "
|
||||||
|
f"in org {record.organization_id}: {old_status} -> {new_status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error evaluating compliance record {record.id}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return processed_count
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_status(now: Optional[datetime] = None) -> Dict[str, Any]:
|
||||||
|
"""Get current status of MFA compliance for monitoring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
now: Current time, defaults to now (UTC)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with compliance statistics
|
||||||
|
"""
|
||||||
|
if now is None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Count records by status
|
||||||
|
status_counts = {}
|
||||||
|
for status in MfaComplianceStatus:
|
||||||
|
count = MfaPolicyCompliance.query.filter(
|
||||||
|
MfaPolicyCompliance.status == status,
|
||||||
|
MfaPolicyCompliance.deleted_at == None,
|
||||||
|
).count()
|
||||||
|
status_counts[status.value] = count
|
||||||
|
|
||||||
|
# Count users approaching deadline (within 7 days by default)
|
||||||
|
approaching_deadline = MfaPolicyCompliance.query.filter(
|
||||||
|
MfaPolicyCompliance.status == MfaComplianceStatus.IN_GRACE,
|
||||||
|
MfaPolicyCompliance.deadline_at != None,
|
||||||
|
MfaPolicyCompliance.deleted_at == None,
|
||||||
|
).count()
|
||||||
|
|
||||||
|
# Count past-due records
|
||||||
|
past_due_count = MfaPolicyCompliance.query.filter(
|
||||||
|
MfaPolicyCompliance.status == MfaComplianceStatus.PAST_DUE,
|
||||||
|
MfaPolicyCompliance.deleted_at == None,
|
||||||
|
).count()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status_counts": status_counts,
|
||||||
|
"approaching_deadline_count": approaching_deadline,
|
||||||
|
"past_due_count": past_due_count,
|
||||||
|
"timestamp": now.isoformat(),
|
||||||
|
}
|
||||||
@@ -12,6 +12,9 @@ from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken
|
|||||||
from gatehouse_app.models.oidc_session import OIDCSession
|
from gatehouse_app.models.oidc_session import OIDCSession
|
||||||
from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
|
from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
|
||||||
from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
|
from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
|
||||||
|
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
|
||||||
|
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
|
||||||
|
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseModel",
|
"BaseModel",
|
||||||
@@ -27,4 +30,7 @@ __all__ = [
|
|||||||
"OIDCSession",
|
"OIDCSession",
|
||||||
"OIDCTokenMetadata",
|
"OIDCTokenMetadata",
|
||||||
"OIDCAuditLog",
|
"OIDCAuditLog",
|
||||||
|
"OrganizationSecurityPolicy",
|
||||||
|
"UserSecurityPolicy",
|
||||||
|
"MfaPolicyCompliance",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
"""MfaPolicyCompliance model."""
|
||||||
|
from gatehouse_app.extensions import db
|
||||||
|
from gatehouse_app.models.base import BaseModel
|
||||||
|
from gatehouse_app.utils.constants import MfaComplianceStatus
|
||||||
|
|
||||||
|
|
||||||
|
class MfaPolicyCompliance(BaseModel):
|
||||||
|
"""MFA policy compliance tracking per user per organization.
|
||||||
|
|
||||||
|
Tracks each user's MFA compliance state separately for each organization membership.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "mfa_policy_compliance"
|
||||||
|
|
||||||
|
user_id = db.Column(
|
||||||
|
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
organization_id = db.Column(
|
||||||
|
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
|
||||||
|
status = db.Column(
|
||||||
|
db.Enum(MfaComplianceStatus),
|
||||||
|
nullable=False,
|
||||||
|
default=MfaComplianceStatus.NOT_APPLICABLE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Snapshot of org policy at the time this record became active
|
||||||
|
policy_version = db.Column(db.Integer, nullable=False)
|
||||||
|
|
||||||
|
# When policy started applying to this user
|
||||||
|
applied_at = db.Column(db.DateTime, nullable=True)
|
||||||
|
|
||||||
|
# Final deadline for this user to comply (per user, not global)
|
||||||
|
deadline_at = db.Column(db.DateTime, nullable=True)
|
||||||
|
|
||||||
|
# When they became compliant under this policy_version
|
||||||
|
compliant_at = db.Column(db.DateTime, nullable=True)
|
||||||
|
|
||||||
|
# When suspended enforcement started for this user
|
||||||
|
suspended_at = db.Column(db.DateTime, nullable=True)
|
||||||
|
|
||||||
|
# Notification tracking
|
||||||
|
last_notified_at = db.Column(db.DateTime, nullable=True)
|
||||||
|
notification_count = db.Column(db.Integer, nullable=False, default=0)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
db.UniqueConstraint(
|
||||||
|
"user_id", "organization_id", name="uix_user_org_compliance"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
user = db.relationship(
|
||||||
|
"User", back_populates="mfa_compliance", foreign_keys=[user_id]
|
||||||
|
)
|
||||||
|
organization = db.relationship("Organization", foreign_keys=[organization_id])
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""String representation of MfaPolicyCompliance."""
|
||||||
|
return f"<MfaPolicyCompliance user={self.user_id} org={self.organization_id} status={self.status}>"
|
||||||
|
|
||||||
|
def to_dict(self, exclude=None):
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
exclude = exclude or []
|
||||||
|
return super().to_dict(exclude=exclude)
|
||||||
@@ -24,6 +24,13 @@ class Organization(BaseModel):
|
|||||||
oidc_clients = db.relationship(
|
oidc_clients = db.relationship(
|
||||||
"OIDCClient", back_populates="organization", cascade="all, delete-orphan"
|
"OIDCClient", back_populates="organization", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
security_policy = db.relationship(
|
||||||
|
"OrganizationSecurityPolicy",
|
||||||
|
back_populates="organization",
|
||||||
|
uselist=False,
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
foreign_keys="OrganizationSecurityPolicy.organization_id",
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""String representation of Organization."""
|
"""String representation of Organization."""
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
"""OrganizationSecurityPolicy model."""
|
||||||
|
from gatehouse_app.extensions import db
|
||||||
|
from gatehouse_app.models.base import BaseModel
|
||||||
|
from gatehouse_app.utils.constants import MfaPolicyMode
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationSecurityPolicy(BaseModel):
|
||||||
|
"""Organization security policy model for MFA configuration.
|
||||||
|
|
||||||
|
One row per organization capturing its current security requirements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "organization_security_policies"
|
||||||
|
|
||||||
|
organization_id = db.Column(
|
||||||
|
db.String(36),
|
||||||
|
db.ForeignKey("organizations.id"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MFA policy configuration
|
||||||
|
mfa_policy_mode = db.Column(
|
||||||
|
db.Enum(MfaPolicyMode), nullable=False, default=MfaPolicyMode.OPTIONAL
|
||||||
|
)
|
||||||
|
|
||||||
|
# Grace period for members in days
|
||||||
|
mfa_grace_period_days = db.Column(db.Integer, nullable=False, default=14)
|
||||||
|
|
||||||
|
# Notification settings (in days before individual user deadline)
|
||||||
|
notify_days_before = db.Column(db.Integer, nullable=False, default=7)
|
||||||
|
|
||||||
|
# Versioning for compatibility tracking
|
||||||
|
policy_version = db.Column(db.Integer, nullable=False, default=1)
|
||||||
|
|
||||||
|
# Audit metadata
|
||||||
|
updated_by_user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
organization = db.relationship(
|
||||||
|
"Organization", back_populates="security_policy", foreign_keys=[organization_id]
|
||||||
|
)
|
||||||
|
updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id])
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""String representation of OrganizationSecurityPolicy."""
|
||||||
|
return f"<OrganizationSecurityPolicy org={self.organization_id} mode={self.mfa_policy_mode}>"
|
||||||
|
|
||||||
|
def to_dict(self, exclude=None):
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
exclude = exclude or []
|
||||||
|
return super().to_dict(exclude=exclude)
|
||||||
@@ -25,6 +25,9 @@ class Session(BaseModel):
|
|||||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||||
|
|
||||||
|
# Compliance session flag
|
||||||
|
is_compliance_only = db.Column(db.Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
user = db.relationship("User", back_populates="sessions")
|
user = db.relationship("User", back_populates="sessions")
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,18 @@ class User(BaseModel):
|
|||||||
foreign_keys="OrganizationMember.user_id",
|
foreign_keys="OrganizationMember.user_id",
|
||||||
)
|
)
|
||||||
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
|
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
|
||||||
|
security_policies = db.relationship(
|
||||||
|
"UserSecurityPolicy",
|
||||||
|
back_populates="user",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
foreign_keys="UserSecurityPolicy.user_id",
|
||||||
|
)
|
||||||
|
mfa_compliance = db.relationship(
|
||||||
|
"MfaPolicyCompliance",
|
||||||
|
back_populates="user",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
foreign_keys="MfaPolicyCompliance.user_id",
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""String representation of User."""
|
"""String representation of User."""
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
"""UserSecurityPolicy model."""
|
||||||
|
from gatehouse_app.extensions import db
|
||||||
|
from gatehouse_app.models.base import BaseModel
|
||||||
|
from gatehouse_app.utils.constants import MfaRequirementOverride
|
||||||
|
|
||||||
|
|
||||||
|
class UserSecurityPolicy(BaseModel):
|
||||||
|
"""User security policy model for per-user MFA overrides.
|
||||||
|
|
||||||
|
Stores per user overrides of organization level MFA requirements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "user_security_policies"
|
||||||
|
|
||||||
|
user_id = db.Column(
|
||||||
|
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
organization_id = db.Column(
|
||||||
|
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
|
||||||
|
mfa_override_mode = db.Column(
|
||||||
|
db.Enum(MfaRequirementOverride),
|
||||||
|
nullable=False,
|
||||||
|
default=MfaRequirementOverride.INHERIT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If override is REQUIRED and you want to force a specific factor set
|
||||||
|
force_totp = db.Column(db.Boolean, nullable=False, default=False)
|
||||||
|
force_webauthn = db.Column(db.Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
db.UniqueConstraint(
|
||||||
|
"user_id", "organization_id", name="uix_user_org_policy"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
user = db.relationship(
|
||||||
|
"User", back_populates="security_policies", foreign_keys=[user_id]
|
||||||
|
)
|
||||||
|
organization = db.relationship(
|
||||||
|
"Organization", foreign_keys=[organization_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""String representation of UserSecurityPolicy."""
|
||||||
|
return f"<UserSecurityPolicy user={self.user_id} org={self.organization_id} mode={self.mfa_override_mode}>"
|
||||||
|
|
||||||
|
def to_dict(self, exclude=None):
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
exclude = exclude or []
|
||||||
|
return super().to_dict(exclude=exclude)
|
||||||
@@ -96,3 +96,29 @@ class TOTPRegenerateBackupCodesSchema(Schema):
|
|||||||
"""Schema for regenerating backup codes."""
|
"""Schema for regenerating backup codes."""
|
||||||
|
|
||||||
password = fields.Str(required=True, validate=validate.Length(min=1))
|
password = fields.Str(required=True, validate=validate.Length(min=1))
|
||||||
|
|
||||||
|
|
||||||
|
class MfaComplianceOrgSchema(Schema):
|
||||||
|
"""Schema for MFA compliance per organization."""
|
||||||
|
organization_id = fields.Str(required=True)
|
||||||
|
organization_name = fields.Str(required=True)
|
||||||
|
status = fields.Str(required=True)
|
||||||
|
deadline_at = fields.Str(allow_none=True)
|
||||||
|
|
||||||
|
|
||||||
|
class MfaComplianceSchema(Schema):
|
||||||
|
"""Schema for MFA compliance summary in login response."""
|
||||||
|
overall_status = fields.Str(required=True)
|
||||||
|
missing_methods = fields.List(fields.Str(), required=True)
|
||||||
|
deadline_at = fields.Str(allow_none=True)
|
||||||
|
orgs = fields.List(fields.Nested(MfaComplianceOrgSchema), required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class LoginResponseSchema(Schema):
|
||||||
|
"""Schema for login response."""
|
||||||
|
user = fields.Dict(required=True)
|
||||||
|
token = fields.Str(required=True)
|
||||||
|
expires_at = fields.Str(required=True)
|
||||||
|
requires_totp = fields.Bool(required=False)
|
||||||
|
requires_mfa_enrollment = fields.Bool(required=False)
|
||||||
|
mfa_compliance = fields.Nested(MfaComplianceSchema, required=False)
|
||||||
|
|||||||
@@ -140,13 +140,14 @@ class AuthService:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_session(user, duration_seconds=86400):
|
def create_session(user, duration_seconds=86400, is_compliance_only=False):
|
||||||
"""
|
"""
|
||||||
Create a new session for the user.
|
Create a new session for the user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user: User instance
|
user: User instance
|
||||||
duration_seconds: Session duration in seconds
|
duration_seconds: Session duration in seconds
|
||||||
|
is_compliance_only: Whether this is a compliance-only session (limited access)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Session instance
|
Session instance
|
||||||
@@ -163,6 +164,7 @@ class AuthService:
|
|||||||
user_agent=request.headers.get("User-Agent"),
|
user_agent=request.headers.get("User-Agent"),
|
||||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
|
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
|
||||||
last_activity_at=datetime.now(timezone.utc),
|
last_activity_at=datetime.now(timezone.utc),
|
||||||
|
is_compliance_only=is_compliance_only,
|
||||||
)
|
)
|
||||||
session.save()
|
session.save()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,978 @@
|
|||||||
|
"""MFA Policy Service."""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
from gatehouse_app.extensions import db
|
||||||
|
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
|
||||||
|
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
|
||||||
|
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
||||||
|
from gatehouse_app.models.user import User
|
||||||
|
from gatehouse_app.models.organization import Organization
|
||||||
|
from gatehouse_app.services.audit_service import AuditService
|
||||||
|
from gatehouse_app.utils.constants import (
|
||||||
|
MfaPolicyMode,
|
||||||
|
MfaComplianceStatus,
|
||||||
|
MfaRequirementOverride,
|
||||||
|
AuditAction,
|
||||||
|
UserStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrgPolicyDto:
|
||||||
|
"""DTO for organization policy."""
|
||||||
|
organization_id: str
|
||||||
|
mfa_policy_mode: str
|
||||||
|
mfa_grace_period_days: int
|
||||||
|
notify_days_before: int
|
||||||
|
policy_version: int
|
||||||
|
updated_by_user_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EffectiveUserPolicyDto:
|
||||||
|
"""DTO for effective user policy combining org and user overrides."""
|
||||||
|
organization_id: str
|
||||||
|
effective_mode: str
|
||||||
|
requires_totp: bool
|
||||||
|
requires_webauthn: bool
|
||||||
|
grace_period_days: int
|
||||||
|
is_exempt: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UserMfaStateDto:
|
||||||
|
"""DTO for per-organization MFA state."""
|
||||||
|
organization_id: str
|
||||||
|
organization_name: str
|
||||||
|
status: str
|
||||||
|
effective_mode: str
|
||||||
|
deadline_at: Optional[str] = None
|
||||||
|
applied_at: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AggregateMfaStateDto:
|
||||||
|
"""DTO for aggregate MFA state across all organizations."""
|
||||||
|
overall_status: str
|
||||||
|
missing_methods: List[str] = field(default_factory=list)
|
||||||
|
deadline_at: Optional[str] = None
|
||||||
|
orgs: List[UserMfaStateDto] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoginPolicyResult:
|
||||||
|
"""Result of policy evaluation after primary auth success."""
|
||||||
|
can_create_full_session: bool
|
||||||
|
create_compliance_only_session: bool
|
||||||
|
compliance_summary: AggregateMfaStateDto
|
||||||
|
|
||||||
|
|
||||||
|
class MfaPolicyService:
|
||||||
|
"""Service for MFA policy evaluation and compliance tracking."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_org_policy(org_id: str) -> Optional[OrgPolicyDto]:
|
||||||
|
"""Get organization security policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org_id: Organization ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OrgPolicyDto or None if not found
|
||||||
|
"""
|
||||||
|
policy = OrganizationSecurityPolicy.query.filter_by(
|
||||||
|
organization_id=org_id, deleted_at=None
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not policy:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return OrgPolicyDto(
|
||||||
|
organization_id=policy.organization_id,
|
||||||
|
mfa_policy_mode=policy.mfa_policy_mode.value,
|
||||||
|
mfa_grace_period_days=policy.mfa_grace_period_days,
|
||||||
|
notify_days_before=policy.notify_days_before,
|
||||||
|
policy_version=policy.policy_version,
|
||||||
|
updated_by_user_id=policy.updated_by_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_effective_user_policy(
|
||||||
|
user_id: str, org_id: str
|
||||||
|
) -> EffectiveUserPolicyDto:
|
||||||
|
"""Get effective user policy combining org policy with user overrides.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
org_id: Organization ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EffectiveUserPolicyDto
|
||||||
|
"""
|
||||||
|
# Get org policy
|
||||||
|
org_policy = OrganizationSecurityPolicy.query.filter_by(
|
||||||
|
organization_id=org_id, deleted_at=None
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not org_policy:
|
||||||
|
# No org policy means no requirements
|
||||||
|
return EffectiveUserPolicyDto(
|
||||||
|
organization_id=org_id,
|
||||||
|
effective_mode=MfaPolicyMode.DISABLED.value,
|
||||||
|
requires_totp=False,
|
||||||
|
requires_webauthn=False,
|
||||||
|
grace_period_days=0,
|
||||||
|
is_exempt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user override
|
||||||
|
user_override = UserSecurityPolicy.query.filter_by(
|
||||||
|
user_id=user_id, organization_id=org_id, deleted_at=None
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# Determine effective mode
|
||||||
|
if user_override:
|
||||||
|
override_mode = user_override.mfa_override_mode
|
||||||
|
if override_mode == MfaRequirementOverride.EXEMPT:
|
||||||
|
return EffectiveUserPolicyDto(
|
||||||
|
organization_id=org_id,
|
||||||
|
effective_mode=MfaPolicyMode.DISABLED.value,
|
||||||
|
requires_totp=False,
|
||||||
|
requires_webauthn=False,
|
||||||
|
grace_period_days=org_policy.mfa_grace_period_days,
|
||||||
|
is_exempt=True,
|
||||||
|
)
|
||||||
|
elif override_mode == MfaRequirementOverride.REQUIRED:
|
||||||
|
# User is required to have MFA even if org is optional
|
||||||
|
effective_mode = MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN
|
||||||
|
else:
|
||||||
|
effective_mode = org_policy.mfa_policy_mode
|
||||||
|
else:
|
||||||
|
effective_mode = org_policy.mfa_policy_mode
|
||||||
|
|
||||||
|
# Determine required methods based on mode
|
||||||
|
requires_totp = effective_mode in (
|
||||||
|
MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
)
|
||||||
|
requires_webauthn = effective_mode in (
|
||||||
|
MfaPolicyMode.REQUIRE_WEBAUTHN,
|
||||||
|
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
)
|
||||||
|
|
||||||
|
return EffectiveUserPolicyDto(
|
||||||
|
organization_id=org_id,
|
||||||
|
effective_mode=effective_mode.value,
|
||||||
|
requires_totp=requires_totp,
|
||||||
|
requires_webauthn=requires_webauthn,
|
||||||
|
grace_period_days=org_policy.mfa_grace_period_days,
|
||||||
|
is_exempt=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def evaluate_user_mfa_state(user: User) -> AggregateMfaStateDto:
|
||||||
|
"""Evaluate user's MFA state across all organizations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AggregateMfaStateDto with overall status and per-org breakdown
|
||||||
|
"""
|
||||||
|
org_states: List[UserMfaStateDto] = []
|
||||||
|
overall_status = MfaComplianceStatus.COMPLIANT.value
|
||||||
|
earliest_deadline: Optional[datetime] = None
|
||||||
|
missing_methods: set = set()
|
||||||
|
|
||||||
|
for membership in user.organization_memberships:
|
||||||
|
if membership.deleted_at is not None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
org = membership.organization
|
||||||
|
if org.deleted_at is not None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
effective_policy = MfaPolicyService.get_effective_user_policy(
|
||||||
|
user.id, org.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get or create compliance record
|
||||||
|
compliance = MfaPolicyCompliance.query.filter_by(
|
||||||
|
user_id=user.id, organization_id=org.id, deleted_at=None
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not compliance:
|
||||||
|
# Create initial compliance record
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=user.id,
|
||||||
|
organization_id=org.id,
|
||||||
|
status=MfaComplianceStatus.NOT_APPLICABLE,
|
||||||
|
policy_version=0,
|
||||||
|
)
|
||||||
|
compliance.save()
|
||||||
|
|
||||||
|
# Determine status based on policy and user MFA state
|
||||||
|
status = MfaPolicyService._evaluate_compliance_status(
|
||||||
|
user, effective_policy, compliance
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update compliance record if needed
|
||||||
|
if compliance.status != status:
|
||||||
|
compliance.status = status
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Track missing methods
|
||||||
|
if status not in (
|
||||||
|
MfaComplianceStatus.COMPLIANT.value,
|
||||||
|
MfaComplianceStatus.NOT_APPLICABLE.value,
|
||||||
|
):
|
||||||
|
if effective_policy.requires_totp and not user.has_totp_enabled():
|
||||||
|
missing_methods.add("totp")
|
||||||
|
if effective_policy.requires_webauthn and not user.has_webauthn_enabled():
|
||||||
|
missing_methods.add("webauthn")
|
||||||
|
|
||||||
|
# Track earliest deadline
|
||||||
|
if compliance.deadline_at:
|
||||||
|
if earliest_deadline is None or compliance.deadline_at < earliest_deadline:
|
||||||
|
earliest_deadline = compliance.deadline_at
|
||||||
|
|
||||||
|
# Determine overall status (most restrictive)
|
||||||
|
if status == MfaComplianceStatus.SUSPENDED.value:
|
||||||
|
overall_status = MfaComplianceStatus.SUSPENDED.value
|
||||||
|
elif (
|
||||||
|
status == MfaComplianceStatus.PAST_DUE.value
|
||||||
|
and overall_status != MfaComplianceStatus.SUSPENDED.value
|
||||||
|
):
|
||||||
|
overall_status = MfaComplianceStatus.PAST_DUE.value
|
||||||
|
elif (
|
||||||
|
status == MfaComplianceStatus.IN_GRACE.value
|
||||||
|
and overall_status
|
||||||
|
not in (
|
||||||
|
MfaComplianceStatus.SUSPENDED.value,
|
||||||
|
MfaComplianceStatus.PAST_DUE.value,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
overall_status = MfaComplianceStatus.IN_GRACE.value
|
||||||
|
elif (
|
||||||
|
status == MfaComplianceStatus.PENDING.value
|
||||||
|
and overall_status == MfaComplianceStatus.COMPLIANT.value
|
||||||
|
):
|
||||||
|
overall_status = MfaComplianceStatus.PENDING.value
|
||||||
|
|
||||||
|
org_states.append(
|
||||||
|
UserMfaStateDto(
|
||||||
|
organization_id=org.id,
|
||||||
|
organization_name=org.name,
|
||||||
|
status=status,
|
||||||
|
effective_mode=effective_policy.effective_mode,
|
||||||
|
deadline_at=compliance.deadline_at.isoformat() if compliance.deadline_at else None,
|
||||||
|
applied_at=compliance.applied_at.isoformat() if compliance.applied_at else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return AggregateMfaStateDto(
|
||||||
|
overall_status=overall_status,
|
||||||
|
missing_methods=list(missing_methods),
|
||||||
|
deadline_at=earliest_deadline.isoformat() if earliest_deadline else None,
|
||||||
|
orgs=org_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _evaluate_compliance_status(
|
||||||
|
user: User,
|
||||||
|
effective_policy: EffectiveUserPolicyDto,
|
||||||
|
compliance: MfaPolicyCompliance,
|
||||||
|
) -> str:
|
||||||
|
"""Evaluate compliance status for a user in an organization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
effective_policy: EffectiveUserPolicyDto
|
||||||
|
compliance: MfaPolicyCompliance instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Status string
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# If exempt or disabled, mark as not applicable
|
||||||
|
if effective_policy.is_exempt:
|
||||||
|
return MfaComplianceStatus.NOT_APPLICABLE.value
|
||||||
|
|
||||||
|
if effective_policy.effective_mode == MfaPolicyMode.DISABLED.value:
|
||||||
|
return MfaComplianceStatus.NOT_APPLICABLE.value
|
||||||
|
|
||||||
|
# Check if user has required MFA methods
|
||||||
|
has_totp = user.has_totp_enabled()
|
||||||
|
has_webauthn = user.has_webauthn_enabled()
|
||||||
|
|
||||||
|
has_required = (
|
||||||
|
(not effective_policy.requires_totp or has_totp)
|
||||||
|
and (not effective_policy.requires_webauthn or has_webauthn)
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_required:
|
||||||
|
return MfaComplianceStatus.COMPLIANT.value
|
||||||
|
|
||||||
|
# User is missing required MFA
|
||||||
|
# If no deadline set, set it now
|
||||||
|
if not compliance.deadline_at and effective_policy.grace_period_days > 0:
|
||||||
|
compliance.applied_at = now
|
||||||
|
compliance.deadline_at = now.replace(
|
||||||
|
tzinfo=None
|
||||||
|
) + __import__("datetime").timedelta(
|
||||||
|
days=effective_policy.grace_period_days
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
return MfaComplianceStatus.IN_GRACE.value
|
||||||
|
|
||||||
|
# Check deadline
|
||||||
|
if compliance.deadline_at:
|
||||||
|
deadline = compliance.deadline_at
|
||||||
|
if deadline.tzinfo is None:
|
||||||
|
deadline = deadline.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
if now < deadline:
|
||||||
|
return MfaComplianceStatus.IN_GRACE.value
|
||||||
|
else:
|
||||||
|
return MfaComplianceStatus.PAST_DUE.value
|
||||||
|
|
||||||
|
return MfaComplianceStatus.PENDING.value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def after_primary_auth_success(
|
||||||
|
user: User, remember_me: bool = False
|
||||||
|
) -> LoginPolicyResult:
|
||||||
|
"""Determine session type based on compliance after primary auth success.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
remember_me: Whether this is a remember-me session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoginPolicyResult with session type and compliance summary
|
||||||
|
"""
|
||||||
|
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user)
|
||||||
|
|
||||||
|
# Check if there are any REQUIRED policies affecting this user
|
||||||
|
has_required_policy = False
|
||||||
|
for org_state in compliance_summary.orgs:
|
||||||
|
if org_state.effective_mode in (
|
||||||
|
MfaPolicyMode.REQUIRE_TOTP.value,
|
||||||
|
MfaPolicyMode.REQUIRE_WEBAUTHN.value,
|
||||||
|
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
|
||||||
|
):
|
||||||
|
has_required_policy = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_required_policy:
|
||||||
|
# No required policies, full session allowed
|
||||||
|
return LoginPolicyResult(
|
||||||
|
can_create_full_session=True,
|
||||||
|
create_compliance_only_session=False,
|
||||||
|
compliance_summary=compliance_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user is compliant
|
||||||
|
if compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value:
|
||||||
|
return LoginPolicyResult(
|
||||||
|
can_create_full_session=True,
|
||||||
|
create_compliance_only_session=False,
|
||||||
|
compliance_summary=compliance_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
# User is not compliant
|
||||||
|
if compliance_summary.overall_status in (
|
||||||
|
MfaComplianceStatus.IN_GRACE.value,
|
||||||
|
MfaComplianceStatus.PENDING.value,
|
||||||
|
):
|
||||||
|
# Can proceed with full session but warnings
|
||||||
|
return LoginPolicyResult(
|
||||||
|
can_create_full_session=True,
|
||||||
|
create_compliance_only_session=False,
|
||||||
|
compliance_summary=compliance_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Past due or suspended - compliance only session
|
||||||
|
return LoginPolicyResult(
|
||||||
|
can_create_full_session=False,
|
||||||
|
create_compliance_only_session=True,
|
||||||
|
compliance_summary=compliance_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def transition_to_suspended_if_past_due(now: Optional[datetime] = None) -> int:
|
||||||
|
"""Scheduled job to transition past-due users to suspended status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
now: Current time, defaults to now
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of users transitioned to suspended
|
||||||
|
"""
|
||||||
|
if now is None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
suspended_count = 0
|
||||||
|
|
||||||
|
# Find all compliance records that are past due
|
||||||
|
past_due_records = MfaPolicyCompliance.query.filter(
|
||||||
|
MfaPolicyCompliance.status == MfaComplianceStatus.PAST_DUE,
|
||||||
|
MfaPolicyCompliance.deadline_at != None,
|
||||||
|
MfaPolicyCompliance.deleted_at == None,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
for record in past_due_records:
|
||||||
|
deadline = record.deadline_at
|
||||||
|
if deadline.tzinfo is None:
|
||||||
|
deadline = deadline.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
if now >= deadline:
|
||||||
|
# Transition to suspended
|
||||||
|
record.status = MfaComplianceStatus.SUSPENDED
|
||||||
|
record.suspended_at = now
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Update user status
|
||||||
|
user = User.query.get(record.user_id)
|
||||||
|
if user and user.status != UserStatus.COMPLIANCE_SUSPENDED:
|
||||||
|
user.status = UserStatus.COMPLIANCE_SUSPENDED
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Audit log
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.MFA_POLICY_USER_SUSPENDED,
|
||||||
|
user_id=record.user_id,
|
||||||
|
organization_id=record.organization_id,
|
||||||
|
description=f"User suspended due to MFA compliance deadline passed",
|
||||||
|
)
|
||||||
|
|
||||||
|
suspended_count += 1
|
||||||
|
|
||||||
|
return suspended_count
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_org_policy(
|
||||||
|
organization_id: str,
|
||||||
|
mfa_policy_mode: MfaPolicyMode,
|
||||||
|
mfa_grace_period_days: int = 14,
|
||||||
|
notify_days_before: int = 7,
|
||||||
|
updated_by_user_id: Optional[str] = None,
|
||||||
|
) -> OrganizationSecurityPolicy:
|
||||||
|
"""Create or update organization security policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
organization_id: Organization ID
|
||||||
|
mfa_policy_mode: MFA policy mode
|
||||||
|
mfa_grace_period_days: Grace period in days
|
||||||
|
notify_days_before: Days before deadline to notify
|
||||||
|
updated_by_user_id: User making the change
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OrganizationSecurityPolicy instance
|
||||||
|
"""
|
||||||
|
policy = OrganizationSecurityPolicy.query.filter_by(
|
||||||
|
organization_id=organization_id, deleted_at=None
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if policy:
|
||||||
|
# Update existing
|
||||||
|
old_mode = policy.mfa_policy_mode
|
||||||
|
policy.mfa_policy_mode = mfa_policy_mode
|
||||||
|
policy.mfa_grace_period_days = mfa_grace_period_days
|
||||||
|
policy.notify_days_before = notify_days_before
|
||||||
|
policy.policy_version += 1
|
||||||
|
policy.updated_by_user_id = updated_by_user_id
|
||||||
|
policy.save()
|
||||||
|
|
||||||
|
# Audit log
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.ORG_SECURITY_POLICY_UPDATE,
|
||||||
|
user_id=updated_by_user_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
description=f"Security policy updated from {old_mode.value} to {mfa_policy_mode.value}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create new
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=organization_id,
|
||||||
|
mfa_policy_mode=mfa_policy_mode,
|
||||||
|
mfa_grace_period_days=mfa_grace_period_days,
|
||||||
|
notify_days_before=notify_days_before,
|
||||||
|
policy_version=1,
|
||||||
|
updated_by_user_id=updated_by_user_id,
|
||||||
|
)
|
||||||
|
policy.save()
|
||||||
|
|
||||||
|
# Audit log
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.ORG_SECURITY_POLICY_UPDATE,
|
||||||
|
user_id=updated_by_user_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
description=f"Security policy created with mode {mfa_policy_mode.value}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_user_override(
|
||||||
|
user_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
mfa_override_mode: MfaRequirementOverride,
|
||||||
|
force_totp: bool = False,
|
||||||
|
force_webauthn: bool = False,
|
||||||
|
updated_by_user_id: Optional[str] = None,
|
||||||
|
) -> UserSecurityPolicy:
|
||||||
|
"""Set user security policy override.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
organization_id: Organization ID
|
||||||
|
mfa_override_mode: Override mode
|
||||||
|
force_totp: Force TOTP requirement
|
||||||
|
force_webauthn: Force WebAuthn requirement
|
||||||
|
updated_by_user_id: User making the change
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserSecurityPolicy instance
|
||||||
|
"""
|
||||||
|
override = UserSecurityPolicy.query.filter_by(
|
||||||
|
user_id=user_id, organization_id=organization_id, deleted_at=None
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if override:
|
||||||
|
old_mode = override.mfa_override_mode
|
||||||
|
override.mfa_override_mode = mfa_override_mode
|
||||||
|
override.force_totp = force_totp
|
||||||
|
override.force_webauthn = force_webauthn
|
||||||
|
override.save()
|
||||||
|
|
||||||
|
# Audit log
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.USER_SECURITY_POLICY_OVERRIDE_UPDATE,
|
||||||
|
user_id=updated_by_user_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
resource_type="user",
|
||||||
|
resource_id=user_id,
|
||||||
|
description=f"User policy override updated from {old_mode.value} to {mfa_override_mode.value}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
override = UserSecurityPolicy(
|
||||||
|
user_id=user_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
mfa_override_mode=mfa_override_mode,
|
||||||
|
force_totp=force_totp,
|
||||||
|
force_webauthn=force_webauthn,
|
||||||
|
)
|
||||||
|
override.save()
|
||||||
|
|
||||||
|
# Audit log
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.USER_SECURITY_POLICY_OVERRIDE_UPDATE,
|
||||||
|
user_id=updated_by_user_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
resource_type="user",
|
||||||
|
resource_id=user_id,
|
||||||
|
description=f"User policy override created with mode {mfa_override_mode.value}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return override
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_compliance(user_id: str, organization_id: str) -> Optional[MfaPolicyCompliance]:
|
||||||
|
"""Get user compliance record for an organization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
organization_id: Organization ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MfaPolicyCompliance or None
|
||||||
|
"""
|
||||||
|
return MfaPolicyCompliance.query.filter_by(
|
||||||
|
user_id=user_id, organization_id=organization_id, deleted_at=None
|
||||||
|
).first()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_org_compliance_list(
|
||||||
|
organization_id: str, status: Optional[MfaComplianceStatus] = None, limit: int = 100, offset: int = 0
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Get list of user compliance records for an organization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
organization_id: Organization ID
|
||||||
|
status: Optional status filter
|
||||||
|
limit: Maximum records to return
|
||||||
|
offset: Offset for pagination
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of compliance records with user info
|
||||||
|
"""
|
||||||
|
query = db.session.query(
|
||||||
|
MfaPolicyCompliance,
|
||||||
|
User.email,
|
||||||
|
User.full_name,
|
||||||
|
).join(
|
||||||
|
User, User.id == MfaPolicyCompliance.user_id
|
||||||
|
).filter(
|
||||||
|
MfaPolicyCompliance.organization_id == organization_id,
|
||||||
|
MfaPolicyCompliance.deleted_at == None,
|
||||||
|
User.deleted_at == None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
query = query.filter(MfaPolicyCompliance.status == status)
|
||||||
|
|
||||||
|
records = query.order_by(
|
||||||
|
MfaPolicyCompliance.created_at.desc()
|
||||||
|
).limit(limit).offset(offset).all()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for compliance, email, full_name in records:
|
||||||
|
result.append({
|
||||||
|
"user_id": compliance.user_id,
|
||||||
|
"email": email,
|
||||||
|
"full_name": full_name,
|
||||||
|
"status": compliance.status.value,
|
||||||
|
"deadline_at": compliance.deadline_at.isoformat() if compliance.deadline_at else None,
|
||||||
|
"applied_at": compliance.applied_at.isoformat() if compliance.applied_at else None,
|
||||||
|
"compliant_at": compliance.compliant_at.isoformat() if compliance.compliant_at else None,
|
||||||
|
"suspended_at": compliance.suspended_at.isoformat() if compliance.suspended_at else None,
|
||||||
|
"notification_count": compliance.notification_count,
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Multi-Organization Edge Case Handling
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_strictest_mode(modes: List[str]) -> str:
|
||||||
|
"""Get the strictest MFA policy mode from a list.
|
||||||
|
|
||||||
|
Used for multi-org scenarios where a user belongs to multiple organizations
|
||||||
|
with different policies. "Most secure wins" logic determines the effective
|
||||||
|
requirement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modes: List of policy mode strings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The strictest mode string
|
||||||
|
"""
|
||||||
|
# Define strictness hierarchy (more strict = higher index)
|
||||||
|
strictness_order = [
|
||||||
|
MfaPolicyMode.DISABLED.value,
|
||||||
|
MfaPolicyMode.OPTIONAL.value,
|
||||||
|
MfaPolicyMode.REQUIRE_TOTP.value,
|
||||||
|
MfaPolicyMode.REQUIRE_WEBAUTHN.value,
|
||||||
|
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
|
||||||
|
]
|
||||||
|
|
||||||
|
max_strictness = -1
|
||||||
|
result_mode = MfaPolicyMode.OPTIONAL.value
|
||||||
|
|
||||||
|
for mode in modes:
|
||||||
|
if mode in strictness_order:
|
||||||
|
idx = strictness_order.index(mode)
|
||||||
|
if idx > max_strictness:
|
||||||
|
max_strictness = idx
|
||||||
|
result_mode = mode
|
||||||
|
|
||||||
|
return result_mode
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reevaluate_all_org_compliance(organization_id: str, now: Optional[datetime] = None) -> int:
|
||||||
|
"""Reevaluate compliance for all users in an organization.
|
||||||
|
|
||||||
|
Called when org policy changes to ensure all users are properly evaluated
|
||||||
|
under the new policy. This handles the edge case where policy becomes
|
||||||
|
more restrictive (e.g., OPTIONAL -> REQUIRE_TOTP).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
organization_id: Organization ID
|
||||||
|
now: Current time, defaults to now
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of compliance records updated
|
||||||
|
"""
|
||||||
|
if now is None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
from gatehouse_app.models.organization_member import OrganizationMember
|
||||||
|
|
||||||
|
updated_count = 0
|
||||||
|
|
||||||
|
# Get all active members of the organization
|
||||||
|
memberships = OrganizationMember.query.filter_by(
|
||||||
|
organization_id=organization_id, deleted_at=None
|
||||||
|
).all()
|
||||||
|
|
||||||
|
for membership in memberships:
|
||||||
|
user = membership.user
|
||||||
|
if not user or user.deleted_at is not None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get or create compliance record
|
||||||
|
compliance = MfaPolicyCompliance.query.filter_by(
|
||||||
|
user_id=user.id, organization_id=organization_id, deleted_at=None
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not compliance:
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=user.id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
status=MfaComplianceStatus.NOT_APPLICABLE,
|
||||||
|
policy_version=0,
|
||||||
|
)
|
||||||
|
compliance.save()
|
||||||
|
|
||||||
|
# Reevaluate under new policy
|
||||||
|
effective_policy = MfaPolicyService.get_effective_user_policy(
|
||||||
|
user.id, organization_id
|
||||||
|
)
|
||||||
|
|
||||||
|
old_status = compliance.status.value if hasattr(compliance.status, 'value') else str(compliance.status)
|
||||||
|
new_status = MfaPolicyService._evaluate_compliance_status(
|
||||||
|
user, effective_policy, compliance
|
||||||
|
)
|
||||||
|
|
||||||
|
if old_status != new_status:
|
||||||
|
compliance.status = MfaComplianceStatus(new_status)
|
||||||
|
# Reset deadline if transitioning to in_grace from a non-grace state
|
||||||
|
if new_status == MfaComplianceStatus.IN_GRACE.value and not compliance.deadline_at:
|
||||||
|
compliance.applied_at = now
|
||||||
|
compliance.deadline_at = now.replace(tzinfo=None) + __import__("datetime").timedelta(
|
||||||
|
days=effective_policy.grace_period_days
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
updated_count += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Reevaluated compliance for user {user.email} in org {organization_id}: "
|
||||||
|
f"{old_status} -> {new_status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return updated_count
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_and_restore_user_status(user_id: str) -> bool:
|
||||||
|
"""Check if user should be restored to ACTIVE status.
|
||||||
|
|
||||||
|
Called after compliance changes to determine if a COMPLIANCE_SUSPENDED
|
||||||
|
user should be restored to ACTIVE status. This happens when:
|
||||||
|
- All org policies are now compliant
|
||||||
|
- User overrides were changed to EXEMPT
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if user status was restored, False otherwise
|
||||||
|
"""
|
||||||
|
user = User.query.get(user_id)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if user.status != UserStatus.COMPLIANCE_SUSPENDED:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Evaluate user's overall compliance state
|
||||||
|
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user)
|
||||||
|
|
||||||
|
# If now compliant across all orgs, restore status
|
||||||
|
if compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value:
|
||||||
|
user.status = UserStatus.ACTIVE
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Audit log
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.MFA_POLICY_USER_COMPLIANT,
|
||||||
|
user_id=user_id,
|
||||||
|
description="User restored to ACTIVE status after becoming MFA compliant",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"User {user.email} restored to ACTIVE status")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# User Override Edge Case Handling
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_override_summary(user_id: str, organization_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get a summary of user override for an organization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
organization_id: Organization ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with override information
|
||||||
|
"""
|
||||||
|
user_override = UserSecurityPolicy.query.filter_by(
|
||||||
|
user_id=user_id, organization_id=organization_id, deleted_at=None
|
||||||
|
).first()
|
||||||
|
|
||||||
|
org_policy = MfaPolicyService.get_org_policy(organization_id)
|
||||||
|
|
||||||
|
if not user_override:
|
||||||
|
return {
|
||||||
|
"has_override": False,
|
||||||
|
"mode": "inherit",
|
||||||
|
"org_policy_mode": org_policy.mfa_policy_mode if org_policy else "none",
|
||||||
|
"effective_mode": org_policy.mfa_policy_mode if org_policy else "disabled",
|
||||||
|
}
|
||||||
|
|
||||||
|
effective_policy = MfaPolicyService.get_effective_user_policy(
|
||||||
|
user_id, organization_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"has_override": True,
|
||||||
|
"mode": user_override.mfa_override_mode.value,
|
||||||
|
"force_totp": user_override.force_totp,
|
||||||
|
"force_webauthn": user_override.force_webauthn,
|
||||||
|
"org_policy_mode": org_policy.mfa_policy_mode if org_policy else "none",
|
||||||
|
"effective_mode": effective_policy.effective_mode,
|
||||||
|
"is_exempt": effective_policy.is_exempt,
|
||||||
|
}
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Security Audit Logging
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def log_suspended_login_attempt(user: User, ip_address: str = None, user_agent: str = None):
|
||||||
|
"""Log a login attempt by a compliance-suspended user.
|
||||||
|
|
||||||
|
This provides audit trail for potential security incidents where
|
||||||
|
suspended users attempt to access the system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
ip_address: Client IP address
|
||||||
|
user_agent: Client user agent
|
||||||
|
"""
|
||||||
|
# Get current compliance summary
|
||||||
|
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user)
|
||||||
|
|
||||||
|
# Find which org(s) caused suspension
|
||||||
|
suspended_orgs = [
|
||||||
|
org for org in compliance_summary.orgs
|
||||||
|
if org.status == MfaComplianceStatus.SUSPENDED.value
|
||||||
|
]
|
||||||
|
|
||||||
|
org_ids = [org.organization_id for org in suspended_orgs]
|
||||||
|
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.USER_LOGIN,
|
||||||
|
user_id=user.id,
|
||||||
|
organization_id=org_ids[0] if org_ids else None,
|
||||||
|
ip_address=ip_address,
|
||||||
|
user_agent=user_agent,
|
||||||
|
description=f"Login attempt while compliance suspended. Suspended orgs: {org_ids}",
|
||||||
|
success=False,
|
||||||
|
error_message="MFA compliance required",
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def log_policy_bypass_attempt(
|
||||||
|
user: User,
|
||||||
|
endpoint: str,
|
||||||
|
ip_address: str = None,
|
||||||
|
user_agent: str = None,
|
||||||
|
):
|
||||||
|
"""Log a potential policy bypass attempt.
|
||||||
|
|
||||||
|
Called when a compliance-only session attempts to access a
|
||||||
|
full-access endpoint. This could indicate security issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
endpoint: Requested endpoint
|
||||||
|
ip_address: Client IP address
|
||||||
|
user_agent: Client user agent
|
||||||
|
"""
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.USER_LOGIN, # Reusing USER_LOGIN for audit
|
||||||
|
user_id=user.id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
user_agent=user_agent,
|
||||||
|
resource_type="endpoint",
|
||||||
|
resource_id=endpoint,
|
||||||
|
description=f"Policy bypass attempt - compliance-only session accessed {endpoint}",
|
||||||
|
success=False,
|
||||||
|
error_message="MFA compliance required",
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_multi_org_aggregate_state(user: User) -> Dict[str, Any]:
|
||||||
|
"""Get aggregate MFA state for a user across all organizations.
|
||||||
|
|
||||||
|
This provides detailed breakdown of how multi-org membership affects
|
||||||
|
compliance status, useful for debugging and admin reporting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with aggregate state details
|
||||||
|
"""
|
||||||
|
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(user)
|
||||||
|
|
||||||
|
# Calculate strictest requirement
|
||||||
|
modes = [org.effective_mode for org in compliance_summary.orgs]
|
||||||
|
strictest_mode = MfaPolicyService.get_strictest_mode(modes)
|
||||||
|
|
||||||
|
# Find organizations requiring MFA
|
||||||
|
requiring_orgs = [
|
||||||
|
{
|
||||||
|
"organization_id": org.organization_id,
|
||||||
|
"organization_name": org.organization_name,
|
||||||
|
"mode": org.effective_mode,
|
||||||
|
"status": org.status,
|
||||||
|
}
|
||||||
|
for org in compliance_summary.orgs
|
||||||
|
if org.effective_mode not in (
|
||||||
|
MfaPolicyMode.DISABLED.value,
|
||||||
|
MfaPolicyMode.OPTIONAL.value,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Find exempt organizations
|
||||||
|
for org in compliance_summary.orgs:
|
||||||
|
override_summary = MfaPolicyService.get_override_summary(
|
||||||
|
user.id, org.organization_id
|
||||||
|
)
|
||||||
|
if override_summary.get("is_exempt"):
|
||||||
|
requiring_orgs = [
|
||||||
|
o for o in requiring_orgs
|
||||||
|
if o["organization_id"] != org.organization_id
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"overall_status": compliance_summary.overall_status,
|
||||||
|
"strictest_mode": strictest_mode,
|
||||||
|
"missing_methods": compliance_summary.missing_methods,
|
||||||
|
"deadline_at": compliance_summary.deadline_at,
|
||||||
|
"requiring_org_count": len(requiring_orgs),
|
||||||
|
"requiring_orgs": requiring_orgs,
|
||||||
|
"total_org_count": len(compliance_summary.orgs),
|
||||||
|
"per_org_details": [
|
||||||
|
{
|
||||||
|
"organization_id": org.organization_id,
|
||||||
|
"organization_name": org.organization_name,
|
||||||
|
"effective_mode": org.effective_mode,
|
||||||
|
"status": org.status,
|
||||||
|
"deadline_at": org.deadline_at,
|
||||||
|
"applied_at": org.applied_at,
|
||||||
|
}
|
||||||
|
for org in compliance_summary.orgs
|
||||||
|
],
|
||||||
|
}
|
||||||
@@ -0,0 +1,430 @@
|
|||||||
|
"""Notification Service for MFA compliance notifications.
|
||||||
|
|
||||||
|
This service handles sending MFA-related notifications to users, including:
|
||||||
|
- Deadline reminder emails
|
||||||
|
- Suspension notifications
|
||||||
|
- Compliance status updates
|
||||||
|
|
||||||
|
The service is designed to work with or without email infrastructure:
|
||||||
|
- If email is configured, it sends actual emails
|
||||||
|
- If email is not available, it logs notifications for debugging/auditing
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from gatehouse_app.services.notification_service import NotificationService
|
||||||
|
NotificationService.send_mfa_deadline_reminder(user, compliance, org_policy)
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
|
||||||
|
from gatehouse_app.extensions import db
|
||||||
|
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
||||||
|
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
|
||||||
|
from gatehouse_app.models.user import User
|
||||||
|
from gatehouse_app.services.audit_service import AuditService
|
||||||
|
from gatehouse_app.utils.constants import AuditAction
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationService:
|
||||||
|
"""Service for sending MFA compliance notifications."""
|
||||||
|
|
||||||
|
# Configuration keys for email settings
|
||||||
|
EMAIL_ENABLED_KEY = "EMAIL_ENABLED"
|
||||||
|
SMTP_HOST_KEY = "SMTP_HOST"
|
||||||
|
SMTP_PORT_KEY = "SMTP_PORT"
|
||||||
|
SMTP_USERNAME_KEY = "SMTP_USERNAME"
|
||||||
|
SMTP_PASSWORD_KEY = "SMTP_PASSWORD"
|
||||||
|
FROM_ADDRESS_KEY = "FROM_ADDRESS"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def send_mfa_deadline_reminder(
|
||||||
|
user: User,
|
||||||
|
compliance: MfaPolicyCompliance,
|
||||||
|
org_policy: OrganizationSecurityPolicy,
|
||||||
|
) -> bool:
|
||||||
|
"""Send MFA deadline reminder notification to user.
|
||||||
|
|
||||||
|
Sends a reminder email to users who are approaching their MFA
|
||||||
|
compliance deadline. The reminder includes:
|
||||||
|
- Days remaining until deadline
|
||||||
|
- Required MFA methods
|
||||||
|
- Link to MFA enrollment
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User to notify
|
||||||
|
compliance: User's compliance record
|
||||||
|
org_policy: Organization's MFA policy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if notification was sent successfully, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Calculate days until deadline
|
||||||
|
deadline = compliance.deadline_at
|
||||||
|
if deadline.tzinfo is None:
|
||||||
|
deadline = deadline.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
days_until_deadline = (deadline - now).days
|
||||||
|
|
||||||
|
# Build notification content
|
||||||
|
subject = f"Action Required: MFA enrollment deadline in {days_until_deadline} days"
|
||||||
|
body = NotificationService._build_deadline_reminder_body(
|
||||||
|
user, compliance, org_policy, days_until_deadline
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send the notification
|
||||||
|
success = NotificationService._send_email(
|
||||||
|
to_address=user.email,
|
||||||
|
subject=subject,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info(
|
||||||
|
f"Sent MFA deadline reminder to {user.email} "
|
||||||
|
f"({days_until_deadline} days remaining # Audit log
|
||||||
|
)"
|
||||||
|
)
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.MFA_POLICY_USER_COMPLIANT,
|
||||||
|
user_id=user.id,
|
||||||
|
organization_id=compliance.organization_id,
|
||||||
|
description=f"MFA deadline reminder sent. Days remaining: {days_until_deadline}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to send MFA deadline reminder to {user.email}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error sending MFA deadline reminder to {user.email}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def send_mfa_suspended_notification(
|
||||||
|
user: User,
|
||||||
|
compliance: MfaPolicyCompliance,
|
||||||
|
org_policy: OrganizationSecurityPolicy,
|
||||||
|
) -> bool:
|
||||||
|
"""Send MFA suspension notification to user.
|
||||||
|
|
||||||
|
Notifies users that their account has been suspended due to
|
||||||
|
failure to comply with MFA requirements. The notification includes:
|
||||||
|
- Explanation of suspension
|
||||||
|
- Steps to restore access
|
||||||
|
- Link to MFA enrollment
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User to notify
|
||||||
|
compliance: User's compliance record
|
||||||
|
org_policy: Organization's MFA policy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if notification was sent successfully, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Build notification content
|
||||||
|
subject = "Account Access Restricted - MFA Enrollment Required"
|
||||||
|
body = NotificationService._build_suspension_body(
|
||||||
|
user, compliance, org_policy
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send the notification
|
||||||
|
success = NotificationService._send_email(
|
||||||
|
to_address=user.email,
|
||||||
|
subject=subject,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info(f"Sent MFA suspension notification to {user.email}")
|
||||||
|
# Audit log
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.MFA_POLICY_USER_SUSPENDED,
|
||||||
|
user_id=user.id,
|
||||||
|
organization_id=compliance.organization_id,
|
||||||
|
description="MFA compliance suspension notification sent",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to send MFA suspension notification to {user.email}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Error sending MFA suspension notification to {user.email}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_deadline_reminder_body(
|
||||||
|
user: User,
|
||||||
|
compliance: MfaPolicyCompliance,
|
||||||
|
org_policy: OrganizationSecurityPolicy,
|
||||||
|
days_until_deadline: int,
|
||||||
|
) -> str:
|
||||||
|
"""Build the email body for deadline reminder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User being notified
|
||||||
|
compliance: Compliance record
|
||||||
|
org_policy: Organization policy
|
||||||
|
days_until_deadline: Days remaining until deadline
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Email body string
|
||||||
|
"""
|
||||||
|
org_name = compliance.organization_id # In real impl, fetch org name
|
||||||
|
|
||||||
|
body = f"""
|
||||||
|
Dear {user.full_name or user.email},
|
||||||
|
|
||||||
|
This is a reminder that you need to set up multi-factor authentication (MFA)
|
||||||
|
to maintain access to your account in the organization "{org_name}".
|
||||||
|
|
||||||
|
**Important Details:**
|
||||||
|
- Days remaining: {days_until_deadline}
|
||||||
|
- Deadline: {compliance.deadline_at.strftime('%Y-%m-%d %H:%M UTC') if compliance.deadline_at else 'Not set'}
|
||||||
|
|
||||||
|
**Required MFA Methods:**
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Add required methods based on policy mode
|
||||||
|
from gatehouse_app.utils.constants import MfaPolicyMode
|
||||||
|
|
||||||
|
mode = org_policy.mfa_policy_mode
|
||||||
|
if mode == MfaPolicyMode.REQUIRE_TOTP:
|
||||||
|
body += "- Authenticator app (TOTP)\n"
|
||||||
|
elif mode == MfaPolicyMode.REQUIRE_WEBAUTHN:
|
||||||
|
body += "- Passkey (WebAuthn)\n"
|
||||||
|
elif mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN:
|
||||||
|
body += "- Authenticator app (TOTP) OR Passkey (WebAuthn)\n"
|
||||||
|
else:
|
||||||
|
body += "- Multi-factor authentication\n"
|
||||||
|
|
||||||
|
body += """
|
||||||
|
**How to Set Up MFA:**
|
||||||
|
1. Log in to your account
|
||||||
|
2. Navigate to Settings > Security
|
||||||
|
3. Follow the prompts to set up an authenticator app or passkey
|
||||||
|
|
||||||
|
If you do not set up MFA by the deadline, your account access will be restricted.
|
||||||
|
|
||||||
|
If you have any questions, please contact your organization administrator.
|
||||||
|
|
||||||
|
Best regards,
|
||||||
|
Gatehouse Security Team
|
||||||
|
"""
|
||||||
|
return body
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_suspension_body(
|
||||||
|
user: User,
|
||||||
|
compliance: MfaPolicyCompliance,
|
||||||
|
org_policy: OrganizationSecurityPolicy,
|
||||||
|
) -> str:
|
||||||
|
"""Build the email body for suspension notification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User being notified
|
||||||
|
compliance: Compliance record
|
||||||
|
org_policy: Organization policy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Email body string
|
||||||
|
"""
|
||||||
|
org_name = compliance.organization_id # In real impl, fetch org name
|
||||||
|
|
||||||
|
body = f"""
|
||||||
|
Dear {user.full_name or user.email},
|
||||||
|
|
||||||
|
Your account access has been restricted because you did not set up
|
||||||
|
multi-factor authentication (MFA) within the required timeframe for
|
||||||
|
the organization "{org_name}".
|
||||||
|
|
||||||
|
**What Happened:**
|
||||||
|
Your MFA compliance deadline passed without MFA being configured.
|
||||||
|
As a result, your account has been placed in a suspended state.
|
||||||
|
|
||||||
|
**How to Restore Access:**
|
||||||
|
1. Log in to your account (you will see a compliance enrollment screen)
|
||||||
|
2. Follow the prompts to set up an authenticator app or passkey
|
||||||
|
3. Once MFA is configured, your access will be restored
|
||||||
|
|
||||||
|
**Required MFA Methods:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Add required methods based on policy mode
|
||||||
|
from gatehouse_app.utils.constants import MfaPolicyMode
|
||||||
|
|
||||||
|
mode = org_policy.mfa_policy_mode
|
||||||
|
if mode == MfaPolicyMode.REQUIRE_TOTP:
|
||||||
|
body += "- Authenticator app (TOTP)\n"
|
||||||
|
elif mode == MfaPolicyMode.REQUIRE_WEBAUTHN:
|
||||||
|
body += "- Passkey (WebAuthn)\n"
|
||||||
|
elif mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN:
|
||||||
|
body += "- Authenticator app (TOTP) OR Passkey (WebAuthn)\n"
|
||||||
|
else:
|
||||||
|
body += "- Multi-factor authentication\n"
|
||||||
|
|
||||||
|
body += """
|
||||||
|
**Need Help?**
|
||||||
|
Contact your organization administrator if you have questions.
|
||||||
|
|
||||||
|
Best regards,
|
||||||
|
Gatehouse Security Team
|
||||||
|
"""
|
||||||
|
return body
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _send_email(
|
||||||
|
to_address: str,
|
||||||
|
subject: str,
|
||||||
|
body: str,
|
||||||
|
html_body: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Send an email notification.
|
||||||
|
|
||||||
|
This method attempts to send an email using configured SMTP settings.
|
||||||
|
If email is not configured, it logs the notification instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to_address: Recipient email address
|
||||||
|
subject: Email subject
|
||||||
|
body: Plain text email body
|
||||||
|
html_body: Optional HTML email body
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if email was sent (or logged), False on error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
# Check if email is configured
|
||||||
|
email_enabled = current_app.config.get(
|
||||||
|
NotificationService.EMAIL_ENABLED_KEY, False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not email_enabled:
|
||||||
|
# Log the notification instead of sending
|
||||||
|
logger.info(
|
||||||
|
f"[EMAIL SIMULATION] To: {to_address}\n"
|
||||||
|
f"Subject: {subject}\n"
|
||||||
|
f"Body: {body[:200]}..." if len(body) > 200 else f"Body: {body}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Get email configuration
|
||||||
|
smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY)
|
||||||
|
smtp_port = current_app.config.get(NotificationService.SMTP_PORT_KEY, 587)
|
||||||
|
smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY)
|
||||||
|
smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY)
|
||||||
|
from_address = current_app.config.get(
|
||||||
|
NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import send_email based on available mail library
|
||||||
|
try:
|
||||||
|
from flask_mail import Message
|
||||||
|
|
||||||
|
from gatehouse_app import mail
|
||||||
|
|
||||||
|
msg = Message(
|
||||||
|
subject=subject,
|
||||||
|
recipients=[to_address],
|
||||||
|
body=body,
|
||||||
|
html=html_body,
|
||||||
|
sender=from_address,
|
||||||
|
)
|
||||||
|
mail.send(msg)
|
||||||
|
logger.info(f"Email sent successfully to {to_address}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Flask-Mail not available, use SMTP directly
|
||||||
|
import smtplib
|
||||||
|
from email.mime.text import MIMEText
|
||||||
|
from email.mime.multipart import MIMEMultipart
|
||||||
|
|
||||||
|
msg = MIMEMultipart("alternative")
|
||||||
|
msg["Subject"] = subject
|
||||||
|
msg["From"] = from_address
|
||||||
|
msg["To"] = to_address
|
||||||
|
|
||||||
|
# Attach plain text and HTML versions
|
||||||
|
part1 = MIMEText(body, "plain")
|
||||||
|
msg.attach(part1)
|
||||||
|
|
||||||
|
if html_body:
|
||||||
|
part2 = MIMEText(html_body, "html")
|
||||||
|
msg.attach(part2)
|
||||||
|
|
||||||
|
# Send via SMTP
|
||||||
|
with smtplib.SMTP(smtp_host, smtp_port) as server:
|
||||||
|
server.starttls()
|
||||||
|
if smtp_username and smtp_password:
|
||||||
|
server.login(smtp_username, smtp_password)
|
||||||
|
server.send_message(msg)
|
||||||
|
|
||||||
|
logger.info(f"Email sent successfully to {to_address}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to send email to {to_address}: {e}")
|
||||||
|
# Log the notification as fallback
|
||||||
|
logger.info(
|
||||||
|
f"[EMAIL FALLBACK] To: {to_address}\n"
|
||||||
|
f"Subject: {subject}\n"
|
||||||
|
f"Body: {body[:500]}..." if len(body) > 500 else f"Body: {body}"
|
||||||
|
)
|
||||||
|
return True # Return True to continue processing
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_notification_stats(user_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get notification statistics for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with notification statistics
|
||||||
|
"""
|
||||||
|
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"total_notifications": 0,
|
||||||
|
"last_notification": None,
|
||||||
|
"by_organization": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
compliance_records = MfaPolicyCompliance.query.filter_by(
|
||||||
|
user_id=user_id, deleted_at=None
|
||||||
|
).all()
|
||||||
|
|
||||||
|
total_notifications = 0
|
||||||
|
last_notification = None
|
||||||
|
|
||||||
|
for record in compliance_records:
|
||||||
|
total_notifications += record.notification_count
|
||||||
|
if record.last_notified_at:
|
||||||
|
if last_notification is None or record.last_notified_at > last_notification:
|
||||||
|
last_notification = record.last_notified_at
|
||||||
|
|
||||||
|
stats["by_organization"].append({
|
||||||
|
"organization_id": record.organization_id,
|
||||||
|
"notification_count": record.notification_count,
|
||||||
|
"last_notified_at": record.last_notified_at.isoformat() if record.last_notified_at else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
stats["total_notifications"] = total_notifications
|
||||||
|
stats["last_notification"] = last_notification.isoformat() if last_notification else None
|
||||||
|
|
||||||
|
return stats
|
||||||
@@ -9,6 +9,7 @@ class UserStatus(str, Enum):
|
|||||||
INACTIVE = "inactive"
|
INACTIVE = "inactive"
|
||||||
SUSPENDED = "suspended"
|
SUSPENDED = "suspended"
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
|
COMPLIANCE_SUSPENDED = "compliance_suspended"
|
||||||
|
|
||||||
|
|
||||||
class OrganizationRole(str, Enum):
|
class OrganizationRole(str, Enum):
|
||||||
@@ -86,6 +87,12 @@ class AuditAction(str, Enum):
|
|||||||
WEBAUTHN_CREDENTIAL_DELETED = "webauthn.credential.deleted"
|
WEBAUTHN_CREDENTIAL_DELETED = "webauthn.credential.deleted"
|
||||||
WEBAUTHN_CREDENTIAL_RENAMED = "webauthn.credential.renamed"
|
WEBAUTHN_CREDENTIAL_RENAMED = "webauthn.credential.renamed"
|
||||||
|
|
||||||
|
# Security policy actions
|
||||||
|
ORG_SECURITY_POLICY_UPDATE = "org.security_policy.update"
|
||||||
|
USER_SECURITY_POLICY_OVERRIDE_UPDATE = "user.security_policy.override_update"
|
||||||
|
MFA_POLICY_USER_SUSPENDED = "mfa.policy.user_suspended"
|
||||||
|
MFA_POLICY_USER_COMPLIANT = "mfa.policy.user_compliant"
|
||||||
|
|
||||||
|
|
||||||
class OIDCGrantType(str, Enum):
|
class OIDCGrantType(str, Enum):
|
||||||
"""OIDC grant types."""
|
"""OIDC grant types."""
|
||||||
@@ -116,3 +123,32 @@ class ErrorType:
|
|||||||
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"
|
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"
|
||||||
INTERNAL_ERROR = "INTERNAL_ERROR"
|
INTERNAL_ERROR = "INTERNAL_ERROR"
|
||||||
BAD_REQUEST = "BAD_REQUEST"
|
BAD_REQUEST = "BAD_REQUEST"
|
||||||
|
|
||||||
|
|
||||||
|
class MfaPolicyMode(str, Enum):
|
||||||
|
"""MFA policy mode for organizations."""
|
||||||
|
|
||||||
|
DISABLED = "disabled"
|
||||||
|
OPTIONAL = "optional"
|
||||||
|
REQUIRE_TOTP = "require_totp"
|
||||||
|
REQUIRE_WEBAUTHN = "require_webauthn"
|
||||||
|
REQUIRE_TOTP_OR_WEBAUTHN = "require_totp_or_webauthn"
|
||||||
|
|
||||||
|
|
||||||
|
class MfaComplianceStatus(str, Enum):
|
||||||
|
"""MFA compliance status for users per organization."""
|
||||||
|
|
||||||
|
NOT_APPLICABLE = "not_applicable"
|
||||||
|
PENDING = "pending"
|
||||||
|
IN_GRACE = "in_grace"
|
||||||
|
COMPLIANT = "compliant"
|
||||||
|
PAST_DUE = "past_due"
|
||||||
|
SUSPENDED = "suspended"
|
||||||
|
|
||||||
|
|
||||||
|
class MfaRequirementOverride(str, Enum):
|
||||||
|
"""User override for organization MFA requirements."""
|
||||||
|
|
||||||
|
INHERIT = "inherit"
|
||||||
|
REQUIRED = "required"
|
||||||
|
EXEMPT = "exempt"
|
||||||
|
|||||||
@@ -2,7 +2,8 @@
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from flask import request, g
|
from flask import request, g
|
||||||
from gatehouse_app.utils.response import api_response
|
from gatehouse_app.utils.response import api_response
|
||||||
from gatehouse_app.utils.constants import OrganizationRole
|
from gatehouse_app.utils.constants import OrganizationRole, UserStatus
|
||||||
|
from gatehouse_app.exceptions.auth_exceptions import UnauthorizedError, ForbiddenError
|
||||||
|
|
||||||
|
|
||||||
def login_required(f):
|
def login_required(f):
|
||||||
@@ -127,3 +128,41 @@ def require_owner(f):
|
|||||||
def require_admin(f):
|
def require_admin(f):
|
||||||
"""Decorator to require organization admin or owner role."""
|
"""Decorator to require organization admin or owner role."""
|
||||||
return require_role(OrganizationRole.OWNER, OrganizationRole.ADMIN)(f)
|
return require_role(OrganizationRole.OWNER, OrganizationRole.ADMIN)(f)
|
||||||
|
|
||||||
|
|
||||||
|
def full_access_required(f):
|
||||||
|
"""Decorator to require full access session (not compliance-only).
|
||||||
|
|
||||||
|
This decorator checks if the user has a compliance-only session or
|
||||||
|
is in COMPLIANCE_SUSPENDED status. If so, it returns a 403 error
|
||||||
|
with error_type "MFA_COMPLIANCE_REQUIRED".
|
||||||
|
|
||||||
|
Use this decorator on endpoints that require full MFA compliance.
|
||||||
|
Endpoints for MFA enrollment, status, and logout should NOT use this decorator.
|
||||||
|
"""
|
||||||
|
@wraps(f)
|
||||||
|
def decorated_function(*args, **kwargs):
|
||||||
|
user = getattr(g, "current_user", None)
|
||||||
|
session = getattr(g, "current_session", None)
|
||||||
|
|
||||||
|
if not user or not session:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Authentication required",
|
||||||
|
status=401,
|
||||||
|
error_type="AUTH_REQUIRED",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for compliance-only session or compliance suspended status
|
||||||
|
if session.is_compliance_only or user.status == UserStatus.COMPLIANCE_SUSPENDED:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="MFA compliance required to access this resource",
|
||||||
|
status=403,
|
||||||
|
error_type="MFA_COMPLIANCE_REQUIRED",
|
||||||
|
error_details={"overall_status": "suspended"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated_function
|
||||||
|
|||||||
@@ -14,5 +14,102 @@ app = create_app(os.getenv("FLASK_ENV", "development"))
|
|||||||
# Create Flask CLI group
|
# Create Flask CLI group
|
||||||
cli = FlaskGroup(create_app=lambda: app)
|
cli = FlaskGroup(create_app=lambda: app)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command("run_mfa_compliance_job")
|
||||||
|
def run_mfa_compliance_job():
|
||||||
|
"""Run the MFA compliance scheduled job.
|
||||||
|
|
||||||
|
This command processes MFA compliance transitions:
|
||||||
|
- Transitions users from PAST_DUE to SUSPENDED status
|
||||||
|
- Sends deadline reminder notifications
|
||||||
|
- Updates notification tracking metadata
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python manage.py run_mfa_compliance_job
|
||||||
|
|
||||||
|
This can be called via cron or a task scheduler:
|
||||||
|
0 * * * * cd /path/to/app && python manage.py run_mfa_compliance_job
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from gatehouse_app.jobs.mfa_compliance_job import process_mfa_compliance, get_job_status
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("MFA Compliance Job")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
print(f"Start time: {now.isoformat()}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Show current status before processing
|
||||||
|
print("Current Compliance Status:")
|
||||||
|
status = get_job_status(now)
|
||||||
|
for status_name, count in status["status_counts"].items():
|
||||||
|
print(f" {status_name}: {count}")
|
||||||
|
print(f" Approaching deadline: {status['approaching_deadline_count']}")
|
||||||
|
print(f" Past due: {status['past_due_count']}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Run the job
|
||||||
|
print("Processing compliance...")
|
||||||
|
result = process_mfa_compliance(now)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Job Results:")
|
||||||
|
print(f" Users suspended: {result['suspended_count']}")
|
||||||
|
print(f" Notifications sent: {result['notified_count']}")
|
||||||
|
print(f" Records processed: {result['processed_count']}")
|
||||||
|
|
||||||
|
if result['errors']:
|
||||||
|
print()
|
||||||
|
print("Errors:")
|
||||||
|
for error in result['errors']:
|
||||||
|
print(f" - {error}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 60)
|
||||||
|
print("Job completed successfully")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command("mfa_compliance_status")
|
||||||
|
def mfa_compliance_status():
|
||||||
|
"""Show current MFA compliance status.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python manage.py mfa_compliance_status
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from gatehouse_app.jobs.mfa_compliance_job import get_job_status
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("MFA Compliance Status Report")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
status = get_job_status(now)
|
||||||
|
|
||||||
|
print(f"Report time: {status['timestamp']}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Compliance Records by Status:")
|
||||||
|
for status_name, count in sorted(status["status_counts"].items()):
|
||||||
|
bar = "█" * min(count, 50)
|
||||||
|
print(f" {status_name:20s}: {count:5d} {bar}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Summary:")
|
||||||
|
print(f" Approaching deadline: {status['approaching_deadline_count']}")
|
||||||
|
print(f" Past due (pending suspension): {status['past_due_count']}")
|
||||||
|
|
||||||
|
total = sum(status["status_counts"].values())
|
||||||
|
compliant = status["status_counts"].get("compliant", 0)
|
||||||
|
if total > 0:
|
||||||
|
compliance_rate = (compliant / total) * 100
|
||||||
|
print(f" Compliance rate: {compliance_rate:.1f}%")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|||||||
@@ -1,150 +0,0 @@
|
|||||||
"""Database migration: Create OIDC tables.
|
|
||||||
|
|
||||||
Revision ID: 001
|
|
||||||
Revises:
|
|
||||||
Create Date: 2024-01-01 00:00:00
|
|
||||||
|
|
||||||
This migration creates all OIDC-related tables for the authorization code flow,
|
|
||||||
refresh token management, OIDC session tracking, token metadata, and audit logging.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
# Revision identifiers
|
|
||||||
revision = '001'
|
|
||||||
down_revision = None
|
|
||||||
branch_labels = None
|
|
||||||
depends_on = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
|
||||||
"""Create OIDC tables."""
|
|
||||||
|
|
||||||
# OIDC Authorization Codes table
|
|
||||||
op.create_table(
|
|
||||||
'oidc_authorization_codes',
|
|
||||||
sa.Column('id', sa.String(36), primary_key=True),
|
|
||||||
sa.Column('created_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('deleted_at', sa.DateTime, nullable=True),
|
|
||||||
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
|
|
||||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
|
|
||||||
sa.Column('code_hash', sa.String(255), nullable=False),
|
|
||||||
sa.Column('redirect_uri', sa.String(512), nullable=False),
|
|
||||||
sa.Column('scope', postgresql.JSON, nullable=True),
|
|
||||||
sa.Column('nonce', sa.String(255), nullable=True),
|
|
||||||
sa.Column('code_verifier', sa.String(255), nullable=True),
|
|
||||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('used_at', sa.DateTime, nullable=True),
|
|
||||||
sa.Column('is_used', sa.Boolean, default=False, nullable=False),
|
|
||||||
sa.Column('ip_address', sa.String(45), nullable=True),
|
|
||||||
sa.Column('user_agent', sa.Text, nullable=True),
|
|
||||||
)
|
|
||||||
op.create_index('ix_oidc_authorization_codes_client_id', 'oidc_authorization_codes', ['client_id'])
|
|
||||||
op.create_index('ix_oidc_authorization_codes_user_id', 'oidc_authorization_codes', ['user_id'])
|
|
||||||
op.create_index('ix_oidc_authorization_codes_expires_at', 'oidc_authorization_codes', ['expires_at'])
|
|
||||||
|
|
||||||
# OIDC Refresh Tokens table
|
|
||||||
op.create_table(
|
|
||||||
'oidc_refresh_tokens',
|
|
||||||
sa.Column('id', sa.String(36), primary_key=True),
|
|
||||||
sa.Column('created_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('deleted_at', sa.DateTime, nullable=True),
|
|
||||||
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
|
|
||||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
|
|
||||||
sa.Column('token_hash', sa.String(255), nullable=False),
|
|
||||||
sa.Column('access_token_id', sa.String(36), sa.ForeignKey('sessions.id'), nullable=True),
|
|
||||||
sa.Column('scope', postgresql.JSON, nullable=True),
|
|
||||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('revoked_at', sa.DateTime, nullable=True),
|
|
||||||
sa.Column('revoked_reason', sa.String(255), nullable=True),
|
|
||||||
sa.Column('previous_token_hash', sa.String(255), nullable=True),
|
|
||||||
sa.Column('rotation_count', sa.Integer, default=0, nullable=False),
|
|
||||||
sa.Column('ip_address', sa.String(45), nullable=True),
|
|
||||||
sa.Column('user_agent', sa.Text, nullable=True),
|
|
||||||
)
|
|
||||||
op.create_index('ix_oidc_refresh_tokens_client_id', 'oidc_refresh_tokens', ['client_id'])
|
|
||||||
op.create_index('ix_oidc_refresh_tokens_user_id', 'oidc_refresh_tokens', ['user_id'])
|
|
||||||
op.create_index('ix_oidc_refresh_tokens_token_hash', 'oidc_refresh_tokens', ['token_hash'], unique=True)
|
|
||||||
op.create_index('ix_oidc_refresh_tokens_access_token_id', 'oidc_refresh_tokens', ['access_token_id'])
|
|
||||||
op.create_index('ix_oidc_refresh_tokens_expires_at', 'oidc_refresh_tokens', ['expires_at'])
|
|
||||||
|
|
||||||
# OIDC Sessions table
|
|
||||||
op.create_table(
|
|
||||||
'oidc_sessions',
|
|
||||||
sa.Column('id', sa.String(36), primary_key=True),
|
|
||||||
sa.Column('created_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('deleted_at', sa.DateTime, nullable=True),
|
|
||||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
|
|
||||||
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
|
|
||||||
sa.Column('state', sa.String(255), nullable=False),
|
|
||||||
sa.Column('nonce', sa.String(255), nullable=True),
|
|
||||||
sa.Column('redirect_uri', sa.String(512), nullable=False),
|
|
||||||
sa.Column('scope', postgresql.JSON, nullable=True),
|
|
||||||
sa.Column('code_challenge', sa.String(255), nullable=True),
|
|
||||||
sa.Column('code_challenge_method', sa.String(10), nullable=True),
|
|
||||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('authenticated_at', sa.DateTime, nullable=True),
|
|
||||||
)
|
|
||||||
op.create_index('ix_oidc_sessions_user_id', 'oidc_sessions', ['user_id'])
|
|
||||||
op.create_index('ix_oidc_sessions_client_id', 'oidc_sessions', ['client_id'])
|
|
||||||
op.create_index('ix_oidc_sessions_state', 'oidc_sessions', ['state'])
|
|
||||||
op.create_index('ix_oidc_sessions_expires_at', 'oidc_sessions', ['expires_at'])
|
|
||||||
|
|
||||||
# OIDC Token Metadata table
|
|
||||||
op.create_table(
|
|
||||||
'oidc_token_metadata',
|
|
||||||
sa.Column('id', sa.String(36), primary_key=True),
|
|
||||||
sa.Column('created_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('deleted_at', sa.DateTime, nullable=True),
|
|
||||||
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
|
|
||||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
|
|
||||||
sa.Column('token_type', sa.String(50), nullable=False),
|
|
||||||
sa.Column('token_jti', sa.String(255), nullable=False),
|
|
||||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('revoked_at', sa.DateTime, nullable=True),
|
|
||||||
sa.Column('revoked_reason', sa.String(255), nullable=True),
|
|
||||||
)
|
|
||||||
op.create_index('ix_oidc_token_metadata_client_id', 'oidc_token_metadata', ['client_id'])
|
|
||||||
op.create_index('ix_oidc_token_metadata_user_id', 'oidc_token_metadata', ['user_id'])
|
|
||||||
op.create_index('ix_oidc_token_metadata_token_jti', 'oidc_token_metadata', ['token_jti'])
|
|
||||||
op.create_index('ix_oidc_token_metadata_expires_at', 'oidc_token_metadata', ['expires_at'])
|
|
||||||
|
|
||||||
# OIDC Audit Logs table
|
|
||||||
op.create_table(
|
|
||||||
'oidc_audit_logs',
|
|
||||||
sa.Column('id', sa.String(36), primary_key=True),
|
|
||||||
sa.Column('created_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime, nullable=False),
|
|
||||||
sa.Column('deleted_at', sa.DateTime, nullable=True),
|
|
||||||
sa.Column('event_type', sa.String(100), nullable=False),
|
|
||||||
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=True),
|
|
||||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=True),
|
|
||||||
sa.Column('success', sa.Boolean, default=True, nullable=False),
|
|
||||||
sa.Column('error_code', sa.String(100), nullable=True),
|
|
||||||
sa.Column('error_description', sa.Text, nullable=True),
|
|
||||||
sa.Column('ip_address', sa.String(45), nullable=True),
|
|
||||||
sa.Column('user_agent', sa.Text, nullable=True),
|
|
||||||
sa.Column('request_id', sa.String(36), nullable=True),
|
|
||||||
sa.Column('event_metadata', postgresql.JSON, nullable=True),
|
|
||||||
)
|
|
||||||
op.create_index('ix_oidc_audit_logs_event_type', 'oidc_audit_logs', ['event_type'])
|
|
||||||
op.create_index('ix_oidc_audit_logs_client_id', 'oidc_audit_logs', ['client_id'])
|
|
||||||
op.create_index('ix_oidc_audit_logs_user_id', 'oidc_audit_logs', ['user_id'])
|
|
||||||
op.create_index('ix_oidc_audit_logs_success', 'oidc_audit_logs', ['success'])
|
|
||||||
op.create_index('ix_oidc_audit_logs_ip_address', 'oidc_audit_logs', ['ip_address'])
|
|
||||||
op.create_index('ix_oidc_audit_logs_request_id', 'oidc_audit_logs', ['request_id'])
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
"""Drop OIDC tables."""
|
|
||||||
op.drop_table('oidc_audit_logs')
|
|
||||||
op.drop_table('oidc_token_metadata')
|
|
||||||
op.drop_table('oidc_sessions')
|
|
||||||
op.drop_table('oidc_refresh_tokens')
|
|
||||||
op.drop_table('oidc_authorization_codes')
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
"""Database migration: Add WebAuthn support.
|
|
||||||
|
|
||||||
Revision ID: 002
|
|
||||||
Revises: 001
|
|
||||||
Create Date: 2024-01-15 00:00:00
|
|
||||||
|
|
||||||
This migration adds support for WebAuthn passkey authentication by:
|
|
||||||
- Adding WEBAUTHN to the AuthMethodType enum (handled in application code)
|
|
||||||
- No schema changes required (uses existing provider_data JSON field)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
# Revision identifiers
|
|
||||||
revision = '002'
|
|
||||||
down_revision = '001'
|
|
||||||
branch_labels = None
|
|
||||||
depends_on = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
|
||||||
"""Add WebAuthn support - no schema changes needed."""
|
|
||||||
# WebAuthn credentials are stored in the existing provider_data JSON field
|
|
||||||
# of the authentication_methods table. No schema changes are required.
|
|
||||||
|
|
||||||
# Create an index for faster lookups of WebAuthn credentials by user
|
|
||||||
# This is optional but recommended for performance
|
|
||||||
# op.create_index(
|
|
||||||
# 'ix_authentication_methods_webauthn_user',
|
|
||||||
# 'authentication_methods',
|
|
||||||
# ['user_id'],
|
|
||||||
# postgresql_where=(sa.text("method_type = 'webauthn'")),
|
|
||||||
# if_not_exists=True
|
|
||||||
# )
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
"""Remove WebAuthn support - no schema changes needed."""
|
|
||||||
# No schema changes to revert
|
|
||||||
pass
|
|
||||||
@@ -0,0 +1,933 @@
|
|||||||
|
"""Integration tests for MFA compliance enforcement."""
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from gatehouse_app.models.user import User
|
||||||
|
from gatehouse_app.models.organization import Organization
|
||||||
|
from gatehouse_app.models.organization_member import OrganizationMember
|
||||||
|
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
|
||||||
|
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
||||||
|
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
|
||||||
|
from gatehouse_app.models.session import Session
|
||||||
|
from gatehouse_app.utils.constants import MfaPolicyMode, MfaComplianceStatus, UserStatus, MfaRequirementOverride
|
||||||
|
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestMfaComplianceLogin:
|
||||||
|
"""Integration tests for MFA compliance during login."""
|
||||||
|
|
||||||
|
def test_login_with_no_policy(self, client, db, test_user):
|
||||||
|
"""Test login with no MFA policy (should work normally)."""
|
||||||
|
login_data = {
|
||||||
|
"email": test_user.email,
|
||||||
|
"password": "TestPassword123!",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
data=json.dumps(login_data),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.get_json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert "user" in data["data"]
|
||||||
|
assert "token" in data["data"]
|
||||||
|
# No MFA compliance info should be present when no policy exists
|
||||||
|
assert "mfa_compliance" not in data["data"]
|
||||||
|
assert "requires_mfa_enrollment" not in data["data"]
|
||||||
|
|
||||||
|
def test_login_with_optional_policy(self, client, db, test_user, test_organization):
|
||||||
|
"""Test login with optional MFA policy (should work normally)."""
|
||||||
|
# Create an optional MFA policy
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
login_data = {
|
||||||
|
"email": test_user.email,
|
||||||
|
"password": "TestPassword123!",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
data=json.dumps(login_data),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.get_json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert "user" in data["data"]
|
||||||
|
assert "token" in data["data"]
|
||||||
|
# MFA compliance should be present but status should be not_applicable
|
||||||
|
assert "mfa_compliance" in data["data"]
|
||||||
|
assert data["data"]["mfa_compliance"]["overall_status"] == "not_applicable"
|
||||||
|
assert "requires_mfa_enrollment" not in data["data"]
|
||||||
|
|
||||||
|
def test_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization):
|
||||||
|
"""Test login with required policy within grace period (should work with warning)."""
|
||||||
|
# Create a required MFA policy
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
login_data = {
|
||||||
|
"email": test_user.email,
|
||||||
|
"password": "TestPassword123!",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
data=json.dumps(login_data),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.get_json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert "user" in data["data"]
|
||||||
|
assert "token" in data["data"]
|
||||||
|
# MFA compliance should be present with in_grace status
|
||||||
|
assert "mfa_compliance" in data["data"]
|
||||||
|
assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace"
|
||||||
|
assert "requires_mfa_enrollment" not in data["data"]
|
||||||
|
assert "totp" in data["data"]["mfa_compliance"]["missing_methods"]
|
||||||
|
|
||||||
|
def test_login_with_required_policy_after_deadline(self, client, db, test_user, test_organization):
|
||||||
|
"""Test login with required policy after deadline (should get compliance-only session)."""
|
||||||
|
# Create a required MFA policy with past deadline
|
||||||
|
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
|
||||||
|
# Create compliance record past due
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.PAST_DUE,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
||||||
|
deadline_at=past_deadline,
|
||||||
|
)
|
||||||
|
db.session.add(compliance)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
login_data = {
|
||||||
|
"email": test_user.email,
|
||||||
|
"password": "TestPassword123!",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
data=json.dumps(login_data),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.get_json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert "user" in data["data"]
|
||||||
|
assert "token" in data["data"]
|
||||||
|
# Should have compliance-only session
|
||||||
|
assert data["data"]["requires_mfa_enrollment"] is True
|
||||||
|
assert "mfa_compliance" in data["data"]
|
||||||
|
assert data["data"]["mfa_compliance"]["overall_status"] in ["past_due", "suspended"]
|
||||||
|
|
||||||
|
def test_login_with_suspended_user(self, client, db, test_user, test_organization):
|
||||||
|
"""Test login with compliance suspended user (should get compliance-only session)."""
|
||||||
|
# Set user status to compliance suspended
|
||||||
|
test_user.status = UserStatus.COMPLIANCE_SUSPENDED
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
login_data = {
|
||||||
|
"email": test_user.email,
|
||||||
|
"password": "TestPassword123!",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
data=json.dumps(login_data),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.get_json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert "user" in data["data"]
|
||||||
|
assert "token" in data["data"]
|
||||||
|
# Should have compliance-only session
|
||||||
|
assert data["data"]["requires_mfa_enrollment"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestMfaComplianceAccess:
|
||||||
|
"""Integration tests for MFA compliance access control."""
|
||||||
|
|
||||||
|
def test_compliance_only_session_denied_full_access(self, client, db, test_user, test_organization):
|
||||||
|
"""Test that compliance-only session cannot access full access endpoints."""
|
||||||
|
# Create a required MFA policy with past deadline
|
||||||
|
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
|
||||||
|
# Create compliance record past due
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.PAST_DUE,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
||||||
|
deadline_at=past_deadline,
|
||||||
|
)
|
||||||
|
db.session.add(compliance)
|
||||||
|
|
||||||
|
# Create a compliance-only session
|
||||||
|
session = Session(
|
||||||
|
user_id=test_user.id,
|
||||||
|
token="compliance_only_token",
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||||
|
is_compliance_only=True,
|
||||||
|
)
|
||||||
|
db.session.add(session)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Try to access a full-access endpoint (get_my_organizations)
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/users/me/organizations",
|
||||||
|
headers={"Authorization": "Bearer compliance_only_token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
data = response.get_json()
|
||||||
|
assert data["success"] is False
|
||||||
|
assert data["error_type"] == "MFA_COMPLIANCE_REQUIRED"
|
||||||
|
|
||||||
|
def test_compliance_only_session_can_access_mfa_enrollment(self, client, db, test_user, test_organization):
|
||||||
|
"""Test that compliance-only session can access MFA enrollment endpoints."""
|
||||||
|
# Create a required MFA policy with past deadline
|
||||||
|
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
|
||||||
|
# Create compliance record past due
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.PAST_DUE,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
||||||
|
deadline_at=past_deadline,
|
||||||
|
)
|
||||||
|
db.session.add(compliance)
|
||||||
|
|
||||||
|
# Create a compliance-only session
|
||||||
|
session = Session(
|
||||||
|
user_id=test_user.id,
|
||||||
|
token="compliance_only_token",
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||||
|
is_compliance_only=True,
|
||||||
|
)
|
||||||
|
db.session.add(session)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Try to access MFA enrollment endpoint (should work)
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/auth/totp/status",
|
||||||
|
headers={"Authorization": "Bearer compliance_only_token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.get_json()
|
||||||
|
assert data["success"] is True
|
||||||
|
|
||||||
|
def test_compliance_only_session_can_access_logout(self, client, db, test_user, test_organization):
|
||||||
|
"""Test that compliance-only session can access logout endpoint."""
|
||||||
|
# Create a required MFA policy with past deadline
|
||||||
|
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
|
||||||
|
# Create compliance record past due
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.PAST_DUE,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
||||||
|
deadline_at=past_deadline,
|
||||||
|
)
|
||||||
|
db.session.add(compliance)
|
||||||
|
|
||||||
|
# Create a compliance-only session
|
||||||
|
session = Session(
|
||||||
|
user_id=test_user.id,
|
||||||
|
token="compliance_only_token",
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||||
|
is_compliance_only=True,
|
||||||
|
)
|
||||||
|
db.session.add(session)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Try to access logout endpoint (should work)
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/logout",
|
||||||
|
headers={"Authorization": "Bearer compliance_only_token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.get_json()
|
||||||
|
assert data["success"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestMfaComplianceWebAuthn:
|
||||||
|
"""Integration tests for MFA compliance with WebAuthn login."""
|
||||||
|
|
||||||
|
def test_webauthn_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization):
|
||||||
|
"""Test WebAuthn login with required policy within grace period."""
|
||||||
|
# Create a required MFA policy
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Note: Full WebAuthn login test would require WebAuthn setup
|
||||||
|
# This test verifies the compliance response structure
|
||||||
|
login_data = {
|
||||||
|
"email": test_user.email,
|
||||||
|
"password": "TestPassword123!",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
data=json.dumps(login_data),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.get_json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert "mfa_compliance" in data["data"]
|
||||||
|
assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestMfaComplianceOIDC:
|
||||||
|
"""Integration tests for MFA compliance with OIDC authorization."""
|
||||||
|
|
||||||
|
def test_oidc_authorize_with_compliance_required(self, client, db, test_user, test_organization, app):
|
||||||
|
"""Test OIDC authorize with compliance required (should show error)."""
|
||||||
|
# Create a required MFA policy with past deadline
|
||||||
|
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
|
||||||
|
# Create compliance record past due
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.PAST_DUE,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
||||||
|
deadline_at=past_deadline,
|
||||||
|
)
|
||||||
|
db.session.add(compliance)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Try OIDC authorize with credentials
|
||||||
|
response = client.post(
|
||||||
|
"/oidc/authorize",
|
||||||
|
data={
|
||||||
|
"client_id": "test_client",
|
||||||
|
"redirect_uri": "http://localhost:8080/callback",
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": "openid profile email",
|
||||||
|
"state": "test_state",
|
||||||
|
"email": test_user.email,
|
||||||
|
"password": "TestPassword123!",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return login page with error
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"Your account requires multi factor enrollment before using single sign on" in response.data
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Phase 4: Edge Case Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestMfaComplianceMultiOrg:
|
||||||
|
"""Integration tests for multi-organization MFA compliance edge cases."""
|
||||||
|
|
||||||
|
def test_user_with_multiple_orgs_different_policies(self, client, db, test_user):
|
||||||
|
"""Test user belonging to multiple orgs with different MFA policies."""
|
||||||
|
# Create two organizations
|
||||||
|
org1 = Organization(
|
||||||
|
name="Org1",
|
||||||
|
slug="org1-test-multi",
|
||||||
|
)
|
||||||
|
org2 = Organization(
|
||||||
|
name="Org2",
|
||||||
|
slug="org2-test-multi",
|
||||||
|
)
|
||||||
|
db.session.add_all([org1, org2])
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Add user to both orgs
|
||||||
|
membership1 = OrganizationMember(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=org1.id,
|
||||||
|
role="member",
|
||||||
|
)
|
||||||
|
membership2 = OrganizationMember(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=org2.id,
|
||||||
|
role="member",
|
||||||
|
)
|
||||||
|
db.session.add_all([membership1, membership2])
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create different policies for each org
|
||||||
|
# Org1: OPTIONAL (no requirement)
|
||||||
|
policy1 = OrganizationSecurityPolicy(
|
||||||
|
organization_id=org1.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
# Org2: REQUIRE_TOTP (strictest)
|
||||||
|
policy2 = OrganizationSecurityPolicy(
|
||||||
|
organization_id=org2.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add_all([policy1, policy2])
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Evaluate user MFA state
|
||||||
|
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user)
|
||||||
|
|
||||||
|
# Overall status should reflect the strictest policy (REQUIRE_TOTP from org2)
|
||||||
|
assert compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value
|
||||||
|
assert "totp" in compliance_summary.missing_methods
|
||||||
|
|
||||||
|
# Verify per-org breakdown
|
||||||
|
assert len(compliance_summary.orgs) == 2
|
||||||
|
org1_status = next((o for o in compliance_summary.orgs if o.organization_id == org1.id), None)
|
||||||
|
org2_status = next((o for o in compliance_summary.orgs if o.organization_id == org2.id), None)
|
||||||
|
|
||||||
|
assert org1_status is not None
|
||||||
|
assert org2_status is not None
|
||||||
|
assert org1_status.status == MfaComplianceStatus.NOT_APPLICABLE.value
|
||||||
|
assert org2_status.status == MfaComplianceStatus.IN_GRACE.value
|
||||||
|
|
||||||
|
def test_user_with_multiple_orgs_all_suspended(self, client, db, test_user):
|
||||||
|
"""Test user with multiple orgs where all require MFA and are past due."""
|
||||||
|
# Create two organizations
|
||||||
|
org1 = Organization(
|
||||||
|
name="Org1",
|
||||||
|
slug="org1-test-suspended",
|
||||||
|
)
|
||||||
|
org2 = Organization(
|
||||||
|
name="Org2",
|
||||||
|
slug="org2-test-suspended",
|
||||||
|
)
|
||||||
|
db.session.add_all([org1, org2])
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Add user to both orgs
|
||||||
|
membership1 = OrganizationMember(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=org1.id,
|
||||||
|
role="member",
|
||||||
|
)
|
||||||
|
membership2 = OrganizationMember(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=org2.id,
|
||||||
|
role="member",
|
||||||
|
)
|
||||||
|
db.session.add_all([membership1, membership2])
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create required policies
|
||||||
|
policy1 = OrganizationSecurityPolicy(
|
||||||
|
organization_id=org1.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
policy2 = OrganizationSecurityPolicy(
|
||||||
|
organization_id=org2.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add_all([policy1, policy2])
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create past-due compliance records for both
|
||||||
|
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
||||||
|
compliance1 = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=org1.id,
|
||||||
|
status=MfaComplianceStatus.SUSPENDED,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
|
||||||
|
deadline_at=past_deadline,
|
||||||
|
suspended_at=past_deadline,
|
||||||
|
)
|
||||||
|
compliance2 = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=org2.id,
|
||||||
|
status=MfaComplianceStatus.SUSPENDED,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
|
||||||
|
deadline_at=past_deadline,
|
||||||
|
suspended_at=past_deadline,
|
||||||
|
)
|
||||||
|
db.session.add_all([compliance1, compliance2])
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Evaluate user MFA state
|
||||||
|
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user)
|
||||||
|
|
||||||
|
# Overall status should be SUSPENDED
|
||||||
|
assert compliance_summary.overall_status == MfaComplianceStatus.SUSPENDED.value
|
||||||
|
|
||||||
|
def test_strictest_mode_selection(self):
|
||||||
|
"""Test that get_strictest_mode returns the most restrictive policy."""
|
||||||
|
modes = [
|
||||||
|
MfaPolicyMode.DISABLED.value,
|
||||||
|
MfaPolicyMode.OPTIONAL.value,
|
||||||
|
MfaPolicyMode.REQUIRE_TOTP.value,
|
||||||
|
]
|
||||||
|
result = MfaPolicyService.get_strictest_mode(modes)
|
||||||
|
assert result == MfaPolicyMode.REQUIRE_TOTP.value
|
||||||
|
|
||||||
|
# Test with REQUIRE_TOTP_OR_WEBAUTHN (strictest)
|
||||||
|
modes_strictest = [
|
||||||
|
MfaPolicyMode.REQUIRE_TOTP.value,
|
||||||
|
MfaPolicyMode.REQUIRE_WEBAUTHN.value,
|
||||||
|
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
|
||||||
|
]
|
||||||
|
result = MfaPolicyService.get_strictest_mode(modes_strictest)
|
||||||
|
assert result == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestMfaComplianceUserOverrides:
|
||||||
|
"""Integration tests for user override edge cases."""
|
||||||
|
|
||||||
|
def test_user_override_inherit_mode(self, client, db, test_user, test_organization):
|
||||||
|
"""Test INHERIT mode - org policy applies as is."""
|
||||||
|
# Create a required policy
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create INHERIT override (default behavior)
|
||||||
|
override = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.INHERIT,
|
||||||
|
)
|
||||||
|
db.session.add(override)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Get effective policy
|
||||||
|
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
||||||
|
|
||||||
|
# Should inherit org policy
|
||||||
|
assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
|
||||||
|
assert effective.requires_totp is True
|
||||||
|
assert effective.is_exempt is False
|
||||||
|
|
||||||
|
def test_user_override_required_mode(self, client, db, test_user, test_organization):
|
||||||
|
"""Test REQUIRED mode - user always required to have MFA."""
|
||||||
|
# Create an optional policy
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create REQUIRED override
|
||||||
|
override = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.REQUIRED,
|
||||||
|
)
|
||||||
|
db.session.add(override)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Get effective policy
|
||||||
|
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
||||||
|
|
||||||
|
# Should be upgraded to REQUIRE_TOTP_OR_WEBAUTHN
|
||||||
|
assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
|
||||||
|
assert effective.requires_totp is True
|
||||||
|
assert effective.requires_webauthn is True
|
||||||
|
assert effective.is_exempt is False
|
||||||
|
|
||||||
|
def test_user_override_exempt_mode(self, client, db, test_user, test_organization):
|
||||||
|
"""Test EXEMPT mode - org policy does not apply."""
|
||||||
|
# Create a required policy
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create EXEMPT override
|
||||||
|
override = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
||||||
|
)
|
||||||
|
db.session.add(override)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Get effective policy
|
||||||
|
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
||||||
|
|
||||||
|
# Should be exempt from policy
|
||||||
|
assert effective.is_exempt is True
|
||||||
|
assert effective.effective_mode == MfaPolicyMode.DISABLED.value
|
||||||
|
assert effective.requires_totp is False
|
||||||
|
assert effective.requires_webauthn is False
|
||||||
|
|
||||||
|
def test_get_override_summary(self, client, db, test_user, test_organization):
|
||||||
|
"""Test getting override summary for a user."""
|
||||||
|
# No override exists
|
||||||
|
summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id)
|
||||||
|
|
||||||
|
assert summary["has_override"] is False
|
||||||
|
assert summary["mode"] == "inherit"
|
||||||
|
|
||||||
|
# Create an override
|
||||||
|
override = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
||||||
|
)
|
||||||
|
db.session.add(override)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Get summary again
|
||||||
|
summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id)
|
||||||
|
|
||||||
|
assert summary["has_override"] is True
|
||||||
|
assert summary["mode"] == "exempt"
|
||||||
|
assert summary["is_exempt"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestMfaCompliancePolicyChanges:
|
||||||
|
"""Integration tests for policy changes affecting existing users."""
|
||||||
|
|
||||||
|
def test_policy_change_triggers_compliance_reevaluation(self, client, db, test_user, test_organization):
|
||||||
|
"""Test that policy change triggers compliance reevaluation."""
|
||||||
|
# Create initial optional policy
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create compliance record (should be NOT_APPLICABLE)
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.NOT_APPLICABLE,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(compliance)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Update policy to REQUIRE_TOTP
|
||||||
|
MfaPolicyService.create_org_policy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
updated_by_user_id=test_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reevaluate all compliance
|
||||||
|
updated_count = MfaPolicyService.reevaluate_all_org_compliance(test_organization.id)
|
||||||
|
|
||||||
|
# Should have updated at least one record
|
||||||
|
assert updated_count >= 1
|
||||||
|
|
||||||
|
# Check compliance status was updated
|
||||||
|
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
|
||||||
|
assert updated_compliance.status == MfaComplianceStatus.IN_GRACE.value
|
||||||
|
assert updated_compliance.deadline_at is not None
|
||||||
|
|
||||||
|
def test_policy_relaxation_clears_requirements(self, client, db, test_user, test_organization):
|
||||||
|
"""Test that relaxing policy clears compliance requirements."""
|
||||||
|
# Create required policy
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create IN_GRACE compliance record
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.IN_GRACE,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=datetime.now(timezone.utc),
|
||||||
|
deadline_at=datetime.now(timezone.utc) + timedelta(days=14),
|
||||||
|
)
|
||||||
|
db.session.add(compliance)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Update policy to OPTIONAL
|
||||||
|
MfaPolicyService.create_org_policy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
updated_by_user_id=test_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reevaluate compliance
|
||||||
|
MfaPolicyService.reevaluate_all_org_compliance(test_organization.id)
|
||||||
|
|
||||||
|
# Check compliance status was updated to NOT_APPLICABLE
|
||||||
|
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
|
||||||
|
assert updated_compliance.status == MfaComplianceStatus.NOT_APPLICABLE.value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestMfaComplianceScheduledJob:
|
||||||
|
"""Integration tests for the MFA compliance scheduled job."""
|
||||||
|
|
||||||
|
def test_transition_to_suspended(self, client, db, test_user, test_organization):
|
||||||
|
"""Test that past-due users are transitioned to suspended."""
|
||||||
|
# Create required policy
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create past-due compliance record
|
||||||
|
past_deadline = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.PAST_DUE,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
||||||
|
deadline_at=past_deadline,
|
||||||
|
)
|
||||||
|
db.session.add(compliance)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Run the job
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
suspended_count = MfaPolicyService.transition_to_suspended_if_past_due(now)
|
||||||
|
|
||||||
|
# Should have suspended the user
|
||||||
|
assert suspended_count >= 1
|
||||||
|
|
||||||
|
# Check compliance status
|
||||||
|
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
|
||||||
|
assert updated_compliance.status == MfaComplianceStatus.SUSPENDED.value
|
||||||
|
assert updated_compliance.suspended_at is not None
|
||||||
|
|
||||||
|
# Check user status
|
||||||
|
db.refresh(test_user)
|
||||||
|
assert test_user.status == UserStatus.COMPLIANCE_SUSPENDED
|
||||||
|
|
||||||
|
def test_check_and_restore_user_status(self, client, db, test_user, test_organization):
|
||||||
|
"""Test that suspended users are restored when they become compliant."""
|
||||||
|
# Create required policy
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add(policy)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# User is suspended
|
||||||
|
test_user.status = UserStatus.COMPLIANCE_SUSPENDED
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create EXEMPT override to clear requirement
|
||||||
|
override = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
||||||
|
)
|
||||||
|
db.session.add(override)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Check and restore status
|
||||||
|
restored = MfaPolicyService.check_and_restore_user_status(test_user.id)
|
||||||
|
|
||||||
|
# Should have restored user
|
||||||
|
assert restored is True
|
||||||
|
db.refresh(test_user)
|
||||||
|
assert test_user.status == UserStatus.ACTIVE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestMfaComplianceMultiOrgAggregate:
|
||||||
|
"""Integration tests for multi-org aggregate state calculation."""
|
||||||
|
|
||||||
|
def test_get_multi_org_aggregate_state(self, client, db, test_user):
|
||||||
|
"""Test aggregate state calculation for multi-org user."""
|
||||||
|
# Create two organizations
|
||||||
|
org1 = Organization(
|
||||||
|
name="AggOrg1",
|
||||||
|
slug="agg-org1-test",
|
||||||
|
)
|
||||||
|
org2 = Organization(
|
||||||
|
name="AggOrg2",
|
||||||
|
slug="agg-org2-test",
|
||||||
|
)
|
||||||
|
db.session.add_all([org1, org2])
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Add user to both
|
||||||
|
membership1 = OrganizationMember(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=org1.id,
|
||||||
|
role="member",
|
||||||
|
)
|
||||||
|
membership2 = OrganizationMember(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=org2.id,
|
||||||
|
role="member",
|
||||||
|
)
|
||||||
|
db.session.add_all([membership1, membership2])
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create policies
|
||||||
|
policy1 = OrganizationSecurityPolicy(
|
||||||
|
organization_id=org1.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
policy2 = OrganizationSecurityPolicy(
|
||||||
|
organization_id=org2.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
db.session.add_all([policy1, policy2])
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Get aggregate state
|
||||||
|
aggregate = MfaPolicyService.get_multi_org_aggregate_state(test_user)
|
||||||
|
|
||||||
|
# Verify structure
|
||||||
|
assert "overall_status" in aggregate
|
||||||
|
assert "strictest_mode" in aggregate
|
||||||
|
assert "missing_methods" in aggregate
|
||||||
|
assert "requiring_org_count" in aggregate
|
||||||
|
assert "requiring_orgs" in aggregate
|
||||||
|
assert "per_org_details" in aggregate
|
||||||
|
|
||||||
|
# Strictest mode should be REQUIRE_TOTP_OR_WEBAUTHN
|
||||||
|
assert aggregate["strictest_mode"] == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
|
||||||
|
|
||||||
|
# Both orgs should require MFA
|
||||||
|
assert aggregate["requiring_org_count"] == 2
|
||||||
|
assert len(aggregate["requiring_orgs"]) == 2
|
||||||
|
assert len(aggregate["per_org_details"]) == 2
|
||||||
@@ -0,0 +1,295 @@
|
|||||||
|
"""Unit tests for MFA policy models."""
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from gatehouse_app.models import (
|
||||||
|
User,
|
||||||
|
Organization,
|
||||||
|
OrganizationMember,
|
||||||
|
OrganizationSecurityPolicy,
|
||||||
|
UserSecurityPolicy,
|
||||||
|
MfaPolicyCompliance,
|
||||||
|
Session,
|
||||||
|
)
|
||||||
|
from gatehouse_app.utils.constants import (
|
||||||
|
UserStatus,
|
||||||
|
MfaPolicyMode,
|
||||||
|
MfaComplianceStatus,
|
||||||
|
MfaRequirementOverride,
|
||||||
|
SessionStatus,
|
||||||
|
OrganizationRole,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestOrganizationSecurityPolicyModel:
|
||||||
|
"""Tests for OrganizationSecurityPolicy model."""
|
||||||
|
|
||||||
|
def test_create_org_security_policy(self, db, test_organization):
|
||||||
|
"""Test creating an organization security policy."""
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
)
|
||||||
|
policy.save()
|
||||||
|
|
||||||
|
assert policy.id is not None
|
||||||
|
assert policy.organization_id == test_organization.id
|
||||||
|
assert policy.mfa_policy_mode == MfaPolicyMode.OPTIONAL
|
||||||
|
assert policy.mfa_grace_period_days == 14
|
||||||
|
assert policy.notify_days_before == 7
|
||||||
|
assert policy.policy_version == 1
|
||||||
|
assert policy.created_at is not None
|
||||||
|
|
||||||
|
def test_org_security_policy_to_dict(self, db, test_organization):
|
||||||
|
"""Test organization security policy to_dict method."""
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=7,
|
||||||
|
notify_days_before=3,
|
||||||
|
)
|
||||||
|
policy.save()
|
||||||
|
|
||||||
|
policy_dict = policy.to_dict()
|
||||||
|
|
||||||
|
assert "id" in policy_dict
|
||||||
|
assert "organization_id" in policy_dict
|
||||||
|
assert policy_dict["organization_id"] == test_organization.id
|
||||||
|
assert "mfa_policy_mode" in policy_dict
|
||||||
|
assert "mfa_grace_period_days" in policy_dict
|
||||||
|
|
||||||
|
def test_org_security_policy_relationships(self, db, test_organization):
|
||||||
|
"""Test organization security policy relationships."""
|
||||||
|
policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
)
|
||||||
|
policy.save()
|
||||||
|
|
||||||
|
# Test relationship
|
||||||
|
assert policy.organization is not None
|
||||||
|
assert policy.organization.id == test_organization.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestUserSecurityPolicyModel:
|
||||||
|
"""Tests for UserSecurityPolicy model."""
|
||||||
|
|
||||||
|
def test_create_user_security_policy(self, db, test_user, test_organization):
|
||||||
|
"""Test creating a user security policy."""
|
||||||
|
policy = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.INHERIT,
|
||||||
|
)
|
||||||
|
policy.save()
|
||||||
|
|
||||||
|
assert policy.id is not None
|
||||||
|
assert policy.user_id == test_user.id
|
||||||
|
assert policy.organization_id == test_organization.id
|
||||||
|
assert policy.mfa_override_mode == MfaRequirementOverride.INHERIT
|
||||||
|
assert policy.force_totp is False
|
||||||
|
assert policy.force_webauthn is False
|
||||||
|
|
||||||
|
def test_user_security_policy_with_overrides(self, db, test_user, test_organization):
|
||||||
|
"""Test user security policy with override settings."""
|
||||||
|
policy = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.REQUIRED,
|
||||||
|
force_totp=True,
|
||||||
|
force_webauthn=False,
|
||||||
|
)
|
||||||
|
policy.save()
|
||||||
|
|
||||||
|
assert policy.mfa_override_mode == MfaRequirementOverride.REQUIRED
|
||||||
|
assert policy.force_totp is True
|
||||||
|
assert policy.force_webauthn is False
|
||||||
|
|
||||||
|
def test_user_security_policy_exempt(self, db, test_user, test_organization):
|
||||||
|
"""Test user security policy with exempt override."""
|
||||||
|
policy = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
||||||
|
)
|
||||||
|
policy.save()
|
||||||
|
|
||||||
|
assert policy.mfa_override_mode == MfaRequirementOverride.EXEMPT
|
||||||
|
|
||||||
|
def test_user_security_policy_relationships(self, db, test_user, test_organization):
|
||||||
|
"""Test user security policy relationships."""
|
||||||
|
policy = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.INHERIT,
|
||||||
|
)
|
||||||
|
policy.save()
|
||||||
|
|
||||||
|
# Test relationships
|
||||||
|
assert policy.user is not None
|
||||||
|
assert policy.user.id == test_user.id
|
||||||
|
assert policy.organization is not None
|
||||||
|
assert policy.organization.id == test_organization.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestMfaPolicyComplianceModel:
|
||||||
|
"""Tests for MfaPolicyCompliance model."""
|
||||||
|
|
||||||
|
def test_create_mfa_policy_compliance(self, db, test_user, test_organization):
|
||||||
|
"""Test creating an MFA policy compliance record."""
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.NOT_APPLICABLE,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
compliance.save()
|
||||||
|
|
||||||
|
assert compliance.id is not None
|
||||||
|
assert compliance.user_id == test_user.id
|
||||||
|
assert compliance.organization_id == test_organization.id
|
||||||
|
assert compliance.status == MfaComplianceStatus.NOT_APPLICABLE
|
||||||
|
assert compliance.policy_version == 1
|
||||||
|
assert compliance.notification_count == 0
|
||||||
|
|
||||||
|
def test_mfa_policy_compliance_in_grace(self, db, test_user, test_organization):
|
||||||
|
"""Test MFA compliance record in grace period."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.IN_GRACE,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=now,
|
||||||
|
deadline_at=now + timedelta(days=14),
|
||||||
|
)
|
||||||
|
compliance.save()
|
||||||
|
|
||||||
|
assert compliance.status == MfaComplianceStatus.IN_GRACE
|
||||||
|
assert compliance.applied_at is not None
|
||||||
|
assert compliance.deadline_at is not None
|
||||||
|
assert compliance.deadline_at > now
|
||||||
|
|
||||||
|
def test_mfa_policy_compliance_compliant(self, db, test_user, test_organization):
|
||||||
|
"""Test MFA compliance record when compliant."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.COMPLIANT,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=now - timedelta(days=30),
|
||||||
|
deadline_at=now - timedelta(days=16),
|
||||||
|
compliant_at=now - timedelta(days=16),
|
||||||
|
)
|
||||||
|
compliance.save()
|
||||||
|
|
||||||
|
assert compliance.status == MfaComplianceStatus.COMPLIANT
|
||||||
|
assert compliance.compliant_at is not None
|
||||||
|
|
||||||
|
def test_mfa_policy_compliance_suspended(self, db, test_user, test_organization):
|
||||||
|
"""Test MFA compliance record when suspended."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.SUSPENDED,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=now - timedelta(days=30),
|
||||||
|
deadline_at=now - timedelta(days=16),
|
||||||
|
suspended_at=now - timedelta(days=16),
|
||||||
|
)
|
||||||
|
compliance.save()
|
||||||
|
|
||||||
|
assert compliance.status == MfaComplianceStatus.SUSPENDED
|
||||||
|
assert compliance.suspended_at is not None
|
||||||
|
|
||||||
|
def test_mfa_policy_compliance_relationships(self, db, test_user, test_organization):
|
||||||
|
"""Test MFA compliance relationships."""
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.NOT_APPLICABLE,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
compliance.save()
|
||||||
|
|
||||||
|
# Test relationships
|
||||||
|
assert compliance.user is not None
|
||||||
|
assert compliance.user.id == test_user.id
|
||||||
|
assert compliance.organization is not None
|
||||||
|
assert compliance.organization.id == test_organization.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSessionModelComplianceFlag:
|
||||||
|
"""Tests for Session model compliance flag."""
|
||||||
|
|
||||||
|
def test_session_default_not_compliance_only(self, db, test_user):
|
||||||
|
"""Test that sessions are not compliance only by default."""
|
||||||
|
session = Session(
|
||||||
|
user_id=test_user.id,
|
||||||
|
token="test-token-123",
|
||||||
|
status=SessionStatus.ACTIVE,
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
|
||||||
|
last_activity_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
session.save()
|
||||||
|
|
||||||
|
assert session.is_compliance_only is False
|
||||||
|
|
||||||
|
def test_session_compliance_only(self, db, test_user):
|
||||||
|
"""Test creating a compliance-only session."""
|
||||||
|
session = Session(
|
||||||
|
user_id=test_user.id,
|
||||||
|
token="compliance-token-123",
|
||||||
|
status=SessionStatus.ACTIVE,
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
|
||||||
|
last_activity_at=datetime.now(timezone.utc),
|
||||||
|
is_compliance_only=True,
|
||||||
|
)
|
||||||
|
session.save()
|
||||||
|
|
||||||
|
assert session.is_compliance_only is True
|
||||||
|
|
||||||
|
def test_session_to_dict_excludes_token(self, db, test_user):
|
||||||
|
"""Test that session to_dict excludes the token."""
|
||||||
|
session = Session(
|
||||||
|
user_id=test_user.id,
|
||||||
|
token="test-token-456",
|
||||||
|
status=SessionStatus.ACTIVE,
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
|
||||||
|
last_activity_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
session.save()
|
||||||
|
|
||||||
|
session_dict = session.to_dict()
|
||||||
|
|
||||||
|
assert "id" in session_dict
|
||||||
|
assert "user_id" in session_dict
|
||||||
|
assert "is_compliance_only" in session_dict
|
||||||
|
assert session_dict["is_compliance_only"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestUserStatusComplianceSuspended:
|
||||||
|
"""Tests for UserStatus.COMPLIANCE_SUSPENDED."""
|
||||||
|
|
||||||
|
def test_compliance_suspended_status_exists(self):
|
||||||
|
"""Test that COMPLIANCE_SUSPENDED status exists."""
|
||||||
|
assert UserStatus.COMPLIANCE_SUSPENDED.value == "compliance_suspended"
|
||||||
|
|
||||||
|
def test_create_compliance_suspended_user(self, db):
|
||||||
|
"""Test creating a compliance suspended user."""
|
||||||
|
user = User(
|
||||||
|
email="suspended@example.com",
|
||||||
|
full_name="Suspended User",
|
||||||
|
status=UserStatus.COMPLIANCE_SUSPENDED,
|
||||||
|
)
|
||||||
|
user.save()
|
||||||
|
|
||||||
|
assert user.status == UserStatus.COMPLIANCE_SUSPENDED
|
||||||
@@ -0,0 +1,476 @@
|
|||||||
|
"""Unit tests for MfaPolicyService."""
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from gatehouse_app.models import (
|
||||||
|
User,
|
||||||
|
Organization,
|
||||||
|
OrganizationMember,
|
||||||
|
OrganizationSecurityPolicy,
|
||||||
|
UserSecurityPolicy,
|
||||||
|
MfaPolicyCompliance,
|
||||||
|
Session,
|
||||||
|
)
|
||||||
|
from gatehouse_app.services.mfa_policy_service import (
|
||||||
|
MfaPolicyService,
|
||||||
|
OrgPolicyDto,
|
||||||
|
EffectiveUserPolicyDto,
|
||||||
|
AggregateMfaStateDto,
|
||||||
|
LoginPolicyResult,
|
||||||
|
)
|
||||||
|
from gatehouse_app.utils.constants import (
|
||||||
|
UserStatus,
|
||||||
|
MfaPolicyMode,
|
||||||
|
MfaComplianceStatus,
|
||||||
|
MfaRequirementOverride,
|
||||||
|
SessionStatus,
|
||||||
|
OrganizationRole,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestMfaPolicyService:
|
||||||
|
"""Tests for MfaPolicyService."""
|
||||||
|
|
||||||
|
def test_get_org_policy_not_found(self, db, test_organization):
|
||||||
|
"""Test getting organization policy when none exists."""
|
||||||
|
policy = MfaPolicyService.get_org_policy(test_organization.id)
|
||||||
|
assert policy is None
|
||||||
|
|
||||||
|
def test_get_org_policy_found(self, db, test_organization):
|
||||||
|
"""Test getting organization policy when it exists."""
|
||||||
|
# Create policy
|
||||||
|
org_policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
)
|
||||||
|
org_policy.save()
|
||||||
|
|
||||||
|
policy = MfaPolicyService.get_org_policy(test_organization.id)
|
||||||
|
|
||||||
|
assert policy is not None
|
||||||
|
assert policy.organization_id == test_organization.id
|
||||||
|
assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
|
||||||
|
assert policy.mfa_grace_period_days == 14
|
||||||
|
assert policy.notify_days_before == 7
|
||||||
|
assert policy.policy_version == 1
|
||||||
|
|
||||||
|
def test_get_effective_user_policy_no_org_policy(self, db, test_user, test_organization):
|
||||||
|
"""Test effective user policy when no org policy exists."""
|
||||||
|
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
||||||
|
|
||||||
|
assert policy is not None
|
||||||
|
assert policy.organization_id == test_organization.id
|
||||||
|
assert policy.effective_mode == MfaPolicyMode.DISABLED.value
|
||||||
|
assert policy.requires_totp is False
|
||||||
|
assert policy.requires_webauthn is False
|
||||||
|
assert policy.is_exempt is True
|
||||||
|
|
||||||
|
def test_get_effective_user_policy_with_org_policy(self, db, test_user, test_organization):
|
||||||
|
"""Test effective user policy with org policy and no override."""
|
||||||
|
# Create org policy
|
||||||
|
org_policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
)
|
||||||
|
org_policy.save()
|
||||||
|
|
||||||
|
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
||||||
|
|
||||||
|
assert policy is not None
|
||||||
|
assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
|
||||||
|
assert policy.requires_totp is True
|
||||||
|
assert policy.requires_webauthn is False
|
||||||
|
assert policy.is_exempt is False
|
||||||
|
|
||||||
|
def test_get_effective_user_policy_with_override_inherit(self, db, test_user, test_organization):
|
||||||
|
"""Test effective user policy with INHERIT override."""
|
||||||
|
# Create org policy
|
||||||
|
org_policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=7,
|
||||||
|
)
|
||||||
|
org_policy.save()
|
||||||
|
|
||||||
|
# Create user override
|
||||||
|
user_override = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.INHERIT,
|
||||||
|
)
|
||||||
|
user_override.save()
|
||||||
|
|
||||||
|
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
||||||
|
|
||||||
|
assert policy.effective_mode == MfaPolicyMode.REQUIRE_WEBAUTHN.value
|
||||||
|
assert policy.requires_webauthn is True
|
||||||
|
|
||||||
|
def test_get_effective_user_policy_with_override_exempt(self, db, test_user, test_organization):
|
||||||
|
"""Test effective user policy with EXEMPT override."""
|
||||||
|
# Create org policy
|
||||||
|
org_policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
)
|
||||||
|
org_policy.save()
|
||||||
|
|
||||||
|
# Create user override
|
||||||
|
user_override = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
||||||
|
)
|
||||||
|
user_override.save()
|
||||||
|
|
||||||
|
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
||||||
|
|
||||||
|
assert policy.effective_mode == MfaPolicyMode.DISABLED.value
|
||||||
|
assert policy.is_exempt is True
|
||||||
|
|
||||||
|
def test_get_effective_user_policy_with_override_required(self, db, test_user, test_organization):
|
||||||
|
"""Test effective user policy with REQUIRED override."""
|
||||||
|
# Create org policy
|
||||||
|
org_policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
)
|
||||||
|
org_policy.save()
|
||||||
|
|
||||||
|
# Create user override
|
||||||
|
user_override = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.REQUIRED,
|
||||||
|
)
|
||||||
|
user_override.save()
|
||||||
|
|
||||||
|
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
||||||
|
|
||||||
|
assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
|
||||||
|
assert policy.requires_totp is True
|
||||||
|
assert policy.requires_webauthn is True
|
||||||
|
assert policy.is_exempt is False
|
||||||
|
|
||||||
|
def test_evaluate_user_mfa_state_no_policy(self, db, test_user, test_organization):
|
||||||
|
"""Test evaluating user MFA state with no policy."""
|
||||||
|
# Create membership
|
||||||
|
membership = OrganizationMember(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
role=OrganizationRole.MEMBER,
|
||||||
|
)
|
||||||
|
membership.save()
|
||||||
|
|
||||||
|
state = MfaPolicyService.evaluate_user_mfa_state(test_user)
|
||||||
|
|
||||||
|
assert state is not None
|
||||||
|
assert state.overall_status == MfaComplianceStatus.COMPLIANT.value
|
||||||
|
assert len(state.missing_methods) == 0
|
||||||
|
assert len(state.orgs) == 1
|
||||||
|
|
||||||
|
def test_evaluate_user_mfa_state_with_policy(self, db, test_user, test_organization):
|
||||||
|
"""Test evaluating user MFA state with policy."""
|
||||||
|
# Create membership
|
||||||
|
membership = OrganizationMember(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
role=OrganizationRole.MEMBER,
|
||||||
|
)
|
||||||
|
membership.save()
|
||||||
|
|
||||||
|
# Create org policy
|
||||||
|
org_policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
)
|
||||||
|
org_policy.save()
|
||||||
|
|
||||||
|
state = MfaPolicyService.evaluate_user_mfa_state(test_user)
|
||||||
|
|
||||||
|
assert state is not None
|
||||||
|
assert state.overall_status == MfaComplianceStatus.IN_GRACE.value
|
||||||
|
assert "totp" in state.missing_methods
|
||||||
|
assert len(state.orgs) == 1
|
||||||
|
assert state.orgs[0].effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
|
||||||
|
|
||||||
|
def test_after_primary_auth_success_no_required_policy(self, db, test_user, test_organization):
|
||||||
|
"""Test after_primary_auth_success with no required policy."""
|
||||||
|
# Create membership
|
||||||
|
membership = OrganizationMember(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
role=OrganizationRole.MEMBER,
|
||||||
|
)
|
||||||
|
membership.save()
|
||||||
|
|
||||||
|
result = MfaPolicyService.after_primary_auth_success(test_user)
|
||||||
|
|
||||||
|
assert result.can_create_full_session is True
|
||||||
|
assert result.create_compliance_only_session is False
|
||||||
|
assert result.compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value
|
||||||
|
|
||||||
|
def test_after_primary_auth_success_in_grace(self, db, test_user, test_organization):
|
||||||
|
"""Test after_primary_auth_success when user is in grace period."""
|
||||||
|
# Create membership
|
||||||
|
membership = OrganizationMember(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
role=OrganizationRole.MEMBER,
|
||||||
|
)
|
||||||
|
membership.save()
|
||||||
|
|
||||||
|
# Create org policy
|
||||||
|
org_policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
)
|
||||||
|
org_policy.save()
|
||||||
|
|
||||||
|
result = MfaPolicyService.after_primary_auth_success(test_user)
|
||||||
|
|
||||||
|
assert result.can_create_full_session is True
|
||||||
|
assert result.create_compliance_only_session is False
|
||||||
|
assert result.compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value
|
||||||
|
|
||||||
|
def test_after_primary_auth_success_past_due(self, db, test_user, test_organization):
|
||||||
|
"""Test after_primary_auth_success when user is past due."""
|
||||||
|
# Create membership
|
||||||
|
membership = OrganizationMember(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
role=OrganizationRole.MEMBER,
|
||||||
|
)
|
||||||
|
membership.save()
|
||||||
|
|
||||||
|
# Create org policy
|
||||||
|
org_policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
)
|
||||||
|
org_policy.save()
|
||||||
|
|
||||||
|
# Create compliance record past due
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.PAST_DUE,
|
||||||
|
policy_version=1,
|
||||||
|
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
|
||||||
|
deadline_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||||
|
)
|
||||||
|
compliance.save()
|
||||||
|
|
||||||
|
result = MfaPolicyService.after_primary_auth_success(test_user)
|
||||||
|
|
||||||
|
assert result.can_create_full_session is False
|
||||||
|
assert result.create_compliance_only_session is True
|
||||||
|
|
||||||
|
def test_create_org_policy_new(self, db, test_organization):
|
||||||
|
"""Test creating a new organization policy."""
|
||||||
|
policy = MfaPolicyService.create_org_policy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
updated_by_user_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert policy is not None
|
||||||
|
assert policy.organization_id == test_organization.id
|
||||||
|
assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN
|
||||||
|
assert policy.policy_version == 1
|
||||||
|
|
||||||
|
def test_create_org_policy_update(self, db, test_organization):
|
||||||
|
"""Test updating an existing organization policy."""
|
||||||
|
# Create initial policy
|
||||||
|
initial_policy = OrganizationSecurityPolicy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
)
|
||||||
|
initial_policy.save()
|
||||||
|
|
||||||
|
# Update policy
|
||||||
|
updated_policy = MfaPolicyService.create_org_policy(
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
||||||
|
mfa_grace_period_days=7,
|
||||||
|
updated_by_user_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated_policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP
|
||||||
|
assert updated_policy.mfa_grace_period_days == 7
|
||||||
|
assert updated_policy.policy_version == 2
|
||||||
|
|
||||||
|
def test_set_user_override_new(self, db, test_user, test_organization):
|
||||||
|
"""Test setting a new user override."""
|
||||||
|
override = MfaPolicyService.set_user_override(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.REQUIRED,
|
||||||
|
force_totp=True,
|
||||||
|
force_webauthn=False,
|
||||||
|
updated_by_user_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert override is not None
|
||||||
|
assert override.user_id == test_user.id
|
||||||
|
assert override.organization_id == test_organization.id
|
||||||
|
assert override.mfa_override_mode == MfaRequirementOverride.REQUIRED
|
||||||
|
assert override.force_totp is True
|
||||||
|
|
||||||
|
def test_set_user_override_update(self, db, test_user, test_organization):
|
||||||
|
"""Test updating an existing user override."""
|
||||||
|
# Create initial override
|
||||||
|
initial_override = UserSecurityPolicy(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.INHERIT,
|
||||||
|
)
|
||||||
|
initial_override.save()
|
||||||
|
|
||||||
|
# Update override
|
||||||
|
updated_override = MfaPolicyService.set_user_override(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
||||||
|
updated_by_user_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated_override.mfa_override_mode == MfaRequirementOverride.EXEMPT
|
||||||
|
|
||||||
|
def test_get_user_compliance(self, db, test_user, test_organization):
|
||||||
|
"""Test getting user compliance record."""
|
||||||
|
# Create compliance record
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.COMPLIANT,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
compliance.save()
|
||||||
|
|
||||||
|
result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == MfaComplianceStatus.COMPLIANT
|
||||||
|
|
||||||
|
def test_get_user_compliance_not_found(self, db, test_user, test_organization):
|
||||||
|
"""Test getting user compliance record when none exists."""
|
||||||
|
result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_org_compliance_list(self, db, test_user, test_organization):
|
||||||
|
"""Test getting organization compliance list."""
|
||||||
|
# Create compliance record
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.IN_GRACE,
|
||||||
|
policy_version=1,
|
||||||
|
deadline_at=datetime.now(timezone.utc) + timedelta(days=14),
|
||||||
|
)
|
||||||
|
compliance.save()
|
||||||
|
|
||||||
|
results = MfaPolicyService.get_org_compliance_list(test_organization.id)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["user_id"] == test_user.id
|
||||||
|
assert results[0]["status"] == MfaComplianceStatus.IN_GRACE.value
|
||||||
|
|
||||||
|
def test_get_org_compliance_list_with_status_filter(self, db, test_user, test_organization):
|
||||||
|
"""Test getting organization compliance list with status filter."""
|
||||||
|
# Create compliance record
|
||||||
|
compliance = MfaPolicyCompliance(
|
||||||
|
user_id=test_user.id,
|
||||||
|
organization_id=test_organization.id,
|
||||||
|
status=MfaComplianceStatus.COMPLIANT,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
compliance.save()
|
||||||
|
|
||||||
|
# Filter by different status
|
||||||
|
results = MfaPolicyService.get_org_compliance_list(
|
||||||
|
test_organization.id, status=MfaComplianceStatus.IN_GRACE
|
||||||
|
)
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
# Filter by correct status
|
||||||
|
results = MfaPolicyService.get_org_compliance_list(
|
||||||
|
test_organization.id, status=MfaComplianceStatus.COMPLIANT
|
||||||
|
)
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestMfaPolicyServiceDto:
|
||||||
|
"""Tests for MfaPolicyService DTOs."""
|
||||||
|
|
||||||
|
def test_org_policy_dto(self):
|
||||||
|
"""Test OrgPolicyDto creation."""
|
||||||
|
dto = OrgPolicyDto(
|
||||||
|
organization_id="org-123",
|
||||||
|
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP.value,
|
||||||
|
mfa_grace_period_days=14,
|
||||||
|
notify_days_before=7,
|
||||||
|
policy_version=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dto.organization_id == "org-123"
|
||||||
|
assert dto.mfa_policy_mode == "require_totp"
|
||||||
|
assert dto.mfa_grace_period_days == 14
|
||||||
|
|
||||||
|
def test_effective_user_policy_dto(self):
|
||||||
|
"""Test EffectiveUserPolicyDto creation."""
|
||||||
|
dto = EffectiveUserPolicyDto(
|
||||||
|
organization_id="org-123",
|
||||||
|
effective_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
|
||||||
|
requires_totp=True,
|
||||||
|
requires_webauthn=True,
|
||||||
|
grace_period_days=14,
|
||||||
|
is_exempt=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dto.requires_totp is True
|
||||||
|
assert dto.requires_webauthn is True
|
||||||
|
assert dto.is_exempt is False
|
||||||
|
|
||||||
|
def test_aggregate_mfa_state_dto(self):
|
||||||
|
"""Test AggregateMfaStateDto creation."""
|
||||||
|
dto = AggregateMfaStateDto(
|
||||||
|
overall_status=MfaComplianceStatus.IN_GRACE.value,
|
||||||
|
missing_methods=["totp"],
|
||||||
|
deadline_at="2025-02-01T00:00:00Z",
|
||||||
|
orgs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dto.overall_status == "in_grace"
|
||||||
|
assert "totp" in dto.missing_methods
|
||||||
|
assert dto.deadline_at == "2025-02-01T00:00:00Z"
|
||||||
|
|
||||||
|
def test_login_policy_result(self):
|
||||||
|
"""Test LoginPolicyResult creation."""
|
||||||
|
summary = AggregateMfaStateDto(
|
||||||
|
overall_status=MfaComplianceStatus.IN_GRACE.value,
|
||||||
|
missing_methods=["totp"],
|
||||||
|
orgs=[],
|
||||||
|
)
|
||||||
|
result = LoginPolicyResult(
|
||||||
|
can_create_full_session=True,
|
||||||
|
create_compliance_only_session=False,
|
||||||
|
compliance_summary=summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.can_create_full_session is True
|
||||||
|
assert result.create_compliance_only_session is False
|
||||||
|
assert result.compliance_summary.overall_status == "in_grace"
|
||||||
Reference in New Issue
Block a user