feat(auth): implement TOTP two-factor authentication with enrollment and verification
Adds TOTP (Time-based One-Time Password) two-factor authentication support including: - New TOTP service with secret generation, QR code provisioning, and code verification - New auth endpoints for enrollment, verification, status, and backup code management - New TOTP authentication method type and user methods for TOTP management - Backup codes generation and verification for account recovery - Updated OIDC endpoints with timezone-aware datetime handling and RFC-compliant responses - Added "roles" scope support for OIDC userinfo and ID tokens - New pyotp dependency for TOTP operations - Comprehensive unit tests for TOTP service
This commit is contained in:
@@ -305,3 +305,7 @@ client_secret: acme_secret_portal_2024
|
|||||||
## User
|
## User
|
||||||
email: bob@acme-corp.com
|
email: bob@acme-corp.com
|
||||||
password: UserPass123!
|
password: UserPass123!
|
||||||
|
|
||||||
|
|
||||||
|
## Sqlite editor
|
||||||
|
sqlite_web instance/db_file.db --port 9999 --host 0.0.0.0
|
||||||
+297
-187
@@ -3,6 +3,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
|
from datetime import datetime, timezone
|
||||||
from urllib.parse import urlencode, urlparse, parse_qs
|
from urllib.parse import urlencode, urlparse, parse_qs
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
@@ -42,14 +43,14 @@ def get_oidc_config():
|
|||||||
"registration_endpoint": f"{base_url}/oidc/register",
|
"registration_endpoint": f"{base_url}/oidc/register",
|
||||||
"revocation_endpoint": f"{base_url}/oidc/revoke",
|
"revocation_endpoint": f"{base_url}/oidc/revoke",
|
||||||
"introspection_endpoint": f"{base_url}/oidc/introspect",
|
"introspection_endpoint": f"{base_url}/oidc/introspect",
|
||||||
"scopes_supported": ["openid", "profile", "email"],
|
"scopes_supported": ["openid", "profile", "email", "roles"],
|
||||||
"response_types_supported": ["code"],
|
"response_types_supported": ["code"],
|
||||||
"response_modes_supported": ["query"],
|
"response_modes_supported": ["query"],
|
||||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||||
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
|
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
|
||||||
"subject_types_supported": ["public"],
|
"subject_types_supported": ["public"],
|
||||||
"id_token_signing_alg_values_supported": ["RS256"],
|
"id_token_signing_alg_values_supported": ["RS256"],
|
||||||
"claims_supported": ["sub", "name", "email", "email_verified"],
|
"claims_supported": ["sub", "name", "email", "email_verified", "roles"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -94,19 +95,49 @@ def require_valid_token():
|
|||||||
Raises:
|
Raises:
|
||||||
InvalidGrantError: If token is invalid
|
InvalidGrantError: If token is invalid
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[OIDC USERINFO] ===========================================")
|
||||||
|
logger.debug("[OIDC USERINFO] require_valid_token() called")
|
||||||
|
logger.debug("[OIDC USERINFO] Request method: %s", request.method)
|
||||||
|
logger.debug("[OIDC USERINFO] Request URL: %s", request.url)
|
||||||
|
logger.debug("[OIDC USERINFO] Request headers: %s", dict(request.headers))
|
||||||
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
logger.debug("[OIDC USERINFO] Authorization header: %s", auth_header[:20] + "..." if len(auth_header) > 20 else auth_header)
|
||||||
|
|
||||||
if not auth_header.startswith("Bearer "):
|
if not auth_header.startswith("Bearer "):
|
||||||
|
logger.error("[OIDC USERINFO] Invalid Authorization header format - missing 'Bearer ' prefix")
|
||||||
raise InvalidGrantError("Invalid token: Missing or invalid Authorization header")
|
raise InvalidGrantError("Invalid token: Missing or invalid Authorization header")
|
||||||
|
|
||||||
token = auth_header[7:]
|
token = auth_header[7:]
|
||||||
claims = OIDCService.validate_access_token(token)
|
logger.debug("[OIDC USERINFO] Token extracted (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
|
||||||
g.current_token = claims
|
logger.debug("[OIDC USERINFO] Token length: %d", len(token))
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug("[OIDC USERINFO] Calling OIDCService.validate_access_token()...")
|
||||||
|
claims = OIDCService.validate_access_token(token)
|
||||||
|
logger.debug("[OIDC USERINFO] Token validation successful")
|
||||||
|
logger.debug("[OIDC USERINFO] Token claims: %s", claims)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("[OIDC USERINFO] Token validation failed: %s: %s", type(e).__name__, str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
g.current_token = claims
|
||||||
|
g.access_token = token # Store the original access token for get_userinfo()
|
||||||
|
logger.debug("[OIDC USERINFO] g.current_token set")
|
||||||
|
|
||||||
|
user_id = claims.get("sub")
|
||||||
|
logger.debug("[OIDC USERINFO] User ID from token: %s", user_id)
|
||||||
|
|
||||||
|
user = User.query.get(user_id)
|
||||||
|
logger.debug("[OIDC USERINFO] User query result: %s", user)
|
||||||
|
|
||||||
user = User.query.get(claims.get("sub"))
|
|
||||||
if not user:
|
if not user:
|
||||||
|
logger.error("[OIDC USERINFO] User not found in database: user_id=%s", user_id)
|
||||||
raise InvalidGrantError("Invalid token: User not found")
|
raise InvalidGrantError("Invalid token: User not found")
|
||||||
|
|
||||||
g.current_user = user
|
g.current_user = user
|
||||||
|
logger.debug("[OIDC USERINFO] g.current_user set: user_id=%s, email=%s", user.id, user.email)
|
||||||
|
logger.debug("[OIDC USERINFO] require_valid_token() completed successfully")
|
||||||
|
|
||||||
|
|
||||||
def _check_password_hash(client, password):
|
def _check_password_hash(client, password):
|
||||||
@@ -175,10 +206,11 @@ def oidc_discovery():
|
|||||||
No authentication required.
|
No authentication required.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
200: OIDC discovery document
|
200: OIDC discovery document (application/json)
|
||||||
"""
|
"""
|
||||||
config = get_oidc_config()
|
config = get_oidc_config()
|
||||||
|
|
||||||
|
# Return discovery document as application/json (per OpenID Connect Discovery 1.0)
|
||||||
response = jsonify(config)
|
response = jsonify(config)
|
||||||
response.headers["Cache-Control"] = "max-age=86400"
|
response.headers["Cache-Control"] = "max-age=86400"
|
||||||
return response, 200
|
return response, 200
|
||||||
@@ -217,8 +249,12 @@ def oidc_authorize():
|
|||||||
200: Login page (GET when not authenticated)
|
200: Login page (GET when not authenticated)
|
||||||
400: Invalid request
|
400: Invalid request
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[OIDC] ===========================================")
|
||||||
logger.debug("[OIDC] oidc_authorize called")
|
logger.debug("[OIDC] oidc_authorize called")
|
||||||
|
logger.debug("[OIDC] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC] Request method: %s", request.method)
|
||||||
|
logger.debug("[OIDC] Request URL: %s", request.url)
|
||||||
|
logger.debug("[OIDC] Remote address: %s", request.remote_addr)
|
||||||
|
|
||||||
# Parse request parameters
|
# Parse request parameters
|
||||||
if request.method == "GET":
|
if request.method == "GET":
|
||||||
@@ -227,6 +263,7 @@ def oidc_authorize():
|
|||||||
params = request.form.to_dict()
|
params = request.form.to_dict()
|
||||||
|
|
||||||
logger.debug("[OIDC] Raw request params: %s", params)
|
logger.debug("[OIDC] Raw request params: %s", params)
|
||||||
|
|
||||||
# Extract required parameters
|
# Extract required parameters
|
||||||
logger.debug("[OIDC] Extracting request parameters...")
|
logger.debug("[OIDC] Extracting request parameters...")
|
||||||
client_id = params.get("client_id")
|
client_id = params.get("client_id")
|
||||||
@@ -367,6 +404,7 @@ def oidc_authorize():
|
|||||||
|
|
||||||
# User is authenticated, generate authorization code
|
# User is authenticated, generate authorization code
|
||||||
logger.debug("[OIDC] User is authenticated, fetching user from database...")
|
logger.debug("[OIDC] User is authenticated, fetching user from database...")
|
||||||
|
logger.debug("[OIDC] Current UTC time before code generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
user = User.query.get(user_id)
|
user = User.query.get(user_id)
|
||||||
logger.debug("[OIDC] User query result: %s", user)
|
logger.debug("[OIDC] User query result: %s", user)
|
||||||
|
|
||||||
@@ -393,12 +431,17 @@ def oidc_authorize():
|
|||||||
user_agent=request.headers.get("User-Agent"),
|
user_agent=request.headers.get("User-Agent"),
|
||||||
)
|
)
|
||||||
logger.debug("[OIDC] Authorization code generated successfully: %s...", code[:20] if code else None)
|
logger.debug("[OIDC] Authorization code generated successfully: %s...", code[:20] if code else None)
|
||||||
|
logger.debug("[OIDC] Current UTC time after code generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("[OIDC] Authorization code generation failed: %s", str(e))
|
logger.error("[OIDC] Authorization code generation failed: %s", str(e))
|
||||||
|
logger.error("[OIDC] Current UTC time at failure: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
import traceback
|
||||||
|
logger.error("[OIDC] Traceback: %s", traceback.format_exc())
|
||||||
return _redirect_with_error(redirect_uri, "server_error", str(e), state)
|
return _redirect_with_error(redirect_uri, "server_error", str(e), state)
|
||||||
|
|
||||||
# Redirect with authorization code
|
# Redirect with authorization code
|
||||||
logger.debug("[OIDC] Redirecting with authorization code...")
|
logger.debug("[OIDC] Redirecting with authorization code...")
|
||||||
|
logger.debug("[OIDC] Current UTC time before redirect: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
redirect_params = {"code": code}
|
redirect_params = {"code": code}
|
||||||
if state:
|
if state:
|
||||||
redirect_params["state"] = state
|
redirect_params["state"] = state
|
||||||
@@ -406,6 +449,7 @@ def oidc_authorize():
|
|||||||
redirect_url = f"{redirect_uri}?{urlencode(redirect_params)}"
|
redirect_url = f"{redirect_uri}?{urlencode(redirect_params)}"
|
||||||
logger.debug("[OIDC] Redirect URL: %s...", redirect_url[:100] if redirect_url else None)
|
logger.debug("[OIDC] Redirect URL: %s...", redirect_url[:100] if redirect_url else None)
|
||||||
logger.debug("[OIDC] oidc_authorize completed successfully")
|
logger.debug("[OIDC] oidc_authorize completed successfully")
|
||||||
|
logger.debug("[OIDC] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
logger.debug("[OIDC] ===========================================")
|
logger.debug("[OIDC] ===========================================")
|
||||||
|
|
||||||
return redirect(redirect_url)
|
return redirect(redirect_url)
|
||||||
@@ -544,14 +588,13 @@ def oidc_token():
|
|||||||
|
|
||||||
# Validate grant_type
|
# Validate grant_type
|
||||||
if not grant_type:
|
if not grant_type:
|
||||||
logger.error("[OIDC] grant_type is requred")
|
logger.error("[OIDC] grant_type is required")
|
||||||
return api_response(
|
# RFC 6749 Section 5.2: Error response for invalid request
|
||||||
success=False,
|
response = jsonify({
|
||||||
message="grant_type is required",
|
"error": "invalid_request",
|
||||||
status=400,
|
"error_description": "grant_type is required"
|
||||||
error_type="INVALID_REQUEST",
|
})
|
||||||
error_details={"error": "invalid_request", "error_description": "grant_type is required"},
|
return response, 400
|
||||||
)
|
|
||||||
|
|
||||||
# Authenticate client
|
# Authenticate client
|
||||||
client_id = data.get("client_id")
|
client_id = data.get("client_id")
|
||||||
@@ -600,46 +643,51 @@ def oidc_token():
|
|||||||
# Unsupported grant type
|
# Unsupported grant type
|
||||||
else:
|
else:
|
||||||
logger.error("[OIDC] Unsupported grant_type")
|
logger.error("[OIDC] Unsupported grant_type")
|
||||||
return api_response(
|
# RFC 6749 Section 5.2: Error response for unsupported grant type
|
||||||
success=False,
|
response = jsonify({
|
||||||
message="Unsupported grant_type",
|
"error": "unsupported_grant_type",
|
||||||
status=400,
|
"error_description": f"Grant type '{grant_type}' is not supported"
|
||||||
error_type="UNSUPPORTED_GRANT_TYPE",
|
})
|
||||||
error_details={"error": "unsupported_grant_type", "error_description": f"Grant type '{grant_type}' is not supported"},
|
return response, 400
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_authorization_code_grant(data, client):
|
def _handle_authorization_code_grant(data, client):
|
||||||
"""Handle authorization_code grant type."""
|
"""Handle authorization_code grant type."""
|
||||||
|
logger.debug("[OIDC] ===========================================")
|
||||||
|
logger.debug("[OIDC] _handle_authorization_code_grant called")
|
||||||
|
logger.debug("[OIDC] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
|
||||||
code = data.get("code")
|
code = data.get("code")
|
||||||
redirect_uri = data.get("redirect_uri")
|
redirect_uri = data.get("redirect_uri")
|
||||||
code_verifier = data.get("code_verifier")
|
code_verifier = data.get("code_verifier")
|
||||||
|
|
||||||
|
logger.debug("[OIDC] Code provided: %s", bool(code))
|
||||||
|
logger.debug("[OIDC] Redirect URI: %s", redirect_uri)
|
||||||
|
logger.debug("[OIDC] Code verifier provided: %s", bool(code_verifier))
|
||||||
|
|
||||||
if not code:
|
if not code:
|
||||||
logger.error("[OIDC] code is required")
|
logger.error("[OIDC] code is required")
|
||||||
return api_response(
|
# RFC 6749 Section 5.2: Error response for invalid request
|
||||||
success=False,
|
response = jsonify({
|
||||||
message="code is required",
|
"error": "invalid_request",
|
||||||
status=400,
|
"error_description": "code is required"
|
||||||
error_type="INVALID_REQUEST",
|
})
|
||||||
error_details={"error": "invalid_request", "error_description": "code is required"},
|
return response, 400
|
||||||
)
|
|
||||||
|
|
||||||
if not redirect_uri:
|
if not redirect_uri:
|
||||||
logger.error("[OIDC] redirect_uri is required")
|
logger.error("[OIDC] redirect_uri is required")
|
||||||
return api_response(
|
response = jsonify({
|
||||||
success=False,
|
"error": "invalid_request",
|
||||||
message="redirect_uri is required",
|
"error_description": "redirect_uri is required"
|
||||||
status=400,
|
})
|
||||||
error_type="INVALID_REQUEST",
|
return response, 400
|
||||||
error_details={"error": "invalid_request", "error_description": "redirect_uri is required"},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Development-only debug logging for authorization code validation
|
# Development-only debug logging for authorization code validation
|
||||||
if current_app.config.get('ENV') == 'development':
|
if current_app.config.get('ENV') == 'development':
|
||||||
logger.debug(f"[OIDC] Authorization code validation: client_id={client.client_id}, code_provided=True")
|
logger.debug(f"[OIDC] Authorization code validation: client_id={client.client_id}, code_provided=True")
|
||||||
|
|
||||||
|
logger.debug("[OIDC] Current UTC time before code validation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
claims, user = OIDCService.validate_authorization_code(
|
claims, user = OIDCService.validate_authorization_code(
|
||||||
code=code,
|
code=code,
|
||||||
client_id=client.client_id,
|
client_id=client.client_id,
|
||||||
@@ -649,23 +697,22 @@ def _handle_authorization_code_grant(data, client):
|
|||||||
user_agent=request.headers.get("User-Agent"),
|
user_agent=request.headers.get("User-Agent"),
|
||||||
)
|
)
|
||||||
except InvalidGrantError as e:
|
except InvalidGrantError as e:
|
||||||
logger.error(f"[OIDC] INVALID_GRANT: {str(e)}")
|
logger.error("[OIDC] INVALID_GRANT: %s", str(e))
|
||||||
return api_response(
|
logger.error("[OIDC] Current UTC time at validation failure: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
success=False,
|
# RFC 6749 Section 5.2: Error response for invalid grant
|
||||||
message=str(e),
|
response = jsonify({
|
||||||
status=400,
|
"error": "invalid_grant",
|
||||||
error_type="INVALID_GRANT",
|
"error_description": str(e)
|
||||||
error_details={"error": "invalid_grant", "error_description": str(e)},
|
})
|
||||||
)
|
return response, 400
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[OIDC] Authorization code validation error: {type(e).__name__}: {str(e)}")
|
logger.error("[OIDC] Authorization code validation error: %s: %s", type(e).__name__, str(e))
|
||||||
return api_response(
|
logger.error("[OIDC] Current UTC time at validation error: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
success=False,
|
response = jsonify({
|
||||||
message=str(e),
|
"error": "invalid_grant",
|
||||||
status=400,
|
"error_description": str(e)
|
||||||
error_type="INVALID_GRANT",
|
})
|
||||||
error_details={"error": "invalid_grant", "error_description": str(e)},
|
return response, 400
|
||||||
)
|
|
||||||
|
|
||||||
# Generate tokens
|
# Generate tokens
|
||||||
try:
|
try:
|
||||||
@@ -673,6 +720,8 @@ def _handle_authorization_code_grant(data, client):
|
|||||||
if current_app.config.get('ENV') == 'development':
|
if current_app.config.get('ENV') == 'development':
|
||||||
logger.debug(f"[OIDC] Token generation: client_id={client.client_id}, user_id={claims['user_id']}, scope={claims['scope']}")
|
logger.debug(f"[OIDC] Token generation: client_id={client.client_id}, user_id={claims['user_id']}, scope={claims['scope']}")
|
||||||
|
|
||||||
|
logger.debug("[OIDC] Current UTC time before token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
|
||||||
tokens = OIDCService.generate_tokens(
|
tokens = OIDCService.generate_tokens(
|
||||||
client_id=client.client_id,
|
client_id=client.client_id,
|
||||||
user_id=claims["user_id"],
|
user_id=claims["user_id"],
|
||||||
@@ -683,40 +732,64 @@ def _handle_authorization_code_grant(data, client):
|
|||||||
auth_time=int(__import__("time").time()),
|
auth_time=int(__import__("time").time()),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[OIDC] Failed to generate tokens {str(e)}")
|
logger.error("[OIDC] Failed to generate tokens: %s", str(e))
|
||||||
return api_response(
|
logger.error("[OIDC] Current UTC time at token generation failure: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
success=False,
|
response = jsonify({
|
||||||
message="Failed to generate tokens",
|
"error": "server_error",
|
||||||
status=500,
|
"error_description": str(e)
|
||||||
error_type="SERVER_ERROR",
|
})
|
||||||
error_details={"error": "server_error", "error_description": str(e)},
|
return response, 500
|
||||||
)
|
|
||||||
|
|
||||||
return api_response(
|
# Return standard OAuth2/OIDC token response (application/json)
|
||||||
data=tokens,
|
# Per RFC 6749 Section 5.1 and OIDC Core 1.0
|
||||||
message="Tokens issued successfully",
|
logger.debug("[OIDC] Current UTC time after token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
status=200,
|
logger.debug("[OIDC] _handle_authorization_code_grant completed successfully")
|
||||||
)
|
|
||||||
|
# Echo tokens to console for diagnostics
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Authorization code exchange completed")
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')}..." if len(tokens.get('access_token', '')) > 50 else f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')}")
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Token Type: {tokens.get('token_type', '')}")
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Expires In: {tokens.get('expires_in', '')}")
|
||||||
|
if 'id_token' in tokens:
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token']}..." if len(tokens['id_token']) > 50 else f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token']}")
|
||||||
|
if 'refresh_token' in tokens:
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token'][:50]}..." if len(tokens['refresh_token']) > 50 else f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token']}")
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Scope: {tokens.get('scope', '')}")
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] ===========================================")
|
||||||
|
|
||||||
|
logger.debug("[OIDC] ===========================================")
|
||||||
|
response = jsonify(tokens)
|
||||||
|
print(tokens)
|
||||||
|
response.headers["Cache-Control"] = "no-store"
|
||||||
|
response.headers["Pragma"] = "no-cache"
|
||||||
|
return response, 200
|
||||||
|
|
||||||
|
|
||||||
def _handle_refresh_token_grant(data, client):
|
def _handle_refresh_token_grant(data, client):
|
||||||
"""Handle refresh_token grant type."""
|
"""Handle refresh_token grant type."""
|
||||||
|
logger.debug("[OIDC] ===========================================")
|
||||||
|
logger.debug("[OIDC] _handle_refresh_token_grant called")
|
||||||
|
logger.debug("[OIDC] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
|
||||||
refresh_token = data.get("refresh_token")
|
refresh_token = data.get("refresh_token")
|
||||||
scope = data.get("scope")
|
scope = data.get("scope")
|
||||||
|
|
||||||
|
logger.debug("[OIDC] Refresh token provided: %s", bool(refresh_token))
|
||||||
|
logger.debug("[OIDC] Scope: %s", scope)
|
||||||
|
|
||||||
if not refresh_token:
|
if not refresh_token:
|
||||||
return api_response(
|
# RFC 6749 Section 5.2: Error response for invalid request
|
||||||
success=False,
|
response = jsonify({
|
||||||
message="refresh_token is required",
|
"error": "invalid_request",
|
||||||
status=400,
|
"error_description": "refresh_token is required"
|
||||||
error_type="INVALID_REQUEST",
|
})
|
||||||
error_details={"error": "invalid_request", "error_description": "refresh_token is required"},
|
return response, 400
|
||||||
)
|
|
||||||
|
|
||||||
# Parse scope if provided
|
# Parse scope if provided
|
||||||
scope_list = scope.split() if scope else None
|
scope_list = scope.split() if scope else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.debug("[OIDC] Current UTC time before token refresh: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
tokens = OIDCService.refresh_access_token(
|
tokens = OIDCService.refresh_access_token(
|
||||||
refresh_token=refresh_token,
|
refresh_token=refresh_token,
|
||||||
client_id=client.client_id,
|
client_id=client.client_id,
|
||||||
@@ -725,19 +798,37 @@ def _handle_refresh_token_grant(data, client):
|
|||||||
user_agent=request.headers.get("User-Agent"),
|
user_agent=request.headers.get("User-Agent"),
|
||||||
)
|
)
|
||||||
except InvalidGrantError as e:
|
except InvalidGrantError as e:
|
||||||
return api_response(
|
logger.error("[OIDC] Refresh token error: %s", str(e))
|
||||||
success=False,
|
logger.error("[OIDC] Current UTC time at refresh failure: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
message=str(e),
|
# RFC 6749 Section 5.2: Error response for invalid grant
|
||||||
status=400,
|
response = jsonify({
|
||||||
error_type="INVALID_GRANT",
|
"error": "invalid_grant",
|
||||||
error_details={"error": "invalid_grant", "error_description": str(e)},
|
"error_description": str(e)
|
||||||
)
|
})
|
||||||
|
return response, 400
|
||||||
|
|
||||||
return api_response(
|
# Return standard OAuth2/OIDC token response (application/json)
|
||||||
data=tokens,
|
# Per RFC 6749 Section 5.1 and OIDC Core 1.0
|
||||||
message="Tokens refreshed successfully",
|
logger.debug("[OIDC] Current UTC time after token refresh: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
status=200,
|
logger.debug("[OIDC] _handle_refresh_token_grant completed successfully")
|
||||||
)
|
|
||||||
|
# Echo tokens to console for diagnostics
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Token refresh completed")
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')[:50]}..." if len(tokens.get('access_token', '')) > 50 else f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')}")
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Token Type: {tokens.get('token_type', '')}")
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Expires In: {tokens.get('expires_in', '')}")
|
||||||
|
if 'id_token' in tokens:
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token'][:50]}..." if len(tokens['id_token']) > 50 else f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token']}")
|
||||||
|
if 'refresh_token' in tokens:
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token'][:50]}..." if len(tokens['refresh_token']) > 50 else f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token']}")
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] Scope: {tokens.get('scope', '')}")
|
||||||
|
print(f"[TOKEN DIAGNOSTICS] ===========================================")
|
||||||
|
|
||||||
|
logger.debug("[OIDC] ===========================================")
|
||||||
|
response = jsonify(tokens)
|
||||||
|
response.headers["Cache-Control"] = "no-store"
|
||||||
|
response.headers["Pragma"] = "no-cache"
|
||||||
|
return response, 200
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -759,37 +850,72 @@ def oidc_userinfo():
|
|||||||
- email_verified: Email verification status (if "email" scope)
|
- email_verified: Email verification status (if "email" scope)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
200: User claims
|
200: User claims in JSON format (application/json)
|
||||||
401: Invalid or insufficient token
|
401: Invalid or missing token (with WWW-Authenticate header per RFC 6750)
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[OIDC USERINFO] ===========================================")
|
||||||
|
logger.debug("[OIDC USERINFO] oidc_userinfo() endpoint called")
|
||||||
|
logger.debug("[OIDC USERINFO] Request method: %s", request.method)
|
||||||
|
logger.debug("[OIDC USERINFO] Request URL: %s", request.url)
|
||||||
|
logger.debug("[OIDC USERINFO] Request content_type: %s", request.content_type)
|
||||||
|
logger.debug("[OIDC USERINFO] Request headers: %s", dict(request.headers))
|
||||||
|
logger.debug("[OIDC USERINFO] Request args: %s", dict(request.args))
|
||||||
|
logger.debug("[OIDC USERINFO] Request form: %s", dict(request.form))
|
||||||
|
request_json = request.get_json(silent=True)
|
||||||
|
logger.debug("[OIDC USERINFO] Request json: %s", request_json)
|
||||||
|
logger.debug("[OIDC USERINFO] Request data length: %d bytes", len(request.get_data()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.debug("[OIDC USERINFO] Calling require_valid_token()...")
|
||||||
require_valid_token()
|
require_valid_token()
|
||||||
|
logger.debug("[OIDC USERINFO] Token validation successful")
|
||||||
except InvalidGrantError as e:
|
except InvalidGrantError as e:
|
||||||
return api_response(
|
logger.error("[OIDC USERINFO] Token validation failed: %s", str(e))
|
||||||
success=False,
|
# RFC 6750 Section 3: Return 401 with WWW-Authenticate header for invalid tokens
|
||||||
message=str(e),
|
response = jsonify({
|
||||||
status=401,
|
"error": "invalid_token",
|
||||||
error_type="INVALID_TOKEN",
|
"error_description": str(e)
|
||||||
error_details={"error": "invalid_token", "error_description": str(e)},
|
})
|
||||||
)
|
response.headers["WWW-Authenticate"] = 'Bearer realm="OIDC UserInfo Endpoint", error="invalid_token", error_description="' + str(e) + '"'
|
||||||
|
return response, 401
|
||||||
# Get userinfo
|
|
||||||
try:
|
|
||||||
userinfo = OIDCService.get_userinfo(g.current_token.get("access_token", ""))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return api_response(
|
logger.error("[OIDC USERINFO] Unexpected error during token validation: %s: %s", type(e).__name__, str(e))
|
||||||
success=False,
|
response = jsonify({
|
||||||
message="Failed to get user info",
|
"error": "server_error",
|
||||||
status=500,
|
"error_description": str(e)
|
||||||
error_type="SERVER_ERROR",
|
})
|
||||||
error_details={"error": "server_error", "error_description": str(e)},
|
response.headers["WWW-Authenticate"] = 'Bearer realm="OIDC UserInfo Endpoint", error="server_error"'
|
||||||
)
|
return response, 500
|
||||||
|
|
||||||
return api_response(
|
logger.debug("[OIDC USERINFO] g.current_token: %s", g.current_token)
|
||||||
data=userinfo,
|
logger.debug("[OIDC USERINFO] g.current_user: user_id=%s, email=%s", g.current_user.id, g.current_user.email)
|
||||||
message="User info retrieved successfully",
|
|
||||||
status=200,
|
# Get userinfo using the original access token
|
||||||
)
|
access_token = g.access_token
|
||||||
|
logger.debug("[OIDC USERINFO] Access token from g.access_token: %s...", access_token[:50] if len(access_token) > 50 else access_token)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug("[OIDC USERINFO] Calling OIDCService.get_userinfo()...")
|
||||||
|
userinfo = OIDCService.get_userinfo(access_token)
|
||||||
|
logger.debug("[OIDC USERINFO] Userinfo retrieved successfully: %s", userinfo)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("[OIDC USERINFO] Failed to get user info: %s: %s", type(e).__name__, str(e))
|
||||||
|
import traceback
|
||||||
|
logger.error("[OIDC USERINFO] Traceback: %s", traceback.format_exc())
|
||||||
|
response = jsonify({
|
||||||
|
"error": "server_error",
|
||||||
|
"error_description": str(e)
|
||||||
|
})
|
||||||
|
return response, 500
|
||||||
|
|
||||||
|
logger.debug("[OIDC USERINFO] Returning userinfo response")
|
||||||
|
logger.debug("[OIDC USERINFO] ===========================================")
|
||||||
|
|
||||||
|
# Return standard OIDC UserInfo response (application/json)
|
||||||
|
# Per OpenID Connect Core 1.0 Section 5.3, response is a JSON object
|
||||||
|
response = jsonify(userinfo)
|
||||||
|
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||||
|
return response, 200
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -806,19 +932,18 @@ def oidc_jwks():
|
|||||||
No authentication required.
|
No authentication required.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
200: JWKS document
|
200: JWKS document (application/json)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
jwks = OIDCService.get_jwks()
|
jwks = OIDCService.get_jwks()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return api_response(
|
response = jsonify({
|
||||||
success=False,
|
"error": "server_error",
|
||||||
message="Failed to get JWKS",
|
"error_description": str(e)
|
||||||
status=500,
|
})
|
||||||
error_type="SERVER_ERROR",
|
return response, 500
|
||||||
error_details={"error": "server_error", "error_description": str(e)},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Return JWKS as application/json (per OpenID Connect Discovery 1.0)
|
||||||
response = jsonify(jwks)
|
response = jsonify(jwks)
|
||||||
response.headers["Cache-Control"] = "max-age=3600"
|
response.headers["Cache-Control"] = "max-age=3600"
|
||||||
return response, 200
|
return response, 200
|
||||||
@@ -858,13 +983,12 @@ def oidc_revoke():
|
|||||||
token = data.get("token")
|
token = data.get("token")
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
return api_response(
|
# RFC 7009 Section 2.1: Error response for invalid request
|
||||||
success=False,
|
response = jsonify({
|
||||||
message="token is required",
|
"error": "invalid_request",
|
||||||
status=400,
|
"error_description": "token is required"
|
||||||
error_type="INVALID_REQUEST",
|
})
|
||||||
error_details={"error": "invalid_request", "error_description": "token is required"},
|
return response, 400
|
||||||
)
|
|
||||||
|
|
||||||
# Authenticate client
|
# Authenticate client
|
||||||
client_id = data.get("client_id")
|
client_id = data.get("client_id")
|
||||||
@@ -901,13 +1025,11 @@ def oidc_revoke():
|
|||||||
user_agent=request.headers.get("User-Agent"),
|
user_agent=request.headers.get("User-Agent"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Revocation should succeed even if token is invalid
|
# Revocation should succeed even if token is invalid (RFC 7009)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return api_response(
|
# RFC 7009 Section 2.2: Successful revocation returns empty body with 200
|
||||||
message="Token revoked successfully",
|
return "", 200
|
||||||
status=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -944,13 +1066,12 @@ def oidc_introspect():
|
|||||||
token = data.get("token")
|
token = data.get("token")
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
return api_response(
|
# RFC 7009 Section 2.1: Error response for invalid request
|
||||||
success=False,
|
response = jsonify({
|
||||||
message="token is required",
|
"error": "invalid_request",
|
||||||
status=400,
|
"error_description": "token is required"
|
||||||
error_type="INVALID_REQUEST",
|
})
|
||||||
error_details={"error": "invalid_request", "error_description": "token is required"},
|
return response, 400
|
||||||
)
|
|
||||||
|
|
||||||
# Authenticate client
|
# Authenticate client
|
||||||
client_id = data.get("client_id")
|
client_id = data.get("client_id")
|
||||||
@@ -986,19 +1107,17 @@ def oidc_introspect():
|
|||||||
user_agent=request.headers.get("User-Agent"),
|
user_agent=request.headers.get("User-Agent"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return api_response(
|
# RFC 7009 Section 2.2: Error response
|
||||||
success=False,
|
response = jsonify({
|
||||||
message="Failed to introspect token",
|
"error": "server_error",
|
||||||
status=500,
|
"error_description": str(e)
|
||||||
error_type="SERVER_ERROR",
|
})
|
||||||
error_details={"error": "server_error", "error_description": str(e)},
|
return response, 500
|
||||||
)
|
|
||||||
|
|
||||||
return api_response(
|
# RFC 7009 Section 2.3: Return introspection response (application/json)
|
||||||
data=result,
|
response = jsonify(result)
|
||||||
message="Token introspection successful",
|
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||||
status=200,
|
return response, 200
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -1030,22 +1149,18 @@ def oidc_register():
|
|||||||
redirect_uris = data.get("redirect_uris", [])
|
redirect_uris = data.get("redirect_uris", [])
|
||||||
|
|
||||||
if not client_name:
|
if not client_name:
|
||||||
return api_response(
|
response = jsonify({
|
||||||
success=False,
|
"error": "invalid_request",
|
||||||
message="client_name is required",
|
"error_description": "client_name is required"
|
||||||
status=400,
|
})
|
||||||
error_type="INVALID_REQUEST",
|
return response, 400
|
||||||
error_details={"error": "invalid_request", "error_description": "client_name is required"},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not redirect_uris:
|
if not redirect_uris:
|
||||||
return api_response(
|
response = jsonify({
|
||||||
success=False,
|
"error": "invalid_request",
|
||||||
message="redirect_uris is required",
|
"error_description": "redirect_uris is required"
|
||||||
status=400,
|
})
|
||||||
error_type="INVALID_REQUEST",
|
return response, 400
|
||||||
error_details={"error": "invalid_request", "error_description": "redirect_uris is required"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate redirect_uris
|
# Validate redirect_uris
|
||||||
for uri in redirect_uris:
|
for uri in redirect_uris:
|
||||||
@@ -1054,13 +1169,11 @@ def oidc_register():
|
|||||||
if not parsed.scheme or not parsed.netloc:
|
if not parsed.scheme or not parsed.netloc:
|
||||||
raise ValueError(f"Invalid redirect URI: {uri}")
|
raise ValueError(f"Invalid redirect URI: {uri}")
|
||||||
except Exception:
|
except Exception:
|
||||||
return api_response(
|
response = jsonify({
|
||||||
success=False,
|
"error": "invalid_request",
|
||||||
message=f"Invalid redirect_uri: {uri}",
|
"error_description": f"Invalid redirect_uri: {uri}"
|
||||||
status=400,
|
})
|
||||||
error_type="INVALID_REQUEST",
|
return response, 400
|
||||||
error_details={"error": "invalid_request", "error_description": f"Invalid redirect_uri: {uri}"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate client credentials
|
# Generate client credentials
|
||||||
client_id = f"oidc_{secrets.token_urlsafe(16)}"
|
client_id = f"oidc_{secrets.token_urlsafe(16)}"
|
||||||
@@ -1092,7 +1205,7 @@ def oidc_register():
|
|||||||
redirect_uris=redirect_uris,
|
redirect_uris=redirect_uris,
|
||||||
grant_types=data.get("grant_types", ["authorization_code", "refresh_token"]),
|
grant_types=data.get("grant_types", ["authorization_code", "refresh_token"]),
|
||||||
response_types=data.get("response_types", ["code"]),
|
response_types=data.get("response_types", ["code"]),
|
||||||
scopes=data.get("scope", "openid profile email").split(),
|
scopes=data.get("scope", "openid profile email roles").split(),
|
||||||
token_endpoint_auth_method=data.get("token_endpoint_auth_method", "client_secret_basic"),
|
token_endpoint_auth_method=data.get("token_endpoint_auth_method", "client_secret_basic"),
|
||||||
is_active=True,
|
is_active=True,
|
||||||
is_confidential=True,
|
is_confidential=True,
|
||||||
@@ -1105,19 +1218,16 @@ def oidc_register():
|
|||||||
client.save()
|
client.save()
|
||||||
|
|
||||||
# Return client credentials
|
# Return client credentials
|
||||||
return api_response(
|
response = jsonify({
|
||||||
data={
|
"client_id": client_id,
|
||||||
"client_id": client_id,
|
"client_secret": client_secret,
|
||||||
"client_secret": client_secret,
|
"client_id_issued_at": int(__import__("time").time()),
|
||||||
"client_id_issued_at": int(__import__("time").time()),
|
"client_secret_expires_at": 0, # Never expires
|
||||||
"client_secret_expires_at": 0, # Never expires
|
"client_name": client_name,
|
||||||
"client_name": client_name,
|
"redirect_uris": redirect_uris,
|
||||||
"redirect_uris": redirect_uris,
|
"token_endpoint_auth_method": data.get("token_endpoint_auth_method", "client_secret_basic"),
|
||||||
"token_endpoint_auth_method": data.get("token_endpoint_auth_method", "client_secret_basic"),
|
"grant_types": client.grant_types,
|
||||||
"grant_types": client.grant_types,
|
"response_types": client.response_types,
|
||||||
"response_types": client.response_types,
|
"scope": " ".join(client.scopes),
|
||||||
"scope": " ".join(client.scopes),
|
})
|
||||||
},
|
return response, 201
|
||||||
message="Client registered successfully",
|
|
||||||
status=201,
|
|
||||||
)
|
|
||||||
|
|||||||
+321
-4
@@ -3,11 +3,20 @@ from flask import request, session, g
|
|||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
from app.api.v1 import api_v1_bp
|
from app.api.v1 import api_v1_bp
|
||||||
from app.utils.response import api_response
|
from app.utils.response import api_response
|
||||||
from app.schemas.auth_schema import RegisterSchema, LoginSchema
|
from app.schemas.auth_schema import (
|
||||||
|
RegisterSchema,
|
||||||
|
LoginSchema,
|
||||||
|
TOTPVerifyEnrollmentSchema,
|
||||||
|
TOTPVerifySchema,
|
||||||
|
TOTPDisableSchema,
|
||||||
|
TOTPRegenerateBackupCodesSchema,
|
||||||
|
)
|
||||||
from app.services.auth_service import AuthService
|
from app.services.auth_service import AuthService
|
||||||
from app.services.user_service import UserService
|
from app.services.user_service import UserService
|
||||||
from app.utils.decorators import login_required
|
from app.utils.decorators import login_required
|
||||||
from app.utils.constants import AuditAction
|
from app.utils.constants import AuditAction
|
||||||
|
from app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||||
|
from app.exceptions.validation_exceptions import ConflictError
|
||||||
|
|
||||||
|
|
||||||
@api_v1_bp.route("/auth/register", methods=["POST"])
|
@api_v1_bp.route("/auth/register", methods=["POST"])
|
||||||
@@ -72,7 +81,7 @@ def login():
|
|||||||
remember_me: Optional boolean for extended session
|
remember_me: Optional boolean for extended session
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
200: Login successful
|
200: Login successful or TOTP code required
|
||||||
400: Validation error
|
400: Validation error
|
||||||
401: Invalid credentials
|
401: Invalid credentials
|
||||||
"""
|
"""
|
||||||
@@ -81,13 +90,29 @@ def login():
|
|||||||
schema = LoginSchema()
|
schema = LoginSchema()
|
||||||
data = schema.load(request.json)
|
data = schema.load(request.json)
|
||||||
|
|
||||||
# Authenticate user
|
# Authenticate user with email and password
|
||||||
user = AuthService.authenticate(
|
user = AuthService.authenticate(
|
||||||
email=data["email"],
|
email=data["email"],
|
||||||
password=data["password"],
|
password=data["password"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create session
|
# Check if user has TOTP enabled for two-factor authentication
|
||||||
|
if user.has_totp_enabled():
|
||||||
|
# TOTP is enabled - store user_id in session for TOTP verification
|
||||||
|
# The /auth/totp/verify endpoint will retrieve this user_id
|
||||||
|
session["totp_pending_user_id"] = user.id
|
||||||
|
|
||||||
|
# Return response indicating TOTP code is required
|
||||||
|
# Do NOT create session or return token yet - wait for TOTP verification
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"requires_totp": True,
|
||||||
|
},
|
||||||
|
message="TOTP code required. Please enter your 6-digit code from your authenticator app.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# TOTP is NOT enabled - proceed with normal login flow
|
||||||
|
# Create session with appropriate duration based on remember_me preference
|
||||||
duration = 2592000 if data.get("remember_me") else 86400 # 30 days vs 1 day
|
duration = 2592000 if data.get("remember_me") else 86400 # 30 days vs 1 day
|
||||||
user_session = AuthService.create_session(user, duration_seconds=duration)
|
user_session = AuthService.create_session(user, duration_seconds=duration)
|
||||||
|
|
||||||
@@ -210,3 +235,295 @@ def revoke_session(session_id):
|
|||||||
return api_response(
|
return api_response(
|
||||||
message="Session revoked successfully",
|
message="Session revoked successfully",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/auth/totp/enroll", methods=["POST"])
|
||||||
|
@login_required
|
||||||
|
def enroll_totp():
|
||||||
|
"""
|
||||||
|
Initiate TOTP enrollment for the current user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
201: TOTP enrollment initiated with secret, provisioning_uri, qr_code, and backup_codes
|
||||||
|
401: Not authenticated
|
||||||
|
409: TOTP already enabled
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Initiate TOTP enrollment
|
||||||
|
result = AuthService.enroll_totp(g.current_user)
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"secret": result["secret"],
|
||||||
|
"provisioning_uri": result["provisioning_uri"],
|
||||||
|
"qr_code": result["qr_code"],
|
||||||
|
"backup_codes": result["backup_codes"],
|
||||||
|
},
|
||||||
|
message="TOTP enrollment initiated. Please verify with your authenticator app.",
|
||||||
|
status=201,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ConflictError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=e.message,
|
||||||
|
status=e.status_code,
|
||||||
|
error_type=e.error_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/auth/totp/verify-enrollment", methods=["POST"])
|
||||||
|
@login_required
|
||||||
|
def verify_totp_enrollment():
|
||||||
|
"""
|
||||||
|
Complete TOTP enrollment by verifying the first TOTP code.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
code: 6-digit TOTP code from authenticator app
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: TOTP enrollment completed successfully
|
||||||
|
400: Validation error
|
||||||
|
401: Not authenticated
|
||||||
|
401: Invalid TOTP code
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate request data
|
||||||
|
schema = TOTPVerifyEnrollmentSchema()
|
||||||
|
data = schema.load(request.json)
|
||||||
|
|
||||||
|
# Verify TOTP enrollment
|
||||||
|
AuthService.verify_totp_enrollment(g.current_user, data["code"])
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
message="TOTP enrollment completed successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Validation failed",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
error_details=e.messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
except InvalidCredentialsError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=e.message,
|
||||||
|
status=e.status_code,
|
||||||
|
error_type=e.error_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/auth/totp/verify", methods=["POST"])
|
||||||
|
def verify_totp():
|
||||||
|
"""
|
||||||
|
Verify TOTP code during login.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
code: 6-digit TOTP code or backup code
|
||||||
|
is_backup_code: True if code is a backup code, False if TOTP code (default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: TOTP code verified successfully with session token
|
||||||
|
400: Validation error
|
||||||
|
401: Invalid TOTP code or session not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate request data
|
||||||
|
schema = TOTPVerifySchema()
|
||||||
|
data = schema.load(request.json)
|
||||||
|
|
||||||
|
# Get user from temporary session (stored in Flask session by login endpoint)
|
||||||
|
user_id = session.get("totp_pending_user_id")
|
||||||
|
if not user_id:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="No pending TOTP verification. Please login first.",
|
||||||
|
status=401,
|
||||||
|
error_type="AUTHENTICATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user from database
|
||||||
|
from app.models.user import User
|
||||||
|
user = User.query.get(user_id)
|
||||||
|
if not user:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="User not found",
|
||||||
|
status=401,
|
||||||
|
error_type="AUTHENTICATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify TOTP code
|
||||||
|
AuthService.authenticate_with_totp(
|
||||||
|
user, data["code"], data.get("is_backup_code", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create full session
|
||||||
|
user_session = AuthService.create_session(user)
|
||||||
|
|
||||||
|
# Clear temporary session
|
||||||
|
session.pop("totp_pending_user_id", None)
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"user": user.to_dict(),
|
||||||
|
"token": user_session.token,
|
||||||
|
"expires_at": user_session.expires_at.isoformat() + "Z"
|
||||||
|
if user_session.expires_at.isoformat()[-1] != "Z"
|
||||||
|
else user_session.expires_at.isoformat(),
|
||||||
|
},
|
||||||
|
message="TOTP verification successful",
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Validation failed",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
error_details=e.messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
except InvalidCredentialsError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=e.message,
|
||||||
|
status=e.status_code,
|
||||||
|
error_type=e.error_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/auth/totp/disable", methods=["DELETE"])
|
||||||
|
@login_required
|
||||||
|
def disable_totp():
|
||||||
|
"""
|
||||||
|
Disable TOTP for the current user.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
password: User's current password for verification
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: TOTP disabled successfully
|
||||||
|
400: Validation error
|
||||||
|
401: Not authenticated or invalid password
|
||||||
|
401: TOTP not enabled
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate request data
|
||||||
|
schema = TOTPDisableSchema()
|
||||||
|
data = schema.load(request.json)
|
||||||
|
|
||||||
|
# Disable TOTP
|
||||||
|
AuthService.disable_totp(g.current_user, data["password"])
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
message="TOTP disabled successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Validation failed",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
error_details=e.messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
except InvalidCredentialsError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=e.message,
|
||||||
|
status=e.status_code,
|
||||||
|
error_type=e.error_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/auth/totp/status", methods=["GET"])
|
||||||
|
@login_required
|
||||||
|
def get_totp_status():
|
||||||
|
"""
|
||||||
|
Get TOTP status for the current user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: TOTP status with totp_enabled, verified_at, and backup_codes_remaining
|
||||||
|
401: Not authenticated
|
||||||
|
"""
|
||||||
|
user = g.current_user
|
||||||
|
|
||||||
|
# Check if TOTP is enabled
|
||||||
|
totp_enabled = user.has_totp_enabled()
|
||||||
|
|
||||||
|
# Get TOTP method to check backup codes remaining
|
||||||
|
backup_codes_remaining = 0
|
||||||
|
verified_at = None
|
||||||
|
|
||||||
|
if totp_enabled:
|
||||||
|
totp_method = user.get_totp_method()
|
||||||
|
if totp_method and totp_method.provider_data:
|
||||||
|
backup_codes = totp_method.provider_data.get("backup_codes", [])
|
||||||
|
backup_codes_remaining = len(backup_codes)
|
||||||
|
if totp_method and totp_method.totp_verified_at:
|
||||||
|
verified_at = totp_method.totp_verified_at.isoformat() + "Z" if totp_method.totp_verified_at.isoformat()[-1] != "Z" else totp_method.totp_verified_at.isoformat()
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"totp_enabled": totp_enabled,
|
||||||
|
"verified_at": verified_at,
|
||||||
|
"backup_codes_remaining": backup_codes_remaining,
|
||||||
|
},
|
||||||
|
message="TOTP status retrieved successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/auth/totp/regenerate-backup-codes", methods=["POST"])
|
||||||
|
@login_required
|
||||||
|
def regenerate_totp_backup_codes():
|
||||||
|
"""
|
||||||
|
Generate new backup codes for TOTP.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
password: User's current password for verification
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: New backup codes generated successfully
|
||||||
|
400: Validation error
|
||||||
|
401: Not authenticated or invalid password
|
||||||
|
401: TOTP not enabled
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate request data
|
||||||
|
schema = TOTPRegenerateBackupCodesSchema()
|
||||||
|
data = schema.load(request.json)
|
||||||
|
|
||||||
|
# Regenerate backup codes
|
||||||
|
backup_codes = AuthService.regenerate_totp_backup_codes(
|
||||||
|
g.current_user, data["password"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"backup_codes": backup_codes,
|
||||||
|
},
|
||||||
|
message="Backup codes regenerated successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Validation failed",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
error_details=e.messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
except InvalidCredentialsError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=e.message,
|
||||||
|
status=e.status_code,
|
||||||
|
error_type=e.error_type,
|
||||||
|
)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ def setup_cors(app):
|
|||||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
|
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
|
||||||
response.headers["Access-Control-Max-Age"] = "3600"
|
response.headers["Access-Control-Max-Age"] = "3600"
|
||||||
|
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||||
return response
|
return response
|
||||||
elif origin and origin in cors_origins:
|
elif origin and origin in cors_origins:
|
||||||
response = make_response("", 204)
|
response = make_response("", 204)
|
||||||
@@ -34,6 +35,7 @@ def setup_cors(app):
|
|||||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
|
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
|
||||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
response.headers["Access-Control-Max-Age"] = "3600"
|
response.headers["Access-Control-Max-Age"] = "3600"
|
||||||
|
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@app.after_request
|
@app.after_request
|
||||||
|
|||||||
@@ -51,4 +51,15 @@ class SecurityHeadersMiddleware:
|
|||||||
"geolocation=(), microphone=(), camera=()"
|
"geolocation=(), microphone=(), camera=()"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Cache-Control: Allow OIDC endpoints to set their own Cache-Control
|
||||||
|
# Only set no-cache for API responses that haven't set their own cache headers
|
||||||
|
if "Cache-Control" not in response.headers:
|
||||||
|
# Check if this is a JSON API response (shouldn't be cached)
|
||||||
|
content_type = response.headers.get("Content-Type", "")
|
||||||
|
if "application/json" in content_type:
|
||||||
|
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||||
|
elif "text/html" not in content_type:
|
||||||
|
# For non-HTML responses, add Pragma for HTTP/1.0 compatibility
|
||||||
|
response.headers["Pragma"] = "no-cache"
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -19,6 +19,11 @@ class AuthenticationMethod(BaseModel):
|
|||||||
provider_user_id = db.Column(db.String(255), nullable=True)
|
provider_user_id = db.Column(db.String(255), nullable=True)
|
||||||
provider_data = db.Column(db.JSON, nullable=True)
|
provider_data = db.Column(db.JSON, nullable=True)
|
||||||
|
|
||||||
|
# # For TOTP authentication
|
||||||
|
# totp_secret = db.Column(db.String(32), nullable=True)
|
||||||
|
# totp_backup_codes = db.Column(db.JSON, nullable=True)
|
||||||
|
# totp_verified_at = db.Column(db.DateTime, nullable=True)
|
||||||
|
|
||||||
# Metadata
|
# Metadata
|
||||||
is_primary = db.Column(db.Boolean, default=False, nullable=False)
|
is_primary = db.Column(db.Boolean, default=False, nullable=False)
|
||||||
verified = db.Column(db.Boolean, default=False, nullable=False)
|
verified = db.Column(db.Boolean, default=False, nullable=False)
|
||||||
@@ -51,9 +56,15 @@ class AuthenticationMethod(BaseModel):
|
|||||||
AuthMethodType.MICROSOFT,
|
AuthMethodType.MICROSOFT,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def is_totp(self):
|
||||||
|
"""Check if this is a TOTP authentication method."""
|
||||||
|
return self.method_type == AuthMethodType.TOTP
|
||||||
|
|
||||||
def to_dict(self, exclude=None):
|
def to_dict(self, exclude=None):
|
||||||
"""Convert to dictionary, excluding sensitive fields."""
|
"""Convert to dictionary, excluding sensitive fields."""
|
||||||
exclude = exclude or []
|
exclude = exclude or []
|
||||||
# Always exclude password hash
|
# Always exclude password hash and TOTP secrets
|
||||||
exclude.append("password_hash")
|
exclude.append("password_hash")
|
||||||
|
exclude.append("totp_secret")
|
||||||
|
exclude.append("totp_backup_codes")
|
||||||
return super().to_dict(exclude=exclude)
|
return super().to_dict(exclude=exclude)
|
||||||
|
|||||||
+5
-5
@@ -1,6 +1,6 @@
|
|||||||
"""Base model with common fields and functionality."""
|
"""Base model with common fields and functionality."""
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from app.extensions import db
|
from app.extensions import db
|
||||||
|
|
||||||
|
|
||||||
@@ -16,9 +16,9 @@ class BaseModel(db.Model):
|
|||||||
unique=True,
|
unique=True,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
created_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow)
|
created_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||||
updated_at = db.Column(
|
updated_at = db.Column(
|
||||||
db.DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow
|
db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)
|
||||||
)
|
)
|
||||||
deleted_at = db.Column(db.DateTime, nullable=True)
|
deleted_at = db.Column(db.DateTime, nullable=True)
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ class BaseModel(db.Model):
|
|||||||
soft: If True, performs soft delete. If False, hard delete.
|
soft: If True, performs soft delete. If False, hard delete.
|
||||||
"""
|
"""
|
||||||
if soft:
|
if soft:
|
||||||
self.deleted_at = datetime.utcnow()
|
self.deleted_at = datetime.now(timezone.utc)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
else:
|
else:
|
||||||
db.session.delete(self)
|
db.session.delete(self)
|
||||||
@@ -47,7 +47,7 @@ class BaseModel(db.Model):
|
|||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
self.updated_at = datetime.utcnow()
|
self.updated_at = datetime.now(timezone.utc)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""OIDC Authorization Code model for auth code flow."""
|
"""OIDC Authorization Code model for auth code flow."""
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from app.extensions import db
|
from app.extensions import db
|
||||||
from app.models.base import BaseModel
|
from app.models.base import BaseModel
|
||||||
|
|
||||||
@@ -49,7 +49,12 @@ class OIDCAuthCode(BaseModel):
|
|||||||
|
|
||||||
def is_expired(self):
|
def is_expired(self):
|
||||||
"""Check if the authorization code has expired."""
|
"""Check if the authorization code has expired."""
|
||||||
return datetime.utcnow() > self.expires_at
|
# Handle both timezone-aware and timezone-naive expires_at values
|
||||||
|
expires_at = self.expires_at
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
# Make naive datetime timezone-aware (UTC)
|
||||||
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
return datetime.now(timezone.utc) > expires_at
|
||||||
|
|
||||||
def is_valid(self):
|
def is_valid(self):
|
||||||
"""Check if the authorization code is valid for use."""
|
"""Check if the authorization code is valid for use."""
|
||||||
@@ -58,7 +63,7 @@ class OIDCAuthCode(BaseModel):
|
|||||||
def mark_as_used(self):
|
def mark_as_used(self):
|
||||||
"""Mark the authorization code as used."""
|
"""Mark the authorization code as used."""
|
||||||
self.is_used = True
|
self.is_used = True
|
||||||
self.used_at = datetime.utcnow()
|
self.used_at = datetime.now(timezone.utc)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -90,7 +95,7 @@ class OIDCAuthCode(BaseModel):
|
|||||||
scope=scope,
|
scope=scope,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
code_verifier=code_verifier,
|
code_verifier=code_verifier,
|
||||||
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
|
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""OIDC Refresh Token model for token rotation."""
|
"""OIDC Refresh Token model for token rotation."""
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from app.extensions import db
|
from app.extensions import db
|
||||||
from app.models.base import BaseModel
|
from app.models.base import BaseModel
|
||||||
|
|
||||||
@@ -58,7 +58,11 @@ class OIDCRefreshToken(BaseModel):
|
|||||||
|
|
||||||
def is_expired(self):
|
def is_expired(self):
|
||||||
"""Check if the refresh token has expired."""
|
"""Check if the refresh token has expired."""
|
||||||
return datetime.utcnow() > self.expires_at
|
# Handle both timezone-aware and timezone-naive expires_at values
|
||||||
|
expires_at = self.expires_at
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
return datetime.now(timezone.utc) > expires_at
|
||||||
|
|
||||||
def is_revoked(self):
|
def is_revoked(self):
|
||||||
"""Check if the refresh token has been revoked."""
|
"""Check if the refresh token has been revoked."""
|
||||||
@@ -74,7 +78,7 @@ class OIDCRefreshToken(BaseModel):
|
|||||||
Args:
|
Args:
|
||||||
reason: Optional reason for revocation
|
reason: Optional reason for revocation
|
||||||
"""
|
"""
|
||||||
self.revoked_at = datetime.utcnow()
|
self.revoked_at = datetime.now(timezone.utc)
|
||||||
self.revoked_reason = reason
|
self.revoked_reason = reason
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@@ -93,7 +97,7 @@ class OIDCRefreshToken(BaseModel):
|
|||||||
self.rotation_count += 1
|
self.rotation_count += 1
|
||||||
# Extend expiration on rotation
|
# Extend expiration on rotation
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
self.expires_at = datetime.utcnow() + timedelta(days=30)
|
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -123,7 +127,7 @@ class OIDCRefreshToken(BaseModel):
|
|||||||
token_hash=token_hash,
|
token_hash=token_hash,
|
||||||
scope=scope,
|
scope=scope,
|
||||||
access_token_id=access_token_id,
|
access_token_id=access_token_id,
|
||||||
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
|
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""OIDC Session model for OIDC session tracking."""
|
"""OIDC Session model for OIDC session tracking."""
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from app.extensions import db
|
from app.extensions import db
|
||||||
from app.models.base import BaseModel
|
from app.models.base import BaseModel
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ class OIDCSession(BaseModel):
|
|||||||
|
|
||||||
def is_expired(self):
|
def is_expired(self):
|
||||||
"""Check if the OIDC session has expired."""
|
"""Check if the OIDC session has expired."""
|
||||||
return datetime.utcnow() > self.expires_at
|
return datetime.now(timezone.utc) > self.expires_at
|
||||||
|
|
||||||
def is_authenticated(self):
|
def is_authenticated(self):
|
||||||
"""Check if the user has been authenticated in this session."""
|
"""Check if the user has been authenticated in this session."""
|
||||||
@@ -57,7 +57,7 @@ class OIDCSession(BaseModel):
|
|||||||
|
|
||||||
def mark_authenticated(self):
|
def mark_authenticated(self):
|
||||||
"""Mark the session as authenticated."""
|
"""Mark the session as authenticated."""
|
||||||
self.authenticated_at = datetime.utcnow()
|
self.authenticated_at = datetime.now(timezone.utc)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def validate_nonce(self, expected_nonce):
|
def validate_nonce(self, expected_nonce):
|
||||||
@@ -126,7 +126,7 @@ class OIDCSession(BaseModel):
|
|||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
code_challenge=code_challenge,
|
code_challenge=code_challenge,
|
||||||
code_challenge_method=code_challenge_method,
|
code_challenge_method=code_challenge_method,
|
||||||
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
|
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||||
)
|
)
|
||||||
db.session.add(session)
|
db.session.add(session)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""OIDC Token Metadata model for token revocation tracking."""
|
"""OIDC Token Metadata model for token revocation tracking."""
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from app.extensions import db
|
from app.extensions import db
|
||||||
from app.models.base import BaseModel
|
from app.models.base import BaseModel
|
||||||
|
|
||||||
@@ -50,7 +50,11 @@ class OIDCTokenMetadata(BaseModel):
|
|||||||
|
|
||||||
def is_expired(self):
|
def is_expired(self):
|
||||||
"""Check if the token has expired."""
|
"""Check if the token has expired."""
|
||||||
return datetime.utcnow() > self.expires_at
|
# Handle both timezone-aware and timezone-naive expires_at values
|
||||||
|
expires_at = self.expires_at
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
return datetime.now(timezone.utc) > expires_at
|
||||||
|
|
||||||
def is_revoked(self):
|
def is_revoked(self):
|
||||||
"""Check if the token has been revoked."""
|
"""Check if the token has been revoked."""
|
||||||
@@ -66,7 +70,7 @@ class OIDCTokenMetadata(BaseModel):
|
|||||||
Args:
|
Args:
|
||||||
reason: Optional reason for revocation
|
reason: Optional reason for revocation
|
||||||
"""
|
"""
|
||||||
self.revoked_at = datetime.utcnow()
|
self.revoked_at = datetime.now(timezone.utc)
|
||||||
self.revoked_reason = reason
|
self.revoked_reason = reason
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Session model."""
|
"""Session model."""
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from app.extensions import db
|
from app.extensions import db
|
||||||
from app.models.base import BaseModel
|
from app.models.base import BaseModel
|
||||||
from app.utils.constants import SessionStatus
|
from app.utils.constants import SessionStatus
|
||||||
@@ -34,7 +34,7 @@ class Session(BaseModel):
|
|||||||
|
|
||||||
def is_active(self):
|
def is_active(self):
|
||||||
"""Check if session is currently active."""
|
"""Check if session is currently active."""
|
||||||
now = datetime.utcnow()
|
now = datetime.now(timezone.utc)
|
||||||
return (
|
return (
|
||||||
self.status == SessionStatus.ACTIVE
|
self.status == SessionStatus.ACTIVE
|
||||||
and self.expires_at > now
|
and self.expires_at > now
|
||||||
@@ -43,7 +43,7 @@ class Session(BaseModel):
|
|||||||
|
|
||||||
def is_expired(self):
|
def is_expired(self):
|
||||||
"""Check if session has expired."""
|
"""Check if session has expired."""
|
||||||
return datetime.utcnow() > self.expires_at
|
return datetime.now(timezone.utc) > self.expires_at
|
||||||
|
|
||||||
def refresh(self, duration_seconds=86400):
|
def refresh(self, duration_seconds=86400):
|
||||||
"""
|
"""
|
||||||
@@ -52,8 +52,8 @@ class Session(BaseModel):
|
|||||||
Args:
|
Args:
|
||||||
duration_seconds: New session duration in seconds
|
duration_seconds: New session duration in seconds
|
||||||
"""
|
"""
|
||||||
self.expires_at = datetime.utcnow() + timedelta(seconds=duration_seconds)
|
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
|
||||||
self.last_activity_at = datetime.utcnow()
|
self.last_activity_at = datetime.now(timezone.utc)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def revoke(self, reason=None):
|
def revoke(self, reason=None):
|
||||||
@@ -64,7 +64,7 @@ class Session(BaseModel):
|
|||||||
reason: Optional reason for revocation
|
reason: Optional reason for revocation
|
||||||
"""
|
"""
|
||||||
self.status = SessionStatus.REVOKED
|
self.status = SessionStatus.REVOKED
|
||||||
self.revoked_at = datetime.utcnow()
|
self.revoked_at = datetime.now(timezone.utc)
|
||||||
if reason:
|
if reason:
|
||||||
self.revoked_reason = reason
|
self.revoked_reason = reason
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|||||||
@@ -59,3 +59,35 @@ class User(BaseModel):
|
|||||||
def get_organizations(self):
|
def get_organizations(self):
|
||||||
"""Get all organizations the user is a member of."""
|
"""Get all organizations the user is a member of."""
|
||||||
return [membership.organization for membership in self.organization_memberships]
|
return [membership.organization for membership in self.organization_memberships]
|
||||||
|
|
||||||
|
def has_totp_enabled(self) -> bool:
|
||||||
|
"""Check if user has TOTP enabled and verified.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if user has a verified TOTP authentication method, False otherwise.
|
||||||
|
"""
|
||||||
|
from app.models.authentication_method import AuthenticationMethod
|
||||||
|
from app.utils.constants import AuthMethodType
|
||||||
|
|
||||||
|
return (
|
||||||
|
AuthenticationMethod.query.filter_by(
|
||||||
|
user_id=self.id,
|
||||||
|
method_type=AuthMethodType.TOTP,
|
||||||
|
verified=True,
|
||||||
|
deleted_at=None,
|
||||||
|
).first()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_totp_method(self):
|
||||||
|
"""Get user's TOTP authentication method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The AuthenticationMethod instance for TOTP or None if not found.
|
||||||
|
"""
|
||||||
|
from app.models.authentication_method import AuthenticationMethod
|
||||||
|
from app.utils.constants import AuthMethodType
|
||||||
|
|
||||||
|
return AuthenticationMethod.query.filter_by(
|
||||||
|
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
|
||||||
|
).first()
|
||||||
|
|||||||
@@ -55,3 +55,34 @@ class ResetPasswordSchema(Schema):
|
|||||||
"""Validate that passwords match."""
|
"""Validate that passwords match."""
|
||||||
if data.get("password") != data.get("password_confirm"):
|
if data.get("password") != data.get("password_confirm"):
|
||||||
raise ValidationError("Passwords do not match", field_name="password_confirm")
|
raise ValidationError("Passwords do not match", field_name="password_confirm")
|
||||||
|
|
||||||
|
|
||||||
|
class TOTPVerifyEnrollmentSchema(Schema):
|
||||||
|
"""Schema for TOTP enrollment verification."""
|
||||||
|
|
||||||
|
code = fields.Str(
|
||||||
|
required=True,
|
||||||
|
validate=validate.Regexp(
|
||||||
|
r"^\d{6}$",
|
||||||
|
error="Code must be a 6-digit number",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TOTPVerifySchema(Schema):
|
||||||
|
"""Schema for TOTP code verification during login."""
|
||||||
|
|
||||||
|
code = fields.Str(required=True)
|
||||||
|
is_backup_code = fields.Bool(missing=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TOTPDisableSchema(Schema):
|
||||||
|
"""Schema for disabling TOTP."""
|
||||||
|
|
||||||
|
password = fields.Str(required=True, validate=validate.Length(min=1))
|
||||||
|
|
||||||
|
|
||||||
|
class TOTPRegenerateBackupCodesSchema(Schema):
|
||||||
|
"""Schema for regenerating backup codes."""
|
||||||
|
|
||||||
|
password = fields.Str(required=True, validate=validate.Length(min=1))
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from app.utils.constants import AuthMethodType, SessionStatus, UserStatus, Audit
|
|||||||
from app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
|
from app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
|
||||||
from app.exceptions.validation_exceptions import EmailAlreadyExistsError
|
from app.exceptions.validation_exceptions import EmailAlreadyExistsError
|
||||||
from app.services.audit_service import AuditService
|
from app.services.audit_service import AuditService
|
||||||
|
from app.services.totp_service import TOTPService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -234,3 +235,315 @@ class AuthService:
|
|||||||
resource_id=session.id,
|
resource_id=session.id,
|
||||||
description=f"Session revoked: {reason or 'User logout'}",
|
description=f"Session revoked: {reason or 'User logout'}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def enroll_totp(user: User) -> dict:
|
||||||
|
"""
|
||||||
|
Initiate TOTP enrollment for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- secret: TOTP secret (base32 encoded)
|
||||||
|
- provisioning_uri: otpauth:// URI for QR code
|
||||||
|
- qr_code: Base64 encoded QR code as data URI
|
||||||
|
- backup_codes: List of plain text backup codes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConflictError: If user already has TOTP enabled
|
||||||
|
"""
|
||||||
|
from app.exceptions.validation_exceptions import ConflictError
|
||||||
|
|
||||||
|
# Check if user already has TOTP enabled
|
||||||
|
if user.has_totp_enabled():
|
||||||
|
raise ConflictError("TOTP is already enabled for this account")
|
||||||
|
|
||||||
|
# Generate TOTP secret
|
||||||
|
secret = TOTPService.generate_secret()
|
||||||
|
|
||||||
|
# Generate provisioning URI
|
||||||
|
provisioning_uri = TOTPService.generate_provisioning_uri(
|
||||||
|
user_email=user.email,
|
||||||
|
secret=secret,
|
||||||
|
issuer="Gatehouse",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate QR code data URI
|
||||||
|
qr_code = TOTPService.generate_qr_code_data_uri(provisioning_uri)
|
||||||
|
|
||||||
|
# Generate backup codes
|
||||||
|
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
|
||||||
|
|
||||||
|
# Create unverified TOTP authentication method
|
||||||
|
auth_method = AuthenticationMethod(
|
||||||
|
user_id=user.id,
|
||||||
|
method_type=AuthMethodType.TOTP,
|
||||||
|
verified=False,
|
||||||
|
is_primary=False,
|
||||||
|
)
|
||||||
|
auth_method.save()
|
||||||
|
|
||||||
|
# Store TOTP data in provider_data (since totp_secret field is commented out)
|
||||||
|
auth_method.provider_data = {
|
||||||
|
"secret": secret,
|
||||||
|
"backup_codes": hashed_backup_codes,
|
||||||
|
}
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Log TOTP enrollment initiation
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.TOTP_ENROLL_INITIATED,
|
||||||
|
user_id=user.id,
|
||||||
|
resource_type="authentication_method",
|
||||||
|
resource_id=auth_method.id,
|
||||||
|
description="TOTP enrollment initiated",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"secret": secret,
|
||||||
|
"provisioning_uri": provisioning_uri,
|
||||||
|
"qr_code": qr_code,
|
||||||
|
"backup_codes": backup_codes,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_totp_enrollment(user: User, code: str) -> bool:
|
||||||
|
"""
|
||||||
|
Complete TOTP enrollment by verifying the first TOTP code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
code: 6-digit TOTP code from authenticator app
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if verification successful
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidCredentialsError: If code is invalid or TOTP method not found
|
||||||
|
"""
|
||||||
|
# Get user's TOTP authentication method
|
||||||
|
auth_method = user.get_totp_method()
|
||||||
|
if not auth_method:
|
||||||
|
raise InvalidCredentialsError("TOTP enrollment not found")
|
||||||
|
|
||||||
|
# Get secret from provider_data
|
||||||
|
secret = auth_method.provider_data.get("secret") if auth_method.provider_data else None
|
||||||
|
if not secret:
|
||||||
|
raise InvalidCredentialsError("TOTP secret not found")
|
||||||
|
|
||||||
|
# Verify the code
|
||||||
|
if not TOTPService.verify_code(secret, code):
|
||||||
|
raise InvalidCredentialsError("Invalid TOTP code")
|
||||||
|
|
||||||
|
# Mark TOTP as verified
|
||||||
|
auth_method.verified = True
|
||||||
|
auth_method.totp_verified_at = datetime.utcnow()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Log TOTP enrollment completion
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.TOTP_ENROLL_COMPLETED,
|
||||||
|
user_id=user.id,
|
||||||
|
resource_type="authentication_method",
|
||||||
|
resource_id=auth_method.id,
|
||||||
|
description="TOTP enrollment completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def disable_totp(user: User, password: str) -> bool:
|
||||||
|
"""
|
||||||
|
Disable TOTP for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
password: User's current password for verification
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if TOTP disabled successfully
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidCredentialsError: If password is invalid or TOTP method not found
|
||||||
|
"""
|
||||||
|
# Verify user's password
|
||||||
|
auth_method = AuthenticationMethod.query.filter_by(
|
||||||
|
user_id=user.id,
|
||||||
|
method_type=AuthMethodType.PASSWORD,
|
||||||
|
deleted_at=None,
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not auth_method or not auth_method.password_hash:
|
||||||
|
raise InvalidCredentialsError("No password authentication method found")
|
||||||
|
|
||||||
|
if not bcrypt.check_password_hash(auth_method.password_hash, password):
|
||||||
|
raise InvalidCredentialsError("Invalid password")
|
||||||
|
|
||||||
|
# Get user's TOTP authentication method
|
||||||
|
totp_method = user.get_totp_method()
|
||||||
|
if not totp_method:
|
||||||
|
raise InvalidCredentialsError("TOTP is not enabled for this account")
|
||||||
|
|
||||||
|
# Soft-delete the TOTP authentication method
|
||||||
|
totp_method.delete(soft=True)
|
||||||
|
|
||||||
|
# Log TOTP disabled
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.TOTP_DISABLED,
|
||||||
|
user_id=user.id,
|
||||||
|
resource_type="authentication_method",
|
||||||
|
resource_id=totp_method.id,
|
||||||
|
description="TOTP disabled",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def authenticate_with_totp(user: User, code: str, is_backup_code: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
Verify TOTP code during login.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
code: 6-digit TOTP code or backup code
|
||||||
|
is_backup_code: True if code is a backup code, False if TOTP code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if code is valid
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidCredentialsError: If code is invalid or TOTP method not found
|
||||||
|
"""
|
||||||
|
# Get user's TOTP authentication method
|
||||||
|
auth_method = user.get_totp_method()
|
||||||
|
if not auth_method:
|
||||||
|
raise InvalidCredentialsError("TOTP is not enabled for this account")
|
||||||
|
|
||||||
|
if is_backup_code:
|
||||||
|
# Verify backup code
|
||||||
|
backup_codes = (
|
||||||
|
auth_method.provider_data.get("backup_codes")
|
||||||
|
if auth_method.provider_data
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
is_valid, remaining_codes = TOTPService.verify_backup_code(backup_codes, code)
|
||||||
|
|
||||||
|
if is_valid:
|
||||||
|
# Update remaining backup codes
|
||||||
|
auth_method.provider_data = {
|
||||||
|
"secret": auth_method.provider_data.get("secret"),
|
||||||
|
"backup_codes": remaining_codes,
|
||||||
|
}
|
||||||
|
auth_method.last_used_at = datetime.utcnow()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Log backup code usage
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.TOTP_BACKUP_CODE_USED,
|
||||||
|
user_id=user.id,
|
||||||
|
resource_type="authentication_method",
|
||||||
|
resource_id=auth_method.id,
|
||||||
|
description="Backup code used for authentication",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Log failed verification
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.TOTP_VERIFY_FAILED,
|
||||||
|
user_id=user.id,
|
||||||
|
resource_type="authentication_method",
|
||||||
|
resource_id=auth_method.id,
|
||||||
|
description="Invalid backup code provided",
|
||||||
|
)
|
||||||
|
raise InvalidCredentialsError("Invalid backup code")
|
||||||
|
else:
|
||||||
|
# Verify TOTP code
|
||||||
|
secret = (
|
||||||
|
auth_method.provider_data.get("secret")
|
||||||
|
if auth_method.provider_data
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if not secret:
|
||||||
|
raise InvalidCredentialsError("TOTP secret not found")
|
||||||
|
|
||||||
|
is_valid = TOTPService.verify_code(secret, code)
|
||||||
|
|
||||||
|
if is_valid:
|
||||||
|
auth_method.last_used_at = datetime.utcnow()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Log successful verification
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.TOTP_VERIFY_SUCCESS,
|
||||||
|
user_id=user.id,
|
||||||
|
resource_type="authentication_method",
|
||||||
|
resource_id=auth_method.id,
|
||||||
|
description="TOTP code verified successfully",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Log failed verification
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.TOTP_VERIFY_FAILED,
|
||||||
|
user_id=user.id,
|
||||||
|
resource_type="authentication_method",
|
||||||
|
resource_id=auth_method.id,
|
||||||
|
description="Invalid TOTP code provided",
|
||||||
|
)
|
||||||
|
raise InvalidCredentialsError("Invalid TOTP code")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def regenerate_totp_backup_codes(user: User, password: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Generate new backup codes for TOTP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
password: User's current password for verification
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of new plain text backup codes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidCredentialsError: If password is invalid or TOTP method not found
|
||||||
|
"""
|
||||||
|
# Verify user's password
|
||||||
|
auth_method = AuthenticationMethod.query.filter_by(
|
||||||
|
user_id=user.id,
|
||||||
|
method_type=AuthMethodType.PASSWORD,
|
||||||
|
deleted_at=None,
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not auth_method or not auth_method.password_hash:
|
||||||
|
raise InvalidCredentialsError("No password authentication method found")
|
||||||
|
|
||||||
|
if not bcrypt.check_password_hash(auth_method.password_hash, password):
|
||||||
|
raise InvalidCredentialsError("Invalid password")
|
||||||
|
|
||||||
|
# Get user's TOTP authentication method
|
||||||
|
totp_method = user.get_totp_method()
|
||||||
|
if not totp_method:
|
||||||
|
raise InvalidCredentialsError("TOTP is not enabled for this account")
|
||||||
|
|
||||||
|
# Generate new backup codes
|
||||||
|
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
|
||||||
|
|
||||||
|
# Update the authentication method with new backup codes
|
||||||
|
totp_method.provider_data = {
|
||||||
|
"secret": totp_method.provider_data.get("secret"),
|
||||||
|
"backup_codes": hashed_backup_codes,
|
||||||
|
}
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Log backup codes regeneration
|
||||||
|
AuditService.log_action(
|
||||||
|
action=AuditAction.TOTP_BACKUP_CODES_REGENERATED,
|
||||||
|
user_id=user.id,
|
||||||
|
resource_type="authentication_method",
|
||||||
|
resource_id=totp_method.id,
|
||||||
|
description="TOTP backup codes regenerated",
|
||||||
|
)
|
||||||
|
|
||||||
|
return backup_codes
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
import hashlib
|
import hashlib
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from flask import current_app, g
|
from flask import current_app, g
|
||||||
@@ -14,6 +14,7 @@ from app.models import (
|
|||||||
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
|
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
|
||||||
OIDCSession, OIDCTokenMetadata
|
OIDCSession, OIDCTokenMetadata
|
||||||
)
|
)
|
||||||
|
from app.models.organization_member import OrganizationMember
|
||||||
from app.exceptions.validation_exceptions import (
|
from app.exceptions.validation_exceptions import (
|
||||||
ValidationError, NotFoundError, BadRequestError
|
ValidationError, NotFoundError, BadRequestError
|
||||||
)
|
)
|
||||||
@@ -121,6 +122,14 @@ class OIDCService:
|
|||||||
ValidationError: If parameters are invalid
|
ValidationError: If parameters are invalid
|
||||||
NotFoundError: If client not found
|
NotFoundError: If client not found
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
|
logger.debug("[OIDC SERVICE] generate_authorization_code called")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||||
|
logger.debug("[OIDC SERVICE] redirect_uri=%s", redirect_uri)
|
||||||
|
logger.debug("[OIDC SERVICE] scope=%s", scope)
|
||||||
|
logger.debug("[OIDC SERVICE] state=%s, nonce=%s", state, nonce)
|
||||||
|
|
||||||
# Validate client exists and is active
|
# Validate client exists and is active
|
||||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||||
|
|
||||||
@@ -152,14 +161,19 @@ class OIDCService:
|
|||||||
raise ValidationError("Invalid scopes")
|
raise ValidationError("Invalid scopes")
|
||||||
|
|
||||||
# Generate authorization code
|
# Generate authorization code
|
||||||
|
logger.debug("[OIDC SERVICE] Generating authorization code...")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time before code generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
code = cls._generate_code()
|
code = cls._generate_code()
|
||||||
code_hash = cls._hash_value(code)
|
code_hash = cls._hash_value(code)
|
||||||
|
logger.debug("[OIDC SERVICE] Code generated: %s...", code[:20] if code else None)
|
||||||
|
|
||||||
# Development-only debug logging for PKCE in code creation
|
# Development-only debug logging for PKCE in code creation
|
||||||
if current_app.config.get('ENV') == 'development':
|
if current_app.config.get('ENV') == 'development':
|
||||||
logger.debug(f"[OIDC] Generate auth code - PKCE: code_challenge={code_challenge is not None}, code_challenge_method={code_challenge_method}")
|
logger.debug(f"[OIDC] Generate auth code - PKCE: code_challenge={code_challenge is not None}, code_challenge_method={code_challenge_method}")
|
||||||
|
|
||||||
# Create auth code record
|
# Create auth code record
|
||||||
|
logger.debug("[OIDC SERVICE] Creating auth code record with lifetime_seconds=600 (10 minutes)")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time before creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
auth_code = OIDCAuthCode.create_code(
|
auth_code = OIDCAuthCode.create_code(
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -172,6 +186,9 @@ class OIDCService:
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
lifetime_seconds=600, # 10 minutes
|
lifetime_seconds=600, # 10 minutes
|
||||||
)
|
)
|
||||||
|
logger.debug("[OIDC SERVICE] Auth code created successfully")
|
||||||
|
logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time after creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
|
||||||
# Log authorization event
|
# Log authorization event
|
||||||
OIDCAuditService.log_authorization_event(
|
OIDCAuditService.log_authorization_event(
|
||||||
@@ -182,6 +199,9 @@ class OIDCService:
|
|||||||
scope=valid_scopes,
|
scope=valid_scopes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] generate_authorization_code completed successfully")
|
||||||
|
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
return code
|
return code
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -211,6 +231,12 @@ class OIDCService:
|
|||||||
InvalidGrantError: If code is invalid
|
InvalidGrantError: If code is invalid
|
||||||
ValidationError: If PKCE validation fails
|
ValidationError: If PKCE validation fails
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
|
logger.debug("[OIDC SERVICE] validate_authorization_code called")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] client_id=%s, redirect_uri=%s", client_id, redirect_uri)
|
||||||
|
logger.debug("[OIDC SERVICE] code_verifier provided: %s", bool(code_verifier))
|
||||||
|
|
||||||
# Get client
|
# Get client
|
||||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||||
|
|
||||||
@@ -223,6 +249,8 @@ class OIDCService:
|
|||||||
raise InvalidGrantError("Invalid client")
|
raise InvalidGrantError("Invalid client")
|
||||||
|
|
||||||
# Hash the provided code and find matching auth code
|
# Hash the provided code and find matching auth code
|
||||||
|
logger.debug("[OIDC SERVICE] Looking up authorization code...")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time before code lookup: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
code_hash = cls._hash_value(code)
|
code_hash = cls._hash_value(code)
|
||||||
auth_code = OIDCAuthCode.query.filter_by(
|
auth_code = OIDCAuthCode.query.filter_by(
|
||||||
code_hash=code_hash,
|
code_hash=code_hash,
|
||||||
@@ -256,8 +284,18 @@ class OIDCService:
|
|||||||
raise InvalidGrantError("Authorization code already used")
|
raise InvalidGrantError("Authorization code already used")
|
||||||
|
|
||||||
# Check expiration
|
# Check expiration
|
||||||
|
logger.debug("[OIDC SERVICE] Checking if authorization code is expired...")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z")
|
||||||
|
# Handle timezone-naive expires_at from database
|
||||||
|
expires_at = auth_code.expires_at
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (expires_at - datetime.now(timezone.utc)).total_seconds())
|
||||||
|
|
||||||
if auth_code.is_expired():
|
if auth_code.is_expired():
|
||||||
logger.error(f"[OIDC] Validate auth code - Code expired: code_hash={code_hash[:20]}..., expires_at={auth_code.expires_at}")
|
logger.error("[OIDC] Validate auth code - Code expired: code_hash=%s..., expires_at (UTC)=%s, current UTC time=%s",
|
||||||
|
code_hash[:20], auth_code.expires_at.isoformat() + "Z", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
OIDCAuditService.log_authorization_event(
|
OIDCAuditService.log_authorization_event(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
user_id=auth_code.user_id,
|
user_id=auth_code.user_id,
|
||||||
@@ -316,6 +354,9 @@ class OIDCService:
|
|||||||
"nonce": auth_code.nonce,
|
"nonce": auth_code.nonce,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] validate_authorization_code completed successfully")
|
||||||
|
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
return claims, user
|
return claims, user
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -366,6 +407,12 @@ class OIDCService:
|
|||||||
"""
|
"""
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
|
logger.debug("[OIDC SERVICE] generate_tokens called")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s, scope=%s", client_id, user_id, scope)
|
||||||
|
logger.debug("[OIDC SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
|
||||||
|
|
||||||
# Get client
|
# Get client
|
||||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||||
|
|
||||||
@@ -377,6 +424,9 @@ class OIDCService:
|
|||||||
raise InvalidClientError()
|
raise InvalidClientError()
|
||||||
|
|
||||||
# Generate access token
|
# Generate access token
|
||||||
|
logger.debug("[OIDC SERVICE] Generating access token...")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600)
|
||||||
access_token_jti = OIDCTokenService._generate_jti()
|
access_token_jti = OIDCTokenService._generate_jti()
|
||||||
access_token = OIDCTokenService.create_access_token(
|
access_token = OIDCTokenService.create_access_token(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
@@ -384,8 +434,13 @@ class OIDCService:
|
|||||||
scope=scope,
|
scope=scope,
|
||||||
jti=access_token_jti,
|
jti=access_token_jti,
|
||||||
)
|
)
|
||||||
|
logger.debug("[OIDC SERVICE] Access token generated successfully")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
|
||||||
# Generate ID token
|
# Generate ID token
|
||||||
|
logger.debug("[OIDC SERVICE] Generating ID token...")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600)
|
||||||
id_token = OIDCTokenService.create_id_token(
|
id_token = OIDCTokenService.create_id_token(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -394,6 +449,8 @@ class OIDCService:
|
|||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
auth_time=auth_time,
|
auth_time=auth_time,
|
||||||
)
|
)
|
||||||
|
logger.debug("[OIDC SERVICE] ID token generated successfully")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
|
||||||
# Generate or rotate refresh token
|
# Generate or rotate refresh token
|
||||||
if "refresh_token" in (client.grant_types or []):
|
if "refresh_token" in (client.grant_types or []):
|
||||||
@@ -445,22 +502,28 @@ class OIDCService:
|
|||||||
client_db_id = client.id
|
client_db_id = client.id
|
||||||
|
|
||||||
# Access token metadata
|
# Access token metadata
|
||||||
|
logger.debug("[OIDC SERVICE] Creating access token metadata...")
|
||||||
|
access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600)
|
||||||
|
logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z")
|
||||||
OIDCTokenMetadata.create_metadata(
|
OIDCTokenMetadata.create_metadata(
|
||||||
client_id=client_db_id,
|
client_id=client_db_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token_type="access_token",
|
token_type="access_token",
|
||||||
token_jti=access_token_jti,
|
token_jti=access_token_jti,
|
||||||
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
|
expires_at=access_token_expires_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ID token metadata (using access token JTI as reference)
|
# ID token metadata (using access token JTI as reference)
|
||||||
|
logger.debug("[OIDC SERVICE] Creating ID token metadata...")
|
||||||
id_token_jti = OIDCTokenService._generate_jti()
|
id_token_jti = OIDCTokenService._generate_jti()
|
||||||
|
id_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.id_token_lifetime or 3600)
|
||||||
|
logger.debug("[OIDC SERVICE] ID token expires_at (UTC): %s", id_token_expires_at.isoformat() + "Z")
|
||||||
OIDCTokenMetadata.create_metadata(
|
OIDCTokenMetadata.create_metadata(
|
||||||
client_id=client_db_id,
|
client_id=client_db_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token_type="id_token",
|
token_type="id_token",
|
||||||
token_jti=id_token_jti,
|
token_jti=id_token_jti,
|
||||||
expires_at=datetime.utcnow() + timedelta(seconds=client.id_token_lifetime or 3600),
|
expires_at=id_token_expires_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log token event
|
# Log token event
|
||||||
@@ -483,6 +546,9 @@ class OIDCService:
|
|||||||
if final_refresh_token:
|
if final_refresh_token:
|
||||||
result["refresh_token"] = final_refresh_token
|
result["refresh_token"] = final_refresh_token
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] generate_tokens completed successfully")
|
||||||
|
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -511,6 +577,11 @@ class OIDCService:
|
|||||||
"""
|
"""
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
|
logger.debug("[OIDC SERVICE] refresh_access_token called")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] client_id=%s, scope=%s", client_id, scope)
|
||||||
|
|
||||||
# Get client
|
# Get client
|
||||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||||
|
|
||||||
@@ -522,6 +593,8 @@ class OIDCService:
|
|||||||
raise InvalidClientError()
|
raise InvalidClientError()
|
||||||
|
|
||||||
# Find refresh token
|
# Find refresh token
|
||||||
|
logger.debug("[OIDC SERVICE] Looking up refresh token...")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time before refresh token lookup: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
||||||
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
||||||
token_hash=token_hash,
|
token_hash=token_hash,
|
||||||
@@ -542,6 +615,16 @@ class OIDCService:
|
|||||||
raise InvalidGrantError("Invalid refresh token")
|
raise InvalidGrantError("Invalid refresh token")
|
||||||
|
|
||||||
# Check if valid
|
# Check if valid
|
||||||
|
logger.debug("[OIDC SERVICE] Checking if refresh token is valid...")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
if refresh_token_obj:
|
||||||
|
logger.debug("[OIDC SERVICE] Refresh token expires_at (UTC): %s", refresh_token_obj.expires_at.isoformat() + "Z")
|
||||||
|
# Handle timezone-naive expires_at from database
|
||||||
|
rt_expires_at = refresh_token_obj.expires_at
|
||||||
|
if rt_expires_at.tzinfo is None:
|
||||||
|
rt_expires_at = rt_expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (rt_expires_at - datetime.now(timezone.utc)).total_seconds())
|
||||||
|
|
||||||
if not refresh_token_obj.is_valid():
|
if not refresh_token_obj.is_valid():
|
||||||
OIDCAuditService.log_token_event(
|
OIDCAuditService.log_token_event(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
@@ -563,6 +646,9 @@ class OIDCService:
|
|||||||
granted_scope = scope or (refresh_token_obj.scope or [])
|
granted_scope = scope or (refresh_token_obj.scope or [])
|
||||||
|
|
||||||
# Generate new access token
|
# Generate new access token
|
||||||
|
logger.debug("[OIDC SERVICE] Generating new access token...")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600)
|
||||||
access_token_jti = OIDCTokenService._generate_jti()
|
access_token_jti = OIDCTokenService._generate_jti()
|
||||||
access_token = OIDCTokenService.create_access_token(
|
access_token = OIDCTokenService.create_access_token(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
@@ -570,14 +656,21 @@ class OIDCService:
|
|||||||
scope=granted_scope,
|
scope=granted_scope,
|
||||||
jti=access_token_jti,
|
jti=access_token_jti,
|
||||||
)
|
)
|
||||||
|
logger.debug("[OIDC SERVICE] Access token generated successfully")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
|
||||||
# Generate new ID token
|
# Generate new ID token
|
||||||
|
logger.debug("[OIDC SERVICE] Generating new ID token...")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600)
|
||||||
id_token = OIDCTokenService.create_id_token(
|
id_token = OIDCTokenService.create_id_token(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
user_id=refresh_token_obj.user_id,
|
user_id=refresh_token_obj.user_id,
|
||||||
scope=granted_scope,
|
scope=granted_scope,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
|
logger.debug("[OIDC SERVICE] ID token generated successfully")
|
||||||
|
logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
|
||||||
# Rotate refresh token
|
# Rotate refresh token
|
||||||
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
|
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
|
||||||
@@ -590,12 +683,15 @@ class OIDCService:
|
|||||||
refresh_token_obj.rotate(new_hash)
|
refresh_token_obj.rotate(new_hash)
|
||||||
|
|
||||||
# Store new token metadata
|
# Store new token metadata
|
||||||
|
logger.debug("[OIDC SERVICE] Creating access token metadata...")
|
||||||
|
access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600)
|
||||||
|
logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z")
|
||||||
OIDCTokenMetadata.create_metadata(
|
OIDCTokenMetadata.create_metadata(
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
user_id=refresh_token_obj.user_id,
|
user_id=refresh_token_obj.user_id,
|
||||||
token_type="access_token",
|
token_type="access_token",
|
||||||
token_jti=access_token_jti,
|
token_jti=access_token_jti,
|
||||||
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
|
expires_at=access_token_expires_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log refresh event
|
# Log refresh event
|
||||||
@@ -615,6 +711,17 @@ class OIDCService:
|
|||||||
"id_token": id_token,
|
"id_token": id_token,
|
||||||
"refresh_token": new_refresh,
|
"refresh_token": new_refresh,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] refresh_access_token completed successfully")
|
||||||
|
logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z")
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": client.access_token_lifetime or 3600,
|
||||||
|
"id_token": id_token,
|
||||||
|
"refresh_token": new_refresh,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
||||||
@@ -630,10 +737,23 @@ class OIDCService:
|
|||||||
Raises:
|
Raises:
|
||||||
InvalidTokenError: If token is invalid
|
InvalidTokenError: If token is invalid
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
|
logger.debug("[OIDC SERVICE] validate_access_token() called")
|
||||||
|
logger.debug("[OIDC SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
|
||||||
|
logger.debug("[OIDC SERVICE] Token length: %d", len(token))
|
||||||
|
logger.debug("[OIDC SERVICE] Client ID: %s", client_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.debug("[OIDC SERVICE] Calling OIDCTokenService.validate_access_token()...")
|
||||||
claims = OIDCTokenService.validate_access_token(token, client_id)
|
claims = OIDCTokenService.validate_access_token(token, client_id)
|
||||||
|
logger.debug("[OIDC SERVICE] Token validation successful")
|
||||||
|
logger.debug("[OIDC SERVICE] Token claims: %s", claims)
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
return claims
|
return claims
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error("[OIDC SERVICE] Token validation failed: %s: %s", type(e).__name__, str(e))
|
||||||
|
import traceback
|
||||||
|
logger.error("[OIDC SERVICE] Traceback: %s", traceback.format_exc())
|
||||||
OIDCAuditService.log_event(
|
OIDCAuditService.log_event(
|
||||||
event_type="token_validation",
|
event_type="token_validation",
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
@@ -770,29 +890,67 @@ class OIDCService:
|
|||||||
Returns:
|
Returns:
|
||||||
User information dictionary
|
User information dictionary
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
|
logger.debug("[OIDC SERVICE] get_userinfo() called")
|
||||||
|
logger.debug("[OIDC SERVICE] Access token (first 50 chars): %s...", access_token[:50] if len(access_token) > 50 else access_token)
|
||||||
|
logger.debug("[OIDC SERVICE] Access token length: %d", len(access_token))
|
||||||
|
|
||||||
|
# Validate access token
|
||||||
|
logger.debug("[OIDC SERVICE] Validating access token...")
|
||||||
claims = cls.validate_access_token(access_token)
|
claims = cls.validate_access_token(access_token)
|
||||||
|
logger.debug("[OIDC SERVICE] Access token validated successfully")
|
||||||
|
logger.debug("[OIDC SERVICE] Token claims: %s", claims)
|
||||||
|
|
||||||
user_id = claims.get("sub")
|
user_id = claims.get("sub")
|
||||||
|
logger.debug("[OIDC SERVICE] User ID from token: %s", user_id)
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] Querying user from database...")
|
||||||
user = User.query.get(user_id)
|
user = User.query.get(user_id)
|
||||||
|
logger.debug("[OIDC SERVICE] User query result: %s", user)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
|
logger.error("[OIDC SERVICE] User not found in database: user_id=%s", user_id)
|
||||||
raise NotFoundError("User not found")
|
raise NotFoundError("User not found")
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] User found: user_id=%s, email=%s, full_name=%s", user.id, user.email, user.full_name)
|
||||||
|
|
||||||
# Get scopes from token
|
# Get scopes from token
|
||||||
scope_str = claims.get("scope", "")
|
scope_str = claims.get("scope", "")
|
||||||
scopes = scope_str.split() if scope_str else []
|
scopes = scope_str.split() if scope_str else []
|
||||||
|
logger.debug("[OIDC SERVICE] Scope string from token: '%s'", scope_str)
|
||||||
|
logger.debug("[OIDC SERVICE] Parsed scopes: %s", scopes)
|
||||||
|
|
||||||
userinfo = {"sub": user_id}
|
userinfo = {"sub": user_id}
|
||||||
|
logger.debug("[OIDC SERVICE] Initial userinfo: %s", userinfo)
|
||||||
|
|
||||||
# Add claims based on scope
|
# Add claims based on scope
|
||||||
if "profile" in scopes and user.full_name:
|
if "profile" in scopes and user.full_name:
|
||||||
|
logger.debug("[OIDC SERVICE] Found 'profile' in scope, adding name claim")
|
||||||
userinfo["name"] = user.full_name
|
userinfo["name"] = user.full_name
|
||||||
|
logger.debug("[OIDC SERVICE] Added name: %s", user.full_name)
|
||||||
|
else:
|
||||||
|
logger.debug("[OIDC SERVICE] 'profile' not in scope or user.full_name is None: profile_in_scope=%s, full_name=%s", "profile" in scopes, user.full_name)
|
||||||
|
|
||||||
if "email" in scopes:
|
if "email" in scopes:
|
||||||
|
logger.debug("[OIDC SERVICE] Found 'email' in scope, adding email claims")
|
||||||
userinfo["email"] = user.email
|
userinfo["email"] = user.email
|
||||||
userinfo["email_verified"] = user.email_verified
|
userinfo["email_verified"] = user.email_verified
|
||||||
|
logger.debug("[OIDC SERVICE] Added email: %s, email_verified: %s", user.email, user.email_verified)
|
||||||
|
else:
|
||||||
|
logger.debug("[OIDC SERVICE] 'email' not in scope")
|
||||||
|
|
||||||
|
if "roles" in scopes:
|
||||||
|
logger.debug("[OIDC SERVICE] Found 'roles' in scope, adding roles claim")
|
||||||
|
user_roles = cls._get_user_roles(user)
|
||||||
|
userinfo["roles"] = user_roles
|
||||||
|
logger.debug("[OIDC SERVICE] Added roles: %s", user_roles)
|
||||||
|
else:
|
||||||
|
logger.debug("[OIDC SERVICE] 'roles' not in scope")
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] Final userinfo: %s", userinfo)
|
||||||
|
|
||||||
# Log userinfo access
|
# Log userinfo access
|
||||||
|
logger.debug("[OIDC SERVICE] Logging userinfo access event...")
|
||||||
OIDCAuditService.log_userinfo_event(
|
OIDCAuditService.log_userinfo_event(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -800,5 +958,54 @@ class OIDCService:
|
|||||||
success=True,
|
success=True,
|
||||||
scopes_claimed=scopes,
|
scopes_claimed=scopes,
|
||||||
)
|
)
|
||||||
|
logger.debug("[OIDC SERVICE] Userinfo access event logged")
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] get_userinfo() completed successfully")
|
||||||
|
logger.debug("[OIDC SERVICE] ===========================================")
|
||||||
|
|
||||||
return userinfo
|
return userinfo
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_user_roles(user: User) -> list:
|
||||||
|
"""Get user's organization roles.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of role objects with organization_id and role
|
||||||
|
"""
|
||||||
|
logger.debug("[OIDC SERVICE] _get_user_roles() called")
|
||||||
|
logger.debug("[OIDC SERVICE] User: %s", user)
|
||||||
|
|
||||||
|
roles = []
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
logger.debug("[OIDC SERVICE] User is None, returning empty roles list")
|
||||||
|
return roles
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] User ID: %s", user.id)
|
||||||
|
logger.debug("[OIDC SERVICE] User email: %s", user.email)
|
||||||
|
logger.debug("[OIDC SERVICE] User organization_memberships: %s", user.organization_memberships)
|
||||||
|
|
||||||
|
if user.organization_memberships:
|
||||||
|
logger.debug("[OIDC SERVICE] User has %d organization memberships", len(user.organization_memberships))
|
||||||
|
for idx, member in enumerate(user.organization_memberships):
|
||||||
|
logger.debug("[OIDC SERVICE] Processing membership %d: member=%s", idx, member)
|
||||||
|
logger.debug("[OIDC SERVICE] organization_id: %s", member.organization_id)
|
||||||
|
logger.debug("[OIDC SERVICE] role: %s", member.role)
|
||||||
|
logger.debug("[OIDC SERVICE] role.value: %s", member.role.value)
|
||||||
|
|
||||||
|
role_entry = {
|
||||||
|
"organization_id": str(member.organization_id),
|
||||||
|
"role": member.role.value
|
||||||
|
}
|
||||||
|
roles.append(role_entry)
|
||||||
|
logger.debug("[OIDC SERVICE] Added role entry: %s", role_entry)
|
||||||
|
else:
|
||||||
|
logger.debug("[OIDC SERVICE] User has no organization memberships")
|
||||||
|
|
||||||
|
logger.debug("[OIDC SERVICE] Final roles list: %s", roles)
|
||||||
|
logger.debug("[OIDC SERVICE] _get_user_roles() completed")
|
||||||
|
|
||||||
|
return roles
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import secrets
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from datetime import timezone
|
||||||
from flask import current_app, g
|
from flask import current_app, g
|
||||||
|
|
||||||
from app.extensions import db
|
from app.extensions import db
|
||||||
@@ -219,11 +220,11 @@ class OIDCSessionService:
|
|||||||
"""
|
"""
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
cutoff = datetime.utcnow() - timedelta(hours=older_than_hours)
|
cutoff = datetime.now(timezone.utc) - timedelta(hours=older_than_hours)
|
||||||
|
|
||||||
# Get expired sessions
|
# Get expired sessions
|
||||||
expired_sessions = OIDCSession.query.filter(
|
expired_sessions = OIDCSession.query.filter(
|
||||||
OIDCSession.expires_at < datetime.utcnow(),
|
OIDCSession.expires_at < datetime.now(timezone.utc),
|
||||||
OIDCSession.deleted_at == None
|
OIDCSession.deleted_at == None
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
|
|||||||
@@ -2,15 +2,20 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import base64
|
import base64
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime, timedelta
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Dict, Optional, Any
|
from typing import Dict, Optional, Any
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from flask import current_app, g
|
from flask import current_app, g
|
||||||
|
|
||||||
from app.models import User, OIDCClient
|
from app.models import User, OIDCClient
|
||||||
|
from app.models.organization_member import OrganizationMember
|
||||||
from app.services.oidc_jwks_service import OIDCJWKSService
|
from app.services.oidc_jwks_service import OIDCJWKSService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OIDCTokenService:
|
class OIDCTokenService:
|
||||||
"""Service for generating and validating OIDC tokens.
|
"""Service for generating and validating OIDC tokens.
|
||||||
@@ -134,7 +139,7 @@ class OIDCTokenService:
|
|||||||
return lifetimes.get(token_type, 3600)
|
return lifetimes.get(token_type, 3600)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_access_token(cls, client_id: str, user_id: str, scope: list,
|
def create_access_token(cls, client_id: str, user_id: str, scope: list,
|
||||||
jti: str = None) -> str:
|
jti: str = None) -> str:
|
||||||
"""Create a JWT access token.
|
"""Create a JWT access token.
|
||||||
|
|
||||||
@@ -147,25 +152,44 @@ class OIDCTokenService:
|
|||||||
Returns:
|
Returns:
|
||||||
JWT access token string
|
JWT access token string
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] create_access_token called")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
|
||||||
|
|
||||||
jti = jti or cls._generate_jti()
|
jti = jti or cls._generate_jti()
|
||||||
now = datetime.utcnow()
|
now_timestamp = int(time.time())
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
|
||||||
|
|
||||||
# Get client for token lifetime
|
# Get client for token lifetime
|
||||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||||
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
|
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Access token lifetime (seconds): %s", lifetime)
|
||||||
|
|
||||||
|
exp_timestamp = now_timestamp + lifetime
|
||||||
|
exp_time = now + timedelta(seconds=lifetime)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Access token expiration time (UTC): %s", exp_time.isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Access token expiration timestamp: %s", exp_timestamp)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
|
||||||
|
|
||||||
claims = {
|
claims = {
|
||||||
"iss": cls._get_issuer(),
|
"iss": cls._get_issuer(),
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"aud": client_id,
|
"aud": client_id,
|
||||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
"exp": exp_timestamp,
|
||||||
"iat": int(now.timestamp()),
|
"iat": now_timestamp,
|
||||||
"nbf": int(now.timestamp()),
|
"nbf": now_timestamp,
|
||||||
"jti": jti,
|
"jti": jti,
|
||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"scope": " ".join(scope) if isinstance(scope, list) else scope,
|
"scope": " ".join(scope) if isinstance(scope, list) else scope,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, nbf=%s",
|
||||||
|
claims["exp"], claims["iat"], claims["nbf"])
|
||||||
|
|
||||||
# Get signing key
|
# Get signing key
|
||||||
jwks_service = OIDCJWKSService()
|
jwks_service = OIDCJWKSService()
|
||||||
signing_key = jwks_service.get_signing_key()
|
signing_key = jwks_service.get_signing_key()
|
||||||
@@ -174,6 +198,7 @@ class OIDCTokenService:
|
|||||||
raise ValueError("No signing key available")
|
raise ValueError("No signing key available")
|
||||||
|
|
||||||
# Sign with RS256
|
# Sign with RS256
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
|
||||||
token = jwt.encode(
|
token = jwt.encode(
|
||||||
claims,
|
claims,
|
||||||
signing_key.private_key,
|
signing_key.private_key,
|
||||||
@@ -181,6 +206,9 @@ class OIDCTokenService:
|
|||||||
headers={"kid": signing_key.kid}
|
headers={"kid": signing_key.kid}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Access token created successfully")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||||
return token
|
return token
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -200,12 +228,30 @@ class OIDCTokenService:
|
|||||||
Returns:
|
Returns:
|
||||||
JWT ID token string
|
JWT ID token string
|
||||||
"""
|
"""
|
||||||
now = datetime.utcnow()
|
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||||
auth_time = auth_time or int(now.timestamp())
|
logger.debug("[OIDC TOKEN SERVICE] create_id_token called")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
|
||||||
|
|
||||||
|
now_timestamp = int(time.time())
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
|
||||||
|
auth_time = auth_time or now_timestamp
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] auth_time (Unix timestamp): %s", auth_time)
|
||||||
|
|
||||||
# Get client for token lifetime
|
# Get client for token lifetime
|
||||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||||
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
|
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ID token lifetime (seconds): %s", lifetime)
|
||||||
|
|
||||||
|
exp_timestamp = now_timestamp + lifetime
|
||||||
|
exp_time = now + timedelta(seconds=lifetime)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ID token expiration time (UTC): %s", exp_time.isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ID token expiration timestamp: %s", exp_timestamp)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
|
||||||
|
|
||||||
# Get user for claims
|
# Get user for claims
|
||||||
user = User.query.get(user_id)
|
user = User.query.get(user_id)
|
||||||
@@ -214,11 +260,14 @@ class OIDCTokenService:
|
|||||||
"iss": cls._get_issuer(),
|
"iss": cls._get_issuer(),
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"aud": client_id,
|
"aud": client_id,
|
||||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
"exp": exp_timestamp,
|
||||||
"iat": int(now.timestamp()),
|
"iat": now_timestamp,
|
||||||
"auth_time": auth_time,
|
"auth_time": auth_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, auth_time=%s",
|
||||||
|
claims["exp"], claims["iat"], claims["auth_time"])
|
||||||
|
|
||||||
# Add nonce if provided
|
# Add nonce if provided
|
||||||
if nonce:
|
if nonce:
|
||||||
claims["nonce"] = nonce
|
claims["nonce"] = nonce
|
||||||
@@ -235,6 +284,10 @@ class OIDCTokenService:
|
|||||||
if user.full_name:
|
if user.full_name:
|
||||||
claims["name"] = user.full_name
|
claims["name"] = user.full_name
|
||||||
|
|
||||||
|
# Add roles claim if scope is granted
|
||||||
|
if scope and "roles" in scope:
|
||||||
|
claims["roles"] = cls._get_user_roles(user)
|
||||||
|
|
||||||
# Add scope if provided
|
# Add scope if provided
|
||||||
if scope:
|
if scope:
|
||||||
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
|
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
|
||||||
@@ -247,6 +300,7 @@ class OIDCTokenService:
|
|||||||
raise ValueError("No signing key available")
|
raise ValueError("No signing key available")
|
||||||
|
|
||||||
# Sign with RS256
|
# Sign with RS256
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
|
||||||
token = jwt.encode(
|
token = jwt.encode(
|
||||||
claims,
|
claims,
|
||||||
signing_key.private_key,
|
signing_key.private_key,
|
||||||
@@ -254,10 +308,32 @@ class OIDCTokenService:
|
|||||||
headers={"kid": signing_key.kid}
|
headers={"kid": signing_key.kid}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ID token created successfully")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_user_roles(user: User) -> list:
|
||||||
|
"""Get user's organization roles.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of role objects with organization_id and role
|
||||||
|
"""
|
||||||
|
roles = []
|
||||||
|
if user and user.organization_memberships:
|
||||||
|
for member in user.organization_memberships:
|
||||||
|
roles.append({
|
||||||
|
"organization_id": str(member.organization_id),
|
||||||
|
"role": member.role.value
|
||||||
|
})
|
||||||
|
return roles
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_refresh_token(cls, client_id: str, user_id: str,
|
def create_refresh_token(cls, client_id: str, user_id: str,
|
||||||
scope: list = None, access_token_id: str = None) -> str:
|
scope: list = None, access_token_id: str = None) -> str:
|
||||||
"""Create an opaque refresh token.
|
"""Create an opaque refresh token.
|
||||||
|
|
||||||
@@ -270,11 +346,21 @@ class OIDCTokenService:
|
|||||||
Returns:
|
Returns:
|
||||||
Opaque refresh token string
|
Opaque refresh token string
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] create_refresh_token called")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] scope=%s, access_token_id=%s", scope, access_token_id)
|
||||||
|
|
||||||
token = cls._generate_opaque_token()
|
token = cls._generate_opaque_token()
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Refresh token generated: %s...", token[:20] if token else None)
|
||||||
|
|
||||||
# Hash for storage
|
# Hash for storage
|
||||||
token_hash = cls._hash_token(token)
|
token_hash = cls._hash_token(token)
|
||||||
|
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Refresh token created successfully")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||||
return token, token_hash
|
return token, token_hash
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -292,54 +378,91 @@ class OIDCTokenService:
|
|||||||
jwt.ExpiredSignatureError: If token is expired
|
jwt.ExpiredSignatureError: If token is expired
|
||||||
jwt.InvalidTokenError: If token is invalid
|
jwt.InvalidTokenError: If token is invalid
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] verify_token_signature() called")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
|
||||||
|
|
||||||
# Get the JWKS with public keys
|
# Get the JWKS with public keys
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Getting JWKS...")
|
||||||
jwks_service = OIDCJWKSService()
|
jwks_service = OIDCJWKSService()
|
||||||
jwks = jwks_service.get_jwks()
|
jwks = jwks_service.get_jwks(include_private_keys=True)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] JWKS retrieved: %d keys", len(jwks.get("keys", [])))
|
||||||
|
|
||||||
# Get the key ID from token header
|
# Get the key ID from token header
|
||||||
try:
|
try:
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Getting unverified token header...")
|
||||||
unverified_header = jwt.get_unverified_header(token)
|
unverified_header = jwt.get_unverified_header(token)
|
||||||
except jwt.DecodeError:
|
logger.debug("[OIDC TOKEN SERVICE] Unverified header: %s", unverified_header)
|
||||||
|
except jwt.DecodeError as e:
|
||||||
|
logger.error("[OIDC TOKEN SERVICE] Failed to decode token header: %s", str(e))
|
||||||
raise jwt.InvalidTokenError("Invalid token header")
|
raise jwt.InvalidTokenError("Invalid token header")
|
||||||
|
|
||||||
kid = unverified_header.get("kid")
|
kid = unverified_header.get("kid")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Key ID (kid) from token header: %s", kid)
|
||||||
|
|
||||||
# Find the matching public key
|
# Find the matching public key
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Searching for matching public key...")
|
||||||
public_key = None
|
public_key = None
|
||||||
for key in jwks.get("keys", []):
|
for idx, key in enumerate(jwks.get("keys", [])):
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Checking key %d: kid=%s", idx, key.get("kid"))
|
||||||
if key.get("kid") == kid:
|
if key.get("kid") == kid:
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Found matching key at index %d", idx)
|
||||||
try:
|
try:
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Loading PEM public key...")
|
||||||
public_key = serialization.load_pem_public_key(
|
public_key = serialization.load_pem_public_key(
|
||||||
key["public_key"].encode() if isinstance(key["public_key"], str)
|
key["public_key"].encode() if isinstance(key["public_key"], str)
|
||||||
else key["public_key"],
|
else key["public_key"],
|
||||||
backend=default_backend()
|
backend=default_backend()
|
||||||
)
|
)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Public key loaded successfully")
|
||||||
break
|
break
|
||||||
except (ImportError, Exception):
|
except (ImportError, Exception) as e:
|
||||||
|
logger.error("[OIDC TOKEN SERVICE] Failed to load public key: %s: %s", type(e).__name__, str(e))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not public_key:
|
if not public_key:
|
||||||
|
logger.error("[OIDC TOKEN SERVICE] No matching public key found for kid=%s", kid)
|
||||||
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
|
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
|
||||||
|
|
||||||
# Verify the signature
|
logger.debug("[OIDC TOKEN SERVICE] Public key found, verifying signature...")
|
||||||
claims = jwt.decode(
|
|
||||||
token,
|
|
||||||
public_key,
|
|
||||||
algorithms=["RS256"],
|
|
||||||
audience=None, # We'll validate audience separately
|
|
||||||
issuer=cls._get_issuer(),
|
|
||||||
options={
|
|
||||||
"verify_signature": True,
|
|
||||||
"verify_exp": True,
|
|
||||||
"verify_aud": False, # Handle audience manually
|
|
||||||
"verify_iss": False, # Handle issuer manually
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return claims
|
# Verify the signature
|
||||||
|
try:
|
||||||
|
claims = jwt.decode(
|
||||||
|
token,
|
||||||
|
public_key,
|
||||||
|
algorithms=["RS256"],
|
||||||
|
audience=None, # We'll validate audience separately
|
||||||
|
issuer=cls._get_issuer(),
|
||||||
|
options={
|
||||||
|
"verify_signature": True,
|
||||||
|
"verify_exp": True,
|
||||||
|
"verify_aud": False, # Handle audience manually
|
||||||
|
"verify_iss": False, # Handle issuer manually
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Signature verification successful")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Decoded claims: %s", claims)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||||
|
return claims
|
||||||
|
except jwt.ExpiredSignatureError as e:
|
||||||
|
logger.error("[OIDC TOKEN SERVICE] Token has expired: %s", str(e))
|
||||||
|
raise
|
||||||
|
except jwt.InvalidSignatureError as e:
|
||||||
|
logger.error("[OIDC TOKEN SERVICE] Invalid token signature: %s", str(e))
|
||||||
|
raise
|
||||||
|
except jwt.InvalidTokenError as e:
|
||||||
|
logger.error("[OIDC TOKEN SERVICE] Invalid token: %s: %s", type(e).__name__, str(e))
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("[OIDC TOKEN SERVICE] Unexpected error during token verification: %s: %s", type(e).__name__, str(e))
|
||||||
|
import traceback
|
||||||
|
logger.error("[OIDC TOKEN SERVICE] Traceback: %s", traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_token(cls, token: str, verify: bool = False) -> Dict:
|
def decode_token(cls, token: str, verify: bool = False) -> Dict:
|
||||||
@@ -378,16 +501,41 @@ class OIDCTokenService:
|
|||||||
jwt.InvalidTokenError: If token is invalid
|
jwt.InvalidTokenError: If token is invalid
|
||||||
ValueError: If token is expired or audience mismatch
|
ValueError: If token is expired or audience mismatch
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() called")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Client ID: %s", client_id)
|
||||||
|
|
||||||
|
# Verify token signature
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Verifying token signature...")
|
||||||
claims = cls.verify_token_signature(token)
|
claims = cls.verify_token_signature(token)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token signature verified")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Claims: %s", claims)
|
||||||
|
|
||||||
# Check expiration
|
# Check expiration
|
||||||
if claims.get("exp", 0) < datetime.utcnow().timestamp():
|
exp = claims.get("exp", 0)
|
||||||
|
now_timestamp = int(time.time())
|
||||||
|
|
||||||
|
if exp < now_timestamp:
|
||||||
|
logger.error("[OIDC TOKEN SERVICE] Token has expired")
|
||||||
raise ValueError("Token has expired")
|
raise ValueError("Token has expired")
|
||||||
|
|
||||||
# Validate audience if client_id provided
|
# Validate audience if client_id provided
|
||||||
|
aud = claims.get("aud")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Token audience (aud): %s", aud)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Expected client_id: %s", client_id)
|
||||||
|
|
||||||
if client_id:
|
if client_id:
|
||||||
if claims.get("aud") != client_id:
|
if aud != client_id:
|
||||||
|
logger.error("[OIDC TOKEN SERVICE] Audience mismatch: expected=%s, got=%s", client_id, aud)
|
||||||
raise ValueError("Invalid audience")
|
raise ValueError("Invalid audience")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Audience validation passed")
|
||||||
|
else:
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] No client_id provided, skipping audience validation")
|
||||||
|
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() completed successfully")
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||||
|
|
||||||
return claims
|
return claims
|
||||||
|
|
||||||
@@ -410,11 +558,17 @@ class OIDCTokenService:
|
|||||||
claims = cls.validate_access_token(token, client_id)
|
claims = cls.validate_access_token(token, client_id)
|
||||||
|
|
||||||
# Calculate remaining time
|
# Calculate remaining time
|
||||||
now = datetime.utcnow().timestamp()
|
now_timestamp = int(time.time())
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
exp = claims.get("exp", 0)
|
exp = claims.get("exp", 0)
|
||||||
iat = claims.get("iat", 0)
|
iat = claims.get("iat", 0)
|
||||||
|
|
||||||
result["active"] = exp > now
|
logger.debug("[OIDC TOKEN SERVICE] Introspection - Current UTC time: %s", now.isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration timestamp: %s", exp)
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration datetime (UTC): %s", datetime.fromtimestamp(exp, tz=timezone.utc).isoformat())
|
||||||
|
logger.debug("[OIDC TOKEN SERVICE] Introspection - Time until expiration: %s seconds", exp - now_timestamp)
|
||||||
|
|
||||||
|
result["active"] = exp > now_timestamp
|
||||||
result.update({
|
result.update({
|
||||||
"iss": claims.get("iss"),
|
"iss": claims.get("iss"),
|
||||||
"sub": claims.get("sub"),
|
"sub": claims.get("sub"),
|
||||||
@@ -429,8 +583,8 @@ class OIDCTokenService:
|
|||||||
})
|
})
|
||||||
|
|
||||||
# Add expiry in seconds
|
# Add expiry in seconds
|
||||||
if exp > now:
|
if exp > now_timestamp:
|
||||||
result["exp"] = int(exp - now)
|
result["exp"] = int(exp - now_timestamp)
|
||||||
|
|
||||||
except (jwt.InvalidTokenError, ValueError) as e:
|
except (jwt.InvalidTokenError, ValueError) as e:
|
||||||
result["active"] = False
|
result["active"] = False
|
||||||
|
|||||||
@@ -0,0 +1,188 @@
|
|||||||
|
"""TOTP (Time-based One-Time Password) service."""
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pyotp
|
||||||
|
from app.extensions import bcrypt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TOTPService:
|
||||||
|
"""Service for TOTP operations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_secret() -> str:
|
||||||
|
"""
|
||||||
|
Generate a new TOTP secret.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base32 encoded secret (32 characters)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The secret is generated using cryptographically secure random bytes
|
||||||
|
and encoded in base32 format for compatibility with authenticator apps.
|
||||||
|
"""
|
||||||
|
# Generate 20 random bytes (160 bits) and encode as base32
|
||||||
|
random_bytes = secrets.token_bytes(20)
|
||||||
|
secret = base64.b32encode(random_bytes).decode("utf-8")
|
||||||
|
logger.debug(f"Generated new TOTP secret: {secret[:8]}...")
|
||||||
|
return secret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_provisioning_uri(user_email: str, secret: str, issuer: str = "Gatehouse") -> str:
|
||||||
|
"""
|
||||||
|
Generate provisioning URI for QR code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_email: User's email address
|
||||||
|
secret: TOTP secret (base32 encoded)
|
||||||
|
issuer: Issuer name (default: "Gatehouse")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
otpauth:// URI for QR code generation
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> uri = TOTPService.generate_provisioning_uri("user@example.com", "JBSWY3DPEHPK3PXP")
|
||||||
|
>>> print(uri)
|
||||||
|
otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse
|
||||||
|
"""
|
||||||
|
totp = pyotp.TOTP(secret)
|
||||||
|
uri = totp.provisioning_uri(name=user_email, issuer_name=issuer)
|
||||||
|
logger.debug(f"Generated provisioning URI for user: {user_email}")
|
||||||
|
return uri
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_code(secret: str, code: str, window: int = 1) -> bool:
|
||||||
|
"""
|
||||||
|
Verify a TOTP code against the secret.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secret: TOTP secret (base32 encoded)
|
||||||
|
code: 6-digit TOTP code to verify
|
||||||
|
window: Time window for code validation (default: 1, allows codes from previous/next time steps)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if code is valid, False otherwise
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The window parameter allows for clock skew between the server
|
||||||
|
and the authenticator app. A window of 1 allows codes from
|
||||||
|
the previous, current, and next 30-second intervals.
|
||||||
|
"""
|
||||||
|
totp = pyotp.TOTP(secret)
|
||||||
|
is_valid = totp.verify(code, valid_window=window)
|
||||||
|
logger.debug(f"TOTP code verification: valid={is_valid}, window={window}")
|
||||||
|
return is_valid
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_backup_codes(count: int = 10) -> Tuple[list[str], list[str]]:
|
||||||
|
"""
|
||||||
|
Generate backup codes for TOTP recovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
count: Number of backup codes to generate (default: 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (plain_codes, hashed_codes)
|
||||||
|
- plain_codes: List of plain text backup codes (for display to user)
|
||||||
|
- hashed_codes: List of bcrypt hashed backup codes (for storage)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Backup codes are 16-character alphanumeric codes that can be used
|
||||||
|
to recover access if the TOTP device is lost. Each code can only
|
||||||
|
be used once.
|
||||||
|
"""
|
||||||
|
plain_codes = []
|
||||||
|
hashed_codes = []
|
||||||
|
|
||||||
|
for _ in range(count):
|
||||||
|
# Generate a 16-character alphanumeric code
|
||||||
|
code = secrets.token_hex(8).upper()
|
||||||
|
plain_codes.append(code)
|
||||||
|
|
||||||
|
# Hash the code using bcrypt
|
||||||
|
hashed_code = bcrypt.generate_password_hash(code).decode("utf-8")
|
||||||
|
hashed_codes.append(hashed_code)
|
||||||
|
|
||||||
|
logger.debug(f"Generated {count} backup codes")
|
||||||
|
return plain_codes, hashed_codes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_backup_code(hashed_codes: list[str], code: str) -> Tuple[bool, list[str]]:
|
||||||
|
"""
|
||||||
|
Verify and consume a backup code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hashed_codes: List of bcrypt hashed backup codes
|
||||||
|
code: Plain text backup code to verify
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, remaining_codes)
|
||||||
|
- is_valid: True if code was valid and consumed, False otherwise
|
||||||
|
- remaining_codes: List of remaining hashed codes (with consumed code removed)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Once a backup code is used, it is removed from the list and cannot
|
||||||
|
be used again. This ensures each code is single-use.
|
||||||
|
"""
|
||||||
|
remaining_codes = []
|
||||||
|
|
||||||
|
for hashed_code in hashed_codes:
|
||||||
|
if bcrypt.check_password_hash(hashed_code, code):
|
||||||
|
# Code found and valid - don't add to remaining codes (consumed)
|
||||||
|
logger.debug("Backup code verified and consumed")
|
||||||
|
return True, remaining_codes
|
||||||
|
else:
|
||||||
|
# Code doesn't match - keep it in remaining codes
|
||||||
|
remaining_codes.append(hashed_code)
|
||||||
|
|
||||||
|
logger.debug("Backup code verification failed")
|
||||||
|
return False, remaining_codes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_qr_code_data_uri(provisioning_uri: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate QR code as data URI for frontend display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provisioning_uri: otpauth:// URI to encode in QR code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64 encoded PNG image as data URI (data:image/png;base64,...)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the qrcode library is not installed, returns a placeholder message.
|
||||||
|
Install with: pip install qrcode[pil]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import qrcode
|
||||||
|
|
||||||
|
# Create QR code
|
||||||
|
qr = qrcode.QRCode(
|
||||||
|
version=1,
|
||||||
|
error_correction=qrcode.constants.ERROR_CORRECT_L,
|
||||||
|
box_size=10,
|
||||||
|
border=4,
|
||||||
|
)
|
||||||
|
qr.add_data(provisioning_uri)
|
||||||
|
qr.make(fit=True)
|
||||||
|
|
||||||
|
# Generate image
|
||||||
|
img = qr.make_image(fill_color="black", back_color="white")
|
||||||
|
|
||||||
|
# Convert to base64
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
img.save(buffer, format="PNG")
|
||||||
|
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
data_uri = f"data:image/png;base64,{img_base64}"
|
||||||
|
logger.debug("Generated QR code data URI")
|
||||||
|
return data_uri
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("qrcode library not installed, returning placeholder")
|
||||||
|
return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]"
|
||||||
@@ -24,6 +24,7 @@ class AuthMethodType(str, Enum):
|
|||||||
"""Authentication method types."""
|
"""Authentication method types."""
|
||||||
|
|
||||||
PASSWORD = "password"
|
PASSWORD = "password"
|
||||||
|
TOTP = "totp"
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
GITHUB = "github"
|
GITHUB = "github"
|
||||||
MICROSOFT = "microsoft"
|
MICROSOFT = "microsoft"
|
||||||
@@ -66,6 +67,13 @@ class AuditAction(str, Enum):
|
|||||||
# Auth method actions
|
# Auth method actions
|
||||||
AUTH_METHOD_ADD = "auth.method.add"
|
AUTH_METHOD_ADD = "auth.method.add"
|
||||||
AUTH_METHOD_REMOVE = "auth.method.remove"
|
AUTH_METHOD_REMOVE = "auth.method.remove"
|
||||||
|
TOTP_ENROLL_INITIATED = "totp.enroll.initiated"
|
||||||
|
TOTP_ENROLL_COMPLETED = "totp.enroll.completed"
|
||||||
|
TOTP_VERIFY_SUCCESS = "totp.verify.success"
|
||||||
|
TOTP_VERIFY_FAILED = "totp.verify.failed"
|
||||||
|
TOTP_DISABLED = "totp.disabled"
|
||||||
|
TOTP_BACKUP_CODE_USED = "totp.backup_code.used"
|
||||||
|
TOTP_BACKUP_CODES_REGENERATED = "totp.backup_codes.regenerated"
|
||||||
|
|
||||||
|
|
||||||
class OIDCGrantType(str, Enum):
|
class OIDCGrantType(str, Enum):
|
||||||
|
|||||||
+102
@@ -0,0 +1,102 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
ISSUER="https://oidctest.wsweet.org"
|
||||||
|
CLIENT_ID="secret"
|
||||||
|
CLIENT_SECRET="tardis"
|
||||||
|
REDIRECT_URI="http://127.0.0.1:5556/callback"
|
||||||
|
SCOPE="openid profile email offline_access"
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# Discover OIDC endpoints
|
||||||
|
# ---------------------------
|
||||||
|
DISCOVERY=$(curl -s "$ISSUER/.well-known/openid-configuration")
|
||||||
|
|
||||||
|
AUTH_ENDPOINT=$(echo "$DISCOVERY" | jq -r .authorization_endpoint)
|
||||||
|
TOKEN_ENDPOINT=$(echo "$DISCOVERY" | jq -r .token_endpoint)
|
||||||
|
USERINFO_ENDPOINT=$(echo "$DISCOVERY" | jq -r .userinfo_endpoint)
|
||||||
|
|
||||||
|
echo "Auth endpoint : $AUTH_ENDPOINT"
|
||||||
|
echo "Token endpoint: $TOKEN_ENDPOINT"
|
||||||
|
echo
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# PKCE
|
||||||
|
# ---------------------------
|
||||||
|
CODE_VERIFIER=$(openssl rand -base64 32 | tr -d '=+/')
|
||||||
|
CODE_CHALLENGE=$(echo -n "$CODE_VERIFIER" | openssl dgst -sha256 -binary | openssl base64 | tr -d '=+/' | tr '/+' '_-')
|
||||||
|
|
||||||
|
STATE=$(openssl rand -hex 16)
|
||||||
|
NONCE=$(openssl rand -hex 16)
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# Build auth URL
|
||||||
|
# ---------------------------
|
||||||
|
AUTH_URL="$AUTH_ENDPOINT?response_type=code\
|
||||||
|
&client_id=$CLIENT_ID\
|
||||||
|
&redirect_uri=$(printf '%s' "$REDIRECT_URI" | jq -s -R -r @uri)\
|
||||||
|
&scope=$(printf '%s' "$SCOPE" | jq -s -R -r @uri)\
|
||||||
|
&state=$STATE\
|
||||||
|
&nonce=$NONCE\
|
||||||
|
&code_challenge=$CODE_CHALLENGE\
|
||||||
|
&code_challenge_method=S256"
|
||||||
|
|
||||||
|
echo "Open this URL in a browser:"
|
||||||
|
echo
|
||||||
|
echo "$AUTH_URL"
|
||||||
|
echo
|
||||||
|
echo "After login you will be redirected to:"
|
||||||
|
echo "$REDIRECT_URI?code=XXXX&state=YYYY"
|
||||||
|
echo
|
||||||
|
read -p "Paste the full redirect URL: " REDIRECT
|
||||||
|
|
||||||
|
CODE=$(echo "$REDIRECT" | sed -n 's/.*code=\([^&]*\).*/\1/p')
|
||||||
|
RETURNED_STATE=$(echo "$REDIRECT" | sed -n 's/.*state=\([^&]*\).*/\1/p')
|
||||||
|
|
||||||
|
if [ "$RETURNED_STATE" != "$STATE" ]; then
|
||||||
|
echo "STATE MISMATCH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# Exchange code for tokens
|
||||||
|
# ---------------------------
|
||||||
|
TOKENS=$(curl -s -X POST "$TOKEN_ENDPOINT" \
|
||||||
|
-u "$CLIENT_ID:$CLIENT_SECRET" \
|
||||||
|
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||||
|
-d "grant_type=authorization_code" \
|
||||||
|
-d "code=$CODE" \
|
||||||
|
-d "redirect_uri=$REDIRECT_URI" \
|
||||||
|
-d "code_verifier=$CODE_VERIFIER")
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "Token response:"
|
||||||
|
echo "$TOKENS" | jq .
|
||||||
|
|
||||||
|
ACCESS_TOKEN=$(echo "$TOKENS" | jq -r .access_token)
|
||||||
|
ID_TOKEN=$(echo "$TOKENS" | jq -r .id_token)
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# JWT decode function
|
||||||
|
# ---------------------------
|
||||||
|
decode() {
|
||||||
|
echo "$1" | awk -F. '{print $2}' | tr '_-' '/+' | base64 -d 2>/dev/null | jq .
|
||||||
|
}
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "================ ID TOKEN ================"
|
||||||
|
decode "$ID_TOKEN"
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "============== ACCESS TOKEN =============="
|
||||||
|
decode "$ACCESS_TOKEN"
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# Userinfo (optional)
|
||||||
|
# ---------------------------
|
||||||
|
if [ "$USERINFO_ENDPOINT" != "null" ]; then
|
||||||
|
echo
|
||||||
|
echo "=============== USERINFO ================="
|
||||||
|
curl -s -H "Authorization: Bearer $ACCESS_TOKEN" "$USERINFO_ENDPOINT" | jq .
|
||||||
|
fi
|
||||||
|
|
||||||
@@ -16,6 +16,7 @@ marshmallow-sqlalchemy==0.29.0
|
|||||||
# Security
|
# Security
|
||||||
bcrypt==4.1.2
|
bcrypt==4.1.2
|
||||||
Flask-Bcrypt==1.0.1
|
Flask-Bcrypt==1.0.1
|
||||||
|
pyotp==2.9.0
|
||||||
|
|
||||||
# JWT / OIDC
|
# JWT / OIDC
|
||||||
PyJWT==2.8.0
|
PyJWT==2.8.0
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html>
|
||||||
|
<body>
|
||||||
|
<h1>OK - protected</h1>
|
||||||
|
<p>User: __USER__</p>
|
||||||
|
<p>Email: __EMAIL__</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
version: "3.9"
|
||||||
|
|
||||||
|
services:
|
||||||
|
nginx:
|
||||||
|
image: nginx:1.27-alpine
|
||||||
|
container_name: app-nginx
|
||||||
|
volumes:
|
||||||
|
- ./app:/usr/share/nginx/html:ro
|
||||||
|
- ./nginx.conf:/etc/nginx/conf.d/default.conf:ro
|
||||||
|
|
||||||
|
expose:
|
||||||
|
- "80"
|
||||||
|
|
||||||
|
# oauth2-proxy:
|
||||||
|
# image: quay.io/oauth2-proxy/oauth2-proxy:v7.7.1
|
||||||
|
# container_name: oauth2-proxy
|
||||||
|
# depends_on:
|
||||||
|
# - nginx
|
||||||
|
# ports:
|
||||||
|
# - "8086:4180"
|
||||||
|
# environment:
|
||||||
|
|
||||||
|
# # ----- Logging -----
|
||||||
|
# OAUTH2_PROXY_LOG_LEVEL: "debug"
|
||||||
|
# OAUTH2_PROXY_AUTH_LOGGING: "true"
|
||||||
|
# OAUTH2_PROXY_REQUEST_LOGGING: "true"
|
||||||
|
# OAUTH2_PROXY_STANDARD_LOGGING: "true"
|
||||||
|
|
||||||
|
|
||||||
|
# # ----- Gatehouse OIDC -----
|
||||||
|
# OAUTH2_PROXY_PROVIDER: oidc
|
||||||
|
# OAUTH2_PROXY_OIDC_ISSUER_URL: "http://192.168.64.7:8888"
|
||||||
|
# OAUTH2_PROXY_CLIENT_ID: "acme-portal-001"
|
||||||
|
# OAUTH2_PROXY_CLIENT_SECRET: "acme_secret_portal_2024"
|
||||||
|
# OAUTH2_PROXY_REDIRECT_URL: "http://192.168.64.7:8086/oauth2/callback"
|
||||||
|
|
||||||
|
# # ----- Session -----
|
||||||
|
# OAUTH2_PROXY_COOKIE_SECRET: "afhXcfftf5qKezJ217qhED7U4UeVyqBHd7lhISNGpXo="
|
||||||
|
# OAUTH2_PROXY_COOKIE_SECURE: "false"
|
||||||
|
# OAUTH2_PROXY_COOKIE_SAMESITE: "lax"
|
||||||
|
# OAUTH2_PROXY_EMAIL_DOMAINS: "*"
|
||||||
|
# OAUTH2_PROXY_HTTP_ADDRESS: "0.0.0.0:4180"
|
||||||
|
|
||||||
|
|
||||||
|
# # ----- Upstream -----
|
||||||
|
# OAUTH2_PROXY_UPSTREAMS: "http://nginx:80"
|
||||||
|
|
||||||
|
# # ----- Identity headers -----
|
||||||
|
# OAUTH2_PROXY_SET_XAUTHREQUEST: "true"
|
||||||
|
# OAUTH2_PROXY_PASS_ACCESS_TOKEN: "true"
|
||||||
|
# OAUTH2_PROXY_PASS_AUTHORIZATION_HEADER: "true"
|
||||||
|
|
||||||
|
# # ----- Claim mapping (Gatehouse) -----
|
||||||
|
# OAUTH2_PROXY_OIDC_EMAIL_CLAIM: "email"
|
||||||
|
# OAUTH2_PROXY_USER_ID_CLAIM: "preferred_username"
|
||||||
|
# OAUTH2_PROXY_OIDC_GROUPS_CLAIM: "roles"
|
||||||
|
|
||||||
|
# OAUTH2_PROXY_SKIP_PROVIDER_BUTTON: "true"
|
||||||
|
|
||||||
|
|
||||||
|
# OAUTH2_PROXY_SCOPE: "openid email profile"
|
||||||
|
# # OAUTH2_PROXY_OIDC_EMAIL_CLAIM: "email"
|
||||||
|
# OAUTH2_PROXY_INSECURE_OIDC_ALLOW_UNVERIFIED_EMAIL: "true"
|
||||||
|
|
||||||
|
|
||||||
|
oauth2-proxy:
|
||||||
|
image: quay.io/oauth2-proxy/oauth2-proxy:v7.7.1
|
||||||
|
container_name: oauth2-proxy
|
||||||
|
depends_on: [nginx]
|
||||||
|
ports:
|
||||||
|
- "8086:4180"
|
||||||
|
command:
|
||||||
|
- --provider=oidc
|
||||||
|
- --oidc-issuer-url=http://192.168.64.7:8888
|
||||||
|
- --client-id=acme-portal-001
|
||||||
|
- --client-secret=acme_secret_portal_2024
|
||||||
|
- --redirect-url=http://192.168.64.7:8086/oauth2/callback
|
||||||
|
- --scope=openid email profile
|
||||||
|
- --email-domain=*
|
||||||
|
- --upstream=http://nginx:80
|
||||||
|
- --http-address=0.0.0.0:4180
|
||||||
|
- --cookie-secret=afhXcfftf5qKezJ217qhED7U4UeVyqBHd7lhISNGpXo=
|
||||||
|
- --cookie-secure=false
|
||||||
|
- --cookie-samesite=lax
|
||||||
|
- --skip-provider-button=true
|
||||||
|
- --standard-logging=true
|
||||||
|
- --request-logging=true
|
||||||
|
- --auth-logging=true
|
||||||
|
- --set-xauthrequest=true
|
||||||
|
- --pass-authorization-header=true # optional (token passthrough)
|
||||||
|
- --pass-access-token=true # optional (token passthrough)
|
||||||
|
- --pass-user-headers=true
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
server {
|
||||||
|
listen 80;
|
||||||
|
server_name _;
|
||||||
|
|
||||||
|
root /usr/share/nginx/html;
|
||||||
|
index index.html;
|
||||||
|
|
||||||
|
location / {
|
||||||
|
try_files $uri $uri/ =404;
|
||||||
|
}
|
||||||
|
|
||||||
|
location = /whoami {
|
||||||
|
default_type text/plain;
|
||||||
|
return 200
|
||||||
|
"user: $http_x_auth_request_user
|
||||||
|
email: $http_x_auth_request_email
|
||||||
|
preferred_username: $http_x_forwarded_preferred_username
|
||||||
|
x-forwarded-user: $http_x_forwarded_user
|
||||||
|
x-forwarded-email: $http_x_forwarded_email
|
||||||
|
authorization: $http_authorization
|
||||||
|
";
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,285 @@
|
|||||||
|
"""Unit tests for TOTPService."""
|
||||||
|
import base64
|
||||||
|
import pytest
|
||||||
|
from app.services.totp_service import TOTPService
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestTOTPService:
|
||||||
|
"""Tests for TOTPService."""
|
||||||
|
|
||||||
|
# Test generate_secret()
|
||||||
|
def test_generate_secret_returns_string(self):
|
||||||
|
"""Test that generate_secret returns a string."""
|
||||||
|
secret = TOTPService.generate_secret()
|
||||||
|
assert isinstance(secret, str)
|
||||||
|
|
||||||
|
def test_generate_secret_length(self):
|
||||||
|
"""Test that generate_secret returns a 32-character string."""
|
||||||
|
secret = TOTPService.generate_secret()
|
||||||
|
assert len(secret) == 32
|
||||||
|
|
||||||
|
def test_generate_secret_base32_encoded(self):
|
||||||
|
"""Test that generate_secret returns a base32 encoded string."""
|
||||||
|
secret = TOTPService.generate_secret()
|
||||||
|
# Base32 characters are A-Z and 2-7
|
||||||
|
valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567")
|
||||||
|
assert all(c in valid_chars for c in secret)
|
||||||
|
|
||||||
|
def test_generate_secret_unique(self):
|
||||||
|
"""Test that generate_secret produces unique secrets."""
|
||||||
|
secret1 = TOTPService.generate_secret()
|
||||||
|
secret2 = TOTPService.generate_secret()
|
||||||
|
assert secret1 != secret2
|
||||||
|
|
||||||
|
# Test generate_provisioning_uri()
|
||||||
|
def test_generate_provisioning_uri_format(self):
|
||||||
|
"""Test that provisioning URI is generated correctly."""
|
||||||
|
email = "user@example.com"
|
||||||
|
secret = "JBSWY3DPEHPK3PXP"
|
||||||
|
issuer = "Gatehouse"
|
||||||
|
|
||||||
|
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
|
||||||
|
|
||||||
|
assert isinstance(uri, str)
|
||||||
|
assert uri.startswith("otpauth://totp/")
|
||||||
|
|
||||||
|
def test_generate_provisioning_uri_contains_email(self):
|
||||||
|
"""Test that provisioning URI contains the user email."""
|
||||||
|
email = "user@example.com"
|
||||||
|
secret = "JBSWY3DPEHPK3PXP"
|
||||||
|
issuer = "Gatehouse"
|
||||||
|
|
||||||
|
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
|
||||||
|
|
||||||
|
assert email in uri
|
||||||
|
|
||||||
|
def test_generate_provisioning_uri_contains_secret(self):
|
||||||
|
"""Test that provisioning URI contains the secret."""
|
||||||
|
email = "user@example.com"
|
||||||
|
secret = "JBSWY3DPEHPK3PXP"
|
||||||
|
issuer = "Gatehouse"
|
||||||
|
|
||||||
|
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
|
||||||
|
|
||||||
|
assert secret in uri
|
||||||
|
|
||||||
|
def test_generate_provisioning_uri_contains_issuer(self):
|
||||||
|
"""Test that provisioning URI contains the issuer."""
|
||||||
|
email = "user@example.com"
|
||||||
|
secret = "JBSWY3DPEHPK3PXP"
|
||||||
|
issuer = "Gatehouse"
|
||||||
|
|
||||||
|
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
|
||||||
|
|
||||||
|
assert issuer in uri
|
||||||
|
|
||||||
|
def test_generate_provisioning_uri_custom_issuer(self):
|
||||||
|
"""Test that provisioning URI uses custom issuer."""
|
||||||
|
email = "user@example.com"
|
||||||
|
secret = "JBSWY3DPEHPK3PXP"
|
||||||
|
custom_issuer = "MyApp"
|
||||||
|
|
||||||
|
uri = TOTPService.generate_provisioning_uri(email, secret, custom_issuer)
|
||||||
|
|
||||||
|
assert custom_issuer in uri
|
||||||
|
|
||||||
|
# Test verify_code()
|
||||||
|
def test_verify_code_valid(self):
|
||||||
|
"""Test that a valid TOTP code is accepted."""
|
||||||
|
secret = TOTPService.generate_secret()
|
||||||
|
# Generate a valid code using pyotp
|
||||||
|
import pyotp
|
||||||
|
totp = pyotp.TOTP(secret)
|
||||||
|
valid_code = totp.now()
|
||||||
|
|
||||||
|
result = TOTPService.verify_code(secret, valid_code)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_verify_code_invalid(self):
|
||||||
|
"""Test that an invalid TOTP code is rejected."""
|
||||||
|
secret = TOTPService.generate_secret()
|
||||||
|
invalid_code = "000000"
|
||||||
|
|
||||||
|
result = TOTPService.verify_code(secret, invalid_code)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_verify_code_window_parameter(self):
|
||||||
|
"""Test that the time window parameter works correctly."""
|
||||||
|
secret = TOTPService.generate_secret()
|
||||||
|
import pyotp
|
||||||
|
totp = pyotp.TOTP(secret)
|
||||||
|
|
||||||
|
# Get current code
|
||||||
|
current_code = totp.now()
|
||||||
|
|
||||||
|
# Verify with window=1 (default) - should accept current code
|
||||||
|
result = TOTPService.verify_code(secret, current_code, window=1)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# Verify with window=0 - should only accept exact time match
|
||||||
|
result = TOTPService.verify_code(secret, current_code, window=0)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_verify_code_wrong_length(self):
|
||||||
|
"""Test that codes with wrong length are rejected."""
|
||||||
|
secret = TOTPService.generate_secret()
|
||||||
|
wrong_length_code = "12345" # 5 digits instead of 6
|
||||||
|
|
||||||
|
result = TOTPService.verify_code(secret, wrong_length_code)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
# Test generate_backup_codes()
|
||||||
|
def test_generate_backup_codes_default_count(self):
|
||||||
|
"""Test that generate_backup_codes generates 10 codes by default."""
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
||||||
|
|
||||||
|
assert len(plain_codes) == 10
|
||||||
|
assert len(hashed_codes) == 10
|
||||||
|
|
||||||
|
def test_generate_backup_codes_custom_count(self):
|
||||||
|
"""Test that generate_backup_codes generates the specified number of codes."""
|
||||||
|
count = 5
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count)
|
||||||
|
|
||||||
|
assert len(plain_codes) == count
|
||||||
|
assert len(hashed_codes) == count
|
||||||
|
|
||||||
|
def test_generate_backup_codes_plain_are_strings(self):
|
||||||
|
"""Test that plain backup codes are strings."""
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
||||||
|
|
||||||
|
assert all(isinstance(code, str) for code in plain_codes)
|
||||||
|
|
||||||
|
def test_generate_backup_codes_plain_length(self):
|
||||||
|
"""Test that plain backup codes are 16 characters long."""
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
||||||
|
|
||||||
|
assert all(len(code) == 16 for code in plain_codes)
|
||||||
|
|
||||||
|
def test_generate_backup_codes_hashed_different_from_plain(self):
|
||||||
|
"""Test that hashed codes are different from plain codes."""
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
||||||
|
|
||||||
|
for plain, hashed in zip(plain_codes, hashed_codes):
|
||||||
|
assert plain != hashed
|
||||||
|
|
||||||
|
def test_generate_backup_codes_are_bcrypt_hashes(self):
|
||||||
|
"""Test that hashed codes are bcrypt hashes."""
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
||||||
|
|
||||||
|
# Bcrypt hashes start with $2a$, $2b$, or $2y$
|
||||||
|
for hashed in hashed_codes:
|
||||||
|
assert hashed.startswith("$2")
|
||||||
|
|
||||||
|
def test_generate_backup_codes_unique(self):
|
||||||
|
"""Test that generated backup codes are unique."""
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
||||||
|
|
||||||
|
assert len(set(plain_codes)) == len(plain_codes)
|
||||||
|
assert len(set(hashed_codes)) == len(hashed_codes)
|
||||||
|
|
||||||
|
# Test verify_backup_code()
|
||||||
|
def test_verify_backup_code_valid(self):
|
||||||
|
"""Test that a valid backup code is accepted and removed."""
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=3)
|
||||||
|
code_to_verify = plain_codes[0]
|
||||||
|
|
||||||
|
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert len(remaining_codes) == 2
|
||||||
|
|
||||||
|
def test_verify_backup_code_invalid(self):
|
||||||
|
"""Test that an invalid backup code is rejected."""
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=3)
|
||||||
|
invalid_code = "INVALIDCODE1234"
|
||||||
|
|
||||||
|
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, invalid_code)
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert len(remaining_codes) == 3
|
||||||
|
|
||||||
|
def test_verify_backup_code_remaining_updated(self):
|
||||||
|
"""Test that the remaining codes list is updated correctly."""
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=5)
|
||||||
|
code_to_verify = plain_codes[2]
|
||||||
|
|
||||||
|
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
# The verified code should be removed
|
||||||
|
assert len(remaining_codes) == 4
|
||||||
|
# The remaining codes should not include the verified code's hash
|
||||||
|
assert hashed_codes[2] not in remaining_codes
|
||||||
|
|
||||||
|
def test_verify_backup_code_case_sensitive(self):
|
||||||
|
"""Test that backup code verification is case sensitive."""
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=1)
|
||||||
|
code_to_verify = plain_codes[0].lower() # Convert to lowercase
|
||||||
|
|
||||||
|
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert len(remaining_codes) == 1
|
||||||
|
|
||||||
|
def test_verify_backup_code_single_use(self):
|
||||||
|
"""Test that a backup code can only be used once."""
|
||||||
|
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=1)
|
||||||
|
code_to_verify = plain_codes[0]
|
||||||
|
|
||||||
|
# First use - should succeed
|
||||||
|
is_valid1, remaining1 = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
|
||||||
|
assert is_valid1 is True
|
||||||
|
assert len(remaining1) == 0
|
||||||
|
|
||||||
|
# Second use - should fail (code already consumed)
|
||||||
|
is_valid2, remaining2 = TOTPService.verify_backup_code(remaining1, code_to_verify)
|
||||||
|
assert is_valid2 is False
|
||||||
|
assert len(remaining2) == 0
|
||||||
|
|
||||||
|
# Test generate_qr_code_data_uri()
|
||||||
|
def test_generate_qr_code_data_uri_format(self):
|
||||||
|
"""Test that a data URI is generated."""
|
||||||
|
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
|
||||||
|
|
||||||
|
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
|
||||||
|
|
||||||
|
assert isinstance(data_uri, str)
|
||||||
|
|
||||||
|
def test_generate_qr_code_data_uri_starts_with_prefix(self):
|
||||||
|
"""Test that the data URI starts with the correct prefix."""
|
||||||
|
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
|
||||||
|
|
||||||
|
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
|
||||||
|
|
||||||
|
assert data_uri.startswith("data:image/png;base64,")
|
||||||
|
|
||||||
|
def test_generate_qr_code_data_uri_contains_base64(self):
|
||||||
|
"""Test that the data URI contains base64 encoded data."""
|
||||||
|
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
|
||||||
|
|
||||||
|
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
|
||||||
|
|
||||||
|
# Extract the base64 part (after the prefix)
|
||||||
|
base64_part = data_uri.split("data:image/png;base64,")[1]
|
||||||
|
|
||||||
|
# Verify it's valid base64
|
||||||
|
try:
|
||||||
|
base64.b64decode(base64_part)
|
||||||
|
assert True
|
||||||
|
except Exception:
|
||||||
|
assert False, "Data URI does not contain valid base64 data"
|
||||||
|
|
||||||
|
def test_generate_qr_code_data_uri_different_uris(self):
|
||||||
|
"""Test that different provisioning URIs generate different QR codes."""
|
||||||
|
uri1 = "otpauth://totp/Gatehouse:user1@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
|
||||||
|
uri2 = "otpauth://totp/Gatehouse:user2@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
|
||||||
|
|
||||||
|
data_uri1 = TOTPService.generate_qr_code_data_uri(uri1)
|
||||||
|
data_uri2 = TOTPService.generate_qr_code_data_uri(uri2)
|
||||||
|
|
||||||
|
assert data_uri1 != data_uri2
|
||||||
Reference in New Issue
Block a user