feat(auth): implement TOTP two-factor authentication with enrollment and verification

Adds TOTP (Time-based One-Time Password) two-factor authentication support including:
- New TOTP service with secret generation, QR code provisioning, and code verification
- New auth endpoints for enrollment, verification, status, and backup code management
- New TOTP authentication method type and user methods for TOTP management
- Backup codes generation and verification for account recovery
- Updated OIDC endpoints with timezone-aware datetime handling and RFC-compliant responses
- Added "roles" scope support for OIDC userinfo and ID tokens
- New pyotp dependency for TOTP operations
- Comprehensive unit tests for TOTP service
This commit is contained in:
2026-01-14 18:06:17 +10:30
parent 977abf66df
commit cfd79190ee
26 changed files with 2176 additions and 263 deletions
+191 -37
View File
@@ -2,15 +2,20 @@
import hashlib
import base64
import secrets
from datetime import datetime, timedelta
import logging
import time
from datetime import datetime, timedelta, timezone
from typing import Dict, Optional, Any
import jwt
from flask import current_app, g
from app.models import User, OIDCClient
from app.models.organization_member import OrganizationMember
from app.services.oidc_jwks_service import OIDCJWKSService
logger = logging.getLogger(__name__)
class OIDCTokenService:
"""Service for generating and validating OIDC tokens.
@@ -134,7 +139,7 @@ class OIDCTokenService:
return lifetimes.get(token_type, 3600)
@classmethod
def create_access_token(cls, client_id: str, user_id: str, scope: list,
def create_access_token(cls, client_id: str, user_id: str, scope: list,
jti: str = None) -> str:
"""Create a JWT access token.
@@ -147,25 +152,44 @@ class OIDCTokenService:
Returns:
JWT access token string
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_access_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
jti = jti or cls._generate_jti()
now = datetime.utcnow()
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
logger.debug("[OIDC TOKEN SERVICE] Access token lifetime (seconds): %s", lifetime)
exp_timestamp = now_timestamp + lifetime
exp_time = now + timedelta(seconds=lifetime)
logger.debug("[OIDC TOKEN SERVICE] Access token expiration time (UTC): %s", exp_time.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Access token expiration timestamp: %s", exp_timestamp)
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
claims = {
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
"iat": int(now.timestamp()),
"nbf": int(now.timestamp()),
"exp": exp_timestamp,
"iat": now_timestamp,
"nbf": now_timestamp,
"jti": jti,
"client_id": client_id,
"scope": " ".join(scope) if isinstance(scope, list) else scope,
}
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, nbf=%s",
claims["exp"], claims["iat"], claims["nbf"])
# Get signing key
jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key()
@@ -174,6 +198,7 @@ class OIDCTokenService:
raise ValueError("No signing key available")
# Sign with RS256
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
token = jwt.encode(
claims,
signing_key.private_key,
@@ -181,6 +206,9 @@ class OIDCTokenService:
headers={"kid": signing_key.kid}
)
logger.debug("[OIDC TOKEN SERVICE] Access token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token
@classmethod
@@ -200,12 +228,30 @@ class OIDCTokenService:
Returns:
JWT ID token string
"""
now = datetime.utcnow()
auth_time = auth_time or int(now.timestamp())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_id_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
auth_time = auth_time or now_timestamp
logger.debug("[OIDC TOKEN SERVICE] auth_time (Unix timestamp): %s", auth_time)
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
logger.debug("[OIDC TOKEN SERVICE] ID token lifetime (seconds): %s", lifetime)
exp_timestamp = now_timestamp + lifetime
exp_time = now + timedelta(seconds=lifetime)
logger.debug("[OIDC TOKEN SERVICE] ID token expiration time (UTC): %s", exp_time.isoformat())
logger.debug("[OIDC TOKEN SERVICE] ID token expiration timestamp: %s", exp_timestamp)
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
# Get user for claims
user = User.query.get(user_id)
@@ -214,11 +260,14 @@ class OIDCTokenService:
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
"iat": int(now.timestamp()),
"exp": exp_timestamp,
"iat": now_timestamp,
"auth_time": auth_time,
}
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, auth_time=%s",
claims["exp"], claims["iat"], claims["auth_time"])
# Add nonce if provided
if nonce:
claims["nonce"] = nonce
@@ -235,6 +284,10 @@ class OIDCTokenService:
if user.full_name:
claims["name"] = user.full_name
# Add roles claim if scope is granted
if scope and "roles" in scope:
claims["roles"] = cls._get_user_roles(user)
# Add scope if provided
if scope:
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
@@ -247,6 +300,7 @@ class OIDCTokenService:
raise ValueError("No signing key available")
# Sign with RS256
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
token = jwt.encode(
claims,
signing_key.private_key,
@@ -254,10 +308,32 @@ class OIDCTokenService:
headers={"kid": signing_key.kid}
)
logger.debug("[OIDC TOKEN SERVICE] ID token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token
@staticmethod
def _get_user_roles(user: User) -> list:
"""Get user's organization roles.
Args:
user: User instance
Returns:
List of role objects with organization_id and role
"""
roles = []
if user and user.organization_memberships:
for member in user.organization_memberships:
roles.append({
"organization_id": str(member.organization_id),
"role": member.role.value
})
return roles
@classmethod
def create_refresh_token(cls, client_id: str, user_id: str,
def create_refresh_token(cls, client_id: str, user_id: str,
scope: list = None, access_token_id: str = None) -> str:
"""Create an opaque refresh token.
@@ -270,11 +346,21 @@ class OIDCTokenService:
Returns:
Opaque refresh token string
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_refresh_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] scope=%s, access_token_id=%s", scope, access_token_id)
token = cls._generate_opaque_token()
logger.debug("[OIDC TOKEN SERVICE] Refresh token generated: %s...", token[:20] if token else None)
# Hash for storage
token_hash = cls._hash_token(token)
logger.debug("[OIDC TOKEN SERVICE] Refresh token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token, token_hash
@classmethod
@@ -292,54 +378,91 @@ class OIDCTokenService:
jwt.ExpiredSignatureError: If token is expired
jwt.InvalidTokenError: If token is invalid
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] verify_token_signature() called")
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
# Get the JWKS with public keys
logger.debug("[OIDC TOKEN SERVICE] Getting JWKS...")
jwks_service = OIDCJWKSService()
jwks = jwks_service.get_jwks()
jwks = jwks_service.get_jwks(include_private_keys=True)
logger.debug("[OIDC TOKEN SERVICE] JWKS retrieved: %d keys", len(jwks.get("keys", [])))
# Get the key ID from token header
try:
logger.debug("[OIDC TOKEN SERVICE] Getting unverified token header...")
unverified_header = jwt.get_unverified_header(token)
except jwt.DecodeError:
logger.debug("[OIDC TOKEN SERVICE] Unverified header: %s", unverified_header)
except jwt.DecodeError as e:
logger.error("[OIDC TOKEN SERVICE] Failed to decode token header: %s", str(e))
raise jwt.InvalidTokenError("Invalid token header")
kid = unverified_header.get("kid")
logger.debug("[OIDC TOKEN SERVICE] Key ID (kid) from token header: %s", kid)
# Find the matching public key
logger.debug("[OIDC TOKEN SERVICE] Searching for matching public key...")
public_key = None
for key in jwks.get("keys", []):
for idx, key in enumerate(jwks.get("keys", [])):
logger.debug("[OIDC TOKEN SERVICE] Checking key %d: kid=%s", idx, key.get("kid"))
if key.get("kid") == kid:
logger.debug("[OIDC TOKEN SERVICE] Found matching key at index %d", idx)
try:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
logger.debug("[OIDC TOKEN SERVICE] Loading PEM public key...")
public_key = serialization.load_pem_public_key(
key["public_key"].encode() if isinstance(key["public_key"], str)
key["public_key"].encode() if isinstance(key["public_key"], str)
else key["public_key"],
backend=default_backend()
)
logger.debug("[OIDC TOKEN SERVICE] Public key loaded successfully")
break
except (ImportError, Exception):
except (ImportError, Exception) as e:
logger.error("[OIDC TOKEN SERVICE] Failed to load public key: %s: %s", type(e).__name__, str(e))
continue
if not public_key:
logger.error("[OIDC TOKEN SERVICE] No matching public key found for kid=%s", kid)
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
# Verify the signature
claims = jwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=None, # We'll validate audience separately
issuer=cls._get_issuer(),
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False, # Handle audience manually
"verify_iss": False, # Handle issuer manually
}
)
logger.debug("[OIDC TOKEN SERVICE] Public key found, verifying signature...")
return claims
# Verify the signature
try:
claims = jwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=None, # We'll validate audience separately
issuer=cls._get_issuer(),
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False, # Handle audience manually
"verify_iss": False, # Handle issuer manually
}
)
logger.debug("[OIDC TOKEN SERVICE] Signature verification successful")
logger.debug("[OIDC TOKEN SERVICE] Decoded claims: %s", claims)
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return claims
except jwt.ExpiredSignatureError as e:
logger.error("[OIDC TOKEN SERVICE] Token has expired: %s", str(e))
raise
except jwt.InvalidSignatureError as e:
logger.error("[OIDC TOKEN SERVICE] Invalid token signature: %s", str(e))
raise
except jwt.InvalidTokenError as e:
logger.error("[OIDC TOKEN SERVICE] Invalid token: %s: %s", type(e).__name__, str(e))
raise
except Exception as e:
logger.error("[OIDC TOKEN SERVICE] Unexpected error during token verification: %s: %s", type(e).__name__, str(e))
import traceback
logger.error("[OIDC TOKEN SERVICE] Traceback: %s", traceback.format_exc())
raise
@classmethod
def decode_token(cls, token: str, verify: bool = False) -> Dict:
@@ -378,16 +501,41 @@ class OIDCTokenService:
jwt.InvalidTokenError: If token is invalid
ValueError: If token is expired or audience mismatch
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() called")
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
logger.debug("[OIDC TOKEN SERVICE] Client ID: %s", client_id)
# Verify token signature
logger.debug("[OIDC TOKEN SERVICE] Verifying token signature...")
claims = cls.verify_token_signature(token)
logger.debug("[OIDC TOKEN SERVICE] Token signature verified")
logger.debug("[OIDC TOKEN SERVICE] Claims: %s", claims)
# Check expiration
if claims.get("exp", 0) < datetime.utcnow().timestamp():
exp = claims.get("exp", 0)
now_timestamp = int(time.time())
if exp < now_timestamp:
logger.error("[OIDC TOKEN SERVICE] Token has expired")
raise ValueError("Token has expired")
# Validate audience if client_id provided
aud = claims.get("aud")
logger.debug("[OIDC TOKEN SERVICE] Token audience (aud): %s", aud)
logger.debug("[OIDC TOKEN SERVICE] Expected client_id: %s", client_id)
if client_id:
if claims.get("aud") != client_id:
if aud != client_id:
logger.error("[OIDC TOKEN SERVICE] Audience mismatch: expected=%s, got=%s", client_id, aud)
raise ValueError("Invalid audience")
logger.debug("[OIDC TOKEN SERVICE] Audience validation passed")
else:
logger.debug("[OIDC TOKEN SERVICE] No client_id provided, skipping audience validation")
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() completed successfully")
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return claims
@@ -410,11 +558,17 @@ class OIDCTokenService:
claims = cls.validate_access_token(token, client_id)
# Calculate remaining time
now = datetime.utcnow().timestamp()
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
exp = claims.get("exp", 0)
iat = claims.get("iat", 0)
result["active"] = exp > now
logger.debug("[OIDC TOKEN SERVICE] Introspection - Current UTC time: %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration timestamp: %s", exp)
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration datetime (UTC): %s", datetime.fromtimestamp(exp, tz=timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] Introspection - Time until expiration: %s seconds", exp - now_timestamp)
result["active"] = exp > now_timestamp
result.update({
"iss": claims.get("iss"),
"sub": claims.get("sub"),
@@ -429,8 +583,8 @@ class OIDCTokenService:
})
# Add expiry in seconds
if exp > now:
result["exp"] = int(exp - now)
if exp > now_timestamp:
result["exp"] = int(exp - now_timestamp)
except (jwt.InvalidTokenError, ValueError) as e:
result["active"] = False