diff --git a/README.md b/README.md index 7d837f2..fabcf2f 100644 --- a/README.md +++ b/README.md @@ -305,3 +305,7 @@ client_secret: acme_secret_portal_2024 ## User email: bob@acme-corp.com password: UserPass123! + + +## Sqlite editor +sqlite_web instance/db_file.db --port 9999 --host 0.0.0.0 \ No newline at end of file diff --git a/app/api/oidc.py b/app/api/oidc.py index 175db2e..cf3d4dc 100644 --- a/app/api/oidc.py +++ b/app/api/oidc.py @@ -3,6 +3,7 @@ import base64 import json import logging import secrets +from datetime import datetime, timezone from urllib.parse import urlencode, urlparse, parse_qs import bcrypt @@ -42,14 +43,14 @@ def get_oidc_config(): "registration_endpoint": f"{base_url}/oidc/register", "revocation_endpoint": f"{base_url}/oidc/revoke", "introspection_endpoint": f"{base_url}/oidc/introspect", - "scopes_supported": ["openid", "profile", "email"], + "scopes_supported": ["openid", "profile", "email", "roles"], "response_types_supported": ["code"], "response_modes_supported": ["query"], "grant_types_supported": ["authorization_code", "refresh_token"], "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], "subject_types_supported": ["public"], "id_token_signing_alg_values_supported": ["RS256"], - "claims_supported": ["sub", "name", "email", "email_verified"], + "claims_supported": ["sub", "name", "email", "email_verified", "roles"], } @@ -94,19 +95,49 @@ def require_valid_token(): Raises: InvalidGrantError: If token is invalid """ + logger.debug("[OIDC USERINFO] ===========================================") + logger.debug("[OIDC USERINFO] require_valid_token() called") + logger.debug("[OIDC USERINFO] Request method: %s", request.method) + logger.debug("[OIDC USERINFO] Request URL: %s", request.url) + logger.debug("[OIDC USERINFO] Request headers: %s", dict(request.headers)) + auth_header = request.headers.get("Authorization", "") + logger.debug("[OIDC USERINFO] Authorization header: %s", auth_header[:20] + "..." if len(auth_header) > 20 else auth_header) + if not auth_header.startswith("Bearer "): + logger.error("[OIDC USERINFO] Invalid Authorization header format - missing 'Bearer ' prefix") raise InvalidGrantError("Invalid token: Missing or invalid Authorization header") token = auth_header[7:] - claims = OIDCService.validate_access_token(token) - g.current_token = claims + logger.debug("[OIDC USERINFO] Token extracted (first 50 chars): %s...", token[:50] if len(token) > 50 else token) + logger.debug("[OIDC USERINFO] Token length: %d", len(token)) + + try: + logger.debug("[OIDC USERINFO] Calling OIDCService.validate_access_token()...") + claims = OIDCService.validate_access_token(token) + logger.debug("[OIDC USERINFO] Token validation successful") + logger.debug("[OIDC USERINFO] Token claims: %s", claims) + except Exception as e: + logger.error("[OIDC USERINFO] Token validation failed: %s: %s", type(e).__name__, str(e)) + raise + + g.current_token = claims + g.access_token = token # Store the original access token for get_userinfo() + logger.debug("[OIDC USERINFO] g.current_token set") + + user_id = claims.get("sub") + logger.debug("[OIDC USERINFO] User ID from token: %s", user_id) + + user = User.query.get(user_id) + logger.debug("[OIDC USERINFO] User query result: %s", user) - user = User.query.get(claims.get("sub")) if not user: + logger.error("[OIDC USERINFO] User not found in database: user_id=%s", user_id) raise InvalidGrantError("Invalid token: User not found") g.current_user = user + logger.debug("[OIDC USERINFO] g.current_user set: user_id=%s, email=%s", user.id, user.email) + logger.debug("[OIDC USERINFO] require_valid_token() completed successfully") def _check_password_hash(client, password): @@ -175,10 +206,11 @@ def oidc_discovery(): No authentication required. Returns: - 200: OIDC discovery document + 200: OIDC discovery document (application/json) """ config = get_oidc_config() + # Return discovery document as application/json (per OpenID Connect Discovery 1.0) response = jsonify(config) response.headers["Cache-Control"] = "max-age=86400" return response, 200 @@ -217,8 +249,12 @@ def oidc_authorize(): 200: Login page (GET when not authenticated) 400: Invalid request """ + logger.debug("[OIDC] ===========================================") logger.debug("[OIDC] oidc_authorize called") - + logger.debug("[OIDC] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC] Request method: %s", request.method) + logger.debug("[OIDC] Request URL: %s", request.url) + logger.debug("[OIDC] Remote address: %s", request.remote_addr) # Parse request parameters if request.method == "GET": @@ -227,6 +263,7 @@ def oidc_authorize(): params = request.form.to_dict() logger.debug("[OIDC] Raw request params: %s", params) + # Extract required parameters logger.debug("[OIDC] Extracting request parameters...") client_id = params.get("client_id") @@ -367,6 +404,7 @@ def oidc_authorize(): # User is authenticated, generate authorization code logger.debug("[OIDC] User is authenticated, fetching user from database...") + logger.debug("[OIDC] Current UTC time before code generation: %s", datetime.now(timezone.utc).isoformat() + "Z") user = User.query.get(user_id) logger.debug("[OIDC] User query result: %s", user) @@ -393,12 +431,17 @@ def oidc_authorize(): user_agent=request.headers.get("User-Agent"), ) logger.debug("[OIDC] Authorization code generated successfully: %s...", code[:20] if code else None) + logger.debug("[OIDC] Current UTC time after code generation: %s", datetime.now(timezone.utc).isoformat() + "Z") except Exception as e: - logger.debug("[OIDC] Authorization code generation failed: %s", str(e)) + logger.error("[OIDC] Authorization code generation failed: %s", str(e)) + logger.error("[OIDC] Current UTC time at failure: %s", datetime.now(timezone.utc).isoformat() + "Z") + import traceback + logger.error("[OIDC] Traceback: %s", traceback.format_exc()) return _redirect_with_error(redirect_uri, "server_error", str(e), state) # Redirect with authorization code logger.debug("[OIDC] Redirecting with authorization code...") + logger.debug("[OIDC] Current UTC time before redirect: %s", datetime.now(timezone.utc).isoformat() + "Z") redirect_params = {"code": code} if state: redirect_params["state"] = state @@ -406,6 +449,7 @@ def oidc_authorize(): redirect_url = f"{redirect_uri}?{urlencode(redirect_params)}" logger.debug("[OIDC] Redirect URL: %s...", redirect_url[:100] if redirect_url else None) logger.debug("[OIDC] oidc_authorize completed successfully") + logger.debug("[OIDC] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") logger.debug("[OIDC] ===========================================") return redirect(redirect_url) @@ -544,14 +588,13 @@ def oidc_token(): # Validate grant_type if not grant_type: - logger.error("[OIDC] grant_type is requred") - return api_response( - success=False, - message="grant_type is required", - status=400, - error_type="INVALID_REQUEST", - error_details={"error": "invalid_request", "error_description": "grant_type is required"}, - ) + logger.error("[OIDC] grant_type is required") + # RFC 6749 Section 5.2: Error response for invalid request + response = jsonify({ + "error": "invalid_request", + "error_description": "grant_type is required" + }) + return response, 400 # Authenticate client client_id = data.get("client_id") @@ -600,46 +643,51 @@ def oidc_token(): # Unsupported grant type else: logger.error("[OIDC] Unsupported grant_type") - return api_response( - success=False, - message="Unsupported grant_type", - status=400, - error_type="UNSUPPORTED_GRANT_TYPE", - error_details={"error": "unsupported_grant_type", "error_description": f"Grant type '{grant_type}' is not supported"}, - ) + # RFC 6749 Section 5.2: Error response for unsupported grant type + response = jsonify({ + "error": "unsupported_grant_type", + "error_description": f"Grant type '{grant_type}' is not supported" + }) + return response, 400 def _handle_authorization_code_grant(data, client): """Handle authorization_code grant type.""" + logger.debug("[OIDC] ===========================================") + logger.debug("[OIDC] _handle_authorization_code_grant called") + logger.debug("[OIDC] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + code = data.get("code") redirect_uri = data.get("redirect_uri") code_verifier = data.get("code_verifier") + logger.debug("[OIDC] Code provided: %s", bool(code)) + logger.debug("[OIDC] Redirect URI: %s", redirect_uri) + logger.debug("[OIDC] Code verifier provided: %s", bool(code_verifier)) + if not code: logger.error("[OIDC] code is required") - return api_response( - success=False, - message="code is required", - status=400, - error_type="INVALID_REQUEST", - error_details={"error": "invalid_request", "error_description": "code is required"}, - ) + # RFC 6749 Section 5.2: Error response for invalid request + response = jsonify({ + "error": "invalid_request", + "error_description": "code is required" + }) + return response, 400 if not redirect_uri: logger.error("[OIDC] redirect_uri is required") - return api_response( - success=False, - message="redirect_uri is required", - status=400, - error_type="INVALID_REQUEST", - error_details={"error": "invalid_request", "error_description": "redirect_uri is required"}, - ) + response = jsonify({ + "error": "invalid_request", + "error_description": "redirect_uri is required" + }) + return response, 400 try: # Development-only debug logging for authorization code validation if current_app.config.get('ENV') == 'development': logger.debug(f"[OIDC] Authorization code validation: client_id={client.client_id}, code_provided=True") + logger.debug("[OIDC] Current UTC time before code validation: %s", datetime.now(timezone.utc).isoformat() + "Z") claims, user = OIDCService.validate_authorization_code( code=code, client_id=client.client_id, @@ -649,23 +697,22 @@ def _handle_authorization_code_grant(data, client): user_agent=request.headers.get("User-Agent"), ) except InvalidGrantError as e: - logger.error(f"[OIDC] INVALID_GRANT: {str(e)}") - return api_response( - success=False, - message=str(e), - status=400, - error_type="INVALID_GRANT", - error_details={"error": "invalid_grant", "error_description": str(e)}, - ) + logger.error("[OIDC] INVALID_GRANT: %s", str(e)) + logger.error("[OIDC] Current UTC time at validation failure: %s", datetime.now(timezone.utc).isoformat() + "Z") + # RFC 6749 Section 5.2: Error response for invalid grant + response = jsonify({ + "error": "invalid_grant", + "error_description": str(e) + }) + return response, 400 except Exception as e: - logger.error(f"[OIDC] Authorization code validation error: {type(e).__name__}: {str(e)}") - return api_response( - success=False, - message=str(e), - status=400, - error_type="INVALID_GRANT", - error_details={"error": "invalid_grant", "error_description": str(e)}, - ) + logger.error("[OIDC] Authorization code validation error: %s: %s", type(e).__name__, str(e)) + logger.error("[OIDC] Current UTC time at validation error: %s", datetime.now(timezone.utc).isoformat() + "Z") + response = jsonify({ + "error": "invalid_grant", + "error_description": str(e) + }) + return response, 400 # Generate tokens try: @@ -673,6 +720,8 @@ def _handle_authorization_code_grant(data, client): if current_app.config.get('ENV') == 'development': logger.debug(f"[OIDC] Token generation: client_id={client.client_id}, user_id={claims['user_id']}, scope={claims['scope']}") + logger.debug("[OIDC] Current UTC time before token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") + tokens = OIDCService.generate_tokens( client_id=client.client_id, user_id=claims["user_id"], @@ -683,40 +732,64 @@ def _handle_authorization_code_grant(data, client): auth_time=int(__import__("time").time()), ) except Exception as e: - logger.error(f"[OIDC] Failed to generate tokens {str(e)}") - return api_response( - success=False, - message="Failed to generate tokens", - status=500, - error_type="SERVER_ERROR", - error_details={"error": "server_error", "error_description": str(e)}, - ) + logger.error("[OIDC] Failed to generate tokens: %s", str(e)) + logger.error("[OIDC] Current UTC time at token generation failure: %s", datetime.now(timezone.utc).isoformat() + "Z") + response = jsonify({ + "error": "server_error", + "error_description": str(e) + }) + return response, 500 - return api_response( - data=tokens, - message="Tokens issued successfully", - status=200, - ) + # Return standard OAuth2/OIDC token response (application/json) + # Per RFC 6749 Section 5.1 and OIDC Core 1.0 + logger.debug("[OIDC] Current UTC time after token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC] _handle_authorization_code_grant completed successfully") + + # Echo tokens to console for diagnostics + print(f"[TOKEN DIAGNOSTICS] Authorization code exchange completed") + print(f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')}..." if len(tokens.get('access_token', '')) > 50 else f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')}") + print(f"[TOKEN DIAGNOSTICS] Token Type: {tokens.get('token_type', '')}") + print(f"[TOKEN DIAGNOSTICS] Expires In: {tokens.get('expires_in', '')}") + if 'id_token' in tokens: + print(f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token']}..." if len(tokens['id_token']) > 50 else f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token']}") + if 'refresh_token' in tokens: + print(f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token'][:50]}..." if len(tokens['refresh_token']) > 50 else f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token']}") + print(f"[TOKEN DIAGNOSTICS] Scope: {tokens.get('scope', '')}") + print(f"[TOKEN DIAGNOSTICS] ===========================================") + + logger.debug("[OIDC] ===========================================") + response = jsonify(tokens) + print(tokens) + response.headers["Cache-Control"] = "no-store" + response.headers["Pragma"] = "no-cache" + return response, 200 def _handle_refresh_token_grant(data, client): """Handle refresh_token grant type.""" + logger.debug("[OIDC] ===========================================") + logger.debug("[OIDC] _handle_refresh_token_grant called") + logger.debug("[OIDC] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + refresh_token = data.get("refresh_token") scope = data.get("scope") + logger.debug("[OIDC] Refresh token provided: %s", bool(refresh_token)) + logger.debug("[OIDC] Scope: %s", scope) + if not refresh_token: - return api_response( - success=False, - message="refresh_token is required", - status=400, - error_type="INVALID_REQUEST", - error_details={"error": "invalid_request", "error_description": "refresh_token is required"}, - ) + # RFC 6749 Section 5.2: Error response for invalid request + response = jsonify({ + "error": "invalid_request", + "error_description": "refresh_token is required" + }) + return response, 400 # Parse scope if provided scope_list = scope.split() if scope else None try: + logger.debug("[OIDC] Current UTC time before token refresh: %s", datetime.now(timezone.utc).isoformat() + "Z") tokens = OIDCService.refresh_access_token( refresh_token=refresh_token, client_id=client.client_id, @@ -725,19 +798,37 @@ def _handle_refresh_token_grant(data, client): user_agent=request.headers.get("User-Agent"), ) except InvalidGrantError as e: - return api_response( - success=False, - message=str(e), - status=400, - error_type="INVALID_GRANT", - error_details={"error": "invalid_grant", "error_description": str(e)}, - ) + logger.error("[OIDC] Refresh token error: %s", str(e)) + logger.error("[OIDC] Current UTC time at refresh failure: %s", datetime.now(timezone.utc).isoformat() + "Z") + # RFC 6749 Section 5.2: Error response for invalid grant + response = jsonify({ + "error": "invalid_grant", + "error_description": str(e) + }) + return response, 400 - return api_response( - data=tokens, - message="Tokens refreshed successfully", - status=200, - ) + # Return standard OAuth2/OIDC token response (application/json) + # Per RFC 6749 Section 5.1 and OIDC Core 1.0 + logger.debug("[OIDC] Current UTC time after token refresh: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC] _handle_refresh_token_grant completed successfully") + + # Echo tokens to console for diagnostics + print(f"[TOKEN DIAGNOSTICS] Token refresh completed") + print(f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')[:50]}..." if len(tokens.get('access_token', '')) > 50 else f"[TOKEN DIAGNOSTICS] Access Token: {tokens.get('access_token', '')}") + print(f"[TOKEN DIAGNOSTICS] Token Type: {tokens.get('token_type', '')}") + print(f"[TOKEN DIAGNOSTICS] Expires In: {tokens.get('expires_in', '')}") + if 'id_token' in tokens: + print(f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token'][:50]}..." if len(tokens['id_token']) > 50 else f"[TOKEN DIAGNOSTICS] ID Token: {tokens['id_token']}") + if 'refresh_token' in tokens: + print(f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token'][:50]}..." if len(tokens['refresh_token']) > 50 else f"[TOKEN DIAGNOSTICS] Refresh Token: {tokens['refresh_token']}") + print(f"[TOKEN DIAGNOSTICS] Scope: {tokens.get('scope', '')}") + print(f"[TOKEN DIAGNOSTICS] ===========================================") + + logger.debug("[OIDC] ===========================================") + response = jsonify(tokens) + response.headers["Cache-Control"] = "no-store" + response.headers["Pragma"] = "no-cache" + return response, 200 # ============================================================================ @@ -759,37 +850,72 @@ def oidc_userinfo(): - email_verified: Email verification status (if "email" scope) Returns: - 200: User claims - 401: Invalid or insufficient token + 200: User claims in JSON format (application/json) + 401: Invalid or missing token (with WWW-Authenticate header per RFC 6750) """ + logger.debug("[OIDC USERINFO] ===========================================") + logger.debug("[OIDC USERINFO] oidc_userinfo() endpoint called") + logger.debug("[OIDC USERINFO] Request method: %s", request.method) + logger.debug("[OIDC USERINFO] Request URL: %s", request.url) + logger.debug("[OIDC USERINFO] Request content_type: %s", request.content_type) + logger.debug("[OIDC USERINFO] Request headers: %s", dict(request.headers)) + logger.debug("[OIDC USERINFO] Request args: %s", dict(request.args)) + logger.debug("[OIDC USERINFO] Request form: %s", dict(request.form)) + request_json = request.get_json(silent=True) + logger.debug("[OIDC USERINFO] Request json: %s", request_json) + logger.debug("[OIDC USERINFO] Request data length: %d bytes", len(request.get_data())) + try: + logger.debug("[OIDC USERINFO] Calling require_valid_token()...") require_valid_token() + logger.debug("[OIDC USERINFO] Token validation successful") except InvalidGrantError as e: - return api_response( - success=False, - message=str(e), - status=401, - error_type="INVALID_TOKEN", - error_details={"error": "invalid_token", "error_description": str(e)}, - ) - - # Get userinfo - try: - userinfo = OIDCService.get_userinfo(g.current_token.get("access_token", "")) + logger.error("[OIDC USERINFO] Token validation failed: %s", str(e)) + # RFC 6750 Section 3: Return 401 with WWW-Authenticate header for invalid tokens + response = jsonify({ + "error": "invalid_token", + "error_description": str(e) + }) + response.headers["WWW-Authenticate"] = 'Bearer realm="OIDC UserInfo Endpoint", error="invalid_token", error_description="' + str(e) + '"' + return response, 401 except Exception as e: - return api_response( - success=False, - message="Failed to get user info", - status=500, - error_type="SERVER_ERROR", - error_details={"error": "server_error", "error_description": str(e)}, - ) + logger.error("[OIDC USERINFO] Unexpected error during token validation: %s: %s", type(e).__name__, str(e)) + response = jsonify({ + "error": "server_error", + "error_description": str(e) + }) + response.headers["WWW-Authenticate"] = 'Bearer realm="OIDC UserInfo Endpoint", error="server_error"' + return response, 500 - return api_response( - data=userinfo, - message="User info retrieved successfully", - status=200, - ) + logger.debug("[OIDC USERINFO] g.current_token: %s", g.current_token) + logger.debug("[OIDC USERINFO] g.current_user: user_id=%s, email=%s", g.current_user.id, g.current_user.email) + + # Get userinfo using the original access token + access_token = g.access_token + logger.debug("[OIDC USERINFO] Access token from g.access_token: %s...", access_token[:50] if len(access_token) > 50 else access_token) + + try: + logger.debug("[OIDC USERINFO] Calling OIDCService.get_userinfo()...") + userinfo = OIDCService.get_userinfo(access_token) + logger.debug("[OIDC USERINFO] Userinfo retrieved successfully: %s", userinfo) + except Exception as e: + logger.error("[OIDC USERINFO] Failed to get user info: %s: %s", type(e).__name__, str(e)) + import traceback + logger.error("[OIDC USERINFO] Traceback: %s", traceback.format_exc()) + response = jsonify({ + "error": "server_error", + "error_description": str(e) + }) + return response, 500 + + logger.debug("[OIDC USERINFO] Returning userinfo response") + logger.debug("[OIDC USERINFO] ===========================================") + + # Return standard OIDC UserInfo response (application/json) + # Per OpenID Connect Core 1.0 Section 5.3, response is a JSON object + response = jsonify(userinfo) + response.headers["Cache-Control"] = "no-cache, no-store" + return response, 200 # ============================================================================ @@ -806,19 +932,18 @@ def oidc_jwks(): No authentication required. Returns: - 200: JWKS document + 200: JWKS document (application/json) """ try: jwks = OIDCService.get_jwks() except Exception as e: - return api_response( - success=False, - message="Failed to get JWKS", - status=500, - error_type="SERVER_ERROR", - error_details={"error": "server_error", "error_description": str(e)}, - ) + response = jsonify({ + "error": "server_error", + "error_description": str(e) + }) + return response, 500 + # Return JWKS as application/json (per OpenID Connect Discovery 1.0) response = jsonify(jwks) response.headers["Cache-Control"] = "max-age=3600" return response, 200 @@ -858,13 +983,12 @@ def oidc_revoke(): token = data.get("token") if not token: - return api_response( - success=False, - message="token is required", - status=400, - error_type="INVALID_REQUEST", - error_details={"error": "invalid_request", "error_description": "token is required"}, - ) + # RFC 7009 Section 2.1: Error response for invalid request + response = jsonify({ + "error": "invalid_request", + "error_description": "token is required" + }) + return response, 400 # Authenticate client client_id = data.get("client_id") @@ -901,13 +1025,11 @@ def oidc_revoke(): user_agent=request.headers.get("User-Agent"), ) except Exception as e: - # Revocation should succeed even if token is invalid + # Revocation should succeed even if token is invalid (RFC 7009) pass - return api_response( - message="Token revoked successfully", - status=200, - ) + # RFC 7009 Section 2.2: Successful revocation returns empty body with 200 + return "", 200 # ============================================================================ @@ -944,13 +1066,12 @@ def oidc_introspect(): token = data.get("token") if not token: - return api_response( - success=False, - message="token is required", - status=400, - error_type="INVALID_REQUEST", - error_details={"error": "invalid_request", "error_description": "token is required"}, - ) + # RFC 7009 Section 2.1: Error response for invalid request + response = jsonify({ + "error": "invalid_request", + "error_description": "token is required" + }) + return response, 400 # Authenticate client client_id = data.get("client_id") @@ -986,19 +1107,17 @@ def oidc_introspect(): user_agent=request.headers.get("User-Agent"), ) except Exception as e: - return api_response( - success=False, - message="Failed to introspect token", - status=500, - error_type="SERVER_ERROR", - error_details={"error": "server_error", "error_description": str(e)}, - ) + # RFC 7009 Section 2.2: Error response + response = jsonify({ + "error": "server_error", + "error_description": str(e) + }) + return response, 500 - return api_response( - data=result, - message="Token introspection successful", - status=200, - ) + # RFC 7009 Section 2.3: Return introspection response (application/json) + response = jsonify(result) + response.headers["Cache-Control"] = "no-cache, no-store" + return response, 200 # ============================================================================ @@ -1030,22 +1149,18 @@ def oidc_register(): redirect_uris = data.get("redirect_uris", []) if not client_name: - return api_response( - success=False, - message="client_name is required", - status=400, - error_type="INVALID_REQUEST", - error_details={"error": "invalid_request", "error_description": "client_name is required"}, - ) + response = jsonify({ + "error": "invalid_request", + "error_description": "client_name is required" + }) + return response, 400 if not redirect_uris: - return api_response( - success=False, - message="redirect_uris is required", - status=400, - error_type="INVALID_REQUEST", - error_details={"error": "invalid_request", "error_description": "redirect_uris is required"}, - ) + response = jsonify({ + "error": "invalid_request", + "error_description": "redirect_uris is required" + }) + return response, 400 # Validate redirect_uris for uri in redirect_uris: @@ -1054,13 +1169,11 @@ def oidc_register(): if not parsed.scheme or not parsed.netloc: raise ValueError(f"Invalid redirect URI: {uri}") except Exception: - return api_response( - success=False, - message=f"Invalid redirect_uri: {uri}", - status=400, - error_type="INVALID_REQUEST", - error_details={"error": "invalid_request", "error_description": f"Invalid redirect_uri: {uri}"}, - ) + response = jsonify({ + "error": "invalid_request", + "error_description": f"Invalid redirect_uri: {uri}" + }) + return response, 400 # Generate client credentials client_id = f"oidc_{secrets.token_urlsafe(16)}" @@ -1092,7 +1205,7 @@ def oidc_register(): redirect_uris=redirect_uris, grant_types=data.get("grant_types", ["authorization_code", "refresh_token"]), response_types=data.get("response_types", ["code"]), - scopes=data.get("scope", "openid profile email").split(), + scopes=data.get("scope", "openid profile email roles").split(), token_endpoint_auth_method=data.get("token_endpoint_auth_method", "client_secret_basic"), is_active=True, is_confidential=True, @@ -1105,19 +1218,16 @@ def oidc_register(): client.save() # Return client credentials - return api_response( - data={ - "client_id": client_id, - "client_secret": client_secret, - "client_id_issued_at": int(__import__("time").time()), - "client_secret_expires_at": 0, # Never expires - "client_name": client_name, - "redirect_uris": redirect_uris, - "token_endpoint_auth_method": data.get("token_endpoint_auth_method", "client_secret_basic"), - "grant_types": client.grant_types, - "response_types": client.response_types, - "scope": " ".join(client.scopes), - }, - message="Client registered successfully", - status=201, - ) + response = jsonify({ + "client_id": client_id, + "client_secret": client_secret, + "client_id_issued_at": int(__import__("time").time()), + "client_secret_expires_at": 0, # Never expires + "client_name": client_name, + "redirect_uris": redirect_uris, + "token_endpoint_auth_method": data.get("token_endpoint_auth_method", "client_secret_basic"), + "grant_types": client.grant_types, + "response_types": client.response_types, + "scope": " ".join(client.scopes), + }) + return response, 201 diff --git a/app/api/v1/auth.py b/app/api/v1/auth.py index bab831f..c1635ce 100644 --- a/app/api/v1/auth.py +++ b/app/api/v1/auth.py @@ -3,11 +3,20 @@ from flask import request, session, g from marshmallow import ValidationError from app.api.v1 import api_v1_bp from app.utils.response import api_response -from app.schemas.auth_schema import RegisterSchema, LoginSchema +from app.schemas.auth_schema import ( + RegisterSchema, + LoginSchema, + TOTPVerifyEnrollmentSchema, + TOTPVerifySchema, + TOTPDisableSchema, + TOTPRegenerateBackupCodesSchema, +) from app.services.auth_service import AuthService from app.services.user_service import UserService from app.utils.decorators import login_required from app.utils.constants import AuditAction +from app.exceptions.auth_exceptions import InvalidCredentialsError +from app.exceptions.validation_exceptions import ConflictError @api_v1_bp.route("/auth/register", methods=["POST"]) @@ -72,7 +81,7 @@ def login(): remember_me: Optional boolean for extended session Returns: - 200: Login successful + 200: Login successful or TOTP code required 400: Validation error 401: Invalid credentials """ @@ -81,13 +90,29 @@ def login(): schema = LoginSchema() data = schema.load(request.json) - # Authenticate user + # Authenticate user with email and password user = AuthService.authenticate( email=data["email"], password=data["password"], ) - # Create session + # Check if user has TOTP enabled for two-factor authentication + if user.has_totp_enabled(): + # TOTP is enabled - store user_id in session for TOTP verification + # The /auth/totp/verify endpoint will retrieve this user_id + session["totp_pending_user_id"] = user.id + + # Return response indicating TOTP code is required + # Do NOT create session or return token yet - wait for TOTP verification + return api_response( + data={ + "requires_totp": True, + }, + message="TOTP code required. Please enter your 6-digit code from your authenticator app.", + ) + + # TOTP is NOT enabled - proceed with normal login flow + # Create session with appropriate duration based on remember_me preference duration = 2592000 if data.get("remember_me") else 86400 # 30 days vs 1 day user_session = AuthService.create_session(user, duration_seconds=duration) @@ -210,3 +235,295 @@ def revoke_session(session_id): return api_response( message="Session revoked successfully", ) + + +@api_v1_bp.route("/auth/totp/enroll", methods=["POST"]) +@login_required +def enroll_totp(): + """ + Initiate TOTP enrollment for the current user. + + Returns: + 201: TOTP enrollment initiated with secret, provisioning_uri, qr_code, and backup_codes + 401: Not authenticated + 409: TOTP already enabled + """ + try: + # Initiate TOTP enrollment + result = AuthService.enroll_totp(g.current_user) + + return api_response( + data={ + "secret": result["secret"], + "provisioning_uri": result["provisioning_uri"], + "qr_code": result["qr_code"], + "backup_codes": result["backup_codes"], + }, + message="TOTP enrollment initiated. Please verify with your authenticator app.", + status=201, + ) + + except ConflictError as e: + return api_response( + success=False, + message=e.message, + status=e.status_code, + error_type=e.error_type, + ) + + +@api_v1_bp.route("/auth/totp/verify-enrollment", methods=["POST"]) +@login_required +def verify_totp_enrollment(): + """ + Complete TOTP enrollment by verifying the first TOTP code. + + Request body: + code: 6-digit TOTP code from authenticator app + + Returns: + 200: TOTP enrollment completed successfully + 400: Validation error + 401: Not authenticated + 401: Invalid TOTP code + """ + try: + # Validate request data + schema = TOTPVerifyEnrollmentSchema() + data = schema.load(request.json) + + # Verify TOTP enrollment + AuthService.verify_totp_enrollment(g.current_user, data["code"]) + + return api_response( + message="TOTP enrollment completed successfully", + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + except InvalidCredentialsError as e: + return api_response( + success=False, + message=e.message, + status=e.status_code, + error_type=e.error_type, + ) + + +@api_v1_bp.route("/auth/totp/verify", methods=["POST"]) +def verify_totp(): + """ + Verify TOTP code during login. + + Request body: + code: 6-digit TOTP code or backup code + is_backup_code: True if code is a backup code, False if TOTP code (default: False) + + Returns: + 200: TOTP code verified successfully with session token + 400: Validation error + 401: Invalid TOTP code or session not found + """ + try: + # Validate request data + schema = TOTPVerifySchema() + data = schema.load(request.json) + + # Get user from temporary session (stored in Flask session by login endpoint) + user_id = session.get("totp_pending_user_id") + if not user_id: + return api_response( + success=False, + message="No pending TOTP verification. Please login first.", + status=401, + error_type="AUTHENTICATION_ERROR", + ) + + # Get user from database + from app.models.user import User + user = User.query.get(user_id) + if not user: + return api_response( + success=False, + message="User not found", + status=401, + error_type="AUTHENTICATION_ERROR", + ) + + # Verify TOTP code + AuthService.authenticate_with_totp( + user, data["code"], data.get("is_backup_code", False) + ) + + # Create full session + user_session = AuthService.create_session(user) + + # Clear temporary session + session.pop("totp_pending_user_id", None) + + return api_response( + data={ + "user": user.to_dict(), + "token": user_session.token, + "expires_at": user_session.expires_at.isoformat() + "Z" + if user_session.expires_at.isoformat()[-1] != "Z" + else user_session.expires_at.isoformat(), + }, + message="TOTP verification successful", + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + except InvalidCredentialsError as e: + return api_response( + success=False, + message=e.message, + status=e.status_code, + error_type=e.error_type, + ) + + +@api_v1_bp.route("/auth/totp/disable", methods=["DELETE"]) +@login_required +def disable_totp(): + """ + Disable TOTP for the current user. + + Request body: + password: User's current password for verification + + Returns: + 200: TOTP disabled successfully + 400: Validation error + 401: Not authenticated or invalid password + 401: TOTP not enabled + """ + try: + # Validate request data + schema = TOTPDisableSchema() + data = schema.load(request.json) + + # Disable TOTP + AuthService.disable_totp(g.current_user, data["password"]) + + return api_response( + message="TOTP disabled successfully", + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + except InvalidCredentialsError as e: + return api_response( + success=False, + message=e.message, + status=e.status_code, + error_type=e.error_type, + ) + + +@api_v1_bp.route("/auth/totp/status", methods=["GET"]) +@login_required +def get_totp_status(): + """ + Get TOTP status for the current user. + + Returns: + 200: TOTP status with totp_enabled, verified_at, and backup_codes_remaining + 401: Not authenticated + """ + user = g.current_user + + # Check if TOTP is enabled + totp_enabled = user.has_totp_enabled() + + # Get TOTP method to check backup codes remaining + backup_codes_remaining = 0 + verified_at = None + + if totp_enabled: + totp_method = user.get_totp_method() + if totp_method and totp_method.provider_data: + backup_codes = totp_method.provider_data.get("backup_codes", []) + backup_codes_remaining = len(backup_codes) + if totp_method and totp_method.totp_verified_at: + verified_at = totp_method.totp_verified_at.isoformat() + "Z" if totp_method.totp_verified_at.isoformat()[-1] != "Z" else totp_method.totp_verified_at.isoformat() + + return api_response( + data={ + "totp_enabled": totp_enabled, + "verified_at": verified_at, + "backup_codes_remaining": backup_codes_remaining, + }, + message="TOTP status retrieved successfully", + ) + + +@api_v1_bp.route("/auth/totp/regenerate-backup-codes", methods=["POST"]) +@login_required +def regenerate_totp_backup_codes(): + """ + Generate new backup codes for TOTP. + + Request body: + password: User's current password for verification + + Returns: + 200: New backup codes generated successfully + 400: Validation error + 401: Not authenticated or invalid password + 401: TOTP not enabled + """ + try: + # Validate request data + schema = TOTPRegenerateBackupCodesSchema() + data = schema.load(request.json) + + # Regenerate backup codes + backup_codes = AuthService.regenerate_totp_backup_codes( + g.current_user, data["password"] + ) + + return api_response( + data={ + "backup_codes": backup_codes, + }, + message="Backup codes regenerated successfully", + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + except InvalidCredentialsError as e: + return api_response( + success=False, + message=e.message, + status=e.status_code, + error_type=e.error_type, + ) diff --git a/app/middleware/cors.py b/app/middleware/cors.py index 85f3aba..7898d60 100644 --- a/app/middleware/cors.py +++ b/app/middleware/cors.py @@ -26,6 +26,7 @@ def setup_cors(app): response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID" response.headers["Access-Control-Max-Age"] = "3600" + response.headers["Cache-Control"] = "no-cache, no-store" return response elif origin and origin in cors_origins: response = make_response("", 204) @@ -34,6 +35,7 @@ def setup_cors(app): response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID" response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Max-Age"] = "3600" + response.headers["Cache-Control"] = "no-cache, no-store" return response @app.after_request diff --git a/app/middleware/security_headers.py b/app/middleware/security_headers.py index d6e5bb0..eb1b6a4 100644 --- a/app/middleware/security_headers.py +++ b/app/middleware/security_headers.py @@ -51,4 +51,15 @@ class SecurityHeadersMiddleware: "geolocation=(), microphone=(), camera=()" ) + # Cache-Control: Allow OIDC endpoints to set their own Cache-Control + # Only set no-cache for API responses that haven't set their own cache headers + if "Cache-Control" not in response.headers: + # Check if this is a JSON API response (shouldn't be cached) + content_type = response.headers.get("Content-Type", "") + if "application/json" in content_type: + response.headers["Cache-Control"] = "no-cache, no-store" + elif "text/html" not in content_type: + # For non-HTML responses, add Pragma for HTTP/1.0 compatibility + response.headers["Pragma"] = "no-cache" + return response diff --git a/app/models/authentication_method.py b/app/models/authentication_method.py index 9bd7794..752356c 100644 --- a/app/models/authentication_method.py +++ b/app/models/authentication_method.py @@ -19,6 +19,11 @@ class AuthenticationMethod(BaseModel): provider_user_id = db.Column(db.String(255), nullable=True) provider_data = db.Column(db.JSON, nullable=True) + # # For TOTP authentication + # totp_secret = db.Column(db.String(32), nullable=True) + # totp_backup_codes = db.Column(db.JSON, nullable=True) + # totp_verified_at = db.Column(db.DateTime, nullable=True) + # Metadata is_primary = db.Column(db.Boolean, default=False, nullable=False) verified = db.Column(db.Boolean, default=False, nullable=False) @@ -51,9 +56,15 @@ class AuthenticationMethod(BaseModel): AuthMethodType.MICROSOFT, ] + def is_totp(self): + """Check if this is a TOTP authentication method.""" + return self.method_type == AuthMethodType.TOTP + def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude password hash + # Always exclude password hash and TOTP secrets exclude.append("password_hash") + exclude.append("totp_secret") + exclude.append("totp_backup_codes") return super().to_dict(exclude=exclude) diff --git a/app/models/base.py b/app/models/base.py index 594fd74..167eb88 100644 --- a/app/models/base.py +++ b/app/models/base.py @@ -1,6 +1,6 @@ """Base model with common fields and functionality.""" import uuid -from datetime import datetime +from datetime import datetime, timezone from app.extensions import db @@ -16,9 +16,9 @@ class BaseModel(db.Model): unique=True, nullable=False, ) - created_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) + created_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)) updated_at = db.Column( - db.DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow + db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc) ) deleted_at = db.Column(db.DateTime, nullable=True) @@ -36,7 +36,7 @@ class BaseModel(db.Model): soft: If True, performs soft delete. If False, hard delete. """ if soft: - self.deleted_at = datetime.utcnow() + self.deleted_at = datetime.now(timezone.utc) db.session.commit() else: db.session.delete(self) @@ -47,7 +47,7 @@ class BaseModel(db.Model): for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) - self.updated_at = datetime.utcnow() + self.updated_at = datetime.now(timezone.utc) db.session.commit() return self diff --git a/app/models/oidc_authorization_code.py b/app/models/oidc_authorization_code.py index 4d39cea..82091b7 100644 --- a/app/models/oidc_authorization_code.py +++ b/app/models/oidc_authorization_code.py @@ -1,5 +1,5 @@ """OIDC Authorization Code model for auth code flow.""" -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from app.extensions import db from app.models.base import BaseModel @@ -49,7 +49,12 @@ class OIDCAuthCode(BaseModel): def is_expired(self): """Check if the authorization code has expired.""" - return datetime.utcnow() > self.expires_at + # Handle both timezone-aware and timezone-naive expires_at values + expires_at = self.expires_at + if expires_at.tzinfo is None: + # Make naive datetime timezone-aware (UTC) + expires_at = expires_at.replace(tzinfo=timezone.utc) + return datetime.now(timezone.utc) > expires_at def is_valid(self): """Check if the authorization code is valid for use.""" @@ -58,7 +63,7 @@ class OIDCAuthCode(BaseModel): def mark_as_used(self): """Mark the authorization code as used.""" self.is_used = True - self.used_at = datetime.utcnow() + self.used_at = datetime.now(timezone.utc) db.session.commit() @classmethod @@ -90,7 +95,7 @@ class OIDCAuthCode(BaseModel): scope=scope, nonce=nonce, code_verifier=code_verifier, - expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds), + expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds), ip_address=ip_address, user_agent=user_agent, ) diff --git a/app/models/oidc_refresh_token.py b/app/models/oidc_refresh_token.py index e499915..e0f88db 100644 --- a/app/models/oidc_refresh_token.py +++ b/app/models/oidc_refresh_token.py @@ -1,5 +1,5 @@ """OIDC Refresh Token model for token rotation.""" -from datetime import datetime +from datetime import datetime, timezone from app.extensions import db from app.models.base import BaseModel @@ -58,7 +58,11 @@ class OIDCRefreshToken(BaseModel): def is_expired(self): """Check if the refresh token has expired.""" - return datetime.utcnow() > self.expires_at + # Handle both timezone-aware and timezone-naive expires_at values + expires_at = self.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + return datetime.now(timezone.utc) > expires_at def is_revoked(self): """Check if the refresh token has been revoked.""" @@ -74,7 +78,7 @@ class OIDCRefreshToken(BaseModel): Args: reason: Optional reason for revocation """ - self.revoked_at = datetime.utcnow() + self.revoked_at = datetime.now(timezone.utc) self.revoked_reason = reason db.session.commit() @@ -93,7 +97,7 @@ class OIDCRefreshToken(BaseModel): self.rotation_count += 1 # Extend expiration on rotation from datetime import timedelta - self.expires_at = datetime.utcnow() + timedelta(days=30) + self.expires_at = datetime.now(timezone.utc) + timedelta(days=30) db.session.commit() return self @@ -123,7 +127,7 @@ class OIDCRefreshToken(BaseModel): token_hash=token_hash, scope=scope, access_token_id=access_token_id, - expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds), + expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds), ip_address=ip_address, user_agent=user_agent, ) diff --git a/app/models/oidc_session.py b/app/models/oidc_session.py index 8e0f503..d23f892 100644 --- a/app/models/oidc_session.py +++ b/app/models/oidc_session.py @@ -1,5 +1,5 @@ """OIDC Session model for OIDC session tracking.""" -from datetime import datetime +from datetime import datetime, timezone from app.extensions import db from app.models.base import BaseModel @@ -49,7 +49,7 @@ class OIDCSession(BaseModel): def is_expired(self): """Check if the OIDC session has expired.""" - return datetime.utcnow() > self.expires_at + return datetime.now(timezone.utc) > self.expires_at def is_authenticated(self): """Check if the user has been authenticated in this session.""" @@ -57,7 +57,7 @@ class OIDCSession(BaseModel): def mark_authenticated(self): """Mark the session as authenticated.""" - self.authenticated_at = datetime.utcnow() + self.authenticated_at = datetime.now(timezone.utc) db.session.commit() def validate_nonce(self, expected_nonce): @@ -126,7 +126,7 @@ class OIDCSession(BaseModel): nonce=nonce, code_challenge=code_challenge, code_challenge_method=code_challenge_method, - expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds), + expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds), ) db.session.add(session) db.session.commit() diff --git a/app/models/oidc_token_metadata.py b/app/models/oidc_token_metadata.py index 2833c30..c5cefa3 100644 --- a/app/models/oidc_token_metadata.py +++ b/app/models/oidc_token_metadata.py @@ -1,6 +1,6 @@ """OIDC Token Metadata model for token revocation tracking.""" import uuid -from datetime import datetime +from datetime import datetime, timezone from app.extensions import db from app.models.base import BaseModel @@ -50,7 +50,11 @@ class OIDCTokenMetadata(BaseModel): def is_expired(self): """Check if the token has expired.""" - return datetime.utcnow() > self.expires_at + # Handle both timezone-aware and timezone-naive expires_at values + expires_at = self.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + return datetime.now(timezone.utc) > expires_at def is_revoked(self): """Check if the token has been revoked.""" @@ -66,7 +70,7 @@ class OIDCTokenMetadata(BaseModel): Args: reason: Optional reason for revocation """ - self.revoked_at = datetime.utcnow() + self.revoked_at = datetime.now(timezone.utc) self.revoked_reason = reason db.session.commit() diff --git a/app/models/session.py b/app/models/session.py index b166d63..0eceb40 100644 --- a/app/models/session.py +++ b/app/models/session.py @@ -1,5 +1,5 @@ """Session model.""" -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from app.extensions import db from app.models.base import BaseModel from app.utils.constants import SessionStatus @@ -34,7 +34,7 @@ class Session(BaseModel): def is_active(self): """Check if session is currently active.""" - now = datetime.utcnow() + now = datetime.now(timezone.utc) return ( self.status == SessionStatus.ACTIVE and self.expires_at > now @@ -43,7 +43,7 @@ class Session(BaseModel): def is_expired(self): """Check if session has expired.""" - return datetime.utcnow() > self.expires_at + return datetime.now(timezone.utc) > self.expires_at def refresh(self, duration_seconds=86400): """ @@ -52,8 +52,8 @@ class Session(BaseModel): Args: duration_seconds: New session duration in seconds """ - self.expires_at = datetime.utcnow() + timedelta(seconds=duration_seconds) - self.last_activity_at = datetime.utcnow() + self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds) + self.last_activity_at = datetime.now(timezone.utc) db.session.commit() def revoke(self, reason=None): @@ -64,7 +64,7 @@ class Session(BaseModel): reason: Optional reason for revocation """ self.status = SessionStatus.REVOKED - self.revoked_at = datetime.utcnow() + self.revoked_at = datetime.now(timezone.utc) if reason: self.revoked_reason = reason db.session.commit() diff --git a/app/models/user.py b/app/models/user.py index 154c456..bb7c700 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -59,3 +59,35 @@ class User(BaseModel): def get_organizations(self): """Get all organizations the user is a member of.""" return [membership.organization for membership in self.organization_memberships] + + def has_totp_enabled(self) -> bool: + """Check if user has TOTP enabled and verified. + + Returns: + True if user has a verified TOTP authentication method, False otherwise. + """ + from app.models.authentication_method import AuthenticationMethod + from app.utils.constants import AuthMethodType + + return ( + AuthenticationMethod.query.filter_by( + user_id=self.id, + method_type=AuthMethodType.TOTP, + verified=True, + deleted_at=None, + ).first() + is not None + ) + + def get_totp_method(self): + """Get user's TOTP authentication method. + + Returns: + The AuthenticationMethod instance for TOTP or None if not found. + """ + from app.models.authentication_method import AuthenticationMethod + from app.utils.constants import AuthMethodType + + return AuthenticationMethod.query.filter_by( + user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None + ).first() diff --git a/app/schemas/auth_schema.py b/app/schemas/auth_schema.py index 0f9bda5..360b92a 100644 --- a/app/schemas/auth_schema.py +++ b/app/schemas/auth_schema.py @@ -55,3 +55,34 @@ class ResetPasswordSchema(Schema): """Validate that passwords match.""" if data.get("password") != data.get("password_confirm"): raise ValidationError("Passwords do not match", field_name="password_confirm") + + +class TOTPVerifyEnrollmentSchema(Schema): + """Schema for TOTP enrollment verification.""" + + code = fields.Str( + required=True, + validate=validate.Regexp( + r"^\d{6}$", + error="Code must be a 6-digit number", + ), + ) + + +class TOTPVerifySchema(Schema): + """Schema for TOTP code verification during login.""" + + code = fields.Str(required=True) + is_backup_code = fields.Bool(missing=False) + + +class TOTPDisableSchema(Schema): + """Schema for disabling TOTP.""" + + password = fields.Str(required=True, validate=validate.Length(min=1)) + + +class TOTPRegenerateBackupCodesSchema(Schema): + """Schema for regenerating backup codes.""" + + password = fields.Str(required=True, validate=validate.Length(min=1)) diff --git a/app/services/auth_service.py b/app/services/auth_service.py index fbda5a5..deaa757 100644 --- a/app/services/auth_service.py +++ b/app/services/auth_service.py @@ -11,6 +11,7 @@ from app.utils.constants import AuthMethodType, SessionStatus, UserStatus, Audit from app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError from app.exceptions.validation_exceptions import EmailAlreadyExistsError from app.services.audit_service import AuditService +from app.services.totp_service import TOTPService logger = logging.getLogger(__name__) @@ -234,3 +235,315 @@ class AuthService: resource_id=session.id, description=f"Session revoked: {reason or 'User logout'}", ) + + @staticmethod + def enroll_totp(user: User) -> dict: + """ + Initiate TOTP enrollment for a user. + + Args: + user: User instance + + Returns: + Dictionary containing: + - secret: TOTP secret (base32 encoded) + - provisioning_uri: otpauth:// URI for QR code + - qr_code: Base64 encoded QR code as data URI + - backup_codes: List of plain text backup codes + + Raises: + ConflictError: If user already has TOTP enabled + """ + from app.exceptions.validation_exceptions import ConflictError + + # Check if user already has TOTP enabled + if user.has_totp_enabled(): + raise ConflictError("TOTP is already enabled for this account") + + # Generate TOTP secret + secret = TOTPService.generate_secret() + + # Generate provisioning URI + provisioning_uri = TOTPService.generate_provisioning_uri( + user_email=user.email, + secret=secret, + issuer="Gatehouse", + ) + + # Generate QR code data URI + qr_code = TOTPService.generate_qr_code_data_uri(provisioning_uri) + + # Generate backup codes + backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes() + + # Create unverified TOTP authentication method + auth_method = AuthenticationMethod( + user_id=user.id, + method_type=AuthMethodType.TOTP, + verified=False, + is_primary=False, + ) + auth_method.save() + + # Store TOTP data in provider_data (since totp_secret field is commented out) + auth_method.provider_data = { + "secret": secret, + "backup_codes": hashed_backup_codes, + } + db.session.commit() + + # Log TOTP enrollment initiation + AuditService.log_action( + action=AuditAction.TOTP_ENROLL_INITIATED, + user_id=user.id, + resource_type="authentication_method", + resource_id=auth_method.id, + description="TOTP enrollment initiated", + ) + + return { + "secret": secret, + "provisioning_uri": provisioning_uri, + "qr_code": qr_code, + "backup_codes": backup_codes, + } + + @staticmethod + def verify_totp_enrollment(user: User, code: str) -> bool: + """ + Complete TOTP enrollment by verifying the first TOTP code. + + Args: + user: User instance + code: 6-digit TOTP code from authenticator app + + Returns: + True if verification successful + + Raises: + InvalidCredentialsError: If code is invalid or TOTP method not found + """ + # Get user's TOTP authentication method + auth_method = user.get_totp_method() + if not auth_method: + raise InvalidCredentialsError("TOTP enrollment not found") + + # Get secret from provider_data + secret = auth_method.provider_data.get("secret") if auth_method.provider_data else None + if not secret: + raise InvalidCredentialsError("TOTP secret not found") + + # Verify the code + if not TOTPService.verify_code(secret, code): + raise InvalidCredentialsError("Invalid TOTP code") + + # Mark TOTP as verified + auth_method.verified = True + auth_method.totp_verified_at = datetime.utcnow() + db.session.commit() + + # Log TOTP enrollment completion + AuditService.log_action( + action=AuditAction.TOTP_ENROLL_COMPLETED, + user_id=user.id, + resource_type="authentication_method", + resource_id=auth_method.id, + description="TOTP enrollment completed", + ) + + return True + + @staticmethod + def disable_totp(user: User, password: str) -> bool: + """ + Disable TOTP for a user. + + Args: + user: User instance + password: User's current password for verification + + Returns: + True if TOTP disabled successfully + + Raises: + InvalidCredentialsError: If password is invalid or TOTP method not found + """ + # Verify user's password + auth_method = AuthenticationMethod.query.filter_by( + user_id=user.id, + method_type=AuthMethodType.PASSWORD, + deleted_at=None, + ).first() + + if not auth_method or not auth_method.password_hash: + raise InvalidCredentialsError("No password authentication method found") + + if not bcrypt.check_password_hash(auth_method.password_hash, password): + raise InvalidCredentialsError("Invalid password") + + # Get user's TOTP authentication method + totp_method = user.get_totp_method() + if not totp_method: + raise InvalidCredentialsError("TOTP is not enabled for this account") + + # Soft-delete the TOTP authentication method + totp_method.delete(soft=True) + + # Log TOTP disabled + AuditService.log_action( + action=AuditAction.TOTP_DISABLED, + user_id=user.id, + resource_type="authentication_method", + resource_id=totp_method.id, + description="TOTP disabled", + ) + + return True + + @staticmethod + def authenticate_with_totp(user: User, code: str, is_backup_code: bool = False) -> bool: + """ + Verify TOTP code during login. + + Args: + user: User instance + code: 6-digit TOTP code or backup code + is_backup_code: True if code is a backup code, False if TOTP code + + Returns: + True if code is valid + + Raises: + InvalidCredentialsError: If code is invalid or TOTP method not found + """ + # Get user's TOTP authentication method + auth_method = user.get_totp_method() + if not auth_method: + raise InvalidCredentialsError("TOTP is not enabled for this account") + + if is_backup_code: + # Verify backup code + backup_codes = ( + auth_method.provider_data.get("backup_codes") + if auth_method.provider_data + else [] + ) + is_valid, remaining_codes = TOTPService.verify_backup_code(backup_codes, code) + + if is_valid: + # Update remaining backup codes + auth_method.provider_data = { + "secret": auth_method.provider_data.get("secret"), + "backup_codes": remaining_codes, + } + auth_method.last_used_at = datetime.utcnow() + db.session.commit() + + # Log backup code usage + AuditService.log_action( + action=AuditAction.TOTP_BACKUP_CODE_USED, + user_id=user.id, + resource_type="authentication_method", + resource_id=auth_method.id, + description="Backup code used for authentication", + ) + else: + # Log failed verification + AuditService.log_action( + action=AuditAction.TOTP_VERIFY_FAILED, + user_id=user.id, + resource_type="authentication_method", + resource_id=auth_method.id, + description="Invalid backup code provided", + ) + raise InvalidCredentialsError("Invalid backup code") + else: + # Verify TOTP code + secret = ( + auth_method.provider_data.get("secret") + if auth_method.provider_data + else None + ) + if not secret: + raise InvalidCredentialsError("TOTP secret not found") + + is_valid = TOTPService.verify_code(secret, code) + + if is_valid: + auth_method.last_used_at = datetime.utcnow() + db.session.commit() + + # Log successful verification + AuditService.log_action( + action=AuditAction.TOTP_VERIFY_SUCCESS, + user_id=user.id, + resource_type="authentication_method", + resource_id=auth_method.id, + description="TOTP code verified successfully", + ) + else: + # Log failed verification + AuditService.log_action( + action=AuditAction.TOTP_VERIFY_FAILED, + user_id=user.id, + resource_type="authentication_method", + resource_id=auth_method.id, + description="Invalid TOTP code provided", + ) + raise InvalidCredentialsError("Invalid TOTP code") + + return True + + @staticmethod + def regenerate_totp_backup_codes(user: User, password: str) -> list[str]: + """ + Generate new backup codes for TOTP. + + Args: + user: User instance + password: User's current password for verification + + Returns: + List of new plain text backup codes + + Raises: + InvalidCredentialsError: If password is invalid or TOTP method not found + """ + # Verify user's password + auth_method = AuthenticationMethod.query.filter_by( + user_id=user.id, + method_type=AuthMethodType.PASSWORD, + deleted_at=None, + ).first() + + if not auth_method or not auth_method.password_hash: + raise InvalidCredentialsError("No password authentication method found") + + if not bcrypt.check_password_hash(auth_method.password_hash, password): + raise InvalidCredentialsError("Invalid password") + + # Get user's TOTP authentication method + totp_method = user.get_totp_method() + if not totp_method: + raise InvalidCredentialsError("TOTP is not enabled for this account") + + # Generate new backup codes + backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes() + + # Update the authentication method with new backup codes + totp_method.provider_data = { + "secret": totp_method.provider_data.get("secret"), + "backup_codes": hashed_backup_codes, + } + db.session.commit() + + # Log backup codes regeneration + AuditService.log_action( + action=AuditAction.TOTP_BACKUP_CODES_REGENERATED, + user_id=user.id, + resource_type="authentication_method", + resource_id=totp_method.id, + description="TOTP backup codes regenerated", + ) + + return backup_codes diff --git a/app/services/oidc_service.py b/app/services/oidc_service.py index 462f612..0476d57 100644 --- a/app/services/oidc_service.py +++ b/app/services/oidc_service.py @@ -2,7 +2,7 @@ import logging import secrets import hashlib -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Dict, List, Optional, Tuple from flask import current_app, g @@ -14,6 +14,7 @@ from app.models import ( User, OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCSession, OIDCTokenMetadata ) +from app.models.organization_member import OrganizationMember from app.exceptions.validation_exceptions import ( ValidationError, NotFoundError, BadRequestError ) @@ -121,6 +122,14 @@ class OIDCService: ValidationError: If parameters are invalid NotFoundError: If client not found """ + logger.debug("[OIDC SERVICE] ===========================================") + logger.debug("[OIDC SERVICE] generate_authorization_code called") + logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s", client_id, user_id) + logger.debug("[OIDC SERVICE] redirect_uri=%s", redirect_uri) + logger.debug("[OIDC SERVICE] scope=%s", scope) + logger.debug("[OIDC SERVICE] state=%s, nonce=%s", state, nonce) + # Validate client exists and is active client = OIDCClient.query.filter_by(client_id=client_id).first() @@ -152,14 +161,19 @@ class OIDCService: raise ValidationError("Invalid scopes") # Generate authorization code + logger.debug("[OIDC SERVICE] Generating authorization code...") + logger.debug("[OIDC SERVICE] Current UTC time before code generation: %s", datetime.now(timezone.utc).isoformat() + "Z") code = cls._generate_code() code_hash = cls._hash_value(code) + logger.debug("[OIDC SERVICE] Code generated: %s...", code[:20] if code else None) # Development-only debug logging for PKCE in code creation if current_app.config.get('ENV') == 'development': logger.debug(f"[OIDC] Generate auth code - PKCE: code_challenge={code_challenge is not None}, code_challenge_method={code_challenge_method}") # Create auth code record + logger.debug("[OIDC SERVICE] Creating auth code record with lifetime_seconds=600 (10 minutes)") + logger.debug("[OIDC SERVICE] Current UTC time before creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z") auth_code = OIDCAuthCode.create_code( client_id=client.id, user_id=user_id, @@ -172,6 +186,9 @@ class OIDCService: user_agent=user_agent, lifetime_seconds=600, # 10 minutes ) + logger.debug("[OIDC SERVICE] Auth code created successfully") + logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z") + logger.debug("[OIDC SERVICE] Current UTC time after creating auth code: %s", datetime.now(timezone.utc).isoformat() + "Z") # Log authorization event OIDCAuditService.log_authorization_event( @@ -182,6 +199,9 @@ class OIDCService: scope=valid_scopes, ) + logger.debug("[OIDC SERVICE] generate_authorization_code completed successfully") + logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] ===========================================") return code @classmethod @@ -211,6 +231,12 @@ class OIDCService: InvalidGrantError: If code is invalid ValidationError: If PKCE validation fails """ + logger.debug("[OIDC SERVICE] ===========================================") + logger.debug("[OIDC SERVICE] validate_authorization_code called") + logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] client_id=%s, redirect_uri=%s", client_id, redirect_uri) + logger.debug("[OIDC SERVICE] code_verifier provided: %s", bool(code_verifier)) + # Get client client = OIDCClient.query.filter_by(client_id=client_id).first() @@ -223,6 +249,8 @@ class OIDCService: raise InvalidGrantError("Invalid client") # Hash the provided code and find matching auth code + logger.debug("[OIDC SERVICE] Looking up authorization code...") + logger.debug("[OIDC SERVICE] Current UTC time before code lookup: %s", datetime.now(timezone.utc).isoformat() + "Z") code_hash = cls._hash_value(code) auth_code = OIDCAuthCode.query.filter_by( code_hash=code_hash, @@ -256,8 +284,18 @@ class OIDCService: raise InvalidGrantError("Authorization code already used") # Check expiration + logger.debug("[OIDC SERVICE] Checking if authorization code is expired...") + logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] Auth code expires_at (UTC): %s", auth_code.expires_at.isoformat() + "Z") + # Handle timezone-naive expires_at from database + expires_at = auth_code.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (expires_at - datetime.now(timezone.utc)).total_seconds()) + if auth_code.is_expired(): - logger.error(f"[OIDC] Validate auth code - Code expired: code_hash={code_hash[:20]}..., expires_at={auth_code.expires_at}") + logger.error("[OIDC] Validate auth code - Code expired: code_hash=%s..., expires_at (UTC)=%s, current UTC time=%s", + code_hash[:20], auth_code.expires_at.isoformat() + "Z", datetime.now(timezone.utc).isoformat() + "Z") OIDCAuditService.log_authorization_event( client_id=client_id, user_id=auth_code.user_id, @@ -316,6 +354,9 @@ class OIDCService: "nonce": auth_code.nonce, } + logger.debug("[OIDC SERVICE] validate_authorization_code completed successfully") + logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] ===========================================") return claims, user @classmethod @@ -366,6 +407,12 @@ class OIDCService: """ import hashlib + logger.debug("[OIDC SERVICE] ===========================================") + logger.debug("[OIDC SERVICE] generate_tokens called") + logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s, scope=%s", client_id, user_id, scope) + logger.debug("[OIDC SERVICE] nonce=%s, auth_time=%s", nonce, auth_time) + # Get client client = OIDCClient.query.filter_by(client_id=client_id).first() @@ -377,6 +424,9 @@ class OIDCService: raise InvalidClientError() # Generate access token + logger.debug("[OIDC SERVICE] Generating access token...") + logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600) access_token_jti = OIDCTokenService._generate_jti() access_token = OIDCTokenService.create_access_token( client_id=client_id, @@ -384,8 +434,13 @@ class OIDCService: scope=scope, jti=access_token_jti, ) + logger.debug("[OIDC SERVICE] Access token generated successfully") + logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") # Generate ID token + logger.debug("[OIDC SERVICE] Generating ID token...") + logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600) id_token = OIDCTokenService.create_id_token( client_id=client_id, user_id=user_id, @@ -394,6 +449,8 @@ class OIDCService: access_token=access_token, auth_time=auth_time, ) + logger.debug("[OIDC SERVICE] ID token generated successfully") + logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") # Generate or rotate refresh token if "refresh_token" in (client.grant_types or []): @@ -445,22 +502,28 @@ class OIDCService: client_db_id = client.id # Access token metadata + logger.debug("[OIDC SERVICE] Creating access token metadata...") + access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600) + logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z") OIDCTokenMetadata.create_metadata( client_id=client_db_id, user_id=user_id, token_type="access_token", token_jti=access_token_jti, - expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600), + expires_at=access_token_expires_at, ) # ID token metadata (using access token JTI as reference) + logger.debug("[OIDC SERVICE] Creating ID token metadata...") id_token_jti = OIDCTokenService._generate_jti() + id_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.id_token_lifetime or 3600) + logger.debug("[OIDC SERVICE] ID token expires_at (UTC): %s", id_token_expires_at.isoformat() + "Z") OIDCTokenMetadata.create_metadata( client_id=client_db_id, user_id=user_id, token_type="id_token", token_jti=id_token_jti, - expires_at=datetime.utcnow() + timedelta(seconds=client.id_token_lifetime or 3600), + expires_at=id_token_expires_at, ) # Log token event @@ -483,6 +546,9 @@ class OIDCService: if final_refresh_token: result["refresh_token"] = final_refresh_token + logger.debug("[OIDC SERVICE] generate_tokens completed successfully") + logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] ===========================================") return result @classmethod @@ -511,6 +577,11 @@ class OIDCService: """ import hashlib + logger.debug("[OIDC SERVICE] ===========================================") + logger.debug("[OIDC SERVICE] refresh_access_token called") + logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] client_id=%s, scope=%s", client_id, scope) + # Get client client = OIDCClient.query.filter_by(client_id=client_id).first() @@ -522,6 +593,8 @@ class OIDCService: raise InvalidClientError() # Find refresh token + logger.debug("[OIDC SERVICE] Looking up refresh token...") + logger.debug("[OIDC SERVICE] Current UTC time before refresh token lookup: %s", datetime.now(timezone.utc).isoformat() + "Z") token_hash = hashlib.sha256(refresh_token.encode()).hexdigest() refresh_token_obj = OIDCRefreshToken.query.filter_by( token_hash=token_hash, @@ -542,6 +615,16 @@ class OIDCService: raise InvalidGrantError("Invalid refresh token") # Check if valid + logger.debug("[OIDC SERVICE] Checking if refresh token is valid...") + logger.debug("[OIDC SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + if refresh_token_obj: + logger.debug("[OIDC SERVICE] Refresh token expires_at (UTC): %s", refresh_token_obj.expires_at.isoformat() + "Z") + # Handle timezone-naive expires_at from database + rt_expires_at = refresh_token_obj.expires_at + if rt_expires_at.tzinfo is None: + rt_expires_at = rt_expires_at.replace(tzinfo=timezone.utc) + logger.debug("[OIDC SERVICE] Time until expiration (seconds): %s", (rt_expires_at - datetime.now(timezone.utc)).total_seconds()) + if not refresh_token_obj.is_valid(): OIDCAuditService.log_token_event( client_id=client_id, @@ -563,6 +646,9 @@ class OIDCService: granted_scope = scope or (refresh_token_obj.scope or []) # Generate new access token + logger.debug("[OIDC SERVICE] Generating new access token...") + logger.debug("[OIDC SERVICE] Current UTC time before access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] Access token lifetime (seconds): %s", client.access_token_lifetime or 3600) access_token_jti = OIDCTokenService._generate_jti() access_token = OIDCTokenService.create_access_token( client_id=client_id, @@ -570,14 +656,21 @@ class OIDCService: scope=granted_scope, jti=access_token_jti, ) + logger.debug("[OIDC SERVICE] Access token generated successfully") + logger.debug("[OIDC SERVICE] Current UTC time after access token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") # Generate new ID token + logger.debug("[OIDC SERVICE] Generating new ID token...") + logger.debug("[OIDC SERVICE] Current UTC time before ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] ID token lifetime (seconds): %s", client.id_token_lifetime or 3600) id_token = OIDCTokenService.create_id_token( client_id=client_id, user_id=refresh_token_obj.user_id, scope=granted_scope, access_token=access_token, ) + logger.debug("[OIDC SERVICE] ID token generated successfully") + logger.debug("[OIDC SERVICE] Current UTC time after ID token generation: %s", datetime.now(timezone.utc).isoformat() + "Z") # Rotate refresh token new_refresh, new_hash = OIDCTokenService.create_refresh_token( @@ -590,12 +683,15 @@ class OIDCService: refresh_token_obj.rotate(new_hash) # Store new token metadata + logger.debug("[OIDC SERVICE] Creating access token metadata...") + access_token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=client.access_token_lifetime or 3600) + logger.debug("[OIDC SERVICE] Access token expires_at (UTC): %s", access_token_expires_at.isoformat() + "Z") OIDCTokenMetadata.create_metadata( client_id=client.id, user_id=refresh_token_obj.user_id, token_type="access_token", token_jti=access_token_jti, - expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600), + expires_at=access_token_expires_at, ) # Log refresh event @@ -615,6 +711,17 @@ class OIDCService: "id_token": id_token, "refresh_token": new_refresh, } + + logger.debug("[OIDC SERVICE] refresh_access_token completed successfully") + logger.debug("[OIDC SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat() + "Z") + logger.debug("[OIDC SERVICE] ===========================================") + return { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": client.access_token_lifetime or 3600, + "id_token": id_token, + "refresh_token": new_refresh, + } @classmethod def validate_access_token(cls, token: str, client_id: str = None) -> Dict: @@ -630,10 +737,23 @@ class OIDCService: Raises: InvalidTokenError: If token is invalid """ + logger.debug("[OIDC SERVICE] ===========================================") + logger.debug("[OIDC SERVICE] validate_access_token() called") + logger.debug("[OIDC SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token) + logger.debug("[OIDC SERVICE] Token length: %d", len(token)) + logger.debug("[OIDC SERVICE] Client ID: %s", client_id) + try: + logger.debug("[OIDC SERVICE] Calling OIDCTokenService.validate_access_token()...") claims = OIDCTokenService.validate_access_token(token, client_id) + logger.debug("[OIDC SERVICE] Token validation successful") + logger.debug("[OIDC SERVICE] Token claims: %s", claims) + logger.debug("[OIDC SERVICE] ===========================================") return claims except Exception as e: + logger.error("[OIDC SERVICE] Token validation failed: %s: %s", type(e).__name__, str(e)) + import traceback + logger.error("[OIDC SERVICE] Traceback: %s", traceback.format_exc()) OIDCAuditService.log_event( event_type="token_validation", client_id=client_id, @@ -770,29 +890,67 @@ class OIDCService: Returns: User information dictionary """ + logger.debug("[OIDC SERVICE] ===========================================") + logger.debug("[OIDC SERVICE] get_userinfo() called") + logger.debug("[OIDC SERVICE] Access token (first 50 chars): %s...", access_token[:50] if len(access_token) > 50 else access_token) + logger.debug("[OIDC SERVICE] Access token length: %d", len(access_token)) + + # Validate access token + logger.debug("[OIDC SERVICE] Validating access token...") claims = cls.validate_access_token(access_token) + logger.debug("[OIDC SERVICE] Access token validated successfully") + logger.debug("[OIDC SERVICE] Token claims: %s", claims) user_id = claims.get("sub") + logger.debug("[OIDC SERVICE] User ID from token: %s", user_id) + + logger.debug("[OIDC SERVICE] Querying user from database...") user = User.query.get(user_id) + logger.debug("[OIDC SERVICE] User query result: %s", user) if not user: + logger.error("[OIDC SERVICE] User not found in database: user_id=%s", user_id) raise NotFoundError("User not found") + logger.debug("[OIDC SERVICE] User found: user_id=%s, email=%s, full_name=%s", user.id, user.email, user.full_name) + # Get scopes from token scope_str = claims.get("scope", "") scopes = scope_str.split() if scope_str else [] + logger.debug("[OIDC SERVICE] Scope string from token: '%s'", scope_str) + logger.debug("[OIDC SERVICE] Parsed scopes: %s", scopes) userinfo = {"sub": user_id} + logger.debug("[OIDC SERVICE] Initial userinfo: %s", userinfo) # Add claims based on scope if "profile" in scopes and user.full_name: + logger.debug("[OIDC SERVICE] Found 'profile' in scope, adding name claim") userinfo["name"] = user.full_name + logger.debug("[OIDC SERVICE] Added name: %s", user.full_name) + else: + logger.debug("[OIDC SERVICE] 'profile' not in scope or user.full_name is None: profile_in_scope=%s, full_name=%s", "profile" in scopes, user.full_name) if "email" in scopes: + logger.debug("[OIDC SERVICE] Found 'email' in scope, adding email claims") userinfo["email"] = user.email userinfo["email_verified"] = user.email_verified + logger.debug("[OIDC SERVICE] Added email: %s, email_verified: %s", user.email, user.email_verified) + else: + logger.debug("[OIDC SERVICE] 'email' not in scope") + + if "roles" in scopes: + logger.debug("[OIDC SERVICE] Found 'roles' in scope, adding roles claim") + user_roles = cls._get_user_roles(user) + userinfo["roles"] = user_roles + logger.debug("[OIDC SERVICE] Added roles: %s", user_roles) + else: + logger.debug("[OIDC SERVICE] 'roles' not in scope") + + logger.debug("[OIDC SERVICE] Final userinfo: %s", userinfo) # Log userinfo access + logger.debug("[OIDC SERVICE] Logging userinfo access event...") OIDCAuditService.log_userinfo_event( access_token=access_token, user_id=user_id, @@ -800,5 +958,54 @@ class OIDCService: success=True, scopes_claimed=scopes, ) + logger.debug("[OIDC SERVICE] Userinfo access event logged") + + logger.debug("[OIDC SERVICE] get_userinfo() completed successfully") + logger.debug("[OIDC SERVICE] ===========================================") return userinfo + + @staticmethod + def _get_user_roles(user: User) -> list: + """Get user's organization roles. + + Args: + user: User instance + + Returns: + List of role objects with organization_id and role + """ + logger.debug("[OIDC SERVICE] _get_user_roles() called") + logger.debug("[OIDC SERVICE] User: %s", user) + + roles = [] + + if not user: + logger.debug("[OIDC SERVICE] User is None, returning empty roles list") + return roles + + logger.debug("[OIDC SERVICE] User ID: %s", user.id) + logger.debug("[OIDC SERVICE] User email: %s", user.email) + logger.debug("[OIDC SERVICE] User organization_memberships: %s", user.organization_memberships) + + if user.organization_memberships: + logger.debug("[OIDC SERVICE] User has %d organization memberships", len(user.organization_memberships)) + for idx, member in enumerate(user.organization_memberships): + logger.debug("[OIDC SERVICE] Processing membership %d: member=%s", idx, member) + logger.debug("[OIDC SERVICE] organization_id: %s", member.organization_id) + logger.debug("[OIDC SERVICE] role: %s", member.role) + logger.debug("[OIDC SERVICE] role.value: %s", member.role.value) + + role_entry = { + "organization_id": str(member.organization_id), + "role": member.role.value + } + roles.append(role_entry) + logger.debug("[OIDC SERVICE] Added role entry: %s", role_entry) + else: + logger.debug("[OIDC SERVICE] User has no organization memberships") + + logger.debug("[OIDC SERVICE] Final roles list: %s", roles) + logger.debug("[OIDC SERVICE] _get_user_roles() completed") + + return roles diff --git a/app/services/oidc_session_service.py b/app/services/oidc_session_service.py index e771143..0be391d 100644 --- a/app/services/oidc_session_service.py +++ b/app/services/oidc_session_service.py @@ -3,6 +3,7 @@ import secrets from datetime import datetime, timedelta from typing import Dict, Optional, Tuple +from datetime import timezone from flask import current_app, g from app.extensions import db @@ -219,11 +220,11 @@ class OIDCSessionService: """ from datetime import timedelta - cutoff = datetime.utcnow() - timedelta(hours=older_than_hours) + cutoff = datetime.now(timezone.utc) - timedelta(hours=older_than_hours) # Get expired sessions expired_sessions = OIDCSession.query.filter( - OIDCSession.expires_at < datetime.utcnow(), + OIDCSession.expires_at < datetime.now(timezone.utc), OIDCSession.deleted_at == None ).all() diff --git a/app/services/oidc_token_service.py b/app/services/oidc_token_service.py index e8efdf5..662019f 100644 --- a/app/services/oidc_token_service.py +++ b/app/services/oidc_token_service.py @@ -2,15 +2,20 @@ import hashlib import base64 import secrets -from datetime import datetime, timedelta +import logging +import time +from datetime import datetime, timedelta, timezone from typing import Dict, Optional, Any import jwt from flask import current_app, g from app.models import User, OIDCClient +from app.models.organization_member import OrganizationMember from app.services.oidc_jwks_service import OIDCJWKSService +logger = logging.getLogger(__name__) + class OIDCTokenService: """Service for generating and validating OIDC tokens. @@ -134,7 +139,7 @@ class OIDCTokenService: return lifetimes.get(token_type, 3600) @classmethod - def create_access_token(cls, client_id: str, user_id: str, scope: list, + def create_access_token(cls, client_id: str, user_id: str, scope: list, jti: str = None) -> str: """Create a JWT access token. @@ -147,25 +152,44 @@ class OIDCTokenService: Returns: JWT access token string """ + logger.debug("[OIDC TOKEN SERVICE] ===========================================") + logger.debug("[OIDC TOKEN SERVICE] create_access_token called") + logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat()) + logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id) + logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope) + jti = jti or cls._generate_jti() - now = datetime.utcnow() + now_timestamp = int(time.time()) + now = datetime.now(timezone.utc) + logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat()) + logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp) # Get client for token lifetime client = OIDCClient.query.filter_by(client_id=client_id).first() lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600 + logger.debug("[OIDC TOKEN SERVICE] Access token lifetime (seconds): %s", lifetime) + + exp_timestamp = now_timestamp + lifetime + exp_time = now + timedelta(seconds=lifetime) + logger.debug("[OIDC TOKEN SERVICE] Access token expiration time (UTC): %s", exp_time.isoformat()) + logger.debug("[OIDC TOKEN SERVICE] Access token expiration timestamp: %s", exp_timestamp) + logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime) claims = { "iss": cls._get_issuer(), "sub": user_id, "aud": client_id, - "exp": int((now + timedelta(seconds=lifetime)).timestamp()), - "iat": int(now.timestamp()), - "nbf": int(now.timestamp()), + "exp": exp_timestamp, + "iat": now_timestamp, + "nbf": now_timestamp, "jti": jti, "client_id": client_id, "scope": " ".join(scope) if isinstance(scope, list) else scope, } + logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, nbf=%s", + claims["exp"], claims["iat"], claims["nbf"]) + # Get signing key jwks_service = OIDCJWKSService() signing_key = jwks_service.get_signing_key() @@ -174,6 +198,7 @@ class OIDCTokenService: raise ValueError("No signing key available") # Sign with RS256 + logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...") token = jwt.encode( claims, signing_key.private_key, @@ -181,6 +206,9 @@ class OIDCTokenService: headers={"kid": signing_key.kid} ) + logger.debug("[OIDC TOKEN SERVICE] Access token created successfully") + logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat()) + logger.debug("[OIDC TOKEN SERVICE] ===========================================") return token @classmethod @@ -200,12 +228,30 @@ class OIDCTokenService: Returns: JWT ID token string """ - now = datetime.utcnow() - auth_time = auth_time or int(now.timestamp()) + logger.debug("[OIDC TOKEN SERVICE] ===========================================") + logger.debug("[OIDC TOKEN SERVICE] create_id_token called") + logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat()) + logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id) + logger.debug("[OIDC TOKEN SERVICE] nonce=%s, auth_time=%s", nonce, auth_time) + logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope) + + now_timestamp = int(time.time()) + now = datetime.now(timezone.utc) + logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat()) + logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp) + auth_time = auth_time or now_timestamp + logger.debug("[OIDC TOKEN SERVICE] auth_time (Unix timestamp): %s", auth_time) # Get client for token lifetime client = OIDCClient.query.filter_by(client_id=client_id).first() lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600 + logger.debug("[OIDC TOKEN SERVICE] ID token lifetime (seconds): %s", lifetime) + + exp_timestamp = now_timestamp + lifetime + exp_time = now + timedelta(seconds=lifetime) + logger.debug("[OIDC TOKEN SERVICE] ID token expiration time (UTC): %s", exp_time.isoformat()) + logger.debug("[OIDC TOKEN SERVICE] ID token expiration timestamp: %s", exp_timestamp) + logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime) # Get user for claims user = User.query.get(user_id) @@ -214,11 +260,14 @@ class OIDCTokenService: "iss": cls._get_issuer(), "sub": user_id, "aud": client_id, - "exp": int((now + timedelta(seconds=lifetime)).timestamp()), - "iat": int(now.timestamp()), + "exp": exp_timestamp, + "iat": now_timestamp, "auth_time": auth_time, } + logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, auth_time=%s", + claims["exp"], claims["iat"], claims["auth_time"]) + # Add nonce if provided if nonce: claims["nonce"] = nonce @@ -235,6 +284,10 @@ class OIDCTokenService: if user.full_name: claims["name"] = user.full_name + # Add roles claim if scope is granted + if scope and "roles" in scope: + claims["roles"] = cls._get_user_roles(user) + # Add scope if provided if scope: claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope @@ -247,6 +300,7 @@ class OIDCTokenService: raise ValueError("No signing key available") # Sign with RS256 + logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...") token = jwt.encode( claims, signing_key.private_key, @@ -254,10 +308,32 @@ class OIDCTokenService: headers={"kid": signing_key.kid} ) + logger.debug("[OIDC TOKEN SERVICE] ID token created successfully") + logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat()) + logger.debug("[OIDC TOKEN SERVICE] ===========================================") return token + @staticmethod + def _get_user_roles(user: User) -> list: + """Get user's organization roles. + + Args: + user: User instance + + Returns: + List of role objects with organization_id and role + """ + roles = [] + if user and user.organization_memberships: + for member in user.organization_memberships: + roles.append({ + "organization_id": str(member.organization_id), + "role": member.role.value + }) + return roles + @classmethod - def create_refresh_token(cls, client_id: str, user_id: str, + def create_refresh_token(cls, client_id: str, user_id: str, scope: list = None, access_token_id: str = None) -> str: """Create an opaque refresh token. @@ -270,11 +346,21 @@ class OIDCTokenService: Returns: Opaque refresh token string """ + logger.debug("[OIDC TOKEN SERVICE] ===========================================") + logger.debug("[OIDC TOKEN SERVICE] create_refresh_token called") + logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat()) + logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id) + logger.debug("[OIDC TOKEN SERVICE] scope=%s, access_token_id=%s", scope, access_token_id) + token = cls._generate_opaque_token() + logger.debug("[OIDC TOKEN SERVICE] Refresh token generated: %s...", token[:20] if token else None) # Hash for storage token_hash = cls._hash_token(token) + logger.debug("[OIDC TOKEN SERVICE] Refresh token created successfully") + logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat()) + logger.debug("[OIDC TOKEN SERVICE] ===========================================") return token, token_hash @classmethod @@ -292,54 +378,91 @@ class OIDCTokenService: jwt.ExpiredSignatureError: If token is expired jwt.InvalidTokenError: If token is invalid """ + logger.debug("[OIDC TOKEN SERVICE] ===========================================") + logger.debug("[OIDC TOKEN SERVICE] verify_token_signature() called") + logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token) + logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token)) + # Get the JWKS with public keys + logger.debug("[OIDC TOKEN SERVICE] Getting JWKS...") jwks_service = OIDCJWKSService() - jwks = jwks_service.get_jwks() + jwks = jwks_service.get_jwks(include_private_keys=True) + logger.debug("[OIDC TOKEN SERVICE] JWKS retrieved: %d keys", len(jwks.get("keys", []))) # Get the key ID from token header try: + logger.debug("[OIDC TOKEN SERVICE] Getting unverified token header...") unverified_header = jwt.get_unverified_header(token) - except jwt.DecodeError: + logger.debug("[OIDC TOKEN SERVICE] Unverified header: %s", unverified_header) + except jwt.DecodeError as e: + logger.error("[OIDC TOKEN SERVICE] Failed to decode token header: %s", str(e)) raise jwt.InvalidTokenError("Invalid token header") kid = unverified_header.get("kid") + logger.debug("[OIDC TOKEN SERVICE] Key ID (kid) from token header: %s", kid) # Find the matching public key + logger.debug("[OIDC TOKEN SERVICE] Searching for matching public key...") public_key = None - for key in jwks.get("keys", []): + for idx, key in enumerate(jwks.get("keys", [])): + logger.debug("[OIDC TOKEN SERVICE] Checking key %d: kid=%s", idx, key.get("kid")) if key.get("kid") == kid: + logger.debug("[OIDC TOKEN SERVICE] Found matching key at index %d", idx) try: from cryptography.hazmat.primitives import serialization from cryptography.hazmat.backends import default_backend + logger.debug("[OIDC TOKEN SERVICE] Loading PEM public key...") public_key = serialization.load_pem_public_key( - key["public_key"].encode() if isinstance(key["public_key"], str) + key["public_key"].encode() if isinstance(key["public_key"], str) else key["public_key"], backend=default_backend() ) + logger.debug("[OIDC TOKEN SERVICE] Public key loaded successfully") break - except (ImportError, Exception): + except (ImportError, Exception) as e: + logger.error("[OIDC TOKEN SERVICE] Failed to load public key: %s: %s", type(e).__name__, str(e)) continue if not public_key: + logger.error("[OIDC TOKEN SERVICE] No matching public key found for kid=%s", kid) raise jwt.InvalidSignatureError(f"Key with kid={kid} not found") - # Verify the signature - claims = jwt.decode( - token, - public_key, - algorithms=["RS256"], - audience=None, # We'll validate audience separately - issuer=cls._get_issuer(), - options={ - "verify_signature": True, - "verify_exp": True, - "verify_aud": False, # Handle audience manually - "verify_iss": False, # Handle issuer manually - } - ) + logger.debug("[OIDC TOKEN SERVICE] Public key found, verifying signature...") - return claims + # Verify the signature + try: + claims = jwt.decode( + token, + public_key, + algorithms=["RS256"], + audience=None, # We'll validate audience separately + issuer=cls._get_issuer(), + options={ + "verify_signature": True, + "verify_exp": True, + "verify_aud": False, # Handle audience manually + "verify_iss": False, # Handle issuer manually + } + ) + logger.debug("[OIDC TOKEN SERVICE] Signature verification successful") + logger.debug("[OIDC TOKEN SERVICE] Decoded claims: %s", claims) + logger.debug("[OIDC TOKEN SERVICE] ===========================================") + return claims + except jwt.ExpiredSignatureError as e: + logger.error("[OIDC TOKEN SERVICE] Token has expired: %s", str(e)) + raise + except jwt.InvalidSignatureError as e: + logger.error("[OIDC TOKEN SERVICE] Invalid token signature: %s", str(e)) + raise + except jwt.InvalidTokenError as e: + logger.error("[OIDC TOKEN SERVICE] Invalid token: %s: %s", type(e).__name__, str(e)) + raise + except Exception as e: + logger.error("[OIDC TOKEN SERVICE] Unexpected error during token verification: %s: %s", type(e).__name__, str(e)) + import traceback + logger.error("[OIDC TOKEN SERVICE] Traceback: %s", traceback.format_exc()) + raise @classmethod def decode_token(cls, token: str, verify: bool = False) -> Dict: @@ -378,16 +501,41 @@ class OIDCTokenService: jwt.InvalidTokenError: If token is invalid ValueError: If token is expired or audience mismatch """ + logger.debug("[OIDC TOKEN SERVICE] ===========================================") + logger.debug("[OIDC TOKEN SERVICE] validate_access_token() called") + logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token) + logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token)) + logger.debug("[OIDC TOKEN SERVICE] Client ID: %s", client_id) + + # Verify token signature + logger.debug("[OIDC TOKEN SERVICE] Verifying token signature...") claims = cls.verify_token_signature(token) + logger.debug("[OIDC TOKEN SERVICE] Token signature verified") + logger.debug("[OIDC TOKEN SERVICE] Claims: %s", claims) # Check expiration - if claims.get("exp", 0) < datetime.utcnow().timestamp(): + exp = claims.get("exp", 0) + now_timestamp = int(time.time()) + + if exp < now_timestamp: + logger.error("[OIDC TOKEN SERVICE] Token has expired") raise ValueError("Token has expired") # Validate audience if client_id provided + aud = claims.get("aud") + logger.debug("[OIDC TOKEN SERVICE] Token audience (aud): %s", aud) + logger.debug("[OIDC TOKEN SERVICE] Expected client_id: %s", client_id) + if client_id: - if claims.get("aud") != client_id: + if aud != client_id: + logger.error("[OIDC TOKEN SERVICE] Audience mismatch: expected=%s, got=%s", client_id, aud) raise ValueError("Invalid audience") + logger.debug("[OIDC TOKEN SERVICE] Audience validation passed") + else: + logger.debug("[OIDC TOKEN SERVICE] No client_id provided, skipping audience validation") + + logger.debug("[OIDC TOKEN SERVICE] validate_access_token() completed successfully") + logger.debug("[OIDC TOKEN SERVICE] ===========================================") return claims @@ -410,11 +558,17 @@ class OIDCTokenService: claims = cls.validate_access_token(token, client_id) # Calculate remaining time - now = datetime.utcnow().timestamp() + now_timestamp = int(time.time()) + now = datetime.now(timezone.utc) exp = claims.get("exp", 0) iat = claims.get("iat", 0) - result["active"] = exp > now + logger.debug("[OIDC TOKEN SERVICE] Introspection - Current UTC time: %s", now.isoformat()) + logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration timestamp: %s", exp) + logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration datetime (UTC): %s", datetime.fromtimestamp(exp, tz=timezone.utc).isoformat()) + logger.debug("[OIDC TOKEN SERVICE] Introspection - Time until expiration: %s seconds", exp - now_timestamp) + + result["active"] = exp > now_timestamp result.update({ "iss": claims.get("iss"), "sub": claims.get("sub"), @@ -429,8 +583,8 @@ class OIDCTokenService: }) # Add expiry in seconds - if exp > now: - result["exp"] = int(exp - now) + if exp > now_timestamp: + result["exp"] = int(exp - now_timestamp) except (jwt.InvalidTokenError, ValueError) as e: result["active"] = False diff --git a/app/services/totp_service.py b/app/services/totp_service.py new file mode 100644 index 0000000..dc44533 --- /dev/null +++ b/app/services/totp_service.py @@ -0,0 +1,188 @@ +"""TOTP (Time-based One-Time Password) service.""" +import base64 +import io +import logging +import secrets +from typing import Tuple + +import pyotp +from app.extensions import bcrypt + +logger = logging.getLogger(__name__) + + +class TOTPService: + """Service for TOTP operations.""" + + @staticmethod + def generate_secret() -> str: + """ + Generate a new TOTP secret. + + Returns: + Base32 encoded secret (32 characters) + + Note: + The secret is generated using cryptographically secure random bytes + and encoded in base32 format for compatibility with authenticator apps. + """ + # Generate 20 random bytes (160 bits) and encode as base32 + random_bytes = secrets.token_bytes(20) + secret = base64.b32encode(random_bytes).decode("utf-8") + logger.debug(f"Generated new TOTP secret: {secret[:8]}...") + return secret + + @staticmethod + def generate_provisioning_uri(user_email: str, secret: str, issuer: str = "Gatehouse") -> str: + """ + Generate provisioning URI for QR code. + + Args: + user_email: User's email address + secret: TOTP secret (base32 encoded) + issuer: Issuer name (default: "Gatehouse") + + Returns: + otpauth:// URI for QR code generation + + Example: + >>> uri = TOTPService.generate_provisioning_uri("user@example.com", "JBSWY3DPEHPK3PXP") + >>> print(uri) + otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse + """ + totp = pyotp.TOTP(secret) + uri = totp.provisioning_uri(name=user_email, issuer_name=issuer) + logger.debug(f"Generated provisioning URI for user: {user_email}") + return uri + + @staticmethod + def verify_code(secret: str, code: str, window: int = 1) -> bool: + """ + Verify a TOTP code against the secret. + + Args: + secret: TOTP secret (base32 encoded) + code: 6-digit TOTP code to verify + window: Time window for code validation (default: 1, allows codes from previous/next time steps) + + Returns: + True if code is valid, False otherwise + + Note: + The window parameter allows for clock skew between the server + and the authenticator app. A window of 1 allows codes from + the previous, current, and next 30-second intervals. + """ + totp = pyotp.TOTP(secret) + is_valid = totp.verify(code, valid_window=window) + logger.debug(f"TOTP code verification: valid={is_valid}, window={window}") + return is_valid + + @staticmethod + def generate_backup_codes(count: int = 10) -> Tuple[list[str], list[str]]: + """ + Generate backup codes for TOTP recovery. + + Args: + count: Number of backup codes to generate (default: 10) + + Returns: + Tuple of (plain_codes, hashed_codes) + - plain_codes: List of plain text backup codes (for display to user) + - hashed_codes: List of bcrypt hashed backup codes (for storage) + + Note: + Backup codes are 16-character alphanumeric codes that can be used + to recover access if the TOTP device is lost. Each code can only + be used once. + """ + plain_codes = [] + hashed_codes = [] + + for _ in range(count): + # Generate a 16-character alphanumeric code + code = secrets.token_hex(8).upper() + plain_codes.append(code) + + # Hash the code using bcrypt + hashed_code = bcrypt.generate_password_hash(code).decode("utf-8") + hashed_codes.append(hashed_code) + + logger.debug(f"Generated {count} backup codes") + return plain_codes, hashed_codes + + @staticmethod + def verify_backup_code(hashed_codes: list[str], code: str) -> Tuple[bool, list[str]]: + """ + Verify and consume a backup code. + + Args: + hashed_codes: List of bcrypt hashed backup codes + code: Plain text backup code to verify + + Returns: + Tuple of (is_valid, remaining_codes) + - is_valid: True if code was valid and consumed, False otherwise + - remaining_codes: List of remaining hashed codes (with consumed code removed) + + Note: + Once a backup code is used, it is removed from the list and cannot + be used again. This ensures each code is single-use. + """ + remaining_codes = [] + + for hashed_code in hashed_codes: + if bcrypt.check_password_hash(hashed_code, code): + # Code found and valid - don't add to remaining codes (consumed) + logger.debug("Backup code verified and consumed") + return True, remaining_codes + else: + # Code doesn't match - keep it in remaining codes + remaining_codes.append(hashed_code) + + logger.debug("Backup code verification failed") + return False, remaining_codes + + @staticmethod + def generate_qr_code_data_uri(provisioning_uri: str) -> str: + """ + Generate QR code as data URI for frontend display. + + Args: + provisioning_uri: otpauth:// URI to encode in QR code + + Returns: + Base64 encoded PNG image as data URI (data:image/png;base64,...) + + Note: + If the qrcode library is not installed, returns a placeholder message. + Install with: pip install qrcode[pil] + """ + try: + import qrcode + + # Create QR code + qr = qrcode.QRCode( + version=1, + error_correction=qrcode.constants.ERROR_CORRECT_L, + box_size=10, + border=4, + ) + qr.add_data(provisioning_uri) + qr.make(fit=True) + + # Generate image + img = qr.make_image(fill_color="black", back_color="white") + + # Convert to base64 + buffer = io.BytesIO() + img.save(buffer, format="PNG") + img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") + + data_uri = f"data:image/png;base64,{img_base64}" + logger.debug("Generated QR code data URI") + return data_uri + + except ImportError: + logger.warning("qrcode library not installed, returning placeholder") + return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]" diff --git a/app/utils/constants.py b/app/utils/constants.py index 2425d7e..86c662d 100644 --- a/app/utils/constants.py +++ b/app/utils/constants.py @@ -24,6 +24,7 @@ class AuthMethodType(str, Enum): """Authentication method types.""" PASSWORD = "password" + TOTP = "totp" GOOGLE = "google" GITHUB = "github" MICROSOFT = "microsoft" @@ -66,6 +67,13 @@ class AuditAction(str, Enum): # Auth method actions AUTH_METHOD_ADD = "auth.method.add" AUTH_METHOD_REMOVE = "auth.method.remove" + TOTP_ENROLL_INITIATED = "totp.enroll.initiated" + TOTP_ENROLL_COMPLETED = "totp.enroll.completed" + TOTP_VERIFY_SUCCESS = "totp.verify.success" + TOTP_VERIFY_FAILED = "totp.verify.failed" + TOTP_DISABLED = "totp.disabled" + TOTP_BACKUP_CODE_USED = "totp.backup_code.used" + TOTP_BACKUP_CODES_REGENERATED = "totp.backup_codes.regenerated" class OIDCGrantType(str, Enum): diff --git a/oidc_test.sh b/oidc_test.sh new file mode 100644 index 0000000..f5abb1a --- /dev/null +++ b/oidc_test.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +set -euo pipefail + +ISSUER="https://oidctest.wsweet.org" +CLIENT_ID="secret" +CLIENT_SECRET="tardis" +REDIRECT_URI="http://127.0.0.1:5556/callback" +SCOPE="openid profile email offline_access" + +# --------------------------- +# Discover OIDC endpoints +# --------------------------- +DISCOVERY=$(curl -s "$ISSUER/.well-known/openid-configuration") + +AUTH_ENDPOINT=$(echo "$DISCOVERY" | jq -r .authorization_endpoint) +TOKEN_ENDPOINT=$(echo "$DISCOVERY" | jq -r .token_endpoint) +USERINFO_ENDPOINT=$(echo "$DISCOVERY" | jq -r .userinfo_endpoint) + +echo "Auth endpoint : $AUTH_ENDPOINT" +echo "Token endpoint: $TOKEN_ENDPOINT" +echo + +# --------------------------- +# PKCE +# --------------------------- +CODE_VERIFIER=$(openssl rand -base64 32 | tr -d '=+/') +CODE_CHALLENGE=$(echo -n "$CODE_VERIFIER" | openssl dgst -sha256 -binary | openssl base64 | tr -d '=+/' | tr '/+' '_-') + +STATE=$(openssl rand -hex 16) +NONCE=$(openssl rand -hex 16) + +# --------------------------- +# Build auth URL +# --------------------------- +AUTH_URL="$AUTH_ENDPOINT?response_type=code\ +&client_id=$CLIENT_ID\ +&redirect_uri=$(printf '%s' "$REDIRECT_URI" | jq -s -R -r @uri)\ +&scope=$(printf '%s' "$SCOPE" | jq -s -R -r @uri)\ +&state=$STATE\ +&nonce=$NONCE\ +&code_challenge=$CODE_CHALLENGE\ +&code_challenge_method=S256" + +echo "Open this URL in a browser:" +echo +echo "$AUTH_URL" +echo +echo "After login you will be redirected to:" +echo "$REDIRECT_URI?code=XXXX&state=YYYY" +echo +read -p "Paste the full redirect URL: " REDIRECT + +CODE=$(echo "$REDIRECT" | sed -n 's/.*code=\([^&]*\).*/\1/p') +RETURNED_STATE=$(echo "$REDIRECT" | sed -n 's/.*state=\([^&]*\).*/\1/p') + +if [ "$RETURNED_STATE" != "$STATE" ]; then + echo "STATE MISMATCH" + exit 1 +fi + +# --------------------------- +# Exchange code for tokens +# --------------------------- +TOKENS=$(curl -s -X POST "$TOKEN_ENDPOINT" \ + -u "$CLIENT_ID:$CLIENT_SECRET" \ + -H "Content-Type: application/x-www-form-urlencoded" \ + -d "grant_type=authorization_code" \ + -d "code=$CODE" \ + -d "redirect_uri=$REDIRECT_URI" \ + -d "code_verifier=$CODE_VERIFIER") + +echo +echo "Token response:" +echo "$TOKENS" | jq . + +ACCESS_TOKEN=$(echo "$TOKENS" | jq -r .access_token) +ID_TOKEN=$(echo "$TOKENS" | jq -r .id_token) + +# --------------------------- +# JWT decode function +# --------------------------- +decode() { + echo "$1" | awk -F. '{print $2}' | tr '_-' '/+' | base64 -d 2>/dev/null | jq . +} + +echo +echo "================ ID TOKEN ================" +decode "$ID_TOKEN" + +echo +echo "============== ACCESS TOKEN ==============" +decode "$ACCESS_TOKEN" + +# --------------------------- +# Userinfo (optional) +# --------------------------- +if [ "$USERINFO_ENDPOINT" != "null" ]; then + echo + echo "=============== USERINFO =================" + curl -s -H "Authorization: Bearer $ACCESS_TOKEN" "$USERINFO_ENDPOINT" | jq . +fi + diff --git a/requirements/base.txt b/requirements/base.txt index ffbaa0f..af22369 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -16,6 +16,7 @@ marshmallow-sqlalchemy==0.29.0 # Security bcrypt==4.1.2 Flask-Bcrypt==1.0.1 +pyotp==2.9.0 # JWT / OIDC PyJWT==2.8.0 diff --git a/test-container/app/index.html b/test-container/app/index.html new file mode 100644 index 0000000..c313423 --- /dev/null +++ b/test-container/app/index.html @@ -0,0 +1,8 @@ + + +
+User: __USER__
+Email: __EMAIL__
+ + diff --git a/test-container/docker-compose.yml b/test-container/docker-compose.yml new file mode 100644 index 0000000..f37a34f --- /dev/null +++ b/test-container/docker-compose.yml @@ -0,0 +1,92 @@ +version: "3.9" + +services: + nginx: + image: nginx:1.27-alpine + container_name: app-nginx + volumes: + - ./app:/usr/share/nginx/html:ro + - ./nginx.conf:/etc/nginx/conf.d/default.conf:ro + + expose: + - "80" + + # oauth2-proxy: + # image: quay.io/oauth2-proxy/oauth2-proxy:v7.7.1 + # container_name: oauth2-proxy + # depends_on: + # - nginx + # ports: + # - "8086:4180" + # environment: + + # # ----- Logging ----- + # OAUTH2_PROXY_LOG_LEVEL: "debug" + # OAUTH2_PROXY_AUTH_LOGGING: "true" + # OAUTH2_PROXY_REQUEST_LOGGING: "true" + # OAUTH2_PROXY_STANDARD_LOGGING: "true" + + + # # ----- Gatehouse OIDC ----- + # OAUTH2_PROXY_PROVIDER: oidc + # OAUTH2_PROXY_OIDC_ISSUER_URL: "http://192.168.64.7:8888" + # OAUTH2_PROXY_CLIENT_ID: "acme-portal-001" + # OAUTH2_PROXY_CLIENT_SECRET: "acme_secret_portal_2024" + # OAUTH2_PROXY_REDIRECT_URL: "http://192.168.64.7:8086/oauth2/callback" + + # # ----- Session ----- + # OAUTH2_PROXY_COOKIE_SECRET: "afhXcfftf5qKezJ217qhED7U4UeVyqBHd7lhISNGpXo=" + # OAUTH2_PROXY_COOKIE_SECURE: "false" + # OAUTH2_PROXY_COOKIE_SAMESITE: "lax" + # OAUTH2_PROXY_EMAIL_DOMAINS: "*" + # OAUTH2_PROXY_HTTP_ADDRESS: "0.0.0.0:4180" + + + # # ----- Upstream ----- + # OAUTH2_PROXY_UPSTREAMS: "http://nginx:80" + + # # ----- Identity headers ----- + # OAUTH2_PROXY_SET_XAUTHREQUEST: "true" + # OAUTH2_PROXY_PASS_ACCESS_TOKEN: "true" + # OAUTH2_PROXY_PASS_AUTHORIZATION_HEADER: "true" + + # # ----- Claim mapping (Gatehouse) ----- + # OAUTH2_PROXY_OIDC_EMAIL_CLAIM: "email" + # OAUTH2_PROXY_USER_ID_CLAIM: "preferred_username" + # OAUTH2_PROXY_OIDC_GROUPS_CLAIM: "roles" + + # OAUTH2_PROXY_SKIP_PROVIDER_BUTTON: "true" + + + # OAUTH2_PROXY_SCOPE: "openid email profile" + # # OAUTH2_PROXY_OIDC_EMAIL_CLAIM: "email" + # OAUTH2_PROXY_INSECURE_OIDC_ALLOW_UNVERIFIED_EMAIL: "true" + + + oauth2-proxy: + image: quay.io/oauth2-proxy/oauth2-proxy:v7.7.1 + container_name: oauth2-proxy + depends_on: [nginx] + ports: + - "8086:4180" + command: + - --provider=oidc + - --oidc-issuer-url=http://192.168.64.7:8888 + - --client-id=acme-portal-001 + - --client-secret=acme_secret_portal_2024 + - --redirect-url=http://192.168.64.7:8086/oauth2/callback + - --scope=openid email profile + - --email-domain=* + - --upstream=http://nginx:80 + - --http-address=0.0.0.0:4180 + - --cookie-secret=afhXcfftf5qKezJ217qhED7U4UeVyqBHd7lhISNGpXo= + - --cookie-secure=false + - --cookie-samesite=lax + - --skip-provider-button=true + - --standard-logging=true + - --request-logging=true + - --auth-logging=true + - --set-xauthrequest=true + - --pass-authorization-header=true # optional (token passthrough) + - --pass-access-token=true # optional (token passthrough) + - --pass-user-headers=true diff --git a/test-container/nginx.conf b/test-container/nginx.conf new file mode 100644 index 0000000..b079024 --- /dev/null +++ b/test-container/nginx.conf @@ -0,0 +1,23 @@ +server { + listen 80; + server_name _; + + root /usr/share/nginx/html; + index index.html; + + location / { + try_files $uri $uri/ =404; + } + + location = /whoami { + default_type text/plain; + return 200 +"user: $http_x_auth_request_user +email: $http_x_auth_request_email +preferred_username: $http_x_forwarded_preferred_username +x-forwarded-user: $http_x_forwarded_user +x-forwarded-email: $http_x_forwarded_email +authorization: $http_authorization +"; + } +} diff --git a/tests/unit/test_services/test_totp_service.py b/tests/unit/test_services/test_totp_service.py new file mode 100644 index 0000000..78b0bd6 --- /dev/null +++ b/tests/unit/test_services/test_totp_service.py @@ -0,0 +1,285 @@ +"""Unit tests for TOTPService.""" +import base64 +import pytest +from app.services.totp_service import TOTPService + + +@pytest.mark.unit +class TestTOTPService: + """Tests for TOTPService.""" + + # Test generate_secret() + def test_generate_secret_returns_string(self): + """Test that generate_secret returns a string.""" + secret = TOTPService.generate_secret() + assert isinstance(secret, str) + + def test_generate_secret_length(self): + """Test that generate_secret returns a 32-character string.""" + secret = TOTPService.generate_secret() + assert len(secret) == 32 + + def test_generate_secret_base32_encoded(self): + """Test that generate_secret returns a base32 encoded string.""" + secret = TOTPService.generate_secret() + # Base32 characters are A-Z and 2-7 + valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") + assert all(c in valid_chars for c in secret) + + def test_generate_secret_unique(self): + """Test that generate_secret produces unique secrets.""" + secret1 = TOTPService.generate_secret() + secret2 = TOTPService.generate_secret() + assert secret1 != secret2 + + # Test generate_provisioning_uri() + def test_generate_provisioning_uri_format(self): + """Test that provisioning URI is generated correctly.""" + email = "user@example.com" + secret = "JBSWY3DPEHPK3PXP" + issuer = "Gatehouse" + + uri = TOTPService.generate_provisioning_uri(email, secret, issuer) + + assert isinstance(uri, str) + assert uri.startswith("otpauth://totp/") + + def test_generate_provisioning_uri_contains_email(self): + """Test that provisioning URI contains the user email.""" + email = "user@example.com" + secret = "JBSWY3DPEHPK3PXP" + issuer = "Gatehouse" + + uri = TOTPService.generate_provisioning_uri(email, secret, issuer) + + assert email in uri + + def test_generate_provisioning_uri_contains_secret(self): + """Test that provisioning URI contains the secret.""" + email = "user@example.com" + secret = "JBSWY3DPEHPK3PXP" + issuer = "Gatehouse" + + uri = TOTPService.generate_provisioning_uri(email, secret, issuer) + + assert secret in uri + + def test_generate_provisioning_uri_contains_issuer(self): + """Test that provisioning URI contains the issuer.""" + email = "user@example.com" + secret = "JBSWY3DPEHPK3PXP" + issuer = "Gatehouse" + + uri = TOTPService.generate_provisioning_uri(email, secret, issuer) + + assert issuer in uri + + def test_generate_provisioning_uri_custom_issuer(self): + """Test that provisioning URI uses custom issuer.""" + email = "user@example.com" + secret = "JBSWY3DPEHPK3PXP" + custom_issuer = "MyApp" + + uri = TOTPService.generate_provisioning_uri(email, secret, custom_issuer) + + assert custom_issuer in uri + + # Test verify_code() + def test_verify_code_valid(self): + """Test that a valid TOTP code is accepted.""" + secret = TOTPService.generate_secret() + # Generate a valid code using pyotp + import pyotp + totp = pyotp.TOTP(secret) + valid_code = totp.now() + + result = TOTPService.verify_code(secret, valid_code) + + assert result is True + + def test_verify_code_invalid(self): + """Test that an invalid TOTP code is rejected.""" + secret = TOTPService.generate_secret() + invalid_code = "000000" + + result = TOTPService.verify_code(secret, invalid_code) + + assert result is False + + def test_verify_code_window_parameter(self): + """Test that the time window parameter works correctly.""" + secret = TOTPService.generate_secret() + import pyotp + totp = pyotp.TOTP(secret) + + # Get current code + current_code = totp.now() + + # Verify with window=1 (default) - should accept current code + result = TOTPService.verify_code(secret, current_code, window=1) + assert result is True + + # Verify with window=0 - should only accept exact time match + result = TOTPService.verify_code(secret, current_code, window=0) + assert result is True + + def test_verify_code_wrong_length(self): + """Test that codes with wrong length are rejected.""" + secret = TOTPService.generate_secret() + wrong_length_code = "12345" # 5 digits instead of 6 + + result = TOTPService.verify_code(secret, wrong_length_code) + + assert result is False + + # Test generate_backup_codes() + def test_generate_backup_codes_default_count(self): + """Test that generate_backup_codes generates 10 codes by default.""" + plain_codes, hashed_codes = TOTPService.generate_backup_codes() + + assert len(plain_codes) == 10 + assert len(hashed_codes) == 10 + + def test_generate_backup_codes_custom_count(self): + """Test that generate_backup_codes generates the specified number of codes.""" + count = 5 + plain_codes, hashed_codes = TOTPService.generate_backup_codes(count) + + assert len(plain_codes) == count + assert len(hashed_codes) == count + + def test_generate_backup_codes_plain_are_strings(self): + """Test that plain backup codes are strings.""" + plain_codes, hashed_codes = TOTPService.generate_backup_codes() + + assert all(isinstance(code, str) for code in plain_codes) + + def test_generate_backup_codes_plain_length(self): + """Test that plain backup codes are 16 characters long.""" + plain_codes, hashed_codes = TOTPService.generate_backup_codes() + + assert all(len(code) == 16 for code in plain_codes) + + def test_generate_backup_codes_hashed_different_from_plain(self): + """Test that hashed codes are different from plain codes.""" + plain_codes, hashed_codes = TOTPService.generate_backup_codes() + + for plain, hashed in zip(plain_codes, hashed_codes): + assert plain != hashed + + def test_generate_backup_codes_are_bcrypt_hashes(self): + """Test that hashed codes are bcrypt hashes.""" + plain_codes, hashed_codes = TOTPService.generate_backup_codes() + + # Bcrypt hashes start with $2a$, $2b$, or $2y$ + for hashed in hashed_codes: + assert hashed.startswith("$2") + + def test_generate_backup_codes_unique(self): + """Test that generated backup codes are unique.""" + plain_codes, hashed_codes = TOTPService.generate_backup_codes() + + assert len(set(plain_codes)) == len(plain_codes) + assert len(set(hashed_codes)) == len(hashed_codes) + + # Test verify_backup_code() + def test_verify_backup_code_valid(self): + """Test that a valid backup code is accepted and removed.""" + plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=3) + code_to_verify = plain_codes[0] + + is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify) + + assert is_valid is True + assert len(remaining_codes) == 2 + + def test_verify_backup_code_invalid(self): + """Test that an invalid backup code is rejected.""" + plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=3) + invalid_code = "INVALIDCODE1234" + + is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, invalid_code) + + assert is_valid is False + assert len(remaining_codes) == 3 + + def test_verify_backup_code_remaining_updated(self): + """Test that the remaining codes list is updated correctly.""" + plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=5) + code_to_verify = plain_codes[2] + + is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify) + + assert is_valid is True + # The verified code should be removed + assert len(remaining_codes) == 4 + # The remaining codes should not include the verified code's hash + assert hashed_codes[2] not in remaining_codes + + def test_verify_backup_code_case_sensitive(self): + """Test that backup code verification is case sensitive.""" + plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=1) + code_to_verify = plain_codes[0].lower() # Convert to lowercase + + is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify) + + assert is_valid is False + assert len(remaining_codes) == 1 + + def test_verify_backup_code_single_use(self): + """Test that a backup code can only be used once.""" + plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=1) + code_to_verify = plain_codes[0] + + # First use - should succeed + is_valid1, remaining1 = TOTPService.verify_backup_code(hashed_codes, code_to_verify) + assert is_valid1 is True + assert len(remaining1) == 0 + + # Second use - should fail (code already consumed) + is_valid2, remaining2 = TOTPService.verify_backup_code(remaining1, code_to_verify) + assert is_valid2 is False + assert len(remaining2) == 0 + + # Test generate_qr_code_data_uri() + def test_generate_qr_code_data_uri_format(self): + """Test that a data URI is generated.""" + provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse" + + data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri) + + assert isinstance(data_uri, str) + + def test_generate_qr_code_data_uri_starts_with_prefix(self): + """Test that the data URI starts with the correct prefix.""" + provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse" + + data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri) + + assert data_uri.startswith("data:image/png;base64,") + + def test_generate_qr_code_data_uri_contains_base64(self): + """Test that the data URI contains base64 encoded data.""" + provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse" + + data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri) + + # Extract the base64 part (after the prefix) + base64_part = data_uri.split("data:image/png;base64,")[1] + + # Verify it's valid base64 + try: + base64.b64decode(base64_part) + assert True + except Exception: + assert False, "Data URI does not contain valid base64 data" + + def test_generate_qr_code_data_uri_different_uris(self): + """Test that different provisioning URIs generate different QR codes.""" + uri1 = "otpauth://totp/Gatehouse:user1@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse" + uri2 = "otpauth://totp/Gatehouse:user2@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse" + + data_uri1 = TOTPService.generate_qr_code_data_uri(uri1) + data_uri2 = TOTPService.generate_qr_code_data_uri(uri2) + + assert data_uri1 != data_uri2