feat(auth): implement TOTP two-factor authentication with enrollment and verification
Adds TOTP (Time-based One-Time Password) two-factor authentication support including: - New TOTP service with secret generation, QR code provisioning, and code verification - New auth endpoints for enrollment, verification, status, and backup code management - New TOTP authentication method type and user methods for TOTP management - Backup codes generation and verification for account recovery - Updated OIDC endpoints with timezone-aware datetime handling and RFC-compliant responses - Added "roles" scope support for OIDC userinfo and ID tokens - New pyotp dependency for TOTP operations - Comprehensive unit tests for TOTP service
This commit is contained in:
@@ -11,6 +11,7 @@ from app.utils.constants import AuthMethodType, SessionStatus, UserStatus, Audit
|
||||
from app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
|
||||
from app.exceptions.validation_exceptions import EmailAlreadyExistsError
|
||||
from app.services.audit_service import AuditService
|
||||
from app.services.totp_service import TOTPService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -234,3 +235,315 @@ class AuthService:
|
||||
resource_id=session.id,
|
||||
description=f"Session revoked: {reason or 'User logout'}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def enroll_totp(user: User) -> dict:
|
||||
"""
|
||||
Initiate TOTP enrollment for a user.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- secret: TOTP secret (base32 encoded)
|
||||
- provisioning_uri: otpauth:// URI for QR code
|
||||
- qr_code: Base64 encoded QR code as data URI
|
||||
- backup_codes: List of plain text backup codes
|
||||
|
||||
Raises:
|
||||
ConflictError: If user already has TOTP enabled
|
||||
"""
|
||||
from app.exceptions.validation_exceptions import ConflictError
|
||||
|
||||
# Check if user already has TOTP enabled
|
||||
if user.has_totp_enabled():
|
||||
raise ConflictError("TOTP is already enabled for this account")
|
||||
|
||||
# Generate TOTP secret
|
||||
secret = TOTPService.generate_secret()
|
||||
|
||||
# Generate provisioning URI
|
||||
provisioning_uri = TOTPService.generate_provisioning_uri(
|
||||
user_email=user.email,
|
||||
secret=secret,
|
||||
issuer="Gatehouse",
|
||||
)
|
||||
|
||||
# Generate QR code data URI
|
||||
qr_code = TOTPService.generate_qr_code_data_uri(provisioning_uri)
|
||||
|
||||
# Generate backup codes
|
||||
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
|
||||
|
||||
# Create unverified TOTP authentication method
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.TOTP,
|
||||
verified=False,
|
||||
is_primary=False,
|
||||
)
|
||||
auth_method.save()
|
||||
|
||||
# Store TOTP data in provider_data (since totp_secret field is commented out)
|
||||
auth_method.provider_data = {
|
||||
"secret": secret,
|
||||
"backup_codes": hashed_backup_codes,
|
||||
}
|
||||
db.session.commit()
|
||||
|
||||
# Log TOTP enrollment initiation
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_ENROLL_INITIATED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="TOTP enrollment initiated",
|
||||
)
|
||||
|
||||
return {
|
||||
"secret": secret,
|
||||
"provisioning_uri": provisioning_uri,
|
||||
"qr_code": qr_code,
|
||||
"backup_codes": backup_codes,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def verify_totp_enrollment(user: User, code: str) -> bool:
|
||||
"""
|
||||
Complete TOTP enrollment by verifying the first TOTP code.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
code: 6-digit TOTP code from authenticator app
|
||||
|
||||
Returns:
|
||||
True if verification successful
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If code is invalid or TOTP method not found
|
||||
"""
|
||||
# Get user's TOTP authentication method
|
||||
auth_method = user.get_totp_method()
|
||||
if not auth_method:
|
||||
raise InvalidCredentialsError("TOTP enrollment not found")
|
||||
|
||||
# Get secret from provider_data
|
||||
secret = auth_method.provider_data.get("secret") if auth_method.provider_data else None
|
||||
if not secret:
|
||||
raise InvalidCredentialsError("TOTP secret not found")
|
||||
|
||||
# Verify the code
|
||||
if not TOTPService.verify_code(secret, code):
|
||||
raise InvalidCredentialsError("Invalid TOTP code")
|
||||
|
||||
# Mark TOTP as verified
|
||||
auth_method.verified = True
|
||||
auth_method.totp_verified_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# Log TOTP enrollment completion
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_ENROLL_COMPLETED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="TOTP enrollment completed",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def disable_totp(user: User, password: str) -> bool:
|
||||
"""
|
||||
Disable TOTP for a user.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
password: User's current password for verification
|
||||
|
||||
Returns:
|
||||
True if TOTP disabled successfully
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If password is invalid or TOTP method not found
|
||||
"""
|
||||
# Verify user's password
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.PASSWORD,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not auth_method or not auth_method.password_hash:
|
||||
raise InvalidCredentialsError("No password authentication method found")
|
||||
|
||||
if not bcrypt.check_password_hash(auth_method.password_hash, password):
|
||||
raise InvalidCredentialsError("Invalid password")
|
||||
|
||||
# Get user's TOTP authentication method
|
||||
totp_method = user.get_totp_method()
|
||||
if not totp_method:
|
||||
raise InvalidCredentialsError("TOTP is not enabled for this account")
|
||||
|
||||
# Soft-delete the TOTP authentication method
|
||||
totp_method.delete(soft=True)
|
||||
|
||||
# Log TOTP disabled
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_DISABLED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=totp_method.id,
|
||||
description="TOTP disabled",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def authenticate_with_totp(user: User, code: str, is_backup_code: bool = False) -> bool:
|
||||
"""
|
||||
Verify TOTP code during login.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
code: 6-digit TOTP code or backup code
|
||||
is_backup_code: True if code is a backup code, False if TOTP code
|
||||
|
||||
Returns:
|
||||
True if code is valid
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If code is invalid or TOTP method not found
|
||||
"""
|
||||
# Get user's TOTP authentication method
|
||||
auth_method = user.get_totp_method()
|
||||
if not auth_method:
|
||||
raise InvalidCredentialsError("TOTP is not enabled for this account")
|
||||
|
||||
if is_backup_code:
|
||||
# Verify backup code
|
||||
backup_codes = (
|
||||
auth_method.provider_data.get("backup_codes")
|
||||
if auth_method.provider_data
|
||||
else []
|
||||
)
|
||||
is_valid, remaining_codes = TOTPService.verify_backup_code(backup_codes, code)
|
||||
|
||||
if is_valid:
|
||||
# Update remaining backup codes
|
||||
auth_method.provider_data = {
|
||||
"secret": auth_method.provider_data.get("secret"),
|
||||
"backup_codes": remaining_codes,
|
||||
}
|
||||
auth_method.last_used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# Log backup code usage
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_BACKUP_CODE_USED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="Backup code used for authentication",
|
||||
)
|
||||
else:
|
||||
# Log failed verification
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_VERIFY_FAILED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="Invalid backup code provided",
|
||||
)
|
||||
raise InvalidCredentialsError("Invalid backup code")
|
||||
else:
|
||||
# Verify TOTP code
|
||||
secret = (
|
||||
auth_method.provider_data.get("secret")
|
||||
if auth_method.provider_data
|
||||
else None
|
||||
)
|
||||
if not secret:
|
||||
raise InvalidCredentialsError("TOTP secret not found")
|
||||
|
||||
is_valid = TOTPService.verify_code(secret, code)
|
||||
|
||||
if is_valid:
|
||||
auth_method.last_used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# Log successful verification
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_VERIFY_SUCCESS,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="TOTP code verified successfully",
|
||||
)
|
||||
else:
|
||||
# Log failed verification
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_VERIFY_FAILED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="Invalid TOTP code provided",
|
||||
)
|
||||
raise InvalidCredentialsError("Invalid TOTP code")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def regenerate_totp_backup_codes(user: User, password: str) -> list[str]:
|
||||
"""
|
||||
Generate new backup codes for TOTP.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
password: User's current password for verification
|
||||
|
||||
Returns:
|
||||
List of new plain text backup codes
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If password is invalid or TOTP method not found
|
||||
"""
|
||||
# Verify user's password
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.PASSWORD,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not auth_method or not auth_method.password_hash:
|
||||
raise InvalidCredentialsError("No password authentication method found")
|
||||
|
||||
if not bcrypt.check_password_hash(auth_method.password_hash, password):
|
||||
raise InvalidCredentialsError("Invalid password")
|
||||
|
||||
# Get user's TOTP authentication method
|
||||
totp_method = user.get_totp_method()
|
||||
if not totp_method:
|
||||
raise InvalidCredentialsError("TOTP is not enabled for this account")
|
||||
|
||||
# Generate new backup codes
|
||||
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
|
||||
|
||||
# Update the authentication method with new backup codes
|
||||
totp_method.provider_data = {
|
||||
"secret": totp_method.provider_data.get("secret"),
|
||||
"backup_codes": hashed_backup_codes,
|
||||
}
|
||||
db.session.commit()
|
||||
|
||||
# Log backup codes regeneration
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_BACKUP_CODES_REGENERATED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=totp_method.id,
|
||||
description="TOTP backup codes regenerated",
|
||||
)
|
||||
|
||||
return backup_codes
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import logging
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from flask import current_app, g
|
||||
@@ -14,6 +14,7 @@ from app.models import (
|
||||
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
|
||||
OIDCSession, OIDCTokenMetadata
|
||||
)
|
||||
from app.models.organization_member import OrganizationMember
|
||||
from app.exceptions.validation_exceptions import (
|
||||
ValidationError, NotFoundError, BadRequestError
|
||||
)
|
||||
@@ -121,6 +122,14 @@ class OIDCService:
|
||||
ValidationError: If parameters are invalid
|
||||
NotFoundError: If client not found
|
||||
"""
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
logger.debug("[OIDC SERVICE] generate_authorization_code called")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||
logger.debug("[OIDC SERVICE] redirect_uri=%s", redirect_uri)
|
||||
logger.debug("[OIDC SERVICE] scope=%s", scope)
|
||||
logger.debug("[OIDC SERVICE] state=%s, nonce=%s", state, nonce)
|
||||
|
||||
# Validate client exists and is active
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
|
||||
@@ -152,14 +161,19 @@ class OIDCService:
|
||||
raise ValidationError("Invalid scopes")
|
||||
|
||||
# Generate authorization code
|
||||
logger.debug("[OIDC SERVICE] Generating authorization code...")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time before code generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
code = cls._generate_code()
|
||||
code_hash = cls._hash_value(code)
|
||||
logger.debug("[OIDC SERVICE] Code generated: %s...", code[:20] if code else None)
|
||||
|
||||
# Development-only debug logging for PKCE in code creation
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[OIDC] Generate auth code - PKCE: code_challenge={code_challenge is not None}, code_challenge_method={code_challenge_method}")
|
||||
|
||||
# Create auth code record
|
||||
logger.debug("[OIDC SERVICE] Creating auth code record with lifetime_seconds=600 (10 minutes)")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time before creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
auth_code = OIDCAuthCode.create_code(
|
||||
client_id=client.id,
|
||||
user_id=user_id,
|
||||
@@ -172,6 +186,9 @@ class OIDCService:
|
||||
user_agent=user_agent,
|
||||
lifetime_seconds=600, # 10 minutes
|
||||
)
|
||||
logger.debug("[OIDC SERVICE] Auth code created successfully")
|
||||
logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time after creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
|
||||
# Log authorization event
|
||||
OIDCAuditService.log_authorization_event(
|
||||
@@ -182,6 +199,9 @@ class OIDCService:
|
||||
scope=valid_scopes,
|
||||
)
|
||||
|
||||
logger.debug("[OIDC SERVICE] generate_authorization_code completed successfully")
|
||||
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
return code
|
||||
|
||||
@classmethod
|
||||
@@ -211,6 +231,12 @@ class OIDCService:
|
||||
InvalidGrantError: If code is invalid
|
||||
ValidationError: If PKCE validation fails
|
||||
"""
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
logger.debug("[OIDC SERVICE] validate_authorization_code called")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] client_id=%s, redirect_uri=%s", client_id, redirect_uri)
|
||||
logger.debug("[OIDC SERVICE] code_verifier provided: %s", bool(code_verifier))
|
||||
|
||||
# Get client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
|
||||
@@ -223,6 +249,8 @@ class OIDCService:
|
||||
raise InvalidGrantError("Invalid client")
|
||||
|
||||
# Hash the provided code and find matching auth code
|
||||
logger.debug("[OIDC SERVICE] Looking up authorization code...")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time before code lookup: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
code_hash = cls._hash_value(code)
|
||||
auth_code = OIDCAuthCode.query.filter_by(
|
||||
code_hash=code_hash,
|
||||
@@ -256,8 +284,18 @@ class OIDCService:
|
||||
raise InvalidGrantError("Authorization code already used")
|
||||
|
||||
# Check expiration
|
||||
logger.debug("[OIDC SERVICE] Checking if authorization code is expired...")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z")
|
||||
# Handle timezone-naive expires_at from database
|
||||
expires_at = auth_code.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (expires_at - datetime.now(timezone.utc)).total_seconds())
|
||||
|
||||
if auth_code.is_expired():
|
||||
logger.error(f"[OIDC] Validate auth code - Code expired: code_hash={code_hash[:20]}..., expires_at={auth_code.expires_at}")
|
||||
logger.error("[OIDC] Validate auth code - Code expired: code_hash=%s..., expires_at (UTC)=%s, current UTC time=%s",
|
||||
code_hash[:20], auth_code.expires_at.isoformat() + "Z", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client_id,
|
||||
user_id=auth_code.user_id,
|
||||
@@ -316,6 +354,9 @@ class OIDCService:
|
||||
"nonce": auth_code.nonce,
|
||||
}
|
||||
|
||||
logger.debug("[OIDC SERVICE] validate_authorization_code completed successfully")
|
||||
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
return claims, user
|
||||
|
||||
@classmethod
|
||||
@@ -366,6 +407,12 @@ class OIDCService:
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
logger.debug("[OIDC SERVICE] generate_tokens called")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s, scope=%s", client_id, user_id, scope)
|
||||
logger.debug("[OIDC SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
|
||||
|
||||
# Get client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
|
||||
@@ -377,6 +424,9 @@ class OIDCService:
|
||||
raise InvalidClientError()
|
||||
|
||||
# Generate access token
|
||||
logger.debug("[OIDC SERVICE] Generating access token...")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600)
|
||||
access_token_jti = OIDCTokenService._generate_jti()
|
||||
access_token = OIDCTokenService.create_access_token(
|
||||
client_id=client_id,
|
||||
@@ -384,8 +434,13 @@ class OIDCService:
|
||||
scope=scope,
|
||||
jti=access_token_jti,
|
||||
)
|
||||
logger.debug("[OIDC SERVICE] Access token generated successfully")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
|
||||
# Generate ID token
|
||||
logger.debug("[OIDC SERVICE] Generating ID token...")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600)
|
||||
id_token = OIDCTokenService.create_id_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
@@ -394,6 +449,8 @@ class OIDCService:
|
||||
access_token=access_token,
|
||||
auth_time=auth_time,
|
||||
)
|
||||
logger.debug("[OIDC SERVICE] ID token generated successfully")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
|
||||
# Generate or rotate refresh token
|
||||
if "refresh_token" in (client.grant_types or []):
|
||||
@@ -445,22 +502,28 @@ class OIDCService:
|
||||
client_db_id = client.id
|
||||
|
||||
# Access token metadata
|
||||
logger.debug("[OIDC SERVICE] Creating access token metadata...")
|
||||
access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600)
|
||||
logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z")
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client_db_id,
|
||||
user_id=user_id,
|
||||
token_type="access_token",
|
||||
token_jti=access_token_jti,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
|
||||
expires_at=access_token_expires_at,
|
||||
)
|
||||
|
||||
# ID token metadata (using access token JTI as reference)
|
||||
logger.debug("[OIDC SERVICE] Creating ID token metadata...")
|
||||
id_token_jti = OIDCTokenService._generate_jti()
|
||||
id_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.id_token_lifetime or 3600)
|
||||
logger.debug("[OIDC SERVICE] ID token expires_at (UTC): %s", id_token_expires_at.isoformat() + "Z")
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client_db_id,
|
||||
user_id=user_id,
|
||||
token_type="id_token",
|
||||
token_jti=id_token_jti,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=client.id_token_lifetime or 3600),
|
||||
expires_at=id_token_expires_at,
|
||||
)
|
||||
|
||||
# Log token event
|
||||
@@ -483,6 +546,9 @@ class OIDCService:
|
||||
if final_refresh_token:
|
||||
result["refresh_token"] = final_refresh_token
|
||||
|
||||
logger.debug("[OIDC SERVICE] generate_tokens completed successfully")
|
||||
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@@ -511,6 +577,11 @@ class OIDCService:
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
logger.debug("[OIDC SERVICE] refresh_access_token called")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] client_id=%s, scope=%s", client_id, scope)
|
||||
|
||||
# Get client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
|
||||
@@ -522,6 +593,8 @@ class OIDCService:
|
||||
raise InvalidClientError()
|
||||
|
||||
# Find refresh token
|
||||
logger.debug("[OIDC SERVICE] Looking up refresh token...")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time before refresh token lookup: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
||||
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
||||
token_hash=token_hash,
|
||||
@@ -542,6 +615,16 @@ class OIDCService:
|
||||
raise InvalidGrantError("Invalid refresh token")
|
||||
|
||||
# Check if valid
|
||||
logger.debug("[OIDC SERVICE] Checking if refresh token is valid...")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
if refresh_token_obj:
|
||||
logger.debug("[OIDC SERVICE] Refresh token expires_at (UTC): %s", refresh_token_obj.expires_at.isoformat() + "Z")
|
||||
# Handle timezone-naive expires_at from database
|
||||
rt_expires_at = refresh_token_obj.expires_at
|
||||
if rt_expires_at.tzinfo is None:
|
||||
rt_expires_at = rt_expires_at.replace(tzinfo=timezone.utc)
|
||||
logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (rt_expires_at - datetime.now(timezone.utc)).total_seconds())
|
||||
|
||||
if not refresh_token_obj.is_valid():
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client_id,
|
||||
@@ -563,6 +646,9 @@ class OIDCService:
|
||||
granted_scope = scope or (refresh_token_obj.scope or [])
|
||||
|
||||
# Generate new access token
|
||||
logger.debug("[OIDC SERVICE] Generating new access token...")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600)
|
||||
access_token_jti = OIDCTokenService._generate_jti()
|
||||
access_token = OIDCTokenService.create_access_token(
|
||||
client_id=client_id,
|
||||
@@ -570,14 +656,21 @@ class OIDCService:
|
||||
scope=granted_scope,
|
||||
jti=access_token_jti,
|
||||
)
|
||||
logger.debug("[OIDC SERVICE] Access token generated successfully")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
|
||||
# Generate new ID token
|
||||
logger.debug("[OIDC SERVICE] Generating new ID token...")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600)
|
||||
id_token = OIDCTokenService.create_id_token(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
scope=granted_scope,
|
||||
access_token=access_token,
|
||||
)
|
||||
logger.debug("[OIDC SERVICE] ID token generated successfully")
|
||||
logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
|
||||
# Rotate refresh token
|
||||
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
|
||||
@@ -590,12 +683,15 @@ class OIDCService:
|
||||
refresh_token_obj.rotate(new_hash)
|
||||
|
||||
# Store new token metadata
|
||||
logger.debug("[OIDC SERVICE] Creating access token metadata...")
|
||||
access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600)
|
||||
logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z")
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client.id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
token_type="access_token",
|
||||
token_jti=access_token_jti,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
|
||||
expires_at=access_token_expires_at,
|
||||
)
|
||||
|
||||
# Log refresh event
|
||||
@@ -615,6 +711,17 @@ class OIDCService:
|
||||
"id_token": id_token,
|
||||
"refresh_token": new_refresh,
|
||||
}
|
||||
|
||||
logger.debug("[OIDC SERVICE] refresh_access_token completed successfully")
|
||||
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": client.access_token_lifetime or 3600,
|
||||
"id_token": id_token,
|
||||
"refresh_token": new_refresh,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
@@ -630,10 +737,23 @@ class OIDCService:
|
||||
Raises:
|
||||
InvalidTokenError: If token is invalid
|
||||
"""
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
logger.debug("[OIDC SERVICE] validate_access_token() called")
|
||||
logger.debug("[OIDC SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
|
||||
logger.debug("[OIDC SERVICE] Token length: %d", len(token))
|
||||
logger.debug("[OIDC SERVICE] Client ID: %s", client_id)
|
||||
|
||||
try:
|
||||
logger.debug("[OIDC SERVICE] Calling OIDCTokenService.validate_access_token()...")
|
||||
claims = OIDCTokenService.validate_access_token(token, client_id)
|
||||
logger.debug("[OIDC SERVICE] Token validation successful")
|
||||
logger.debug("[OIDC SERVICE] Token claims: %s", claims)
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
return claims
|
||||
except Exception as e:
|
||||
logger.error("[OIDC SERVICE] Token validation failed: %s: %s", type(e).__name__, str(e))
|
||||
import traceback
|
||||
logger.error("[OIDC SERVICE] Traceback: %s", traceback.format_exc())
|
||||
OIDCAuditService.log_event(
|
||||
event_type="token_validation",
|
||||
client_id=client_id,
|
||||
@@ -770,29 +890,67 @@ class OIDCService:
|
||||
Returns:
|
||||
User information dictionary
|
||||
"""
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
logger.debug("[OIDC SERVICE] get_userinfo() called")
|
||||
logger.debug("[OIDC SERVICE] Access token (first 50 chars): %s...", access_token[:50] if len(access_token) > 50 else access_token)
|
||||
logger.debug("[OIDC SERVICE] Access token length: %d", len(access_token))
|
||||
|
||||
# Validate access token
|
||||
logger.debug("[OIDC SERVICE] Validating access token...")
|
||||
claims = cls.validate_access_token(access_token)
|
||||
logger.debug("[OIDC SERVICE] Access token validated successfully")
|
||||
logger.debug("[OIDC SERVICE] Token claims: %s", claims)
|
||||
|
||||
user_id = claims.get("sub")
|
||||
logger.debug("[OIDC SERVICE] User ID from token: %s", user_id)
|
||||
|
||||
logger.debug("[OIDC SERVICE] Querying user from database...")
|
||||
user = User.query.get(user_id)
|
||||
logger.debug("[OIDC SERVICE] User query result: %s", user)
|
||||
|
||||
if not user:
|
||||
logger.error("[OIDC SERVICE] User not found in database: user_id=%s", user_id)
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
logger.debug("[OIDC SERVICE] User found: user_id=%s, email=%s, full_name=%s", user.id, user.email, user.full_name)
|
||||
|
||||
# Get scopes from token
|
||||
scope_str = claims.get("scope", "")
|
||||
scopes = scope_str.split() if scope_str else []
|
||||
logger.debug("[OIDC SERVICE] Scope string from token: '%s'", scope_str)
|
||||
logger.debug("[OIDC SERVICE] Parsed scopes: %s", scopes)
|
||||
|
||||
userinfo = {"sub": user_id}
|
||||
logger.debug("[OIDC SERVICE] Initial userinfo: %s", userinfo)
|
||||
|
||||
# Add claims based on scope
|
||||
if "profile" in scopes and user.full_name:
|
||||
logger.debug("[OIDC SERVICE] Found 'profile' in scope, adding name claim")
|
||||
userinfo["name"] = user.full_name
|
||||
logger.debug("[OIDC SERVICE] Added name: %s", user.full_name)
|
||||
else:
|
||||
logger.debug("[OIDC SERVICE] 'profile' not in scope or user.full_name is None: profile_in_scope=%s, full_name=%s", "profile" in scopes, user.full_name)
|
||||
|
||||
if "email" in scopes:
|
||||
logger.debug("[OIDC SERVICE] Found 'email' in scope, adding email claims")
|
||||
userinfo["email"] = user.email
|
||||
userinfo["email_verified"] = user.email_verified
|
||||
logger.debug("[OIDC SERVICE] Added email: %s, email_verified: %s", user.email, user.email_verified)
|
||||
else:
|
||||
logger.debug("[OIDC SERVICE] 'email' not in scope")
|
||||
|
||||
if "roles" in scopes:
|
||||
logger.debug("[OIDC SERVICE] Found 'roles' in scope, adding roles claim")
|
||||
user_roles = cls._get_user_roles(user)
|
||||
userinfo["roles"] = user_roles
|
||||
logger.debug("[OIDC SERVICE] Added roles: %s", user_roles)
|
||||
else:
|
||||
logger.debug("[OIDC SERVICE] 'roles' not in scope")
|
||||
|
||||
logger.debug("[OIDC SERVICE] Final userinfo: %s", userinfo)
|
||||
|
||||
# Log userinfo access
|
||||
logger.debug("[OIDC SERVICE] Logging userinfo access event...")
|
||||
OIDCAuditService.log_userinfo_event(
|
||||
access_token=access_token,
|
||||
user_id=user_id,
|
||||
@@ -800,5 +958,54 @@ class OIDCService:
|
||||
success=True,
|
||||
scopes_claimed=scopes,
|
||||
)
|
||||
logger.debug("[OIDC SERVICE] Userinfo access event logged")
|
||||
|
||||
logger.debug("[OIDC SERVICE] get_userinfo() completed successfully")
|
||||
logger.debug("[OIDC SERVICE] ===========================================")
|
||||
|
||||
return userinfo
|
||||
|
||||
@staticmethod
|
||||
def _get_user_roles(user: User) -> list:
|
||||
"""Get user's organization roles.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
List of role objects with organization_id and role
|
||||
"""
|
||||
logger.debug("[OIDC SERVICE] _get_user_roles() called")
|
||||
logger.debug("[OIDC SERVICE] User: %s", user)
|
||||
|
||||
roles = []
|
||||
|
||||
if not user:
|
||||
logger.debug("[OIDC SERVICE] User is None, returning empty roles list")
|
||||
return roles
|
||||
|
||||
logger.debug("[OIDC SERVICE] User ID: %s", user.id)
|
||||
logger.debug("[OIDC SERVICE] User email: %s", user.email)
|
||||
logger.debug("[OIDC SERVICE] User organization_memberships: %s", user.organization_memberships)
|
||||
|
||||
if user.organization_memberships:
|
||||
logger.debug("[OIDC SERVICE] User has %d organization memberships", len(user.organization_memberships))
|
||||
for idx, member in enumerate(user.organization_memberships):
|
||||
logger.debug("[OIDC SERVICE] Processing membership %d: member=%s", idx, member)
|
||||
logger.debug("[OIDC SERVICE] organization_id: %s", member.organization_id)
|
||||
logger.debug("[OIDC SERVICE] role: %s", member.role)
|
||||
logger.debug("[OIDC SERVICE] role.value: %s", member.role.value)
|
||||
|
||||
role_entry = {
|
||||
"organization_id": str(member.organization_id),
|
||||
"role": member.role.value
|
||||
}
|
||||
roles.append(role_entry)
|
||||
logger.debug("[OIDC SERVICE] Added role entry: %s", role_entry)
|
||||
else:
|
||||
logger.debug("[OIDC SERVICE] User has no organization memberships")
|
||||
|
||||
logger.debug("[OIDC SERVICE] Final roles list: %s", roles)
|
||||
logger.debug("[OIDC SERVICE] _get_user_roles() completed")
|
||||
|
||||
return roles
|
||||
|
||||
@@ -3,6 +3,7 @@ import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from datetime import timezone
|
||||
from flask import current_app, g
|
||||
|
||||
from app.extensions import db
|
||||
@@ -219,11 +220,11 @@ class OIDCSessionService:
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(hours=older_than_hours)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=older_than_hours)
|
||||
|
||||
# Get expired sessions
|
||||
expired_sessions = OIDCSession.query.filter(
|
||||
OIDCSession.expires_at < datetime.utcnow(),
|
||||
OIDCSession.expires_at < datetime.now(timezone.utc),
|
||||
OIDCSession.deleted_at == None
|
||||
).all()
|
||||
|
||||
|
||||
@@ -2,15 +2,20 @@
|
||||
import hashlib
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
import jwt
|
||||
from flask import current_app, g
|
||||
|
||||
from app.models import User, OIDCClient
|
||||
from app.models.organization_member import OrganizationMember
|
||||
from app.services.oidc_jwks_service import OIDCJWKSService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OIDCTokenService:
|
||||
"""Service for generating and validating OIDC tokens.
|
||||
@@ -134,7 +139,7 @@ class OIDCTokenService:
|
||||
return lifetimes.get(token_type, 3600)
|
||||
|
||||
@classmethod
|
||||
def create_access_token(cls, client_id: str, user_id: str, scope: list,
|
||||
def create_access_token(cls, client_id: str, user_id: str, scope: list,
|
||||
jti: str = None) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
@@ -147,25 +152,44 @@ class OIDCTokenService:
|
||||
Returns:
|
||||
JWT access token string
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] create_access_token called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
|
||||
|
||||
jti = jti or cls._generate_jti()
|
||||
now = datetime.utcnow()
|
||||
now_timestamp = int(time.time())
|
||||
now = datetime.now(timezone.utc)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token lifetime (seconds): %s", lifetime)
|
||||
|
||||
exp_timestamp = now_timestamp + lifetime
|
||||
exp_time = now + timedelta(seconds=lifetime)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token expiration time (UTC): %s", exp_time.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token expiration timestamp: %s", exp_timestamp)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
|
||||
|
||||
claims = {
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"nbf": int(now.timestamp()),
|
||||
"exp": exp_timestamp,
|
||||
"iat": now_timestamp,
|
||||
"nbf": now_timestamp,
|
||||
"jti": jti,
|
||||
"client_id": client_id,
|
||||
"scope": " ".join(scope) if isinstance(scope, list) else scope,
|
||||
}
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, nbf=%s",
|
||||
claims["exp"], claims["iat"], claims["nbf"])
|
||||
|
||||
# Get signing key
|
||||
jwks_service = OIDCJWKSService()
|
||||
signing_key = jwks_service.get_signing_key()
|
||||
@@ -174,6 +198,7 @@ class OIDCTokenService:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
@@ -181,6 +206,9 @@ class OIDCTokenService:
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token created successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
@@ -200,12 +228,30 @@ class OIDCTokenService:
|
||||
Returns:
|
||||
JWT ID token string
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
auth_time = auth_time or int(now.timestamp())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] create_id_token called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||
logger.debug("[OIDC TOKEN SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
|
||||
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
|
||||
|
||||
now_timestamp = int(time.time())
|
||||
now = datetime.now(timezone.utc)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
|
||||
auth_time = auth_time or now_timestamp
|
||||
logger.debug("[OIDC TOKEN SERVICE] auth_time (Unix timestamp): %s", auth_time)
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token lifetime (seconds): %s", lifetime)
|
||||
|
||||
exp_timestamp = now_timestamp + lifetime
|
||||
exp_time = now + timedelta(seconds=lifetime)
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token expiration time (UTC): %s", exp_time.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token expiration timestamp: %s", exp_timestamp)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
|
||||
|
||||
# Get user for claims
|
||||
user = User.query.get(user_id)
|
||||
@@ -214,11 +260,14 @@ class OIDCTokenService:
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": exp_timestamp,
|
||||
"iat": now_timestamp,
|
||||
"auth_time": auth_time,
|
||||
}
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, auth_time=%s",
|
||||
claims["exp"], claims["iat"], claims["auth_time"])
|
||||
|
||||
# Add nonce if provided
|
||||
if nonce:
|
||||
claims["nonce"] = nonce
|
||||
@@ -235,6 +284,10 @@ class OIDCTokenService:
|
||||
if user.full_name:
|
||||
claims["name"] = user.full_name
|
||||
|
||||
# Add roles claim if scope is granted
|
||||
if scope and "roles" in scope:
|
||||
claims["roles"] = cls._get_user_roles(user)
|
||||
|
||||
# Add scope if provided
|
||||
if scope:
|
||||
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
|
||||
@@ -247,6 +300,7 @@ class OIDCTokenService:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
@@ -254,10 +308,32 @@ class OIDCTokenService:
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token created successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def _get_user_roles(user: User) -> list:
|
||||
"""Get user's organization roles.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
List of role objects with organization_id and role
|
||||
"""
|
||||
roles = []
|
||||
if user and user.organization_memberships:
|
||||
for member in user.organization_memberships:
|
||||
roles.append({
|
||||
"organization_id": str(member.organization_id),
|
||||
"role": member.role.value
|
||||
})
|
||||
return roles
|
||||
|
||||
@classmethod
|
||||
def create_refresh_token(cls, client_id: str, user_id: str,
|
||||
def create_refresh_token(cls, client_id: str, user_id: str,
|
||||
scope: list = None, access_token_id: str = None) -> str:
|
||||
"""Create an opaque refresh token.
|
||||
|
||||
@@ -270,11 +346,21 @@ class OIDCTokenService:
|
||||
Returns:
|
||||
Opaque refresh token string
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] create_refresh_token called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||
logger.debug("[OIDC TOKEN SERVICE] scope=%s, access_token_id=%s", scope, access_token_id)
|
||||
|
||||
token = cls._generate_opaque_token()
|
||||
logger.debug("[OIDC TOKEN SERVICE] Refresh token generated: %s...", token[:20] if token else None)
|
||||
|
||||
# Hash for storage
|
||||
token_hash = cls._hash_token(token)
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Refresh token created successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return token, token_hash
|
||||
|
||||
@classmethod
|
||||
@@ -292,54 +378,91 @@ class OIDCTokenService:
|
||||
jwt.ExpiredSignatureError: If token is expired
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] verify_token_signature() called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
|
||||
|
||||
# Get the JWKS with public keys
|
||||
logger.debug("[OIDC TOKEN SERVICE] Getting JWKS...")
|
||||
jwks_service = OIDCJWKSService()
|
||||
jwks = jwks_service.get_jwks()
|
||||
jwks = jwks_service.get_jwks(include_private_keys=True)
|
||||
logger.debug("[OIDC TOKEN SERVICE] JWKS retrieved: %d keys", len(jwks.get("keys", [])))
|
||||
|
||||
# Get the key ID from token header
|
||||
try:
|
||||
logger.debug("[OIDC TOKEN SERVICE] Getting unverified token header...")
|
||||
unverified_header = jwt.get_unverified_header(token)
|
||||
except jwt.DecodeError:
|
||||
logger.debug("[OIDC TOKEN SERVICE] Unverified header: %s", unverified_header)
|
||||
except jwt.DecodeError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Failed to decode token header: %s", str(e))
|
||||
raise jwt.InvalidTokenError("Invalid token header")
|
||||
|
||||
kid = unverified_header.get("kid")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Key ID (kid) from token header: %s", kid)
|
||||
|
||||
# Find the matching public key
|
||||
logger.debug("[OIDC TOKEN SERVICE] Searching for matching public key...")
|
||||
public_key = None
|
||||
for key in jwks.get("keys", []):
|
||||
for idx, key in enumerate(jwks.get("keys", [])):
|
||||
logger.debug("[OIDC TOKEN SERVICE] Checking key %d: kid=%s", idx, key.get("kid"))
|
||||
if key.get("kid") == kid:
|
||||
logger.debug("[OIDC TOKEN SERVICE] Found matching key at index %d", idx)
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Loading PEM public key...")
|
||||
public_key = serialization.load_pem_public_key(
|
||||
key["public_key"].encode() if isinstance(key["public_key"], str)
|
||||
key["public_key"].encode() if isinstance(key["public_key"], str)
|
||||
else key["public_key"],
|
||||
backend=default_backend()
|
||||
)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Public key loaded successfully")
|
||||
break
|
||||
except (ImportError, Exception):
|
||||
except (ImportError, Exception) as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Failed to load public key: %s: %s", type(e).__name__, str(e))
|
||||
continue
|
||||
|
||||
if not public_key:
|
||||
logger.error("[OIDC TOKEN SERVICE] No matching public key found for kid=%s", kid)
|
||||
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
|
||||
|
||||
# Verify the signature
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
public_key,
|
||||
algorithms=["RS256"],
|
||||
audience=None, # We'll validate audience separately
|
||||
issuer=cls._get_issuer(),
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_aud": False, # Handle audience manually
|
||||
"verify_iss": False, # Handle issuer manually
|
||||
}
|
||||
)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Public key found, verifying signature...")
|
||||
|
||||
return claims
|
||||
# Verify the signature
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
public_key,
|
||||
algorithms=["RS256"],
|
||||
audience=None, # We'll validate audience separately
|
||||
issuer=cls._get_issuer(),
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_aud": False, # Handle audience manually
|
||||
"verify_iss": False, # Handle issuer manually
|
||||
}
|
||||
)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Signature verification successful")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Decoded claims: %s", claims)
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return claims
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Token has expired: %s", str(e))
|
||||
raise
|
||||
except jwt.InvalidSignatureError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Invalid token signature: %s", str(e))
|
||||
raise
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Invalid token: %s: %s", type(e).__name__, str(e))
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Unexpected error during token verification: %s: %s", type(e).__name__, str(e))
|
||||
import traceback
|
||||
logger.error("[OIDC TOKEN SERVICE] Traceback: %s", traceback.format_exc())
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def decode_token(cls, token: str, verify: bool = False) -> Dict:
|
||||
@@ -378,16 +501,41 @@ class OIDCTokenService:
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
ValueError: If token is expired or audience mismatch
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
|
||||
logger.debug("[OIDC TOKEN SERVICE] Client ID: %s", client_id)
|
||||
|
||||
# Verify token signature
|
||||
logger.debug("[OIDC TOKEN SERVICE] Verifying token signature...")
|
||||
claims = cls.verify_token_signature(token)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token signature verified")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Claims: %s", claims)
|
||||
|
||||
# Check expiration
|
||||
if claims.get("exp", 0) < datetime.utcnow().timestamp():
|
||||
exp = claims.get("exp", 0)
|
||||
now_timestamp = int(time.time())
|
||||
|
||||
if exp < now_timestamp:
|
||||
logger.error("[OIDC TOKEN SERVICE] Token has expired")
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
# Validate audience if client_id provided
|
||||
aud = claims.get("aud")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token audience (aud): %s", aud)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Expected client_id: %s", client_id)
|
||||
|
||||
if client_id:
|
||||
if claims.get("aud") != client_id:
|
||||
if aud != client_id:
|
||||
logger.error("[OIDC TOKEN SERVICE] Audience mismatch: expected=%s, got=%s", client_id, aud)
|
||||
raise ValueError("Invalid audience")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Audience validation passed")
|
||||
else:
|
||||
logger.debug("[OIDC TOKEN SERVICE] No client_id provided, skipping audience validation")
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() completed successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
|
||||
return claims
|
||||
|
||||
@@ -410,11 +558,17 @@ class OIDCTokenService:
|
||||
claims = cls.validate_access_token(token, client_id)
|
||||
|
||||
# Calculate remaining time
|
||||
now = datetime.utcnow().timestamp()
|
||||
now_timestamp = int(time.time())
|
||||
now = datetime.now(timezone.utc)
|
||||
exp = claims.get("exp", 0)
|
||||
iat = claims.get("iat", 0)
|
||||
|
||||
result["active"] = exp > now
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Current UTC time: %s", now.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration timestamp: %s", exp)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration datetime (UTC): %s", datetime.fromtimestamp(exp, tz=timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Time until expiration: %s seconds", exp - now_timestamp)
|
||||
|
||||
result["active"] = exp > now_timestamp
|
||||
result.update({
|
||||
"iss": claims.get("iss"),
|
||||
"sub": claims.get("sub"),
|
||||
@@ -429,8 +583,8 @@ class OIDCTokenService:
|
||||
})
|
||||
|
||||
# Add expiry in seconds
|
||||
if exp > now:
|
||||
result["exp"] = int(exp - now)
|
||||
if exp > now_timestamp:
|
||||
result["exp"] = int(exp - now_timestamp)
|
||||
|
||||
except (jwt.InvalidTokenError, ValueError) as e:
|
||||
result["active"] = False
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
"""TOTP (Time-based One-Time Password) service."""
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Tuple
|
||||
|
||||
import pyotp
|
||||
from app.extensions import bcrypt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TOTPService:
|
||||
"""Service for TOTP operations."""
|
||||
|
||||
@staticmethod
|
||||
def generate_secret() -> str:
|
||||
"""
|
||||
Generate a new TOTP secret.
|
||||
|
||||
Returns:
|
||||
Base32 encoded secret (32 characters)
|
||||
|
||||
Note:
|
||||
The secret is generated using cryptographically secure random bytes
|
||||
and encoded in base32 format for compatibility with authenticator apps.
|
||||
"""
|
||||
# Generate 20 random bytes (160 bits) and encode as base32
|
||||
random_bytes = secrets.token_bytes(20)
|
||||
secret = base64.b32encode(random_bytes).decode("utf-8")
|
||||
logger.debug(f"Generated new TOTP secret: {secret[:8]}...")
|
||||
return secret
|
||||
|
||||
@staticmethod
|
||||
def generate_provisioning_uri(user_email: str, secret: str, issuer: str = "Gatehouse") -> str:
|
||||
"""
|
||||
Generate provisioning URI for QR code.
|
||||
|
||||
Args:
|
||||
user_email: User's email address
|
||||
secret: TOTP secret (base32 encoded)
|
||||
issuer: Issuer name (default: "Gatehouse")
|
||||
|
||||
Returns:
|
||||
otpauth:// URI for QR code generation
|
||||
|
||||
Example:
|
||||
>>> uri = TOTPService.generate_provisioning_uri("user@example.com", "JBSWY3DPEHPK3PXP")
|
||||
>>> print(uri)
|
||||
otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse
|
||||
"""
|
||||
totp = pyotp.TOTP(secret)
|
||||
uri = totp.provisioning_uri(name=user_email, issuer_name=issuer)
|
||||
logger.debug(f"Generated provisioning URI for user: {user_email}")
|
||||
return uri
|
||||
|
||||
@staticmethod
|
||||
def verify_code(secret: str, code: str, window: int = 1) -> bool:
|
||||
"""
|
||||
Verify a TOTP code against the secret.
|
||||
|
||||
Args:
|
||||
secret: TOTP secret (base32 encoded)
|
||||
code: 6-digit TOTP code to verify
|
||||
window: Time window for code validation (default: 1, allows codes from previous/next time steps)
|
||||
|
||||
Returns:
|
||||
True if code is valid, False otherwise
|
||||
|
||||
Note:
|
||||
The window parameter allows for clock skew between the server
|
||||
and the authenticator app. A window of 1 allows codes from
|
||||
the previous, current, and next 30-second intervals.
|
||||
"""
|
||||
totp = pyotp.TOTP(secret)
|
||||
is_valid = totp.verify(code, valid_window=window)
|
||||
logger.debug(f"TOTP code verification: valid={is_valid}, window={window}")
|
||||
return is_valid
|
||||
|
||||
@staticmethod
|
||||
def generate_backup_codes(count: int = 10) -> Tuple[list[str], list[str]]:
|
||||
"""
|
||||
Generate backup codes for TOTP recovery.
|
||||
|
||||
Args:
|
||||
count: Number of backup codes to generate (default: 10)
|
||||
|
||||
Returns:
|
||||
Tuple of (plain_codes, hashed_codes)
|
||||
- plain_codes: List of plain text backup codes (for display to user)
|
||||
- hashed_codes: List of bcrypt hashed backup codes (for storage)
|
||||
|
||||
Note:
|
||||
Backup codes are 16-character alphanumeric codes that can be used
|
||||
to recover access if the TOTP device is lost. Each code can only
|
||||
be used once.
|
||||
"""
|
||||
plain_codes = []
|
||||
hashed_codes = []
|
||||
|
||||
for _ in range(count):
|
||||
# Generate a 16-character alphanumeric code
|
||||
code = secrets.token_hex(8).upper()
|
||||
plain_codes.append(code)
|
||||
|
||||
# Hash the code using bcrypt
|
||||
hashed_code = bcrypt.generate_password_hash(code).decode("utf-8")
|
||||
hashed_codes.append(hashed_code)
|
||||
|
||||
logger.debug(f"Generated {count} backup codes")
|
||||
return plain_codes, hashed_codes
|
||||
|
||||
@staticmethod
|
||||
def verify_backup_code(hashed_codes: list[str], code: str) -> Tuple[bool, list[str]]:
|
||||
"""
|
||||
Verify and consume a backup code.
|
||||
|
||||
Args:
|
||||
hashed_codes: List of bcrypt hashed backup codes
|
||||
code: Plain text backup code to verify
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, remaining_codes)
|
||||
- is_valid: True if code was valid and consumed, False otherwise
|
||||
- remaining_codes: List of remaining hashed codes (with consumed code removed)
|
||||
|
||||
Note:
|
||||
Once a backup code is used, it is removed from the list and cannot
|
||||
be used again. This ensures each code is single-use.
|
||||
"""
|
||||
remaining_codes = []
|
||||
|
||||
for hashed_code in hashed_codes:
|
||||
if bcrypt.check_password_hash(hashed_code, code):
|
||||
# Code found and valid - don't add to remaining codes (consumed)
|
||||
logger.debug("Backup code verified and consumed")
|
||||
return True, remaining_codes
|
||||
else:
|
||||
# Code doesn't match - keep it in remaining codes
|
||||
remaining_codes.append(hashed_code)
|
||||
|
||||
logger.debug("Backup code verification failed")
|
||||
return False, remaining_codes
|
||||
|
||||
@staticmethod
|
||||
def generate_qr_code_data_uri(provisioning_uri: str) -> str:
|
||||
"""
|
||||
Generate QR code as data URI for frontend display.
|
||||
|
||||
Args:
|
||||
provisioning_uri: otpauth:// URI to encode in QR code
|
||||
|
||||
Returns:
|
||||
Base64 encoded PNG image as data URI (data:image/png;base64,...)
|
||||
|
||||
Note:
|
||||
If the qrcode library is not installed, returns a placeholder message.
|
||||
Install with: pip install qrcode[pil]
|
||||
"""
|
||||
try:
|
||||
import qrcode
|
||||
|
||||
# Create QR code
|
||||
qr = qrcode.QRCode(
|
||||
version=1,
|
||||
error_correction=qrcode.constants.ERROR_CORRECT_L,
|
||||
box_size=10,
|
||||
border=4,
|
||||
)
|
||||
qr.add_data(provisioning_uri)
|
||||
qr.make(fit=True)
|
||||
|
||||
# Generate image
|
||||
img = qr.make_image(fill_color="black", back_color="white")
|
||||
|
||||
# Convert to base64
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
data_uri = f"data:image/png;base64,{img_base64}"
|
||||
logger.debug("Generated QR code data URI")
|
||||
return data_uri
|
||||
|
||||
except ImportError:
|
||||
logger.warning("qrcode library not installed, returning placeholder")
|
||||
return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]"
|
||||
Reference in New Issue
Block a user