feat(auth): implement TOTP two-factor authentication with enrollment and verification

Adds TOTP (Time-based One-Time Password) two-factor authentication support including:
- New TOTP service with secret generation, QR code provisioning, and code verification
- New auth endpoints for enrollment, verification, status, and backup code management
- New TOTP authentication method type and user methods for TOTP management
- Backup codes generation and verification for account recovery
- Updated OIDC endpoints with timezone-aware datetime handling and RFC-compliant responses
- Added "roles" scope support for OIDC userinfo and ID tokens
- New pyotp dependency for TOTP operations
- Comprehensive unit tests for TOTP service
This commit is contained in:
2026-01-14 18:06:17 +10:30
parent 977abf66df
commit cfd79190ee
26 changed files with 2176 additions and 263 deletions
+4
View File
@@ -305,3 +305,7 @@ client_secret: acme_secret_portal_2024
## User ## User
email: bob@acme-corp.com email: bob@acme-corp.com
password: UserPass123! password: UserPass123!
## Sqlite editor
sqlite_web instance/db_file.db --port 9999 --host 0.0.0.0
+297 -187
View File
@@ -3,6 +3,7 @@ import base64
import json import json
import logging import logging
import secrets import secrets
from datetime import datetime, timezone
from urllib.parse import urlencode, urlparse, parse_qs from urllib.parse import urlencode, urlparse, parse_qs
import bcrypt import bcrypt
@@ -42,14 +43,14 @@ def get_oidc_config():
"registration_endpoint": f"{base_url}/oidc/register", "registration_endpoint": f"{base_url}/oidc/register",
"revocation_endpoint": f"{base_url}/oidc/revoke", "revocation_endpoint": f"{base_url}/oidc/revoke",
"introspection_endpoint": f"{base_url}/oidc/introspect", "introspection_endpoint": f"{base_url}/oidc/introspect",
"scopes_supported": ["openid", "profile", "email"], "scopes_supported": ["openid", "profile", "email", "roles"],
"response_types_supported": ["code"], "response_types_supported": ["code"],
"response_modes_supported": ["query"], "response_modes_supported": ["query"],
"grant_types_supported": ["authorization_code", "refresh_token"], "grant_types_supported": ["authorization_code", "refresh_token"],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
"subject_types_supported": ["public"], "subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["RS256"], "id_token_signing_alg_values_supported": ["RS256"],
"claims_supported": ["sub", "name", "email", "email_verified"], "claims_supported": ["sub", "name", "email", "email_verified", "roles"],
} }
@@ -94,19 +95,49 @@ def require_valid_token():
Raises: Raises:
InvalidGrantError: If token is invalid InvalidGrantError: If token is invalid
""" """
logger.debug("[OIDC USERINFO] ===========================================")
logger.debug("[OIDC USERINFO] require_valid_token() called")
logger.debug("[OIDC USERINFO] Request method: %s", request.method)
logger.debug("[OIDC USERINFO] Request URL: %s", request.url)
logger.debug("[OIDC USERINFO] Request headers: %s", dict(request.headers))
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")
logger.debug("[OIDC USERINFO] Authorization header: %s", auth_header[:20] + "..." if len(auth_header) > 20 else auth_header)
if not auth_header.startswith("Bearer "): if not auth_header.startswith("Bearer "):
logger.error("[OIDC USERINFO] Invalid Authorization header format - missing 'Bearer ' prefix")
raise InvalidGrantError("Invalid token: Missing or invalid Authorization header") raise InvalidGrantError("Invalid token: Missing or invalid Authorization header")
token = auth_header[7:] token = auth_header[7:]
claims = OIDCService.validate_access_token(token) logger.debug("[OIDC USERINFO] Token extracted (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
g.current_token = claims logger.debug("[OIDC USERINFO] Token length: %d", len(token))
try:
logger.debug("[OIDC USERINFO] Calling OIDCService.validate_access_token()...")
claims = OIDCService.validate_access_token(token)
logger.debug("[OIDC USERINFO] Token validation successful")
logger.debug("[OIDC USERINFO] Token claims: %s", claims)
except Exception as e:
logger.error("[OIDC USERINFO] Token validation failed: %s: %s", type(e).__name__, str(e))
raise
g.current_token = claims
g.access_token = token # Store the original access token for get_userinfo()
logger.debug("[OIDC USERINFO] g.current_token set")
user_id = claims.get("sub")
logger.debug("[OIDC USERINFO] User ID from token: %s", user_id)
user = User.query.get(user_id)
logger.debug("[OIDC USERINFO] User query result: %s", user)
user = User.query.get(claims.get("sub"))
if not user: if not user:
logger.error("[OIDC USERINFO] User not found in database: user_id=%s", user_id)
raise InvalidGrantError("Invalid token: User not found") raise InvalidGrantError("Invalid token: User not found")
g.current_user = user g.current_user = user
logger.debug("[OIDC USERINFO] g.current_user set: user_id=%s, email=%s", user.id, user.email)
logger.debug("[OIDC USERINFO] require_valid_token() completed successfully")
def _check_password_hash(client, password): def _check_password_hash(client, password):
@@ -175,10 +206,11 @@ def oidc_discovery():
No authentication required. No authentication required.
Returns: Returns:
200: OIDC discovery document 200: OIDC discovery document (application/json)
""" """
config = get_oidc_config() config = get_oidc_config()
# Return discovery document as application/json (per OpenID Connect Discovery 1.0)
response = jsonify(config) response = jsonify(config)
response.headers["Cache-Control"] = "max-age=86400" response.headers["Cache-Control"] = "max-age=86400"
return response, 200 return response, 200
@@ -217,8 +249,12 @@ def oidc_authorize():
200: Login page (GET when not authenticated) 200: Login page (GET when not authenticated)
400: Invalid request 400: Invalid request
""" """
logger.debug("[OIDC] ===========================================")
logger.debug("[OIDC] oidc_authorize called") logger.debug("[OIDC] oidc_authorize called")
logger.debug("[OIDC] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC] Request method: %s", request.method)
logger.debug("[OIDC] Request URL: %s", request.url)
logger.debug("[OIDC] Remote address: %s", request.remote_addr)
# Parse request parameters # Parse request parameters
if request.method == "GET": if request.method == "GET":
@@ -227,6 +263,7 @@ def oidc_authorize():
params = request.form.to_dict() params = request.form.to_dict()
logger.debug("[OIDC] Raw request params: %s", params) logger.debug("[OIDC] Raw request params: %s", params)
# Extract required parameters # Extract required parameters
logger.debug("[OIDC] Extracting request parameters...") logger.debug("[OIDC] Extracting request parameters...")
client_id = params.get("client_id") client_id = params.get("client_id")
@@ -367,6 +404,7 @@ def oidc_authorize():
# User is authenticated, generate authorization code # User is authenticated, generate authorization code
logger.debug("[OIDC] User is authenticated, fetching user from database...") logger.debug("[OIDC] User is authenticated, fetching user from database...")
logger.debug("[OIDC] Current UTC time before code generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
user = User.query.get(user_id) user = User.query.get(user_id)
logger.debug("[OIDC] User query result: %s", user) logger.debug("[OIDC] User query result: %s", user)
@@ -393,12 +431,17 @@ def oidc_authorize():
user_agent=request.headers.get("User-Agent"), user_agent=request.headers.get("User-Agent"),
) )
logger.debug("[OIDC] Authorization code generated successfully: %s...", code[:20] if code else None) logger.debug("[OIDC] Authorization code generated successfully: %s...", code[:20] if code else None)
logger.debug("[OIDC] Current UTC time after code generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
except Exception as e: except Exception as e:
logger.debug("[OIDC] Authorization code generation failed: %s", str(e)) logger.error("[OIDC] Authorization code generation failed: %s", str(e))
logger.error("[OIDC] Current UTC time at failure: %s", datetime.now(timezone.utc).isoformat() + "Z")
import traceback
logger.error("[OIDC] Traceback: %s", traceback.format_exc())
return _redirect_with_error(redirect_uri, "server_error", str(e), state) return _redirect_with_error(redirect_uri, "server_error", str(e), state)
# Redirect with authorization code # Redirect with authorization code
logger.debug("[OIDC] Redirecting with authorization code...") logger.debug("[OIDC] Redirecting with authorization code...")
logger.debug("[OIDC] Current UTC time before redirect: %s", datetime.now(timezone.utc).isoformat() + "Z")
redirect_params = {"code": code} redirect_params = {"code": code}
if state: if state:
redirect_params["state"] = state redirect_params["state"] = state
@@ -406,6 +449,7 @@ def oidc_authorize():
redirect_url = f"{redirect_uri}?{urlencode(redirect_params)}" redirect_url = f"{redirect_uri}?{urlencode(redirect_params)}"
logger.debug("[OIDC] Redirect URL: %s...", redirect_url[:100] if redirect_url else None) logger.debug("[OIDC] Redirect URL: %s...", redirect_url[:100] if redirect_url else None)
logger.debug("[OIDC] oidc_authorize completed successfully") logger.debug("[OIDC] oidc_authorize completed successfully")
logger.debug("[OIDC] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC] ===========================================") logger.debug("[OIDC] ===========================================")
return redirect(redirect_url) return redirect(redirect_url)
@@ -544,14 +588,13 @@ def oidc_token():
# Validate grant_type # Validate grant_type
if not grant_type: if not grant_type:
logger.error("[OIDC] grant_type is requred") logger.error("[OIDC] grant_type is required")
return api_response( # RFC 6749 Section 5.2: Error response for invalid request
success=False, response = jsonify({
message="grant_type is required", "error": "invalid_request",
status=400, "error_description": "grant_type is required"
error_type="INVALID_REQUEST", })
error_details={"error": "invalid_request", "error_description": "grant_type is required"}, return response, 400
)
# Authenticate client # Authenticate client
client_id = data.get("client_id") client_id = data.get("client_id")
@@ -600,46 +643,51 @@ def oidc_token():
# Unsupported grant type # Unsupported grant type
else: else:
logger.error("[OIDC] Unsupported grant_type") logger.error("[OIDC] Unsupported grant_type")
return api_response( # RFC 6749 Section 5.2: Error response for unsupported grant type
success=False, response = jsonify({
message="Unsupported grant_type", "error": "unsupported_grant_type",
status=400, "error_description": f"Grant type '{grant_type}' is not supported"
error_type="UNSUPPORTED_GRANT_TYPE", })
error_details={"error": "unsupported_grant_type", "error_description": f"Grant type '{grant_type}' is not supported"}, return response, 400
)
def _handle_authorization_code_grant(data, client): def _handle_authorization_code_grant(data, client):
"""Handle authorization_code grant type.""" """Handle authorization_code grant type."""
logger.debug("[OIDC] ===========================================")
logger.debug("[OIDC] _handle_authorization_code_grant called")
logger.debug("[OIDC] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
code = data.get("code") code = data.get("code")
redirect_uri = data.get("redirect_uri") redirect_uri = data.get("redirect_uri")
code_verifier = data.get("code_verifier") code_verifier = data.get("code_verifier")
logger.debug("[OIDC] Code provided: %s", bool(code))
logger.debug("[OIDC] Redirect URI: %s", redirect_uri)
logger.debug("[OIDC] Code verifier provided: %s", bool(code_verifier))
if not code: if not code:
logger.error("[OIDC] code is required") logger.error("[OIDC] code is required")
return api_response( # RFC 6749 Section 5.2: Error response for invalid request
success=False, response = jsonify({
message="code is required", "error": "invalid_request",
status=400, "error_description": "code is required"
error_type="INVALID_REQUEST", })
error_details={"error": "invalid_request", "error_description": "code is required"}, return response, 400
)
if not redirect_uri: if not redirect_uri:
logger.error("[OIDC] redirect_uri is required") logger.error("[OIDC] redirect_uri is required")
return api_response( response = jsonify({
success=False, "error": "invalid_request",
message="redirect_uri is required", "error_description": "redirect_uri is required"
status=400, })
error_type="INVALID_REQUEST", return response, 400
error_details={"error": "invalid_request", "error_description": "redirect_uri is required"},
)
try: try:
# Development-only debug logging for authorization code validation # Development-only debug logging for authorization code validation
if current_app.config.get('ENV') == 'development': if current_app.config.get('ENV') == 'development':
logger.debug(f"[OIDC] Authorization code validation: client_id={client.client_id}, code_provided=True") logger.debug(f"[OIDC] Authorization code validation: client_id={client.client_id}, code_provided=True")
logger.debug("[OIDC] Current UTC time before code validation: %s", datetime.now(timezone.utc).isoformat() + "Z")
claims, user = OIDCService.validate_authorization_code( claims, user = OIDCService.validate_authorization_code(
code=code, code=code,
client_id=client.client_id, client_id=client.client_id,
@@ -649,23 +697,22 @@ def _handle_authorization_code_grant(data, client):
user_agent=request.headers.get("User-Agent"), user_agent=request.headers.get("User-Agent"),
) )
except InvalidGrantError as e: except InvalidGrantError as e:
logger.error(f"[OIDC] INVALID_GRANT: {str(e)}") logger.error("[OIDC] INVALID_GRANT: %s", str(e))
return api_response( logger.error("[OIDC] Current UTC time at validation failure: %s", datetime.now(timezone.utc).isoformat() + "Z")
success=False, # RFC 6749 Section 5.2: Error response for invalid grant
message=str(e), response = jsonify({
status=400, "error": "invalid_grant",
error_type="INVALID_GRANT", "error_description": str(e)
error_details={"error": "invalid_grant", "error_description": str(e)}, })
) return response, 400
except Exception as e: except Exception as e:
logger.error(f"[OIDC] Authorization code validation error: {type(e).__name__}: {str(e)}") logger.error("[OIDC] Authorization code validation error: %s: %s", type(e).__name__, str(e))
return api_response( logger.error("[OIDC] Current UTC time at validation error: %s", datetime.now(timezone.utc).isoformat() + "Z")
success=False, response = jsonify({
message=str(e), "error": "invalid_grant",
status=400, "error_description": str(e)
error_type="INVALID_GRANT", })
error_details={"error": "invalid_grant", "error_description": str(e)}, return response, 400
)
# Generate tokens # Generate tokens
try: try:
@@ -673,6 +720,8 @@ def _handle_authorization_code_grant(data, client):
if current_app.config.get('ENV') == 'development': if current_app.config.get('ENV') == 'development':
logger.debug(f"[OIDC] Token generation: client_id={client.client_id}, user_id={claims['user_id']}, scope={claims['scope']}") logger.debug(f"[OIDC] Token generation: client_id={client.client_id}, user_id={claims['user_id']}, scope={claims['scope']}")
logger.debug("[OIDC] Current UTC time before token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
tokens = OIDCService.generate_tokens( tokens = OIDCService.generate_tokens(
client_id=client.client_id, client_id=client.client_id,
user_id=claims["user_id"], user_id=claims["user_id"],
@@ -683,40 +732,64 @@ def _handle_authorization_code_grant(data, client):
auth_time=int(__import__("time").time()), auth_time=int(__import__("time").time()),
) )
except Exception as e: except Exception as e:
logger.error(f"[OIDC] Failed to generate tokens {str(e)}") logger.error("[OIDC] Failed to generate tokens: %s", str(e))
return api_response( logger.error("[OIDC] Current UTC time at token generation failure: %s", datetime.now(timezone.utc).isoformat() + "Z")
success=False, response = jsonify({
message="Failed to generate tokens", "error": "server_error",
status=500, "error_description": str(e)
error_type="SERVER_ERROR", })
error_details={"error": "server_error", "error_description": str(e)}, return response, 500
)
return api_response( # Return standard OAuth2/OIDC token response (application/json)
data=tokens, # Per RFC 6749 Section 5.1 and OIDC Core 1.0
message="Tokens issued successfully", logger.debug("[OIDC] Current UTC time after token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
status=200, logger.debug("[OIDC] _handle_authorization_code_grant completed successfully")
)
# Echo tokens to console for diagnostics
print(f"[TOKEN DIAGNOSTICS] Authorization code exchange completed")
print(f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')}..." if len(tokens.get('access_token', '')) > 50 else f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')}")
print(f"[TOKEN DIAGNOSTICS] Token Type: {tokens.get('token_type', '')}")
print(f"[TOKEN DIAGNOSTICS] Expires In: {tokens.get('expires_in', '')}")
if 'id_token' in tokens:
print(f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token']}..." if len(tokens['id_token']) > 50 else f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token']}")
if 'refresh_token' in tokens:
print(f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token'][:50]}..." if len(tokens['refresh_token']) > 50 else f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token']}")
print(f"[TOKEN DIAGNOSTICS] Scope: {tokens.get('scope', '')}")
print(f"[TOKEN DIAGNOSTICS] ===========================================")
logger.debug("[OIDC] ===========================================")
response = jsonify(tokens)
print(tokens)
response.headers["Cache-Control"] = "no-store"
response.headers["Pragma"] = "no-cache"
return response, 200
def _handle_refresh_token_grant(data, client): def _handle_refresh_token_grant(data, client):
"""Handle refresh_token grant type.""" """Handle refresh_token grant type."""
logger.debug("[OIDC] ===========================================")
logger.debug("[OIDC] _handle_refresh_token_grant called")
logger.debug("[OIDC] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
refresh_token = data.get("refresh_token") refresh_token = data.get("refresh_token")
scope = data.get("scope") scope = data.get("scope")
logger.debug("[OIDC] Refresh token provided: %s", bool(refresh_token))
logger.debug("[OIDC] Scope: %s", scope)
if not refresh_token: if not refresh_token:
return api_response( # RFC 6749 Section 5.2: Error response for invalid request
success=False, response = jsonify({
message="refresh_token is required", "error": "invalid_request",
status=400, "error_description": "refresh_token is required"
error_type="INVALID_REQUEST", })
error_details={"error": "invalid_request", "error_description": "refresh_token is required"}, return response, 400
)
# Parse scope if provided # Parse scope if provided
scope_list = scope.split() if scope else None scope_list = scope.split() if scope else None
try: try:
logger.debug("[OIDC] Current UTC time before token refresh: %s", datetime.now(timezone.utc).isoformat() + "Z")
tokens = OIDCService.refresh_access_token( tokens = OIDCService.refresh_access_token(
refresh_token=refresh_token, refresh_token=refresh_token,
client_id=client.client_id, client_id=client.client_id,
@@ -725,19 +798,37 @@ def _handle_refresh_token_grant(data, client):
user_agent=request.headers.get("User-Agent"), user_agent=request.headers.get("User-Agent"),
) )
except InvalidGrantError as e: except InvalidGrantError as e:
return api_response( logger.error("[OIDC] Refresh token error: %s", str(e))
success=False, logger.error("[OIDC] Current UTC time at refresh failure: %s", datetime.now(timezone.utc).isoformat() + "Z")
message=str(e), # RFC 6749 Section 5.2: Error response for invalid grant
status=400, response = jsonify({
error_type="INVALID_GRANT", "error": "invalid_grant",
error_details={"error": "invalid_grant", "error_description": str(e)}, "error_description": str(e)
) })
return response, 400
return api_response( # Return standard OAuth2/OIDC token response (application/json)
data=tokens, # Per RFC 6749 Section 5.1 and OIDC Core 1.0
message="Tokens refreshed successfully", logger.debug("[OIDC] Current UTC time after token refresh: %s", datetime.now(timezone.utc).isoformat() + "Z")
status=200, logger.debug("[OIDC] _handle_refresh_token_grant completed successfully")
)
# Echo tokens to console for diagnostics
print(f"[TOKEN DIAGNOSTICS] Token refresh completed")
print(f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')[:50]}..." if len(tokens.get('access_token', '')) > 50 else f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')}")
print(f"[TOKEN DIAGNOSTICS] Token Type: {tokens.get('token_type', '')}")
print(f"[TOKEN DIAGNOSTICS] Expires In: {tokens.get('expires_in', '')}")
if 'id_token' in tokens:
print(f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token'][:50]}..." if len(tokens['id_token']) > 50 else f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token']}")
if 'refresh_token' in tokens:
print(f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token'][:50]}..." if len(tokens['refresh_token']) > 50 else f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token']}")
print(f"[TOKEN DIAGNOSTICS] Scope: {tokens.get('scope', '')}")
print(f"[TOKEN DIAGNOSTICS] ===========================================")
logger.debug("[OIDC] ===========================================")
response = jsonify(tokens)
response.headers["Cache-Control"] = "no-store"
response.headers["Pragma"] = "no-cache"
return response, 200
# ============================================================================ # ============================================================================
@@ -759,37 +850,72 @@ def oidc_userinfo():
- email_verified: Email verification status (if "email" scope) - email_verified: Email verification status (if "email" scope)
Returns: Returns:
200: User claims 200: User claims in JSON format (application/json)
401: Invalid or insufficient token 401: Invalid or missing token (with WWW-Authenticate header per RFC 6750)
""" """
logger.debug("[OIDC USERINFO] ===========================================")
logger.debug("[OIDC USERINFO] oidc_userinfo() endpoint called")
logger.debug("[OIDC USERINFO] Request method: %s", request.method)
logger.debug("[OIDC USERINFO] Request URL: %s", request.url)
logger.debug("[OIDC USERINFO] Request content_type: %s", request.content_type)
logger.debug("[OIDC USERINFO] Request headers: %s", dict(request.headers))
logger.debug("[OIDC USERINFO] Request args: %s", dict(request.args))
logger.debug("[OIDC USERINFO] Request form: %s", dict(request.form))
request_json = request.get_json(silent=True)
logger.debug("[OIDC USERINFO] Request json: %s", request_json)
logger.debug("[OIDC USERINFO] Request data length: %d bytes", len(request.get_data()))
try: try:
logger.debug("[OIDC USERINFO] Calling require_valid_token()...")
require_valid_token() require_valid_token()
logger.debug("[OIDC USERINFO] Token validation successful")
except InvalidGrantError as e: except InvalidGrantError as e:
return api_response( logger.error("[OIDC USERINFO] Token validation failed: %s", str(e))
success=False, # RFC 6750 Section 3: Return 401 with WWW-Authenticate header for invalid tokens
message=str(e), response = jsonify({
status=401, "error": "invalid_token",
error_type="INVALID_TOKEN", "error_description": str(e)
error_details={"error": "invalid_token", "error_description": str(e)}, })
) response.headers["WWW-Authenticate"] = 'Bearer realm="OIDC UserInfo Endpoint", error="invalid_token", error_description="' + str(e) + '"'
return response, 401
# Get userinfo
try:
userinfo = OIDCService.get_userinfo(g.current_token.get("access_token", ""))
except Exception as e: except Exception as e:
return api_response( logger.error("[OIDC USERINFO] Unexpected error during token validation: %s: %s", type(e).__name__, str(e))
success=False, response = jsonify({
message="Failed to get user info", "error": "server_error",
status=500, "error_description": str(e)
error_type="SERVER_ERROR", })
error_details={"error": "server_error", "error_description": str(e)}, response.headers["WWW-Authenticate"] = 'Bearer realm="OIDC UserInfo Endpoint", error="server_error"'
) return response, 500
return api_response( logger.debug("[OIDC USERINFO] g.current_token: %s", g.current_token)
data=userinfo, logger.debug("[OIDC USERINFO] g.current_user: user_id=%s, email=%s", g.current_user.id, g.current_user.email)
message="User info retrieved successfully",
status=200, # Get userinfo using the original access token
) access_token = g.access_token
logger.debug("[OIDC USERINFO] Access token from g.access_token: %s...", access_token[:50] if len(access_token) > 50 else access_token)
try:
logger.debug("[OIDC USERINFO] Calling OIDCService.get_userinfo()...")
userinfo = OIDCService.get_userinfo(access_token)
logger.debug("[OIDC USERINFO] Userinfo retrieved successfully: %s", userinfo)
except Exception as e:
logger.error("[OIDC USERINFO] Failed to get user info: %s: %s", type(e).__name__, str(e))
import traceback
logger.error("[OIDC USERINFO] Traceback: %s", traceback.format_exc())
response = jsonify({
"error": "server_error",
"error_description": str(e)
})
return response, 500
logger.debug("[OIDC USERINFO] Returning userinfo response")
logger.debug("[OIDC USERINFO] ===========================================")
# Return standard OIDC UserInfo response (application/json)
# Per OpenID Connect Core 1.0 Section 5.3, response is a JSON object
response = jsonify(userinfo)
response.headers["Cache-Control"] = "no-cache, no-store"
return response, 200
# ============================================================================ # ============================================================================
@@ -806,19 +932,18 @@ def oidc_jwks():
No authentication required. No authentication required.
Returns: Returns:
200: JWKS document 200: JWKS document (application/json)
""" """
try: try:
jwks = OIDCService.get_jwks() jwks = OIDCService.get_jwks()
except Exception as e: except Exception as e:
return api_response( response = jsonify({
success=False, "error": "server_error",
message="Failed to get JWKS", "error_description": str(e)
status=500, })
error_type="SERVER_ERROR", return response, 500
error_details={"error": "server_error", "error_description": str(e)},
)
# Return JWKS as application/json (per OpenID Connect Discovery 1.0)
response = jsonify(jwks) response = jsonify(jwks)
response.headers["Cache-Control"] = "max-age=3600" response.headers["Cache-Control"] = "max-age=3600"
return response, 200 return response, 200
@@ -858,13 +983,12 @@ def oidc_revoke():
token = data.get("token") token = data.get("token")
if not token: if not token:
return api_response( # RFC 7009 Section 2.1: Error response for invalid request
success=False, response = jsonify({
message="token is required", "error": "invalid_request",
status=400, "error_description": "token is required"
error_type="INVALID_REQUEST", })
error_details={"error": "invalid_request", "error_description": "token is required"}, return response, 400
)
# Authenticate client # Authenticate client
client_id = data.get("client_id") client_id = data.get("client_id")
@@ -901,13 +1025,11 @@ def oidc_revoke():
user_agent=request.headers.get("User-Agent"), user_agent=request.headers.get("User-Agent"),
) )
except Exception as e: except Exception as e:
# Revocation should succeed even if token is invalid # Revocation should succeed even if token is invalid (RFC 7009)
pass pass
return api_response( # RFC 7009 Section 2.2: Successful revocation returns empty body with 200
message="Token revoked successfully", return "", 200
status=200,
)
# ============================================================================ # ============================================================================
@@ -944,13 +1066,12 @@ def oidc_introspect():
token = data.get("token") token = data.get("token")
if not token: if not token:
return api_response( # RFC 7009 Section 2.1: Error response for invalid request
success=False, response = jsonify({
message="token is required", "error": "invalid_request",
status=400, "error_description": "token is required"
error_type="INVALID_REQUEST", })
error_details={"error": "invalid_request", "error_description": "token is required"}, return response, 400
)
# Authenticate client # Authenticate client
client_id = data.get("client_id") client_id = data.get("client_id")
@@ -986,19 +1107,17 @@ def oidc_introspect():
user_agent=request.headers.get("User-Agent"), user_agent=request.headers.get("User-Agent"),
) )
except Exception as e: except Exception as e:
return api_response( # RFC 7009 Section 2.2: Error response
success=False, response = jsonify({
message="Failed to introspect token", "error": "server_error",
status=500, "error_description": str(e)
error_type="SERVER_ERROR", })
error_details={"error": "server_error", "error_description": str(e)}, return response, 500
)
return api_response( # RFC 7009 Section 2.3: Return introspection response (application/json)
data=result, response = jsonify(result)
message="Token introspection successful", response.headers["Cache-Control"] = "no-cache, no-store"
status=200, return response, 200
)
# ============================================================================ # ============================================================================
@@ -1030,22 +1149,18 @@ def oidc_register():
redirect_uris = data.get("redirect_uris", []) redirect_uris = data.get("redirect_uris", [])
if not client_name: if not client_name:
return api_response( response = jsonify({
success=False, "error": "invalid_request",
message="client_name is required", "error_description": "client_name is required"
status=400, })
error_type="INVALID_REQUEST", return response, 400
error_details={"error": "invalid_request", "error_description": "client_name is required"},
)
if not redirect_uris: if not redirect_uris:
return api_response( response = jsonify({
success=False, "error": "invalid_request",
message="redirect_uris is required", "error_description": "redirect_uris is required"
status=400, })
error_type="INVALID_REQUEST", return response, 400
error_details={"error": "invalid_request", "error_description": "redirect_uris is required"},
)
# Validate redirect_uris # Validate redirect_uris
for uri in redirect_uris: for uri in redirect_uris:
@@ -1054,13 +1169,11 @@ def oidc_register():
if not parsed.scheme or not parsed.netloc: if not parsed.scheme or not parsed.netloc:
raise ValueError(f"Invalid redirect URI: {uri}") raise ValueError(f"Invalid redirect URI: {uri}")
except Exception: except Exception:
return api_response( response = jsonify({
success=False, "error": "invalid_request",
message=f"Invalid redirect_uri: {uri}", "error_description": f"Invalid redirect_uri: {uri}"
status=400, })
error_type="INVALID_REQUEST", return response, 400
error_details={"error": "invalid_request", "error_description": f"Invalid redirect_uri: {uri}"},
)
# Generate client credentials # Generate client credentials
client_id = f"oidc_{secrets.token_urlsafe(16)}" client_id = f"oidc_{secrets.token_urlsafe(16)}"
@@ -1092,7 +1205,7 @@ def oidc_register():
redirect_uris=redirect_uris, redirect_uris=redirect_uris,
grant_types=data.get("grant_types", ["authorization_code", "refresh_token"]), grant_types=data.get("grant_types", ["authorization_code", "refresh_token"]),
response_types=data.get("response_types", ["code"]), response_types=data.get("response_types", ["code"]),
scopes=data.get("scope", "openid profile email").split(), scopes=data.get("scope", "openid profile email roles").split(),
token_endpoint_auth_method=data.get("token_endpoint_auth_method", "client_secret_basic"), token_endpoint_auth_method=data.get("token_endpoint_auth_method", "client_secret_basic"),
is_active=True, is_active=True,
is_confidential=True, is_confidential=True,
@@ -1105,19 +1218,16 @@ def oidc_register():
client.save() client.save()
# Return client credentials # Return client credentials
return api_response( response = jsonify({
data={ "client_id": client_id,
"client_id": client_id, "client_secret": client_secret,
"client_secret": client_secret, "client_id_issued_at": int(__import__("time").time()),
"client_id_issued_at": int(__import__("time").time()), "client_secret_expires_at": 0, # Never expires
"client_secret_expires_at": 0, # Never expires "client_name": client_name,
"client_name": client_name, "redirect_uris": redirect_uris,
"redirect_uris": redirect_uris, "token_endpoint_auth_method": data.get("token_endpoint_auth_method", "client_secret_basic"),
"token_endpoint_auth_method": data.get("token_endpoint_auth_method", "client_secret_basic"), "grant_types": client.grant_types,
"grant_types": client.grant_types, "response_types": client.response_types,
"response_types": client.response_types, "scope": " ".join(client.scopes),
"scope": " ".join(client.scopes), })
}, return response, 201
message="Client registered successfully",
status=201,
)
+321 -4
View File
@@ -3,11 +3,20 @@ from flask import request, session, g
from marshmallow import ValidationError from marshmallow import ValidationError
from app.api.v1 import api_v1_bp from app.api.v1 import api_v1_bp
from app.utils.response import api_response from app.utils.response import api_response
from app.schemas.auth_schema import RegisterSchema, LoginSchema from app.schemas.auth_schema import (
RegisterSchema,
LoginSchema,
TOTPVerifyEnrollmentSchema,
TOTPVerifySchema,
TOTPDisableSchema,
TOTPRegenerateBackupCodesSchema,
)
from app.services.auth_service import AuthService from app.services.auth_service import AuthService
from app.services.user_service import UserService from app.services.user_service import UserService
from app.utils.decorators import login_required from app.utils.decorators import login_required
from app.utils.constants import AuditAction from app.utils.constants import AuditAction
from app.exceptions.auth_exceptions import InvalidCredentialsError
from app.exceptions.validation_exceptions import ConflictError
@api_v1_bp.route("/auth/register", methods=["POST"]) @api_v1_bp.route("/auth/register", methods=["POST"])
@@ -72,7 +81,7 @@ def login():
remember_me: Optional boolean for extended session remember_me: Optional boolean for extended session
Returns: Returns:
200: Login successful 200: Login successful or TOTP code required
400: Validation error 400: Validation error
401: Invalid credentials 401: Invalid credentials
""" """
@@ -81,13 +90,29 @@ def login():
schema = LoginSchema() schema = LoginSchema()
data = schema.load(request.json) data = schema.load(request.json)
# Authenticate user # Authenticate user with email and password
user = AuthService.authenticate( user = AuthService.authenticate(
email=data["email"], email=data["email"],
password=data["password"], password=data["password"],
) )
# Create session # Check if user has TOTP enabled for two-factor authentication
if user.has_totp_enabled():
# TOTP is enabled - store user_id in session for TOTP verification
# The /auth/totp/verify endpoint will retrieve this user_id
session["totp_pending_user_id"] = user.id
# Return response indicating TOTP code is required
# Do NOT create session or return token yet - wait for TOTP verification
return api_response(
data={
"requires_totp": True,
},
message="TOTP code required. Please enter your 6-digit code from your authenticator app.",
)
# TOTP is NOT enabled - proceed with normal login flow
# Create session with appropriate duration based on remember_me preference
duration = 2592000 if data.get("remember_me") else 86400 # 30 days vs 1 day duration = 2592000 if data.get("remember_me") else 86400 # 30 days vs 1 day
user_session = AuthService.create_session(user, duration_seconds=duration) user_session = AuthService.create_session(user, duration_seconds=duration)
@@ -210,3 +235,295 @@ def revoke_session(session_id):
return api_response( return api_response(
message="Session revoked successfully", message="Session revoked successfully",
) )
@api_v1_bp.route("/auth/totp/enroll", methods=["POST"])
@login_required
def enroll_totp():
"""
Initiate TOTP enrollment for the current user.
Returns:
201: TOTP enrollment initiated with secret, provisioning_uri, qr_code, and backup_codes
401: Not authenticated
409: TOTP already enabled
"""
try:
# Initiate TOTP enrollment
result = AuthService.enroll_totp(g.current_user)
return api_response(
data={
"secret": result["secret"],
"provisioning_uri": result["provisioning_uri"],
"qr_code": result["qr_code"],
"backup_codes": result["backup_codes"],
},
message="TOTP enrollment initiated. Please verify with your authenticator app.",
status=201,
)
except ConflictError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/totp/verify-enrollment", methods=["POST"])
@login_required
def verify_totp_enrollment():
"""
Complete TOTP enrollment by verifying the first TOTP code.
Request body:
code: 6-digit TOTP code from authenticator app
Returns:
200: TOTP enrollment completed successfully
400: Validation error
401: Not authenticated
401: Invalid TOTP code
"""
try:
# Validate request data
schema = TOTPVerifyEnrollmentSchema()
data = schema.load(request.json)
# Verify TOTP enrollment
AuthService.verify_totp_enrollment(g.current_user, data["code"])
return api_response(
message="TOTP enrollment completed successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except InvalidCredentialsError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/totp/verify", methods=["POST"])
def verify_totp():
"""
Verify TOTP code during login.
Request body:
code: 6-digit TOTP code or backup code
is_backup_code: True if code is a backup code, False if TOTP code (default: False)
Returns:
200: TOTP code verified successfully with session token
400: Validation error
401: Invalid TOTP code or session not found
"""
try:
# Validate request data
schema = TOTPVerifySchema()
data = schema.load(request.json)
# Get user from temporary session (stored in Flask session by login endpoint)
user_id = session.get("totp_pending_user_id")
if not user_id:
return api_response(
success=False,
message="No pending TOTP verification. Please login first.",
status=401,
error_type="AUTHENTICATION_ERROR",
)
# Get user from database
from app.models.user import User
user = User.query.get(user_id)
if not user:
return api_response(
success=False,
message="User not found",
status=401,
error_type="AUTHENTICATION_ERROR",
)
# Verify TOTP code
AuthService.authenticate_with_totp(
user, data["code"], data.get("is_backup_code", False)
)
# Create full session
user_session = AuthService.create_session(user)
# Clear temporary session
session.pop("totp_pending_user_id", None)
return api_response(
data={
"user": user.to_dict(),
"token": user_session.token,
"expires_at": user_session.expires_at.isoformat() + "Z"
if user_session.expires_at.isoformat()[-1] != "Z"
else user_session.expires_at.isoformat(),
},
message="TOTP verification successful",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except InvalidCredentialsError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/totp/disable", methods=["DELETE"])
@login_required
def disable_totp():
"""
Disable TOTP for the current user.
Request body:
password: User's current password for verification
Returns:
200: TOTP disabled successfully
400: Validation error
401: Not authenticated or invalid password
401: TOTP not enabled
"""
try:
# Validate request data
schema = TOTPDisableSchema()
data = schema.load(request.json)
# Disable TOTP
AuthService.disable_totp(g.current_user, data["password"])
return api_response(
message="TOTP disabled successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except InvalidCredentialsError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/totp/status", methods=["GET"])
@login_required
def get_totp_status():
"""
Get TOTP status for the current user.
Returns:
200: TOTP status with totp_enabled, verified_at, and backup_codes_remaining
401: Not authenticated
"""
user = g.current_user
# Check if TOTP is enabled
totp_enabled = user.has_totp_enabled()
# Get TOTP method to check backup codes remaining
backup_codes_remaining = 0
verified_at = None
if totp_enabled:
totp_method = user.get_totp_method()
if totp_method and totp_method.provider_data:
backup_codes = totp_method.provider_data.get("backup_codes", [])
backup_codes_remaining = len(backup_codes)
if totp_method and totp_method.totp_verified_at:
verified_at = totp_method.totp_verified_at.isoformat() + "Z" if totp_method.totp_verified_at.isoformat()[-1] != "Z" else totp_method.totp_verified_at.isoformat()
return api_response(
data={
"totp_enabled": totp_enabled,
"verified_at": verified_at,
"backup_codes_remaining": backup_codes_remaining,
},
message="TOTP status retrieved successfully",
)
@api_v1_bp.route("/auth/totp/regenerate-backup-codes", methods=["POST"])
@login_required
def regenerate_totp_backup_codes():
"""
Generate new backup codes for TOTP.
Request body:
password: User's current password for verification
Returns:
200: New backup codes generated successfully
400: Validation error
401: Not authenticated or invalid password
401: TOTP not enabled
"""
try:
# Validate request data
schema = TOTPRegenerateBackupCodesSchema()
data = schema.load(request.json)
# Regenerate backup codes
backup_codes = AuthService.regenerate_totp_backup_codes(
g.current_user, data["password"]
)
return api_response(
data={
"backup_codes": backup_codes,
},
message="Backup codes regenerated successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except InvalidCredentialsError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
+2
View File
@@ -26,6 +26,7 @@ def setup_cors(app):
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID" response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
response.headers["Access-Control-Max-Age"] = "3600" response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response return response
elif origin and origin in cors_origins: elif origin and origin in cors_origins:
response = make_response("", 204) response = make_response("", 204)
@@ -34,6 +35,7 @@ def setup_cors(app):
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID" response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600" response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response return response
@app.after_request @app.after_request
+11
View File
@@ -51,4 +51,15 @@ class SecurityHeadersMiddleware:
"geolocation=(), microphone=(), camera=()" "geolocation=(), microphone=(), camera=()"
) )
# Cache-Control: Allow OIDC endpoints to set their own Cache-Control
# Only set no-cache for API responses that haven't set their own cache headers
if "Cache-Control" not in response.headers:
# Check if this is a JSON API response (shouldn't be cached)
content_type = response.headers.get("Content-Type", "")
if "application/json" in content_type:
response.headers["Cache-Control"] = "no-cache, no-store"
elif "text/html" not in content_type:
# For non-HTML responses, add Pragma for HTTP/1.0 compatibility
response.headers["Pragma"] = "no-cache"
return response return response
+12 -1
View File
@@ -19,6 +19,11 @@ class AuthenticationMethod(BaseModel):
provider_user_id = db.Column(db.String(255), nullable=True) provider_user_id = db.Column(db.String(255), nullable=True)
provider_data = db.Column(db.JSON, nullable=True) provider_data = db.Column(db.JSON, nullable=True)
# # For TOTP authentication
# totp_secret = db.Column(db.String(32), nullable=True)
# totp_backup_codes = db.Column(db.JSON, nullable=True)
# totp_verified_at = db.Column(db.DateTime, nullable=True)
# Metadata # Metadata
is_primary = db.Column(db.Boolean, default=False, nullable=False) is_primary = db.Column(db.Boolean, default=False, nullable=False)
verified = db.Column(db.Boolean, default=False, nullable=False) verified = db.Column(db.Boolean, default=False, nullable=False)
@@ -51,9 +56,15 @@ class AuthenticationMethod(BaseModel):
AuthMethodType.MICROSOFT, AuthMethodType.MICROSOFT,
] ]
def is_totp(self):
"""Check if this is a TOTP authentication method."""
return self.method_type == AuthMethodType.TOTP
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields.""" """Convert to dictionary, excluding sensitive fields."""
exclude = exclude or [] exclude = exclude or []
# Always exclude password hash # Always exclude password hash and TOTP secrets
exclude.append("password_hash") exclude.append("password_hash")
exclude.append("totp_secret")
exclude.append("totp_backup_codes")
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
+5 -5
View File
@@ -1,6 +1,6 @@
"""Base model with common fields and functionality.""" """Base model with common fields and functionality."""
import uuid import uuid
from datetime import datetime from datetime import datetime, timezone
from app.extensions import db from app.extensions import db
@@ -16,9 +16,9 @@ class BaseModel(db.Model):
unique=True, unique=True,
nullable=False, nullable=False,
) )
created_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) created_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
updated_at = db.Column( updated_at = db.Column(
db.DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)
) )
deleted_at = db.Column(db.DateTime, nullable=True) deleted_at = db.Column(db.DateTime, nullable=True)
@@ -36,7 +36,7 @@ class BaseModel(db.Model):
soft: If True, performs soft delete. If False, hard delete. soft: If True, performs soft delete. If False, hard delete.
""" """
if soft: if soft:
self.deleted_at = datetime.utcnow() self.deleted_at = datetime.now(timezone.utc)
db.session.commit() db.session.commit()
else: else:
db.session.delete(self) db.session.delete(self)
@@ -47,7 +47,7 @@ class BaseModel(db.Model):
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
self.updated_at = datetime.utcnow() self.updated_at = datetime.now(timezone.utc)
db.session.commit() db.session.commit()
return self return self
+9 -4
View File
@@ -1,5 +1,5 @@
"""OIDC Authorization Code model for auth code flow.""" """OIDC Authorization Code model for auth code flow."""
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from app.extensions import db from app.extensions import db
from app.models.base import BaseModel from app.models.base import BaseModel
@@ -49,7 +49,12 @@ class OIDCAuthCode(BaseModel):
def is_expired(self): def is_expired(self):
"""Check if the authorization code has expired.""" """Check if the authorization code has expired."""
return datetime.utcnow() > self.expires_at # Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
# Make naive datetime timezone-aware (UTC)
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_valid(self): def is_valid(self):
"""Check if the authorization code is valid for use.""" """Check if the authorization code is valid for use."""
@@ -58,7 +63,7 @@ class OIDCAuthCode(BaseModel):
def mark_as_used(self): def mark_as_used(self):
"""Mark the authorization code as used.""" """Mark the authorization code as used."""
self.is_used = True self.is_used = True
self.used_at = datetime.utcnow() self.used_at = datetime.now(timezone.utc)
db.session.commit() db.session.commit()
@classmethod @classmethod
@@ -90,7 +95,7 @@ class OIDCAuthCode(BaseModel):
scope=scope, scope=scope,
nonce=nonce, nonce=nonce,
code_verifier=code_verifier, code_verifier=code_verifier,
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds), expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
ip_address=ip_address, ip_address=ip_address,
user_agent=user_agent, user_agent=user_agent,
) )
+9 -5
View File
@@ -1,5 +1,5 @@
"""OIDC Refresh Token model for token rotation.""" """OIDC Refresh Token model for token rotation."""
from datetime import datetime from datetime import datetime, timezone
from app.extensions import db from app.extensions import db
from app.models.base import BaseModel from app.models.base import BaseModel
@@ -58,7 +58,11 @@ class OIDCRefreshToken(BaseModel):
def is_expired(self): def is_expired(self):
"""Check if the refresh token has expired.""" """Check if the refresh token has expired."""
return datetime.utcnow() > self.expires_at # Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self): def is_revoked(self):
"""Check if the refresh token has been revoked.""" """Check if the refresh token has been revoked."""
@@ -74,7 +78,7 @@ class OIDCRefreshToken(BaseModel):
Args: Args:
reason: Optional reason for revocation reason: Optional reason for revocation
""" """
self.revoked_at = datetime.utcnow() self.revoked_at = datetime.now(timezone.utc)
self.revoked_reason = reason self.revoked_reason = reason
db.session.commit() db.session.commit()
@@ -93,7 +97,7 @@ class OIDCRefreshToken(BaseModel):
self.rotation_count += 1 self.rotation_count += 1
# Extend expiration on rotation # Extend expiration on rotation
from datetime import timedelta from datetime import timedelta
self.expires_at = datetime.utcnow() + timedelta(days=30) self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
db.session.commit() db.session.commit()
return self return self
@@ -123,7 +127,7 @@ class OIDCRefreshToken(BaseModel):
token_hash=token_hash, token_hash=token_hash,
scope=scope, scope=scope,
access_token_id=access_token_id, access_token_id=access_token_id,
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds), expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
ip_address=ip_address, ip_address=ip_address,
user_agent=user_agent, user_agent=user_agent,
) )
+4 -4
View File
@@ -1,5 +1,5 @@
"""OIDC Session model for OIDC session tracking.""" """OIDC Session model for OIDC session tracking."""
from datetime import datetime from datetime import datetime, timezone
from app.extensions import db from app.extensions import db
from app.models.base import BaseModel from app.models.base import BaseModel
@@ -49,7 +49,7 @@ class OIDCSession(BaseModel):
def is_expired(self): def is_expired(self):
"""Check if the OIDC session has expired.""" """Check if the OIDC session has expired."""
return datetime.utcnow() > self.expires_at return datetime.now(timezone.utc) > self.expires_at
def is_authenticated(self): def is_authenticated(self):
"""Check if the user has been authenticated in this session.""" """Check if the user has been authenticated in this session."""
@@ -57,7 +57,7 @@ class OIDCSession(BaseModel):
def mark_authenticated(self): def mark_authenticated(self):
"""Mark the session as authenticated.""" """Mark the session as authenticated."""
self.authenticated_at = datetime.utcnow() self.authenticated_at = datetime.now(timezone.utc)
db.session.commit() db.session.commit()
def validate_nonce(self, expected_nonce): def validate_nonce(self, expected_nonce):
@@ -126,7 +126,7 @@ class OIDCSession(BaseModel):
nonce=nonce, nonce=nonce,
code_challenge=code_challenge, code_challenge=code_challenge,
code_challenge_method=code_challenge_method, code_challenge_method=code_challenge_method,
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds), expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
) )
db.session.add(session) db.session.add(session)
db.session.commit() db.session.commit()
+7 -3
View File
@@ -1,6 +1,6 @@
"""OIDC Token Metadata model for token revocation tracking.""" """OIDC Token Metadata model for token revocation tracking."""
import uuid import uuid
from datetime import datetime from datetime import datetime, timezone
from app.extensions import db from app.extensions import db
from app.models.base import BaseModel from app.models.base import BaseModel
@@ -50,7 +50,11 @@ class OIDCTokenMetadata(BaseModel):
def is_expired(self): def is_expired(self):
"""Check if the token has expired.""" """Check if the token has expired."""
return datetime.utcnow() > self.expires_at # Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self): def is_revoked(self):
"""Check if the token has been revoked.""" """Check if the token has been revoked."""
@@ -66,7 +70,7 @@ class OIDCTokenMetadata(BaseModel):
Args: Args:
reason: Optional reason for revocation reason: Optional reason for revocation
""" """
self.revoked_at = datetime.utcnow() self.revoked_at = datetime.now(timezone.utc)
self.revoked_reason = reason self.revoked_reason = reason
db.session.commit() db.session.commit()
+6 -6
View File
@@ -1,5 +1,5 @@
"""Session model.""" """Session model."""
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from app.extensions import db from app.extensions import db
from app.models.base import BaseModel from app.models.base import BaseModel
from app.utils.constants import SessionStatus from app.utils.constants import SessionStatus
@@ -34,7 +34,7 @@ class Session(BaseModel):
def is_active(self): def is_active(self):
"""Check if session is currently active.""" """Check if session is currently active."""
now = datetime.utcnow() now = datetime.now(timezone.utc)
return ( return (
self.status == SessionStatus.ACTIVE self.status == SessionStatus.ACTIVE
and self.expires_at > now and self.expires_at > now
@@ -43,7 +43,7 @@ class Session(BaseModel):
def is_expired(self): def is_expired(self):
"""Check if session has expired.""" """Check if session has expired."""
return datetime.utcnow() > self.expires_at return datetime.now(timezone.utc) > self.expires_at
def refresh(self, duration_seconds=86400): def refresh(self, duration_seconds=86400):
""" """
@@ -52,8 +52,8 @@ class Session(BaseModel):
Args: Args:
duration_seconds: New session duration in seconds duration_seconds: New session duration in seconds
""" """
self.expires_at = datetime.utcnow() + timedelta(seconds=duration_seconds) self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
self.last_activity_at = datetime.utcnow() self.last_activity_at = datetime.now(timezone.utc)
db.session.commit() db.session.commit()
def revoke(self, reason=None): def revoke(self, reason=None):
@@ -64,7 +64,7 @@ class Session(BaseModel):
reason: Optional reason for revocation reason: Optional reason for revocation
""" """
self.status = SessionStatus.REVOKED self.status = SessionStatus.REVOKED
self.revoked_at = datetime.utcnow() self.revoked_at = datetime.now(timezone.utc)
if reason: if reason:
self.revoked_reason = reason self.revoked_reason = reason
db.session.commit() db.session.commit()
+32
View File
@@ -59,3 +59,35 @@ class User(BaseModel):
def get_organizations(self): def get_organizations(self):
"""Get all organizations the user is a member of.""" """Get all organizations the user is a member of."""
return [membership.organization for membership in self.organization_memberships] return [membership.organization for membership in self.organization_memberships]
def has_totp_enabled(self) -> bool:
"""Check if user has TOTP enabled and verified.
Returns:
True if user has a verified TOTP authentication method, False otherwise.
"""
from app.models.authentication_method import AuthenticationMethod
from app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.TOTP,
verified=True,
deleted_at=None,
).first()
is not None
)
def get_totp_method(self):
"""Get user's TOTP authentication method.
Returns:
The AuthenticationMethod instance for TOTP or None if not found.
"""
from app.models.authentication_method import AuthenticationMethod
from app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
).first()
+31
View File
@@ -55,3 +55,34 @@ class ResetPasswordSchema(Schema):
"""Validate that passwords match.""" """Validate that passwords match."""
if data.get("password") != data.get("password_confirm"): if data.get("password") != data.get("password_confirm"):
raise ValidationError("Passwords do not match", field_name="password_confirm") raise ValidationError("Passwords do not match", field_name="password_confirm")
class TOTPVerifyEnrollmentSchema(Schema):
"""Schema for TOTP enrollment verification."""
code = fields.Str(
required=True,
validate=validate.Regexp(
r"^\d{6}$",
error="Code must be a 6-digit number",
),
)
class TOTPVerifySchema(Schema):
"""Schema for TOTP code verification during login."""
code = fields.Str(required=True)
is_backup_code = fields.Bool(missing=False)
class TOTPDisableSchema(Schema):
"""Schema for disabling TOTP."""
password = fields.Str(required=True, validate=validate.Length(min=1))
class TOTPRegenerateBackupCodesSchema(Schema):
"""Schema for regenerating backup codes."""
password = fields.Str(required=True, validate=validate.Length(min=1))
+313
View File
@@ -11,6 +11,7 @@ from app.utils.constants import AuthMethodType, SessionStatus, UserStatus, Audit
from app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError from app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
from app.exceptions.validation_exceptions import EmailAlreadyExistsError from app.exceptions.validation_exceptions import EmailAlreadyExistsError
from app.services.audit_service import AuditService from app.services.audit_service import AuditService
from app.services.totp_service import TOTPService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -234,3 +235,315 @@ class AuthService:
resource_id=session.id, resource_id=session.id,
description=f"Session revoked: {reason or 'User logout'}", description=f"Session revoked: {reason or 'User logout'}",
) )
@staticmethod
def enroll_totp(user: User) -> dict:
"""
Initiate TOTP enrollment for a user.
Args:
user: User instance
Returns:
Dictionary containing:
- secret: TOTP secret (base32 encoded)
- provisioning_uri: otpauth:// URI for QR code
- qr_code: Base64 encoded QR code as data URI
- backup_codes: List of plain text backup codes
Raises:
ConflictError: If user already has TOTP enabled
"""
from app.exceptions.validation_exceptions import ConflictError
# Check if user already has TOTP enabled
if user.has_totp_enabled():
raise ConflictError("TOTP is already enabled for this account")
# Generate TOTP secret
secret = TOTPService.generate_secret()
# Generate provisioning URI
provisioning_uri = TOTPService.generate_provisioning_uri(
user_email=user.email,
secret=secret,
issuer="Gatehouse",
)
# Generate QR code data URI
qr_code = TOTPService.generate_qr_code_data_uri(provisioning_uri)
# Generate backup codes
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
# Create unverified TOTP authentication method
auth_method = AuthenticationMethod(
user_id=user.id,
method_type=AuthMethodType.TOTP,
verified=False,
is_primary=False,
)
auth_method.save()
# Store TOTP data in provider_data (since totp_secret field is commented out)
auth_method.provider_data = {
"secret": secret,
"backup_codes": hashed_backup_codes,
}
db.session.commit()
# Log TOTP enrollment initiation
AuditService.log_action(
action=AuditAction.TOTP_ENROLL_INITIATED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP enrollment initiated",
)
return {
"secret": secret,
"provisioning_uri": provisioning_uri,
"qr_code": qr_code,
"backup_codes": backup_codes,
}
@staticmethod
def verify_totp_enrollment(user: User, code: str) -> bool:
"""
Complete TOTP enrollment by verifying the first TOTP code.
Args:
user: User instance
code: 6-digit TOTP code from authenticator app
Returns:
True if verification successful
Raises:
InvalidCredentialsError: If code is invalid or TOTP method not found
"""
# Get user's TOTP authentication method
auth_method = user.get_totp_method()
if not auth_method:
raise InvalidCredentialsError("TOTP enrollment not found")
# Get secret from provider_data
secret = auth_method.provider_data.get("secret") if auth_method.provider_data else None
if not secret:
raise InvalidCredentialsError("TOTP secret not found")
# Verify the code
if not TOTPService.verify_code(secret, code):
raise InvalidCredentialsError("Invalid TOTP code")
# Mark TOTP as verified
auth_method.verified = True
auth_method.totp_verified_at = datetime.utcnow()
db.session.commit()
# Log TOTP enrollment completion
AuditService.log_action(
action=AuditAction.TOTP_ENROLL_COMPLETED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP enrollment completed",
)
return True
@staticmethod
def disable_totp(user: User, password: str) -> bool:
"""
Disable TOTP for a user.
Args:
user: User instance
password: User's current password for verification
Returns:
True if TOTP disabled successfully
Raises:
InvalidCredentialsError: If password is invalid or TOTP method not found
"""
# Verify user's password
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError("No password authentication method found")
if not bcrypt.check_password_hash(auth_method.password_hash, password):
raise InvalidCredentialsError("Invalid password")
# Get user's TOTP authentication method
totp_method = user.get_totp_method()
if not totp_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
# Soft-delete the TOTP authentication method
totp_method.delete(soft=True)
# Log TOTP disabled
AuditService.log_action(
action=AuditAction.TOTP_DISABLED,
user_id=user.id,
resource_type="authentication_method",
resource_id=totp_method.id,
description="TOTP disabled",
)
return True
@staticmethod
def authenticate_with_totp(user: User, code: str, is_backup_code: bool = False) -> bool:
"""
Verify TOTP code during login.
Args:
user: User instance
code: 6-digit TOTP code or backup code
is_backup_code: True if code is a backup code, False if TOTP code
Returns:
True if code is valid
Raises:
InvalidCredentialsError: If code is invalid or TOTP method not found
"""
# Get user's TOTP authentication method
auth_method = user.get_totp_method()
if not auth_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
if is_backup_code:
# Verify backup code
backup_codes = (
auth_method.provider_data.get("backup_codes")
if auth_method.provider_data
else []
)
is_valid, remaining_codes = TOTPService.verify_backup_code(backup_codes, code)
if is_valid:
# Update remaining backup codes
auth_method.provider_data = {
"secret": auth_method.provider_data.get("secret"),
"backup_codes": remaining_codes,
}
auth_method.last_used_at = datetime.utcnow()
db.session.commit()
# Log backup code usage
AuditService.log_action(
action=AuditAction.TOTP_BACKUP_CODE_USED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Backup code used for authentication",
)
else:
# Log failed verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_FAILED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Invalid backup code provided",
)
raise InvalidCredentialsError("Invalid backup code")
else:
# Verify TOTP code
secret = (
auth_method.provider_data.get("secret")
if auth_method.provider_data
else None
)
if not secret:
raise InvalidCredentialsError("TOTP secret not found")
is_valid = TOTPService.verify_code(secret, code)
if is_valid:
auth_method.last_used_at = datetime.utcnow()
db.session.commit()
# Log successful verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_SUCCESS,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP code verified successfully",
)
else:
# Log failed verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_FAILED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Invalid TOTP code provided",
)
raise InvalidCredentialsError("Invalid TOTP code")
return True
@staticmethod
def regenerate_totp_backup_codes(user: User, password: str) -> list[str]:
"""
Generate new backup codes for TOTP.
Args:
user: User instance
password: User's current password for verification
Returns:
List of new plain text backup codes
Raises:
InvalidCredentialsError: If password is invalid or TOTP method not found
"""
# Verify user's password
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError("No password authentication method found")
if not bcrypt.check_password_hash(auth_method.password_hash, password):
raise InvalidCredentialsError("Invalid password")
# Get user's TOTP authentication method
totp_method = user.get_totp_method()
if not totp_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
# Generate new backup codes
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
# Update the authentication method with new backup codes
totp_method.provider_data = {
"secret": totp_method.provider_data.get("secret"),
"backup_codes": hashed_backup_codes,
}
db.session.commit()
# Log backup codes regeneration
AuditService.log_action(
action=AuditAction.TOTP_BACKUP_CODES_REGENERATED,
user_id=user.id,
resource_type="authentication_method",
resource_id=totp_method.id,
description="TOTP backup codes regenerated",
)
return backup_codes
+212 -5
View File
@@ -2,7 +2,7 @@
import logging import logging
import secrets import secrets
import hashlib import hashlib
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from flask import current_app, g from flask import current_app, g
@@ -14,6 +14,7 @@ from app.models import (
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken, User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
OIDCSession, OIDCTokenMetadata OIDCSession, OIDCTokenMetadata
) )
from app.models.organization_member import OrganizationMember
from app.exceptions.validation_exceptions import ( from app.exceptions.validation_exceptions import (
ValidationError, NotFoundError, BadRequestError ValidationError, NotFoundError, BadRequestError
) )
@@ -121,6 +122,14 @@ class OIDCService:
ValidationError: If parameters are invalid ValidationError: If parameters are invalid
NotFoundError: If client not found NotFoundError: If client not found
""" """
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] generate_authorization_code called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC SERVICE] redirect_uri=%s", redirect_uri)
logger.debug("[OIDC SERVICE] scope=%s", scope)
logger.debug("[OIDC SERVICE] state=%s, nonce=%s", state, nonce)
# Validate client exists and is active # Validate client exists and is active
client = OIDCClient.query.filter_by(client_id=client_id).first() client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -152,14 +161,19 @@ class OIDCService:
raise ValidationError("Invalid scopes") raise ValidationError("Invalid scopes")
# Generate authorization code # Generate authorization code
logger.debug("[OIDC SERVICE] Generating authorization code...")
logger.debug("[OIDC SERVICE] Current UTC time before code generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
code = cls._generate_code() code = cls._generate_code()
code_hash = cls._hash_value(code) code_hash = cls._hash_value(code)
logger.debug("[OIDC SERVICE] Code generated: %s...", code[:20] if code else None)
# Development-only debug logging for PKCE in code creation # Development-only debug logging for PKCE in code creation
if current_app.config.get('ENV') == 'development': if current_app.config.get('ENV') == 'development':
logger.debug(f"[OIDC] Generate auth code - PKCE: code_challenge={code_challenge is not None}, code_challenge_method={code_challenge_method}") logger.debug(f"[OIDC] Generate auth code - PKCE: code_challenge={code_challenge is not None}, code_challenge_method={code_challenge_method}")
# Create auth code record # Create auth code record
logger.debug("[OIDC SERVICE] Creating auth code record with lifetime_seconds=600 (10 minutes)")
logger.debug("[OIDC SERVICE] Current UTC time before creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z")
auth_code = OIDCAuthCode.create_code( auth_code = OIDCAuthCode.create_code(
client_id=client.id, client_id=client.id,
user_id=user_id, user_id=user_id,
@@ -172,6 +186,9 @@ class OIDCService:
user_agent=user_agent, user_agent=user_agent,
lifetime_seconds=600, # 10 minutes lifetime_seconds=600, # 10 minutes
) )
logger.debug("[OIDC SERVICE] Auth code created successfully")
logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z")
logger.debug("[OIDC SERVICE] Current UTC time after creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Log authorization event # Log authorization event
OIDCAuditService.log_authorization_event( OIDCAuditService.log_authorization_event(
@@ -182,6 +199,9 @@ class OIDCService:
scope=valid_scopes, scope=valid_scopes,
) )
logger.debug("[OIDC SERVICE] generate_authorization_code completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return code return code
@classmethod @classmethod
@@ -211,6 +231,12 @@ class OIDCService:
InvalidGrantError: If code is invalid InvalidGrantError: If code is invalid
ValidationError: If PKCE validation fails ValidationError: If PKCE validation fails
""" """
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] validate_authorization_code called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, redirect_uri=%s", client_id, redirect_uri)
logger.debug("[OIDC SERVICE] code_verifier provided: %s", bool(code_verifier))
# Get client # Get client
client = OIDCClient.query.filter_by(client_id=client_id).first() client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -223,6 +249,8 @@ class OIDCService:
raise InvalidGrantError("Invalid client") raise InvalidGrantError("Invalid client")
# Hash the provided code and find matching auth code # Hash the provided code and find matching auth code
logger.debug("[OIDC SERVICE] Looking up authorization code...")
logger.debug("[OIDC SERVICE] Current UTC time before code lookup: %s", datetime.now(timezone.utc).isoformat() + "Z")
code_hash = cls._hash_value(code) code_hash = cls._hash_value(code)
auth_code = OIDCAuthCode.query.filter_by( auth_code = OIDCAuthCode.query.filter_by(
code_hash=code_hash, code_hash=code_hash,
@@ -256,8 +284,18 @@ class OIDCService:
raise InvalidGrantError("Authorization code already used") raise InvalidGrantError("Authorization code already used")
# Check expiration # Check expiration
logger.debug("[OIDC SERVICE] Checking if authorization code is expired...")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z")
# Handle timezone-naive expires_at from database
expires_at = auth_code.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (expires_at - datetime.now(timezone.utc)).total_seconds())
if auth_code.is_expired(): if auth_code.is_expired():
logger.error(f"[OIDC] Validate auth code - Code expired: code_hash={code_hash[:20]}..., expires_at={auth_code.expires_at}") logger.error("[OIDC] Validate auth code - Code expired: code_hash=%s..., expires_at (UTC)=%s, current UTC time=%s",
code_hash[:20], auth_code.expires_at.isoformat() + "Z", datetime.now(timezone.utc).isoformat() + "Z")
OIDCAuditService.log_authorization_event( OIDCAuditService.log_authorization_event(
client_id=client_id, client_id=client_id,
user_id=auth_code.user_id, user_id=auth_code.user_id,
@@ -316,6 +354,9 @@ class OIDCService:
"nonce": auth_code.nonce, "nonce": auth_code.nonce,
} }
logger.debug("[OIDC SERVICE] validate_authorization_code completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return claims, user return claims, user
@classmethod @classmethod
@@ -366,6 +407,12 @@ class OIDCService:
""" """
import hashlib import hashlib
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] generate_tokens called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s, scope=%s", client_id, user_id, scope)
logger.debug("[OIDC SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
# Get client # Get client
client = OIDCClient.query.filter_by(client_id=client_id).first() client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -377,6 +424,9 @@ class OIDCService:
raise InvalidClientError() raise InvalidClientError()
# Generate access token # Generate access token
logger.debug("[OIDC SERVICE] Generating access token...")
logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600)
access_token_jti = OIDCTokenService._generate_jti() access_token_jti = OIDCTokenService._generate_jti()
access_token = OIDCTokenService.create_access_token( access_token = OIDCTokenService.create_access_token(
client_id=client_id, client_id=client_id,
@@ -384,8 +434,13 @@ class OIDCService:
scope=scope, scope=scope,
jti=access_token_jti, jti=access_token_jti,
) )
logger.debug("[OIDC SERVICE] Access token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Generate ID token # Generate ID token
logger.debug("[OIDC SERVICE] Generating ID token...")
logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600)
id_token = OIDCTokenService.create_id_token( id_token = OIDCTokenService.create_id_token(
client_id=client_id, client_id=client_id,
user_id=user_id, user_id=user_id,
@@ -394,6 +449,8 @@ class OIDCService:
access_token=access_token, access_token=access_token,
auth_time=auth_time, auth_time=auth_time,
) )
logger.debug("[OIDC SERVICE] ID token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Generate or rotate refresh token # Generate or rotate refresh token
if "refresh_token" in (client.grant_types or []): if "refresh_token" in (client.grant_types or []):
@@ -445,22 +502,28 @@ class OIDCService:
client_db_id = client.id client_db_id = client.id
# Access token metadata # Access token metadata
logger.debug("[OIDC SERVICE] Creating access token metadata...")
access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600)
logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z")
OIDCTokenMetadata.create_metadata( OIDCTokenMetadata.create_metadata(
client_id=client_db_id, client_id=client_db_id,
user_id=user_id, user_id=user_id,
token_type="access_token", token_type="access_token",
token_jti=access_token_jti, token_jti=access_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600), expires_at=access_token_expires_at,
) )
# ID token metadata (using access token JTI as reference) # ID token metadata (using access token JTI as reference)
logger.debug("[OIDC SERVICE] Creating ID token metadata...")
id_token_jti = OIDCTokenService._generate_jti() id_token_jti = OIDCTokenService._generate_jti()
id_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.id_token_lifetime or 3600)
logger.debug("[OIDC SERVICE] ID token expires_at (UTC): %s", id_token_expires_at.isoformat() + "Z")
OIDCTokenMetadata.create_metadata( OIDCTokenMetadata.create_metadata(
client_id=client_db_id, client_id=client_db_id,
user_id=user_id, user_id=user_id,
token_type="id_token", token_type="id_token",
token_jti=id_token_jti, token_jti=id_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.id_token_lifetime or 3600), expires_at=id_token_expires_at,
) )
# Log token event # Log token event
@@ -483,6 +546,9 @@ class OIDCService:
if final_refresh_token: if final_refresh_token:
result["refresh_token"] = final_refresh_token result["refresh_token"] = final_refresh_token
logger.debug("[OIDC SERVICE] generate_tokens completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return result return result
@classmethod @classmethod
@@ -511,6 +577,11 @@ class OIDCService:
""" """
import hashlib import hashlib
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] refresh_access_token called")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] client_id=%s, scope=%s", client_id, scope)
# Get client # Get client
client = OIDCClient.query.filter_by(client_id=client_id).first() client = OIDCClient.query.filter_by(client_id=client_id).first()
@@ -522,6 +593,8 @@ class OIDCService:
raise InvalidClientError() raise InvalidClientError()
# Find refresh token # Find refresh token
logger.debug("[OIDC SERVICE] Looking up refresh token...")
logger.debug("[OIDC SERVICE] Current UTC time before refresh token lookup: %s", datetime.now(timezone.utc).isoformat() + "Z")
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest() token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
refresh_token_obj = OIDCRefreshToken.query.filter_by( refresh_token_obj = OIDCRefreshToken.query.filter_by(
token_hash=token_hash, token_hash=token_hash,
@@ -542,6 +615,16 @@ class OIDCService:
raise InvalidGrantError("Invalid refresh token") raise InvalidGrantError("Invalid refresh token")
# Check if valid # Check if valid
logger.debug("[OIDC SERVICE] Checking if refresh token is valid...")
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
if refresh_token_obj:
logger.debug("[OIDC SERVICE] Refresh token expires_at (UTC): %s", refresh_token_obj.expires_at.isoformat() + "Z")
# Handle timezone-naive expires_at from database
rt_expires_at = refresh_token_obj.expires_at
if rt_expires_at.tzinfo is None:
rt_expires_at = rt_expires_at.replace(tzinfo=timezone.utc)
logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (rt_expires_at - datetime.now(timezone.utc)).total_seconds())
if not refresh_token_obj.is_valid(): if not refresh_token_obj.is_valid():
OIDCAuditService.log_token_event( OIDCAuditService.log_token_event(
client_id=client_id, client_id=client_id,
@@ -563,6 +646,9 @@ class OIDCService:
granted_scope = scope or (refresh_token_obj.scope or []) granted_scope = scope or (refresh_token_obj.scope or [])
# Generate new access token # Generate new access token
logger.debug("[OIDC SERVICE] Generating new access token...")
logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600)
access_token_jti = OIDCTokenService._generate_jti() access_token_jti = OIDCTokenService._generate_jti()
access_token = OIDCTokenService.create_access_token( access_token = OIDCTokenService.create_access_token(
client_id=client_id, client_id=client_id,
@@ -570,14 +656,21 @@ class OIDCService:
scope=granted_scope, scope=granted_scope,
jti=access_token_jti, jti=access_token_jti,
) )
logger.debug("[OIDC SERVICE] Access token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Generate new ID token # Generate new ID token
logger.debug("[OIDC SERVICE] Generating new ID token...")
logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600)
id_token = OIDCTokenService.create_id_token( id_token = OIDCTokenService.create_id_token(
client_id=client_id, client_id=client_id,
user_id=refresh_token_obj.user_id, user_id=refresh_token_obj.user_id,
scope=granted_scope, scope=granted_scope,
access_token=access_token, access_token=access_token,
) )
logger.debug("[OIDC SERVICE] ID token generated successfully")
logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
# Rotate refresh token # Rotate refresh token
new_refresh, new_hash = OIDCTokenService.create_refresh_token( new_refresh, new_hash = OIDCTokenService.create_refresh_token(
@@ -590,12 +683,15 @@ class OIDCService:
refresh_token_obj.rotate(new_hash) refresh_token_obj.rotate(new_hash)
# Store new token metadata # Store new token metadata
logger.debug("[OIDC SERVICE] Creating access token metadata...")
access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600)
logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z")
OIDCTokenMetadata.create_metadata( OIDCTokenMetadata.create_metadata(
client_id=client.id, client_id=client.id,
user_id=refresh_token_obj.user_id, user_id=refresh_token_obj.user_id,
token_type="access_token", token_type="access_token",
token_jti=access_token_jti, token_jti=access_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600), expires_at=access_token_expires_at,
) )
# Log refresh event # Log refresh event
@@ -615,6 +711,17 @@ class OIDCService:
"id_token": id_token, "id_token": id_token,
"refresh_token": new_refresh, "refresh_token": new_refresh,
} }
logger.debug("[OIDC SERVICE] refresh_access_token completed successfully")
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
logger.debug("[OIDC SERVICE] ===========================================")
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": client.access_token_lifetime or 3600,
"id_token": id_token,
"refresh_token": new_refresh,
}
@classmethod @classmethod
def validate_access_token(cls, token: str, client_id: str = None) -> Dict: def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
@@ -630,10 +737,23 @@ class OIDCService:
Raises: Raises:
InvalidTokenError: If token is invalid InvalidTokenError: If token is invalid
""" """
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] validate_access_token() called")
logger.debug("[OIDC SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC SERVICE] Token length: %d", len(token))
logger.debug("[OIDC SERVICE] Client ID: %s", client_id)
try: try:
logger.debug("[OIDC SERVICE] Calling OIDCTokenService.validate_access_token()...")
claims = OIDCTokenService.validate_access_token(token, client_id) claims = OIDCTokenService.validate_access_token(token, client_id)
logger.debug("[OIDC SERVICE] Token validation successful")
logger.debug("[OIDC SERVICE] Token claims: %s", claims)
logger.debug("[OIDC SERVICE] ===========================================")
return claims return claims
except Exception as e: except Exception as e:
logger.error("[OIDC SERVICE] Token validation failed: %s: %s", type(e).__name__, str(e))
import traceback
logger.error("[OIDC SERVICE] Traceback: %s", traceback.format_exc())
OIDCAuditService.log_event( OIDCAuditService.log_event(
event_type="token_validation", event_type="token_validation",
client_id=client_id, client_id=client_id,
@@ -770,29 +890,67 @@ class OIDCService:
Returns: Returns:
User information dictionary User information dictionary
""" """
logger.debug("[OIDC SERVICE] ===========================================")
logger.debug("[OIDC SERVICE] get_userinfo() called")
logger.debug("[OIDC SERVICE] Access token (first 50 chars): %s...", access_token[:50] if len(access_token) > 50 else access_token)
logger.debug("[OIDC SERVICE] Access token length: %d", len(access_token))
# Validate access token
logger.debug("[OIDC SERVICE] Validating access token...")
claims = cls.validate_access_token(access_token) claims = cls.validate_access_token(access_token)
logger.debug("[OIDC SERVICE] Access token validated successfully")
logger.debug("[OIDC SERVICE] Token claims: %s", claims)
user_id = claims.get("sub") user_id = claims.get("sub")
logger.debug("[OIDC SERVICE] User ID from token: %s", user_id)
logger.debug("[OIDC SERVICE] Querying user from database...")
user = User.query.get(user_id) user = User.query.get(user_id)
logger.debug("[OIDC SERVICE] User query result: %s", user)
if not user: if not user:
logger.error("[OIDC SERVICE] User not found in database: user_id=%s", user_id)
raise NotFoundError("User not found") raise NotFoundError("User not found")
logger.debug("[OIDC SERVICE] User found: user_id=%s, email=%s, full_name=%s", user.id, user.email, user.full_name)
# Get scopes from token # Get scopes from token
scope_str = claims.get("scope", "") scope_str = claims.get("scope", "")
scopes = scope_str.split() if scope_str else [] scopes = scope_str.split() if scope_str else []
logger.debug("[OIDC SERVICE] Scope string from token: '%s'", scope_str)
logger.debug("[OIDC SERVICE] Parsed scopes: %s", scopes)
userinfo = {"sub": user_id} userinfo = {"sub": user_id}
logger.debug("[OIDC SERVICE] Initial userinfo: %s", userinfo)
# Add claims based on scope # Add claims based on scope
if "profile" in scopes and user.full_name: if "profile" in scopes and user.full_name:
logger.debug("[OIDC SERVICE] Found 'profile' in scope, adding name claim")
userinfo["name"] = user.full_name userinfo["name"] = user.full_name
logger.debug("[OIDC SERVICE] Added name: %s", user.full_name)
else:
logger.debug("[OIDC SERVICE] 'profile' not in scope or user.full_name is None: profile_in_scope=%s, full_name=%s", "profile" in scopes, user.full_name)
if "email" in scopes: if "email" in scopes:
logger.debug("[OIDC SERVICE] Found 'email' in scope, adding email claims")
userinfo["email"] = user.email userinfo["email"] = user.email
userinfo["email_verified"] = user.email_verified userinfo["email_verified"] = user.email_verified
logger.debug("[OIDC SERVICE] Added email: %s, email_verified: %s", user.email, user.email_verified)
else:
logger.debug("[OIDC SERVICE] 'email' not in scope")
if "roles" in scopes:
logger.debug("[OIDC SERVICE] Found 'roles' in scope, adding roles claim")
user_roles = cls._get_user_roles(user)
userinfo["roles"] = user_roles
logger.debug("[OIDC SERVICE] Added roles: %s", user_roles)
else:
logger.debug("[OIDC SERVICE] 'roles' not in scope")
logger.debug("[OIDC SERVICE] Final userinfo: %s", userinfo)
# Log userinfo access # Log userinfo access
logger.debug("[OIDC SERVICE] Logging userinfo access event...")
OIDCAuditService.log_userinfo_event( OIDCAuditService.log_userinfo_event(
access_token=access_token, access_token=access_token,
user_id=user_id, user_id=user_id,
@@ -800,5 +958,54 @@ class OIDCService:
success=True, success=True,
scopes_claimed=scopes, scopes_claimed=scopes,
) )
logger.debug("[OIDC SERVICE] Userinfo access event logged")
logger.debug("[OIDC SERVICE] get_userinfo() completed successfully")
logger.debug("[OIDC SERVICE] ===========================================")
return userinfo return userinfo
@staticmethod
def _get_user_roles(user: User) -> list:
"""Get user's organization roles.
Args:
user: User instance
Returns:
List of role objects with organization_id and role
"""
logger.debug("[OIDC SERVICE] _get_user_roles() called")
logger.debug("[OIDC SERVICE] User: %s", user)
roles = []
if not user:
logger.debug("[OIDC SERVICE] User is None, returning empty roles list")
return roles
logger.debug("[OIDC SERVICE] User ID: %s", user.id)
logger.debug("[OIDC SERVICE] User email: %s", user.email)
logger.debug("[OIDC SERVICE] User organization_memberships: %s", user.organization_memberships)
if user.organization_memberships:
logger.debug("[OIDC SERVICE] User has %d organization memberships", len(user.organization_memberships))
for idx, member in enumerate(user.organization_memberships):
logger.debug("[OIDC SERVICE] Processing membership %d: member=%s", idx, member)
logger.debug("[OIDC SERVICE] organization_id: %s", member.organization_id)
logger.debug("[OIDC SERVICE] role: %s", member.role)
logger.debug("[OIDC SERVICE] role.value: %s", member.role.value)
role_entry = {
"organization_id": str(member.organization_id),
"role": member.role.value
}
roles.append(role_entry)
logger.debug("[OIDC SERVICE] Added role entry: %s", role_entry)
else:
logger.debug("[OIDC SERVICE] User has no organization memberships")
logger.debug("[OIDC SERVICE] Final roles list: %s", roles)
logger.debug("[OIDC SERVICE] _get_user_roles() completed")
return roles
+3 -2
View File
@@ -3,6 +3,7 @@ import secrets
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from datetime import timezone
from flask import current_app, g from flask import current_app, g
from app.extensions import db from app.extensions import db
@@ -219,11 +220,11 @@ class OIDCSessionService:
""" """
from datetime import timedelta from datetime import timedelta
cutoff = datetime.utcnow() - timedelta(hours=older_than_hours) cutoff = datetime.now(timezone.utc) - timedelta(hours=older_than_hours)
# Get expired sessions # Get expired sessions
expired_sessions = OIDCSession.query.filter( expired_sessions = OIDCSession.query.filter(
OIDCSession.expires_at < datetime.utcnow(), OIDCSession.expires_at < datetime.now(timezone.utc),
OIDCSession.deleted_at == None OIDCSession.deleted_at == None
).all() ).all()
+191 -37
View File
@@ -2,15 +2,20 @@
import hashlib import hashlib
import base64 import base64
import secrets import secrets
from datetime import datetime, timedelta import logging
import time
from datetime import datetime, timedelta, timezone
from typing import Dict, Optional, Any from typing import Dict, Optional, Any
import jwt import jwt
from flask import current_app, g from flask import current_app, g
from app.models import User, OIDCClient from app.models import User, OIDCClient
from app.models.organization_member import OrganizationMember
from app.services.oidc_jwks_service import OIDCJWKSService from app.services.oidc_jwks_service import OIDCJWKSService
logger = logging.getLogger(__name__)
class OIDCTokenService: class OIDCTokenService:
"""Service for generating and validating OIDC tokens. """Service for generating and validating OIDC tokens.
@@ -134,7 +139,7 @@ class OIDCTokenService:
return lifetimes.get(token_type, 3600) return lifetimes.get(token_type, 3600)
@classmethod @classmethod
def create_access_token(cls, client_id: str, user_id: str, scope: list, def create_access_token(cls, client_id: str, user_id: str, scope: list,
jti: str = None) -> str: jti: str = None) -> str:
"""Create a JWT access token. """Create a JWT access token.
@@ -147,25 +152,44 @@ class OIDCTokenService:
Returns: Returns:
JWT access token string JWT access token string
""" """
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_access_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
jti = jti or cls._generate_jti() jti = jti or cls._generate_jti()
now = datetime.utcnow() now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
# Get client for token lifetime # Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first() client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600 lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
logger.debug("[OIDC TOKEN SERVICE] Access token lifetime (seconds): %s", lifetime)
exp_timestamp = now_timestamp + lifetime
exp_time = now + timedelta(seconds=lifetime)
logger.debug("[OIDC TOKEN SERVICE] Access token expiration time (UTC): %s", exp_time.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Access token expiration timestamp: %s", exp_timestamp)
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
claims = { claims = {
"iss": cls._get_issuer(), "iss": cls._get_issuer(),
"sub": user_id, "sub": user_id,
"aud": client_id, "aud": client_id,
"exp": int((now + timedelta(seconds=lifetime)).timestamp()), "exp": exp_timestamp,
"iat": int(now.timestamp()), "iat": now_timestamp,
"nbf": int(now.timestamp()), "nbf": now_timestamp,
"jti": jti, "jti": jti,
"client_id": client_id, "client_id": client_id,
"scope": " ".join(scope) if isinstance(scope, list) else scope, "scope": " ".join(scope) if isinstance(scope, list) else scope,
} }
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, nbf=%s",
claims["exp"], claims["iat"], claims["nbf"])
# Get signing key # Get signing key
jwks_service = OIDCJWKSService() jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key() signing_key = jwks_service.get_signing_key()
@@ -174,6 +198,7 @@ class OIDCTokenService:
raise ValueError("No signing key available") raise ValueError("No signing key available")
# Sign with RS256 # Sign with RS256
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
token = jwt.encode( token = jwt.encode(
claims, claims,
signing_key.private_key, signing_key.private_key,
@@ -181,6 +206,9 @@ class OIDCTokenService:
headers={"kid": signing_key.kid} headers={"kid": signing_key.kid}
) )
logger.debug("[OIDC TOKEN SERVICE] Access token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token return token
@classmethod @classmethod
@@ -200,12 +228,30 @@ class OIDCTokenService:
Returns: Returns:
JWT ID token string JWT ID token string
""" """
now = datetime.utcnow() logger.debug("[OIDC TOKEN SERVICE] ===========================================")
auth_time = auth_time or int(now.timestamp()) logger.debug("[OIDC TOKEN SERVICE] create_id_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
auth_time = auth_time or now_timestamp
logger.debug("[OIDC TOKEN SERVICE] auth_time (Unix timestamp): %s", auth_time)
# Get client for token lifetime # Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first() client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600 lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
logger.debug("[OIDC TOKEN SERVICE] ID token lifetime (seconds): %s", lifetime)
exp_timestamp = now_timestamp + lifetime
exp_time = now + timedelta(seconds=lifetime)
logger.debug("[OIDC TOKEN SERVICE] ID token expiration time (UTC): %s", exp_time.isoformat())
logger.debug("[OIDC TOKEN SERVICE] ID token expiration timestamp: %s", exp_timestamp)
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
# Get user for claims # Get user for claims
user = User.query.get(user_id) user = User.query.get(user_id)
@@ -214,11 +260,14 @@ class OIDCTokenService:
"iss": cls._get_issuer(), "iss": cls._get_issuer(),
"sub": user_id, "sub": user_id,
"aud": client_id, "aud": client_id,
"exp": int((now + timedelta(seconds=lifetime)).timestamp()), "exp": exp_timestamp,
"iat": int(now.timestamp()), "iat": now_timestamp,
"auth_time": auth_time, "auth_time": auth_time,
} }
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, auth_time=%s",
claims["exp"], claims["iat"], claims["auth_time"])
# Add nonce if provided # Add nonce if provided
if nonce: if nonce:
claims["nonce"] = nonce claims["nonce"] = nonce
@@ -235,6 +284,10 @@ class OIDCTokenService:
if user.full_name: if user.full_name:
claims["name"] = user.full_name claims["name"] = user.full_name
# Add roles claim if scope is granted
if scope and "roles" in scope:
claims["roles"] = cls._get_user_roles(user)
# Add scope if provided # Add scope if provided
if scope: if scope:
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
@@ -247,6 +300,7 @@ class OIDCTokenService:
raise ValueError("No signing key available") raise ValueError("No signing key available")
# Sign with RS256 # Sign with RS256
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
token = jwt.encode( token = jwt.encode(
claims, claims,
signing_key.private_key, signing_key.private_key,
@@ -254,10 +308,32 @@ class OIDCTokenService:
headers={"kid": signing_key.kid} headers={"kid": signing_key.kid}
) )
logger.debug("[OIDC TOKEN SERVICE] ID token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token return token
@staticmethod
def _get_user_roles(user: User) -> list:
"""Get user's organization roles.
Args:
user: User instance
Returns:
List of role objects with organization_id and role
"""
roles = []
if user and user.organization_memberships:
for member in user.organization_memberships:
roles.append({
"organization_id": str(member.organization_id),
"role": member.role.value
})
return roles
@classmethod @classmethod
def create_refresh_token(cls, client_id: str, user_id: str, def create_refresh_token(cls, client_id: str, user_id: str,
scope: list = None, access_token_id: str = None) -> str: scope: list = None, access_token_id: str = None) -> str:
"""Create an opaque refresh token. """Create an opaque refresh token.
@@ -270,11 +346,21 @@ class OIDCTokenService:
Returns: Returns:
Opaque refresh token string Opaque refresh token string
""" """
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_refresh_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] scope=%s, access_token_id=%s", scope, access_token_id)
token = cls._generate_opaque_token() token = cls._generate_opaque_token()
logger.debug("[OIDC TOKEN SERVICE] Refresh token generated: %s...", token[:20] if token else None)
# Hash for storage # Hash for storage
token_hash = cls._hash_token(token) token_hash = cls._hash_token(token)
logger.debug("[OIDC TOKEN SERVICE] Refresh token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token, token_hash return token, token_hash
@classmethod @classmethod
@@ -292,54 +378,91 @@ class OIDCTokenService:
jwt.ExpiredSignatureError: If token is expired jwt.ExpiredSignatureError: If token is expired
jwt.InvalidTokenError: If token is invalid jwt.InvalidTokenError: If token is invalid
""" """
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] verify_token_signature() called")
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
# Get the JWKS with public keys # Get the JWKS with public keys
logger.debug("[OIDC TOKEN SERVICE] Getting JWKS...")
jwks_service = OIDCJWKSService() jwks_service = OIDCJWKSService()
jwks = jwks_service.get_jwks() jwks = jwks_service.get_jwks(include_private_keys=True)
logger.debug("[OIDC TOKEN SERVICE] JWKS retrieved: %d keys", len(jwks.get("keys", [])))
# Get the key ID from token header # Get the key ID from token header
try: try:
logger.debug("[OIDC TOKEN SERVICE] Getting unverified token header...")
unverified_header = jwt.get_unverified_header(token) unverified_header = jwt.get_unverified_header(token)
except jwt.DecodeError: logger.debug("[OIDC TOKEN SERVICE] Unverified header: %s", unverified_header)
except jwt.DecodeError as e:
logger.error("[OIDC TOKEN SERVICE] Failed to decode token header: %s", str(e))
raise jwt.InvalidTokenError("Invalid token header") raise jwt.InvalidTokenError("Invalid token header")
kid = unverified_header.get("kid") kid = unverified_header.get("kid")
logger.debug("[OIDC TOKEN SERVICE] Key ID (kid) from token header: %s", kid)
# Find the matching public key # Find the matching public key
logger.debug("[OIDC TOKEN SERVICE] Searching for matching public key...")
public_key = None public_key = None
for key in jwks.get("keys", []): for idx, key in enumerate(jwks.get("keys", [])):
logger.debug("[OIDC TOKEN SERVICE] Checking key %d: kid=%s", idx, key.get("kid"))
if key.get("kid") == kid: if key.get("kid") == kid:
logger.debug("[OIDC TOKEN SERVICE] Found matching key at index %d", idx)
try: try:
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
logger.debug("[OIDC TOKEN SERVICE] Loading PEM public key...")
public_key = serialization.load_pem_public_key( public_key = serialization.load_pem_public_key(
key["public_key"].encode() if isinstance(key["public_key"], str) key["public_key"].encode() if isinstance(key["public_key"], str)
else key["public_key"], else key["public_key"],
backend=default_backend() backend=default_backend()
) )
logger.debug("[OIDC TOKEN SERVICE] Public key loaded successfully")
break break
except (ImportError, Exception): except (ImportError, Exception) as e:
logger.error("[OIDC TOKEN SERVICE] Failed to load public key: %s: %s", type(e).__name__, str(e))
continue continue
if not public_key: if not public_key:
logger.error("[OIDC TOKEN SERVICE] No matching public key found for kid=%s", kid)
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found") raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
# Verify the signature logger.debug("[OIDC TOKEN SERVICE] Public key found, verifying signature...")
claims = jwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=None, # We'll validate audience separately
issuer=cls._get_issuer(),
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False, # Handle audience manually
"verify_iss": False, # Handle issuer manually
}
)
return claims # Verify the signature
try:
claims = jwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=None, # We'll validate audience separately
issuer=cls._get_issuer(),
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False, # Handle audience manually
"verify_iss": False, # Handle issuer manually
}
)
logger.debug("[OIDC TOKEN SERVICE] Signature verification successful")
logger.debug("[OIDC TOKEN SERVICE] Decoded claims: %s", claims)
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return claims
except jwt.ExpiredSignatureError as e:
logger.error("[OIDC TOKEN SERVICE] Token has expired: %s", str(e))
raise
except jwt.InvalidSignatureError as e:
logger.error("[OIDC TOKEN SERVICE] Invalid token signature: %s", str(e))
raise
except jwt.InvalidTokenError as e:
logger.error("[OIDC TOKEN SERVICE] Invalid token: %s: %s", type(e).__name__, str(e))
raise
except Exception as e:
logger.error("[OIDC TOKEN SERVICE] Unexpected error during token verification: %s: %s", type(e).__name__, str(e))
import traceback
logger.error("[OIDC TOKEN SERVICE] Traceback: %s", traceback.format_exc())
raise
@classmethod @classmethod
def decode_token(cls, token: str, verify: bool = False) -> Dict: def decode_token(cls, token: str, verify: bool = False) -> Dict:
@@ -378,16 +501,41 @@ class OIDCTokenService:
jwt.InvalidTokenError: If token is invalid jwt.InvalidTokenError: If token is invalid
ValueError: If token is expired or audience mismatch ValueError: If token is expired or audience mismatch
""" """
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() called")
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
logger.debug("[OIDC TOKEN SERVICE] Client ID: %s", client_id)
# Verify token signature
logger.debug("[OIDC TOKEN SERVICE] Verifying token signature...")
claims = cls.verify_token_signature(token) claims = cls.verify_token_signature(token)
logger.debug("[OIDC TOKEN SERVICE] Token signature verified")
logger.debug("[OIDC TOKEN SERVICE] Claims: %s", claims)
# Check expiration # Check expiration
if claims.get("exp", 0) < datetime.utcnow().timestamp(): exp = claims.get("exp", 0)
now_timestamp = int(time.time())
if exp < now_timestamp:
logger.error("[OIDC TOKEN SERVICE] Token has expired")
raise ValueError("Token has expired") raise ValueError("Token has expired")
# Validate audience if client_id provided # Validate audience if client_id provided
aud = claims.get("aud")
logger.debug("[OIDC TOKEN SERVICE] Token audience (aud): %s", aud)
logger.debug("[OIDC TOKEN SERVICE] Expected client_id: %s", client_id)
if client_id: if client_id:
if claims.get("aud") != client_id: if aud != client_id:
logger.error("[OIDC TOKEN SERVICE] Audience mismatch: expected=%s, got=%s", client_id, aud)
raise ValueError("Invalid audience") raise ValueError("Invalid audience")
logger.debug("[OIDC TOKEN SERVICE] Audience validation passed")
else:
logger.debug("[OIDC TOKEN SERVICE] No client_id provided, skipping audience validation")
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() completed successfully")
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return claims return claims
@@ -410,11 +558,17 @@ class OIDCTokenService:
claims = cls.validate_access_token(token, client_id) claims = cls.validate_access_token(token, client_id)
# Calculate remaining time # Calculate remaining time
now = datetime.utcnow().timestamp() now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
exp = claims.get("exp", 0) exp = claims.get("exp", 0)
iat = claims.get("iat", 0) iat = claims.get("iat", 0)
result["active"] = exp > now logger.debug("[OIDC TOKEN SERVICE] Introspection - Current UTC time: %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration timestamp: %s", exp)
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration datetime (UTC): %s", datetime.fromtimestamp(exp, tz=timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] Introspection - Time until expiration: %s seconds", exp - now_timestamp)
result["active"] = exp > now_timestamp
result.update({ result.update({
"iss": claims.get("iss"), "iss": claims.get("iss"),
"sub": claims.get("sub"), "sub": claims.get("sub"),
@@ -429,8 +583,8 @@ class OIDCTokenService:
}) })
# Add expiry in seconds # Add expiry in seconds
if exp > now: if exp > now_timestamp:
result["exp"] = int(exp - now) result["exp"] = int(exp - now_timestamp)
except (jwt.InvalidTokenError, ValueError) as e: except (jwt.InvalidTokenError, ValueError) as e:
result["active"] = False result["active"] = False
+188
View File
@@ -0,0 +1,188 @@
"""TOTP (Time-based One-Time Password) service."""
import base64
import io
import logging
import secrets
from typing import Tuple
import pyotp
from app.extensions import bcrypt
logger = logging.getLogger(__name__)
class TOTPService:
"""Service for TOTP operations."""
@staticmethod
def generate_secret() -> str:
"""
Generate a new TOTP secret.
Returns:
Base32 encoded secret (32 characters)
Note:
The secret is generated using cryptographically secure random bytes
and encoded in base32 format for compatibility with authenticator apps.
"""
# Generate 20 random bytes (160 bits) and encode as base32
random_bytes = secrets.token_bytes(20)
secret = base64.b32encode(random_bytes).decode("utf-8")
logger.debug(f"Generated new TOTP secret: {secret[:8]}...")
return secret
@staticmethod
def generate_provisioning_uri(user_email: str, secret: str, issuer: str = "Gatehouse") -> str:
"""
Generate provisioning URI for QR code.
Args:
user_email: User's email address
secret: TOTP secret (base32 encoded)
issuer: Issuer name (default: "Gatehouse")
Returns:
otpauth:// URI for QR code generation
Example:
>>> uri = TOTPService.generate_provisioning_uri("user@example.com", "JBSWY3DPEHPK3PXP")
>>> print(uri)
otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse
"""
totp = pyotp.TOTP(secret)
uri = totp.provisioning_uri(name=user_email, issuer_name=issuer)
logger.debug(f"Generated provisioning URI for user: {user_email}")
return uri
@staticmethod
def verify_code(secret: str, code: str, window: int = 1) -> bool:
"""
Verify a TOTP code against the secret.
Args:
secret: TOTP secret (base32 encoded)
code: 6-digit TOTP code to verify
window: Time window for code validation (default: 1, allows codes from previous/next time steps)
Returns:
True if code is valid, False otherwise
Note:
The window parameter allows for clock skew between the server
and the authenticator app. A window of 1 allows codes from
the previous, current, and next 30-second intervals.
"""
totp = pyotp.TOTP(secret)
is_valid = totp.verify(code, valid_window=window)
logger.debug(f"TOTP code verification: valid={is_valid}, window={window}")
return is_valid
@staticmethod
def generate_backup_codes(count: int = 10) -> Tuple[list[str], list[str]]:
"""
Generate backup codes for TOTP recovery.
Args:
count: Number of backup codes to generate (default: 10)
Returns:
Tuple of (plain_codes, hashed_codes)
- plain_codes: List of plain text backup codes (for display to user)
- hashed_codes: List of bcrypt hashed backup codes (for storage)
Note:
Backup codes are 16-character alphanumeric codes that can be used
to recover access if the TOTP device is lost. Each code can only
be used once.
"""
plain_codes = []
hashed_codes = []
for _ in range(count):
# Generate a 16-character alphanumeric code
code = secrets.token_hex(8).upper()
plain_codes.append(code)
# Hash the code using bcrypt
hashed_code = bcrypt.generate_password_hash(code).decode("utf-8")
hashed_codes.append(hashed_code)
logger.debug(f"Generated {count} backup codes")
return plain_codes, hashed_codes
@staticmethod
def verify_backup_code(hashed_codes: list[str], code: str) -> Tuple[bool, list[str]]:
"""
Verify and consume a backup code.
Args:
hashed_codes: List of bcrypt hashed backup codes
code: Plain text backup code to verify
Returns:
Tuple of (is_valid, remaining_codes)
- is_valid: True if code was valid and consumed, False otherwise
- remaining_codes: List of remaining hashed codes (with consumed code removed)
Note:
Once a backup code is used, it is removed from the list and cannot
be used again. This ensures each code is single-use.
"""
remaining_codes = []
for hashed_code in hashed_codes:
if bcrypt.check_password_hash(hashed_code, code):
# Code found and valid - don't add to remaining codes (consumed)
logger.debug("Backup code verified and consumed")
return True, remaining_codes
else:
# Code doesn't match - keep it in remaining codes
remaining_codes.append(hashed_code)
logger.debug("Backup code verification failed")
return False, remaining_codes
@staticmethod
def generate_qr_code_data_uri(provisioning_uri: str) -> str:
"""
Generate QR code as data URI for frontend display.
Args:
provisioning_uri: otpauth:// URI to encode in QR code
Returns:
Base64 encoded PNG image as data URI (data:image/png;base64,...)
Note:
If the qrcode library is not installed, returns a placeholder message.
Install with: pip install qrcode[pil]
"""
try:
import qrcode
# Create QR code
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_L,
box_size=10,
border=4,
)
qr.add_data(provisioning_uri)
qr.make(fit=True)
# Generate image
img = qr.make_image(fill_color="black", back_color="white")
# Convert to base64
buffer = io.BytesIO()
img.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
data_uri = f"data:image/png;base64,{img_base64}"
logger.debug("Generated QR code data URI")
return data_uri
except ImportError:
logger.warning("qrcode library not installed, returning placeholder")
return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]"
+8
View File
@@ -24,6 +24,7 @@ class AuthMethodType(str, Enum):
"""Authentication method types.""" """Authentication method types."""
PASSWORD = "password" PASSWORD = "password"
TOTP = "totp"
GOOGLE = "google" GOOGLE = "google"
GITHUB = "github" GITHUB = "github"
MICROSOFT = "microsoft" MICROSOFT = "microsoft"
@@ -66,6 +67,13 @@ class AuditAction(str, Enum):
# Auth method actions # Auth method actions
AUTH_METHOD_ADD = "auth.method.add" AUTH_METHOD_ADD = "auth.method.add"
AUTH_METHOD_REMOVE = "auth.method.remove" AUTH_METHOD_REMOVE = "auth.method.remove"
TOTP_ENROLL_INITIATED = "totp.enroll.initiated"
TOTP_ENROLL_COMPLETED = "totp.enroll.completed"
TOTP_VERIFY_SUCCESS = "totp.verify.success"
TOTP_VERIFY_FAILED = "totp.verify.failed"
TOTP_DISABLED = "totp.disabled"
TOTP_BACKUP_CODE_USED = "totp.backup_code.used"
TOTP_BACKUP_CODES_REGENERATED = "totp.backup_codes.regenerated"
class OIDCGrantType(str, Enum): class OIDCGrantType(str, Enum):
+102
View File
@@ -0,0 +1,102 @@
#!/usr/bin/env bash
set -euo pipefail
ISSUER="https://oidctest.wsweet.org"
CLIENT_ID="secret"
CLIENT_SECRET="tardis"
REDIRECT_URI="http://127.0.0.1:5556/callback"
SCOPE="openid profile email offline_access"
# ---------------------------
# Discover OIDC endpoints
# ---------------------------
DISCOVERY=$(curl -s "$ISSUER/.well-known/openid-configuration")
AUTH_ENDPOINT=$(echo "$DISCOVERY" | jq -r .authorization_endpoint)
TOKEN_ENDPOINT=$(echo "$DISCOVERY" | jq -r .token_endpoint)
USERINFO_ENDPOINT=$(echo "$DISCOVERY" | jq -r .userinfo_endpoint)
echo "Auth endpoint : $AUTH_ENDPOINT"
echo "Token endpoint: $TOKEN_ENDPOINT"
echo
# ---------------------------
# PKCE
# ---------------------------
CODE_VERIFIER=$(openssl rand -base64 32 | tr -d '=+/')
CODE_CHALLENGE=$(echo -n "$CODE_VERIFIER" | openssl dgst -sha256 -binary | openssl base64 | tr -d '=+/' | tr '/+' '_-')
STATE=$(openssl rand -hex 16)
NONCE=$(openssl rand -hex 16)
# ---------------------------
# Build auth URL
# ---------------------------
AUTH_URL="$AUTH_ENDPOINT?response_type=code\
&client_id=$CLIENT_ID\
&redirect_uri=$(printf '%s' "$REDIRECT_URI" | jq -s -R -r @uri)\
&scope=$(printf '%s' "$SCOPE" | jq -s -R -r @uri)\
&state=$STATE\
&nonce=$NONCE\
&code_challenge=$CODE_CHALLENGE\
&code_challenge_method=S256"
echo "Open this URL in a browser:"
echo
echo "$AUTH_URL"
echo
echo "After login you will be redirected to:"
echo "$REDIRECT_URI?code=XXXX&state=YYYY"
echo
read -p "Paste the full redirect URL: " REDIRECT
CODE=$(echo "$REDIRECT" | sed -n 's/.*code=\([^&]*\).*/\1/p')
RETURNED_STATE=$(echo "$REDIRECT" | sed -n 's/.*state=\([^&]*\).*/\1/p')
if [ "$RETURNED_STATE" != "$STATE" ]; then
echo "STATE MISMATCH"
exit 1
fi
# ---------------------------
# Exchange code for tokens
# ---------------------------
TOKENS=$(curl -s -X POST "$TOKEN_ENDPOINT" \
-u "$CLIENT_ID:$CLIENT_SECRET" \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=authorization_code" \
-d "code=$CODE" \
-d "redirect_uri=$REDIRECT_URI" \
-d "code_verifier=$CODE_VERIFIER")
echo
echo "Token response:"
echo "$TOKENS" | jq .
ACCESS_TOKEN=$(echo "$TOKENS" | jq -r .access_token)
ID_TOKEN=$(echo "$TOKENS" | jq -r .id_token)
# ---------------------------
# JWT decode function
# ---------------------------
decode() {
echo "$1" | awk -F. '{print $2}' | tr '_-' '/+' | base64 -d 2>/dev/null | jq .
}
echo
echo "================ ID TOKEN ================"
decode "$ID_TOKEN"
echo
echo "============== ACCESS TOKEN =============="
decode "$ACCESS_TOKEN"
# ---------------------------
# Userinfo (optional)
# ---------------------------
if [ "$USERINFO_ENDPOINT" != "null" ]; then
echo
echo "=============== USERINFO ================="
curl -s -H "Authorization: Bearer $ACCESS_TOKEN" "$USERINFO_ENDPOINT" | jq .
fi
+1
View File
@@ -16,6 +16,7 @@ marshmallow-sqlalchemy==0.29.0
# Security # Security
bcrypt==4.1.2 bcrypt==4.1.2
Flask-Bcrypt==1.0.1 Flask-Bcrypt==1.0.1
pyotp==2.9.0
# JWT / OIDC # JWT / OIDC
PyJWT==2.8.0 PyJWT==2.8.0
+8
View File
@@ -0,0 +1,8 @@
<!doctype html>
<html>
<body>
<h1>OK - protected</h1>
<p>User: __USER__</p>
<p>Email: __EMAIL__</p>
</body>
</html>
+92
View File
@@ -0,0 +1,92 @@
version: "3.9"
services:
nginx:
image: nginx:1.27-alpine
container_name: app-nginx
volumes:
- ./app:/usr/share/nginx/html:ro
- ./nginx.conf:/etc/nginx/conf.d/default.conf:ro
expose:
- "80"
# oauth2-proxy:
# image: quay.io/oauth2-proxy/oauth2-proxy:v7.7.1
# container_name: oauth2-proxy
# depends_on:
# - nginx
# ports:
# - "8086:4180"
# environment:
# # ----- Logging -----
# OAUTH2_PROXY_LOG_LEVEL: "debug"
# OAUTH2_PROXY_AUTH_LOGGING: "true"
# OAUTH2_PROXY_REQUEST_LOGGING: "true"
# OAUTH2_PROXY_STANDARD_LOGGING: "true"
# # ----- Gatehouse OIDC -----
# OAUTH2_PROXY_PROVIDER: oidc
# OAUTH2_PROXY_OIDC_ISSUER_URL: "http://192.168.64.7:8888"
# OAUTH2_PROXY_CLIENT_ID: "acme-portal-001"
# OAUTH2_PROXY_CLIENT_SECRET: "acme_secret_portal_2024"
# OAUTH2_PROXY_REDIRECT_URL: "http://192.168.64.7:8086/oauth2/callback"
# # ----- Session -----
# OAUTH2_PROXY_COOKIE_SECRET: "afhXcfftf5qKezJ217qhED7U4UeVyqBHd7lhISNGpXo="
# OAUTH2_PROXY_COOKIE_SECURE: "false"
# OAUTH2_PROXY_COOKIE_SAMESITE: "lax"
# OAUTH2_PROXY_EMAIL_DOMAINS: "*"
# OAUTH2_PROXY_HTTP_ADDRESS: "0.0.0.0:4180"
# # ----- Upstream -----
# OAUTH2_PROXY_UPSTREAMS: "http://nginx:80"
# # ----- Identity headers -----
# OAUTH2_PROXY_SET_XAUTHREQUEST: "true"
# OAUTH2_PROXY_PASS_ACCESS_TOKEN: "true"
# OAUTH2_PROXY_PASS_AUTHORIZATION_HEADER: "true"
# # ----- Claim mapping (Gatehouse) -----
# OAUTH2_PROXY_OIDC_EMAIL_CLAIM: "email"
# OAUTH2_PROXY_USER_ID_CLAIM: "preferred_username"
# OAUTH2_PROXY_OIDC_GROUPS_CLAIM: "roles"
# OAUTH2_PROXY_SKIP_PROVIDER_BUTTON: "true"
# OAUTH2_PROXY_SCOPE: "openid email profile"
# # OAUTH2_PROXY_OIDC_EMAIL_CLAIM: "email"
# OAUTH2_PROXY_INSECURE_OIDC_ALLOW_UNVERIFIED_EMAIL: "true"
oauth2-proxy:
image: quay.io/oauth2-proxy/oauth2-proxy:v7.7.1
container_name: oauth2-proxy
depends_on: [nginx]
ports:
- "8086:4180"
command:
- --provider=oidc
- --oidc-issuer-url=http://192.168.64.7:8888
- --client-id=acme-portal-001
- --client-secret=acme_secret_portal_2024
- --redirect-url=http://192.168.64.7:8086/oauth2/callback
- --scope=openid email profile
- --email-domain=*
- --upstream=http://nginx:80
- --http-address=0.0.0.0:4180
- --cookie-secret=afhXcfftf5qKezJ217qhED7U4UeVyqBHd7lhISNGpXo=
- --cookie-secure=false
- --cookie-samesite=lax
- --skip-provider-button=true
- --standard-logging=true
- --request-logging=true
- --auth-logging=true
- --set-xauthrequest=true
- --pass-authorization-header=true # optional (token passthrough)
- --pass-access-token=true # optional (token passthrough)
- --pass-user-headers=true
+23
View File
@@ -0,0 +1,23 @@
server {
listen 80;
server_name _;
root /usr/share/nginx/html;
index index.html;
location / {
try_files $uri $uri/ =404;
}
location = /whoami {
default_type text/plain;
return 200
"user: $http_x_auth_request_user
email: $http_x_auth_request_email
preferred_username: $http_x_forwarded_preferred_username
x-forwarded-user: $http_x_forwarded_user
x-forwarded-email: $http_x_forwarded_email
authorization: $http_authorization
";
}
}
@@ -0,0 +1,285 @@
"""Unit tests for TOTPService."""
import base64
import pytest
from app.services.totp_service import TOTPService
@pytest.mark.unit
class TestTOTPService:
"""Tests for TOTPService."""
# Test generate_secret()
def test_generate_secret_returns_string(self):
"""Test that generate_secret returns a string."""
secret = TOTPService.generate_secret()
assert isinstance(secret, str)
def test_generate_secret_length(self):
"""Test that generate_secret returns a 32-character string."""
secret = TOTPService.generate_secret()
assert len(secret) == 32
def test_generate_secret_base32_encoded(self):
"""Test that generate_secret returns a base32 encoded string."""
secret = TOTPService.generate_secret()
# Base32 characters are A-Z and 2-7
valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567")
assert all(c in valid_chars for c in secret)
def test_generate_secret_unique(self):
"""Test that generate_secret produces unique secrets."""
secret1 = TOTPService.generate_secret()
secret2 = TOTPService.generate_secret()
assert secret1 != secret2
# Test generate_provisioning_uri()
def test_generate_provisioning_uri_format(self):
"""Test that provisioning URI is generated correctly."""
email = "user@example.com"
secret = "JBSWY3DPEHPK3PXP"
issuer = "Gatehouse"
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
assert isinstance(uri, str)
assert uri.startswith("otpauth://totp/")
def test_generate_provisioning_uri_contains_email(self):
"""Test that provisioning URI contains the user email."""
email = "user@example.com"
secret = "JBSWY3DPEHPK3PXP"
issuer = "Gatehouse"
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
assert email in uri
def test_generate_provisioning_uri_contains_secret(self):
"""Test that provisioning URI contains the secret."""
email = "user@example.com"
secret = "JBSWY3DPEHPK3PXP"
issuer = "Gatehouse"
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
assert secret in uri
def test_generate_provisioning_uri_contains_issuer(self):
"""Test that provisioning URI contains the issuer."""
email = "user@example.com"
secret = "JBSWY3DPEHPK3PXP"
issuer = "Gatehouse"
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
assert issuer in uri
def test_generate_provisioning_uri_custom_issuer(self):
"""Test that provisioning URI uses custom issuer."""
email = "user@example.com"
secret = "JBSWY3DPEHPK3PXP"
custom_issuer = "MyApp"
uri = TOTPService.generate_provisioning_uri(email, secret, custom_issuer)
assert custom_issuer in uri
# Test verify_code()
def test_verify_code_valid(self):
"""Test that a valid TOTP code is accepted."""
secret = TOTPService.generate_secret()
# Generate a valid code using pyotp
import pyotp
totp = pyotp.TOTP(secret)
valid_code = totp.now()
result = TOTPService.verify_code(secret, valid_code)
assert result is True
def test_verify_code_invalid(self):
"""Test that an invalid TOTP code is rejected."""
secret = TOTPService.generate_secret()
invalid_code = "000000"
result = TOTPService.verify_code(secret, invalid_code)
assert result is False
def test_verify_code_window_parameter(self):
"""Test that the time window parameter works correctly."""
secret = TOTPService.generate_secret()
import pyotp
totp = pyotp.TOTP(secret)
# Get current code
current_code = totp.now()
# Verify with window=1 (default) - should accept current code
result = TOTPService.verify_code(secret, current_code, window=1)
assert result is True
# Verify with window=0 - should only accept exact time match
result = TOTPService.verify_code(secret, current_code, window=0)
assert result is True
def test_verify_code_wrong_length(self):
"""Test that codes with wrong length are rejected."""
secret = TOTPService.generate_secret()
wrong_length_code = "12345" # 5 digits instead of 6
result = TOTPService.verify_code(secret, wrong_length_code)
assert result is False
# Test generate_backup_codes()
def test_generate_backup_codes_default_count(self):
"""Test that generate_backup_codes generates 10 codes by default."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
assert len(plain_codes) == 10
assert len(hashed_codes) == 10
def test_generate_backup_codes_custom_count(self):
"""Test that generate_backup_codes generates the specified number of codes."""
count = 5
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count)
assert len(plain_codes) == count
assert len(hashed_codes) == count
def test_generate_backup_codes_plain_are_strings(self):
"""Test that plain backup codes are strings."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
assert all(isinstance(code, str) for code in plain_codes)
def test_generate_backup_codes_plain_length(self):
"""Test that plain backup codes are 16 characters long."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
assert all(len(code) == 16 for code in plain_codes)
def test_generate_backup_codes_hashed_different_from_plain(self):
"""Test that hashed codes are different from plain codes."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
for plain, hashed in zip(plain_codes, hashed_codes):
assert plain != hashed
def test_generate_backup_codes_are_bcrypt_hashes(self):
"""Test that hashed codes are bcrypt hashes."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
# Bcrypt hashes start with $2a$, $2b$, or $2y$
for hashed in hashed_codes:
assert hashed.startswith("$2")
def test_generate_backup_codes_unique(self):
"""Test that generated backup codes are unique."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
assert len(set(plain_codes)) == len(plain_codes)
assert len(set(hashed_codes)) == len(hashed_codes)
# Test verify_backup_code()
def test_verify_backup_code_valid(self):
"""Test that a valid backup code is accepted and removed."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=3)
code_to_verify = plain_codes[0]
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
assert is_valid is True
assert len(remaining_codes) == 2
def test_verify_backup_code_invalid(self):
"""Test that an invalid backup code is rejected."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=3)
invalid_code = "INVALIDCODE1234"
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, invalid_code)
assert is_valid is False
assert len(remaining_codes) == 3
def test_verify_backup_code_remaining_updated(self):
"""Test that the remaining codes list is updated correctly."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=5)
code_to_verify = plain_codes[2]
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
assert is_valid is True
# The verified code should be removed
assert len(remaining_codes) == 4
# The remaining codes should not include the verified code's hash
assert hashed_codes[2] not in remaining_codes
def test_verify_backup_code_case_sensitive(self):
"""Test that backup code verification is case sensitive."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=1)
code_to_verify = plain_codes[0].lower() # Convert to lowercase
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
assert is_valid is False
assert len(remaining_codes) == 1
def test_verify_backup_code_single_use(self):
"""Test that a backup code can only be used once."""
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=1)
code_to_verify = plain_codes[0]
# First use - should succeed
is_valid1, remaining1 = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
assert is_valid1 is True
assert len(remaining1) == 0
# Second use - should fail (code already consumed)
is_valid2, remaining2 = TOTPService.verify_backup_code(remaining1, code_to_verify)
assert is_valid2 is False
assert len(remaining2) == 0
# Test generate_qr_code_data_uri()
def test_generate_qr_code_data_uri_format(self):
"""Test that a data URI is generated."""
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
assert isinstance(data_uri, str)
def test_generate_qr_code_data_uri_starts_with_prefix(self):
"""Test that the data URI starts with the correct prefix."""
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
assert data_uri.startswith("data:image/png;base64,")
def test_generate_qr_code_data_uri_contains_base64(self):
"""Test that the data URI contains base64 encoded data."""
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
# Extract the base64 part (after the prefix)
base64_part = data_uri.split("data:image/png;base64,")[1]
# Verify it's valid base64
try:
base64.b64decode(base64_part)
assert True
except Exception:
assert False, "Data URI does not contain valid base64 data"
def test_generate_qr_code_data_uri_different_uris(self):
"""Test that different provisioning URIs generate different QR codes."""
uri1 = "otpauth://totp/Gatehouse:user1@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
uri2 = "otpauth://totp/Gatehouse:user2@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
data_uri1 = TOTPService.generate_qr_code_data_uri(uri1)
data_uri2 = TOTPService.generate_qr_code_data_uri(uri2)
assert data_uri1 != data_uri2