enable policies

This commit is contained in:
2026-01-16 17:31:20 +10:30
parent b2e084db33
commit d063a0ca81
28 changed files with 4296 additions and 224 deletions
+933
View File
@@ -0,0 +1,933 @@
"""Integration tests for MFA compliance enforcement."""
import pytest
import json
from datetime import datetime, timezone, timedelta
from gatehouse_app.models.user import User
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.session import Session
from gatehouse_app.utils.constants import MfaPolicyMode, MfaComplianceStatus, UserStatus, MfaRequirementOverride
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
@pytest.mark.integration
class TestMfaComplianceLogin:
"""Integration tests for MFA compliance during login."""
def test_login_with_no_policy(self, client, db, test_user):
"""Test login with no MFA policy (should work normally)."""
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# No MFA compliance info should be present when no policy exists
assert "mfa_compliance" not in data["data"]
assert "requires_mfa_enrollment" not in data["data"]
def test_login_with_optional_policy(self, client, db, test_user, test_organization):
"""Test login with optional MFA policy (should work normally)."""
# Create an optional MFA policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# MFA compliance should be present but status should be not_applicable
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] == "not_applicable"
assert "requires_mfa_enrollment" not in data["data"]
def test_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization):
"""Test login with required policy within grace period (should work with warning)."""
# Create a required MFA policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# MFA compliance should be present with in_grace status
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace"
assert "requires_mfa_enrollment" not in data["data"]
assert "totp" in data["data"]["mfa_compliance"]["missing_methods"]
def test_login_with_required_policy_after_deadline(self, client, db, test_user, test_organization):
"""Test login with required policy after deadline (should get compliance-only session)."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# Should have compliance-only session
assert data["data"]["requires_mfa_enrollment"] is True
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] in ["past_due", "suspended"]
def test_login_with_suspended_user(self, client, db, test_user, test_organization):
"""Test login with compliance suspended user (should get compliance-only session)."""
# Set user status to compliance suspended
test_user.status = UserStatus.COMPLIANCE_SUSPENDED
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# Should have compliance-only session
assert data["data"]["requires_mfa_enrollment"] is True
@pytest.mark.integration
class TestMfaComplianceAccess:
"""Integration tests for MFA compliance access control."""
def test_compliance_only_session_denied_full_access(self, client, db, test_user, test_organization):
"""Test that compliance-only session cannot access full access endpoints."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
# Create a compliance-only session
session = Session(
user_id=test_user.id,
token="compliance_only_token",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
is_compliance_only=True,
)
db.session.add(session)
db.session.commit()
# Try to access a full-access endpoint (get_my_organizations)
response = client.get(
"/api/v1/users/me/organizations",
headers={"Authorization": "Bearer compliance_only_token"},
)
assert response.status_code == 403
data = response.get_json()
assert data["success"] is False
assert data["error_type"] == "MFA_COMPLIANCE_REQUIRED"
def test_compliance_only_session_can_access_mfa_enrollment(self, client, db, test_user, test_organization):
"""Test that compliance-only session can access MFA enrollment endpoints."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
# Create a compliance-only session
session = Session(
user_id=test_user.id,
token="compliance_only_token",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
is_compliance_only=True,
)
db.session.add(session)
db.session.commit()
# Try to access MFA enrollment endpoint (should work)
response = client.get(
"/api/v1/auth/totp/status",
headers={"Authorization": "Bearer compliance_only_token"},
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
def test_compliance_only_session_can_access_logout(self, client, db, test_user, test_organization):
"""Test that compliance-only session can access logout endpoint."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
# Create a compliance-only session
session = Session(
user_id=test_user.id,
token="compliance_only_token",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
is_compliance_only=True,
)
db.session.add(session)
db.session.commit()
# Try to access logout endpoint (should work)
response = client.post(
"/api/v1/auth/logout",
headers={"Authorization": "Bearer compliance_only_token"},
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
@pytest.mark.integration
class TestMfaComplianceWebAuthn:
"""Integration tests for MFA compliance with WebAuthn login."""
def test_webauthn_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization):
"""Test WebAuthn login with required policy within grace period."""
# Create a required MFA policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Note: Full WebAuthn login test would require WebAuthn setup
# This test verifies the compliance response structure
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace"
@pytest.mark.integration
class TestMfaComplianceOIDC:
"""Integration tests for MFA compliance with OIDC authorization."""
def test_oidc_authorize_with_compliance_required(self, client, db, test_user, test_organization, app):
"""Test OIDC authorize with compliance required (should show error)."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
db.session.commit()
# Try OIDC authorize with credentials
response = client.post(
"/oidc/authorize",
data={
"client_id": "test_client",
"redirect_uri": "http://localhost:8080/callback",
"response_type": "code",
"scope": "openid profile email",
"state": "test_state",
"email": test_user.email,
"password": "TestPassword123!",
},
)
# Should return login page with error
assert response.status_code == 200
assert b"Your account requires multi factor enrollment before using single sign on" in response.data
# =============================================================================
# Phase 4: Edge Case Tests
# =============================================================================
@pytest.mark.integration
class TestMfaComplianceMultiOrg:
"""Integration tests for multi-organization MFA compliance edge cases."""
def test_user_with_multiple_orgs_different_policies(self, client, db, test_user):
"""Test user belonging to multiple orgs with different MFA policies."""
# Create two organizations
org1 = Organization(
name="Org1",
slug="org1-test-multi",
)
org2 = Organization(
name="Org2",
slug="org2-test-multi",
)
db.session.add_all([org1, org2])
db.session.commit()
# Add user to both orgs
membership1 = OrganizationMember(
user_id=test_user.id,
organization_id=org1.id,
role="member",
)
membership2 = OrganizationMember(
user_id=test_user.id,
organization_id=org2.id,
role="member",
)
db.session.add_all([membership1, membership2])
db.session.commit()
# Create different policies for each org
# Org1: OPTIONAL (no requirement)
policy1 = OrganizationSecurityPolicy(
organization_id=org1.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
# Org2: REQUIRE_TOTP (strictest)
policy2 = OrganizationSecurityPolicy(
organization_id=org2.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add_all([policy1, policy2])
db.session.commit()
# Evaluate user MFA state
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user)
# Overall status should reflect the strictest policy (REQUIRE_TOTP from org2)
assert compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value
assert "totp" in compliance_summary.missing_methods
# Verify per-org breakdown
assert len(compliance_summary.orgs) == 2
org1_status = next((o for o in compliance_summary.orgs if o.organization_id == org1.id), None)
org2_status = next((o for o in compliance_summary.orgs if o.organization_id == org2.id), None)
assert org1_status is not None
assert org2_status is not None
assert org1_status.status == MfaComplianceStatus.NOT_APPLICABLE.value
assert org2_status.status == MfaComplianceStatus.IN_GRACE.value
def test_user_with_multiple_orgs_all_suspended(self, client, db, test_user):
"""Test user with multiple orgs where all require MFA and are past due."""
# Create two organizations
org1 = Organization(
name="Org1",
slug="org1-test-suspended",
)
org2 = Organization(
name="Org2",
slug="org2-test-suspended",
)
db.session.add_all([org1, org2])
db.session.commit()
# Add user to both orgs
membership1 = OrganizationMember(
user_id=test_user.id,
organization_id=org1.id,
role="member",
)
membership2 = OrganizationMember(
user_id=test_user.id,
organization_id=org2.id,
role="member",
)
db.session.add_all([membership1, membership2])
db.session.commit()
# Create required policies
policy1 = OrganizationSecurityPolicy(
organization_id=org1.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
policy2 = OrganizationSecurityPolicy(
organization_id=org2.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add_all([policy1, policy2])
db.session.commit()
# Create past-due compliance records for both
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
compliance1 = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=org1.id,
status=MfaComplianceStatus.SUSPENDED,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
deadline_at=past_deadline,
suspended_at=past_deadline,
)
compliance2 = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=org2.id,
status=MfaComplianceStatus.SUSPENDED,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
deadline_at=past_deadline,
suspended_at=past_deadline,
)
db.session.add_all([compliance1, compliance2])
db.session.commit()
# Evaluate user MFA state
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user)
# Overall status should be SUSPENDED
assert compliance_summary.overall_status == MfaComplianceStatus.SUSPENDED.value
def test_strictest_mode_selection(self):
"""Test that get_strictest_mode returns the most restrictive policy."""
modes = [
MfaPolicyMode.DISABLED.value,
MfaPolicyMode.OPTIONAL.value,
MfaPolicyMode.REQUIRE_TOTP.value,
]
result = MfaPolicyService.get_strictest_mode(modes)
assert result == MfaPolicyMode.REQUIRE_TOTP.value
# Test with REQUIRE_TOTP_OR_WEBAUTHN (strictest)
modes_strictest = [
MfaPolicyMode.REQUIRE_TOTP.value,
MfaPolicyMode.REQUIRE_WEBAUTHN.value,
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
]
result = MfaPolicyService.get_strictest_mode(modes_strictest)
assert result == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
@pytest.mark.integration
class TestMfaComplianceUserOverrides:
"""Integration tests for user override edge cases."""
def test_user_override_inherit_mode(self, client, db, test_user, test_organization):
"""Test INHERIT mode - org policy applies as is."""
# Create a required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create INHERIT override (default behavior)
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
db.session.add(override)
db.session.commit()
# Get effective policy
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
# Should inherit org policy
assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
assert effective.requires_totp is True
assert effective.is_exempt is False
def test_user_override_required_mode(self, client, db, test_user, test_organization):
"""Test REQUIRED mode - user always required to have MFA."""
# Create an optional policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create REQUIRED override
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
)
db.session.add(override)
db.session.commit()
# Get effective policy
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
# Should be upgraded to REQUIRE_TOTP_OR_WEBAUTHN
assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
assert effective.requires_totp is True
assert effective.requires_webauthn is True
assert effective.is_exempt is False
def test_user_override_exempt_mode(self, client, db, test_user, test_organization):
"""Test EXEMPT mode - org policy does not apply."""
# Create a required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create EXEMPT override
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
db.session.add(override)
db.session.commit()
# Get effective policy
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
# Should be exempt from policy
assert effective.is_exempt is True
assert effective.effective_mode == MfaPolicyMode.DISABLED.value
assert effective.requires_totp is False
assert effective.requires_webauthn is False
def test_get_override_summary(self, client, db, test_user, test_organization):
"""Test getting override summary for a user."""
# No override exists
summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id)
assert summary["has_override"] is False
assert summary["mode"] == "inherit"
# Create an override
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
db.session.add(override)
db.session.commit()
# Get summary again
summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id)
assert summary["has_override"] is True
assert summary["mode"] == "exempt"
assert summary["is_exempt"] is True
@pytest.mark.integration
class TestMfaCompliancePolicyChanges:
"""Integration tests for policy changes affecting existing users."""
def test_policy_change_triggers_compliance_reevaluation(self, client, db, test_user, test_organization):
"""Test that policy change triggers compliance reevaluation."""
# Create initial optional policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create compliance record (should be NOT_APPLICABLE)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.NOT_APPLICABLE,
policy_version=1,
)
db.session.add(compliance)
db.session.commit()
# Update policy to REQUIRE_TOTP
MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
updated_by_user_id=test_user.id,
)
# Reevaluate all compliance
updated_count = MfaPolicyService.reevaluate_all_org_compliance(test_organization.id)
# Should have updated at least one record
assert updated_count >= 1
# Check compliance status was updated
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert updated_compliance.status == MfaComplianceStatus.IN_GRACE.value
assert updated_compliance.deadline_at is not None
def test_policy_relaxation_clears_requirements(self, client, db, test_user, test_organization):
"""Test that relaxing policy clears compliance requirements."""
# Create required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create IN_GRACE compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.IN_GRACE,
policy_version=1,
applied_at=datetime.now(timezone.utc),
deadline_at=datetime.now(timezone.utc) + timedelta(days=14),
)
db.session.add(compliance)
db.session.commit()
# Update policy to OPTIONAL
MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
updated_by_user_id=test_user.id,
)
# Reevaluate compliance
MfaPolicyService.reevaluate_all_org_compliance(test_organization.id)
# Check compliance status was updated to NOT_APPLICABLE
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert updated_compliance.status == MfaComplianceStatus.NOT_APPLICABLE.value
@pytest.mark.integration
class TestMfaComplianceScheduledJob:
"""Integration tests for the MFA compliance scheduled job."""
def test_transition_to_suspended(self, client, db, test_user, test_organization):
"""Test that past-due users are transitioned to suspended."""
# Create required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create past-due compliance record
past_deadline = datetime.now(timezone.utc) - timedelta(hours=1)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
db.session.commit()
# Run the job
now = datetime.now(timezone.utc)
suspended_count = MfaPolicyService.transition_to_suspended_if_past_due(now)
# Should have suspended the user
assert suspended_count >= 1
# Check compliance status
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert updated_compliance.status == MfaComplianceStatus.SUSPENDED.value
assert updated_compliance.suspended_at is not None
# Check user status
db.refresh(test_user)
assert test_user.status == UserStatus.COMPLIANCE_SUSPENDED
def test_check_and_restore_user_status(self, client, db, test_user, test_organization):
"""Test that suspended users are restored when they become compliant."""
# Create required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# User is suspended
test_user.status = UserStatus.COMPLIANCE_SUSPENDED
db.session.commit()
# Create EXEMPT override to clear requirement
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
db.session.add(override)
db.session.commit()
# Check and restore status
restored = MfaPolicyService.check_and_restore_user_status(test_user.id)
# Should have restored user
assert restored is True
db.refresh(test_user)
assert test_user.status == UserStatus.ACTIVE
@pytest.mark.integration
class TestMfaComplianceMultiOrgAggregate:
"""Integration tests for multi-org aggregate state calculation."""
def test_get_multi_org_aggregate_state(self, client, db, test_user):
"""Test aggregate state calculation for multi-org user."""
# Create two organizations
org1 = Organization(
name="AggOrg1",
slug="agg-org1-test",
)
org2 = Organization(
name="AggOrg2",
slug="agg-org2-test",
)
db.session.add_all([org1, org2])
db.session.commit()
# Add user to both
membership1 = OrganizationMember(
user_id=test_user.id,
organization_id=org1.id,
role="member",
)
membership2 = OrganizationMember(
user_id=test_user.id,
organization_id=org2.id,
role="member",
)
db.session.add_all([membership1, membership2])
db.session.commit()
# Create policies
policy1 = OrganizationSecurityPolicy(
organization_id=org1.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
policy2 = OrganizationSecurityPolicy(
organization_id=org2.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add_all([policy1, policy2])
db.session.commit()
# Get aggregate state
aggregate = MfaPolicyService.get_multi_org_aggregate_state(test_user)
# Verify structure
assert "overall_status" in aggregate
assert "strictest_mode" in aggregate
assert "missing_methods" in aggregate
assert "requiring_org_count" in aggregate
assert "requiring_orgs" in aggregate
assert "per_org_details" in aggregate
# Strictest mode should be REQUIRE_TOTP_OR_WEBAUTHN
assert aggregate["strictest_mode"] == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
# Both orgs should require MFA
assert aggregate["requiring_org_count"] == 2
assert len(aggregate["requiring_orgs"]) == 2
assert len(aggregate["per_org_details"]) == 2
+295
View File
@@ -0,0 +1,295 @@
"""Unit tests for MFA policy models."""
import pytest
from datetime import datetime, timezone, timedelta
from gatehouse_app.models import (
User,
Organization,
OrganizationMember,
OrganizationSecurityPolicy,
UserSecurityPolicy,
MfaPolicyCompliance,
Session,
)
from gatehouse_app.utils.constants import (
UserStatus,
MfaPolicyMode,
MfaComplianceStatus,
MfaRequirementOverride,
SessionStatus,
OrganizationRole,
)
@pytest.mark.unit
class TestOrganizationSecurityPolicyModel:
"""Tests for OrganizationSecurityPolicy model."""
def test_create_org_security_policy(self, db, test_organization):
"""Test creating an organization security policy."""
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
)
policy.save()
assert policy.id is not None
assert policy.organization_id == test_organization.id
assert policy.mfa_policy_mode == MfaPolicyMode.OPTIONAL
assert policy.mfa_grace_period_days == 14
assert policy.notify_days_before == 7
assert policy.policy_version == 1
assert policy.created_at is not None
def test_org_security_policy_to_dict(self, db, test_organization):
"""Test organization security policy to_dict method."""
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=7,
notify_days_before=3,
)
policy.save()
policy_dict = policy.to_dict()
assert "id" in policy_dict
assert "organization_id" in policy_dict
assert policy_dict["organization_id"] == test_organization.id
assert "mfa_policy_mode" in policy_dict
assert "mfa_grace_period_days" in policy_dict
def test_org_security_policy_relationships(self, db, test_organization):
"""Test organization security policy relationships."""
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
)
policy.save()
# Test relationship
assert policy.organization is not None
assert policy.organization.id == test_organization.id
@pytest.mark.unit
class TestUserSecurityPolicyModel:
"""Tests for UserSecurityPolicy model."""
def test_create_user_security_policy(self, db, test_user, test_organization):
"""Test creating a user security policy."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
policy.save()
assert policy.id is not None
assert policy.user_id == test_user.id
assert policy.organization_id == test_organization.id
assert policy.mfa_override_mode == MfaRequirementOverride.INHERIT
assert policy.force_totp is False
assert policy.force_webauthn is False
def test_user_security_policy_with_overrides(self, db, test_user, test_organization):
"""Test user security policy with override settings."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
force_totp=True,
force_webauthn=False,
)
policy.save()
assert policy.mfa_override_mode == MfaRequirementOverride.REQUIRED
assert policy.force_totp is True
assert policy.force_webauthn is False
def test_user_security_policy_exempt(self, db, test_user, test_organization):
"""Test user security policy with exempt override."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
policy.save()
assert policy.mfa_override_mode == MfaRequirementOverride.EXEMPT
def test_user_security_policy_relationships(self, db, test_user, test_organization):
"""Test user security policy relationships."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
policy.save()
# Test relationships
assert policy.user is not None
assert policy.user.id == test_user.id
assert policy.organization is not None
assert policy.organization.id == test_organization.id
@pytest.mark.unit
class TestMfaPolicyComplianceModel:
"""Tests for MfaPolicyCompliance model."""
def test_create_mfa_policy_compliance(self, db, test_user, test_organization):
"""Test creating an MFA policy compliance record."""
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.NOT_APPLICABLE,
policy_version=1,
)
compliance.save()
assert compliance.id is not None
assert compliance.user_id == test_user.id
assert compliance.organization_id == test_organization.id
assert compliance.status == MfaComplianceStatus.NOT_APPLICABLE
assert compliance.policy_version == 1
assert compliance.notification_count == 0
def test_mfa_policy_compliance_in_grace(self, db, test_user, test_organization):
"""Test MFA compliance record in grace period."""
now = datetime.now(timezone.utc)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.IN_GRACE,
policy_version=1,
applied_at=now,
deadline_at=now + timedelta(days=14),
)
compliance.save()
assert compliance.status == MfaComplianceStatus.IN_GRACE
assert compliance.applied_at is not None
assert compliance.deadline_at is not None
assert compliance.deadline_at > now
def test_mfa_policy_compliance_compliant(self, db, test_user, test_organization):
"""Test MFA compliance record when compliant."""
now = datetime.now(timezone.utc)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.COMPLIANT,
policy_version=1,
applied_at=now - timedelta(days=30),
deadline_at=now - timedelta(days=16),
compliant_at=now - timedelta(days=16),
)
compliance.save()
assert compliance.status == MfaComplianceStatus.COMPLIANT
assert compliance.compliant_at is not None
def test_mfa_policy_compliance_suspended(self, db, test_user, test_organization):
"""Test MFA compliance record when suspended."""
now = datetime.now(timezone.utc)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.SUSPENDED,
policy_version=1,
applied_at=now - timedelta(days=30),
deadline_at=now - timedelta(days=16),
suspended_at=now - timedelta(days=16),
)
compliance.save()
assert compliance.status == MfaComplianceStatus.SUSPENDED
assert compliance.suspended_at is not None
def test_mfa_policy_compliance_relationships(self, db, test_user, test_organization):
"""Test MFA compliance relationships."""
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.NOT_APPLICABLE,
policy_version=1,
)
compliance.save()
# Test relationships
assert compliance.user is not None
assert compliance.user.id == test_user.id
assert compliance.organization is not None
assert compliance.organization.id == test_organization.id
@pytest.mark.unit
class TestSessionModelComplianceFlag:
"""Tests for Session model compliance flag."""
def test_session_default_not_compliance_only(self, db, test_user):
"""Test that sessions are not compliance only by default."""
session = Session(
user_id=test_user.id,
token="test-token-123",
status=SessionStatus.ACTIVE,
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
last_activity_at=datetime.now(timezone.utc),
)
session.save()
assert session.is_compliance_only is False
def test_session_compliance_only(self, db, test_user):
"""Test creating a compliance-only session."""
session = Session(
user_id=test_user.id,
token="compliance-token-123",
status=SessionStatus.ACTIVE,
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
last_activity_at=datetime.now(timezone.utc),
is_compliance_only=True,
)
session.save()
assert session.is_compliance_only is True
def test_session_to_dict_excludes_token(self, db, test_user):
"""Test that session to_dict excludes the token."""
session = Session(
user_id=test_user.id,
token="test-token-456",
status=SessionStatus.ACTIVE,
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
last_activity_at=datetime.now(timezone.utc),
)
session.save()
session_dict = session.to_dict()
assert "id" in session_dict
assert "user_id" in session_dict
assert "is_compliance_only" in session_dict
assert session_dict["is_compliance_only"] is False
@pytest.mark.unit
class TestUserStatusComplianceSuspended:
"""Tests for UserStatus.COMPLIANCE_SUSPENDED."""
def test_compliance_suspended_status_exists(self):
"""Test that COMPLIANCE_SUSPENDED status exists."""
assert UserStatus.COMPLIANCE_SUSPENDED.value == "compliance_suspended"
def test_create_compliance_suspended_user(self, db):
"""Test creating a compliance suspended user."""
user = User(
email="suspended@example.com",
full_name="Suspended User",
status=UserStatus.COMPLIANCE_SUSPENDED,
)
user.save()
assert user.status == UserStatus.COMPLIANCE_SUSPENDED
@@ -0,0 +1,476 @@
"""Unit tests for MfaPolicyService."""
import pytest
from datetime import datetime, timezone, timedelta
from unittest.mock import patch, MagicMock
from gatehouse_app.models import (
User,
Organization,
OrganizationMember,
OrganizationSecurityPolicy,
UserSecurityPolicy,
MfaPolicyCompliance,
Session,
)
from gatehouse_app.services.mfa_policy_service import (
MfaPolicyService,
OrgPolicyDto,
EffectiveUserPolicyDto,
AggregateMfaStateDto,
LoginPolicyResult,
)
from gatehouse_app.utils.constants import (
UserStatus,
MfaPolicyMode,
MfaComplianceStatus,
MfaRequirementOverride,
SessionStatus,
OrganizationRole,
)
@pytest.mark.unit
class TestMfaPolicyService:
"""Tests for MfaPolicyService."""
def test_get_org_policy_not_found(self, db, test_organization):
"""Test getting organization policy when none exists."""
policy = MfaPolicyService.get_org_policy(test_organization.id)
assert policy is None
def test_get_org_policy_found(self, db, test_organization):
"""Test getting organization policy when it exists."""
# Create policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
)
org_policy.save()
policy = MfaPolicyService.get_org_policy(test_organization.id)
assert policy is not None
assert policy.organization_id == test_organization.id
assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
assert policy.mfa_grace_period_days == 14
assert policy.notify_days_before == 7
assert policy.policy_version == 1
def test_get_effective_user_policy_no_org_policy(self, db, test_user, test_organization):
"""Test effective user policy when no org policy exists."""
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy is not None
assert policy.organization_id == test_organization.id
assert policy.effective_mode == MfaPolicyMode.DISABLED.value
assert policy.requires_totp is False
assert policy.requires_webauthn is False
assert policy.is_exempt is True
def test_get_effective_user_policy_with_org_policy(self, db, test_user, test_organization):
"""Test effective user policy with org policy and no override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy is not None
assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
assert policy.requires_totp is True
assert policy.requires_webauthn is False
assert policy.is_exempt is False
def test_get_effective_user_policy_with_override_inherit(self, db, test_user, test_organization):
"""Test effective user policy with INHERIT override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
mfa_grace_period_days=7,
)
org_policy.save()
# Create user override
user_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
user_override.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy.effective_mode == MfaPolicyMode.REQUIRE_WEBAUTHN.value
assert policy.requires_webauthn is True
def test_get_effective_user_policy_with_override_exempt(self, db, test_user, test_organization):
"""Test effective user policy with EXEMPT override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
)
org_policy.save()
# Create user override
user_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
user_override.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy.effective_mode == MfaPolicyMode.DISABLED.value
assert policy.is_exempt is True
def test_get_effective_user_policy_with_override_required(self, db, test_user, test_organization):
"""Test effective user policy with REQUIRED override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
)
org_policy.save()
# Create user override
user_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
)
user_override.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
assert policy.requires_totp is True
assert policy.requires_webauthn is True
assert policy.is_exempt is False
def test_evaluate_user_mfa_state_no_policy(self, db, test_user, test_organization):
"""Test evaluating user MFA state with no policy."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
state = MfaPolicyService.evaluate_user_mfa_state(test_user)
assert state is not None
assert state.overall_status == MfaComplianceStatus.COMPLIANT.value
assert len(state.missing_methods) == 0
assert len(state.orgs) == 1
def test_evaluate_user_mfa_state_with_policy(self, db, test_user, test_organization):
"""Test evaluating user MFA state with policy."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
state = MfaPolicyService.evaluate_user_mfa_state(test_user)
assert state is not None
assert state.overall_status == MfaComplianceStatus.IN_GRACE.value
assert "totp" in state.missing_methods
assert len(state.orgs) == 1
assert state.orgs[0].effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
def test_after_primary_auth_success_no_required_policy(self, db, test_user, test_organization):
"""Test after_primary_auth_success with no required policy."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
result = MfaPolicyService.after_primary_auth_success(test_user)
assert result.can_create_full_session is True
assert result.create_compliance_only_session is False
assert result.compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value
def test_after_primary_auth_success_in_grace(self, db, test_user, test_organization):
"""Test after_primary_auth_success when user is in grace period."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
result = MfaPolicyService.after_primary_auth_success(test_user)
assert result.can_create_full_session is True
assert result.create_compliance_only_session is False
assert result.compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value
def test_after_primary_auth_success_past_due(self, db, test_user, test_organization):
"""Test after_primary_auth_success when user is past due."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
deadline_at=datetime.now(timezone.utc) - timedelta(days=1),
)
compliance.save()
result = MfaPolicyService.after_primary_auth_success(test_user)
assert result.can_create_full_session is False
assert result.create_compliance_only_session is True
def test_create_org_policy_new(self, db, test_organization):
"""Test creating a new organization policy."""
policy = MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
updated_by_user_id=None,
)
assert policy is not None
assert policy.organization_id == test_organization.id
assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN
assert policy.policy_version == 1
def test_create_org_policy_update(self, db, test_organization):
"""Test updating an existing organization policy."""
# Create initial policy
initial_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
)
initial_policy.save()
# Update policy
updated_policy = MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=7,
updated_by_user_id=None,
)
assert updated_policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP
assert updated_policy.mfa_grace_period_days == 7
assert updated_policy.policy_version == 2
def test_set_user_override_new(self, db, test_user, test_organization):
"""Test setting a new user override."""
override = MfaPolicyService.set_user_override(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
force_totp=True,
force_webauthn=False,
updated_by_user_id=None,
)
assert override is not None
assert override.user_id == test_user.id
assert override.organization_id == test_organization.id
assert override.mfa_override_mode == MfaRequirementOverride.REQUIRED
assert override.force_totp is True
def test_set_user_override_update(self, db, test_user, test_organization):
"""Test updating an existing user override."""
# Create initial override
initial_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
initial_override.save()
# Update override
updated_override = MfaPolicyService.set_user_override(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
updated_by_user_id=None,
)
assert updated_override.mfa_override_mode == MfaRequirementOverride.EXEMPT
def test_get_user_compliance(self, db, test_user, test_organization):
"""Test getting user compliance record."""
# Create compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.COMPLIANT,
policy_version=1,
)
compliance.save()
result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert result is not None
assert result.status == MfaComplianceStatus.COMPLIANT
def test_get_user_compliance_not_found(self, db, test_user, test_organization):
"""Test getting user compliance record when none exists."""
result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert result is None
def test_get_org_compliance_list(self, db, test_user, test_organization):
"""Test getting organization compliance list."""
# Create compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.IN_GRACE,
policy_version=1,
deadline_at=datetime.now(timezone.utc) + timedelta(days=14),
)
compliance.save()
results = MfaPolicyService.get_org_compliance_list(test_organization.id)
assert len(results) == 1
assert results[0]["user_id"] == test_user.id
assert results[0]["status"] == MfaComplianceStatus.IN_GRACE.value
def test_get_org_compliance_list_with_status_filter(self, db, test_user, test_organization):
"""Test getting organization compliance list with status filter."""
# Create compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.COMPLIANT,
policy_version=1,
)
compliance.save()
# Filter by different status
results = MfaPolicyService.get_org_compliance_list(
test_organization.id, status=MfaComplianceStatus.IN_GRACE
)
assert len(results) == 0
# Filter by correct status
results = MfaPolicyService.get_org_compliance_list(
test_organization.id, status=MfaComplianceStatus.COMPLIANT
)
assert len(results) == 1
@pytest.mark.unit
class TestMfaPolicyServiceDto:
"""Tests for MfaPolicyService DTOs."""
def test_org_policy_dto(self):
"""Test OrgPolicyDto creation."""
dto = OrgPolicyDto(
organization_id="org-123",
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP.value,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
assert dto.organization_id == "org-123"
assert dto.mfa_policy_mode == "require_totp"
assert dto.mfa_grace_period_days == 14
def test_effective_user_policy_dto(self):
"""Test EffectiveUserPolicyDto creation."""
dto = EffectiveUserPolicyDto(
organization_id="org-123",
effective_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
requires_totp=True,
requires_webauthn=True,
grace_period_days=14,
is_exempt=False,
)
assert dto.requires_totp is True
assert dto.requires_webauthn is True
assert dto.is_exempt is False
def test_aggregate_mfa_state_dto(self):
"""Test AggregateMfaStateDto creation."""
dto = AggregateMfaStateDto(
overall_status=MfaComplianceStatus.IN_GRACE.value,
missing_methods=["totp"],
deadline_at="2025-02-01T00:00:00Z",
orgs=[],
)
assert dto.overall_status == "in_grace"
assert "totp" in dto.missing_methods
assert dto.deadline_at == "2025-02-01T00:00:00Z"
def test_login_policy_result(self):
"""Test LoginPolicyResult creation."""
summary = AggregateMfaStateDto(
overall_status=MfaComplianceStatus.IN_GRACE.value,
missing_methods=["totp"],
orgs=[],
)
result = LoginPolicyResult(
can_create_full_session=True,
create_compliance_only_session=False,
compliance_summary=summary,
)
assert result.can_create_full_session is True
assert result.create_compliance_only_session is False
assert result.compliance_summary.overall_status == "in_grace"