remove junk

This commit is contained in:
2026-02-23 13:25:05 +10:30
parent 7637d7df45
commit cbdf6185b6
24 changed files with 0 additions and 6789 deletions
-1
View File
@@ -1 +0,0 @@
"""Tests package."""
-375
View File
@@ -1,375 +0,0 @@
"""Pytest configuration and fixtures."""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta, timezone
from gatehouse_app import create_app
from gatehouse_app.extensions import db as _db
from gatehouse_app.models import User, Organization, OrganizationMember, AuthenticationMethod
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.utils.constants import OrganizationRole, AuthMethodType
from gatehouse_app.services.external_auth_service import ExternalProviderConfig, OAuthState
@pytest.fixture(scope="session")
def app():
"""Create application for testing."""
app = create_app("testing")
return app
@pytest.fixture(scope="function")
def db(app):
"""Create database for testing."""
with app.app_context():
_db.create_all()
yield _db
_db.session.remove()
_db.drop_all()
@pytest.fixture(scope="function")
def client(app, db):
"""Create test client."""
return app.test_client()
@pytest.fixture(scope="function")
def test_user(db):
"""Create a test user."""
email = "test@example.com"
password = "TestPassword123!"
full_name = "Test User"
user = AuthService.register_user(
email=email,
password=password,
full_name=full_name,
)
# Store password for testing
user._test_password = password
return user
@pytest.fixture(scope="function")
def test_organization(db, test_user):
"""Create a test organization."""
from gatehouse_app.services.organization_service import OrganizationService
org = OrganizationService.create_organization(
name="Test Organization",
slug="test-org",
owner_user_id=test_user.id,
description="A test organization",
)
return org
@pytest.fixture(scope="function")
def authenticated_client(client, test_user):
"""Create authenticated test client."""
# Login
response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": test_user._test_password,
},
)
assert response.status_code == 200
return client
@pytest.fixture(scope="function")
def second_test_user(db):
"""Create a second test user."""
email = "second@example.com"
password = "TestPassword123!"
full_name = "Second User"
user = AuthService.register_user(
email=email,
password=password,
full_name=full_name,
)
user._test_password = password
return user
# =============================================================================
# External Auth Testing Fixtures
# =============================================================================
@pytest.fixture(scope="function")
def google_provider_config(db, test_organization):
"""Create a Google OAuth provider configuration."""
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-google-client-id",
client_secret_encrypted="encrypted-google-secret",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=[
"http://localhost:3000/callback",
"http://localhost:5173/callback",
"https://myapp.example.com/callback",
],
is_active=True,
)
config.save()
return config
@pytest.fixture(scope="function")
def github_provider_config(db, test_organization):
"""Create a GitHub OAuth provider configuration."""
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GITHUB.value,
client_id="test-github-client-id",
client_secret_encrypted="encrypted-github-secret",
auth_url="https://github.com/login/oauth/authorize",
token_url="https://github.com/login/oauth/access_token",
userinfo_url="https://api.github.com/user",
scopes=["read:user", "user:email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
return config
@pytest.fixture(scope="function")
def microsoft_provider_config(db, test_organization):
"""Create a Microsoft OAuth provider configuration."""
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.MICROSOFT.value,
client_id="test-microsoft-client-id",
client_secret_encrypted="encrypted-microsoft-secret",
auth_url="https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
token_url="https://login.microsoftonline.com/common/oauth2/v2.0/token",
userinfo_url="https://graph.microsoft.com/oidc/userinfo",
scopes=["openid", "profile", "email", "User.Read"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
return config
@pytest.fixture(scope="function")
def user_with_google_link(db, test_user):
"""Create a test user with a linked Google account."""
auth_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123456789",
provider_data={
"email": test_user.email,
"name": "Test User",
"picture": "https://example.com/avatar.jpg",
},
verified=True,
is_primary=False,
)
auth_method.save()
return test_user
@pytest.fixture(scope="function")
def user_with_multiple_providers(db, test_user):
"""Create a test user with multiple linked external accounts."""
# Google account
google_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123",
provider_data={
"email": test_user.email,
"name": "Test User",
},
verified=True,
)
google_method.save()
# GitHub account
github_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GITHUB,
provider_user_id="github-456",
provider_data={
"email": "user@github.com",
"name": "Test User",
},
verified=True,
)
github_method.save()
return test_user
@pytest.fixture
def mock_google_oauth_token_response():
"""Mock Google OAuth token response."""
return {
"access_token": "ya29.mock-access-token",
"refresh_token": "1//mock-refresh-token",
"id_token": "eyJ.mock-id-token",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "openid profile email",
}
@pytest.fixture
def mock_google_oauth_user_info():
"""Mock Google OAuth user info response."""
return {
"sub": "google-123456789",
"name": "Test User",
"given_name": "Test",
"family_name": "User",
"picture": "https://example.com/avatar.jpg",
"email": "testuser@gmail.com",
"email_verified": True,
}
@pytest.fixture
def mock_github_oauth_token_response():
"""Mock GitHub OAuth token response."""
return {
"access_token": "gho_mock-access-token",
"token_type": "bearer",
"scope": "read:user,user:email",
}
@pytest.fixture
def mock_github_oauth_user_info():
"""Mock GitHub OAuth user info response."""
return {
"id": 123456789,
"login": "testuser",
"name": "Test User",
"email": "testuser@github.com",
"avatar_url": "https://example.com/avatar.jpg",
"type": "User",
}
@pytest.fixture
def oauth_login_state(db, test_organization):
"""Create an OAuth state for login flow."""
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
nonce="mock-nonce",
code_verifier="mock-code-verifier",
code_challenge="mock-code-challenge",
lifetime_seconds=600,
)
return state
@pytest.fixture
def oauth_register_state(db, test_organization):
"""Create an OAuth state for register flow."""
state = OAuthState.create_state(
flow_type="register",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
lifetime_seconds=600,
)
return state
@pytest.fixture
def oauth_link_state(db, test_user, test_organization):
"""Create an OAuth state for link flow."""
state = OAuthState.create_state(
flow_type="link",
provider_type=AuthMethodType.GOOGLE,
user_id=test_user.id,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
lifetime_seconds=600,
)
return state
@pytest.fixture
def expired_oauth_state(db, test_organization):
"""Create an expired OAuth state."""
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
lifetime_seconds=-1, # Already expired
)
return state
@pytest.fixture
def used_oauth_state(db, test_organization):
"""Create a used OAuth state."""
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
lifetime_seconds=600,
)
state.mark_used()
return state
@pytest.fixture
def mock_oauth_flow_mocks():
"""Common mocks for OAuth flow tests."""
with patch.object(
ExternalProviderConfig, 'get_client_secret', return_value='mock-secret'
) as mock_get_secret, patch(
'requests.post'
) as mock_post, patch(
'requests.get'
) as mock_get:
# Mock token exchange response
mock_post.return_value.json.return_value = {
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"id_token": "mock-id-token",
"expires_in": 3600,
}
mock_post.return_value.raise_for_status = Mock()
# Mock user info response
mock_get.return_value.json.return_value = {
"sub": "google-123",
"email": "testuser@gmail.com",
"email_verified": True,
"name": "Test User",
"picture": "https://example.com/avatar.jpg",
}
mock_get.return_value.raise_for_status = Mock()
yield {
'get_secret': mock_get_secret,
'post': mock_post,
'get': mock_get,
}
-1
View File
@@ -1 +0,0 @@
"""Integration tests package."""
-107
View File
@@ -1,107 +0,0 @@
"""Integration tests for authentication flow."""
import pytest
import json
@pytest.mark.integration
class TestAuthFlow:
"""Integration tests for authentication endpoints."""
def test_register_login_logout_flow(self, client, db):
"""Test complete registration, login, and logout flow."""
# Register
register_data = {
"email": "integration@example.com",
"password": "TestPassword123!",
"password_confirm": "TestPassword123!",
"full_name": "Integration Test",
}
response = client.post(
"/api/v1/auth/register",
data=json.dumps(register_data),
content_type="application/json",
)
assert response.status_code == 201
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert data["data"]["user"]["email"] == "integration@example.com"
# Logout
response = client.post("/api/v1/auth/logout")
assert response.status_code == 200
# Login
login_data = {
"email": "integration@example.com",
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
# Logout again
response = client.post("/api/v1/auth/logout")
assert response.status_code == 200
def test_get_current_user_authenticated(self, authenticated_client):
"""Test getting current user when authenticated."""
response = authenticated_client.get("/api/v1/auth/me")
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
def test_get_current_user_unauthenticated(self, client):
"""Test getting current user when not authenticated."""
response = client.get("/api/v1/auth/me")
assert response.status_code == 401
data = response.get_json()
assert data["success"] is False
def test_invalid_credentials(self, client, test_user):
"""Test login with invalid credentials."""
login_data = {
"email": test_user.email,
"password": "WrongPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 401
data = response.get_json()
assert data["success"] is False
def test_duplicate_registration(self, client, test_user):
"""Test registering with existing email."""
register_data = {
"email": test_user.email,
"password": "TestPassword123!",
"password_confirm": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/register",
data=json.dumps(register_data),
content_type="application/json",
)
assert response.status_code == 409
data = response.get_json()
assert data["success"] is False
@@ -1,696 +0,0 @@
"""Integration tests for external authentication API flows."""
import pytest
import json
from unittest.mock import patch, Mock
from gatehouse_app.services.external_auth_service import (
ExternalAuthService,
ExternalProviderConfig,
OAuthState,
)
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.utils.constants import AuthMethodType, OrganizationRole
from gatehouse_app.models import User, AuthenticationMethod, OrganizationMember
@pytest.mark.integration
class TestExternalAuthApiFlows:
"""Integration tests for external auth API flows."""
def test_complete_account_linking_flow(
self, app, db, client, test_user, test_organization
):
"""Test complete account linking flow: initiate → callback → complete."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
client_secret_encrypted="encrypted-secret",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create organization membership
member = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
member.save()
# Login to get token
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": test_user._test_password,
},
)
assert login_response.status_code == 200
token = login_response.get_json()["data"]["token"]
with patch.object(
ExternalAuthService, '_exchange_code'
) as mock_exchange, patch.object(
ExternalAuthService, '_get_user_info'
) as mock_get_user_info:
# Mock external provider responses
mock_exchange.return_value = {
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"id_token": "mock-id-token",
"expires_in": 3600,
}
mock_get_user_info.return_value = {
"provider_user_id": "google-123",
"email": "user@gmail.com",
"email_verified": True,
"name": "Test User",
"picture": "https://example.com/avatar.jpg",
"raw_data": {},
}
# Step 1: Initiate link flow
initiate_response = client.post(
"/api/v1/auth/external/google/link",
json={},
headers={"Authorization": f"Bearer {token}"},
)
assert initiate_response.status_code == 200
initiate_data = initiate_response.get_json()
assert "authorization_url" in initiate_data["data"]
assert "state" in initiate_data["data"]
state = initiate_data["data"]["state"]
# Step 2: Simulate callback (complete link flow)
with patch.object(AuditService, 'log_external_auth_link_completed'):
complete_response = client.get(
f"/api/v1/auth/external/google/callback",
query_string={
"code": "mock-auth-code",
"state": state,
},
)
# The callback returns 200 on success
assert complete_response.status_code == 200
# Verify account is linked
auth_method = AuthenticationMethod.query.filter_by(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123",
).first()
assert auth_method is not None
def test_complete_login_flow(
self, app, db, client, test_user, test_organization
):
"""Test complete login flow: initiate → callback → authenticate."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
client_secret_encrypted="encrypted-secret",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create authentication method for user
auth_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123",
provider_data={"email": test_user.email},
verified=True,
)
auth_method.save()
with patch.object(
ExternalAuthService, '_exchange_code'
) as mock_exchange, patch.object(
ExternalAuthService, '_get_user_info'
) as mock_get_user_info:
# Mock external provider responses
mock_exchange.return_value = {
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"id_token": "mock-id-token",
"expires_in": 3600,
}
mock_get_user_info.return_value = {
"provider_user_id": "google-123",
"email": test_user.email,
"email_verified": True,
"name": "Test User",
"picture": "https://example.com/avatar.jpg",
"raw_data": {},
}
# Initiate login flow
login_init_response = client.get(
"/api/v1/auth/external/google/authorize",
query_string={"flow": "login"},
)
assert login_init_response.status_code == 200
login_init_data = login_init_response.get_json()
assert "authorization_url" in login_init_data["data"]
state = login_init_data["data"]["state"]
# Simulate callback
callback_response = client.get(
f"/api/v1/auth/external/google/callback",
query_string={
"code": "mock-auth-code",
"state": state,
},
)
assert callback_response.status_code == 200
callback_data = callback_response.get_json()
assert callback_data["success"] is True
assert callback_data["flow_type"] == "login"
assert "token" in callback_data["data"]
assert callback_data["data"]["user"]["id"] == test_user.id
def test_account_unlinking_flow(
self, app, db, client, test_user, test_organization
):
"""Test account unlinking flow."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create organization membership
member = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
member.save()
# Create password auth method
password_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.PASSWORD,
provider_user_id=test_user.id,
)
password_method.save()
# Create Google auth method
google_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123",
provider_data={"email": test_user.email},
verified=True,
)
google_method.save()
# Login to get token
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": test_user._test_password,
},
)
token = login_response.get_json()["data"]["token"]
# Unlink Google account
with patch.object(AuditService, 'log_external_auth_unlink'):
unlink_response = client.delete(
"/api/v1/auth/external/google/unlink",
headers={"Authorization": f"Bearer {token}"},
)
assert unlink_response.status_code == 200
unlink_data = unlink_response.get_json()
assert "success" in unlink_data or "message" in unlink_data
# Verify account is unlinked
auth_method = AuthenticationMethod.query.filter_by(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
).first()
assert auth_method is None
def test_provider_configuration_crud(
self, app, db, client, test_user, test_organization
):
"""Test provider configuration CRUD operations."""
with app.app_context():
# Create organization membership as admin
member = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.ADMIN,
)
member.save()
# Login to get token
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": test_user._test_password,
},
)
token = login_response.get_json()["data"]["token"]
# Step 1: Create provider config
with patch.object(AuditService, 'log_external_auth_config_create'):
create_response = client.post(
"/api/v1/auth/external/google/config",
json={
"client_id": "new-client-id",
"client_secret": "new-client-secret",
"scopes": ["openid", "profile", "email"],
"redirect_uris": ["http://localhost:3000/callback"],
},
headers={"Authorization": f"Bearer {token}"},
)
assert create_response.status_code == 201
create_data = create_response.get_json()
assert create_data["data"]["provider_type"] == "google"
assert create_data["data"]["client_id"] == "new-client-id"
config_id = create_data["data"]["id"]
# Step 2: List providers
list_response = client.get(
"/api/v1/auth/external/providers",
headers={"Authorization": f"Bearer {token}"},
)
assert list_response.status_code == 200
list_data = list_response.get_json()
google_provider = next(
p for p in list_data["data"]["providers"] if p["id"] == "google"
)
assert google_provider["is_configured"] is True
# Step 3: Get provider config
get_response = client.get(
"/api/v1/auth/external/google/config",
headers={"Authorization": f"Bearer {token}"},
)
assert get_response.status_code == 200
get_data = get_response.get_json()
assert get_data["data"]["client_id"] == "new-client-id"
# Step 4: Update provider config
with patch.object(AuditService, 'log_external_auth_config_update'):
update_response = client.post(
"/api/v1/auth/external/google/config",
json={
"client_id": "updated-client-id",
"client_secret": "updated-client-secret",
},
headers={"Authorization": f"Bearer {token}"},
)
assert update_response.status_code == 200
update_data = update_response.get_json()
assert update_data["data"]["client_id"] == "updated-client-id"
# Step 5: Delete provider config
with patch.object(AuditService, 'log_external_auth_config_delete'):
delete_response = client.delete(
"/api/v1/auth/external/google/config",
headers={"Authorization": f"Bearer {token}"},
)
assert delete_response.status_code == 200
# Verify deletion
get_deleted_response = client.get(
"/api/v1/auth/external/google/config",
headers={"Authorization": f"Bearer {token}"},
)
assert get_deleted_response.status_code == 404
def test_invalid_state_error(self, app, db, client, test_user, test_organization):
"""Test error handling for invalid OAuth state."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Try callback with invalid state
callback_response = client.get(
"/api/v1/auth/external/google/callback",
query_string={
"code": "mock-auth-code",
"state": "invalid-state",
},
)
assert callback_response.status_code == 400
callback_data = callback_response.get_json()
assert callback_data["error_type"] == "INVALID_STATE"
def test_expired_state_error(self, app, db, client, test_user, test_organization):
"""Test error handling for expired OAuth state."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create expired state
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
lifetime_seconds=-1, # Already expired
)
# Try callback with expired state
callback_response = client.get(
"/api/v1/auth/external/google/callback",
query_string={
"code": "mock-auth-code",
"state": state.state,
},
)
assert callback_response.status_code == 400
callback_data = callback_response.get_json()
assert callback_data["error_type"] == "INVALID_STATE"
def test_provider_not_configured_error(
self, app, db, client, test_user, test_organization
):
"""Test error handling when provider is not configured."""
with app.app_context():
# Create organization membership
member = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
member.save()
# Login to get token
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": test_user._test_password,
},
)
token = login_response.get_json()["data"]["token"]
# Try to link with unconfigured provider
link_response = client.post(
"/api/v1/auth/external/google/link",
json={},
headers={"Authorization": f"Bearer {token}"},
)
assert link_response.status_code == 400
link_data = link_response.get_json()
assert link_data["error_type"] == "PROVIDER_NOT_CONFIGURED"
def test_linked_accounts_list(self, app, db, client, test_user, test_organization):
"""Test listing linked accounts."""
with app.app_context():
# Create organization membership
member = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
member.save()
# Create authentication methods
google_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123",
provider_data={
"email": test_user.email,
"name": "Test User",
"picture": "https://example.com/avatar.jpg",
},
verified=True,
)
google_method.save()
github_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GITHUB,
provider_user_id="github-456",
provider_data={
"email": "user@github.com",
"name": "Test User",
},
verified=True,
)
github_method.save()
# Login to get token
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": test_user._test_password,
},
)
token = login_response.get_json()["data"]["token"]
# List linked accounts
list_response = client.get(
"/api/v1/auth/external/linked-accounts",
headers={"Authorization": f"Bearer {token}"},
)
assert list_response.status_code == 200
list_data = list_response.get_json()
assert len(list_data["data"]["linked_accounts"]) == 2
assert list_data["data"]["unlink_available"] is True
def test_non_admin_cannot_manage_providers(
self, app, db, client, test_user, test_organization
):
"""Test that non-admin users cannot manage provider configurations."""
with app.app_context():
# Create organization membership as regular member
member = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
member.save()
# Login to get token
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": test_user._test_password,
},
)
token = login_response.get_json()["data"]["token"]
# Try to create provider config (should fail)
create_response = client.post(
"/api/v1/auth/external/google/config",
json={
"client_id": "client-id",
"client_secret": "client-secret",
},
headers={"Authorization": f"Bearer {token}"},
)
assert create_response.status_code == 403
assert create_response.get_json()["error_type"] == "FORBIDDEN"
def test_unsupported_provider_error(
self, app, db, client, test_user, test_organization
):
"""Test error handling for unsupported provider."""
with app.app_context():
# Create organization membership
member = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
member.save()
# Login to get token
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": test_user._test_password,
},
)
token = login_response.get_json()["data"]["token"]
# Try to link with unsupported provider
link_response = client.post(
"/api/v1/auth/external/unsupported/link",
json={},
headers={"Authorization": f"Bearer {token}"},
)
assert link_response.status_code == 400
link_data = link_response.get_json()
assert link_data["error_type"] == "UNSUPPORTED_PROVIDER"
@pytest.mark.integration
class TestExternalAuthAuditLogging:
"""Integration tests for audit logging in external auth flows."""
@patch('gatehouse_app.services.audit_service.AuditService')
def test_audit_log_on_link_initiated(
self, mock_audit, app, db, client, test_user, test_organization
):
"""Test audit log is created when link flow is initiated."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create organization membership
member = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
member.save()
# Login to get token
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": test_user._test_password,
},
)
token = login_response.get_json()["data"]["token"]
# Initiate link flow
link_response = client.post(
"/api/v1/auth/external/google/link",
json={},
headers={"Authorization": f"Bearer {token}"},
)
# Verify audit log was called
mock_audit.log_external_auth_link_initiated.assert_called_once()
@patch('gatehouse_app.services.audit_service.AuditService')
def test_audit_log_on_unlink(
self, mock_audit, app, db, client, test_user, test_organization
):
"""Test audit log is created when account is unlinked."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create organization membership
member = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
member.save()
# Create password auth method
password_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.PASSWORD,
provider_user_id=test_user.id,
)
password_method.save()
# Create Google auth method
google_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123",
provider_data={"email": test_user.email},
verified=True,
)
google_method.save()
# Login to get token
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": test_user._test_password,
},
)
token = login_response.get_json()["data"]["token"]
# Unlink Google account
unlink_response = client.delete(
"/api/v1/auth/external/google/unlink",
headers={"Authorization": f"Bearer {token}"},
)
# Verify audit log was called
mock_audit.log_external_auth_unlink.assert_called_once()
-933
View File
@@ -1,933 +0,0 @@
"""Integration tests for MFA compliance enforcement."""
import pytest
import json
from datetime import datetime, timezone, timedelta
from gatehouse_app.models.user import User
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.session import Session
from gatehouse_app.utils.constants import MfaPolicyMode, MfaComplianceStatus, UserStatus, MfaRequirementOverride
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
@pytest.mark.integration
class TestMfaComplianceLogin:
"""Integration tests for MFA compliance during login."""
def test_login_with_no_policy(self, client, db, test_user):
"""Test login with no MFA policy (should work normally)."""
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# No MFA compliance info should be present when no policy exists
assert "mfa_compliance" not in data["data"]
assert "requires_mfa_enrollment" not in data["data"]
def test_login_with_optional_policy(self, client, db, test_user, test_organization):
"""Test login with optional MFA policy (should work normally)."""
# Create an optional MFA policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# MFA compliance should be present but status should be not_applicable
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] == "not_applicable"
assert "requires_mfa_enrollment" not in data["data"]
def test_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization):
"""Test login with required policy within grace period (should work with warning)."""
# Create a required MFA policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# MFA compliance should be present with in_grace status
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace"
assert "requires_mfa_enrollment" not in data["data"]
assert "totp" in data["data"]["mfa_compliance"]["missing_methods"]
def test_login_with_required_policy_after_deadline(self, client, db, test_user, test_organization):
"""Test login with required policy after deadline (should get compliance-only session)."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# Should have compliance-only session
assert data["data"]["requires_mfa_enrollment"] is True
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] in ["past_due", "suspended"]
def test_login_with_suspended_user(self, client, db, test_user, test_organization):
"""Test login with compliance suspended user (should get compliance-only session)."""
# Set user status to compliance suspended
test_user.status = UserStatus.COMPLIANCE_SUSPENDED
db.session.commit()
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "user" in data["data"]
assert "token" in data["data"]
# Should have compliance-only session
assert data["data"]["requires_mfa_enrollment"] is True
@pytest.mark.integration
class TestMfaComplianceAccess:
"""Integration tests for MFA compliance access control."""
def test_compliance_only_session_denied_full_access(self, client, db, test_user, test_organization):
"""Test that compliance-only session cannot access full access endpoints."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
# Create a compliance-only session
session = Session(
user_id=test_user.id,
token="compliance_only_token",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
is_compliance_only=True,
)
db.session.add(session)
db.session.commit()
# Try to access a full-access endpoint (get_my_organizations)
response = client.get(
"/api/v1/users/me/organizations",
headers={"Authorization": "Bearer compliance_only_token"},
)
assert response.status_code == 403
data = response.get_json()
assert data["success"] is False
assert data["error_type"] == "MFA_COMPLIANCE_REQUIRED"
def test_compliance_only_session_can_access_mfa_enrollment(self, client, db, test_user, test_organization):
"""Test that compliance-only session can access MFA enrollment endpoints."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
# Create a compliance-only session
session = Session(
user_id=test_user.id,
token="compliance_only_token",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
is_compliance_only=True,
)
db.session.add(session)
db.session.commit()
# Try to access MFA enrollment endpoint (should work)
response = client.get(
"/api/v1/auth/totp/status",
headers={"Authorization": "Bearer compliance_only_token"},
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
def test_compliance_only_session_can_access_logout(self, client, db, test_user, test_organization):
"""Test that compliance-only session can access logout endpoint."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
# Create a compliance-only session
session = Session(
user_id=test_user.id,
token="compliance_only_token",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
is_compliance_only=True,
)
db.session.add(session)
db.session.commit()
# Try to access logout endpoint (should work)
response = client.post(
"/api/v1/auth/logout",
headers={"Authorization": "Bearer compliance_only_token"},
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
@pytest.mark.integration
class TestMfaComplianceWebAuthn:
"""Integration tests for MFA compliance with WebAuthn login."""
def test_webauthn_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization):
"""Test WebAuthn login with required policy within grace period."""
# Create a required MFA policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Note: Full WebAuthn login test would require WebAuthn setup
# This test verifies the compliance response structure
login_data = {
"email": test_user.email,
"password": "TestPassword123!",
}
response = client.post(
"/api/v1/auth/login",
data=json.dumps(login_data),
content_type="application/json",
)
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "mfa_compliance" in data["data"]
assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace"
@pytest.mark.integration
class TestMfaComplianceOIDC:
"""Integration tests for MFA compliance with OIDC authorization."""
def test_oidc_authorize_with_compliance_required(self, client, db, test_user, test_organization, app):
"""Test OIDC authorize with compliance required (should show error)."""
# Create a required MFA policy with past deadline
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
db.session.commit()
# Try OIDC authorize with credentials
response = client.post(
"/oidc/authorize",
data={
"client_id": "test_client",
"redirect_uri": "http://localhost:8080/callback",
"response_type": "code",
"scope": "openid profile email",
"state": "test_state",
"email": test_user.email,
"password": "TestPassword123!",
},
)
# Should return login page with error
assert response.status_code == 200
assert b"Your account requires multi factor enrollment before using single sign on" in response.data
# =============================================================================
# Phase 4: Edge Case Tests
# =============================================================================
@pytest.mark.integration
class TestMfaComplianceMultiOrg:
"""Integration tests for multi-organization MFA compliance edge cases."""
def test_user_with_multiple_orgs_different_policies(self, client, db, test_user):
"""Test user belonging to multiple orgs with different MFA policies."""
# Create two organizations
org1 = Organization(
name="Org1",
slug="org1-test-multi",
)
org2 = Organization(
name="Org2",
slug="org2-test-multi",
)
db.session.add_all([org1, org2])
db.session.commit()
# Add user to both orgs
membership1 = OrganizationMember(
user_id=test_user.id,
organization_id=org1.id,
role="member",
)
membership2 = OrganizationMember(
user_id=test_user.id,
organization_id=org2.id,
role="member",
)
db.session.add_all([membership1, membership2])
db.session.commit()
# Create different policies for each org
# Org1: OPTIONAL (no requirement)
policy1 = OrganizationSecurityPolicy(
organization_id=org1.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
# Org2: REQUIRE_TOTP (strictest)
policy2 = OrganizationSecurityPolicy(
organization_id=org2.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add_all([policy1, policy2])
db.session.commit()
# Evaluate user MFA state
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user)
# Overall status should reflect the strictest policy (REQUIRE_TOTP from org2)
assert compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value
assert "totp" in compliance_summary.missing_methods
# Verify per-org breakdown
assert len(compliance_summary.orgs) == 2
org1_status = next((o for o in compliance_summary.orgs if o.organization_id == org1.id), None)
org2_status = next((o for o in compliance_summary.orgs if o.organization_id == org2.id), None)
assert org1_status is not None
assert org2_status is not None
assert org1_status.status == MfaComplianceStatus.NOT_APPLICABLE.value
assert org2_status.status == MfaComplianceStatus.IN_GRACE.value
def test_user_with_multiple_orgs_all_suspended(self, client, db, test_user):
"""Test user with multiple orgs where all require MFA and are past due."""
# Create two organizations
org1 = Organization(
name="Org1",
slug="org1-test-suspended",
)
org2 = Organization(
name="Org2",
slug="org2-test-suspended",
)
db.session.add_all([org1, org2])
db.session.commit()
# Add user to both orgs
membership1 = OrganizationMember(
user_id=test_user.id,
organization_id=org1.id,
role="member",
)
membership2 = OrganizationMember(
user_id=test_user.id,
organization_id=org2.id,
role="member",
)
db.session.add_all([membership1, membership2])
db.session.commit()
# Create required policies
policy1 = OrganizationSecurityPolicy(
organization_id=org1.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
policy2 = OrganizationSecurityPolicy(
organization_id=org2.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add_all([policy1, policy2])
db.session.commit()
# Create past-due compliance records for both
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
compliance1 = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=org1.id,
status=MfaComplianceStatus.SUSPENDED,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
deadline_at=past_deadline,
suspended_at=past_deadline,
)
compliance2 = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=org2.id,
status=MfaComplianceStatus.SUSPENDED,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
deadline_at=past_deadline,
suspended_at=past_deadline,
)
db.session.add_all([compliance1, compliance2])
db.session.commit()
# Evaluate user MFA state
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user)
# Overall status should be SUSPENDED
assert compliance_summary.overall_status == MfaComplianceStatus.SUSPENDED.value
def test_strictest_mode_selection(self):
"""Test that get_strictest_mode returns the most restrictive policy."""
modes = [
MfaPolicyMode.DISABLED.value,
MfaPolicyMode.OPTIONAL.value,
MfaPolicyMode.REQUIRE_TOTP.value,
]
result = MfaPolicyService.get_strictest_mode(modes)
assert result == MfaPolicyMode.REQUIRE_TOTP.value
# Test with REQUIRE_TOTP_OR_WEBAUTHN (strictest)
modes_strictest = [
MfaPolicyMode.REQUIRE_TOTP.value,
MfaPolicyMode.REQUIRE_WEBAUTHN.value,
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
]
result = MfaPolicyService.get_strictest_mode(modes_strictest)
assert result == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
@pytest.mark.integration
class TestMfaComplianceUserOverrides:
"""Integration tests for user override edge cases."""
def test_user_override_inherit_mode(self, client, db, test_user, test_organization):
"""Test INHERIT mode - org policy applies as is."""
# Create a required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create INHERIT override (default behavior)
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
db.session.add(override)
db.session.commit()
# Get effective policy
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
# Should inherit org policy
assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
assert effective.requires_totp is True
assert effective.is_exempt is False
def test_user_override_required_mode(self, client, db, test_user, test_organization):
"""Test REQUIRED mode - user always required to have MFA."""
# Create an optional policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create REQUIRED override
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
)
db.session.add(override)
db.session.commit()
# Get effective policy
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
# Should be upgraded to REQUIRE_TOTP_OR_WEBAUTHN
assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
assert effective.requires_totp is True
assert effective.requires_webauthn is True
assert effective.is_exempt is False
def test_user_override_exempt_mode(self, client, db, test_user, test_organization):
"""Test EXEMPT mode - org policy does not apply."""
# Create a required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create EXEMPT override
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
db.session.add(override)
db.session.commit()
# Get effective policy
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
# Should be exempt from policy
assert effective.is_exempt is True
assert effective.effective_mode == MfaPolicyMode.DISABLED.value
assert effective.requires_totp is False
assert effective.requires_webauthn is False
def test_get_override_summary(self, client, db, test_user, test_organization):
"""Test getting override summary for a user."""
# No override exists
summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id)
assert summary["has_override"] is False
assert summary["mode"] == "inherit"
# Create an override
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
db.session.add(override)
db.session.commit()
# Get summary again
summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id)
assert summary["has_override"] is True
assert summary["mode"] == "exempt"
assert summary["is_exempt"] is True
@pytest.mark.integration
class TestMfaCompliancePolicyChanges:
"""Integration tests for policy changes affecting existing users."""
def test_policy_change_triggers_compliance_reevaluation(self, client, db, test_user, test_organization):
"""Test that policy change triggers compliance reevaluation."""
# Create initial optional policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create compliance record (should be NOT_APPLICABLE)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.NOT_APPLICABLE,
policy_version=1,
)
db.session.add(compliance)
db.session.commit()
# Update policy to REQUIRE_TOTP
MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
updated_by_user_id=test_user.id,
)
# Reevaluate all compliance
updated_count = MfaPolicyService.reevaluate_all_org_compliance(test_organization.id)
# Should have updated at least one record
assert updated_count >= 1
# Check compliance status was updated
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert updated_compliance.status == MfaComplianceStatus.IN_GRACE.value
assert updated_compliance.deadline_at is not None
def test_policy_relaxation_clears_requirements(self, client, db, test_user, test_organization):
"""Test that relaxing policy clears compliance requirements."""
# Create required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create IN_GRACE compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.IN_GRACE,
policy_version=1,
applied_at=datetime.now(timezone.utc),
deadline_at=datetime.now(timezone.utc) + timedelta(days=14),
)
db.session.add(compliance)
db.session.commit()
# Update policy to OPTIONAL
MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
updated_by_user_id=test_user.id,
)
# Reevaluate compliance
MfaPolicyService.reevaluate_all_org_compliance(test_organization.id)
# Check compliance status was updated to NOT_APPLICABLE
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert updated_compliance.status == MfaComplianceStatus.NOT_APPLICABLE.value
@pytest.mark.integration
class TestMfaComplianceScheduledJob:
"""Integration tests for the MFA compliance scheduled job."""
def test_transition_to_suspended(self, client, db, test_user, test_organization):
"""Test that past-due users are transitioned to suspended."""
# Create required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# Create past-due compliance record
past_deadline = datetime.now(timezone.utc) - timedelta(hours=1)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
deadline_at=past_deadline,
)
db.session.add(compliance)
db.session.commit()
# Run the job
now = datetime.now(timezone.utc)
suspended_count = MfaPolicyService.transition_to_suspended_if_past_due(now)
# Should have suspended the user
assert suspended_count >= 1
# Check compliance status
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert updated_compliance.status == MfaComplianceStatus.SUSPENDED.value
assert updated_compliance.suspended_at is not None
# Check user status
db.refresh(test_user)
assert test_user.status == UserStatus.COMPLIANCE_SUSPENDED
def test_check_and_restore_user_status(self, client, db, test_user, test_organization):
"""Test that suspended users are restored when they become compliant."""
# Create required policy
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add(policy)
db.session.commit()
# User is suspended
test_user.status = UserStatus.COMPLIANCE_SUSPENDED
db.session.commit()
# Create EXEMPT override to clear requirement
override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
db.session.add(override)
db.session.commit()
# Check and restore status
restored = MfaPolicyService.check_and_restore_user_status(test_user.id)
# Should have restored user
assert restored is True
db.refresh(test_user)
assert test_user.status == UserStatus.ACTIVE
@pytest.mark.integration
class TestMfaComplianceMultiOrgAggregate:
"""Integration tests for multi-org aggregate state calculation."""
def test_get_multi_org_aggregate_state(self, client, db, test_user):
"""Test aggregate state calculation for multi-org user."""
# Create two organizations
org1 = Organization(
name="AggOrg1",
slug="agg-org1-test",
)
org2 = Organization(
name="AggOrg2",
slug="agg-org2-test",
)
db.session.add_all([org1, org2])
db.session.commit()
# Add user to both
membership1 = OrganizationMember(
user_id=test_user.id,
organization_id=org1.id,
role="member",
)
membership2 = OrganizationMember(
user_id=test_user.id,
organization_id=org2.id,
role="member",
)
db.session.add_all([membership1, membership2])
db.session.commit()
# Create policies
policy1 = OrganizationSecurityPolicy(
organization_id=org1.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
policy2 = OrganizationSecurityPolicy(
organization_id=org2.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
db.session.add_all([policy1, policy2])
db.session.commit()
# Get aggregate state
aggregate = MfaPolicyService.get_multi_org_aggregate_state(test_user)
# Verify structure
assert "overall_status" in aggregate
assert "strictest_mode" in aggregate
assert "missing_methods" in aggregate
assert "requiring_org_count" in aggregate
assert "requiring_orgs" in aggregate
assert "per_org_details" in aggregate
# Strictest mode should be REQUIRE_TOTP_OR_WEBAUTHN
assert aggregate["strictest_mode"] == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
# Both orgs should require MFA
assert aggregate["requiring_org_count"] == 2
assert len(aggregate["requiring_orgs"]) == 2
assert len(aggregate["per_org_details"]) == 2
File diff suppressed because it is too large Load Diff
-1
View File
@@ -1 +0,0 @@
"""Unit tests package."""
-295
View File
@@ -1,295 +0,0 @@
"""Unit tests for MFA policy models."""
import pytest
from datetime import datetime, timezone, timedelta
from gatehouse_app.models import (
User,
Organization,
OrganizationMember,
OrganizationSecurityPolicy,
UserSecurityPolicy,
MfaPolicyCompliance,
Session,
)
from gatehouse_app.utils.constants import (
UserStatus,
MfaPolicyMode,
MfaComplianceStatus,
MfaRequirementOverride,
SessionStatus,
OrganizationRole,
)
@pytest.mark.unit
class TestOrganizationSecurityPolicyModel:
"""Tests for OrganizationSecurityPolicy model."""
def test_create_org_security_policy(self, db, test_organization):
"""Test creating an organization security policy."""
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
notify_days_before=7,
)
policy.save()
assert policy.id is not None
assert policy.organization_id == test_organization.id
assert policy.mfa_policy_mode == MfaPolicyMode.OPTIONAL
assert policy.mfa_grace_period_days == 14
assert policy.notify_days_before == 7
assert policy.policy_version == 1
assert policy.created_at is not None
def test_org_security_policy_to_dict(self, db, test_organization):
"""Test organization security policy to_dict method."""
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=7,
notify_days_before=3,
)
policy.save()
policy_dict = policy.to_dict()
assert "id" in policy_dict
assert "organization_id" in policy_dict
assert policy_dict["organization_id"] == test_organization.id
assert "mfa_policy_mode" in policy_dict
assert "mfa_grace_period_days" in policy_dict
def test_org_security_policy_relationships(self, db, test_organization):
"""Test organization security policy relationships."""
policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
)
policy.save()
# Test relationship
assert policy.organization is not None
assert policy.organization.id == test_organization.id
@pytest.mark.unit
class TestUserSecurityPolicyModel:
"""Tests for UserSecurityPolicy model."""
def test_create_user_security_policy(self, db, test_user, test_organization):
"""Test creating a user security policy."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
policy.save()
assert policy.id is not None
assert policy.user_id == test_user.id
assert policy.organization_id == test_organization.id
assert policy.mfa_override_mode == MfaRequirementOverride.INHERIT
assert policy.force_totp is False
assert policy.force_webauthn is False
def test_user_security_policy_with_overrides(self, db, test_user, test_organization):
"""Test user security policy with override settings."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
force_totp=True,
force_webauthn=False,
)
policy.save()
assert policy.mfa_override_mode == MfaRequirementOverride.REQUIRED
assert policy.force_totp is True
assert policy.force_webauthn is False
def test_user_security_policy_exempt(self, db, test_user, test_organization):
"""Test user security policy with exempt override."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
policy.save()
assert policy.mfa_override_mode == MfaRequirementOverride.EXEMPT
def test_user_security_policy_relationships(self, db, test_user, test_organization):
"""Test user security policy relationships."""
policy = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
policy.save()
# Test relationships
assert policy.user is not None
assert policy.user.id == test_user.id
assert policy.organization is not None
assert policy.organization.id == test_organization.id
@pytest.mark.unit
class TestMfaPolicyComplianceModel:
"""Tests for MfaPolicyCompliance model."""
def test_create_mfa_policy_compliance(self, db, test_user, test_organization):
"""Test creating an MFA policy compliance record."""
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.NOT_APPLICABLE,
policy_version=1,
)
compliance.save()
assert compliance.id is not None
assert compliance.user_id == test_user.id
assert compliance.organization_id == test_organization.id
assert compliance.status == MfaComplianceStatus.NOT_APPLICABLE
assert compliance.policy_version == 1
assert compliance.notification_count == 0
def test_mfa_policy_compliance_in_grace(self, db, test_user, test_organization):
"""Test MFA compliance record in grace period."""
now = datetime.now(timezone.utc)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.IN_GRACE,
policy_version=1,
applied_at=now,
deadline_at=now + timedelta(days=14),
)
compliance.save()
assert compliance.status == MfaComplianceStatus.IN_GRACE
assert compliance.applied_at is not None
assert compliance.deadline_at is not None
assert compliance.deadline_at > now
def test_mfa_policy_compliance_compliant(self, db, test_user, test_organization):
"""Test MFA compliance record when compliant."""
now = datetime.now(timezone.utc)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.COMPLIANT,
policy_version=1,
applied_at=now - timedelta(days=30),
deadline_at=now - timedelta(days=16),
compliant_at=now - timedelta(days=16),
)
compliance.save()
assert compliance.status == MfaComplianceStatus.COMPLIANT
assert compliance.compliant_at is not None
def test_mfa_policy_compliance_suspended(self, db, test_user, test_organization):
"""Test MFA compliance record when suspended."""
now = datetime.now(timezone.utc)
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.SUSPENDED,
policy_version=1,
applied_at=now - timedelta(days=30),
deadline_at=now - timedelta(days=16),
suspended_at=now - timedelta(days=16),
)
compliance.save()
assert compliance.status == MfaComplianceStatus.SUSPENDED
assert compliance.suspended_at is not None
def test_mfa_policy_compliance_relationships(self, db, test_user, test_organization):
"""Test MFA compliance relationships."""
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.NOT_APPLICABLE,
policy_version=1,
)
compliance.save()
# Test relationships
assert compliance.user is not None
assert compliance.user.id == test_user.id
assert compliance.organization is not None
assert compliance.organization.id == test_organization.id
@pytest.mark.unit
class TestSessionModelComplianceFlag:
"""Tests for Session model compliance flag."""
def test_session_default_not_compliance_only(self, db, test_user):
"""Test that sessions are not compliance only by default."""
session = Session(
user_id=test_user.id,
token="test-token-123",
status=SessionStatus.ACTIVE,
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
last_activity_at=datetime.now(timezone.utc),
)
session.save()
assert session.is_compliance_only is False
def test_session_compliance_only(self, db, test_user):
"""Test creating a compliance-only session."""
session = Session(
user_id=test_user.id,
token="compliance-token-123",
status=SessionStatus.ACTIVE,
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
last_activity_at=datetime.now(timezone.utc),
is_compliance_only=True,
)
session.save()
assert session.is_compliance_only is True
def test_session_to_dict_excludes_token(self, db, test_user):
"""Test that session to_dict excludes the token."""
session = Session(
user_id=test_user.id,
token="test-token-456",
status=SessionStatus.ACTIVE,
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
last_activity_at=datetime.now(timezone.utc),
)
session.save()
session_dict = session.to_dict()
assert "id" in session_dict
assert "user_id" in session_dict
assert "is_compliance_only" in session_dict
assert session_dict["is_compliance_only"] is False
@pytest.mark.unit
class TestUserStatusComplianceSuspended:
"""Tests for UserStatus.COMPLIANCE_SUSPENDED."""
def test_compliance_suspended_status_exists(self):
"""Test that COMPLIANCE_SUSPENDED status exists."""
assert UserStatus.COMPLIANCE_SUSPENDED.value == "compliance_suspended"
def test_create_compliance_suspended_user(self, db):
"""Test creating a compliance suspended user."""
user = User(
email="suspended@example.com",
full_name="Suspended User",
status=UserStatus.COMPLIANCE_SUSPENDED,
)
user.save()
assert user.status == UserStatus.COMPLIANCE_SUSPENDED
-76
View File
@@ -1,76 +0,0 @@
"""Unit tests for models."""
import pytest
from datetime import datetime
from gatehouse_app.models import User, Organization
from gatehouse_app.utils.constants import UserStatus
@pytest.mark.unit
class TestUserModel:
"""Tests for User model."""
def test_create_user(self, db):
"""Test creating a user."""
user = User(
email="test@example.com",
full_name="Test User",
status=UserStatus.ACTIVE,
)
user.save()
assert user.id is not None
assert user.email == "test@example.com"
assert user.full_name == "Test User"
assert user.status == UserStatus.ACTIVE
assert user.created_at is not None
assert user.deleted_at is None
def test_user_to_dict(self, test_user):
"""Test user to_dict method."""
user_dict = test_user.to_dict()
assert "id" in user_dict
assert "email" in user_dict
assert user_dict["email"] == test_user.email
assert "created_at" in user_dict
def test_user_soft_delete(self, test_user):
"""Test soft deleting a user."""
test_user.delete(soft=True)
assert test_user.deleted_at is not None
assert isinstance(test_user.deleted_at, datetime)
@pytest.mark.unit
class TestOrganizationModel:
"""Tests for Organization model."""
def test_create_organization(self, db):
"""Test creating an organization."""
org = Organization(
name="Test Org",
slug="test-org",
description="Test organization",
)
org.save()
assert org.id is not None
assert org.name == "Test Org"
assert org.slug == "test-org"
assert org.is_active is True
assert org.created_at is not None
def test_organization_to_dict(self, test_organization):
"""Test organization to_dict method."""
org_dict = test_organization.to_dict()
assert "id" in org_dict
assert "name" in org_dict
assert org_dict["name"] == test_organization.name
assert "slug" in org_dict
def test_get_member_count(self, test_organization):
"""Test getting member count."""
count = test_organization.get_member_count()
assert count == 1 # Only the owner
-1
View File
@@ -1 +0,0 @@
"""Services unit tests package."""
@@ -1,102 +0,0 @@
"""Unit tests for AuthService."""
import pytest
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError
from gatehouse_app.utils.constants import UserStatus, AuthMethodType
@pytest.mark.unit
class TestAuthService:
"""Tests for AuthService."""
def test_register_user(self, db):
"""Test user registration."""
email = "newuser@example.com"
password = "SecurePassword123!"
full_name = "New User"
user = AuthService.register_user(
email=email,
password=password,
full_name=full_name,
)
assert user.id is not None
assert user.email == email.lower()
assert user.full_name == full_name
assert user.status == UserStatus.ACTIVE
assert user.has_password_auth()
def test_register_duplicate_email(self, db, test_user):
"""Test registering with duplicate email."""
with pytest.raises(EmailAlreadyExistsError):
AuthService.register_user(
email=test_user.email,
password="SomePassword123!",
)
def test_authenticate_success(self, db, test_user):
"""Test successful authentication."""
user = AuthService.authenticate(
email=test_user.email,
password=test_user._test_password,
)
assert user.id == test_user.id
assert user.last_login_at is not None
def test_authenticate_wrong_password(self, db, test_user):
"""Test authentication with wrong password."""
with pytest.raises(InvalidCredentialsError):
AuthService.authenticate(
email=test_user.email,
password="WrongPassword123!",
)
def test_authenticate_nonexistent_user(self, db):
"""Test authentication with non-existent email."""
with pytest.raises(InvalidCredentialsError):
AuthService.authenticate(
email="nonexistent@example.com",
password="SomePassword123!",
)
def test_create_session(self, app, db, test_user):
"""Test creating a session."""
with app.test_request_context():
session = AuthService.create_session(test_user)
assert session.id is not None
assert session.user_id == test_user.id
assert session.token is not None
assert session.is_active()
def test_change_password(self, app, db, test_user):
"""Test changing password."""
with app.test_request_context():
new_password = "NewPassword456!"
AuthService.change_password(
user=test_user,
current_password=test_user._test_password,
new_password=new_password,
)
# Verify can login with new password
user = AuthService.authenticate(
email=test_user.email,
password=new_password,
)
assert user.id == test_user.id
def test_change_password_wrong_current(self, app, db, test_user):
"""Test changing password with wrong current password."""
with app.test_request_context():
with pytest.raises(InvalidCredentialsError):
AuthService.change_password(
user=test_user,
current_password="WrongPassword123!",
new_password="NewPassword456!",
)
@@ -1,698 +0,0 @@
"""Unit tests for ExternalAuthService."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime, timedelta, timezone
from gatehouse_app.services.external_auth_service import (
ExternalAuthService,
ExternalAuthError,
OAuthState,
ExternalProviderConfig,
)
from gatehouse_app.utils.constants import AuthMethodType
from gatehouse_app.models import User, AuthenticationMethod
@pytest.mark.unit
class TestExternalAuthService:
"""Tests for ExternalAuthService."""
def test_get_provider_config_success(self, app, db, test_organization):
"""Test getting provider configuration successfully."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
client_secret_encrypted="encrypted-secret",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Get config
result = ExternalAuthService.get_provider_config(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE,
)
assert result.id == config.id
assert result.client_id == "test-client-id"
assert result.is_active is True
def test_get_provider_config_not_configured(self, app, db, test_organization):
"""Test getting provider configuration when not configured."""
with app.app_context():
with pytest.raises(ExternalAuthError) as exc_info:
ExternalAuthService.get_provider_config(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE,
)
assert exc_info.value.error_type == "PROVIDER_NOT_CONFIGURED"
assert exc_info.value.status_code == 400
def test_get_provider_config_inactive(self, app, db, test_organization):
"""Test getting provider configuration when inactive."""
with app.app_context():
# Create inactive provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=False,
)
config.save()
with pytest.raises(ExternalAuthError) as exc_info:
ExternalAuthService.get_provider_config(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE,
)
assert exc_info.value.error_type == "PROVIDER_NOT_CONFIGURED"
@patch('gatehouse_app.services.external_auth_service.AuditService')
def test_initiate_link_flow_success(self, mock_audit, app, db, test_user, test_organization):
"""Test initiating account linking flow successfully."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Initiate link flow
auth_url, state = ExternalAuthService.initiate_link_flow(
user_id=test_user.id,
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
)
assert auth_url is not None
assert state is not None
assert len(state) == 43 # Base64 URL-safe token length
# Verify state was created
state_record = OAuthState.query.filter_by(state=state).first()
assert state_record is not None
assert state_record.flow_type == "link"
assert state_record.user_id == test_user.id
assert state_record.provider_type == AuthMethodType.GOOGLE.value
# Verify audit log
mock_audit.log_external_auth_link_initiated.assert_called_once()
def test_initiate_link_flow_invalid_redirect_uri(self, app, db, test_user, test_organization):
"""Test initiating link flow with invalid redirect URI."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
with pytest.raises(ExternalAuthError) as exc_info:
ExternalAuthService.initiate_link_flow(
user_id=test_user.id,
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://malicious-site.com/callback",
)
assert exc_info.value.error_type == "INVALID_REDIRECT_URI"
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._exchange_code')
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._get_user_info')
@patch('gatehouse_app.services.external_auth_service.AuditService')
def test_complete_link_flow_success(
self, mock_audit, mock_get_user_info, mock_exchange_code,
app, db, test_user, test_organization
):
"""Test completing account linking flow successfully."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create OAuth state
state = OAuthState.create_state(
flow_type="link",
provider_type=AuthMethodType.GOOGLE,
user_id=test_user.id,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
)
# Mock external provider responses
mock_exchange_code.return_value = {
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"id_token": "mock-id-token",
"expires_in": 3600,
}
mock_get_user_info.return_value = {
"provider_user_id": "google-123",
"email": "user@gmail.com",
"email_verified": True,
"name": "Test User",
"picture": "https://example.com/avatar.jpg",
"raw_data": {},
}
# Complete link flow
auth_method = ExternalAuthService.complete_link_flow(
provider_type=AuthMethodType.GOOGLE,
authorization_code="mock-auth-code",
state=state.state,
redirect_uri="http://localhost:3000/callback",
)
assert auth_method is not None
assert auth_method.user_id == test_user.id
assert auth_method.method_type == AuthMethodType.GOOGLE
assert auth_method.provider_user_id == "google-123"
# Verify state is marked as used
state_record = OAuthState.query.get(state.id)
assert state_record.used is True
# Verify audit log
mock_audit.log_external_auth_link_completed.assert_called_once()
def test_complete_link_flow_invalid_state(self, app, db):
"""Test completing link flow with invalid state."""
with app.app_context():
with pytest.raises(ExternalAuthError) as exc_info:
ExternalAuthService.complete_link_flow(
provider_type=AuthMethodType.GOOGLE,
authorization_code="mock-auth-code",
state="invalid-state",
redirect_uri="http://localhost:3000/callback",
)
assert exc_info.value.error_type == "INVALID_STATE"
def test_complete_link_flow_wrong_flow_type(self, app, db, test_organization):
"""Test completing link flow with wrong flow type state."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create login flow state instead of link
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
)
with pytest.raises(ExternalAuthError) as exc_info:
ExternalAuthService.complete_link_flow(
provider_type=AuthMethodType.GOOGLE,
authorization_code="mock-auth-code",
state=state.state,
redirect_uri="http://localhost:3000/callback",
)
assert exc_info.value.error_type == "INVALID_FLOW_TYPE"
def test_complete_link_flow_provider_mismatch(self, app, db, test_organization):
"""Test completing link flow with provider mismatch."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create state with different provider
state = OAuthState.create_state(
flow_type="link",
provider_type=AuthMethodType.GITHUB,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
)
with pytest.raises(ExternalAuthError) as exc_info:
ExternalAuthService.complete_link_flow(
provider_type=AuthMethodType.GOOGLE,
authorization_code="mock-auth-code",
state=state.state,
redirect_uri="http://localhost:3000/callback",
)
assert exc_info.value.error_type == "PROVIDER_MISMATCH"
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._exchange_code')
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._get_user_info')
@patch('gatehouse_app.services.external_auth_service.AuditService')
def test_authenticate_with_provider_success(
self, mock_audit, mock_get_user_info, mock_exchange_code,
app, db, test_user, test_organization
):
"""Test authenticating with provider successfully."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create authentication method for user
auth_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123",
provider_data={"email": test_user.email},
verified=True,
)
auth_method.save()
# Create OAuth state
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
)
# Mock external provider responses
mock_exchange_code.return_value = {
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"id_token": "mock-id-token",
"expires_in": 3600,
}
mock_get_user_info.return_value = {
"provider_user_id": "google-123",
"email": test_user.email,
"email_verified": True,
"name": "Test User",
"picture": "https://example.com/avatar.jpg",
"raw_data": {},
}
# Authenticate
user, session_data = ExternalAuthService.authenticate_with_provider(
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
authorization_code="mock-auth-code",
state=state.state,
redirect_uri="http://localhost:3000/callback",
)
assert user.id == test_user.id
assert session_data is not None
assert "token" in session_data
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._exchange_code')
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._get_user_info')
@patch('gatehouse_app.services.external_auth_service.AuditService')
def test_authenticate_with_provider_account_not_found(
self, mock_audit, mock_get_user_info, mock_exchange_code,
app, db, test_organization
):
"""Test authenticating with provider when account not found."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create OAuth state
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
)
# Mock external provider responses
mock_exchange_code.return_value = {
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"id_token": "mock-id-token",
"expires_in": 3600,
}
mock_get_user_info.return_value = {
"provider_user_id": "google-456",
"email": "newuser@gmail.com",
"email_verified": True,
"name": "New User",
"picture": "https://example.com/avatar.jpg",
"raw_data": {},
}
with pytest.raises(ExternalAuthError) as exc_info:
ExternalAuthService.authenticate_with_provider(
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
authorization_code="mock-auth-code",
state=state.state,
redirect_uri="http://localhost:3000/callback",
)
assert exc_info.value.error_type == "ACCOUNT_NOT_FOUND"
@patch('gatehouse_app.services.external_auth_service.AuditService')
def test_unlink_provider_success(self, mock_audit, app, db, test_user):
"""Test unlinking provider successfully."""
with app.app_context():
# Create password auth method first (so user has other methods)
password_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.PASSWORD,
provider_user_id=test_user.id,
)
password_method.save()
# Create Google auth method
google_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123",
provider_data={"email": test_user.email},
verified=True,
)
google_method.save()
# Unlink Google
result = ExternalAuthService.unlink_provider(
user_id=test_user.id,
provider_type=AuthMethodType.GOOGLE,
)
assert result is True
# Verify auth method is deleted
method = AuthenticationMethod.query.filter_by(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
).first()
assert method is None
# Verify audit log
mock_audit.log_external_auth_unlink.assert_called_once()
def test_unlink_provider_not_linked(self, app, db, test_user):
"""Test unlinking provider that is not linked."""
with app.app_context():
with pytest.raises(ExternalAuthError) as exc_info:
ExternalAuthService.unlink_provider(
user_id=test_user.id,
provider_type=AuthMethodType.GOOGLE,
)
assert exc_info.value.error_type == "PROVIDER_NOT_LINKED"
def test_unlink_provider_last_method(self, app, db, test_user):
"""Test unlinking last authentication method."""
with app.app_context():
# Create only Google auth method
google_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123",
provider_data={"email": test_user.email},
verified=True,
)
google_method.save()
with pytest.raises(ExternalAuthError) as exc_info:
ExternalAuthService.unlink_provider(
user_id=test_user.id,
provider_type=AuthMethodType.GOOGLE,
)
assert exc_info.value.error_type == "CANNOT_UNLINK_LAST"
def test_get_linked_accounts(self, app, db, test_user):
"""Test getting linked accounts for user."""
with app.app_context():
# Create Google auth method
google_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123",
provider_data={
"email": test_user.email,
"name": "Test User",
"picture": "https://example.com/avatar.jpg",
},
verified=True,
)
google_method.save()
# Create GitHub auth method
github_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GITHUB,
provider_user_id="github-456",
provider_data={
"email": "user@github.com",
"name": "Test User",
},
verified=True,
)
github_method.save()
# Get linked accounts
accounts = ExternalAuthService.get_linked_accounts(test_user.id)
assert len(accounts) == 2
google_account = next(a for a in accounts if a["provider_type"] == "google")
assert google_account["provider_user_id"] == "google-123"
assert google_account["email"] == test_user.email
github_account = next(a for a in accounts if a["provider_type"] == "github")
assert github_account["provider_user_id"] == "github-456"
@pytest.mark.unit
class TestOAuthState:
"""Tests for OAuthState model."""
def test_create_state(self, app, db):
"""Test creating OAuth state."""
with app.app_context():
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
user_id="user-123",
organization_id="org-456",
redirect_uri="http://localhost:3000/callback",
)
assert state.state is not None
assert len(state.state) == 43
assert state.flow_type == "login"
assert state.provider_type == AuthMethodType.GOOGLE.value
assert state.user_id == "user-123"
assert state.organization_id == "org-456"
assert state.redirect_uri == "http://localhost:3000/callback"
assert state.used is False
assert state.expires_at > datetime.now(timezone.utc)
def test_is_valid(self, app, db):
"""Test OAuth state validity check."""
with app.app_context():
# Create valid state
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
)
assert state.is_valid() is True
# Mark as used
state.mark_used()
assert state.is_valid() is False
def test_is_valid_expired(self, app, db):
"""Test OAuth state validity with expiration."""
with app.app_context():
# Create expired state
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
lifetime_seconds=-1, # Already expired
)
assert state.is_valid() is False
@pytest.mark.unit
class TestExternalProviderConfig:
"""Tests for ExternalProviderConfig model."""
def test_is_redirect_uri_allowed(self, app, db, test_organization):
"""Test redirect URI validation."""
with app.app_context():
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=[
"http://localhost:3000/callback",
"https://myapp.com/callback",
],
is_active=True,
)
config.save()
assert config.is_redirect_uri_allowed("http://localhost:3000/callback") is True
assert config.is_redirect_uri_allowed("https://myapp.com/callback") is True
assert config.is_redirect_uri_allowed("http://malicious.com/callback") is False
def test_to_dict(self, app, db, test_organization):
"""Test converting config to dictionary."""
with app.app_context():
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
result = config.to_dict()
assert result["organization_id"] == test_organization.id
assert result["provider_type"] == AuthMethodType.GOOGLE.value
assert result["client_id"] == "test-client-id"
assert "client_secret" not in result
assert result["is_active"] is True
def test_to_dict_include_secrets(self, app, db, test_organization):
"""Test converting config to dictionary with secrets."""
with app.app_context():
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
client_secret_encrypted="encrypted-secret",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
result = config.to_dict(include_secrets=True)
assert "client_secret" in result
@pytest.mark.unit
class TestExternalAuthError:
"""Tests for ExternalAuthError exception."""
def test_error_creation(self):
"""Test creating ExternalAuthError."""
error = ExternalAuthError(
message="Test error message",
error_type="TEST_ERROR",
status_code=400,
)
assert error.message == "Test error message"
assert error.error_type == "TEST_ERROR"
assert error.status_code == 400
def test_error_default_status_code(self):
"""Test ExternalAuthError with default status code."""
error = ExternalAuthError(
message="Test error message",
error_type="TEST_ERROR",
)
assert error.status_code == 400
@@ -1,476 +0,0 @@
"""Unit tests for MfaPolicyService."""
import pytest
from datetime import datetime, timezone, timedelta
from unittest.mock import patch, MagicMock
from gatehouse_app.models import (
User,
Organization,
OrganizationMember,
OrganizationSecurityPolicy,
UserSecurityPolicy,
MfaPolicyCompliance,
Session,
)
from gatehouse_app.services.mfa_policy_service import (
MfaPolicyService,
OrgPolicyDto,
EffectiveUserPolicyDto,
AggregateMfaStateDto,
LoginPolicyResult,
)
from gatehouse_app.utils.constants import (
UserStatus,
MfaPolicyMode,
MfaComplianceStatus,
MfaRequirementOverride,
SessionStatus,
OrganizationRole,
)
@pytest.mark.unit
class TestMfaPolicyService:
"""Tests for MfaPolicyService."""
def test_get_org_policy_not_found(self, db, test_organization):
"""Test getting organization policy when none exists."""
policy = MfaPolicyService.get_org_policy(test_organization.id)
assert policy is None
def test_get_org_policy_found(self, db, test_organization):
"""Test getting organization policy when it exists."""
# Create policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
)
org_policy.save()
policy = MfaPolicyService.get_org_policy(test_organization.id)
assert policy is not None
assert policy.organization_id == test_organization.id
assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
assert policy.mfa_grace_period_days == 14
assert policy.notify_days_before == 7
assert policy.policy_version == 1
def test_get_effective_user_policy_no_org_policy(self, db, test_user, test_organization):
"""Test effective user policy when no org policy exists."""
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy is not None
assert policy.organization_id == test_organization.id
assert policy.effective_mode == MfaPolicyMode.DISABLED.value
assert policy.requires_totp is False
assert policy.requires_webauthn is False
assert policy.is_exempt is True
def test_get_effective_user_policy_with_org_policy(self, db, test_user, test_organization):
"""Test effective user policy with org policy and no override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy is not None
assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
assert policy.requires_totp is True
assert policy.requires_webauthn is False
assert policy.is_exempt is False
def test_get_effective_user_policy_with_override_inherit(self, db, test_user, test_organization):
"""Test effective user policy with INHERIT override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
mfa_grace_period_days=7,
)
org_policy.save()
# Create user override
user_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
user_override.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy.effective_mode == MfaPolicyMode.REQUIRE_WEBAUTHN.value
assert policy.requires_webauthn is True
def test_get_effective_user_policy_with_override_exempt(self, db, test_user, test_organization):
"""Test effective user policy with EXEMPT override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
)
org_policy.save()
# Create user override
user_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
)
user_override.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy.effective_mode == MfaPolicyMode.DISABLED.value
assert policy.is_exempt is True
def test_get_effective_user_policy_with_override_required(self, db, test_user, test_organization):
"""Test effective user policy with REQUIRED override."""
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
)
org_policy.save()
# Create user override
user_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
)
user_override.save()
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
assert policy.requires_totp is True
assert policy.requires_webauthn is True
assert policy.is_exempt is False
def test_evaluate_user_mfa_state_no_policy(self, db, test_user, test_organization):
"""Test evaluating user MFA state with no policy."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
state = MfaPolicyService.evaluate_user_mfa_state(test_user)
assert state is not None
assert state.overall_status == MfaComplianceStatus.COMPLIANT.value
assert len(state.missing_methods) == 0
assert len(state.orgs) == 1
def test_evaluate_user_mfa_state_with_policy(self, db, test_user, test_organization):
"""Test evaluating user MFA state with policy."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
state = MfaPolicyService.evaluate_user_mfa_state(test_user)
assert state is not None
assert state.overall_status == MfaComplianceStatus.IN_GRACE.value
assert "totp" in state.missing_methods
assert len(state.orgs) == 1
assert state.orgs[0].effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
def test_after_primary_auth_success_no_required_policy(self, db, test_user, test_organization):
"""Test after_primary_auth_success with no required policy."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
result = MfaPolicyService.after_primary_auth_success(test_user)
assert result.can_create_full_session is True
assert result.create_compliance_only_session is False
assert result.compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value
def test_after_primary_auth_success_in_grace(self, db, test_user, test_organization):
"""Test after_primary_auth_success when user is in grace period."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
result = MfaPolicyService.after_primary_auth_success(test_user)
assert result.can_create_full_session is True
assert result.create_compliance_only_session is False
assert result.compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value
def test_after_primary_auth_success_past_due(self, db, test_user, test_organization):
"""Test after_primary_auth_success when user is past due."""
# Create membership
membership = OrganizationMember(
user_id=test_user.id,
organization_id=test_organization.id,
role=OrganizationRole.MEMBER,
)
membership.save()
# Create org policy
org_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=14,
)
org_policy.save()
# Create compliance record past due
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.PAST_DUE,
policy_version=1,
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
deadline_at=datetime.now(timezone.utc) - timedelta(days=1),
)
compliance.save()
result = MfaPolicyService.after_primary_auth_success(test_user)
assert result.can_create_full_session is False
assert result.create_compliance_only_session is True
def test_create_org_policy_new(self, db, test_organization):
"""Test creating a new organization policy."""
policy = MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
mfa_grace_period_days=14,
notify_days_before=7,
updated_by_user_id=None,
)
assert policy is not None
assert policy.organization_id == test_organization.id
assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN
assert policy.policy_version == 1
def test_create_org_policy_update(self, db, test_organization):
"""Test updating an existing organization policy."""
# Create initial policy
initial_policy = OrganizationSecurityPolicy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
mfa_grace_period_days=14,
)
initial_policy.save()
# Update policy
updated_policy = MfaPolicyService.create_org_policy(
organization_id=test_organization.id,
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
mfa_grace_period_days=7,
updated_by_user_id=None,
)
assert updated_policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP
assert updated_policy.mfa_grace_period_days == 7
assert updated_policy.policy_version == 2
def test_set_user_override_new(self, db, test_user, test_organization):
"""Test setting a new user override."""
override = MfaPolicyService.set_user_override(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.REQUIRED,
force_totp=True,
force_webauthn=False,
updated_by_user_id=None,
)
assert override is not None
assert override.user_id == test_user.id
assert override.organization_id == test_organization.id
assert override.mfa_override_mode == MfaRequirementOverride.REQUIRED
assert override.force_totp is True
def test_set_user_override_update(self, db, test_user, test_organization):
"""Test updating an existing user override."""
# Create initial override
initial_override = UserSecurityPolicy(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.INHERIT,
)
initial_override.save()
# Update override
updated_override = MfaPolicyService.set_user_override(
user_id=test_user.id,
organization_id=test_organization.id,
mfa_override_mode=MfaRequirementOverride.EXEMPT,
updated_by_user_id=None,
)
assert updated_override.mfa_override_mode == MfaRequirementOverride.EXEMPT
def test_get_user_compliance(self, db, test_user, test_organization):
"""Test getting user compliance record."""
# Create compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.COMPLIANT,
policy_version=1,
)
compliance.save()
result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert result is not None
assert result.status == MfaComplianceStatus.COMPLIANT
def test_get_user_compliance_not_found(self, db, test_user, test_organization):
"""Test getting user compliance record when none exists."""
result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
assert result is None
def test_get_org_compliance_list(self, db, test_user, test_organization):
"""Test getting organization compliance list."""
# Create compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.IN_GRACE,
policy_version=1,
deadline_at=datetime.now(timezone.utc) + timedelta(days=14),
)
compliance.save()
results = MfaPolicyService.get_org_compliance_list(test_organization.id)
assert len(results) == 1
assert results[0]["user_id"] == test_user.id
assert results[0]["status"] == MfaComplianceStatus.IN_GRACE.value
def test_get_org_compliance_list_with_status_filter(self, db, test_user, test_organization):
"""Test getting organization compliance list with status filter."""
# Create compliance record
compliance = MfaPolicyCompliance(
user_id=test_user.id,
organization_id=test_organization.id,
status=MfaComplianceStatus.COMPLIANT,
policy_version=1,
)
compliance.save()
# Filter by different status
results = MfaPolicyService.get_org_compliance_list(
test_organization.id, status=MfaComplianceStatus.IN_GRACE
)
assert len(results) == 0
# Filter by correct status
results = MfaPolicyService.get_org_compliance_list(
test_organization.id, status=MfaComplianceStatus.COMPLIANT
)
assert len(results) == 1
@pytest.mark.unit
class TestMfaPolicyServiceDto:
"""Tests for MfaPolicyService DTOs."""
def test_org_policy_dto(self):
"""Test OrgPolicyDto creation."""
dto = OrgPolicyDto(
organization_id="org-123",
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP.value,
mfa_grace_period_days=14,
notify_days_before=7,
policy_version=1,
)
assert dto.organization_id == "org-123"
assert dto.mfa_policy_mode == "require_totp"
assert dto.mfa_grace_period_days == 14
def test_effective_user_policy_dto(self):
"""Test EffectiveUserPolicyDto creation."""
dto = EffectiveUserPolicyDto(
organization_id="org-123",
effective_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
requires_totp=True,
requires_webauthn=True,
grace_period_days=14,
is_exempt=False,
)
assert dto.requires_totp is True
assert dto.requires_webauthn is True
assert dto.is_exempt is False
def test_aggregate_mfa_state_dto(self):
"""Test AggregateMfaStateDto creation."""
dto = AggregateMfaStateDto(
overall_status=MfaComplianceStatus.IN_GRACE.value,
missing_methods=["totp"],
deadline_at="2025-02-01T00:00:00Z",
orgs=[],
)
assert dto.overall_status == "in_grace"
assert "totp" in dto.missing_methods
assert dto.deadline_at == "2025-02-01T00:00:00Z"
def test_login_policy_result(self):
"""Test LoginPolicyResult creation."""
summary = AggregateMfaStateDto(
overall_status=MfaComplianceStatus.IN_GRACE.value,
missing_methods=["totp"],
orgs=[],
)
result = LoginPolicyResult(
can_create_full_session=True,
create_compliance_only_session=False,
compliance_summary=summary,
)
assert result.can_create_full_session is True
assert result.create_compliance_only_session is False
assert result.compliance_summary.overall_status == "in_grace"
@@ -1,533 +0,0 @@
"""Unit tests for OAuthFlowService."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime, timedelta, timezone
from gatehouse_app.services.oauth_flow_service import (
OAuthFlowService,
OAuthFlowError,
)
from gatehouse_app.services.external_auth_service import OAuthState, ExternalProviderConfig
from gatehouse_app.utils.constants import AuthMethodType
from gatehouse_app.models import User, AuthenticationMethod
@pytest.mark.unit
class TestOAuthFlowService:
"""Tests for OAuthFlowService."""
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
def test_initiate_login_flow_success(self, mock_audit, app, db, test_organization):
"""Test initiating login flow successfully."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
with app.test_request_context():
auth_url, state = OAuthFlowService.initiate_login_flow(
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
)
assert auth_url is not None
assert state is not None
assert len(state) == 43
# Verify state was created with correct flow type
state_record = OAuthState.query.filter_by(state=state).first()
assert state_record is not None
assert state_record.flow_type == "login"
assert state_record.organization_id == test_organization.id
def test_initiate_login_flow_invalid_redirect_uri(self, app, db, test_organization):
"""Test initiating login flow with invalid redirect URI."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
with app.test_request_context():
with pytest.raises(OAuthFlowError) as exc_info:
OAuthFlowService.initiate_login_flow(
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://malicious.com/callback",
)
assert exc_info.value.error_type == "INVALID_REDIRECT_URI"
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
def test_initiate_register_flow_success(self, mock_audit, app, db, test_organization):
"""Test initiating register flow successfully."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
with app.test_request_context():
auth_url, state = OAuthFlowService.initiate_register_flow(
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
)
assert auth_url is not None
assert state is not None
# Verify state was created with correct flow type
state_record = OAuthState.query.filter_by(state=state).first()
assert state_record is not None
assert state_record.flow_type == "register"
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService.authenticate_with_provider')
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
def test_handle_callback_login_flow(
self, mock_audit, mock_authenticate,
app, db, test_user, test_organization
):
"""Test handling callback for login flow."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create authentication method
auth_method = AuthenticationMethod(
user_id=test_user.id,
method_type=AuthMethodType.GOOGLE,
provider_user_id="google-123",
provider_data={"email": test_user.email},
verified=True,
)
auth_method.save()
# Create login state
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
)
# Mock authentication
mock_authenticate.return_value = (test_user, {"token": "session-token", "expires_in": 86400})
with app.test_request_context():
result = OAuthFlowService.handle_callback(
provider_type=AuthMethodType.GOOGLE,
authorization_code="mock-auth-code",
state=state.state,
redirect_uri="http://localhost:3000/callback",
)
assert result["success"] is True
assert result["flow_type"] == "login"
assert result["user"]["id"] == test_user.id
assert result["session"]["token"] == "session-token"
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService.complete_link_flow')
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
def test_handle_callback_link_flow(
self, mock_audit, mock_complete_link,
app, db, test_user, test_organization
):
"""Test handling callback for link flow."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create link state
state = OAuthState.create_state(
flow_type="link",
provider_type=AuthMethodType.GOOGLE,
user_id=test_user.id,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
)
# Mock complete link
mock_auth_method = Mock()
mock_auth_method.id = "auth-method-123"
mock_auth_method.provider_user_id = "google-123"
mock_auth_method.verified = True
mock_complete_link.return_value = mock_auth_method
with app.test_request_context():
result = OAuthFlowService.handle_callback(
provider_type=AuthMethodType.GOOGLE,
authorization_code="mock-auth-code",
state=state.state,
redirect_uri="http://localhost:3000/callback",
)
assert result["success"] is True
assert result["flow_type"] == "link"
assert result["linked_account"]["id"] == "auth-method-123"
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService._exchange_code')
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService._get_user_info')
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService._encrypt_provider_data')
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
@patch('gatehouse_app.services.auth_service.AuthService.create_session')
def test_handle_callback_register_flow(
self, mock_create_session, mock_audit, mock_encrypt,
mock_get_user_info, mock_exchange_code,
app, db, test_organization
):
"""Test handling callback for register flow."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create register state
state = OAuthState.create_state(
flow_type="register",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
)
# Mock external provider responses
mock_exchange_code.return_value = {
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"id_token": "mock-id-token",
"expires_in": 3600,
}
mock_get_user_info.return_value = {
"provider_user_id": "google-new-123",
"email": "newuser@gmail.com",
"email_verified": True,
"name": "New User",
"picture": "https://example.com/avatar.jpg",
"raw_data": {},
}
mock_encrypt.return_value = {
"access_token": "mock-access-token",
"email": "newuser@gmail.com",
"name": "New User",
}
mock_session = Mock()
mock_session.to_dict.return_value = {"token": "session-token", "expires_in": 86400}
mock_create_session.return_value = mock_session
with app.test_request_context():
result = OAuthFlowService.handle_callback(
provider_type=AuthMethodType.GOOGLE,
authorization_code="mock-auth-code",
state=state.state,
redirect_uri="http://localhost:3000/callback",
)
assert result["success"] is True
assert result["flow_type"] == "register"
assert result["user"]["email"] == "newuser@gmail.com"
assert result["session"]["token"] == "session-token"
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService._exchange_code')
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService._get_user_info')
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
def test_handle_callback_register_flow_email_exists(
self, mock_audit, mock_get_user_info, mock_exchange_code,
app, db, test_user, test_organization
):
"""Test handling callback for register flow when email already exists."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create register state
state = OAuthState.create_state(
flow_type="register",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
)
# Mock external provider responses
mock_exchange_code.return_value = {
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"id_token": "mock-id-token",
"expires_in": 3600,
}
# Return email that matches existing user
mock_get_user_info.return_value = {
"provider_user_id": "google-new-123",
"email": test_user.email, # Existing email
"email_verified": True,
"name": "Test User",
"picture": "https://example.com/avatar.jpg",
"raw_data": {},
}
with app.test_request_context():
with pytest.raises(OAuthFlowError) as exc_info:
OAuthFlowService.handle_callback(
provider_type=AuthMethodType.GOOGLE,
authorization_code="mock-auth-code",
state=state.state,
redirect_uri="http://localhost:3000/callback",
)
assert exc_info.value.error_type == "EMAIL_EXISTS"
def test_handle_callback_invalid_state(self, app, db):
"""Test handling callback with invalid state."""
with app.app_context():
with app.test_request_context():
with pytest.raises(OAuthFlowError) as exc_info:
OAuthFlowService.handle_callback(
provider_type=AuthMethodType.GOOGLE,
authorization_code="mock-auth-code",
state="invalid-state",
)
assert exc_info.value.error_type == "INVALID_STATE"
def test_handle_callback_provider_error(self, app, db):
"""Test handling callback with provider error."""
with app.app_context():
with app.test_request_context():
with pytest.raises(OAuthFlowError) as exc_info:
OAuthFlowService.handle_callback(
provider_type=AuthMethodType.GOOGLE,
authorization_code=None,
state=None,
error="access_denied",
error_description="User denied access",
)
assert exc_info.value.error_type == "ACCESS_DENIED"
def test_handle_callback_unknown_flow_type(self, app, db, test_organization):
"""Test handling callback with unknown flow type."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create state with unknown flow type
state = OAuthState.create_state(
flow_type="unknown",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
redirect_uri="http://localhost:3000/callback",
)
with app.test_request_context():
with pytest.raises(OAuthFlowError) as exc_info:
OAuthFlowService.handle_callback(
provider_type=AuthMethodType.GOOGLE,
authorization_code="mock-auth-code",
state=state.state,
)
assert exc_info.value.error_type == "INVALID_FLOW_TYPE"
def test_validate_state_valid(self, app, db, test_organization):
"""Test validating a valid state."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create state
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
)
result = OAuthFlowService.validate_state(state.state)
assert result is not None
assert result.id == state.id
def test_validate_state_invalid(self, app, db):
"""Test validating an invalid state."""
with app.app_context():
result = OAuthFlowService.validate_state("nonexistent-state")
assert result is None
def test_validate_state_expired(self, app, db, test_organization):
"""Test validating an expired state."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create expired state
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
lifetime_seconds=-1,
)
result = OAuthFlowService.validate_state(state.state)
assert result is None
def test_validate_state_used(self, app, db, test_organization):
"""Test validating a used state."""
with app.app_context():
# Create provider config
config = ExternalProviderConfig(
organization_id=test_organization.id,
provider_type=AuthMethodType.GOOGLE.value,
client_id="test-client-id",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
scopes=["openid", "profile", "email"],
redirect_uris=["http://localhost:3000/callback"],
is_active=True,
)
config.save()
# Create and mark state as used
state = OAuthState.create_state(
flow_type="login",
provider_type=AuthMethodType.GOOGLE,
organization_id=test_organization.id,
)
state.mark_used()
result = OAuthFlowService.validate_state(state.state)
assert result is None
@pytest.mark.unit
class TestOAuthFlowError:
"""Tests for OAuthFlowError exception."""
def test_error_creation(self):
"""Test creating OAuthFlowError."""
error = OAuthFlowError(
message="Test error message",
error_type="TEST_ERROR",
status_code=400,
)
assert error.message == "Test error message"
assert error.error_type == "TEST_ERROR"
assert error.status_code == 400
def test_error_default_status_code(self):
"""Test OAuthFlowError with default status code."""
error = OAuthFlowError(
message="Test error message",
error_type="TEST_ERROR",
)
assert error.status_code == 400
@@ -1,285 +0,0 @@
"""Unit tests for TOTPService."""
import base64
import pytest
from gatehouse_app.services.totp_service import TOTPService
@pytest.mark.unit
class TestTOTPService:
"""Tests for TOTPService."""
# Test generate_secret()
def test_generate_secret_returns_string(self):
"""Test that generate_secret returns a string."""
secret = TOTPService.generate_secret()
assert isinstance(secret, str)
def test_generate_secret_length(self):
"""Test that generate_secret returns a 32-character string."""
secret = TOTPService.generate_secret()
assert len(secret) == 32
def test_generate_secret_base32_encoded(self):
"""Test that generate_secret returns a base32 encoded string."""
secret = TOTPService.generate_secret()
# Base32 characters are A-Z and 2-7
valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567")
assert all(c in valid_chars for c in secret)
def test_generate_secret_unique(self):
"""Test that generate_secret produces unique secrets."""
secret1 = TOTPService.generate_secret()
secret2 = TOTPService.generate_secret()
assert secret1 != secret2
# Test generate_provisioning_uri()
def test_generate_provisioning_uri_format(self):
"""Test that provisioning URI is generated correctly."""
email = "user@example.com"
secret = "JBSWY3DPEHPK3PXP"
issuer = "Gatehouse"
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
assert isinstance(uri, str)
assert uri.startswith("otpauth://totp/")
def test_generate_provisioning_uri_contains_email(self):
"""Test that provisioning URI contains the user email."""
email = "user@example.com"
secret = "JBSWY3DPEHPK3PXP"
issuer = "Gatehouse"
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
assert email in uri
def test_generate_provisioning_uri_contains_secret(self):
"""Test that provisioning URI contains the secret."""
email = "user@example.com"
secret = "JBSWY3DPEHPK3PXP"
issuer = "Gatehouse"
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
assert secret in uri
def test_generate_provisioning_uri_contains_issuer(self):
"""Test that provisioning URI contains the issuer."""
email = "user@example.com"
secret = "JBSWY3DPEHPK3PXP"
issuer = "Gatehouse"
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
assert issuer in uri
def test_generate_provisioning_uri_custom_issuer(self):
"""Test that provisioning URI uses custom issuer."""
email = "user@example.com"
secret = "JBSWY3DPEHPK3PXP"
custom_issuer = "MyApp"
uri = TOTPService.generate_provisioning_uri(email, secret, custom_issuer)
assert custom_issuer in uri
# Test verify_code()
def test_verify_code_valid(self):
"""Test that a valid TOTP code is accepted."""
secret = TOTPService.generate_secret()
# Generate a valid code using pyotp
import pyotp
totp = pyotp.TOTP(secret)
valid_code = totp.now()
result = TOTPService.verify_code(secret, valid_code)
assert result is True
def test_verify_code_invalid(self):
"""Test that an invalid TOTP code is rejected."""
secret = TOTPService.generate_secret()
invalid_code = "000000"
result = TOTPService.verify_code(secret, invalid_code)
assert result is False
def test_verify_code_window_parameter(self):
"""Test that the time window parameter works correctly."""
secret = TOTPService.generate_secret()
import pyotp
totp = pyotp.TOTP(secret)
# Get current code
current_code = totp.now()
# Verify with window=1 (default) - should accept current code
result = TOTPService.verify_code(secret, current_code, window=1)
assert result is True
# Verify with window=0 - should only accept exact time match
result = TOTPService.verify_code(secret, current_code, window=0)
assert result is True
def test_verify_code_wrong_length(self):
"""Test that codes with wrong length are rejected."""
secret = TOTPService.generate_secret()
wrong_length_code = "12345" # 5 digits instead of 6
result = TOTPService.verify_code(secret, wrong_length_code)
assert result is False
# Test generate_backup_codes()
def test_generate_backup_codes_default_count(self):
"""Test that generate_backup_codes generates 10 codes by default."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
assert len(plain_codes) == 10
assert len(hashed_codes) == 10
def test_generate_backup_codes_custom_count(self):
"""Test that generate_backup_codes generates the specified number of codes."""
count = 5
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count)
assert len(plain_codes) == count
assert len(hashed_codes) == count
def test_generate_backup_codes_plain_are_strings(self):
"""Test that plain backup codes are strings."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
assert all(isinstance(code, str) for code in plain_codes)
def test_generate_backup_codes_plain_length(self):
"""Test that plain backup codes are 16 characters long."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
assert all(len(code) == 16 for code in plain_codes)
def test_generate_backup_codes_hashed_different_from_plain(self):
"""Test that hashed codes are different from plain codes."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
for plain, hashed in zip(plain_codes, hashed_codes):
assert plain != hashed
def test_generate_backup_codes_are_bcrypt_hashes(self):
"""Test that hashed codes are bcrypt hashes."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
# Bcrypt hashes start with $2a$, $2b$, or $2y$
for hashed in hashed_codes:
assert hashed.startswith("$2")
def test_generate_backup_codes_unique(self):
"""Test that generated backup codes are unique."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
assert len(set(plain_codes)) == len(plain_codes)
assert len(set(hashed_codes)) == len(hashed_codes)
# Test verify_backup_code()
def test_verify_backup_code_valid(self):
"""Test that a valid backup code is accepted and removed."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=3)
code_to_verify = plain_codes[0]
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
assert is_valid is True
assert len(remaining_codes) == 2
def test_verify_backup_code_invalid(self):
"""Test that an invalid backup code is rejected."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=3)
invalid_code = "INVALIDCODE1234"
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, invalid_code)
assert is_valid is False
assert len(remaining_codes) == 3
def test_verify_backup_code_remaining_updated(self):
"""Test that the remaining codes list is updated correctly."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=5)
code_to_verify = plain_codes[2]
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
assert is_valid is True
# The verified code should be removed
assert len(remaining_codes) == 4
# The remaining codes should not include the verified code's hash
assert hashed_codes[2] not in remaining_codes
def test_verify_backup_code_case_sensitive(self):
"""Test that backup code verification is case sensitive."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=1)
code_to_verify = plain_codes[0].lower() # Convert to lowercase
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
assert is_valid is False
assert len(remaining_codes) == 1
def test_verify_backup_code_single_use(self):
"""Test that a backup code can only be used once."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=1)
code_to_verify = plain_codes[0]
# First use - should succeed
is_valid1, remaining1 = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
assert is_valid1 is True
assert len(remaining1) == 0
# Second use - should fail (code already consumed)
is_valid2, remaining2 = TOTPService.verify_backup_code(remaining1, code_to_verify)
assert is_valid2 is False
assert len(remaining2) == 0
# Test generate_qr_code_data_uri()
def test_generate_qr_code_data_uri_format(self):
"""Test that a data URI is generated."""
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
assert isinstance(data_uri, str)
def test_generate_qr_code_data_uri_starts_with_prefix(self):
"""Test that the data URI starts with the correct prefix."""
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
assert data_uri.startswith("data:image/png;base64,")
def test_generate_qr_code_data_uri_contains_base64(self):
"""Test that the data URI contains base64 encoded data."""
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
# Extract the base64 part (after the prefix)
base64_part = data_uri.split("data:image/png;base64,")[1]
# Verify it's valid base64
try:
base64.b64decode(base64_part)
assert True
except Exception:
assert False, "Data URI does not contain valid base64 data"
def test_generate_qr_code_data_uri_different_uris(self):
"""Test that different provisioning URIs generate different QR codes."""
uri1 = "otpauth://totp/Gatehouse:user1@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
uri2 = "otpauth://totp/Gatehouse:user2@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
data_uri1 = TOTPService.generate_qr_code_data_uri(uri1)
data_uri2 = TOTPService.generate_qr_code_data_uri(uri2)
assert data_uri1 != data_uri2