feat(auth): implement TOTP two-factor authentication with enrollment and verification
Adds TOTP (Time-based One-Time Password) two-factor authentication support including: - New TOTP service with secret generation, QR code provisioning, and code verification - New auth endpoints for enrollment, verification, status, and backup code management - New TOTP authentication method type and user methods for TOTP management - Backup codes generation and verification for account recovery - Updated OIDC endpoints with timezone-aware datetime handling and RFC-compliant responses - Added "roles" scope support for OIDC userinfo and ID tokens - New pyotp dependency for TOTP operations - Comprehensive unit tests for TOTP service
This commit is contained in:
@@ -2,15 +2,20 @@
|
||||
import hashlib
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
import jwt
|
||||
from flask import current_app, g
|
||||
|
||||
from app.models import User, OIDCClient
|
||||
from app.models.organization_member import OrganizationMember
|
||||
from app.services.oidc_jwks_service import OIDCJWKSService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OIDCTokenService:
|
||||
"""Service for generating and validating OIDC tokens.
|
||||
@@ -134,7 +139,7 @@ class OIDCTokenService:
|
||||
return lifetimes.get(token_type, 3600)
|
||||
|
||||
@classmethod
|
||||
def create_access_token(cls, client_id: str, user_id: str, scope: list,
|
||||
def create_access_token(cls, client_id: str, user_id: str, scope: list,
|
||||
jti: str = None) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
@@ -147,25 +152,44 @@ class OIDCTokenService:
|
||||
Returns:
|
||||
JWT access token string
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] create_access_token called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
|
||||
|
||||
jti = jti or cls._generate_jti()
|
||||
now = datetime.utcnow()
|
||||
now_timestamp = int(time.time())
|
||||
now = datetime.now(timezone.utc)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token lifetime (seconds): %s", lifetime)
|
||||
|
||||
exp_timestamp = now_timestamp + lifetime
|
||||
exp_time = now + timedelta(seconds=lifetime)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token expiration time (UTC): %s", exp_time.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token expiration timestamp: %s", exp_timestamp)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
|
||||
|
||||
claims = {
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"nbf": int(now.timestamp()),
|
||||
"exp": exp_timestamp,
|
||||
"iat": now_timestamp,
|
||||
"nbf": now_timestamp,
|
||||
"jti": jti,
|
||||
"client_id": client_id,
|
||||
"scope": " ".join(scope) if isinstance(scope, list) else scope,
|
||||
}
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, nbf=%s",
|
||||
claims["exp"], claims["iat"], claims["nbf"])
|
||||
|
||||
# Get signing key
|
||||
jwks_service = OIDCJWKSService()
|
||||
signing_key = jwks_service.get_signing_key()
|
||||
@@ -174,6 +198,7 @@ class OIDCTokenService:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
@@ -181,6 +206,9 @@ class OIDCTokenService:
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token created successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
@@ -200,12 +228,30 @@ class OIDCTokenService:
|
||||
Returns:
|
||||
JWT ID token string
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
auth_time = auth_time or int(now.timestamp())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] create_id_token called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||
logger.debug("[OIDC TOKEN SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
|
||||
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
|
||||
|
||||
now_timestamp = int(time.time())
|
||||
now = datetime.now(timezone.utc)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
|
||||
auth_time = auth_time or now_timestamp
|
||||
logger.debug("[OIDC TOKEN SERVICE] auth_time (Unix timestamp): %s", auth_time)
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token lifetime (seconds): %s", lifetime)
|
||||
|
||||
exp_timestamp = now_timestamp + lifetime
|
||||
exp_time = now + timedelta(seconds=lifetime)
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token expiration time (UTC): %s", exp_time.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token expiration timestamp: %s", exp_timestamp)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
|
||||
|
||||
# Get user for claims
|
||||
user = User.query.get(user_id)
|
||||
@@ -214,11 +260,14 @@ class OIDCTokenService:
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": exp_timestamp,
|
||||
"iat": now_timestamp,
|
||||
"auth_time": auth_time,
|
||||
}
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, auth_time=%s",
|
||||
claims["exp"], claims["iat"], claims["auth_time"])
|
||||
|
||||
# Add nonce if provided
|
||||
if nonce:
|
||||
claims["nonce"] = nonce
|
||||
@@ -235,6 +284,10 @@ class OIDCTokenService:
|
||||
if user.full_name:
|
||||
claims["name"] = user.full_name
|
||||
|
||||
# Add roles claim if scope is granted
|
||||
if scope and "roles" in scope:
|
||||
claims["roles"] = cls._get_user_roles(user)
|
||||
|
||||
# Add scope if provided
|
||||
if scope:
|
||||
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
|
||||
@@ -247,6 +300,7 @@ class OIDCTokenService:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
@@ -254,10 +308,32 @@ class OIDCTokenService:
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token created successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def _get_user_roles(user: User) -> list:
|
||||
"""Get user's organization roles.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
List of role objects with organization_id and role
|
||||
"""
|
||||
roles = []
|
||||
if user and user.organization_memberships:
|
||||
for member in user.organization_memberships:
|
||||
roles.append({
|
||||
"organization_id": str(member.organization_id),
|
||||
"role": member.role.value
|
||||
})
|
||||
return roles
|
||||
|
||||
@classmethod
|
||||
def create_refresh_token(cls, client_id: str, user_id: str,
|
||||
def create_refresh_token(cls, client_id: str, user_id: str,
|
||||
scope: list = None, access_token_id: str = None) -> str:
|
||||
"""Create an opaque refresh token.
|
||||
|
||||
@@ -270,11 +346,21 @@ class OIDCTokenService:
|
||||
Returns:
|
||||
Opaque refresh token string
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] create_refresh_token called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||
logger.debug("[OIDC TOKEN SERVICE] scope=%s, access_token_id=%s", scope, access_token_id)
|
||||
|
||||
token = cls._generate_opaque_token()
|
||||
logger.debug("[OIDC TOKEN SERVICE] Refresh token generated: %s...", token[:20] if token else None)
|
||||
|
||||
# Hash for storage
|
||||
token_hash = cls._hash_token(token)
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Refresh token created successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return token, token_hash
|
||||
|
||||
@classmethod
|
||||
@@ -292,54 +378,91 @@ class OIDCTokenService:
|
||||
jwt.ExpiredSignatureError: If token is expired
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] verify_token_signature() called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
|
||||
|
||||
# Get the JWKS with public keys
|
||||
logger.debug("[OIDC TOKEN SERVICE] Getting JWKS...")
|
||||
jwks_service = OIDCJWKSService()
|
||||
jwks = jwks_service.get_jwks()
|
||||
jwks = jwks_service.get_jwks(include_private_keys=True)
|
||||
logger.debug("[OIDC TOKEN SERVICE] JWKS retrieved: %d keys", len(jwks.get("keys", [])))
|
||||
|
||||
# Get the key ID from token header
|
||||
try:
|
||||
logger.debug("[OIDC TOKEN SERVICE] Getting unverified token header...")
|
||||
unverified_header = jwt.get_unverified_header(token)
|
||||
except jwt.DecodeError:
|
||||
logger.debug("[OIDC TOKEN SERVICE] Unverified header: %s", unverified_header)
|
||||
except jwt.DecodeError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Failed to decode token header: %s", str(e))
|
||||
raise jwt.InvalidTokenError("Invalid token header")
|
||||
|
||||
kid = unverified_header.get("kid")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Key ID (kid) from token header: %s", kid)
|
||||
|
||||
# Find the matching public key
|
||||
logger.debug("[OIDC TOKEN SERVICE] Searching for matching public key...")
|
||||
public_key = None
|
||||
for key in jwks.get("keys", []):
|
||||
for idx, key in enumerate(jwks.get("keys", [])):
|
||||
logger.debug("[OIDC TOKEN SERVICE] Checking key %d: kid=%s", idx, key.get("kid"))
|
||||
if key.get("kid") == kid:
|
||||
logger.debug("[OIDC TOKEN SERVICE] Found matching key at index %d", idx)
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Loading PEM public key...")
|
||||
public_key = serialization.load_pem_public_key(
|
||||
key["public_key"].encode() if isinstance(key["public_key"], str)
|
||||
key["public_key"].encode() if isinstance(key["public_key"], str)
|
||||
else key["public_key"],
|
||||
backend=default_backend()
|
||||
)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Public key loaded successfully")
|
||||
break
|
||||
except (ImportError, Exception):
|
||||
except (ImportError, Exception) as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Failed to load public key: %s: %s", type(e).__name__, str(e))
|
||||
continue
|
||||
|
||||
if not public_key:
|
||||
logger.error("[OIDC TOKEN SERVICE] No matching public key found for kid=%s", kid)
|
||||
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
|
||||
|
||||
# Verify the signature
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
public_key,
|
||||
algorithms=["RS256"],
|
||||
audience=None, # We'll validate audience separately
|
||||
issuer=cls._get_issuer(),
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_aud": False, # Handle audience manually
|
||||
"verify_iss": False, # Handle issuer manually
|
||||
}
|
||||
)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Public key found, verifying signature...")
|
||||
|
||||
return claims
|
||||
# Verify the signature
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
public_key,
|
||||
algorithms=["RS256"],
|
||||
audience=None, # We'll validate audience separately
|
||||
issuer=cls._get_issuer(),
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_aud": False, # Handle audience manually
|
||||
"verify_iss": False, # Handle issuer manually
|
||||
}
|
||||
)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Signature verification successful")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Decoded claims: %s", claims)
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return claims
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Token has expired: %s", str(e))
|
||||
raise
|
||||
except jwt.InvalidSignatureError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Invalid token signature: %s", str(e))
|
||||
raise
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Invalid token: %s: %s", type(e).__name__, str(e))
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Unexpected error during token verification: %s: %s", type(e).__name__, str(e))
|
||||
import traceback
|
||||
logger.error("[OIDC TOKEN SERVICE] Traceback: %s", traceback.format_exc())
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def decode_token(cls, token: str, verify: bool = False) -> Dict:
|
||||
@@ -378,16 +501,41 @@ class OIDCTokenService:
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
ValueError: If token is expired or audience mismatch
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
|
||||
logger.debug("[OIDC TOKEN SERVICE] Client ID: %s", client_id)
|
||||
|
||||
# Verify token signature
|
||||
logger.debug("[OIDC TOKEN SERVICE] Verifying token signature...")
|
||||
claims = cls.verify_token_signature(token)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token signature verified")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Claims: %s", claims)
|
||||
|
||||
# Check expiration
|
||||
if claims.get("exp", 0) < datetime.utcnow().timestamp():
|
||||
exp = claims.get("exp", 0)
|
||||
now_timestamp = int(time.time())
|
||||
|
||||
if exp < now_timestamp:
|
||||
logger.error("[OIDC TOKEN SERVICE] Token has expired")
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
# Validate audience if client_id provided
|
||||
aud = claims.get("aud")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token audience (aud): %s", aud)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Expected client_id: %s", client_id)
|
||||
|
||||
if client_id:
|
||||
if claims.get("aud") != client_id:
|
||||
if aud != client_id:
|
||||
logger.error("[OIDC TOKEN SERVICE] Audience mismatch: expected=%s, got=%s", client_id, aud)
|
||||
raise ValueError("Invalid audience")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Audience validation passed")
|
||||
else:
|
||||
logger.debug("[OIDC TOKEN SERVICE] No client_id provided, skipping audience validation")
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() completed successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
|
||||
return claims
|
||||
|
||||
@@ -410,11 +558,17 @@ class OIDCTokenService:
|
||||
claims = cls.validate_access_token(token, client_id)
|
||||
|
||||
# Calculate remaining time
|
||||
now = datetime.utcnow().timestamp()
|
||||
now_timestamp = int(time.time())
|
||||
now = datetime.now(timezone.utc)
|
||||
exp = claims.get("exp", 0)
|
||||
iat = claims.get("iat", 0)
|
||||
|
||||
result["active"] = exp > now
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Current UTC time: %s", now.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration timestamp: %s", exp)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration datetime (UTC): %s", datetime.fromtimestamp(exp, tz=timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Time until expiration: %s seconds", exp - now_timestamp)
|
||||
|
||||
result["active"] = exp > now_timestamp
|
||||
result.update({
|
||||
"iss": claims.get("iss"),
|
||||
"sub": claims.get("sub"),
|
||||
@@ -429,8 +583,8 @@ class OIDCTokenService:
|
||||
})
|
||||
|
||||
# Add expiry in seconds
|
||||
if exp > now:
|
||||
result["exp"] = int(exp - now)
|
||||
if exp > now_timestamp:
|
||||
result["exp"] = int(exp - now_timestamp)
|
||||
|
||||
except (jwt.InvalidTokenError, ValueError) as e:
|
||||
result["active"] = False
|
||||
|
||||
Reference in New Issue
Block a user