feat(auth): implement TOTP two-factor authentication with enrollment and verification

Adds TOTP (Time-based One-Time Password) two-factor authentication support including:
- New TOTP service with secret generation, QR code provisioning, and code verification
- New auth endpoints for enrollment, verification, status, and backup code management
- New TOTP authentication method type and user methods for TOTP management
- Backup codes generation and verification for account recovery
- Updated OIDC endpoints with timezone-aware datetime handling and RFC-compliant responses
- Added "roles" scope support for OIDC userinfo and ID tokens
- New pyotp dependency for TOTP operations
- Comprehensive unit tests for TOTP service
This commit is contained in:
2026-01-14 18:06:17 +10:30
parent 977abf66df
commit cfd79190ee
26 changed files with 2176 additions and 263 deletions
+212 -5
View File
@@ -2,7 +2,7 @@
import logging
import secrets
import hashlib
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Tuple
from flask import current_app, g
@@ -14,6 +14,7 @@ from app.models import (
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
OIDCSession, OIDCTokenMetadata
)
from app.models.organization_member import OrganizationMember
from app.exceptions.validation_exceptions import (
ValidationError, NotFoundError, BadRequestError
)
@@ -121,6 +122,14 @@ class OIDCService:
ValidationError: If parameters are invalid
NotFoundError: If client not found
"""
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] generate_authorization_code called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC SERVICE] redirect_uri=%s", redirect_uri)
logger.debug("[OIDC SERVICE] scope=%s", scope)
logger.debug("[OIDC SERVICE] state=%s, nonce=%s", state, nonce)
# Validate client exists and is active
client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -152,14 +161,19 @@ class OIDCService:
raise ValidationError("Invalid scopes")
# Generate authorization code
logger.debug("[OIDC SERVICE] Generating authorization code...")
logger.debug("[OIDC SERVICE] Current UTC time before code generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
code = cls._generate_code()
code_hash = cls._hash_value(code)
logger.debug("[OIDC SERVICE] Code generated: %s...", code[:20] if code else None)
# Development-only debug logging for PKCE in code creation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[OIDC] Generate auth code - PKCE: code_challenge={code_challenge is not None}, code_challenge_method={code_challenge_method}")
# Create auth code record
logger.debug("[OIDC SERVICE] Creating auth code record with lifetime_seconds=600 (10 minutes)")
logger.debug("[OIDC SERVICE] Current UTC time before creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z")
auth_code = OIDCAuthCode.create_code(
client_id=client.id,
user_id=user_id,
@@ -172,6 +186,9 @@ class OIDCService:
user_agent=user_agent,
lifetime_seconds=600, # 10 minutes
)
logger.debug("[OIDC SERVICE] Auth code created successfully")
logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z")
logger.debug("[OIDC SERVICE] Current UTC time after creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Log authorization event
OIDCAuditService.log_authorization_event(
@@ -182,6 +199,9 @@ class OIDCService:
scope=valid_scopes,
)
logger.debug("[OIDC SERVICE] generate_authorization_code completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return code
@classmethod
@@ -211,6 +231,12 @@ class OIDCService:
InvalidGrantError: If code is invalid
ValidationError: If PKCE validation fails
"""
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] validate_authorization_code called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, redirect_uri=%s", client_id, redirect_uri)
logger.debug("[OIDC SERVICE] code_verifier provided: %s", bool(code_verifier))
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -223,6 +249,8 @@ class OIDCService:
raise InvalidGrantError("Invalid client")
# Hash the provided code and find matching auth code
logger.debug("[OIDC SERVICE] Looking up authorization code...")
logger.debug("[OIDC SERVICE] Current UTC time before code lookup: %s", datetime.now(timezone.utc).isoformat() + "Z")
code_hash = cls._hash_value(code)
auth_code = OIDCAuthCode.query.filter_by(
code_hash=code_hash,
@@ -256,8 +284,18 @@ class OIDCService:
raise InvalidGrantError("Authorization code already used")
# Check expiration
logger.debug("[OIDC SERVICE] Checking if authorization code is expired...")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z")
# Handle timezone-naive expires_at from database
expires_at = auth_code.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (expires_at - datetime.now(timezone.utc)).total_seconds())
if auth_code.is_expired():
logger.error(f"[OIDC] Validate auth code - Code expired: code_hash={code_hash[:20]}..., expires_at={auth_code.expires_at}")
logger.error("[OIDC] Validate auth code - Code expired: code_hash=%s..., expires_at (UTC)=%s, current UTC time=%s",
code_hash[:20], auth_code.expires_at.isoformat() + "Z", datetime.now(timezone.utc).isoformat() + "Z")
OIDCAuditService.log_authorization_event(
client_id=client_id,
user_id=auth_code.user_id,
@@ -316,6 +354,9 @@ class OIDCService:
"nonce": auth_code.nonce,
}
logger.debug("[OIDC SERVICE] validate_authorization_code completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return claims, user
@classmethod
@@ -366,6 +407,12 @@ class OIDCService:
"""
import hashlib
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] generate_tokens called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s, scope=%s", client_id, user_id, scope)
logger.debug("[OIDC SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -377,6 +424,9 @@ class OIDCService:
raise InvalidClientError()
# Generate access token
logger.debug("[OIDC SERVICE] Generating access token...")
logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600)
access_token_jti = OIDCTokenService._generate_jti()
access_token = OIDCTokenService.create_access_token(
client_id=client_id,
@@ -384,8 +434,13 @@ class OIDCService:
scope=scope,
jti=access_token_jti,
)
logger.debug("[OIDC SERVICE] Access token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Generate ID token
logger.debug("[OIDC SERVICE] Generating ID token...")
logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600)
id_token = OIDCTokenService.create_id_token(
client_id=client_id,
user_id=user_id,
@@ -394,6 +449,8 @@ class OIDCService:
access_token=access_token,
auth_time=auth_time,
)
logger.debug("[OIDC SERVICE] ID token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Generate or rotate refresh token
if "refresh_token" in (client.grant_types or []):
@@ -445,22 +502,28 @@ class OIDCService:
client_db_id = client.id
# Access token metadata
logger.debug("[OIDC SERVICE] Creating access token metadata...")
access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600)
logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z")
OIDCTokenMetadata.create_metadata(
client_id=client_db_id,
user_id=user_id,
token_type="access_token",
token_jti=access_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
expires_at=access_token_expires_at,
)
# ID token metadata (using access token JTI as reference)
logger.debug("[OIDC SERVICE] Creating ID token metadata...")
id_token_jti = OIDCTokenService._generate_jti()
id_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.id_token_lifetime or 3600)
logger.debug("[OIDC SERVICE] ID token expires_at (UTC): %s", id_token_expires_at.isoformat() + "Z")
OIDCTokenMetadata.create_metadata(
client_id=client_db_id,
user_id=user_id,
token_type="id_token",
token_jti=id_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.id_token_lifetime or 3600),
expires_at=id_token_expires_at,
)
# Log token event
@@ -483,6 +546,9 @@ class OIDCService:
if final_refresh_token:
result["refresh_token"] = final_refresh_token
logger.debug("[OIDC SERVICE] generate_tokens completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return result
@classmethod
@@ -511,6 +577,11 @@ class OIDCService:
"""
import hashlib
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] refresh_access_token called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, scope=%s", client_id, scope)
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -522,6 +593,8 @@ class OIDCService:
raise InvalidClientError()
# Find refresh token
logger.debug("[OIDC SERVICE] Looking up refresh token...")
logger.debug("[OIDC SERVICE] Current UTC time before refresh token lookup: %s", datetime.now(timezone.utc).isoformat() + "Z")
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
refresh_token_obj = OIDCRefreshToken.query.filter_by(
token_hash=token_hash,
@@ -542,6 +615,16 @@ class OIDCService:
raise InvalidGrantError("Invalid refresh token")
# Check if valid
logger.debug("[OIDC SERVICE] Checking if refresh token is valid...")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
if refresh_token_obj:
logger.debug("[OIDC SERVICE] Refresh token expires_at (UTC): %s", refresh_token_obj.expires_at.isoformat() + "Z")
# Handle timezone-naive expires_at from database
rt_expires_at = refresh_token_obj.expires_at
if rt_expires_at.tzinfo is None:
rt_expires_at = rt_expires_at.replace(tzinfo=timezone.utc)
logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (rt_expires_at - datetime.now(timezone.utc)).total_seconds())
if not refresh_token_obj.is_valid():
OIDCAuditService.log_token_event(
client_id=client_id,
@@ -563,6 +646,9 @@ class OIDCService:
granted_scope = scope or (refresh_token_obj.scope or [])
# Generate new access token
logger.debug("[OIDC SERVICE] Generating new access token...")
logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600)
access_token_jti = OIDCTokenService._generate_jti()
access_token = OIDCTokenService.create_access_token(
client_id=client_id,
@@ -570,14 +656,21 @@ class OIDCService:
scope=granted_scope,
jti=access_token_jti,
)
logger.debug("[OIDC SERVICE] Access token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Generate new ID token
logger.debug("[OIDC SERVICE] Generating new ID token...")
logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600)
id_token = OIDCTokenService.create_id_token(
client_id=client_id,
user_id=refresh_token_obj.user_id,
scope=granted_scope,
access_token=access_token,
)
logger.debug("[OIDC SERVICE] ID token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Rotate refresh token
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
@@ -590,12 +683,15 @@ class OIDCService:
refresh_token_obj.rotate(new_hash)
# Store new token metadata
logger.debug("[OIDC SERVICE] Creating access token metadata...")
access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600)
logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z")
OIDCTokenMetadata.create_metadata(
client_id=client.id,
user_id=refresh_token_obj.user_id,
token_type="access_token",
token_jti=access_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
expires_at=access_token_expires_at,
)
# Log refresh event
@@ -615,6 +711,17 @@ class OIDCService:
"id_token": id_token,
"refresh_token": new_refresh,
}
logger.debug("[OIDC SERVICE] refresh_access_token completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": client.access_token_lifetime or 3600,
"id_token": id_token,
"refresh_token": new_refresh,
}
@classmethod
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
@@ -630,10 +737,23 @@ class OIDCService:
Raises:
InvalidTokenError: If token is invalid
"""
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] validate_access_token() called")
logger.debug("[OIDC SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC SERVICE] Token length: %d", len(token))
logger.debug("[OIDC SERVICE] Client ID: %s", client_id)
try:
logger.debug("[OIDC SERVICE] Calling OIDCTokenService.validate_access_token()...")
claims = OIDCTokenService.validate_access_token(token, client_id)
logger.debug("[OIDC SERVICE] Token validation successful")
logger.debug("[OIDC SERVICE] Token claims: %s", claims)
logger.debug("[OIDC SERVICE] ===========================================")
return claims
except Exception as e:
logger.error("[OIDC SERVICE] Token validation failed: %s: %s", type(e).__name__, str(e))
import traceback
logger.error("[OIDC SERVICE] Traceback: %s", traceback.format_exc())
OIDCAuditService.log_event(
event_type="token_validation",
client_id=client_id,
@@ -770,29 +890,67 @@ class OIDCService:
Returns:
User information dictionary
"""
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] get_userinfo() called")
logger.debug("[OIDC SERVICE] Access token (first 50 chars): %s...", access_token[:50] if len(access_token) > 50 else access_token)
logger.debug("[OIDC SERVICE] Access token length: %d", len(access_token))
# Validate access token
logger.debug("[OIDC SERVICE] Validating access token...")
claims = cls.validate_access_token(access_token)
logger.debug("[OIDC SERVICE] Access token validated successfully")
logger.debug("[OIDC SERVICE] Token claims: %s", claims)
user_id = claims.get("sub")
logger.debug("[OIDC SERVICE] User ID from token: %s", user_id)
logger.debug("[OIDC SERVICE] Querying user from database...")
user = User.query.get(user_id)
logger.debug("[OIDC SERVICE] User query result: %s", user)
if not user:
logger.error("[OIDC SERVICE] User not found in database: user_id=%s", user_id)
raise NotFoundError("User not found")
logger.debug("[OIDC SERVICE] User found: user_id=%s, email=%s, full_name=%s", user.id, user.email, user.full_name)
# Get scopes from token
scope_str = claims.get("scope", "")
scopes = scope_str.split() if scope_str else []
logger.debug("[OIDC SERVICE] Scope string from token: '%s'", scope_str)
logger.debug("[OIDC SERVICE] Parsed scopes: %s", scopes)
userinfo = {"sub": user_id}
logger.debug("[OIDC SERVICE] Initial userinfo: %s", userinfo)
# Add claims based on scope
if "profile" in scopes and user.full_name:
logger.debug("[OIDC SERVICE] Found 'profile' in scope, adding name claim")
userinfo["name"] = user.full_name
logger.debug("[OIDC SERVICE] Added name: %s", user.full_name)
else:
logger.debug("[OIDC SERVICE] 'profile' not in scope or user.full_name is None: profile_in_scope=%s, full_name=%s", "profile" in scopes, user.full_name)
if "email" in scopes:
logger.debug("[OIDC SERVICE] Found 'email' in scope, adding email claims")
userinfo["email"] = user.email
userinfo["email_verified"] = user.email_verified
logger.debug("[OIDC SERVICE] Added email: %s, email_verified: %s", user.email, user.email_verified)
else:
logger.debug("[OIDC SERVICE] 'email' not in scope")
if "roles" in scopes:
logger.debug("[OIDC SERVICE] Found 'roles' in scope, adding roles claim")
user_roles = cls._get_user_roles(user)
userinfo["roles"] = user_roles
logger.debug("[OIDC SERVICE] Added roles: %s", user_roles)
else:
logger.debug("[OIDC SERVICE] 'roles' not in scope")
logger.debug("[OIDC SERVICE] Final userinfo: %s", userinfo)
# Log userinfo access
logger.debug("[OIDC SERVICE] Logging userinfo access event...")
OIDCAuditService.log_userinfo_event(
access_token=access_token,
user_id=user_id,
@@ -800,5 +958,54 @@ class OIDCService:
success=True,
scopes_claimed=scopes,
)
logger.debug("[OIDC SERVICE] Userinfo access event logged")
logger.debug("[OIDC SERVICE] get_userinfo() completed successfully")
logger.debug("[OIDC SERVICE] ===========================================")
return userinfo
@staticmethod
def _get_user_roles(user: User) -> list:
"""Get user's organization roles.
Args:
user: User instance
Returns:
List of role objects with organization_id and role
"""
logger.debug("[OIDC SERVICE] _get_user_roles() called")
logger.debug("[OIDC SERVICE] User: %s", user)
roles = []
if not user:
logger.debug("[OIDC SERVICE] User is None, returning empty roles list")
return roles
logger.debug("[OIDC SERVICE] User ID: %s", user.id)
logger.debug("[OIDC SERVICE] User email: %s", user.email)
logger.debug("[OIDC SERVICE] User organization_memberships: %s", user.organization_memberships)
if user.organization_memberships:
logger.debug("[OIDC SERVICE] User has %d organization memberships", len(user.organization_memberships))
for idx, member in enumerate(user.organization_memberships):
logger.debug("[OIDC SERVICE] Processing membership %d: member=%s", idx, member)
logger.debug("[OIDC SERVICE] organization_id: %s", member.organization_id)
logger.debug("[OIDC SERVICE] role: %s", member.role)
logger.debug("[OIDC SERVICE] role.value: %s", member.role.value)
role_entry = {
"organization_id": str(member.organization_id),
"role": member.role.value
}
roles.append(role_entry)
logger.debug("[OIDC SERVICE] Added role entry: %s", role_entry)
else:
logger.debug("[OIDC SERVICE] User has no organization memberships")
logger.debug("[OIDC SERVICE] Final roles list: %s", roles)
logger.debug("[OIDC SERVICE] _get_user_roles() completed")
return roles