feat(auth): implement TOTP two-factor authentication with enrollment and verification

Adds TOTP (Time-based One-Time Password) two-factor authentication support including:
- New TOTP service with secret generation, QR code provisioning, and code verification
- New auth endpoints for enrollment, verification, status, and backup code management
- New TOTP authentication method type and user methods for TOTP management
- Backup codes generation and verification for account recovery
- Updated OIDC endpoints with timezone-aware datetime handling and RFC-compliant responses
- Added "roles" scope support for OIDC userinfo and ID tokens
- New pyotp dependency for TOTP operations
- Comprehensive unit tests for TOTP service
This commit is contained in:
2026-01-14 18:06:17 +10:30
parent 977abf66df
commit cfd79190ee
26 changed files with 2176 additions and 263 deletions
+313
View File
@@ -11,6 +11,7 @@ from app.utils.constants import AuthMethodType, SessionStatus, UserStatus, Audit
from app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
from app.exceptions.validation_exceptions import EmailAlreadyExistsError
from app.services.audit_service import AuditService
from app.services.totp_service import TOTPService
logger = logging.getLogger(__name__)
@@ -234,3 +235,315 @@ class AuthService:
resource_id=session.id,
description=f"Session revoked: {reason or 'User logout'}",
)
@staticmethod
def enroll_totp(user: User) -> dict:
"""
Initiate TOTP enrollment for a user.
Args:
user: User instance
Returns:
Dictionary containing:
- secret: TOTP secret (base32 encoded)
- provisioning_uri: otpauth:// URI for QR code
- qr_code: Base64 encoded QR code as data URI
- backup_codes: List of plain text backup codes
Raises:
ConflictError: If user already has TOTP enabled
"""
from app.exceptions.validation_exceptions import ConflictError
# Check if user already has TOTP enabled
if user.has_totp_enabled():
raise ConflictError("TOTP is already enabled for this account")
# Generate TOTP secret
secret = TOTPService.generate_secret()
# Generate provisioning URI
provisioning_uri = TOTPService.generate_provisioning_uri(
user_email=user.email,
secret=secret,
issuer="Gatehouse",
)
# Generate QR code data URI
qr_code = TOTPService.generate_qr_code_data_uri(provisioning_uri)
# Generate backup codes
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
# Create unverified TOTP authentication method
auth_method = AuthenticationMethod(
user_id=user.id,
method_type=AuthMethodType.TOTP,
verified=False,
is_primary=False,
)
auth_method.save()
# Store TOTP data in provider_data (since totp_secret field is commented out)
auth_method.provider_data = {
"secret": secret,
"backup_codes": hashed_backup_codes,
}
db.session.commit()
# Log TOTP enrollment initiation
AuditService.log_action(
action=AuditAction.TOTP_ENROLL_INITIATED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP enrollment initiated",
)
return {
"secret": secret,
"provisioning_uri": provisioning_uri,
"qr_code": qr_code,
"backup_codes": backup_codes,
}
@staticmethod
def verify_totp_enrollment(user: User, code: str) -> bool:
"""
Complete TOTP enrollment by verifying the first TOTP code.
Args:
user: User instance
code: 6-digit TOTP code from authenticator app
Returns:
True if verification successful
Raises:
InvalidCredentialsError: If code is invalid or TOTP method not found
"""
# Get user's TOTP authentication method
auth_method = user.get_totp_method()
if not auth_method:
raise InvalidCredentialsError("TOTP enrollment not found")
# Get secret from provider_data
secret = auth_method.provider_data.get("secret") if auth_method.provider_data else None
if not secret:
raise InvalidCredentialsError("TOTP secret not found")
# Verify the code
if not TOTPService.verify_code(secret, code):
raise InvalidCredentialsError("Invalid TOTP code")
# Mark TOTP as verified
auth_method.verified = True
auth_method.totp_verified_at = datetime.utcnow()
db.session.commit()
# Log TOTP enrollment completion
AuditService.log_action(
action=AuditAction.TOTP_ENROLL_COMPLETED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP enrollment completed",
)
return True
@staticmethod
def disable_totp(user: User, password: str) -> bool:
"""
Disable TOTP for a user.
Args:
user: User instance
password: User's current password for verification
Returns:
True if TOTP disabled successfully
Raises:
InvalidCredentialsError: If password is invalid or TOTP method not found
"""
# Verify user's password
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError("No password authentication method found")
if not bcrypt.check_password_hash(auth_method.password_hash, password):
raise InvalidCredentialsError("Invalid password")
# Get user's TOTP authentication method
totp_method = user.get_totp_method()
if not totp_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
# Soft-delete the TOTP authentication method
totp_method.delete(soft=True)
# Log TOTP disabled
AuditService.log_action(
action=AuditAction.TOTP_DISABLED,
user_id=user.id,
resource_type="authentication_method",
resource_id=totp_method.id,
description="TOTP disabled",
)
return True
@staticmethod
def authenticate_with_totp(user: User, code: str, is_backup_code: bool = False) -> bool:
"""
Verify TOTP code during login.
Args:
user: User instance
code: 6-digit TOTP code or backup code
is_backup_code: True if code is a backup code, False if TOTP code
Returns:
True if code is valid
Raises:
InvalidCredentialsError: If code is invalid or TOTP method not found
"""
# Get user's TOTP authentication method
auth_method = user.get_totp_method()
if not auth_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
if is_backup_code:
# Verify backup code
backup_codes = (
auth_method.provider_data.get("backup_codes")
if auth_method.provider_data
else []
)
is_valid, remaining_codes = TOTPService.verify_backup_code(backup_codes, code)
if is_valid:
# Update remaining backup codes
auth_method.provider_data = {
"secret": auth_method.provider_data.get("secret"),
"backup_codes": remaining_codes,
}
auth_method.last_used_at = datetime.utcnow()
db.session.commit()
# Log backup code usage
AuditService.log_action(
action=AuditAction.TOTP_BACKUP_CODE_USED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Backup code used for authentication",
)
else:
# Log failed verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_FAILED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Invalid backup code provided",
)
raise InvalidCredentialsError("Invalid backup code")
else:
# Verify TOTP code
secret = (
auth_method.provider_data.get("secret")
if auth_method.provider_data
else None
)
if not secret:
raise InvalidCredentialsError("TOTP secret not found")
is_valid = TOTPService.verify_code(secret, code)
if is_valid:
auth_method.last_used_at = datetime.utcnow()
db.session.commit()
# Log successful verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_SUCCESS,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP code verified successfully",
)
else:
# Log failed verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_FAILED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Invalid TOTP code provided",
)
raise InvalidCredentialsError("Invalid TOTP code")
return True
@staticmethod
def regenerate_totp_backup_codes(user: User, password: str) -> list[str]:
"""
Generate new backup codes for TOTP.
Args:
user: User instance
password: User's current password for verification
Returns:
List of new plain text backup codes
Raises:
InvalidCredentialsError: If password is invalid or TOTP method not found
"""
# Verify user's password
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError("No password authentication method found")
if not bcrypt.check_password_hash(auth_method.password_hash, password):
raise InvalidCredentialsError("Invalid password")
# Get user's TOTP authentication method
totp_method = user.get_totp_method()
if not totp_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
# Generate new backup codes
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
# Update the authentication method with new backup codes
totp_method.provider_data = {
"secret": totp_method.provider_data.get("secret"),
"backup_codes": hashed_backup_codes,
}
db.session.commit()
# Log backup codes regeneration
AuditService.log_action(
action=AuditAction.TOTP_BACKUP_CODES_REGENERATED,
user_id=user.id,
resource_type="authentication_method",
resource_id=totp_method.id,
description="TOTP backup codes regenerated",
)
return backup_codes
+212 -5
View File
@@ -2,7 +2,7 @@
import logging
import secrets
import hashlib
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Tuple
from flask import current_app, g
@@ -14,6 +14,7 @@ from app.models import (
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
OIDCSession, OIDCTokenMetadata
)
from app.models.organization_member import OrganizationMember
from app.exceptions.validation_exceptions import (
ValidationError, NotFoundError, BadRequestError
)
@@ -121,6 +122,14 @@ class OIDCService:
ValidationError: If parameters are invalid
NotFoundError: If client not found
"""
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] generate_authorization_code called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC SERVICE] redirect_uri=%s", redirect_uri)
logger.debug("[OIDC SERVICE] scope=%s", scope)
logger.debug("[OIDC SERVICE] state=%s, nonce=%s", state, nonce)
# Validate client exists and is active
client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -152,14 +161,19 @@ class OIDCService:
raise ValidationError("Invalid scopes")
# Generate authorization code
logger.debug("[OIDC SERVICE] Generating authorization code...")
logger.debug("[OIDC SERVICE] Current UTC time before code generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
code = cls._generate_code()
code_hash = cls._hash_value(code)
logger.debug("[OIDC SERVICE] Code generated: %s...", code[:20] if code else None)
# Development-only debug logging for PKCE in code creation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[OIDC] Generate auth code - PKCE: code_challenge={code_challenge is not None}, code_challenge_method={code_challenge_method}")
# Create auth code record
logger.debug("[OIDC SERVICE] Creating auth code record with lifetime_seconds=600 (10 minutes)")
logger.debug("[OIDC SERVICE] Current UTC time before creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z")
auth_code = OIDCAuthCode.create_code(
client_id=client.id,
user_id=user_id,
@@ -172,6 +186,9 @@ class OIDCService:
user_agent=user_agent,
lifetime_seconds=600, # 10 minutes
)
logger.debug("[OIDC SERVICE] Auth code created successfully")
logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z")
logger.debug("[OIDC SERVICE] Current UTC time after creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Log authorization event
OIDCAuditService.log_authorization_event(
@@ -182,6 +199,9 @@ class OIDCService:
scope=valid_scopes,
)
logger.debug("[OIDC SERVICE] generate_authorization_code completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return code
@classmethod
@@ -211,6 +231,12 @@ class OIDCService:
InvalidGrantError: If code is invalid
ValidationError: If PKCE validation fails
"""
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] validate_authorization_code called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, redirect_uri=%s", client_id, redirect_uri)
logger.debug("[OIDC SERVICE] code_verifier provided: %s", bool(code_verifier))
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -223,6 +249,8 @@ class OIDCService:
raise InvalidGrantError("Invalid client")
# Hash the provided code and find matching auth code
logger.debug("[OIDC SERVICE] Looking up authorization code...")
logger.debug("[OIDC SERVICE] Current UTC time before code lookup: %s", datetime.now(timezone.utc).isoformat() + "Z")
code_hash = cls._hash_value(code)
auth_code = OIDCAuthCode.query.filter_by(
code_hash=code_hash,
@@ -256,8 +284,18 @@ class OIDCService:
raise InvalidGrantError("Authorization code already used")
# Check expiration
logger.debug("[OIDC SERVICE] Checking if authorization code is expired...")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z")
# Handle timezone-naive expires_at from database
expires_at = auth_code.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (expires_at - datetime.now(timezone.utc)).total_seconds())
if auth_code.is_expired():
logger.error(f"[OIDC] Validate auth code - Code expired: code_hash={code_hash[:20]}..., expires_at={auth_code.expires_at}")
logger.error("[OIDC] Validate auth code - Code expired: code_hash=%s..., expires_at (UTC)=%s, current UTC time=%s",
code_hash[:20], auth_code.expires_at.isoformat() + "Z", datetime.now(timezone.utc).isoformat() + "Z")
OIDCAuditService.log_authorization_event(
client_id=client_id,
user_id=auth_code.user_id,
@@ -316,6 +354,9 @@ class OIDCService:
"nonce": auth_code.nonce,
}
logger.debug("[OIDC SERVICE] validate_authorization_code completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return claims, user
@classmethod
@@ -366,6 +407,12 @@ class OIDCService:
"""
import hashlib
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] generate_tokens called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s, scope=%s", client_id, user_id, scope)
logger.debug("[OIDC SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -377,6 +424,9 @@ class OIDCService:
raise InvalidClientError()
# Generate access token
logger.debug("[OIDC SERVICE] Generating access token...")
logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600)
access_token_jti = OIDCTokenService._generate_jti()
access_token = OIDCTokenService.create_access_token(
client_id=client_id,
@@ -384,8 +434,13 @@ class OIDCService:
scope=scope,
jti=access_token_jti,
)
logger.debug("[OIDC SERVICE] Access token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Generate ID token
logger.debug("[OIDC SERVICE] Generating ID token...")
logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600)
id_token = OIDCTokenService.create_id_token(
client_id=client_id,
user_id=user_id,
@@ -394,6 +449,8 @@ class OIDCService:
access_token=access_token,
auth_time=auth_time,
)
logger.debug("[OIDC SERVICE] ID token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Generate or rotate refresh token
if "refresh_token" in (client.grant_types or []):
@@ -445,22 +502,28 @@ class OIDCService:
client_db_id = client.id
# Access token metadata
logger.debug("[OIDC SERVICE] Creating access token metadata...")
access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600)
logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z")
OIDCTokenMetadata.create_metadata(
client_id=client_db_id,
user_id=user_id,
token_type="access_token",
token_jti=access_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
expires_at=access_token_expires_at,
)
# ID token metadata (using access token JTI as reference)
logger.debug("[OIDC SERVICE] Creating ID token metadata...")
id_token_jti = OIDCTokenService._generate_jti()
id_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.id_token_lifetime or 3600)
logger.debug("[OIDC SERVICE] ID token expires_at (UTC): %s", id_token_expires_at.isoformat() + "Z")
OIDCTokenMetadata.create_metadata(
client_id=client_db_id,
user_id=user_id,
token_type="id_token",
token_jti=id_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.id_token_lifetime or 3600),
expires_at=id_token_expires_at,
)
# Log token event
@@ -483,6 +546,9 @@ class OIDCService:
if final_refresh_token:
result["refresh_token"] = final_refresh_token
logger.debug("[OIDC SERVICE] generate_tokens completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return result
@classmethod
@@ -511,6 +577,11 @@ class OIDCService:
"""
import hashlib
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] refresh_access_token called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, scope=%s", client_id, scope)
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -522,6 +593,8 @@ class OIDCService:
raise InvalidClientError()
# Find refresh token
logger.debug("[OIDC SERVICE] Looking up refresh token...")
logger.debug("[OIDC SERVICE] Current UTC time before refresh token lookup: %s", datetime.now(timezone.utc).isoformat() + "Z")
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
refresh_token_obj = OIDCRefreshToken.query.filter_by(
token_hash=token_hash,
@@ -542,6 +615,16 @@ class OIDCService:
raise InvalidGrantError("Invalid refresh token")
# Check if valid
logger.debug("[OIDC SERVICE] Checking if refresh token is valid...")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
if refresh_token_obj:
logger.debug("[OIDC SERVICE] Refresh token expires_at (UTC): %s", refresh_token_obj.expires_at.isoformat() + "Z")
# Handle timezone-naive expires_at from database
rt_expires_at = refresh_token_obj.expires_at
if rt_expires_at.tzinfo is None:
rt_expires_at = rt_expires_at.replace(tzinfo=timezone.utc)
logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (rt_expires_at - datetime.now(timezone.utc)).total_seconds())
if not refresh_token_obj.is_valid():
OIDCAuditService.log_token_event(
client_id=client_id,
@@ -563,6 +646,9 @@ class OIDCService:
granted_scope = scope or (refresh_token_obj.scope or [])
# Generate new access token
logger.debug("[OIDC SERVICE] Generating new access token...")
logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600)
access_token_jti = OIDCTokenService._generate_jti()
access_token = OIDCTokenService.create_access_token(
client_id=client_id,
@@ -570,14 +656,21 @@ class OIDCService:
scope=granted_scope,
jti=access_token_jti,
)
logger.debug("[OIDC SERVICE] Access token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Generate new ID token
logger.debug("[OIDC SERVICE] Generating new ID token...")
logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600)
id_token = OIDCTokenService.create_id_token(
client_id=client_id,
user_id=refresh_token_obj.user_id,
scope=granted_scope,
access_token=access_token,
)
logger.debug("[OIDC SERVICE] ID token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Rotate refresh token
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
@@ -590,12 +683,15 @@ class OIDCService:
refresh_token_obj.rotate(new_hash)
# Store new token metadata
logger.debug("[OIDC SERVICE] Creating access token metadata...")
access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600)
logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z")
OIDCTokenMetadata.create_metadata(
client_id=client.id,
user_id=refresh_token_obj.user_id,
token_type="access_token",
token_jti=access_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
expires_at=access_token_expires_at,
)
# Log refresh event
@@ -615,6 +711,17 @@ class OIDCService:
"id_token": id_token,
"refresh_token": new_refresh,
}
logger.debug("[OIDC SERVICE] refresh_access_token completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": client.access_token_lifetime or 3600,
"id_token": id_token,
"refresh_token": new_refresh,
}
@classmethod
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
@@ -630,10 +737,23 @@ class OIDCService:
Raises:
InvalidTokenError: If token is invalid
"""
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] validate_access_token() called")
logger.debug("[OIDC SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC SERVICE] Token length: %d", len(token))
logger.debug("[OIDC SERVICE] Client ID: %s", client_id)
try:
logger.debug("[OIDC SERVICE] Calling OIDCTokenService.validate_access_token()...")
claims = OIDCTokenService.validate_access_token(token, client_id)
logger.debug("[OIDC SERVICE] Token validation successful")
logger.debug("[OIDC SERVICE] Token claims: %s", claims)
logger.debug("[OIDC SERVICE] ===========================================")
return claims
except Exception as e:
logger.error("[OIDC SERVICE] Token validation failed: %s: %s", type(e).__name__, str(e))
import traceback
logger.error("[OIDC SERVICE] Traceback: %s", traceback.format_exc())
OIDCAuditService.log_event(
event_type="token_validation",
client_id=client_id,
@@ -770,29 +890,67 @@ class OIDCService:
Returns:
User information dictionary
"""
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] get_userinfo() called")
logger.debug("[OIDC SERVICE] Access token (first 50 chars): %s...", access_token[:50] if len(access_token) > 50 else access_token)
logger.debug("[OIDC SERVICE] Access token length: %d", len(access_token))
# Validate access token
logger.debug("[OIDC SERVICE] Validating access token...")
claims = cls.validate_access_token(access_token)
logger.debug("[OIDC SERVICE] Access token validated successfully")
logger.debug("[OIDC SERVICE] Token claims: %s", claims)
user_id = claims.get("sub")
logger.debug("[OIDC SERVICE] User ID from token: %s", user_id)
logger.debug("[OIDC SERVICE] Querying user from database...")
user = User.query.get(user_id)
logger.debug("[OIDC SERVICE] User query result: %s", user)
if not user:
logger.error("[OIDC SERVICE] User not found in database: user_id=%s", user_id)
raise NotFoundError("User not found")
logger.debug("[OIDC SERVICE] User found: user_id=%s, email=%s, full_name=%s", user.id, user.email, user.full_name)
# Get scopes from token
scope_str = claims.get("scope", "")
scopes = scope_str.split() if scope_str else []
logger.debug("[OIDC SERVICE] Scope string from token: '%s'", scope_str)
logger.debug("[OIDC SERVICE] Parsed scopes: %s", scopes)
userinfo = {"sub": user_id}
logger.debug("[OIDC SERVICE] Initial userinfo: %s", userinfo)
# Add claims based on scope
if "profile" in scopes and user.full_name:
logger.debug("[OIDC SERVICE] Found 'profile' in scope, adding name claim")
userinfo["name"] = user.full_name
logger.debug("[OIDC SERVICE] Added name: %s", user.full_name)
else:
logger.debug("[OIDC SERVICE] 'profile' not in scope or user.full_name is None: profile_in_scope=%s, full_name=%s", "profile" in scopes, user.full_name)
if "email" in scopes:
logger.debug("[OIDC SERVICE] Found 'email' in scope, adding email claims")
userinfo["email"] = user.email
userinfo["email_verified"] = user.email_verified
logger.debug("[OIDC SERVICE] Added email: %s, email_verified: %s", user.email, user.email_verified)
else:
logger.debug("[OIDC SERVICE] 'email' not in scope")
if "roles" in scopes:
logger.debug("[OIDC SERVICE] Found 'roles' in scope, adding roles claim")
user_roles = cls._get_user_roles(user)
userinfo["roles"] = user_roles
logger.debug("[OIDC SERVICE] Added roles: %s", user_roles)
else:
logger.debug("[OIDC SERVICE] 'roles' not in scope")
logger.debug("[OIDC SERVICE] Final userinfo: %s", userinfo)
# Log userinfo access
logger.debug("[OIDC SERVICE] Logging userinfo access event...")
OIDCAuditService.log_userinfo_event(
access_token=access_token,
user_id=user_id,
@@ -800,5 +958,54 @@ class OIDCService:
success=True,
scopes_claimed=scopes,
)
logger.debug("[OIDC SERVICE] Userinfo access event logged")
logger.debug("[OIDC SERVICE] get_userinfo() completed successfully")
logger.debug("[OIDC SERVICE] ===========================================")
return userinfo
@staticmethod
def _get_user_roles(user: User) -> list:
"""Get user's organization roles.
Args:
user: User instance
Returns:
List of role objects with organization_id and role
"""
logger.debug("[OIDC SERVICE] _get_user_roles() called")
logger.debug("[OIDC SERVICE] User: %s", user)
roles = []
if not user:
logger.debug("[OIDC SERVICE] User is None, returning empty roles list")
return roles
logger.debug("[OIDC SERVICE] User ID: %s", user.id)
logger.debug("[OIDC SERVICE] User email: %s", user.email)
logger.debug("[OIDC SERVICE] User organization_memberships: %s", user.organization_memberships)
if user.organization_memberships:
logger.debug("[OIDC SERVICE] User has %d organization memberships", len(user.organization_memberships))
for idx, member in enumerate(user.organization_memberships):
logger.debug("[OIDC SERVICE] Processing membership %d: member=%s", idx, member)
logger.debug("[OIDC SERVICE] organization_id: %s", member.organization_id)
logger.debug("[OIDC SERVICE] role: %s", member.role)
logger.debug("[OIDC SERVICE] role.value: %s", member.role.value)
role_entry = {
"organization_id": str(member.organization_id),
"role": member.role.value
}
roles.append(role_entry)
logger.debug("[OIDC SERVICE] Added role entry: %s", role_entry)
else:
logger.debug("[OIDC SERVICE] User has no organization memberships")
logger.debug("[OIDC SERVICE] Final roles list: %s", roles)
logger.debug("[OIDC SERVICE] _get_user_roles() completed")
return roles
+3 -2
View File
@@ -3,6 +3,7 @@ import secrets
from datetime import datetime, timedelta
from typing import Dict, Optional, Tuple
from datetime import timezone
from flask import current_app, g
from app.extensions import db
@@ -219,11 +220,11 @@ class OIDCSessionService:
"""
from datetime import timedelta
cutoff = datetime.utcnow() - timedelta(hours=older_than_hours)
cutoff = datetime.now(timezone.utc) - timedelta(hours=older_than_hours)
# Get expired sessions
expired_sessions = OIDCSession.query.filter(
OIDCSession.expires_at < datetime.utcnow(),
OIDCSession.expires_at < datetime.now(timezone.utc),
OIDCSession.deleted_at == None
).all()
+191 -37
View File
@@ -2,15 +2,20 @@
import hashlib
import base64
import secrets
from datetime import datetime, timedelta
import logging
import time
from datetime import datetime, timedelta, timezone
from typing import Dict, Optional, Any
import jwt
from flask import current_app, g
from app.models import User, OIDCClient
from app.models.organization_member import OrganizationMember
from app.services.oidc_jwks_service import OIDCJWKSService
logger = logging.getLogger(__name__)
class OIDCTokenService:
"""Service for generating and validating OIDC tokens.
@@ -134,7 +139,7 @@ class OIDCTokenService:
return lifetimes.get(token_type, 3600)
@classmethod
def create_access_token(cls, client_id: str, user_id: str, scope: list,
def create_access_token(cls, client_id: str, user_id: str, scope: list,
jti: str = None) -> str:
"""Create a JWT access token.
@@ -147,25 +152,44 @@ class OIDCTokenService:
Returns:
JWT access token string
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_access_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
jti = jti or cls._generate_jti()
now = datetime.utcnow()
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
logger.debug("[OIDC TOKEN SERVICE] Access token lifetime (seconds): %s", lifetime)
exp_timestamp = now_timestamp + lifetime
exp_time = now + timedelta(seconds=lifetime)
logger.debug("[OIDC TOKEN SERVICE] Access token expiration time (UTC): %s", exp_time.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Access token expiration timestamp: %s", exp_timestamp)
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
claims = {
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
"iat": int(now.timestamp()),
"nbf": int(now.timestamp()),
"exp": exp_timestamp,
"iat": now_timestamp,
"nbf": now_timestamp,
"jti": jti,
"client_id": client_id,
"scope": " ".join(scope) if isinstance(scope, list) else scope,
}
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, nbf=%s",
claims["exp"], claims["iat"], claims["nbf"])
# Get signing key
jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key()
@@ -174,6 +198,7 @@ class OIDCTokenService:
raise ValueError("No signing key available")
# Sign with RS256
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
token = jwt.encode(
claims,
signing_key.private_key,
@@ -181,6 +206,9 @@ class OIDCTokenService:
headers={"kid": signing_key.kid}
)
logger.debug("[OIDC TOKEN SERVICE] Access token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token
@classmethod
@@ -200,12 +228,30 @@ class OIDCTokenService:
Returns:
JWT ID token string
"""
now = datetime.utcnow()
auth_time = auth_time or int(now.timestamp())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_id_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
auth_time = auth_time or now_timestamp
logger.debug("[OIDC TOKEN SERVICE] auth_time (Unix timestamp): %s", auth_time)
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
logger.debug("[OIDC TOKEN SERVICE] ID token lifetime (seconds): %s", lifetime)
exp_timestamp = now_timestamp + lifetime
exp_time = now + timedelta(seconds=lifetime)
logger.debug("[OIDC TOKEN SERVICE] ID token expiration time (UTC): %s", exp_time.isoformat())
logger.debug("[OIDC TOKEN SERVICE] ID token expiration timestamp: %s", exp_timestamp)
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
# Get user for claims
user = User.query.get(user_id)
@@ -214,11 +260,14 @@ class OIDCTokenService:
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
"iat": int(now.timestamp()),
"exp": exp_timestamp,
"iat": now_timestamp,
"auth_time": auth_time,
}
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, auth_time=%s",
claims["exp"], claims["iat"], claims["auth_time"])
# Add nonce if provided
if nonce:
claims["nonce"] = nonce
@@ -235,6 +284,10 @@ class OIDCTokenService:
if user.full_name:
claims["name"] = user.full_name
# Add roles claim if scope is granted
if scope and "roles" in scope:
claims["roles"] = cls._get_user_roles(user)
# Add scope if provided
if scope:
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
@@ -247,6 +300,7 @@ class OIDCTokenService:
raise ValueError("No signing key available")
# Sign with RS256
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
token = jwt.encode(
claims,
signing_key.private_key,
@@ -254,10 +308,32 @@ class OIDCTokenService:
headers={"kid": signing_key.kid}
)
logger.debug("[OIDC TOKEN SERVICE] ID token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token
@staticmethod
def _get_user_roles(user: User) -> list:
"""Get user's organization roles.
Args:
user: User instance
Returns:
List of role objects with organization_id and role
"""
roles = []
if user and user.organization_memberships:
for member in user.organization_memberships:
roles.append({
"organization_id": str(member.organization_id),
"role": member.role.value
})
return roles
@classmethod
def create_refresh_token(cls, client_id: str, user_id: str,
def create_refresh_token(cls, client_id: str, user_id: str,
scope: list = None, access_token_id: str = None) -> str:
"""Create an opaque refresh token.
@@ -270,11 +346,21 @@ class OIDCTokenService:
Returns:
Opaque refresh token string
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_refresh_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] scope=%s, access_token_id=%s", scope, access_token_id)
token = cls._generate_opaque_token()
logger.debug("[OIDC TOKEN SERVICE] Refresh token generated: %s...", token[:20] if token else None)
# Hash for storage
token_hash = cls._hash_token(token)
logger.debug("[OIDC TOKEN SERVICE] Refresh token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token, token_hash
@classmethod
@@ -292,54 +378,91 @@ class OIDCTokenService:
jwt.ExpiredSignatureError: If token is expired
jwt.InvalidTokenError: If token is invalid
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] verify_token_signature() called")
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
# Get the JWKS with public keys
logger.debug("[OIDC TOKEN SERVICE] Getting JWKS...")
jwks_service = OIDCJWKSService()
jwks = jwks_service.get_jwks()
jwks = jwks_service.get_jwks(include_private_keys=True)
logger.debug("[OIDC TOKEN SERVICE] JWKS retrieved: %d keys", len(jwks.get("keys", [])))
# Get the key ID from token header
try:
logger.debug("[OIDC TOKEN SERVICE] Getting unverified token header...")
unverified_header = jwt.get_unverified_header(token)
except jwt.DecodeError:
logger.debug("[OIDC TOKEN SERVICE] Unverified header: %s", unverified_header)
except jwt.DecodeError as e:
logger.error("[OIDC TOKEN SERVICE] Failed to decode token header: %s", str(e))
raise jwt.InvalidTokenError("Invalid token header")
kid = unverified_header.get("kid")
logger.debug("[OIDC TOKEN SERVICE] Key ID (kid) from token header: %s", kid)
# Find the matching public key
logger.debug("[OIDC TOKEN SERVICE] Searching for matching public key...")
public_key = None
for key in jwks.get("keys", []):
for idx, key in enumerate(jwks.get("keys", [])):
logger.debug("[OIDC TOKEN SERVICE] Checking key %d: kid=%s", idx, key.get("kid"))
if key.get("kid") == kid:
logger.debug("[OIDC TOKEN SERVICE] Found matching key at index %d", idx)
try:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
logger.debug("[OIDC TOKEN SERVICE] Loading PEM public key...")
public_key = serialization.load_pem_public_key(
key["public_key"].encode() if isinstance(key["public_key"], str)
key["public_key"].encode() if isinstance(key["public_key"], str)
else key["public_key"],
backend=default_backend()
)
logger.debug("[OIDC TOKEN SERVICE] Public key loaded successfully")
break
except (ImportError, Exception):
except (ImportError, Exception) as e:
logger.error("[OIDC TOKEN SERVICE] Failed to load public key: %s: %s", type(e).__name__, str(e))
continue
if not public_key:
logger.error("[OIDC TOKEN SERVICE] No matching public key found for kid=%s", kid)
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
# Verify the signature
claims = jwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=None, # We'll validate audience separately
issuer=cls._get_issuer(),
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False, # Handle audience manually
"verify_iss": False, # Handle issuer manually
}
)
logger.debug("[OIDC TOKEN SERVICE] Public key found, verifying signature...")
return claims
# Verify the signature
try:
claims = jwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=None, # We'll validate audience separately
issuer=cls._get_issuer(),
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False, # Handle audience manually
"verify_iss": False, # Handle issuer manually
}
)
logger.debug("[OIDC TOKEN SERVICE] Signature verification successful")
logger.debug("[OIDC TOKEN SERVICE] Decoded claims: %s", claims)
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return claims
except jwt.ExpiredSignatureError as e:
logger.error("[OIDC TOKEN SERVICE] Token has expired: %s", str(e))
raise
except jwt.InvalidSignatureError as e:
logger.error("[OIDC TOKEN SERVICE] Invalid token signature: %s", str(e))
raise
except jwt.InvalidTokenError as e:
logger.error("[OIDC TOKEN SERVICE] Invalid token: %s: %s", type(e).__name__, str(e))
raise
except Exception as e:
logger.error("[OIDC TOKEN SERVICE] Unexpected error during token verification: %s: %s", type(e).__name__, str(e))
import traceback
logger.error("[OIDC TOKEN SERVICE] Traceback: %s", traceback.format_exc())
raise
@classmethod
def decode_token(cls, token: str, verify: bool = False) -> Dict:
@@ -378,16 +501,41 @@ class OIDCTokenService:
jwt.InvalidTokenError: If token is invalid
ValueError: If token is expired or audience mismatch
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() called")
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
logger.debug("[OIDC TOKEN SERVICE] Client ID: %s", client_id)
# Verify token signature
logger.debug("[OIDC TOKEN SERVICE] Verifying token signature...")
claims = cls.verify_token_signature(token)
logger.debug("[OIDC TOKEN SERVICE] Token signature verified")
logger.debug("[OIDC TOKEN SERVICE] Claims: %s", claims)
# Check expiration
if claims.get("exp", 0) < datetime.utcnow().timestamp():
exp = claims.get("exp", 0)
now_timestamp = int(time.time())
if exp < now_timestamp:
logger.error("[OIDC TOKEN SERVICE] Token has expired")
raise ValueError("Token has expired")
# Validate audience if client_id provided
aud = claims.get("aud")
logger.debug("[OIDC TOKEN SERVICE] Token audience (aud): %s", aud)
logger.debug("[OIDC TOKEN SERVICE] Expected client_id: %s", client_id)
if client_id:
if claims.get("aud") != client_id:
if aud != client_id:
logger.error("[OIDC TOKEN SERVICE] Audience mismatch: expected=%s, got=%s", client_id, aud)
raise ValueError("Invalid audience")
logger.debug("[OIDC TOKEN SERVICE] Audience validation passed")
else:
logger.debug("[OIDC TOKEN SERVICE] No client_id provided, skipping audience validation")
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() completed successfully")
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return claims
@@ -410,11 +558,17 @@ class OIDCTokenService:
claims = cls.validate_access_token(token, client_id)
# Calculate remaining time
now = datetime.utcnow().timestamp()
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
exp = claims.get("exp", 0)
iat = claims.get("iat", 0)
result["active"] = exp > now
logger.debug("[OIDC TOKEN SERVICE] Introspection - Current UTC time: %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration timestamp: %s", exp)
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration datetime (UTC): %s", datetime.fromtimestamp(exp, tz=timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] Introspection - Time until expiration: %s seconds", exp - now_timestamp)
result["active"] = exp > now_timestamp
result.update({
"iss": claims.get("iss"),
"sub": claims.get("sub"),
@@ -429,8 +583,8 @@ class OIDCTokenService:
})
# Add expiry in seconds
if exp > now:
result["exp"] = int(exp - now)
if exp > now_timestamp:
result["exp"] = int(exp - now_timestamp)
except (jwt.InvalidTokenError, ValueError) as e:
result["active"] = False
+188
View File
@@ -0,0 +1,188 @@
"""TOTP (Time-based One-Time Password) service."""
import base64
import io
import logging
import secrets
from typing import Tuple
import pyotp
from app.extensions import bcrypt
logger = logging.getLogger(__name__)
class TOTPService:
"""Service for TOTP operations."""
@staticmethod
def generate_secret() -> str:
"""
Generate a new TOTP secret.
Returns:
Base32 encoded secret (32 characters)
Note:
The secret is generated using cryptographically secure random bytes
and encoded in base32 format for compatibility with authenticator apps.
"""
# Generate 20 random bytes (160 bits) and encode as base32
random_bytes = secrets.token_bytes(20)
secret = base64.b32encode(random_bytes).decode("utf-8")
logger.debug(f"Generated new TOTP secret: {secret[:8]}...")
return secret
@staticmethod
def generate_provisioning_uri(user_email: str, secret: str, issuer: str = "Gatehouse") -> str:
"""
Generate provisioning URI for QR code.
Args:
user_email: User's email address
secret: TOTP secret (base32 encoded)
issuer: Issuer name (default: "Gatehouse")
Returns:
otpauth:// URI for QR code generation
Example:
>>> uri = TOTPService.generate_provisioning_uri("user@example.com", "JBSWY3DPEHPK3PXP")
>>> print(uri)
otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse
"""
totp = pyotp.TOTP(secret)
uri = totp.provisioning_uri(name=user_email, issuer_name=issuer)
logger.debug(f"Generated provisioning URI for user: {user_email}")
return uri
@staticmethod
def verify_code(secret: str, code: str, window: int = 1) -> bool:
"""
Verify a TOTP code against the secret.
Args:
secret: TOTP secret (base32 encoded)
code: 6-digit TOTP code to verify
window: Time window for code validation (default: 1, allows codes from previous/next time steps)
Returns:
True if code is valid, False otherwise
Note:
The window parameter allows for clock skew between the server
and the authenticator app. A window of 1 allows codes from
the previous, current, and next 30-second intervals.
"""
totp = pyotp.TOTP(secret)
is_valid = totp.verify(code, valid_window=window)
logger.debug(f"TOTP code verification: valid={is_valid}, window={window}")
return is_valid
@staticmethod
def generate_backup_codes(count: int = 10) -> Tuple[list[str], list[str]]:
"""
Generate backup codes for TOTP recovery.
Args:
count: Number of backup codes to generate (default: 10)
Returns:
Tuple of (plain_codes, hashed_codes)
- plain_codes: List of plain text backup codes (for display to user)
- hashed_codes: List of bcrypt hashed backup codes (for storage)
Note:
Backup codes are 16-character alphanumeric codes that can be used
to recover access if the TOTP device is lost. Each code can only
be used once.
"""
plain_codes = []
hashed_codes = []
for _ in range(count):
# Generate a 16-character alphanumeric code
code = secrets.token_hex(8).upper()
plain_codes.append(code)
# Hash the code using bcrypt
hashed_code = bcrypt.generate_password_hash(code).decode("utf-8")
hashed_codes.append(hashed_code)
logger.debug(f"Generated {count} backup codes")
return plain_codes, hashed_codes
@staticmethod
def verify_backup_code(hashed_codes: list[str], code: str) -> Tuple[bool, list[str]]:
"""
Verify and consume a backup code.
Args:
hashed_codes: List of bcrypt hashed backup codes
code: Plain text backup code to verify
Returns:
Tuple of (is_valid, remaining_codes)
- is_valid: True if code was valid and consumed, False otherwise
- remaining_codes: List of remaining hashed codes (with consumed code removed)
Note:
Once a backup code is used, it is removed from the list and cannot
be used again. This ensures each code is single-use.
"""
remaining_codes = []
for hashed_code in hashed_codes:
if bcrypt.check_password_hash(hashed_code, code):
# Code found and valid - don't add to remaining codes (consumed)
logger.debug("Backup code verified and consumed")
return True, remaining_codes
else:
# Code doesn't match - keep it in remaining codes
remaining_codes.append(hashed_code)
logger.debug("Backup code verification failed")
return False, remaining_codes
@staticmethod
def generate_qr_code_data_uri(provisioning_uri: str) -> str:
"""
Generate QR code as data URI for frontend display.
Args:
provisioning_uri: otpauth:// URI to encode in QR code
Returns:
Base64 encoded PNG image as data URI (data:image/png;base64,...)
Note:
If the qrcode library is not installed, returns a placeholder message.
Install with: pip install qrcode[pil]
"""
try:
import qrcode
# Create QR code
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_L,
box_size=10,
border=4,
)
qr.add_data(provisioning_uri)
qr.make(fit=True)
# Generate image
img = qr.make_image(fill_color="black", back_color="white")
# Convert to base64
buffer = io.BytesIO()
img.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
data_uri = f"data:image/png;base64,{img_base64}"
logger.debug("Generated QR code data URI")
return data_uri
except ImportError:
logger.warning("qrcode library not installed, returning placeholder")
return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]"