major checkpoint
This commit is contained in:
+24
-10
@@ -3,12 +3,21 @@ import os
|
||||
import logging
|
||||
from flask import Flask
|
||||
from config import get_config
|
||||
from app.extensions import db, migrate, bcrypt, cors, ma, limiter, session
|
||||
from app.extensions import db, migrate, bcrypt, ma, limiter, session
|
||||
from app.middleware import RequestIDMiddleware, SecurityHeadersMiddleware, setup_cors
|
||||
from app.exceptions.base import BaseAPIException
|
||||
from app.utils.response import api_response
|
||||
import redis
|
||||
|
||||
# Configure SQLAlchemy logging BEFORE any database operations
|
||||
# This must be done before db.init_app() to prevent verbose logging
|
||||
_log_level_env = os.getenv("SQLALCHEMY_LOG_LEVEL", "WARNING").upper()
|
||||
_sqlalchemy_log_level = getattr(logging, _log_level_env, logging.WARNING)
|
||||
logging.getLogger('sqlalchemy').setLevel(_sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(_sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.dialects').setLevel(_sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.pool').setLevel(_sqlalchemy_log_level)
|
||||
|
||||
|
||||
def create_app(config_name=None):
|
||||
"""
|
||||
@@ -53,12 +62,9 @@ def initialize_extensions(app):
|
||||
# Security
|
||||
bcrypt.init_app(app)
|
||||
|
||||
# CORS
|
||||
cors.init_app(
|
||||
app,
|
||||
origins=app.config.get("CORS_ORIGINS", []),
|
||||
supports_credentials=app.config.get("CORS_SUPPORTS_CREDENTIALS", True),
|
||||
)
|
||||
# CORS - using custom middleware only (see app/middleware/cors.py)
|
||||
# Flask-CORS disabled to avoid conflicts
|
||||
# cors.init_app(app)
|
||||
|
||||
# Marshmallow
|
||||
ma.init_app(app)
|
||||
@@ -84,15 +90,19 @@ def setup_middleware(app):
|
||||
"""Setup application middleware."""
|
||||
RequestIDMiddleware(app)
|
||||
SecurityHeadersMiddleware(app)
|
||||
setup_cors(app, cors)
|
||||
setup_cors(app)
|
||||
|
||||
|
||||
def register_blueprints(app):
|
||||
"""Register application blueprints."""
|
||||
from app.api import register_api_blueprints
|
||||
from app.api.oidc import oidc_bp
|
||||
|
||||
register_api_blueprints(app)
|
||||
|
||||
# Register OIDC blueprint at root level
|
||||
app.register_blueprint(oidc_bp)
|
||||
|
||||
|
||||
def register_error_handlers(app):
|
||||
"""Register error handlers."""
|
||||
@@ -169,7 +179,11 @@ def setup_logging(app):
|
||||
|
||||
app.logger.setLevel(log_level)
|
||||
|
||||
# Reduce SQLAlchemy logging noise
|
||||
logging.getLogger('sqlalchemy').setLevel(logging.WARNING)
|
||||
# Configure SQLAlchemy logging level (also set at module level before DB init)
|
||||
sqlalchemy_log_level = getattr(logging, app.config.get("SQLALCHEMY_LOG_LEVEL", "WARNING"), logging.WARNING)
|
||||
logging.getLogger('sqlalchemy').setLevel(sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.dialects').setLevel(sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.pool').setLevel(sqlalchemy_log_level)
|
||||
|
||||
app.logger.info("Application startup")
|
||||
|
||||
+964
@@ -0,0 +1,964 @@
|
||||
"""OIDC (OpenID Connect) API endpoints - Root level blueprint."""
|
||||
import base64
|
||||
import json
|
||||
import secrets
|
||||
from urllib.parse import urlencode, urlparse, parse_qs
|
||||
|
||||
import bcrypt
|
||||
from flask import Blueprint, request, redirect, jsonify, session, g, current_app, Response
|
||||
|
||||
from app.utils.response import api_response
|
||||
from app.services.oidc_service import (
|
||||
OIDCService, InvalidClientError, InvalidGrantError, InvalidRequestError
|
||||
)
|
||||
from app.services.auth_service import AuthService
|
||||
from app.extensions import db
|
||||
from app.models import User, OIDCClient
|
||||
from app.models.organization import Organization
|
||||
from app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||
|
||||
|
||||
# Create OIDC blueprint registered at root level
|
||||
oidc_bp = Blueprint("oidc", __name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
def get_oidc_config():
|
||||
"""Get OIDC configuration from app config."""
|
||||
base_url = current_app.config.get("OIDC_ISSUER_URL", "http://localhost:5000")
|
||||
return {
|
||||
"issuer": base_url,
|
||||
"authorization_endpoint": f"{base_url}/oidc/authorize",
|
||||
"token_endpoint": f"{base_url}/oidc/token",
|
||||
"userinfo_endpoint": f"{base_url}/oidc/userinfo",
|
||||
"jwks_uri": f"{base_url}/oidc/jwks",
|
||||
"registration_endpoint": f"{base_url}/oidc/register",
|
||||
"revocation_endpoint": f"{base_url}/oidc/revoke",
|
||||
"introspection_endpoint": f"{base_url}/oidc/introspect",
|
||||
"scopes_supported": ["openid", "profile", "email"],
|
||||
"response_types_supported": ["code"],
|
||||
"response_modes_supported": ["query"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
|
||||
"subject_types_supported": ["public"],
|
||||
"id_token_signing_alg_values_supported": ["RS256"],
|
||||
"claims_supported": ["sub", "name", "email", "email_verified"],
|
||||
}
|
||||
|
||||
|
||||
def authenticate_client(client_id, client_secret=None):
|
||||
"""Authenticate an OIDC client.
|
||||
|
||||
Args:
|
||||
client_id: The client ID
|
||||
client_secret: Optional client secret
|
||||
|
||||
Returns:
|
||||
OIDCClient instance
|
||||
|
||||
Raises:
|
||||
InvalidClientError: If authentication fails
|
||||
"""
|
||||
client = OIDCClient.query.filter_by(client_id=client_id, is_active=True).first()
|
||||
if not client:
|
||||
raise InvalidClientError("Invalid client")
|
||||
|
||||
if client.is_confidential and client_secret:
|
||||
if not bcrypt.check_password_hash(client.client_secret_hash, client_secret):
|
||||
raise InvalidClientError("Invalid client credentials")
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def require_valid_token():
|
||||
"""Validate Bearer token from Authorization header.
|
||||
|
||||
Sets g.current_token and g.current_user on success.
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If token is invalid
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise InvalidGrantError("Invalid token: Missing or invalid Authorization header")
|
||||
|
||||
token = auth_header[7:]
|
||||
claims = OIDCService.validate_access_token(token)
|
||||
g.current_token = claims
|
||||
|
||||
user = User.query.get(claims.get("sub"))
|
||||
if not user:
|
||||
raise InvalidGrantError("Invalid token: User not found")
|
||||
|
||||
g.current_user = user
|
||||
|
||||
|
||||
def parse_basic_auth():
|
||||
"""Parse Basic authentication from Authorization header.
|
||||
|
||||
Returns:
|
||||
Tuple of (client_id, client_secret) or (None, None)
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
try:
|
||||
encoded = auth_header[6:]
|
||||
decoded = base64.b64decode(encoded).decode("utf-8")
|
||||
client_id, client_secret = decoded.split(":", 1)
|
||||
return client_id, client_secret
|
||||
except Exception:
|
||||
pass
|
||||
return None, None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Discovery Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@oidc_bp.route("/.well-known/openid-configuration", methods=["GET"])
|
||||
def oidc_discovery():
|
||||
"""OpenID Connect Discovery endpoint.
|
||||
|
||||
Returns the OIDC configuration as JSON.
|
||||
|
||||
Cache-Control: max-age=86400
|
||||
No authentication required.
|
||||
|
||||
Returns:
|
||||
200: OIDC discovery document
|
||||
"""
|
||||
config = get_oidc_config()
|
||||
|
||||
response = jsonify(config)
|
||||
response.headers["Cache-Control"] = "max-age=86400"
|
||||
return response, 200
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@oidc_bp.route("/oidc/authorize", methods=["GET", "POST"])
|
||||
def oidc_authorize():
|
||||
"""OpenID Connect Authorization endpoint.
|
||||
|
||||
Initiates the OIDC authentication flow.
|
||||
|
||||
GET Parameters:
|
||||
client_id: The client ID
|
||||
redirect_uri: The redirect URI
|
||||
response_type: Must be "code" for authorization code flow
|
||||
scope: Space-separated scopes (e.g., "openid profile email")
|
||||
state: Opaque state value for CSRF protection
|
||||
nonce: Nonce for ID token replay protection
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method ("S256" or "plain")
|
||||
prompt: "login", "consent", "select_account", "none"
|
||||
max_age: Maximum authentication age in seconds
|
||||
acr_values: Requested Authentication Context Class Reference
|
||||
|
||||
POST Parameters:
|
||||
Same as GET, plus:
|
||||
email: User email
|
||||
password: User password
|
||||
|
||||
Returns:
|
||||
302: Redirect with authorization code or error
|
||||
200: Login page (GET when not authenticated)
|
||||
400: Invalid request
|
||||
"""
|
||||
# Parse request parameters
|
||||
if request.method == "GET":
|
||||
params = request.args.to_dict()
|
||||
else:
|
||||
params = request.form.to_dict()
|
||||
|
||||
# Extract required parameters
|
||||
client_id = params.get("client_id")
|
||||
redirect_uri = params.get("redirect_uri")
|
||||
response_type = params.get("response_type")
|
||||
scope = params.get("scope", "")
|
||||
state = params.get("state", "")
|
||||
nonce = params.get("nonce", "")
|
||||
code_challenge = params.get("code_challenge")
|
||||
code_challenge_method = params.get("code_challenge_method")
|
||||
|
||||
# Validate required parameters
|
||||
errors = []
|
||||
if not client_id:
|
||||
errors.append("client_id is required")
|
||||
if not redirect_uri:
|
||||
errors.append("redirect_uri is required")
|
||||
if not response_type:
|
||||
errors.append("response_type is required")
|
||||
|
||||
if errors:
|
||||
return _redirect_with_error(redirect_uri, "invalid_request", "; ".join(errors), state)
|
||||
|
||||
# Validate response_type
|
||||
if response_type != "code":
|
||||
return _redirect_with_error(
|
||||
redirect_uri, "unsupported_response_type",
|
||||
"Only response_type=code is supported", state
|
||||
)
|
||||
|
||||
# Validate client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id, is_active=True).first()
|
||||
if not client:
|
||||
return _redirect_with_error(redirect_uri, "unauthorized_client", "Invalid client", state)
|
||||
|
||||
# Validate redirect URI
|
||||
if not client.is_redirect_uri_allowed(redirect_uri):
|
||||
return _redirect_with_error(redirect_uri, "invalid_request", "Invalid redirect_uri", state)
|
||||
|
||||
# Validate scopes
|
||||
requested_scopes = scope.split() if scope else []
|
||||
allowed_scopes = client.scopes or []
|
||||
valid_scopes = [s for s in requested_scopes if s in allowed_scopes]
|
||||
|
||||
if not valid_scopes:
|
||||
return _redirect_with_error(redirect_uri, "invalid_scope", "Invalid or no scopes requested", state)
|
||||
|
||||
# Check if user is already authenticated via session
|
||||
user_id = session.get("oidc_user_id")
|
||||
|
||||
# Handle POST with credentials
|
||||
if request.method == "POST" and not user_id:
|
||||
email = params.get("email")
|
||||
password = params.get("password")
|
||||
|
||||
if not email or not password:
|
||||
return _show_login_page(
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
response_type=response_type,
|
||||
error="Invalid credentials"
|
||||
)
|
||||
|
||||
try:
|
||||
user = AuthService.authenticate(email, password)
|
||||
user_id = user.id
|
||||
session["oidc_user_id"] = user_id
|
||||
except InvalidCredentialsError:
|
||||
return _show_login_page(
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
response_type=response_type,
|
||||
error="Invalid email or password"
|
||||
)
|
||||
|
||||
# If no user, show login page
|
||||
if not user_id:
|
||||
return _show_login_page(
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
response_type=response_type
|
||||
)
|
||||
|
||||
# User is authenticated, generate authorization code
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return _redirect_with_error(redirect_uri, "server_error", "User not found", state)
|
||||
|
||||
try:
|
||||
code = OIDCService.generate_authorization_code(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=valid_scopes,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
)
|
||||
except Exception as e:
|
||||
return _redirect_with_error(redirect_uri, "server_error", str(e), state)
|
||||
|
||||
# Redirect with authorization code
|
||||
redirect_params = {"code": code}
|
||||
if state:
|
||||
redirect_params["state"] = state
|
||||
|
||||
return redirect(f"{redirect_uri}?{urlencode(redirect_params)}")
|
||||
|
||||
|
||||
def _redirect_with_error(redirect_uri, error, error_description, state=None):
|
||||
"""Redirect to client with error parameters."""
|
||||
if not redirect_uri:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=error_description,
|
||||
status=400,
|
||||
error_type=error.upper(),
|
||||
error_details={"error": error, "error_description": error_description},
|
||||
)
|
||||
|
||||
params = {
|
||||
"error": error,
|
||||
"error_description": error_description,
|
||||
}
|
||||
if state:
|
||||
params["state"] = state
|
||||
|
||||
return redirect(f"{redirect_uri}?{urlencode(params)}")
|
||||
|
||||
|
||||
def _show_login_page(client_id, redirect_uri, scope, state, nonce, response_type, error=None):
|
||||
"""Show the login page for authorization."""
|
||||
# Simple HTML login page
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Sign In - OIDC Authorization</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 40px; background: #f5f5f5; }}
|
||||
.container {{ max-width: 400px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); }}
|
||||
h1 {{ color: #333; font-size: 24px; margin-bottom: 20px; }}
|
||||
.form-group {{ margin-bottom: 15px; }}
|
||||
label {{ display: block; margin-bottom: 5px; color: #555; font-weight: bold; }}
|
||||
input[type="email"], input[type="password"] {{ width: 100%; padding: 10px; border: 1px solid #ddd; border-radius: 4px; box-sizing: border-box; }}
|
||||
button {{ width: 100%; padding: 12px; background: #007bff; color: white; border: none; border-radius: 4px; cursor: pointer; font-size: 16px; }}
|
||||
button:hover {{ background: #0056b3; }}
|
||||
.error {{ color: #dc3545; margin-bottom: 15px; }}
|
||||
.cancel {{ text-align: center; margin-top: 15px; }}
|
||||
.cancel a {{ color: #666; text-decoration: none; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Sign In</h1>
|
||||
{"<p class='error'>" + error + "</p>" if error else ""}
|
||||
<form method="POST">
|
||||
<input type="hidden" name="client_id" value="{client_id}">
|
||||
<input type="hidden" name="redirect_uri" value="{redirect_uri}">
|
||||
<input type="hidden" name="scope" value="{scope}">
|
||||
<input type="hidden" name="state" value="{state}">
|
||||
<input type="hidden" name="nonce" value="{nonce}">
|
||||
<input type="hidden" name="response_type" value="{response_type}">
|
||||
|
||||
<div class="form-group">
|
||||
<label for="email">Email</label>
|
||||
<input type="email" id="email" name="email" required>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="password">Password</label>
|
||||
<input type="password" id="password" name="password" required>
|
||||
</div>
|
||||
|
||||
<button type="submit">Sign In</button>
|
||||
</form>
|
||||
<p class="cancel">
|
||||
<a href="{redirect_uri}">Cancel</a>
|
||||
</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return Response(html, mimetype="text/html"), 200
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@oidc_bp.route("/oidc/token", methods=["POST"])
|
||||
def oidc_token():
|
||||
"""OpenID Connect Token endpoint.
|
||||
|
||||
Exchanges authorization code for tokens or refreshes tokens.
|
||||
|
||||
Request body (application/x-www-form-urlencoded):
|
||||
grant_type: "authorization_code" or "refresh_token"
|
||||
|
||||
For authorization_code:
|
||||
code: The authorization code
|
||||
redirect_uri: The redirect URI used in authorization
|
||||
client_id: The client ID
|
||||
client_secret: The client secret (optional if using Basic auth)
|
||||
code_verifier: PKCE code verifier (optional)
|
||||
|
||||
For refresh_token:
|
||||
refresh_token: The refresh token
|
||||
scope: Optional scope override
|
||||
client_id: The client ID
|
||||
client_secret: The client secret (optional if using Basic auth)
|
||||
|
||||
Authentication:
|
||||
- Basic auth with client_id:client_secret, or
|
||||
- client_id + client_secret in request body
|
||||
|
||||
Returns:
|
||||
200: JSON with tokens
|
||||
400: Invalid request
|
||||
401: Invalid client
|
||||
"""
|
||||
# Parse request body
|
||||
if request.content_type and "application/x-www-form-urlencoded" in request.content_type:
|
||||
data = request.form.to_dict()
|
||||
else:
|
||||
data = request.json or {}
|
||||
|
||||
grant_type = data.get("grant_type")
|
||||
|
||||
# Validate grant_type
|
||||
if not grant_type:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="grant_type is required",
|
||||
status=400,
|
||||
error_type="INVALID_REQUEST",
|
||||
error_details={"error": "invalid_request", "error_description": "grant_type is required"},
|
||||
)
|
||||
|
||||
# Authenticate client
|
||||
client_id = data.get("client_id")
|
||||
client_secret = data.get("client_secret")
|
||||
|
||||
# Try Basic auth if client_id not in body
|
||||
if not client_id:
|
||||
client_id, client_secret = parse_basic_auth()
|
||||
|
||||
if not client_id:
|
||||
# Return 401 with WWW-Authenticate header for Basic auth
|
||||
response = jsonify({
|
||||
"error": "invalid_client",
|
||||
"error_description": "Client authentication required"
|
||||
})
|
||||
response.headers["WWW-Authenticate"] = 'Basic realm="OIDC Token Endpoint"'
|
||||
return response, 401
|
||||
|
||||
try:
|
||||
client = authenticate_client(client_id, client_secret)
|
||||
except InvalidClientError:
|
||||
response = jsonify({
|
||||
"error": "invalid_client",
|
||||
"error_description": "Invalid client credentials"
|
||||
})
|
||||
return response, 401
|
||||
|
||||
# Handle authorization_code grant
|
||||
if grant_type == "authorization_code":
|
||||
return _handle_authorization_code_grant(data, client)
|
||||
|
||||
# Handle refresh_token grant
|
||||
elif grant_type == "refresh_token":
|
||||
return _handle_refresh_token_grant(data, client)
|
||||
|
||||
# Unsupported grant type
|
||||
else:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Unsupported grant_type",
|
||||
status=400,
|
||||
error_type="UNSUPPORTED_GRANT_TYPE",
|
||||
error_details={"error": "unsupported_grant_type", "error_description": f"Grant type '{grant_type}' is not supported"},
|
||||
)
|
||||
|
||||
|
||||
def _handle_authorization_code_grant(data, client):
|
||||
"""Handle authorization_code grant type."""
|
||||
code = data.get("code")
|
||||
redirect_uri = data.get("redirect_uri")
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
if not code:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="code is required",
|
||||
status=400,
|
||||
error_type="INVALID_REQUEST",
|
||||
error_details={"error": "invalid_request", "error_description": "code is required"},
|
||||
)
|
||||
|
||||
if not redirect_uri:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="redirect_uri is required",
|
||||
status=400,
|
||||
error_type="INVALID_REQUEST",
|
||||
error_details={"error": "invalid_request", "error_description": "redirect_uri is required"},
|
||||
)
|
||||
|
||||
try:
|
||||
claims, user = OIDCService.validate_authorization_code(
|
||||
code=code,
|
||||
client_id=client.client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=code_verifier,
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=400,
|
||||
error_type="INVALID_GRANT",
|
||||
error_details={"error": "invalid_grant", "error_description": str(e)},
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
try:
|
||||
tokens = OIDCService.generate_tokens(
|
||||
client_id=client.client_id,
|
||||
user_id=claims["user_id"],
|
||||
scope=claims["scope"],
|
||||
nonce=claims.get("nonce"),
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
auth_time=int(__import__("time").time()),
|
||||
)
|
||||
except Exception as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Failed to generate tokens",
|
||||
status=500,
|
||||
error_type="SERVER_ERROR",
|
||||
error_details={"error": "server_error", "error_description": str(e)},
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data=tokens,
|
||||
message="Tokens issued successfully",
|
||||
status=200,
|
||||
)
|
||||
|
||||
|
||||
def _handle_refresh_token_grant(data, client):
|
||||
"""Handle refresh_token grant type."""
|
||||
refresh_token = data.get("refresh_token")
|
||||
scope = data.get("scope")
|
||||
|
||||
if not refresh_token:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="refresh_token is required",
|
||||
status=400,
|
||||
error_type="INVALID_REQUEST",
|
||||
error_details={"error": "invalid_request", "error_description": "refresh_token is required"},
|
||||
)
|
||||
|
||||
# Parse scope if provided
|
||||
scope_list = scope.split() if scope else None
|
||||
|
||||
try:
|
||||
tokens = OIDCService.refresh_access_token(
|
||||
refresh_token=refresh_token,
|
||||
client_id=client.client_id,
|
||||
scope=scope_list,
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=400,
|
||||
error_type="INVALID_GRANT",
|
||||
error_details={"error": "invalid_grant", "error_description": str(e)},
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data=tokens,
|
||||
message="Tokens refreshed successfully",
|
||||
status=200,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# UserInfo Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@oidc_bp.route("/oidc/userinfo", methods=["GET", "POST"])
|
||||
def oidc_userinfo():
|
||||
"""OpenID Connect UserInfo endpoint.
|
||||
|
||||
Returns claims about the authenticated user.
|
||||
|
||||
Authorization: Bearer {access_token}
|
||||
|
||||
Returns claims based on granted scopes:
|
||||
- sub: User ID (always included)
|
||||
- name: User full name (if "profile" scope)
|
||||
- email: User email (if "email" scope)
|
||||
- email_verified: Email verification status (if "email" scope)
|
||||
|
||||
Returns:
|
||||
200: User claims
|
||||
401: Invalid or insufficient token
|
||||
"""
|
||||
try:
|
||||
require_valid_token()
|
||||
except InvalidGrantError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=401,
|
||||
error_type="INVALID_TOKEN",
|
||||
error_details={"error": "invalid_token", "error_description": str(e)},
|
||||
)
|
||||
|
||||
# Get userinfo
|
||||
try:
|
||||
userinfo = OIDCService.get_userinfo(g.current_token.get("access_token", ""))
|
||||
except Exception as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Failed to get user info",
|
||||
status=500,
|
||||
error_type="SERVER_ERROR",
|
||||
error_details={"error": "server_error", "error_description": str(e)},
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data=userinfo,
|
||||
message="User info retrieved successfully",
|
||||
status=200,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# JWKS Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@oidc_bp.route("/oidc/jwks", methods=["GET"])
|
||||
def oidc_jwks():
|
||||
"""OpenID Connect JSON Web Key Set endpoint.
|
||||
|
||||
Returns the public keys used to sign tokens.
|
||||
|
||||
Cache-Control: max-age=3600
|
||||
No authentication required.
|
||||
|
||||
Returns:
|
||||
200: JWKS document
|
||||
"""
|
||||
try:
|
||||
jwks = OIDCService.get_jwks()
|
||||
except Exception as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Failed to get JWKS",
|
||||
status=500,
|
||||
error_type="SERVER_ERROR",
|
||||
error_details={"error": "server_error", "error_description": str(e)},
|
||||
)
|
||||
|
||||
response = jsonify(jwks)
|
||||
response.headers["Cache-Control"] = "max-age=3600"
|
||||
return response, 200
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@oidc_bp.route("/oidc/revoke", methods=["POST"])
|
||||
def oidc_revoke():
|
||||
"""OAuth2 Token Revocation endpoint.
|
||||
|
||||
Revokes an access token or refresh token.
|
||||
|
||||
Request body (application/x-www-form-urlencoded):
|
||||
token: The token to revoke
|
||||
token_type_hint: Optional hint ("access_token" or "refresh_token")
|
||||
client_id: The client ID
|
||||
client_secret: The client secret (optional if using Basic auth)
|
||||
|
||||
Authentication:
|
||||
- Basic auth with client_id:client_secret, or
|
||||
- client_id + client_secret in request body
|
||||
|
||||
Returns:
|
||||
200: Token revoked successfully
|
||||
400: Invalid request
|
||||
401: Invalid client
|
||||
"""
|
||||
# Parse request body
|
||||
if request.content_type and "application/x-www-form-urlencoded" in request.content_type:
|
||||
data = request.form.to_dict()
|
||||
else:
|
||||
data = request.json or {}
|
||||
|
||||
token = data.get("token")
|
||||
|
||||
if not token:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="token is required",
|
||||
status=400,
|
||||
error_type="INVALID_REQUEST",
|
||||
error_details={"error": "invalid_request", "error_description": "token is required"},
|
||||
)
|
||||
|
||||
# Authenticate client
|
||||
client_id = data.get("client_id")
|
||||
client_secret = data.get("client_secret")
|
||||
|
||||
if not client_id:
|
||||
client_id, client_secret = parse_basic_auth()
|
||||
|
||||
if not client_id:
|
||||
response = jsonify({
|
||||
"error": "invalid_client",
|
||||
"error_description": "Client authentication required"
|
||||
})
|
||||
response.headers["WWW-Authenticate"] = 'Basic realm="OIDC Revoke Endpoint"'
|
||||
return response, 401
|
||||
|
||||
try:
|
||||
client = authenticate_client(client_id, client_secret)
|
||||
except InvalidClientError:
|
||||
response = jsonify({
|
||||
"error": "invalid_client",
|
||||
"error_description": "Invalid client credentials"
|
||||
})
|
||||
return response, 401
|
||||
|
||||
token_type_hint = data.get("token_type_hint")
|
||||
|
||||
try:
|
||||
OIDCService.revoke_token(
|
||||
token=token,
|
||||
client_id=client.client_id,
|
||||
token_type_hint=token_type_hint,
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
)
|
||||
except Exception as e:
|
||||
# Revocation should succeed even if token is invalid
|
||||
pass
|
||||
|
||||
return api_response(
|
||||
message="Token revoked successfully",
|
||||
status=200,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@oidc_bp.route("/oidc/introspect", methods=["POST"])
|
||||
def oidc_introspect():
|
||||
"""OAuth2 Token Introspection endpoint.
|
||||
|
||||
Returns information about a token.
|
||||
|
||||
Request body (application/x-www-form-urlencoded):
|
||||
token: The token to introspect
|
||||
token_type_hint: Optional hint ("access_token" or "refresh_token")
|
||||
client_id: The client ID
|
||||
client_secret: The client secret (optional if using Basic auth)
|
||||
|
||||
Authentication:
|
||||
- Basic auth with client_id:client_secret, or
|
||||
- client_id + client_secret in request body
|
||||
|
||||
Returns:
|
||||
200: Token status and claims
|
||||
400: Invalid request
|
||||
401: Invalid client
|
||||
"""
|
||||
# Parse request body
|
||||
if request.content_type and "application/x-www-form-urlencoded" in request.content_type:
|
||||
data = request.form.to_dict()
|
||||
else:
|
||||
data = request.json or {}
|
||||
|
||||
token = data.get("token")
|
||||
|
||||
if not token:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="token is required",
|
||||
status=400,
|
||||
error_type="INVALID_REQUEST",
|
||||
error_details={"error": "invalid_request", "error_description": "token is required"},
|
||||
)
|
||||
|
||||
# Authenticate client
|
||||
client_id = data.get("client_id")
|
||||
client_secret = data.get("client_secret")
|
||||
|
||||
if not client_id:
|
||||
client_id, client_secret = parse_basic_auth()
|
||||
|
||||
if not client_id:
|
||||
response = jsonify({
|
||||
"error": "invalid_client",
|
||||
"error_description": "Client authentication required"
|
||||
})
|
||||
response.headers["WWW-Authenticate"] = 'Basic realm="OIDC Introspect Endpoint"'
|
||||
return response, 401
|
||||
|
||||
try:
|
||||
client = authenticate_client(client_id, client_secret)
|
||||
except InvalidClientError:
|
||||
response = jsonify({
|
||||
"error": "invalid_client",
|
||||
"error_description": "Invalid client credentials"
|
||||
})
|
||||
return response, 401
|
||||
|
||||
token_type_hint = data.get("token_type_hint")
|
||||
|
||||
try:
|
||||
result = OIDCService.introspect_token(
|
||||
token=token,
|
||||
client_id=client.client_id,
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
)
|
||||
except Exception as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Failed to introspect token",
|
||||
status=500,
|
||||
error_type="SERVER_ERROR",
|
||||
error_details={"error": "server_error", "error_description": str(e)},
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data=result,
|
||||
message="Token introspection successful",
|
||||
status=200,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Client Registration Endpoint (Optional)
|
||||
# ============================================================================
|
||||
|
||||
@oidc_bp.route("/oidc/register", methods=["POST"])
|
||||
def oidc_register():
|
||||
"""OpenID Connect Client Registration endpoint.
|
||||
|
||||
Registers a new OIDC client.
|
||||
|
||||
Request body (application/json):
|
||||
client_name: Name of the client
|
||||
redirect_uris: List of redirect URIs
|
||||
token_endpoint_auth_method: "client_secret_basic" or "client_secret_post"
|
||||
grant_types: List of grant types ["authorization_code", "refresh_token"]
|
||||
response_types: List of response types ["code"]
|
||||
scope: Space-separated scopes (default: "openid profile email")
|
||||
|
||||
Returns:
|
||||
201: Client registered successfully
|
||||
400: Invalid request
|
||||
"""
|
||||
data = request.json or {}
|
||||
|
||||
# Validate required fields
|
||||
client_name = data.get("client_name")
|
||||
redirect_uris = data.get("redirect_uris", [])
|
||||
|
||||
if not client_name:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="client_name is required",
|
||||
status=400,
|
||||
error_type="INVALID_REQUEST",
|
||||
error_details={"error": "invalid_request", "error_description": "client_name is required"},
|
||||
)
|
||||
|
||||
if not redirect_uris:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="redirect_uris is required",
|
||||
status=400,
|
||||
error_type="INVALID_REQUEST",
|
||||
error_details={"error": "invalid_request", "error_description": "redirect_uris is required"},
|
||||
)
|
||||
|
||||
# Validate redirect_uris
|
||||
for uri in redirect_uris:
|
||||
try:
|
||||
parsed = urlparse(uri)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
raise ValueError(f"Invalid redirect URI: {uri}")
|
||||
except Exception:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=f"Invalid redirect_uri: {uri}",
|
||||
status=400,
|
||||
error_type="INVALID_REQUEST",
|
||||
error_details={"error": "invalid_request", "error_description": f"Invalid redirect_uri: {uri}"},
|
||||
)
|
||||
|
||||
# Generate client credentials
|
||||
client_id = f"oidc_{secrets.token_urlsafe(16)}"
|
||||
client_secret = f"secret_{secrets.token_urlsafe(24)}"
|
||||
client_secret_hash = bcrypt.generate_password_hash(client_secret).decode("utf-8")
|
||||
|
||||
# Get organization from request or default
|
||||
org_id = data.get("organization_id")
|
||||
if org_id:
|
||||
organization = Organization.query.get(org_id)
|
||||
else:
|
||||
# Get first active organization or create a default one
|
||||
organization = Organization.query.filter_by(is_active=True).first()
|
||||
|
||||
if not organization:
|
||||
# Create a default organization for the client
|
||||
organization = Organization(
|
||||
name=f"OIDC Clients",
|
||||
slug=f"oidc-clients-{secrets.token_urlsafe(8)}",
|
||||
)
|
||||
organization.save()
|
||||
|
||||
# Create OIDC client
|
||||
client = OIDCClient(
|
||||
organization_id=organization.id,
|
||||
name=client_name,
|
||||
client_id=client_id,
|
||||
client_secret_hash=client_secret_hash,
|
||||
redirect_uris=redirect_uris,
|
||||
grant_types=data.get("grant_types", ["authorization_code", "refresh_token"]),
|
||||
response_types=data.get("response_types", ["code"]),
|
||||
scopes=data.get("scope", "openid profile email").split(),
|
||||
token_endpoint_auth_method=data.get("token_endpoint_auth_method", "client_secret_basic"),
|
||||
is_active=True,
|
||||
is_confidential=True,
|
||||
require_pkce=True,
|
||||
logo_uri=data.get("logo_uri"),
|
||||
client_uri=data.get("client_uri"),
|
||||
policy_uri=data.get("policy_uri"),
|
||||
tos_uri=data.get("tos_uri"),
|
||||
)
|
||||
client.save()
|
||||
|
||||
# Return client credentials
|
||||
return api_response(
|
||||
data={
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"client_id_issued_at": int(__import__("time").time()),
|
||||
"client_secret_expires_at": 0, # Never expires
|
||||
"client_name": client_name,
|
||||
"redirect_uris": redirect_uris,
|
||||
"token_endpoint_auth_method": data.get("token_endpoint_auth_method", "client_secret_basic"),
|
||||
"grant_types": client.grant_types,
|
||||
"response_types": client.response_types,
|
||||
"scope": " ".join(client.scopes),
|
||||
},
|
||||
message="Client registered successfully",
|
||||
status=201,
|
||||
)
|
||||
+1
-7
@@ -12,13 +12,7 @@ from flask_session import Session
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
bcrypt = Bcrypt()
|
||||
cors = CORS(
|
||||
supports_credentials=True,
|
||||
resources={r"/api/*": {"origins": "*"}}, # Apply CORS to all API routes
|
||||
allow_headers=["Content-Type", "Authorization", "X-Request-ID"],
|
||||
methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
expose_headers=["X-Request-ID"],
|
||||
)
|
||||
cors = CORS()
|
||||
ma = Marshmallow()
|
||||
limiter = Limiter(
|
||||
key_func=get_remote_address,
|
||||
|
||||
+41
-8
@@ -1,17 +1,40 @@
|
||||
"""CORS middleware configuration."""
|
||||
from flask import request
|
||||
from flask import request, make_response
|
||||
|
||||
|
||||
def setup_cors(app, cors):
|
||||
def setup_cors(app):
|
||||
"""
|
||||
Configure CORS for the application.
|
||||
|
||||
Args:
|
||||
app: Flask application instance
|
||||
cors: Flask-CORS instance
|
||||
"""
|
||||
# CORS is already initialized in extensions.py
|
||||
# This function provides additional configuration if needed
|
||||
|
||||
@app.before_request
|
||||
def handle_preflight():
|
||||
"""Handle CORS preflight OPTIONS requests."""
|
||||
if request.method == "OPTIONS":
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
|
||||
if allow_all:
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
return response
|
||||
elif origin and origin in cors_origins:
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
return response
|
||||
|
||||
@app.after_request
|
||||
def after_request_cors(response):
|
||||
@@ -19,11 +42,21 @@ def setup_cors(app, cors):
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
# Allow all origins in development if CORS_ORIGINS is "*"
|
||||
if cors_origins == "*" or origin in cors_origins:
|
||||
response.headers["Access-Control-Allow-Origin"] = origin if cors_origins != "*" else "*"
|
||||
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
|
||||
if allow_all:
|
||||
# When allowing all origins, set header to "*"
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
elif origin and origin in cors_origins:
|
||||
# When allowing specific origins, echo the request origin
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
|
||||
return response
|
||||
|
||||
@@ -7,6 +7,11 @@ from app.models.authentication_method import AuthenticationMethod
|
||||
from app.models.session import Session
|
||||
from app.models.audit_log import AuditLog
|
||||
from app.models.oidc_client import OIDCClient
|
||||
from app.models.oidc_authorization_code import OIDCAuthCode
|
||||
from app.models.oidc_refresh_token import OIDCRefreshToken
|
||||
from app.models.oidc_session import OIDCSession
|
||||
from app.models.oidc_token_metadata import OIDCTokenMetadata
|
||||
from app.models.oidc_audit_log import OIDCAuditLog
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
@@ -17,4 +22,9 @@ __all__ = [
|
||||
"Session",
|
||||
"AuditLog",
|
||||
"OIDCClient",
|
||||
"OIDCAuthCode",
|
||||
"OIDCRefreshToken",
|
||||
"OIDCSession",
|
||||
"OIDCTokenMetadata",
|
||||
"OIDCAuditLog",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
|
||||
from datetime import datetime
|
||||
from app.extensions import db
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCAuditLog(BaseModel):
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking.
|
||||
|
||||
This model logs all OIDC-related events for security, compliance,
|
||||
and debugging purposes.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_audit_logs"
|
||||
|
||||
# Event type categorization
|
||||
event_type = db.Column(db.String(100), nullable=False, index=True)
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=True, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Event outcome
|
||||
success = db.Column(db.Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Error details (for failed events)
|
||||
error_code = db.Column(db.String(100), nullable=True)
|
||||
error_description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Request context
|
||||
ip_address = db.Column(db.String(45), nullable=True, index=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
request_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
|
||||
# Additional event metadata
|
||||
event_metadata = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="audit_logs")
|
||||
user = db.relationship("User", back_populates="oidc_audit_logs")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuditLog."""
|
||||
status = "success" if self.success else "failed"
|
||||
return f"<OIDCAuditLog event={self.event_type} status={status} client={self.client_id}>"
|
||||
|
||||
@classmethod
|
||||
def log_event(cls, event_type, client_id=None, user_id=None, success=True,
|
||||
error_code=None, error_description=None, ip_address=None,
|
||||
user_agent=None, request_id=None, event_metadata=None):
|
||||
"""Log an OIDC event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event (e.g., "authorization_request", "token_issue")
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
success: Whether the event was successful
|
||||
error_code: Error code if event failed
|
||||
error_description: Error description if event failed
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
request_id: Request ID for correlation
|
||||
event_metadata: Additional event metadata
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
log = cls(
|
||||
event_type=event_type,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
db.session.add(log)
|
||||
db.session.commit()
|
||||
return log
|
||||
|
||||
@classmethod
|
||||
def log_authorization_request(cls, client_id, user_id, redirect_uri, scope,
|
||||
ip_address=None, user_agent=None, request_id=None,
|
||||
success=True, error_code=None, error_description=None):
|
||||
"""Log an authorization request event."""
|
||||
return cls.log_event(
|
||||
event_type="authorization_request",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": scope,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_issue(cls, client_id, user_id, token_type,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
"""Log a token issuance event."""
|
||||
return cls.log_event(
|
||||
event_type="token_issue",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=True,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={"token_type": token_type}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_revocation(cls, client_id, user_id, token_type, reason=None,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
"""Log a token revocation event."""
|
||||
return cls.log_event(
|
||||
event_type="token_revocation",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=True,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={
|
||||
"token_type": token_type,
|
||||
"reason": reason,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_authentication_failure(cls, client_id, error_code, error_description,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
"""Log an authentication failure event."""
|
||||
return cls.log_event(
|
||||
event_type="authentication_failure",
|
||||
client_id=client_id,
|
||||
success=False,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_user(cls, user_id, limit=100):
|
||||
"""Get audit events for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return cls.query.filter_by(user_id=user_id, deleted_at=None)\
|
||||
.order_by(cls.created_at.desc())\
|
||||
.limit(limit)\
|
||||
.all()
|
||||
|
||||
@classmethod
|
||||
def get_events_for_client(cls, client_id, limit=100):
|
||||
"""Get audit events for a client.
|
||||
|
||||
Args:
|
||||
client_id: The client ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return cls.query.filter_by(client_id=client_id, deleted_at=None)\
|
||||
.order_by(cls.created_at.desc())\
|
||||
.limit(limit)\
|
||||
.all()
|
||||
|
||||
@classmethod
|
||||
def get_failed_events(cls, client_id=None, user_id=None, start_date=None,
|
||||
end_date=None, limit=100):
|
||||
"""Get failed audit events.
|
||||
|
||||
Args:
|
||||
client_id: Optional client ID filter
|
||||
user_id: Optional user ID filter
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
query = cls.query.filter_by(success=False, deleted_at=None)
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
if start_date:
|
||||
query = query.filter(cls.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(cls.created_at <= end_date)
|
||||
|
||||
return query.order_by(cls.created_at.desc()).limit(limit).all()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from app.models.user import User
|
||||
User.oidc_audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from app.models.oidc_client import OIDCClient
|
||||
OIDCClient.audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -0,0 +1,120 @@
|
||||
"""OIDC Authorization Code model for auth code flow."""
|
||||
from datetime import datetime, timedelta
|
||||
from app.extensions import db
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCAuthCode(BaseModel):
|
||||
"""OIDC Authorization Code model for authorization code flow.
|
||||
|
||||
Authorization codes are single-use, short-lived codes used in the
|
||||
authorization code grant flow. The code is hashed for security.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_authorization_codes"
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Authorization code (hashed for security)
|
||||
code_hash = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# Request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True) # Requested scopes
|
||||
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
|
||||
code_verifier = db.Column(db.String(255), nullable=True) # For PKCE
|
||||
|
||||
# Status tracking
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
is_used = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
# Request metadata
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="authorization_codes")
|
||||
user = db.relationship("User", back_populates="oidc_auth_codes")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuthCode."""
|
||||
return f"<OIDCAuthCode client_id={self.client_id} user_id={self.user_id} used={self.is_used}>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the authorization code has expired."""
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
def is_valid(self):
|
||||
"""Check if the authorization code is valid for use."""
|
||||
return not self.is_used and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def mark_as_used(self):
|
||||
"""Mark the authorization code as used."""
|
||||
self.is_used = True
|
||||
self.used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None,
|
||||
nonce=None, code_verifier=None, ip_address=None, user_agent=None,
|
||||
lifetime_seconds=600):
|
||||
"""Create a new authorization code.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
code_hash: Hashed authorization code
|
||||
redirect_uri: The redirect URI
|
||||
scope: Requested scopes
|
||||
nonce: OIDC nonce
|
||||
code_verifier: PKCE code verifier
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
|
||||
|
||||
Returns:
|
||||
OIDCAuthCode instance
|
||||
"""
|
||||
code = cls(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
code_hash=code_hash,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
nonce=nonce,
|
||||
code_verifier=code_verifier,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
db.session.add(code)
|
||||
db.session.commit()
|
||||
return code
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude code hash
|
||||
exclude.append("code_hash")
|
||||
exclude.append("code_verifier")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from app.models.user import User
|
||||
User.oidc_auth_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from app.models.oidc_client import OIDCClient
|
||||
OIDCClient.authorization_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -0,0 +1,159 @@
|
||||
"""OIDC Refresh Token model for token rotation."""
|
||||
from datetime import datetime
|
||||
from app.extensions import db
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCRefreshToken(BaseModel):
|
||||
"""OIDC Refresh Token model for token refresh and rotation.
|
||||
|
||||
Refresh tokens are long-lived credentials used to obtain new access tokens.
|
||||
They support token rotation for enhanced security.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_refresh_tokens"
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token (hashed for security)
|
||||
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
|
||||
# Associated access token ID
|
||||
access_token_id = db.Column(
|
||||
db.String(36), db.ForeignKey("sessions.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Token scope
|
||||
scope = db.Column(db.JSON, nullable=True) # Granted scopes
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
|
||||
# Revocation tracking
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Token rotation metadata
|
||||
previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation
|
||||
rotation_count = db.Column(db.Integer, default=0, nullable=False)
|
||||
|
||||
# Request metadata
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="refresh_tokens")
|
||||
user = db.relationship("User", back_populates="oidc_refresh_tokens")
|
||||
access_token = db.relationship("Session", back_populates="oidc_refresh_token")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCRefreshToken."""
|
||||
return f"<OIDCRefreshToken client_id={self.client_id} user_id={self.user_id} revoked={self.is_revoked()}>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the refresh token has expired."""
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
def is_revoked(self):
|
||||
"""Check if the refresh token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self):
|
||||
"""Check if the refresh token is valid for use."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""Revoke the refresh token.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked_at = datetime.utcnow()
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
def rotate(self, new_token_hash):
|
||||
"""Rotate the refresh token (invalidate old, create new).
|
||||
|
||||
Args:
|
||||
new_token_hash: Hash of the new refresh token
|
||||
|
||||
Returns:
|
||||
self for chaining
|
||||
"""
|
||||
# Store reference to old token
|
||||
self.previous_token_hash = self.token_hash
|
||||
self.token_hash = new_token_hash
|
||||
self.rotation_count += 1
|
||||
# Extend expiration on rotation
|
||||
from datetime import timedelta
|
||||
self.expires_at = datetime.utcnow() + timedelta(days=30)
|
||||
db.session.commit()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create_token(cls, client_id, user_id, token_hash, scope=None,
|
||||
access_token_id=None, ip_address=None, user_agent=None,
|
||||
lifetime_seconds=2592000):
|
||||
"""Create a new refresh token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
token_hash: Hashed refresh token
|
||||
scope: Granted scopes
|
||||
access_token_id: Associated access token ID
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Token lifetime in seconds (default 30 days)
|
||||
|
||||
Returns:
|
||||
OIDCRefreshToken instance
|
||||
"""
|
||||
from datetime import timedelta
|
||||
token = cls(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
token_hash=token_hash,
|
||||
scope=scope,
|
||||
access_token_id=access_token_id,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
db.session.add(token)
|
||||
db.session.commit()
|
||||
return token
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude token hashes
|
||||
exclude.append("token_hash")
|
||||
exclude.append("previous_token_hash")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from app.models.user import User
|
||||
User.oidc_refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from app.models.oidc_client import OIDCClient
|
||||
OIDCClient.refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to Session model
|
||||
from app.models.session import Session
|
||||
Session.oidc_refresh_token = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="access_token", uselist=False
|
||||
)
|
||||
@@ -0,0 +1,162 @@
|
||||
"""OIDC Session model for OIDC session tracking."""
|
||||
from datetime import datetime
|
||||
from app.extensions import db
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCSession(BaseModel):
|
||||
"""OIDC Session model for tracking OIDC authentication sessions.
|
||||
|
||||
This model tracks the state during the OIDC authentication flow,
|
||||
including PKCE parameters and nonce validation.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_sessions"
|
||||
|
||||
# User reference
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Client reference
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# State management
|
||||
state = db.Column(db.String(255), nullable=False, index=True)
|
||||
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
|
||||
|
||||
# Authorization request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True) # Requested scopes
|
||||
|
||||
# PKCE parameters
|
||||
code_challenge = db.Column(db.String(255), nullable=True)
|
||||
code_challenge_method = db.Column(db.String(10), nullable=True) # "S256" or "plain"
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
authenticated_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="oidc_sessions")
|
||||
client = db.relationship("OIDCClient", back_populates="oidc_sessions")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCSession."""
|
||||
return f"<OIDCSession user_id={self.user_id} client_id={self.client_id} state={self.state[:8]}...>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the OIDC session has expired."""
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
def is_authenticated(self):
|
||||
"""Check if the user has been authenticated in this session."""
|
||||
return self.authenticated_at is not None
|
||||
|
||||
def mark_authenticated(self):
|
||||
"""Mark the session as authenticated."""
|
||||
self.authenticated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def validate_nonce(self, expected_nonce):
|
||||
"""Validate the nonce matches the expected value.
|
||||
|
||||
Args:
|
||||
expected_nonce: The expected nonce value
|
||||
|
||||
Returns:
|
||||
bool: True if nonce matches
|
||||
"""
|
||||
return self.nonce == expected_nonce
|
||||
|
||||
def validate_code_challenge(self, code_verifier):
|
||||
"""Validate the code verifier against the stored code challenge.
|
||||
|
||||
Args:
|
||||
code_verifier: The PKCE code verifier
|
||||
|
||||
Returns:
|
||||
bool: True if code challenge is valid
|
||||
"""
|
||||
if not self.code_challenge:
|
||||
return False
|
||||
|
||||
if self.code_challenge_method == "S256":
|
||||
import hashlib
|
||||
import base64
|
||||
# SHA256 hash of code_verifier
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
# Base64 URL encode without padding
|
||||
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return self.code_challenge == expected
|
||||
elif self.code_challenge_method == "plain":
|
||||
return self.code_challenge == code_verifier
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def create_session(cls, user_id, client_id, state, redirect_uri, scope=None,
|
||||
nonce=None, code_challenge=None, code_challenge_method=None,
|
||||
lifetime_seconds=600):
|
||||
"""Create a new OIDC session.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: The OIDC client ID
|
||||
state: The state parameter
|
||||
redirect_uri: The redirect URI
|
||||
scope: Requested scopes
|
||||
nonce: OIDC nonce
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method ("S256" or "plain")
|
||||
lifetime_seconds: Session lifetime in seconds
|
||||
|
||||
Returns:
|
||||
OIDCSession instance
|
||||
"""
|
||||
from datetime import timedelta
|
||||
session = cls(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
state=state,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
nonce=nonce,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
|
||||
)
|
||||
db.session.add(session)
|
||||
db.session.commit()
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def get_by_state(cls, state):
|
||||
"""Get a session by state parameter.
|
||||
|
||||
Args:
|
||||
state: The state parameter
|
||||
|
||||
Returns:
|
||||
OIDCSession instance or None
|
||||
"""
|
||||
return cls.query.filter_by(state=state, deleted_at=None).first()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from app.models.user import User
|
||||
User.oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from app.models.oidc_client import OIDCClient
|
||||
OIDCClient.oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -0,0 +1,192 @@
|
||||
"""OIDC Token Metadata model for token revocation tracking."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.extensions import db
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCTokenMetadata(BaseModel):
|
||||
"""OIDC Token Metadata model for tracking issued tokens.
|
||||
|
||||
This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens)
|
||||
for the purpose of token revocation. The id field matches the JTI (JWT ID) claim.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_token_metadata"
|
||||
|
||||
# Token identifier (matches JTI in JWT)
|
||||
id = db.Column(
|
||||
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token type
|
||||
token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token"
|
||||
|
||||
# Token identifier for revocation lookup
|
||||
token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
|
||||
# Revocation tracking
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="token_metadata")
|
||||
user = db.relationship("User", back_populates="oidc_token_metadata")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCTokenMetadata."""
|
||||
return f"<OIDCTokenMetadata jti={self.token_jti[:8]}... type={self.token_type} revoked={self.is_revoked()}>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the token has expired."""
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
def is_revoked(self):
|
||||
"""Check if the token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self):
|
||||
"""Check if the token is valid (not expired and not revoked)."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""Revoke the token.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked_at = datetime.utcnow()
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_metadata(cls, client_id, user_id, token_type, token_jti,
|
||||
expires_at, ip_address=None, user_agent=None):
|
||||
"""Create token metadata for tracking.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
token_type: Type of token ("access_token", "refresh_token", "id_token")
|
||||
token_jti: JWT ID claim
|
||||
expires_at: Token expiration datetime
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
OIDCTokenMetadata instance
|
||||
"""
|
||||
metadata = cls(
|
||||
id=str(uuid.uuid4()),
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
token_type=token_type,
|
||||
token_jti=token_jti,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.session.add(metadata)
|
||||
db.session.commit()
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def get_by_jti(cls, token_jti):
|
||||
"""Get token metadata by JWT ID.
|
||||
|
||||
Args:
|
||||
token_jti: The JWT ID
|
||||
|
||||
Returns:
|
||||
OIDCTokenMetadata instance or None
|
||||
"""
|
||||
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
|
||||
|
||||
@classmethod
|
||||
def revoke_by_jti(cls, token_jti, reason=None):
|
||||
"""Revoke a token by its JWT ID.
|
||||
|
||||
Args:
|
||||
token_jti: The JWT ID
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
bool: True if token was found and revoked
|
||||
"""
|
||||
metadata = cls.get_by_jti(token_jti)
|
||||
if metadata:
|
||||
metadata.revoke(reason)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_user(cls, user_id, client_id=None, reason=None):
|
||||
"""Revoke all tokens for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: Optional client ID to filter by
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
int: Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(user_id=user_id, deleted_at=None)
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
|
||||
tokens = query.filter(cls.revoked_at == None).all()
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_client(cls, client_id, user_id=None, reason=None):
|
||||
"""Revoke all tokens for a client.
|
||||
|
||||
Args:
|
||||
client_id: The client ID
|
||||
user_id: Optional user ID to filter by
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
int: Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(client_id=client_id, deleted_at=None)
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
|
||||
tokens = query.filter(cls.revoked_at == None).all()
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from app.models.user import User
|
||||
User.oidc_token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from app.models.oidc_client import OIDCClient
|
||||
OIDCClient.token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -4,6 +4,11 @@ from app.services.user_service import UserService
|
||||
from app.services.organization_service import OrganizationService
|
||||
from app.services.session_service import SessionService
|
||||
from app.services.audit_service import AuditService
|
||||
from app.services.oidc_service import OIDCService, OIDCError
|
||||
from app.services.oidc_jwks_service import OIDCJWKSService
|
||||
from app.services.oidc_token_service import OIDCTokenService
|
||||
from app.services.oidc_session_service import OIDCSessionService
|
||||
from app.services.oidc_audit_service import OIDCAuditService
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
@@ -11,4 +16,10 @@ __all__ = [
|
||||
"OrganizationService",
|
||||
"SessionService",
|
||||
"AuditService",
|
||||
"OIDCService",
|
||||
"OIDCError",
|
||||
"OIDCJWKSService",
|
||||
"OIDCTokenService",
|
||||
"OIDCSessionService",
|
||||
"OIDCAuditService",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
"""OIDC Audit Service for comprehensive OIDC event logging."""
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from flask import g
|
||||
|
||||
from app.models import OIDCAuditLog, OIDCClient, User
|
||||
from app.exceptions.validation_exceptions import NotFoundError
|
||||
|
||||
|
||||
class OIDCAuditService:
|
||||
"""Service for OIDC-specific audit logging.
|
||||
|
||||
This service provides methods to log all OIDC-related events including:
|
||||
- Authorization requests and responses
|
||||
- Token issuance and refresh
|
||||
- Token revocation
|
||||
- UserInfo endpoint access
|
||||
- Authentication failures
|
||||
"""
|
||||
|
||||
# Event type constants
|
||||
EVENT_AUTHORIZATION_REQUEST = "authorization_request"
|
||||
EVENT_AUTHORIZATION_RESPONSE = "authorization_response"
|
||||
EVENT_TOKEN_ISSUE = "token_issue"
|
||||
EVENT_TOKEN_REFRESH = "token_refresh"
|
||||
EVENT_TOKEN_REVOCATION = "token_revocation"
|
||||
EVENT_TOKEN_INTROSPECTION = "token_introspection"
|
||||
EVENT_USERINFO_ACCESS = "userinfo_access"
|
||||
EVENT_AUTHENTICATION_FAILURE = "authentication_failure"
|
||||
EVENT_AUTHORIZATION_FAILURE = "authorization_failure"
|
||||
EVENT_JWKS_ACCESS = "jwks_access"
|
||||
EVENT_REGISTRATION = "client_registration"
|
||||
|
||||
@classmethod
|
||||
def _get_request_context(cls) -> Dict:
|
||||
"""Extract request context for logging.
|
||||
|
||||
Returns:
|
||||
Dictionary with IP, user_agent, and request_id
|
||||
"""
|
||||
from flask import request
|
||||
|
||||
return {
|
||||
"ip_address": request.remote_addr if request else None,
|
||||
"user_agent": request.headers.get("User-Agent") if request else None,
|
||||
"request_id": g.get("request_id"),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def log_event(
|
||||
cls,
|
||||
event_type: str,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
metadata: Dict = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a generic OIDC event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
success: Whether the event was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
metadata: Additional event metadata
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
context = cls._get_request_context()
|
||||
|
||||
log = OIDCAuditLog.log_event(
|
||||
event_type=event_type,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=context["ip_address"],
|
||||
user_agent=context["user_agent"],
|
||||
request_id=context["request_id"],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
@classmethod
|
||||
def log_authorization_event(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
redirect_uri: str = None,
|
||||
scope: list = None,
|
||||
response_type: str = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log an authorization event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID (if authenticated)
|
||||
success: Whether authorization was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
redirect_uri: Redirect URI from request
|
||||
scope: Requested scopes
|
||||
response_type: Response type (e.g., "code")
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
metadata = {
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": scope,
|
||||
"response_type": response_type,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_AUTHORIZATION_REQUEST,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_event(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str = None,
|
||||
token_type: str = "access_token",
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
grant_type: str = None,
|
||||
scopes: list = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a token issuance or refresh event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
token_type: Type of token issued
|
||||
success: Whether token issuance was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
grant_type: Grant type used (e.g., "authorization_code", "refresh_token")
|
||||
scopes: Scopes included in the token
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
metadata = {
|
||||
"token_type": token_type,
|
||||
"grant_type": grant_type,
|
||||
"scopes": scopes,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_TOKEN_ISSUE if token_type else cls.EVENT_TOKEN_REFRESH,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_userinfo_event(
|
||||
cls,
|
||||
access_token: str = None,
|
||||
user_id: str = None,
|
||||
client_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
scopes_claimed: list = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a UserInfo endpoint access event.
|
||||
|
||||
Args:
|
||||
access_token: Access token used (masked)
|
||||
user_id: User ID returned
|
||||
client_id: Client ID making the request
|
||||
success: Whether access was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
scopes_claimed: Scopes claimed in the request
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
# Mask the access token for security
|
||||
masked_token = None
|
||||
if access_token:
|
||||
masked_token = access_token[:8] + "..." + access_token[-4:] if len(access_token) > 12 else "***"
|
||||
|
||||
metadata = {
|
||||
"token_prefix": masked_token,
|
||||
"scopes_claimed": scopes_claimed,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_USERINFO_ACCESS,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_revocation_event(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str = None,
|
||||
token_type: str = "access_token",
|
||||
reason: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a token revocation event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
token_type: Type of token being revoked
|
||||
reason: Revocation reason
|
||||
success: Whether revocation was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
metadata = {
|
||||
"token_type": token_type,
|
||||
"reason": reason,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_TOKEN_REVOCATION,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_authentication_failure(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
error_code: str = "authentication_failed",
|
||||
error_description: str = "Authentication failed",
|
||||
user_id: str = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log an authentication failure event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
error_code: Error code
|
||||
error_description: Error description
|
||||
user_id: User ID if known
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_AUTHENTICATION_FAILURE,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=False,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_user(
|
||||
cls,
|
||||
user_id: str,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False
|
||||
) -> List[OIDCAuditLog]:
|
||||
"""Get audit events for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
limit: Maximum number of events to return
|
||||
include_deleted: Include soft-deleted events
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return OIDCAuditLog.get_events_for_user(user_id, limit)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_client(
|
||||
cls,
|
||||
client_id: str,
|
||||
limit: int = 100
|
||||
) -> List[OIDCAuditLog]:
|
||||
"""Get audit events for a specific client.
|
||||
|
||||
Args:
|
||||
client_id: Client ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return OIDCAuditLog.get_events_for_client(client_id, limit)
|
||||
|
||||
@classmethod
|
||||
def get_failed_events(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
start_date: datetime = None,
|
||||
end_date: datetime = None,
|
||||
limit: int = 100
|
||||
) -> List[OIDCAuditLog]:
|
||||
"""Get failed audit events for analysis.
|
||||
|
||||
Args:
|
||||
client_id: Optional client ID filter
|
||||
user_id: Optional user ID filter
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of failed OIDCAuditLog instances
|
||||
"""
|
||||
return OIDCAuditLog.get_failed_events(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_event_summary(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
days: int = 30
|
||||
) -> Dict:
|
||||
"""Get a summary of audit events.
|
||||
|
||||
Args:
|
||||
client_id: Optional client ID filter
|
||||
days: Number of days to look back
|
||||
|
||||
Returns:
|
||||
Summary dictionary with event counts
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
start_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
query = OIDCAuditLog.query.filter(
|
||||
OIDCAuditLog.created_at >= start_date
|
||||
)
|
||||
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
|
||||
events = query.all()
|
||||
|
||||
# Count by event type
|
||||
event_counts = {}
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for event in events:
|
||||
event_type = event.event_type
|
||||
event_counts[event_type] = event_counts.get(event_type, 0) + 1
|
||||
|
||||
if event.success:
|
||||
success_count += 1
|
||||
else:
|
||||
failure_count += 1
|
||||
|
||||
return {
|
||||
"total_events": len(events),
|
||||
"successful_events": success_count,
|
||||
"failed_events": failure_count,
|
||||
"by_event_type": event_counts,
|
||||
"period_days": days,
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
"""OIDC JWKS Service for key management and rotation."""
|
||||
import uuid
|
||||
import json
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from app.extensions import db
|
||||
|
||||
|
||||
class JWKSKey:
|
||||
"""Represents a JWKS key entry."""
|
||||
|
||||
def __init__(self, kid: str, private_key: str, public_key: str,
|
||||
algorithm: str = "RS256", created_at: datetime = None,
|
||||
expires_at: datetime = None, is_active: bool = True):
|
||||
self.kid = kid
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
self.algorithm = algorithm
|
||||
self.created_at = created_at or datetime.utcnow()
|
||||
self.expires_at = expires_at or datetime.utcnow() + timedelta(days=365)
|
||||
self.is_active = is_active
|
||||
|
||||
def to_jwk(self) -> Dict:
|
||||
"""Convert to JWK format for JWKS endpoint."""
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
# Import cryptography here to avoid issues if not installed
|
||||
try:
|
||||
# Get public key from PEM
|
||||
public_key = serialization.load_pem_public_key(
|
||||
self.public_key.encode(), backend=default_backend()
|
||||
)
|
||||
|
||||
# Get RSA parameters
|
||||
public_numbers = public_key.public_numbers()
|
||||
|
||||
return {
|
||||
"kty": "RSA",
|
||||
"kid": self.kid,
|
||||
"use": "sig",
|
||||
"alg": self.algorithm,
|
||||
"n": _base64url_encode(public_numbers.n),
|
||||
"e": _base64url_encode(public_numbers.e),
|
||||
}
|
||||
except ImportError:
|
||||
# Fallback for when cryptography is not installed
|
||||
return {
|
||||
"kty": "RSA",
|
||||
"kid": self.kid,
|
||||
"use": "sig",
|
||||
"alg": self.algorithm,
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary for storage."""
|
||||
return {
|
||||
"kid": self.kid,
|
||||
"private_key": self.private_key,
|
||||
"public_key": self.public_key,
|
||||
"algorithm": self.algorithm,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat(),
|
||||
"is_active": self.is_active,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "JWKSKey":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
kid=data["kid"],
|
||||
private_key=data["private_key"],
|
||||
public_key=data["public_key"],
|
||||
algorithm=data.get("algorithm", "RS256"),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]),
|
||||
is_active=data.get("is_active", True),
|
||||
)
|
||||
|
||||
|
||||
def _base64url_encode(value: int) -> str:
|
||||
"""Encode an integer to base64url format."""
|
||||
import base64
|
||||
byte_length = (value.bit_length() + 7) // 8 or 1
|
||||
encoded = value.to_bytes(byte_length, byteorder="big")
|
||||
return base64.urlsafe_b64encode(encoded).decode().rstrip("=")
|
||||
|
||||
|
||||
class OIDCJWKSService:
|
||||
"""Service for managing OIDC signing keys (JWKS).
|
||||
|
||||
This service handles RSA key pair generation, rotation, and JWKS document
|
||||
generation for the OIDC implementation.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_keys: Dict[str, JWKSKey] = {}
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._keys = {}
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset the singleton (for testing)."""
|
||||
cls._instance = None
|
||||
cls._keys = {}
|
||||
|
||||
def _generate_kid(self, private_key: str) -> str:
|
||||
"""Generate a key ID from the private key fingerprint."""
|
||||
kid_hash = hashlib.sha256(private_key.encode()).hexdigest()[:32]
|
||||
return kid_hash
|
||||
|
||||
def _generate_rsa_key_pair(self) -> Tuple[str, str]:
|
||||
"""Generate a new RSA key pair in PEM format.
|
||||
|
||||
Returns:
|
||||
Tuple of (private_key_pem, public_key_pem)
|
||||
"""
|
||||
try:
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
# Generate RSA private key
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
backend=default_backend()
|
||||
)
|
||||
|
||||
# Get public key
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Serialize to PEM
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
).decode()
|
||||
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
).decode()
|
||||
|
||||
return private_pem, public_pem
|
||||
except ImportError:
|
||||
# Fallback for testing without cryptography
|
||||
import secrets
|
||||
return f"private_key_{secrets.token_hex(32)}", f"public_key_{secrets.token_hex(32)}"
|
||||
|
||||
def get_jwks(self, include_private_keys: bool = False) -> Dict:
|
||||
"""Get the JWKS document containing public keys.
|
||||
|
||||
Args:
|
||||
include_private_keys: Whether to include private keys (for internal use only)
|
||||
|
||||
Returns:
|
||||
JWKS document dictionary
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
keys = []
|
||||
for kid, key in self._keys.items():
|
||||
# Only include active, non-expired keys
|
||||
if key.is_active and key.expires_at > now:
|
||||
if include_private_keys:
|
||||
keys.append(key.to_dict())
|
||||
else:
|
||||
keys.append(key.to_jwk())
|
||||
|
||||
return {
|
||||
"keys": keys
|
||||
}
|
||||
|
||||
def get_signing_key(self) -> Optional[JWKSKey]:
|
||||
"""Get the current active signing key.
|
||||
|
||||
Returns:
|
||||
JWKSKey instance or None if no active key
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
for kid, key in self._keys.items():
|
||||
if key.is_active and key.expires_at > now:
|
||||
return key
|
||||
|
||||
return None
|
||||
|
||||
def get_key_by_kid(self, kid: str) -> Optional[JWKSKey]:
|
||||
"""Get a specific key by its ID.
|
||||
|
||||
Args:
|
||||
kid: Key ID to look up
|
||||
|
||||
Returns:
|
||||
JWKSKey instance or None if not found
|
||||
"""
|
||||
return self._keys.get(kid)
|
||||
|
||||
def generate_new_key_pair(self, expires_in_days: int = 365) -> JWKSKey:
|
||||
"""Generate a new RSA key pair for signing.
|
||||
|
||||
Args:
|
||||
expires_in_days: Days until key expiration
|
||||
|
||||
Returns:
|
||||
JWKSKey instance
|
||||
"""
|
||||
private_key, public_key = self._generate_rsa_key_pair()
|
||||
kid = self._generate_kid(private_key)
|
||||
|
||||
now = datetime.utcnow()
|
||||
key = JWKSKey(
|
||||
kid=kid,
|
||||
private_key=private_key,
|
||||
public_key=public_key,
|
||||
algorithm="RS256",
|
||||
created_at=now,
|
||||
expires_at=now + timedelta(days=expires_in_days),
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
self._keys[kid] = key
|
||||
|
||||
# Deactivate old keys (but keep them for grace period)
|
||||
for old_kid in self._keys:
|
||||
if old_kid != kid:
|
||||
self._keys[old_kid].is_active = False
|
||||
|
||||
return key
|
||||
|
||||
def rotate_keys(self, grace_period_hours: int = 24) -> Tuple[JWKSKey, List[str]]:
|
||||
"""Rotate signing keys, keeping previous key active for grace period.
|
||||
|
||||
Args:
|
||||
grace_period_hours: Hours to keep old keys active
|
||||
|
||||
Returns:
|
||||
Tuple of (new_key, list_of_deprecated_kids)
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
grace_end = now + timedelta(hours=grace_period_hours)
|
||||
|
||||
# Mark current key as deprecated
|
||||
current_key = self.get_signing_key()
|
||||
deprecated_kids = []
|
||||
|
||||
if current_key:
|
||||
deprecated_kids.append(current_key.kid)
|
||||
# Keep key active but mark as deprecated
|
||||
current_key.is_active = False
|
||||
current_key.expires_at = grace_end
|
||||
|
||||
# Generate new key
|
||||
new_key = self.generate_new_key_pair()
|
||||
|
||||
# Clean up expired keys
|
||||
expired_kids = [
|
||||
kid for kid, key in self._keys.items()
|
||||
if key.expires_at < now
|
||||
]
|
||||
for kid in expired_kids:
|
||||
del self._keys[kid]
|
||||
|
||||
return new_key, deprecated_kids
|
||||
|
||||
def verify_key_exists(self, kid: str) -> bool:
|
||||
"""Check if a key with the given ID exists and is valid.
|
||||
|
||||
Args:
|
||||
kid: Key ID to check
|
||||
|
||||
Returns:
|
||||
True if key exists and is valid
|
||||
"""
|
||||
key = self.get_key_by_kid(kid)
|
||||
if not key:
|
||||
return False
|
||||
|
||||
now = datetime.utcnow()
|
||||
return key.is_active and key.expires_at > now
|
||||
|
||||
def initialize_with_key(self) -> JWKSKey:
|
||||
"""Initialize the service with a key if none exists.
|
||||
|
||||
Returns:
|
||||
JWKSKey instance
|
||||
"""
|
||||
if not self._keys:
|
||||
return self.generate_new_key_pair()
|
||||
return self.get_signing_key()
|
||||
@@ -0,0 +1,745 @@
|
||||
"""OIDC Service - Main OIDC service layer."""
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from flask import current_app, g
|
||||
|
||||
from app.extensions import db
|
||||
from app.models import (
|
||||
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
|
||||
OIDCSession, OIDCTokenMetadata
|
||||
)
|
||||
from app.exceptions.validation_exceptions import (
|
||||
ValidationError, NotFoundError, BadRequestError
|
||||
)
|
||||
from app.exceptions.auth_exceptions import UnauthorizedError, InvalidTokenError
|
||||
from app.services.oidc_token_service import OIDCTokenService
|
||||
from app.services.oidc_session_service import OIDCSessionService
|
||||
from app.services.oidc_audit_service import OIDCAuditService
|
||||
from app.services.oidc_jwks_service import OIDCJWKSService
|
||||
|
||||
|
||||
class OIDCError(Exception):
|
||||
"""Base exception for OIDC errors."""
|
||||
|
||||
def __init__(self, error: str, error_description: str = None, status_code: int = 400):
|
||||
self.error = error
|
||||
self.error_description = error_description
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class InvalidClientError(OIDCError):
|
||||
"""Raised when client authentication fails."""
|
||||
|
||||
def __init__(self, error_description: str = "Invalid client"):
|
||||
super().__init__("invalid_client", error_description, 401)
|
||||
|
||||
|
||||
class InvalidGrantError(OIDCError):
|
||||
"""Raised when grant is invalid."""
|
||||
|
||||
def __init__(self, error_description: str = "Invalid grant"):
|
||||
super().__init__("invalid_grant", error_description, 400)
|
||||
|
||||
|
||||
class InvalidRequestError(OIDCError):
|
||||
"""Raised when request is malformed."""
|
||||
|
||||
def __init__(self, error_description: str = "Invalid request"):
|
||||
super().__init__("invalid_request", error_description, 400)
|
||||
|
||||
|
||||
class OIDCService:
|
||||
"""Main OIDC service handling all OpenID Connect operations.
|
||||
|
||||
This service provides:
|
||||
- Authorization code generation and validation
|
||||
- Token generation (access, refresh, ID tokens)
|
||||
- Token refresh with rotation
|
||||
- Token validation and introspection
|
||||
- Token revocation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_code() -> str:
|
||||
"""Generate a secure authorization code.
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded code
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _hash_value(value: str) -> str:
|
||||
"""Hash a value for secure storage.
|
||||
|
||||
Args:
|
||||
value: Value to hash
|
||||
|
||||
Returns:
|
||||
SHA256 hash
|
||||
"""
|
||||
return hashlib.sha256(value.encode()).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def generate_authorization_code(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
redirect_uri: str,
|
||||
scope: list,
|
||||
state: str,
|
||||
nonce: str,
|
||||
code_challenge: str = None,
|
||||
code_challenge_method: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None
|
||||
) -> str:
|
||||
"""Generate an authorization code for the auth code flow.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
redirect_uri: Redirect URI
|
||||
scope: Requested scopes
|
||||
state: State parameter
|
||||
nonce: Nonce for ID token
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method ("S256" or "plain")
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
Authorization code string
|
||||
|
||||
Raises:
|
||||
ValidationError: If parameters are invalid
|
||||
NotFoundError: If client not found
|
||||
"""
|
||||
# Validate client exists and is active
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
raise NotFoundError("Client not found")
|
||||
|
||||
if not client.is_active:
|
||||
raise ValidationError("Client is not active")
|
||||
|
||||
# Validate redirect URI
|
||||
if not client.is_redirect_uri_allowed(redirect_uri):
|
||||
raise ValidationError("Invalid redirect_uri")
|
||||
|
||||
# Validate scopes
|
||||
allowed_scopes = client.scopes or []
|
||||
valid_scopes = [s for s in scope if s in allowed_scopes]
|
||||
|
||||
if not valid_scopes:
|
||||
raise ValidationError("Invalid scopes")
|
||||
|
||||
# Generate authorization code
|
||||
code = cls._generate_code()
|
||||
code_hash = cls._hash_value(code)
|
||||
|
||||
# Create auth code record
|
||||
auth_code = OIDCAuthCode.create_code(
|
||||
client_id=client.id,
|
||||
user_id=user_id,
|
||||
code_hash=code_hash,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=valid_scopes,
|
||||
nonce=nonce,
|
||||
code_verifier=code_challenge, # Store for validation
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
lifetime_seconds=600, # 10 minutes
|
||||
)
|
||||
|
||||
# Log authorization event
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=True,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=valid_scopes,
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
@classmethod
|
||||
def validate_authorization_code(
|
||||
cls,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None
|
||||
) -> Tuple[Dict, User]:
|
||||
"""Validate and consume an authorization code.
|
||||
|
||||
Args:
|
||||
code: Authorization code
|
||||
client_id: OIDC client ID
|
||||
redirect_uri: Redirect URI
|
||||
code_verifier: PKCE code verifier (required if PKCE was used)
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
Tuple of (claims dict, User instance)
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If code is invalid
|
||||
ValidationError: If PKCE validation fails
|
||||
"""
|
||||
# Get client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
raise InvalidGrantError("Invalid client")
|
||||
|
||||
# Hash the provided code and find matching auth code
|
||||
code_hash = cls._hash_value(code)
|
||||
auth_code = OIDCAuthCode.query.filter_by(
|
||||
code_hash=code_hash,
|
||||
client_id=client.id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not auth_code:
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Invalid or expired authorization code",
|
||||
)
|
||||
raise InvalidGrantError("Invalid or expired authorization code")
|
||||
|
||||
# Check if already used
|
||||
if auth_code.is_used:
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client_id,
|
||||
user_id=auth_code.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Authorization code already used",
|
||||
)
|
||||
raise InvalidGrantError("Authorization code already used")
|
||||
|
||||
# Check expiration
|
||||
if auth_code.is_expired():
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client_id,
|
||||
user_id=auth_code.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Authorization code expired",
|
||||
)
|
||||
raise InvalidGrantError("Authorization code expired")
|
||||
|
||||
# Validate redirect URI
|
||||
if auth_code.redirect_uri != redirect_uri:
|
||||
raise InvalidGrantError("Invalid redirect_uri")
|
||||
|
||||
# Validate PKCE if required
|
||||
if client.require_pkce and auth_code.code_verifier:
|
||||
if not code_verifier:
|
||||
raise ValidationError("code_verifier is required")
|
||||
|
||||
# Verify code verifier
|
||||
expected_challenge = cls._compute_code_challenge(code_verifier, "S256")
|
||||
if expected_challenge != auth_code.code_verifier:
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client_id,
|
||||
user_id=auth_code.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Invalid code_verifier",
|
||||
)
|
||||
raise InvalidGrantError("Invalid code_verifier")
|
||||
|
||||
# Mark code as used
|
||||
auth_code.mark_as_used()
|
||||
|
||||
# Get user
|
||||
user = User.query.get(auth_code.user_id)
|
||||
if not user:
|
||||
raise InvalidGrantError("User not found")
|
||||
|
||||
claims = {
|
||||
"user_id": auth_code.user_id,
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": auth_code.scope,
|
||||
"nonce": auth_code.nonce,
|
||||
}
|
||||
|
||||
return claims, user
|
||||
|
||||
@classmethod
|
||||
def _compute_code_challenge(cls, verifier: str, method: str = "S256") -> str:
|
||||
"""Compute PKCE code challenge from verifier.
|
||||
|
||||
Args:
|
||||
verifier: Code verifier
|
||||
method: Challenge method
|
||||
|
||||
Returns:
|
||||
Code challenge
|
||||
"""
|
||||
import hashlib
|
||||
import base64
|
||||
|
||||
if method == "S256":
|
||||
digest = hashlib.sha256(verifier.encode()).digest()
|
||||
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return verifier
|
||||
|
||||
@classmethod
|
||||
def generate_tokens(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
scope: list,
|
||||
nonce: str = None,
|
||||
refresh_token: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
auth_time: int = None
|
||||
) -> Dict:
|
||||
"""Generate access token, ID token, and refresh token.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
scope: Granted scopes
|
||||
nonce: Nonce for ID token
|
||||
refresh_token: Existing refresh token (for rotation)
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
auth_time: Authentication time
|
||||
|
||||
Returns:
|
||||
Dictionary with tokens
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# Get client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
raise InvalidClientError()
|
||||
|
||||
# Generate access token
|
||||
access_token_jti = OIDCTokenService._generate_jti()
|
||||
access_token = OIDCTokenService.create_access_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
jti=access_token_jti,
|
||||
)
|
||||
|
||||
# Generate ID token
|
||||
id_token = OIDCTokenService.create_id_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
nonce=nonce,
|
||||
scope=scope,
|
||||
access_token=access_token,
|
||||
auth_time=auth_time,
|
||||
)
|
||||
|
||||
# Generate or rotate refresh token
|
||||
if "refresh_token" in (client.grant_types or []):
|
||||
if refresh_token:
|
||||
# Rotate existing refresh token
|
||||
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
||||
token_hash=hashlib.sha256(refresh_token.encode()).hexdigest(),
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if refresh_token_obj and refresh_token_obj.is_valid():
|
||||
# Create new refresh token
|
||||
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
access_token_id=access_token_jti,
|
||||
)
|
||||
|
||||
# Rotate in database
|
||||
refresh_token_obj.rotate(new_hash)
|
||||
final_refresh_token = new_refresh
|
||||
else:
|
||||
final_refresh_token = None
|
||||
else:
|
||||
# Create new refresh token
|
||||
final_refresh_token, refresh_hash = OIDCTokenService.create_refresh_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
access_token_id=access_token_jti,
|
||||
)
|
||||
|
||||
# Store refresh token
|
||||
OIDCRefreshToken.create_token(
|
||||
client_id=client.id,
|
||||
user_id=user_id,
|
||||
token_hash=refresh_hash,
|
||||
scope=scope,
|
||||
access_token_id=access_token_jti,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
lifetime_seconds=client.refresh_token_lifetime or 2592000,
|
||||
)
|
||||
else:
|
||||
final_refresh_token = None
|
||||
|
||||
# Store token metadata
|
||||
client_db_id = client.id
|
||||
|
||||
# Access token metadata
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client_db_id,
|
||||
user_id=user_id,
|
||||
token_type="access_token",
|
||||
token_jti=access_token_jti,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
|
||||
)
|
||||
|
||||
# ID token metadata (using access token JTI as reference)
|
||||
id_token_jti = OIDCTokenService._generate_jti()
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client_db_id,
|
||||
user_id=user_id,
|
||||
token_type="id_token",
|
||||
token_jti=id_token_jti,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=client.id_token_lifetime or 3600),
|
||||
)
|
||||
|
||||
# Log token event
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
token_type="access_token",
|
||||
success=True,
|
||||
grant_type="authorization_code",
|
||||
scopes=scope,
|
||||
)
|
||||
|
||||
result = {
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": client.access_token_lifetime or 3600,
|
||||
"id_token": id_token,
|
||||
}
|
||||
|
||||
if final_refresh_token:
|
||||
result["refresh_token"] = final_refresh_token
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def refresh_access_token(
|
||||
cls,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
scope: list = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None
|
||||
) -> Dict:
|
||||
"""Refresh an access token with token rotation.
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token
|
||||
client_id: OIDC client ID
|
||||
scope: Optional scope override
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
Dictionary with new tokens
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If refresh token is invalid
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# Get client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
raise InvalidClientError()
|
||||
|
||||
# Find refresh token
|
||||
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
||||
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
||||
token_hash=token_hash,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not refresh_token_obj:
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Invalid refresh token",
|
||||
)
|
||||
raise InvalidGrantError("Invalid refresh token")
|
||||
|
||||
# Check if valid
|
||||
if not refresh_token_obj.is_valid():
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Refresh token expired or revoked",
|
||||
)
|
||||
raise InvalidGrantError("Refresh token expired or revoked")
|
||||
|
||||
# Validate client matches
|
||||
if refresh_token_obj.client_id != client.id:
|
||||
raise InvalidGrantError("Client mismatch")
|
||||
|
||||
# Get original scope or use provided
|
||||
granted_scope = scope or (refresh_token_obj.scope or [])
|
||||
|
||||
# Generate new access token
|
||||
access_token_jti = OIDCTokenService._generate_jti()
|
||||
access_token = OIDCTokenService.create_access_token(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
scope=granted_scope,
|
||||
jti=access_token_jti,
|
||||
)
|
||||
|
||||
# Generate new ID token
|
||||
id_token = OIDCTokenService.create_id_token(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
scope=granted_scope,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# Rotate refresh token
|
||||
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
scope=granted_scope,
|
||||
access_token_id=access_token_jti,
|
||||
)
|
||||
|
||||
refresh_token_obj.rotate(new_hash)
|
||||
|
||||
# Store new token metadata
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client.id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
token_type="access_token",
|
||||
token_jti=access_token_jti,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
|
||||
)
|
||||
|
||||
# Log refresh event
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
token_type="access_token",
|
||||
success=True,
|
||||
grant_type="refresh_token",
|
||||
scopes=granted_scope,
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": client.access_token_lifetime or 3600,
|
||||
"id_token": id_token,
|
||||
"refresh_token": new_refresh,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
"""Validate an access token and return its claims.
|
||||
|
||||
Args:
|
||||
token: JWT access token
|
||||
client_id: Optional client ID to validate audience
|
||||
|
||||
Returns:
|
||||
Token claims
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If token is invalid
|
||||
"""
|
||||
try:
|
||||
claims = OIDCTokenService.validate_access_token(token, client_id)
|
||||
return claims
|
||||
except Exception as e:
|
||||
OIDCAuditService.log_event(
|
||||
event_type="token_validation",
|
||||
client_id=client_id,
|
||||
success=False,
|
||||
error_code="invalid_token",
|
||||
error_description=str(e),
|
||||
)
|
||||
raise InvalidTokenError(str(e))
|
||||
|
||||
@classmethod
|
||||
def revoke_token(
|
||||
cls,
|
||||
token: str,
|
||||
client_id: str,
|
||||
token_type_hint: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None
|
||||
) -> bool:
|
||||
"""Revoke a token.
|
||||
|
||||
Args:
|
||||
token: Token to revoke
|
||||
client_id: OIDC client ID
|
||||
token_type_hint: Hint about token type
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
True if token was revoked
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# Get client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
raise InvalidClientError()
|
||||
|
||||
revoked = False
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
# Try to revoke as refresh token
|
||||
if token_type_hint in (None, "refresh_token"):
|
||||
refresh_token = OIDCRefreshToken.query.filter_by(
|
||||
token_hash=token_hash,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if refresh_token:
|
||||
refresh_token.revoke(reason="revoked_by_client")
|
||||
revoked = True
|
||||
|
||||
OIDCAuditService.log_token_revocation_event(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token.user_id,
|
||||
token_type="refresh_token",
|
||||
reason="revoked_by_client",
|
||||
)
|
||||
|
||||
# Try to revoke as access token (JTI lookup)
|
||||
if not revoked or token_type_hint in (None, "access_token"):
|
||||
try:
|
||||
# Decode token to get JTI
|
||||
claims = OIDCTokenService.decode_token(token)
|
||||
jti = claims.get("jti")
|
||||
|
||||
if jti:
|
||||
revoked_at = OIDCTokenMetadata.revoke_by_jti(
|
||||
jti,
|
||||
reason="revoked_by_client"
|
||||
)
|
||||
if revoked_at:
|
||||
revoked = True
|
||||
|
||||
OIDCAuditService.log_token_revocation_event(
|
||||
client_id=client_id,
|
||||
user_id=claims.get("sub"),
|
||||
token_type="access_token",
|
||||
reason="revoked_by_client",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return revoked
|
||||
|
||||
@classmethod
|
||||
def introspect_token(
|
||||
cls,
|
||||
token: str,
|
||||
client_id: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None
|
||||
) -> Dict:
|
||||
"""Introspect a token and return its status and claims.
|
||||
|
||||
Args:
|
||||
token: Token to introspect
|
||||
client_id: Client ID for validation
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
Introspection response
|
||||
"""
|
||||
result = OIDCTokenService.introspect_token(token, client_id)
|
||||
|
||||
# Log introspection
|
||||
OIDCAuditService.log_event(
|
||||
event_type="token_introspection",
|
||||
client_id=client_id,
|
||||
user_id=result.get("sub"),
|
||||
success=result.get("active", False),
|
||||
metadata={"active": result.get("active")},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_jwks(cls) -> Dict:
|
||||
"""Get the JWKS document.
|
||||
|
||||
Returns:
|
||||
JWKS document
|
||||
"""
|
||||
jwks_service = OIDCJWKSService()
|
||||
return jwks_service.get_jwks()
|
||||
|
||||
@classmethod
|
||||
def get_userinfo(cls, access_token: str) -> Dict:
|
||||
"""Get user information using access token.
|
||||
|
||||
Args:
|
||||
access_token: Access token
|
||||
|
||||
Returns:
|
||||
User information dictionary
|
||||
"""
|
||||
claims = cls.validate_access_token(access_token)
|
||||
|
||||
user_id = claims.get("sub")
|
||||
user = User.query.get(user_id)
|
||||
|
||||
if not user:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
# Get scopes from token
|
||||
scope_str = claims.get("scope", "")
|
||||
scopes = scope_str.split() if scope_str else []
|
||||
|
||||
userinfo = {"sub": user_id}
|
||||
|
||||
# Add claims based on scope
|
||||
if "profile" in scopes and user.full_name:
|
||||
userinfo["name"] = user.full_name
|
||||
|
||||
if "email" in scopes:
|
||||
userinfo["email"] = user.email
|
||||
userinfo["email_verified"] = user.email_verified
|
||||
|
||||
# Log userinfo access
|
||||
OIDCAuditService.log_userinfo_event(
|
||||
access_token=access_token,
|
||||
user_id=user_id,
|
||||
client_id=claims.get("client_id"),
|
||||
success=True,
|
||||
scopes_claimed=scopes,
|
||||
)
|
||||
|
||||
return userinfo
|
||||
@@ -0,0 +1,288 @@
|
||||
"""OIDC Session Service for session management during OIDC flow."""
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from flask import current_app, g
|
||||
|
||||
from app.extensions import db
|
||||
from app.models import OIDCSession, OIDCClient, User
|
||||
from app.exceptions.validation_exceptions import NotFoundError, ValidationError
|
||||
|
||||
|
||||
class OIDCSessionService:
|
||||
"""Service for managing OIDC authentication sessions.
|
||||
|
||||
This service handles:
|
||||
- Creating OIDC sessions during authorization flow
|
||||
- Validating sessions with state and nonce
|
||||
- Managing PKCE code challenges
|
||||
- Cleaning up expired sessions
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_state() -> str:
|
||||
"""Generate a secure state parameter.
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded state
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _generate_nonce() -> str:
|
||||
"""Generate a secure nonce for OIDC.
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded nonce
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _generate_code_challenge(verifier: str, method: str = "S256") -> str:
|
||||
"""Generate a PKCE code challenge from verifier.
|
||||
|
||||
Args:
|
||||
verifier: The code verifier
|
||||
method: Challenge method ("S256" or "plain")
|
||||
|
||||
Returns:
|
||||
Code challenge string
|
||||
"""
|
||||
import hashlib
|
||||
import base64
|
||||
|
||||
if method == "S256":
|
||||
digest = hashlib.sha256(verifier.encode()).digest()
|
||||
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
elif method == "plain":
|
||||
return verifier
|
||||
else:
|
||||
raise ValueError(f"Unsupported code challenge method: {method}")
|
||||
|
||||
@classmethod
|
||||
def validate_code_verifier(cls, code_verifier: str, code_challenge: str,
|
||||
method: str = "S256") -> bool:
|
||||
"""Validate a PKCE code verifier against the stored challenge.
|
||||
|
||||
Args:
|
||||
code_verifier: The code verifier from the token request
|
||||
code_challenge: The code challenge from the authorization request
|
||||
method: The challenge method used
|
||||
|
||||
Returns:
|
||||
True if validation succeeds
|
||||
"""
|
||||
if not code_verifier or not code_challenge:
|
||||
return False
|
||||
|
||||
# Validate code verifier length (43-128 characters)
|
||||
if method == "S256" and not (43 <= len(code_verifier) <= 128):
|
||||
return False
|
||||
|
||||
# Calculate expected challenge
|
||||
expected_challenge = cls._generate_code_challenge(code_verifier, method)
|
||||
|
||||
return secrets.compare_digest(expected_challenge, code_challenge)
|
||||
|
||||
@classmethod
|
||||
def create_session(
|
||||
cls,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
state: str = None,
|
||||
nonce: str = None,
|
||||
redirect_uri: str = None,
|
||||
scope: list = None,
|
||||
code_challenge: str = None,
|
||||
code_challenge_method: str = None,
|
||||
lifetime_seconds: int = 600
|
||||
) -> OIDCSession:
|
||||
"""Create a new OIDC session for the authorization flow.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: The OIDC client ID
|
||||
state: State parameter (generated if not provided)
|
||||
nonce: Nonce for ID token validation (generated if not provided)
|
||||
redirect_uri: Redirect URI from authorization request
|
||||
scope: Requested scopes
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method ("S256" or "plain")
|
||||
lifetime_seconds: Session lifetime in seconds
|
||||
|
||||
Returns:
|
||||
OIDCSession instance
|
||||
"""
|
||||
# Generate state and nonce if not provided
|
||||
state = state or cls._generate_state()
|
||||
nonce = nonce or cls._generate_nonce()
|
||||
|
||||
session = OIDCSession.create_session(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
lifetime_seconds=lifetime_seconds,
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def validate_session(cls, state: str, nonce: str = None) -> Tuple[OIDCSession, User]:
|
||||
"""Validate an OIDC session by state and optionally nonce.
|
||||
|
||||
Args:
|
||||
state: The state parameter
|
||||
nonce: The nonce to validate (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (OIDCSession, User)
|
||||
|
||||
Raises:
|
||||
ValidationError: If session is invalid
|
||||
NotFoundError: If session not found
|
||||
"""
|
||||
session = OIDCSession.get_by_state(state)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError("OIDC session not found or expired")
|
||||
|
||||
if session.is_expired():
|
||||
raise ValidationError("OIDC session has expired")
|
||||
|
||||
# Validate nonce if provided
|
||||
if nonce and not session.validate_nonce(nonce):
|
||||
raise ValidationError("Invalid nonce")
|
||||
|
||||
# Get user
|
||||
user = User.query.get(session.user_id)
|
||||
if not user:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
return session, user
|
||||
|
||||
@classmethod
|
||||
def validate_pkce(cls, session: OIDCSession, code_verifier: str) -> bool:
|
||||
"""Validate PKCE code verifier against the session's code challenge.
|
||||
|
||||
Args:
|
||||
session: OIDCSession instance
|
||||
code_verifier: The code verifier from token request
|
||||
|
||||
Returns:
|
||||
True if validation succeeds
|
||||
|
||||
Raises:
|
||||
ValidationError: If PKCE validation fails
|
||||
"""
|
||||
if not session.code_challenge:
|
||||
# No PKCE was used, skip validation
|
||||
return True
|
||||
|
||||
if not code_verifier:
|
||||
raise ValidationError("code_verifier is required")
|
||||
|
||||
is_valid = session.validate_code_challenge(code_verifier)
|
||||
|
||||
if not is_valid:
|
||||
raise ValidationError("Invalid code_verifier")
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def mark_session_authenticated(cls, session: OIDCSession) -> OIDCSession:
|
||||
"""Mark a session as authenticated (user has logged in).
|
||||
|
||||
Args:
|
||||
session: OIDCSession instance
|
||||
|
||||
Returns:
|
||||
Updated OIDCSession instance
|
||||
"""
|
||||
session.mark_authenticated()
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired_sessions(cls, older_than_hours: int = 24) -> int:
|
||||
"""Remove expired OIDC sessions.
|
||||
|
||||
Args:
|
||||
older_than_hours: Only delete sessions expired more than this many hours ago
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(hours=older_than_hours)
|
||||
|
||||
# Get expired sessions
|
||||
expired_sessions = OIDCSession.query.filter(
|
||||
OIDCSession.expires_at < datetime.utcnow(),
|
||||
OIDCSession.deleted_at == None
|
||||
).all()
|
||||
|
||||
count = 0
|
||||
for session in expired_sessions:
|
||||
# Only hard delete if past the grace period
|
||||
if session.expires_at < cutoff:
|
||||
session.delete()
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def get_session_by_state(cls, state: str) -> Optional[OIDCSession]:
|
||||
"""Get an OIDC session by state.
|
||||
|
||||
Args:
|
||||
state: The state parameter
|
||||
|
||||
Returns:
|
||||
OIDCSession instance or None
|
||||
"""
|
||||
return OIDCSession.get_by_state(state)
|
||||
|
||||
@classmethod
|
||||
def validate_redirect_uri(cls, client_id: str, redirect_uri: str) -> bool:
|
||||
"""Validate that a redirect URI is allowed for a client.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
redirect_uri: The redirect URI to validate
|
||||
|
||||
Returns:
|
||||
True if redirect URI is allowed
|
||||
"""
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
return False
|
||||
|
||||
return client.is_redirect_uri_allowed(redirect_uri)
|
||||
|
||||
@classmethod
|
||||
def validate_scopes(cls, client_id: str, requested_scopes: list) -> list:
|
||||
"""Validate and filter scopes against client's allowed scopes.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
requested_scopes: List of requested scopes
|
||||
|
||||
Returns:
|
||||
List of allowed scopes
|
||||
"""
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
return []
|
||||
|
||||
allowed_scopes = client.scopes or []
|
||||
|
||||
# Filter to only allowed scopes
|
||||
valid_scopes = [s for s in requested_scopes if s in allowed_scopes]
|
||||
|
||||
return valid_scopes
|
||||
@@ -0,0 +1,439 @@
|
||||
"""OIDC Token Service for JWT token generation and validation."""
|
||||
import hashlib
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
import jwt
|
||||
from flask import current_app, g
|
||||
|
||||
from app.models import User, OIDCClient
|
||||
from app.services.oidc_jwks_service import OIDCJWKSService
|
||||
|
||||
|
||||
class OIDCTokenService:
|
||||
"""Service for generating and validating OIDC tokens.
|
||||
|
||||
This service handles:
|
||||
- Access token creation (JWT)
|
||||
- ID token creation (JWT)
|
||||
- Refresh token creation (opaque)
|
||||
- Token signature verification
|
||||
- Hash generation for PKCE claims (at_hash, c_hash)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_jti() -> str:
|
||||
"""Generate a unique JWT ID."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _generate_opaque_token(length: int = 43) -> str:
|
||||
"""Generate an opaque token (for refresh tokens).
|
||||
|
||||
Args:
|
||||
length: Length of the token
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded token
|
||||
"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
@staticmethod
|
||||
def _hash_token(token: str) -> str:
|
||||
"""Hash a token for secure storage.
|
||||
|
||||
Args:
|
||||
token: Token to hash
|
||||
|
||||
Returns:
|
||||
SHA256 hash of the token
|
||||
"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _base64url_encode(data: bytes) -> str:
|
||||
"""Encode bytes to base64url format without padding.
|
||||
|
||||
Args:
|
||||
data: Bytes to encode
|
||||
|
||||
Returns:
|
||||
Base64url encoded string
|
||||
"""
|
||||
return base64.urlsafe_b64encode(data).decode().rstrip("=")
|
||||
|
||||
@staticmethod
|
||||
def create_at_hash(access_token: str) -> str:
|
||||
"""Create the at_hash claim for ID token.
|
||||
|
||||
Implements OIDC spec for access token hash generation.
|
||||
Hash is the left-most half of the hash of the ASCII representation
|
||||
of the access token.
|
||||
|
||||
Args:
|
||||
access_token: The access token string
|
||||
|
||||
Returns:
|
||||
Base64url encoded hash
|
||||
"""
|
||||
# Hash the access token using SHA256
|
||||
hash_digest = hashlib.sha256(access_token.encode()).digest()
|
||||
|
||||
# Take left-most half of the hash
|
||||
half_length = len(hash_digest) // 2
|
||||
left_half = hash_digest[:half_length]
|
||||
|
||||
# Base64url encode
|
||||
return OIDCTokenService._base64url_encode(left_half)
|
||||
|
||||
@staticmethod
|
||||
def create_c_hash(code: str) -> str:
|
||||
"""Create the c_hash claim for ID token.
|
||||
|
||||
Implements OIDC spec for authorization code hash generation.
|
||||
|
||||
Args:
|
||||
code: The authorization code string
|
||||
|
||||
Returns:
|
||||
Base64url encoded hash
|
||||
"""
|
||||
# Hash the code using SHA256
|
||||
hash_digest = hashlib.sha256(code.encode()).digest()
|
||||
|
||||
# Take left-most half of the hash
|
||||
half_length = len(hash_digest) // 2
|
||||
left_half = hash_digest[:half_length]
|
||||
|
||||
# Base64url encode
|
||||
return OIDCTokenService._base64url_encode(left_half)
|
||||
|
||||
@staticmethod
|
||||
def _get_issuer() -> str:
|
||||
"""Get the OIDC issuer URL."""
|
||||
return current_app.config.get("OIDC_ISSUER_URL", "http://localhost:5000")
|
||||
|
||||
@staticmethod
|
||||
def _get_token_lifetime(client: OIDCClient, token_type: str) -> int:
|
||||
"""Get the token lifetime in seconds for a client.
|
||||
|
||||
Args:
|
||||
client: OIDCClient instance
|
||||
token_type: Type of token ("access_token", "refresh_token", "id_token")
|
||||
|
||||
Returns:
|
||||
Lifetime in seconds
|
||||
"""
|
||||
lifetimes = {
|
||||
"access_token": client.access_token_lifetime or 3600,
|
||||
"refresh_token": client.refresh_token_lifetime or 2592000,
|
||||
"id_token": client.id_token_lifetime or 3600,
|
||||
}
|
||||
return lifetimes.get(token_type, 3600)
|
||||
|
||||
@classmethod
|
||||
def create_access_token(cls, client_id: str, user_id: str, scope: list,
|
||||
jti: str = None) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID (subject)
|
||||
scope: List of granted scopes
|
||||
jti: Optional JWT ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
JWT access token string
|
||||
"""
|
||||
jti = jti or cls._generate_jti()
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
|
||||
|
||||
claims = {
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"nbf": int(now.timestamp()),
|
||||
"jti": jti,
|
||||
"client_id": client_id,
|
||||
"scope": " ".join(scope) if isinstance(scope, list) else scope,
|
||||
}
|
||||
|
||||
# Get signing key
|
||||
jwks_service = OIDCJWKSService()
|
||||
signing_key = jwks_service.get_signing_key()
|
||||
|
||||
if not signing_key:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
algorithm="RS256",
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def create_id_token(cls, client_id: str, user_id: str, nonce: str = None,
|
||||
scope: list = None, access_token: str = None,
|
||||
auth_time: int = None) -> str:
|
||||
"""Create a JWT ID token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID (subject)
|
||||
nonce: Nonce for replay protection
|
||||
scope: Requested/Granted scopes
|
||||
access_token: Associated access token (for at_hash)
|
||||
auth_time: Authentication time (Unix timestamp)
|
||||
|
||||
Returns:
|
||||
JWT ID token string
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
auth_time = auth_time or int(now.timestamp())
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
|
||||
|
||||
# Get user for claims
|
||||
user = User.query.get(user_id)
|
||||
|
||||
claims = {
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"auth_time": auth_time,
|
||||
}
|
||||
|
||||
# Add nonce if provided
|
||||
if nonce:
|
||||
claims["nonce"] = nonce
|
||||
|
||||
# Add at_hash if access token provided
|
||||
if access_token:
|
||||
claims["at_hash"] = cls.create_at_hash(access_token)
|
||||
|
||||
# Add standard claims if user exists
|
||||
if user:
|
||||
if user.email:
|
||||
claims["email"] = user.email
|
||||
claims["email_verified"] = user.email_verified
|
||||
if user.full_name:
|
||||
claims["name"] = user.full_name
|
||||
|
||||
# Add scope if provided
|
||||
if scope:
|
||||
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
|
||||
|
||||
# Get signing key
|
||||
jwks_service = OIDCJWKSService()
|
||||
signing_key = jwks_service.get_signing_key()
|
||||
|
||||
if not signing_key:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
algorithm="RS256",
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def create_refresh_token(cls, client_id: str, user_id: str,
|
||||
scope: list = None, access_token_id: str = None) -> str:
|
||||
"""Create an opaque refresh token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
scope: List of granted scopes
|
||||
access_token_id: Associated access token ID
|
||||
|
||||
Returns:
|
||||
Opaque refresh token string
|
||||
"""
|
||||
token = cls._generate_opaque_token()
|
||||
|
||||
# Hash for storage
|
||||
token_hash = cls._hash_token(token)
|
||||
|
||||
return token, token_hash
|
||||
|
||||
@classmethod
|
||||
def verify_token_signature(cls, token: str) -> Dict:
|
||||
"""Verify the signature of a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
Decoded token claims
|
||||
|
||||
Raises:
|
||||
jwt.InvalidSignatureError: If signature verification fails
|
||||
jwt.ExpiredSignatureError: If token is expired
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
"""
|
||||
# Get the JWKS with public keys
|
||||
jwks_service = OIDCJWKSService()
|
||||
jwks = jwks_service.get_jwks()
|
||||
|
||||
# Get the key ID from token header
|
||||
try:
|
||||
unverified_header = jwt.get_unverified_header(token)
|
||||
except jwt.DecodeError:
|
||||
raise jwt.InvalidTokenError("Invalid token header")
|
||||
|
||||
kid = unverified_header.get("kid")
|
||||
|
||||
# Find the matching public key
|
||||
public_key = None
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
public_key = serialization.load_pem_public_key(
|
||||
key["public_key"].encode() if isinstance(key["public_key"], str)
|
||||
else key["public_key"],
|
||||
backend=default_backend()
|
||||
)
|
||||
break
|
||||
except (ImportError, Exception):
|
||||
continue
|
||||
|
||||
if not public_key:
|
||||
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
|
||||
|
||||
# Verify the signature
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
public_key,
|
||||
algorithms=["RS256"],
|
||||
audience=None, # We'll validate audience separately
|
||||
issuer=cls._get_issuer(),
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_aud": False, # Handle audience manually
|
||||
"verify_iss": False, # Handle issuer manually
|
||||
}
|
||||
)
|
||||
|
||||
return claims
|
||||
|
||||
@classmethod
|
||||
def decode_token(cls, token: str, verify: bool = False) -> Dict:
|
||||
"""Decode a JWT token without verification (for debugging).
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
verify: Whether to verify signature
|
||||
|
||||
Returns:
|
||||
Decoded token claims
|
||||
"""
|
||||
if verify:
|
||||
return cls.verify_token_signature(token)
|
||||
|
||||
return jwt.decode(
|
||||
token,
|
||||
options={
|
||||
"verify_signature": False,
|
||||
"verify_exp": False,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
"""Validate an access token and return its claims.
|
||||
|
||||
Args:
|
||||
token: JWT access token
|
||||
client_id: Optional client ID to validate audience
|
||||
|
||||
Returns:
|
||||
Token claims dictionary
|
||||
|
||||
Raises:
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
ValueError: If token is expired or audience mismatch
|
||||
"""
|
||||
claims = cls.verify_token_signature(token)
|
||||
|
||||
# Check expiration
|
||||
if claims.get("exp", 0) < datetime.utcnow().timestamp():
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
# Validate audience if client_id provided
|
||||
if client_id:
|
||||
if claims.get("aud") != client_id:
|
||||
raise ValueError("Invalid audience")
|
||||
|
||||
return claims
|
||||
|
||||
@classmethod
|
||||
def introspect_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
"""Introspect a token and return its status and claims.
|
||||
|
||||
Args:
|
||||
token: JWT token to introspect
|
||||
client_id: Client ID for audience validation
|
||||
|
||||
Returns:
|
||||
Dictionary with active status and claims
|
||||
"""
|
||||
result = {
|
||||
"active": False,
|
||||
}
|
||||
|
||||
try:
|
||||
claims = cls.validate_access_token(token, client_id)
|
||||
|
||||
# Calculate remaining time
|
||||
now = datetime.utcnow().timestamp()
|
||||
exp = claims.get("exp", 0)
|
||||
iat = claims.get("iat", 0)
|
||||
|
||||
result["active"] = exp > now
|
||||
result.update({
|
||||
"iss": claims.get("iss"),
|
||||
"sub": claims.get("sub"),
|
||||
"aud": claims.get("aud"),
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": claims.get("nbf"),
|
||||
"jti": claims.get("jti"),
|
||||
"client_id": claims.get("client_id"),
|
||||
"scope": claims.get("scope"),
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
|
||||
# Add expiry in seconds
|
||||
if exp > now:
|
||||
result["exp"] = int(exp - now)
|
||||
|
||||
except (jwt.InvalidTokenError, ValueError) as e:
|
||||
result["active"] = False
|
||||
result["error"] = str(e)
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user