major checkpoint

This commit is contained in:
2026-01-08 15:59:53 +10:30
parent 211854ca0a
commit 5e060f267d
33 changed files with 8088 additions and 43 deletions
+4 -1
View File
@@ -6,6 +6,7 @@ SECRET_KEY=your-secret-key-here-change-in-production
# Database # Database
DATABASE_URL=postgresql://user:password@localhost:5432/authy2_dev DATABASE_URL=postgresql://user:password@localhost:5432/authy2_dev
SQLALCHEMY_ECHO=False SQLALCHEMY_ECHO=False
SQLALCHEMY_LOG_LEVEL=WARNING
# Security # Security
BCRYPT_LOG_ROUNDS=12 BCRYPT_LOG_ROUNDS=12
@@ -15,7 +16,9 @@ SESSION_COOKIE_SAMESITE=Lax
MAX_SESSION_DURATION=86400 MAX_SESSION_DURATION=86400
# CORS # CORS
CORS_ORIGINS=http://localhost:3000,http://localhost:5173 #CORS_ORIGINS=http://localhost:3000,http://localhost:5173,https://oidc-playpen.lovable.app/,http://localhost:8080/
CORS_ORIGINS=*
# JWT (if using JWT instead of sessions) # JWT (if using JWT instead of sessions)
JWT_SECRET_KEY=your-jwt-secret-key-here JWT_SECRET_KEY=your-jwt-secret-key-here
+10
View File
@@ -282,3 +282,13 @@ MIT
For issues and questions: For issues and questions:
- GitHub Issues: [repository-url]/issues - GitHub Issues: [repository-url]/issues
- Documentation: See `docs/` directory - Documentation: See `docs/` directory
# Boostrap db
python manage.py db upgrade
python manage.py db migrate
## running seed
python -m scripts.seed_data
+135
View File
@@ -0,0 +1,135 @@
# OIDC Extension to Seed Data Script
## Summary
Extended [`scripts/seed_data.py`](scripts/seed_data.py) to include OIDC client seeding functionality.
## Changes Made
### 1. Added Imports
- `import secrets` - For generating secure random values
- `import hashlib` - For hashing client secrets
- `from app.models.oidc_client import OIDCClient` - OIDC client model
### 2. New Helper Function: `create_or_get_oidc_client()`
Creates OIDC clients with proper configuration or returns existing ones. Features:
- Checks for existing clients by `client_id`
- Hashes client secrets using SHA256
- Supports all OIDC client configuration options
- Proper error handling and logging
### 3. New Seed Step: Step 5 - Create OIDC Clients
Added 4 OIDC clients across the 3 seeded organizations:
#### Acme Corporation (2 clients)
1. **Acme Internal Portal** (`acme-portal-001`)
- Confidential client
- Grant types: authorization_code, refresh_token
- Scopes: openid, profile, email, offline_access
- PKCE required
- Redirect URIs for production and localhost
2. **Acme Mobile App** (`acme-mobile-001`)
- Public client (mobile app)
- Shorter token lifetimes for security
- PKCE required
- Custom URL scheme for mobile redirect
#### Tech Startup Inc (1 client)
3. **Tech Startup Dashboard** (`tech-dashboard-001`)
- Confidential client
- Standard OIDC configuration
- PKCE required
#### Data Systems Inc (1 client)
4. **Data Systems API Client** (`data-api-001`)
- Confidential server-to-server client
- Additional grant type: client_credentials
- Custom scopes: api:read, api:write
- PKCE not required (server-to-server)
## OIDC Client Test Credentials
All clients are configured with test credentials for development:
| Client | Client ID | Client Secret |
|--------|-----------|---------------|
| Acme Portal | `acme-portal-001` | `acme_secret_portal_2024` |
| Acme Mobile | `acme-mobile-001` | `acme_secret_mobile_2024` |
| Tech Dashboard | `tech-dashboard-001` | `tech_secret_dashboard_2024` |
| Data API | `data-api-001` | `data_secret_api_2024` |
## Enhanced Summary Output
The seed script now displays:
- Total count of OIDC clients created
- Detailed information for each client including:
- Client name and ID
- Organization
- Configured grant types
- Configured scopes
- Number of redirect URIs
- Complete test credentials table
## Example Output
```
[Step 5] Creating OIDC Clients...
Acme Corporation OIDC Clients:
→ Created OIDC client: Acme Internal Portal
→ Created OIDC client: Acme Mobile App
Tech Startup OIDC Clients:
→ Created OIDC client: Tech Startup Dashboard
Data Systems OIDC Clients:
→ Created OIDC client: Data Systems API Client
Created 4 OIDC clients
============================================================
Seed Complete!
============================================================
📊 Summary:
Organizations: 3
Admin Users: 2
Regular Users: 9
OIDC Clients: 4
🔐 OIDC Clients:
Acme Internal Portal
Client ID: acme-portal-001
Organization: Acme Corporation
Grant Types: authorization_code, refresh_token
Scopes: openid, profile, email, offline_access
Redirect URIs: 2 configured
...
```
## Features
- **Idempotent**: Running the script multiple times won't create duplicate clients
- **Comprehensive**: Creates diverse client types (confidential, public, server-to-server)
- **Production-ready**: Includes proper secret hashing and security configurations
- **Developer-friendly**: Includes localhost URLs and clear test credentials
- **Well-documented**: Clear console output showing what was created
## Usage
Run the seed script as usual:
```bash
python scripts/seed_data.py
```
The OIDC clients will be automatically created along with users and organizations.
## Security Notes
- Client secrets are hashed using SHA256 before storage
- Test credentials are clearly marked and should **not** be used in production
- PKCE is enabled by default for web and mobile clients
- Token lifetimes are configured appropriately for each client type
+24 -10
View File
@@ -3,12 +3,21 @@ import os
import logging import logging
from flask import Flask from flask import Flask
from config import get_config from config import get_config
from app.extensions import db, migrate, bcrypt, cors, ma, limiter, session from app.extensions import db, migrate, bcrypt, ma, limiter, session
from app.middleware import RequestIDMiddleware, SecurityHeadersMiddleware, setup_cors from app.middleware import RequestIDMiddleware, SecurityHeadersMiddleware, setup_cors
from app.exceptions.base import BaseAPIException from app.exceptions.base import BaseAPIException
from app.utils.response import api_response from app.utils.response import api_response
import redis import redis
# Configure SQLAlchemy logging BEFORE any database operations
# This must be done before db.init_app() to prevent verbose logging
_log_level_env = os.getenv("SQLALCHEMY_LOG_LEVEL", "WARNING").upper()
_sqlalchemy_log_level = getattr(logging, _log_level_env, logging.WARNING)
logging.getLogger('sqlalchemy').setLevel(_sqlalchemy_log_level)
logging.getLogger('sqlalchemy.engine').setLevel(_sqlalchemy_log_level)
logging.getLogger('sqlalchemy.dialects').setLevel(_sqlalchemy_log_level)
logging.getLogger('sqlalchemy.pool').setLevel(_sqlalchemy_log_level)
def create_app(config_name=None): def create_app(config_name=None):
""" """
@@ -53,12 +62,9 @@ def initialize_extensions(app):
# Security # Security
bcrypt.init_app(app) bcrypt.init_app(app)
# CORS # CORS - using custom middleware only (see app/middleware/cors.py)
cors.init_app( # Flask-CORS disabled to avoid conflicts
app, # cors.init_app(app)
origins=app.config.get("CORS_ORIGINS", []),
supports_credentials=app.config.get("CORS_SUPPORTS_CREDENTIALS", True),
)
# Marshmallow # Marshmallow
ma.init_app(app) ma.init_app(app)
@@ -84,15 +90,19 @@ def setup_middleware(app):
"""Setup application middleware.""" """Setup application middleware."""
RequestIDMiddleware(app) RequestIDMiddleware(app)
SecurityHeadersMiddleware(app) SecurityHeadersMiddleware(app)
setup_cors(app, cors) setup_cors(app)
def register_blueprints(app): def register_blueprints(app):
"""Register application blueprints.""" """Register application blueprints."""
from app.api import register_api_blueprints from app.api import register_api_blueprints
from app.api.oidc import oidc_bp
register_api_blueprints(app) register_api_blueprints(app)
# Register OIDC blueprint at root level
app.register_blueprint(oidc_bp)
def register_error_handlers(app): def register_error_handlers(app):
"""Register error handlers.""" """Register error handlers."""
@@ -169,7 +179,11 @@ def setup_logging(app):
app.logger.setLevel(log_level) app.logger.setLevel(log_level)
# Reduce SQLAlchemy logging noise # Configure SQLAlchemy logging level (also set at module level before DB init)
logging.getLogger('sqlalchemy').setLevel(logging.WARNING) sqlalchemy_log_level = getattr(logging, app.config.get("SQLALCHEMY_LOG_LEVEL", "WARNING"), logging.WARNING)
logging.getLogger('sqlalchemy').setLevel(sqlalchemy_log_level)
logging.getLogger('sqlalchemy.engine').setLevel(sqlalchemy_log_level)
logging.getLogger('sqlalchemy.dialects').setLevel(sqlalchemy_log_level)
logging.getLogger('sqlalchemy.pool').setLevel(sqlalchemy_log_level)
app.logger.info("Application startup") app.logger.info("Application startup")
+964
View File
@@ -0,0 +1,964 @@
"""OIDC (OpenID Connect) API endpoints - Root level blueprint."""
import base64
import json
import secrets
from urllib.parse import urlencode, urlparse, parse_qs
import bcrypt
from flask import Blueprint, request, redirect, jsonify, session, g, current_app, Response
from app.utils.response import api_response
from app.services.oidc_service import (
OIDCService, InvalidClientError, InvalidGrantError, InvalidRequestError
)
from app.services.auth_service import AuthService
from app.extensions import db
from app.models import User, OIDCClient
from app.models.organization import Organization
from app.exceptions.auth_exceptions import InvalidCredentialsError
# Create OIDC blueprint registered at root level
oidc_bp = Blueprint("oidc", __name__)
# ============================================================================
# Helper Functions
# ============================================================================
def get_oidc_config():
"""Get OIDC configuration from app config."""
base_url = current_app.config.get("OIDC_ISSUER_URL", "http://localhost:5000")
return {
"issuer": base_url,
"authorization_endpoint": f"{base_url}/oidc/authorize",
"token_endpoint": f"{base_url}/oidc/token",
"userinfo_endpoint": f"{base_url}/oidc/userinfo",
"jwks_uri": f"{base_url}/oidc/jwks",
"registration_endpoint": f"{base_url}/oidc/register",
"revocation_endpoint": f"{base_url}/oidc/revoke",
"introspection_endpoint": f"{base_url}/oidc/introspect",
"scopes_supported": ["openid", "profile", "email"],
"response_types_supported": ["code"],
"response_modes_supported": ["query"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["RS256"],
"claims_supported": ["sub", "name", "email", "email_verified"],
}
def authenticate_client(client_id, client_secret=None):
"""Authenticate an OIDC client.
Args:
client_id: The client ID
client_secret: Optional client secret
Returns:
OIDCClient instance
Raises:
InvalidClientError: If authentication fails
"""
client = OIDCClient.query.filter_by(client_id=client_id, is_active=True).first()
if not client:
raise InvalidClientError("Invalid client")
if client.is_confidential and client_secret:
if not bcrypt.check_password_hash(client.client_secret_hash, client_secret):
raise InvalidClientError("Invalid client credentials")
return client
def require_valid_token():
"""Validate Bearer token from Authorization header.
Sets g.current_token and g.current_user on success.
Raises:
InvalidGrantError: If token is invalid
"""
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise InvalidGrantError("Invalid token: Missing or invalid Authorization header")
token = auth_header[7:]
claims = OIDCService.validate_access_token(token)
g.current_token = claims
user = User.query.get(claims.get("sub"))
if not user:
raise InvalidGrantError("Invalid token: User not found")
g.current_user = user
def parse_basic_auth():
"""Parse Basic authentication from Authorization header.
Returns:
Tuple of (client_id, client_secret) or (None, None)
"""
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
try:
encoded = auth_header[6:]
decoded = base64.b64decode(encoded).decode("utf-8")
client_id, client_secret = decoded.split(":", 1)
return client_id, client_secret
except Exception:
pass
return None, None
# ============================================================================
# Discovery Endpoint
# ============================================================================
@oidc_bp.route("/.well-known/openid-configuration", methods=["GET"])
def oidc_discovery():
"""OpenID Connect Discovery endpoint.
Returns the OIDC configuration as JSON.
Cache-Control: max-age=86400
No authentication required.
Returns:
200: OIDC discovery document
"""
config = get_oidc_config()
response = jsonify(config)
response.headers["Cache-Control"] = "max-age=86400"
return response, 200
# ============================================================================
# Authorization Endpoint
# ============================================================================
@oidc_bp.route("/oidc/authorize", methods=["GET", "POST"])
def oidc_authorize():
"""OpenID Connect Authorization endpoint.
Initiates the OIDC authentication flow.
GET Parameters:
client_id: The client ID
redirect_uri: The redirect URI
response_type: Must be "code" for authorization code flow
scope: Space-separated scopes (e.g., "openid profile email")
state: Opaque state value for CSRF protection
nonce: Nonce for ID token replay protection
code_challenge: PKCE code challenge
code_challenge_method: PKCE method ("S256" or "plain")
prompt: "login", "consent", "select_account", "none"
max_age: Maximum authentication age in seconds
acr_values: Requested Authentication Context Class Reference
POST Parameters:
Same as GET, plus:
email: User email
password: User password
Returns:
302: Redirect with authorization code or error
200: Login page (GET when not authenticated)
400: Invalid request
"""
# Parse request parameters
if request.method == "GET":
params = request.args.to_dict()
else:
params = request.form.to_dict()
# Extract required parameters
client_id = params.get("client_id")
redirect_uri = params.get("redirect_uri")
response_type = params.get("response_type")
scope = params.get("scope", "")
state = params.get("state", "")
nonce = params.get("nonce", "")
code_challenge = params.get("code_challenge")
code_challenge_method = params.get("code_challenge_method")
# Validate required parameters
errors = []
if not client_id:
errors.append("client_id is required")
if not redirect_uri:
errors.append("redirect_uri is required")
if not response_type:
errors.append("response_type is required")
if errors:
return _redirect_with_error(redirect_uri, "invalid_request", "; ".join(errors), state)
# Validate response_type
if response_type != "code":
return _redirect_with_error(
redirect_uri, "unsupported_response_type",
"Only response_type=code is supported", state
)
# Validate client
client = OIDCClient.query.filter_by(client_id=client_id, is_active=True).first()
if not client:
return _redirect_with_error(redirect_uri, "unauthorized_client", "Invalid client", state)
# Validate redirect URI
if not client.is_redirect_uri_allowed(redirect_uri):
return _redirect_with_error(redirect_uri, "invalid_request", "Invalid redirect_uri", state)
# Validate scopes
requested_scopes = scope.split() if scope else []
allowed_scopes = client.scopes or []
valid_scopes = [s for s in requested_scopes if s in allowed_scopes]
if not valid_scopes:
return _redirect_with_error(redirect_uri, "invalid_scope", "Invalid or no scopes requested", state)
# Check if user is already authenticated via session
user_id = session.get("oidc_user_id")
# Handle POST with credentials
if request.method == "POST" and not user_id:
email = params.get("email")
password = params.get("password")
if not email or not password:
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type,
error="Invalid credentials"
)
try:
user = AuthService.authenticate(email, password)
user_id = user.id
session["oidc_user_id"] = user_id
except InvalidCredentialsError:
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type,
error="Invalid email or password"
)
# If no user, show login page
if not user_id:
return _show_login_page(
client_id=client_id,
redirect_uri=redirect_uri,
scope=scope,
state=state,
nonce=nonce,
response_type=response_type
)
# User is authenticated, generate authorization code
user = User.query.get(user_id)
if not user:
return _redirect_with_error(redirect_uri, "server_error", "User not found", state)
try:
code = OIDCService.generate_authorization_code(
client_id=client_id,
user_id=user_id,
redirect_uri=redirect_uri,
scope=valid_scopes,
state=state,
nonce=nonce,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent"),
)
except Exception as e:
return _redirect_with_error(redirect_uri, "server_error", str(e), state)
# Redirect with authorization code
redirect_params = {"code": code}
if state:
redirect_params["state"] = state
return redirect(f"{redirect_uri}?{urlencode(redirect_params)}")
def _redirect_with_error(redirect_uri, error, error_description, state=None):
"""Redirect to client with error parameters."""
if not redirect_uri:
return api_response(
success=False,
message=error_description,
status=400,
error_type=error.upper(),
error_details={"error": error, "error_description": error_description},
)
params = {
"error": error,
"error_description": error_description,
}
if state:
params["state"] = state
return redirect(f"{redirect_uri}?{urlencode(params)}")
def _show_login_page(client_id, redirect_uri, scope, state, nonce, response_type, error=None):
"""Show the login page for authorization."""
# Simple HTML login page
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Sign In - OIDC Authorization</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; background: #f5f5f5; }}
.container {{ max-width: 400px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); }}
h1 {{ color: #333; font-size: 24px; margin-bottom: 20px; }}
.form-group {{ margin-bottom: 15px; }}
label {{ display: block; margin-bottom: 5px; color: #555; font-weight: bold; }}
input[type="email"], input[type="password"] {{ width: 100%; padding: 10px; border: 1px solid #ddd; border-radius: 4px; box-sizing: border-box; }}
button {{ width: 100%; padding: 12px; background: #007bff; color: white; border: none; border-radius: 4px; cursor: pointer; font-size: 16px; }}
button:hover {{ background: #0056b3; }}
.error {{ color: #dc3545; margin-bottom: 15px; }}
.cancel {{ text-align: center; margin-top: 15px; }}
.cancel a {{ color: #666; text-decoration: none; }}
</style>
</head>
<body>
<div class="container">
<h1>Sign In</h1>
{"<p class='error'>" + error + "</p>" if error else ""}
<form method="POST">
<input type="hidden" name="client_id" value="{client_id}">
<input type="hidden" name="redirect_uri" value="{redirect_uri}">
<input type="hidden" name="scope" value="{scope}">
<input type="hidden" name="state" value="{state}">
<input type="hidden" name="nonce" value="{nonce}">
<input type="hidden" name="response_type" value="{response_type}">
<div class="form-group">
<label for="email">Email</label>
<input type="email" id="email" name="email" required>
</div>
<div class="form-group">
<label for="password">Password</label>
<input type="password" id="password" name="password" required>
</div>
<button type="submit">Sign In</button>
</form>
<p class="cancel">
<a href="{redirect_uri}">Cancel</a>
</p>
</div>
</body>
</html>
"""
return Response(html, mimetype="text/html"), 200
# ============================================================================
# Token Endpoint
# ============================================================================
@oidc_bp.route("/oidc/token", methods=["POST"])
def oidc_token():
"""OpenID Connect Token endpoint.
Exchanges authorization code for tokens or refreshes tokens.
Request body (application/x-www-form-urlencoded):
grant_type: "authorization_code" or "refresh_token"
For authorization_code:
code: The authorization code
redirect_uri: The redirect URI used in authorization
client_id: The client ID
client_secret: The client secret (optional if using Basic auth)
code_verifier: PKCE code verifier (optional)
For refresh_token:
refresh_token: The refresh token
scope: Optional scope override
client_id: The client ID
client_secret: The client secret (optional if using Basic auth)
Authentication:
- Basic auth with client_id:client_secret, or
- client_id + client_secret in request body
Returns:
200: JSON with tokens
400: Invalid request
401: Invalid client
"""
# Parse request body
if request.content_type and "application/x-www-form-urlencoded" in request.content_type:
data = request.form.to_dict()
else:
data = request.json or {}
grant_type = data.get("grant_type")
# Validate grant_type
if not grant_type:
return api_response(
success=False,
message="grant_type is required",
status=400,
error_type="INVALID_REQUEST",
error_details={"error": "invalid_request", "error_description": "grant_type is required"},
)
# Authenticate client
client_id = data.get("client_id")
client_secret = data.get("client_secret")
# Try Basic auth if client_id not in body
if not client_id:
client_id, client_secret = parse_basic_auth()
if not client_id:
# Return 401 with WWW-Authenticate header for Basic auth
response = jsonify({
"error": "invalid_client",
"error_description": "Client authentication required"
})
response.headers["WWW-Authenticate"] = 'Basic realm="OIDC Token Endpoint"'
return response, 401
try:
client = authenticate_client(client_id, client_secret)
except InvalidClientError:
response = jsonify({
"error": "invalid_client",
"error_description": "Invalid client credentials"
})
return response, 401
# Handle authorization_code grant
if grant_type == "authorization_code":
return _handle_authorization_code_grant(data, client)
# Handle refresh_token grant
elif grant_type == "refresh_token":
return _handle_refresh_token_grant(data, client)
# Unsupported grant type
else:
return api_response(
success=False,
message="Unsupported grant_type",
status=400,
error_type="UNSUPPORTED_GRANT_TYPE",
error_details={"error": "unsupported_grant_type", "error_description": f"Grant type '{grant_type}' is not supported"},
)
def _handle_authorization_code_grant(data, client):
"""Handle authorization_code grant type."""
code = data.get("code")
redirect_uri = data.get("redirect_uri")
code_verifier = data.get("code_verifier")
if not code:
return api_response(
success=False,
message="code is required",
status=400,
error_type="INVALID_REQUEST",
error_details={"error": "invalid_request", "error_description": "code is required"},
)
if not redirect_uri:
return api_response(
success=False,
message="redirect_uri is required",
status=400,
error_type="INVALID_REQUEST",
error_details={"error": "invalid_request", "error_description": "redirect_uri is required"},
)
try:
claims, user = OIDCService.validate_authorization_code(
code=code,
client_id=client.client_id,
redirect_uri=redirect_uri,
code_verifier=code_verifier,
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent"),
)
except InvalidGrantError as e:
return api_response(
success=False,
message=str(e),
status=400,
error_type="INVALID_GRANT",
error_details={"error": "invalid_grant", "error_description": str(e)},
)
# Generate tokens
try:
tokens = OIDCService.generate_tokens(
client_id=client.client_id,
user_id=claims["user_id"],
scope=claims["scope"],
nonce=claims.get("nonce"),
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent"),
auth_time=int(__import__("time").time()),
)
except Exception as e:
return api_response(
success=False,
message="Failed to generate tokens",
status=500,
error_type="SERVER_ERROR",
error_details={"error": "server_error", "error_description": str(e)},
)
return api_response(
data=tokens,
message="Tokens issued successfully",
status=200,
)
def _handle_refresh_token_grant(data, client):
"""Handle refresh_token grant type."""
refresh_token = data.get("refresh_token")
scope = data.get("scope")
if not refresh_token:
return api_response(
success=False,
message="refresh_token is required",
status=400,
error_type="INVALID_REQUEST",
error_details={"error": "invalid_request", "error_description": "refresh_token is required"},
)
# Parse scope if provided
scope_list = scope.split() if scope else None
try:
tokens = OIDCService.refresh_access_token(
refresh_token=refresh_token,
client_id=client.client_id,
scope=scope_list,
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent"),
)
except InvalidGrantError as e:
return api_response(
success=False,
message=str(e),
status=400,
error_type="INVALID_GRANT",
error_details={"error": "invalid_grant", "error_description": str(e)},
)
return api_response(
data=tokens,
message="Tokens refreshed successfully",
status=200,
)
# ============================================================================
# UserInfo Endpoint
# ============================================================================
@oidc_bp.route("/oidc/userinfo", methods=["GET", "POST"])
def oidc_userinfo():
"""OpenID Connect UserInfo endpoint.
Returns claims about the authenticated user.
Authorization: Bearer {access_token}
Returns claims based on granted scopes:
- sub: User ID (always included)
- name: User full name (if "profile" scope)
- email: User email (if "email" scope)
- email_verified: Email verification status (if "email" scope)
Returns:
200: User claims
401: Invalid or insufficient token
"""
try:
require_valid_token()
except InvalidGrantError as e:
return api_response(
success=False,
message=str(e),
status=401,
error_type="INVALID_TOKEN",
error_details={"error": "invalid_token", "error_description": str(e)},
)
# Get userinfo
try:
userinfo = OIDCService.get_userinfo(g.current_token.get("access_token", ""))
except Exception as e:
return api_response(
success=False,
message="Failed to get user info",
status=500,
error_type="SERVER_ERROR",
error_details={"error": "server_error", "error_description": str(e)},
)
return api_response(
data=userinfo,
message="User info retrieved successfully",
status=200,
)
# ============================================================================
# JWKS Endpoint
# ============================================================================
@oidc_bp.route("/oidc/jwks", methods=["GET"])
def oidc_jwks():
"""OpenID Connect JSON Web Key Set endpoint.
Returns the public keys used to sign tokens.
Cache-Control: max-age=3600
No authentication required.
Returns:
200: JWKS document
"""
try:
jwks = OIDCService.get_jwks()
except Exception as e:
return api_response(
success=False,
message="Failed to get JWKS",
status=500,
error_type="SERVER_ERROR",
error_details={"error": "server_error", "error_description": str(e)},
)
response = jsonify(jwks)
response.headers["Cache-Control"] = "max-age=3600"
return response, 200
# ============================================================================
# Token Revocation Endpoint
# ============================================================================
@oidc_bp.route("/oidc/revoke", methods=["POST"])
def oidc_revoke():
"""OAuth2 Token Revocation endpoint.
Revokes an access token or refresh token.
Request body (application/x-www-form-urlencoded):
token: The token to revoke
token_type_hint: Optional hint ("access_token" or "refresh_token")
client_id: The client ID
client_secret: The client secret (optional if using Basic auth)
Authentication:
- Basic auth with client_id:client_secret, or
- client_id + client_secret in request body
Returns:
200: Token revoked successfully
400: Invalid request
401: Invalid client
"""
# Parse request body
if request.content_type and "application/x-www-form-urlencoded" in request.content_type:
data = request.form.to_dict()
else:
data = request.json or {}
token = data.get("token")
if not token:
return api_response(
success=False,
message="token is required",
status=400,
error_type="INVALID_REQUEST",
error_details={"error": "invalid_request", "error_description": "token is required"},
)
# Authenticate client
client_id = data.get("client_id")
client_secret = data.get("client_secret")
if not client_id:
client_id, client_secret = parse_basic_auth()
if not client_id:
response = jsonify({
"error": "invalid_client",
"error_description": "Client authentication required"
})
response.headers["WWW-Authenticate"] = 'Basic realm="OIDC Revoke Endpoint"'
return response, 401
try:
client = authenticate_client(client_id, client_secret)
except InvalidClientError:
response = jsonify({
"error": "invalid_client",
"error_description": "Invalid client credentials"
})
return response, 401
token_type_hint = data.get("token_type_hint")
try:
OIDCService.revoke_token(
token=token,
client_id=client.client_id,
token_type_hint=token_type_hint,
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent"),
)
except Exception as e:
# Revocation should succeed even if token is invalid
pass
return api_response(
message="Token revoked successfully",
status=200,
)
# ============================================================================
# Token Introspection Endpoint
# ============================================================================
@oidc_bp.route("/oidc/introspect", methods=["POST"])
def oidc_introspect():
"""OAuth2 Token Introspection endpoint.
Returns information about a token.
Request body (application/x-www-form-urlencoded):
token: The token to introspect
token_type_hint: Optional hint ("access_token" or "refresh_token")
client_id: The client ID
client_secret: The client secret (optional if using Basic auth)
Authentication:
- Basic auth with client_id:client_secret, or
- client_id + client_secret in request body
Returns:
200: Token status and claims
400: Invalid request
401: Invalid client
"""
# Parse request body
if request.content_type and "application/x-www-form-urlencoded" in request.content_type:
data = request.form.to_dict()
else:
data = request.json or {}
token = data.get("token")
if not token:
return api_response(
success=False,
message="token is required",
status=400,
error_type="INVALID_REQUEST",
error_details={"error": "invalid_request", "error_description": "token is required"},
)
# Authenticate client
client_id = data.get("client_id")
client_secret = data.get("client_secret")
if not client_id:
client_id, client_secret = parse_basic_auth()
if not client_id:
response = jsonify({
"error": "invalid_client",
"error_description": "Client authentication required"
})
response.headers["WWW-Authenticate"] = 'Basic realm="OIDC Introspect Endpoint"'
return response, 401
try:
client = authenticate_client(client_id, client_secret)
except InvalidClientError:
response = jsonify({
"error": "invalid_client",
"error_description": "Invalid client credentials"
})
return response, 401
token_type_hint = data.get("token_type_hint")
try:
result = OIDCService.introspect_token(
token=token,
client_id=client.client_id,
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent"),
)
except Exception as e:
return api_response(
success=False,
message="Failed to introspect token",
status=500,
error_type="SERVER_ERROR",
error_details={"error": "server_error", "error_description": str(e)},
)
return api_response(
data=result,
message="Token introspection successful",
status=200,
)
# ============================================================================
# Client Registration Endpoint (Optional)
# ============================================================================
@oidc_bp.route("/oidc/register", methods=["POST"])
def oidc_register():
"""OpenID Connect Client Registration endpoint.
Registers a new OIDC client.
Request body (application/json):
client_name: Name of the client
redirect_uris: List of redirect URIs
token_endpoint_auth_method: "client_secret_basic" or "client_secret_post"
grant_types: List of grant types ["authorization_code", "refresh_token"]
response_types: List of response types ["code"]
scope: Space-separated scopes (default: "openid profile email")
Returns:
201: Client registered successfully
400: Invalid request
"""
data = request.json or {}
# Validate required fields
client_name = data.get("client_name")
redirect_uris = data.get("redirect_uris", [])
if not client_name:
return api_response(
success=False,
message="client_name is required",
status=400,
error_type="INVALID_REQUEST",
error_details={"error": "invalid_request", "error_description": "client_name is required"},
)
if not redirect_uris:
return api_response(
success=False,
message="redirect_uris is required",
status=400,
error_type="INVALID_REQUEST",
error_details={"error": "invalid_request", "error_description": "redirect_uris is required"},
)
# Validate redirect_uris
for uri in redirect_uris:
try:
parsed = urlparse(uri)
if not parsed.scheme or not parsed.netloc:
raise ValueError(f"Invalid redirect URI: {uri}")
except Exception:
return api_response(
success=False,
message=f"Invalid redirect_uri: {uri}",
status=400,
error_type="INVALID_REQUEST",
error_details={"error": "invalid_request", "error_description": f"Invalid redirect_uri: {uri}"},
)
# Generate client credentials
client_id = f"oidc_{secrets.token_urlsafe(16)}"
client_secret = f"secret_{secrets.token_urlsafe(24)}"
client_secret_hash = bcrypt.generate_password_hash(client_secret).decode("utf-8")
# Get organization from request or default
org_id = data.get("organization_id")
if org_id:
organization = Organization.query.get(org_id)
else:
# Get first active organization or create a default one
organization = Organization.query.filter_by(is_active=True).first()
if not organization:
# Create a default organization for the client
organization = Organization(
name=f"OIDC Clients",
slug=f"oidc-clients-{secrets.token_urlsafe(8)}",
)
organization.save()
# Create OIDC client
client = OIDCClient(
organization_id=organization.id,
name=client_name,
client_id=client_id,
client_secret_hash=client_secret_hash,
redirect_uris=redirect_uris,
grant_types=data.get("grant_types", ["authorization_code", "refresh_token"]),
response_types=data.get("response_types", ["code"]),
scopes=data.get("scope", "openid profile email").split(),
token_endpoint_auth_method=data.get("token_endpoint_auth_method", "client_secret_basic"),
is_active=True,
is_confidential=True,
require_pkce=True,
logo_uri=data.get("logo_uri"),
client_uri=data.get("client_uri"),
policy_uri=data.get("policy_uri"),
tos_uri=data.get("tos_uri"),
)
client.save()
# Return client credentials
return api_response(
data={
"client_id": client_id,
"client_secret": client_secret,
"client_id_issued_at": int(__import__("time").time()),
"client_secret_expires_at": 0, # Never expires
"client_name": client_name,
"redirect_uris": redirect_uris,
"token_endpoint_auth_method": data.get("token_endpoint_auth_method", "client_secret_basic"),
"grant_types": client.grant_types,
"response_types": client.response_types,
"scope": " ".join(client.scopes),
},
message="Client registered successfully",
status=201,
)
+1 -7
View File
@@ -12,13 +12,7 @@ from flask_session import Session
db = SQLAlchemy() db = SQLAlchemy()
migrate = Migrate() migrate = Migrate()
bcrypt = Bcrypt() bcrypt = Bcrypt()
cors = CORS( cors = CORS()
supports_credentials=True,
resources={r"/api/*": {"origins": "*"}}, # Apply CORS to all API routes
allow_headers=["Content-Type", "Authorization", "X-Request-ID"],
methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
expose_headers=["X-Request-ID"],
)
ma = Marshmallow() ma = Marshmallow()
limiter = Limiter( limiter = Limiter(
key_func=get_remote_address, key_func=get_remote_address,
+41 -8
View File
@@ -1,17 +1,40 @@
"""CORS middleware configuration.""" """CORS middleware configuration."""
from flask import request from flask import request, make_response
def setup_cors(app, cors): def setup_cors(app):
""" """
Configure CORS for the application. Configure CORS for the application.
Args: Args:
app: Flask application instance app: Flask application instance
cors: Flask-CORS instance
""" """
# CORS is already initialized in extensions.py
# This function provides additional configuration if needed @app.before_request
def handle_preflight():
"""Handle CORS preflight OPTIONS requests."""
if request.method == "OPTIONS":
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all:
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
response.headers["Access-Control-Max-Age"] = "3600"
return response
elif origin and origin in cors_origins:
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
return response
@app.after_request @app.after_request
def after_request_cors(response): def after_request_cors(response):
@@ -19,11 +42,21 @@ def setup_cors(app, cors):
origin = request.headers.get("Origin") origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", []) cors_origins = app.config.get("CORS_ORIGINS", [])
# Allow all origins in development if CORS_ORIGINS is "*" # Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
if cors_origins == "*" or origin in cors_origins: allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
response.headers["Access-Control-Allow-Origin"] = origin if cors_origins != "*" else "*"
if allow_all:
# When allowing all origins, set header to "*"
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID" response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
response.headers["Access-Control-Max-Age"] = "3600" response.headers["Access-Control-Max-Age"] = "3600"
elif origin and origin in cors_origins:
# When allowing specific origins, echo the request origin
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
return response return response
+10
View File
@@ -7,6 +7,11 @@ from app.models.authentication_method import AuthenticationMethod
from app.models.session import Session from app.models.session import Session
from app.models.audit_log import AuditLog from app.models.audit_log import AuditLog
from app.models.oidc_client import OIDCClient from app.models.oidc_client import OIDCClient
from app.models.oidc_authorization_code import OIDCAuthCode
from app.models.oidc_refresh_token import OIDCRefreshToken
from app.models.oidc_session import OIDCSession
from app.models.oidc_token_metadata import OIDCTokenMetadata
from app.models.oidc_audit_log import OIDCAuditLog
__all__ = [ __all__ = [
"BaseModel", "BaseModel",
@@ -17,4 +22,9 @@ __all__ = [
"Session", "Session",
"AuditLog", "AuditLog",
"OIDCClient", "OIDCClient",
"OIDCAuthCode",
"OIDCRefreshToken",
"OIDCSession",
"OIDCTokenMetadata",
"OIDCAuditLog",
] ]
+231
View File
@@ -0,0 +1,231 @@
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
from datetime import datetime
from app.extensions import db
from app.models.base import BaseModel
class OIDCAuditLog(BaseModel):
"""OIDC Audit Log model for comprehensive OIDC event tracking.
This model logs all OIDC-related events for security, compliance,
and debugging purposes.
"""
__tablename__ = "oidc_audit_logs"
# Event type categorization
event_type = db.Column(db.String(100), nullable=False, index=True)
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=True, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=True, index=True
)
# Event outcome
success = db.Column(db.Boolean, default=True, nullable=False, index=True)
# Error details (for failed events)
error_code = db.Column(db.String(100), nullable=True)
error_description = db.Column(db.Text, nullable=True)
# Request context
ip_address = db.Column(db.String(45), nullable=True, index=True)
user_agent = db.Column(db.Text, nullable=True)
request_id = db.Column(db.String(36), nullable=True, index=True)
# Additional event metadata
event_metadata = db.Column(db.JSON, nullable=True)
# Relationships
client = db.relationship("OIDCClient", back_populates="audit_logs")
user = db.relationship("User", back_populates="oidc_audit_logs")
def __repr__(self):
"""String representation of OIDCAuditLog."""
status = "success" if self.success else "failed"
return f"<OIDCAuditLog event={self.event_type} status={status} client={self.client_id}>"
@classmethod
def log_event(cls, event_type, client_id=None, user_id=None, success=True,
error_code=None, error_description=None, ip_address=None,
user_agent=None, request_id=None, event_metadata=None):
"""Log an OIDC event.
Args:
event_type: Type of event (e.g., "authorization_request", "token_issue")
client_id: The OIDC client ID
user_id: The user ID
success: Whether the event was successful
error_code: Error code if event failed
error_description: Error description if event failed
ip_address: Client IP address
user_agent: Client user agent
request_id: Request ID for correlation
event_metadata: Additional event metadata
Returns:
OIDCAuditLog instance
"""
log = cls(
event_type=event_type,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata=event_metadata,
)
db.session.add(log)
db.session.commit()
return log
@classmethod
def log_authorization_request(cls, client_id, user_id, redirect_uri, scope,
ip_address=None, user_agent=None, request_id=None,
success=True, error_code=None, error_description=None):
"""Log an authorization request event."""
return cls.log_event(
event_type="authorization_request",
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={
"redirect_uri": redirect_uri,
"scope": scope,
}
)
@classmethod
def log_token_issue(cls, client_id, user_id, token_type,
ip_address=None, user_agent=None, request_id=None):
"""Log a token issuance event."""
return cls.log_event(
event_type="token_issue",
client_id=client_id,
user_id=user_id,
success=True,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={"token_type": token_type}
)
@classmethod
def log_token_revocation(cls, client_id, user_id, token_type, reason=None,
ip_address=None, user_agent=None, request_id=None):
"""Log a token revocation event."""
return cls.log_event(
event_type="token_revocation",
client_id=client_id,
user_id=user_id,
success=True,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={
"token_type": token_type,
"reason": reason,
}
)
@classmethod
def log_authentication_failure(cls, client_id, error_code, error_description,
ip_address=None, user_agent=None, request_id=None):
"""Log an authentication failure event."""
return cls.log_event(
event_type="authentication_failure",
client_id=client_id,
success=False,
error_code=error_code,
error_description=error_description,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
)
@classmethod
def get_events_for_user(cls, user_id, limit=100):
"""Get audit events for a user.
Args:
user_id: The user ID
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
return cls.query.filter_by(user_id=user_id, deleted_at=None)\
.order_by(cls.created_at.desc())\
.limit(limit)\
.all()
@classmethod
def get_events_for_client(cls, client_id, limit=100):
"""Get audit events for a client.
Args:
client_id: The client ID
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
return cls.query.filter_by(client_id=client_id, deleted_at=None)\
.order_by(cls.created_at.desc())\
.limit(limit)\
.all()
@classmethod
def get_failed_events(cls, client_id=None, user_id=None, start_date=None,
end_date=None, limit=100):
"""Get failed audit events.
Args:
client_id: Optional client ID filter
user_id: Optional user ID filter
start_date: Optional start date filter
end_date: Optional end date filter
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
query = cls.query.filter_by(success=False, deleted_at=None)
if client_id:
query = query.filter_by(client_id=client_id)
if user_id:
query = query.filter_by(user_id=user_id)
if start_date:
query = query.filter(cls.created_at >= start_date)
if end_date:
query = query.filter(cls.created_at <= end_date)
return query.order_by(cls.created_at.desc()).limit(limit).all()
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from app.models.user import User
User.oidc_audit_logs = db.relationship(
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from app.models.oidc_client import OIDCClient
OIDCClient.audit_logs = db.relationship(
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
)
+120
View File
@@ -0,0 +1,120 @@
"""OIDC Authorization Code model for auth code flow."""
from datetime import datetime, timedelta
from app.extensions import db
from app.models.base import BaseModel
class OIDCAuthCode(BaseModel):
"""OIDC Authorization Code model for authorization code flow.
Authorization codes are single-use, short-lived codes used in the
authorization code grant flow. The code is hashed for security.
"""
__tablename__ = "oidc_authorization_codes"
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Authorization code (hashed for security)
code_hash = db.Column(db.String(255), nullable=False)
# Request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True) # Requested scopes
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
code_verifier = db.Column(db.String(255), nullable=True) # For PKCE
# Status tracking
expires_at = db.Column(db.DateTime, nullable=False, index=True)
used_at = db.Column(db.DateTime, nullable=True)
is_used = db.Column(db.Boolean, default=False, nullable=False)
# Request metadata
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
# Relationships
client = db.relationship("OIDCClient", back_populates="authorization_codes")
user = db.relationship("User", back_populates="oidc_auth_codes")
def __repr__(self):
"""String representation of OIDCAuthCode."""
return f"<OIDCAuthCode client_id={self.client_id} user_id={self.user_id} used={self.is_used}>"
def is_expired(self):
"""Check if the authorization code has expired."""
return datetime.utcnow() > self.expires_at
def is_valid(self):
"""Check if the authorization code is valid for use."""
return not self.is_used and not self.is_expired() and self.deleted_at is None
def mark_as_used(self):
"""Mark the authorization code as used."""
self.is_used = True
self.used_at = datetime.utcnow()
db.session.commit()
@classmethod
def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None,
nonce=None, code_verifier=None, ip_address=None, user_agent=None,
lifetime_seconds=600):
"""Create a new authorization code.
Args:
client_id: The OIDC client ID
user_id: The user ID
code_hash: Hashed authorization code
redirect_uri: The redirect URI
scope: Requested scopes
nonce: OIDC nonce
code_verifier: PKCE code verifier
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
Returns:
OIDCAuthCode instance
"""
code = cls(
client_id=client_id,
user_id=user_id,
code_hash=code_hash,
redirect_uri=redirect_uri,
scope=scope,
nonce=nonce,
code_verifier=code_verifier,
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
ip_address=ip_address,
user_agent=user_agent,
)
db.session.add(code)
db.session.commit()
return code
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude code hash
exclude.append("code_hash")
exclude.append("code_verifier")
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from app.models.user import User
User.oidc_auth_codes = db.relationship(
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from app.models.oidc_client import OIDCClient
OIDCClient.authorization_codes = db.relationship(
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
)
+159
View File
@@ -0,0 +1,159 @@
"""OIDC Refresh Token model for token rotation."""
from datetime import datetime
from app.extensions import db
from app.models.base import BaseModel
class OIDCRefreshToken(BaseModel):
"""OIDC Refresh Token model for token refresh and rotation.
Refresh tokens are long-lived credentials used to obtain new access tokens.
They support token rotation for enhanced security.
"""
__tablename__ = "oidc_refresh_tokens"
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Token (hashed for security)
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
# Associated access token ID
access_token_id = db.Column(
db.String(36), db.ForeignKey("sessions.id"), nullable=True, index=True
)
# Token scope
scope = db.Column(db.JSON, nullable=True) # Granted scopes
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
# Revocation tracking
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
# Token rotation metadata
previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation
rotation_count = db.Column(db.Integer, default=0, nullable=False)
# Request metadata
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
# Relationships
client = db.relationship("OIDCClient", back_populates="refresh_tokens")
user = db.relationship("User", back_populates="oidc_refresh_tokens")
access_token = db.relationship("Session", back_populates="oidc_refresh_token")
def __repr__(self):
"""String representation of OIDCRefreshToken."""
return f"<OIDCRefreshToken client_id={self.client_id} user_id={self.user_id} revoked={self.is_revoked()}>"
def is_expired(self):
"""Check if the refresh token has expired."""
return datetime.utcnow() > self.expires_at
def is_revoked(self):
"""Check if the refresh token has been revoked."""
return self.revoked_at is not None
def is_valid(self):
"""Check if the refresh token is valid for use."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason=None):
"""Revoke the refresh token.
Args:
reason: Optional reason for revocation
"""
self.revoked_at = datetime.utcnow()
self.revoked_reason = reason
db.session.commit()
def rotate(self, new_token_hash):
"""Rotate the refresh token (invalidate old, create new).
Args:
new_token_hash: Hash of the new refresh token
Returns:
self for chaining
"""
# Store reference to old token
self.previous_token_hash = self.token_hash
self.token_hash = new_token_hash
self.rotation_count += 1
# Extend expiration on rotation
from datetime import timedelta
self.expires_at = datetime.utcnow() + timedelta(days=30)
db.session.commit()
return self
@classmethod
def create_token(cls, client_id, user_id, token_hash, scope=None,
access_token_id=None, ip_address=None, user_agent=None,
lifetime_seconds=2592000):
"""Create a new refresh token.
Args:
client_id: The OIDC client ID
user_id: The user ID
token_hash: Hashed refresh token
scope: Granted scopes
access_token_id: Associated access token ID
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Token lifetime in seconds (default 30 days)
Returns:
OIDCRefreshToken instance
"""
from datetime import timedelta
token = cls(
client_id=client_id,
user_id=user_id,
token_hash=token_hash,
scope=scope,
access_token_id=access_token_id,
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
ip_address=ip_address,
user_agent=user_agent,
)
db.session.add(token)
db.session.commit()
return token
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude token hashes
exclude.append("token_hash")
exclude.append("previous_token_hash")
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from app.models.user import User
User.oidc_refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from app.models.oidc_client import OIDCClient
OIDCClient.refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
)
# Add relationship back to Session model
from app.models.session import Session
Session.oidc_refresh_token = db.relationship(
"OIDCRefreshToken", back_populates="access_token", uselist=False
)
+162
View File
@@ -0,0 +1,162 @@
"""OIDC Session model for OIDC session tracking."""
from datetime import datetime
from app.extensions import db
from app.models.base import BaseModel
class OIDCSession(BaseModel):
"""OIDC Session model for tracking OIDC authentication sessions.
This model tracks the state during the OIDC authentication flow,
including PKCE parameters and nonce validation.
"""
__tablename__ = "oidc_sessions"
# User reference
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Client reference
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
# State management
state = db.Column(db.String(255), nullable=False, index=True)
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
# Authorization request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True) # Requested scopes
# PKCE parameters
code_challenge = db.Column(db.String(255), nullable=True)
code_challenge_method = db.Column(db.String(10), nullable=True) # "S256" or "plain"
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
authenticated_at = db.Column(db.DateTime, nullable=True)
# Relationships
user = db.relationship("User", back_populates="oidc_sessions")
client = db.relationship("OIDCClient", back_populates="oidc_sessions")
def __repr__(self):
"""String representation of OIDCSession."""
return f"<OIDCSession user_id={self.user_id} client_id={self.client_id} state={self.state[:8]}...>"
def is_expired(self):
"""Check if the OIDC session has expired."""
return datetime.utcnow() > self.expires_at
def is_authenticated(self):
"""Check if the user has been authenticated in this session."""
return self.authenticated_at is not None
def mark_authenticated(self):
"""Mark the session as authenticated."""
self.authenticated_at = datetime.utcnow()
db.session.commit()
def validate_nonce(self, expected_nonce):
"""Validate the nonce matches the expected value.
Args:
expected_nonce: The expected nonce value
Returns:
bool: True if nonce matches
"""
return self.nonce == expected_nonce
def validate_code_challenge(self, code_verifier):
"""Validate the code verifier against the stored code challenge.
Args:
code_verifier: The PKCE code verifier
Returns:
bool: True if code challenge is valid
"""
if not self.code_challenge:
return False
if self.code_challenge_method == "S256":
import hashlib
import base64
# SHA256 hash of code_verifier
digest = hashlib.sha256(code_verifier.encode()).digest()
# Base64 URL encode without padding
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
return self.code_challenge == expected
elif self.code_challenge_method == "plain":
return self.code_challenge == code_verifier
return False
@classmethod
def create_session(cls, user_id, client_id, state, redirect_uri, scope=None,
nonce=None, code_challenge=None, code_challenge_method=None,
lifetime_seconds=600):
"""Create a new OIDC session.
Args:
user_id: The user ID
client_id: The OIDC client ID
state: The state parameter
redirect_uri: The redirect URI
scope: Requested scopes
nonce: OIDC nonce
code_challenge: PKCE code challenge
code_challenge_method: PKCE method ("S256" or "plain")
lifetime_seconds: Session lifetime in seconds
Returns:
OIDCSession instance
"""
from datetime import timedelta
session = cls(
user_id=user_id,
client_id=client_id,
state=state,
redirect_uri=redirect_uri,
scope=scope,
nonce=nonce,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
)
db.session.add(session)
db.session.commit()
return session
@classmethod
def get_by_state(cls, state):
"""Get a session by state parameter.
Args:
state: The state parameter
Returns:
OIDCSession instance or None
"""
return cls.query.filter_by(state=state, deleted_at=None).first()
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from app.models.user import User
User.oidc_sessions = db.relationship(
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from app.models.oidc_client import OIDCClient
OIDCClient.oidc_sessions = db.relationship(
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
)
+192
View File
@@ -0,0 +1,192 @@
"""OIDC Token Metadata model for token revocation tracking."""
import uuid
from datetime import datetime
from app.extensions import db
from app.models.base import BaseModel
class OIDCTokenMetadata(BaseModel):
"""OIDC Token Metadata model for tracking issued tokens.
This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens)
for the purpose of token revocation. The id field matches the JTI (JWT ID) claim.
"""
__tablename__ = "oidc_token_metadata"
# Token identifier (matches JTI in JWT)
id = db.Column(
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Token type
token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token"
# Token identifier for revocation lookup
token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
# Revocation tracking
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
# Relationships
client = db.relationship("OIDCClient", back_populates="token_metadata")
user = db.relationship("User", back_populates="oidc_token_metadata")
def __repr__(self):
"""String representation of OIDCTokenMetadata."""
return f"<OIDCTokenMetadata jti={self.token_jti[:8]}... type={self.token_type} revoked={self.is_revoked()}>"
def is_expired(self):
"""Check if the token has expired."""
return datetime.utcnow() > self.expires_at
def is_revoked(self):
"""Check if the token has been revoked."""
return self.revoked_at is not None
def is_valid(self):
"""Check if the token is valid (not expired and not revoked)."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason=None):
"""Revoke the token.
Args:
reason: Optional reason for revocation
"""
self.revoked_at = datetime.utcnow()
self.revoked_reason = reason
db.session.commit()
@classmethod
def create_metadata(cls, client_id, user_id, token_type, token_jti,
expires_at, ip_address=None, user_agent=None):
"""Create token metadata for tracking.
Args:
client_id: The OIDC client ID
user_id: The user ID
token_type: Type of token ("access_token", "refresh_token", "id_token")
token_jti: JWT ID claim
expires_at: Token expiration datetime
ip_address: Client IP address
user_agent: Client user agent
Returns:
OIDCTokenMetadata instance
"""
metadata = cls(
id=str(uuid.uuid4()),
client_id=client_id,
user_id=user_id,
token_type=token_type,
token_jti=token_jti,
expires_at=expires_at,
)
db.session.add(metadata)
db.session.commit()
return metadata
@classmethod
def get_by_jti(cls, token_jti):
"""Get token metadata by JWT ID.
Args:
token_jti: The JWT ID
Returns:
OIDCTokenMetadata instance or None
"""
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
@classmethod
def revoke_by_jti(cls, token_jti, reason=None):
"""Revoke a token by its JWT ID.
Args:
token_jti: The JWT ID
reason: Optional revocation reason
Returns:
bool: True if token was found and revoked
"""
metadata = cls.get_by_jti(token_jti)
if metadata:
metadata.revoke(reason)
return True
return False
@classmethod
def revoke_all_for_user(cls, user_id, client_id=None, reason=None):
"""Revoke all tokens for a user.
Args:
user_id: The user ID
client_id: Optional client ID to filter by
reason: Optional revocation reason
Returns:
int: Number of tokens revoked
"""
query = cls.query.filter_by(user_id=user_id, deleted_at=None)
if client_id:
query = query.filter_by(client_id=client_id)
tokens = query.filter(cls.revoked_at == None).all()
count = 0
for token in tokens:
token.revoke(reason)
count += 1
return count
@classmethod
def revoke_all_for_client(cls, client_id, user_id=None, reason=None):
"""Revoke all tokens for a client.
Args:
client_id: The client ID
user_id: Optional user ID to filter by
reason: Optional revocation reason
Returns:
int: Number of tokens revoked
"""
query = cls.query.filter_by(client_id=client_id, deleted_at=None)
if user_id:
query = query.filter_by(user_id=user_id)
tokens = query.filter(cls.revoked_at == None).all()
count = 0
for token in tokens:
token.revoke(reason)
count += 1
return count
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from app.models.user import User
User.oidc_token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from app.models.oidc_client import OIDCClient
OIDCClient.token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
)
+11
View File
@@ -4,6 +4,11 @@ from app.services.user_service import UserService
from app.services.organization_service import OrganizationService from app.services.organization_service import OrganizationService
from app.services.session_service import SessionService from app.services.session_service import SessionService
from app.services.audit_service import AuditService from app.services.audit_service import AuditService
from app.services.oidc_service import OIDCService, OIDCError
from app.services.oidc_jwks_service import OIDCJWKSService
from app.services.oidc_token_service import OIDCTokenService
from app.services.oidc_session_service import OIDCSessionService
from app.services.oidc_audit_service import OIDCAuditService
__all__ = [ __all__ = [
"AuthService", "AuthService",
@@ -11,4 +16,10 @@ __all__ = [
"OrganizationService", "OrganizationService",
"SessionService", "SessionService",
"AuditService", "AuditService",
"OIDCService",
"OIDCError",
"OIDCJWKSService",
"OIDCTokenService",
"OIDCSessionService",
"OIDCAuditService",
] ]
+408
View File
@@ -0,0 +1,408 @@
"""OIDC Audit Service for comprehensive OIDC event logging."""
from datetime import datetime
from typing import Dict, List, Optional
from flask import g
from app.models import OIDCAuditLog, OIDCClient, User
from app.exceptions.validation_exceptions import NotFoundError
class OIDCAuditService:
"""Service for OIDC-specific audit logging.
This service provides methods to log all OIDC-related events including:
- Authorization requests and responses
- Token issuance and refresh
- Token revocation
- UserInfo endpoint access
- Authentication failures
"""
# Event type constants
EVENT_AUTHORIZATION_REQUEST = "authorization_request"
EVENT_AUTHORIZATION_RESPONSE = "authorization_response"
EVENT_TOKEN_ISSUE = "token_issue"
EVENT_TOKEN_REFRESH = "token_refresh"
EVENT_TOKEN_REVOCATION = "token_revocation"
EVENT_TOKEN_INTROSPECTION = "token_introspection"
EVENT_USERINFO_ACCESS = "userinfo_access"
EVENT_AUTHENTICATION_FAILURE = "authentication_failure"
EVENT_AUTHORIZATION_FAILURE = "authorization_failure"
EVENT_JWKS_ACCESS = "jwks_access"
EVENT_REGISTRATION = "client_registration"
@classmethod
def _get_request_context(cls) -> Dict:
"""Extract request context for logging.
Returns:
Dictionary with IP, user_agent, and request_id
"""
from flask import request
return {
"ip_address": request.remote_addr if request else None,
"user_agent": request.headers.get("User-Agent") if request else None,
"request_id": g.get("request_id"),
}
@classmethod
def log_event(
cls,
event_type: str,
client_id: str = None,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
metadata: Dict = None
) -> OIDCAuditLog:
"""Log a generic OIDC event.
Args:
event_type: Type of event
client_id: OIDC client ID
user_id: User ID
success: Whether the event was successful
error_code: Error code if failed
error_description: Error description if failed
metadata: Additional event metadata
Returns:
OIDCAuditLog instance
"""
context = cls._get_request_context()
log = OIDCAuditLog.log_event(
event_type=event_type,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
ip_address=context["ip_address"],
user_agent=context["user_agent"],
request_id=context["request_id"],
metadata=metadata,
)
return log
@classmethod
def log_authorization_event(
cls,
client_id: str,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
redirect_uri: str = None,
scope: list = None,
response_type: str = None
) -> OIDCAuditLog:
"""Log an authorization event.
Args:
client_id: OIDC client ID
user_id: User ID (if authenticated)
success: Whether authorization was successful
error_code: Error code if failed
error_description: Error description if failed
redirect_uri: Redirect URI from request
scope: Requested scopes
response_type: Response type (e.g., "code")
Returns:
OIDCAuditLog instance
"""
metadata = {
"redirect_uri": redirect_uri,
"scope": scope,
"response_type": response_type,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_AUTHORIZATION_REQUEST,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_token_event(
cls,
client_id: str,
user_id: str = None,
token_type: str = "access_token",
success: bool = True,
error_code: str = None,
error_description: str = None,
grant_type: str = None,
scopes: list = None
) -> OIDCAuditLog:
"""Log a token issuance or refresh event.
Args:
client_id: OIDC client ID
user_id: User ID
token_type: Type of token issued
success: Whether token issuance was successful
error_code: Error code if failed
error_description: Error description if failed
grant_type: Grant type used (e.g., "authorization_code", "refresh_token")
scopes: Scopes included in the token
Returns:
OIDCAuditLog instance
"""
metadata = {
"token_type": token_type,
"grant_type": grant_type,
"scopes": scopes,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_TOKEN_ISSUE if token_type else cls.EVENT_TOKEN_REFRESH,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_userinfo_event(
cls,
access_token: str = None,
user_id: str = None,
client_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
scopes_claimed: list = None
) -> OIDCAuditLog:
"""Log a UserInfo endpoint access event.
Args:
access_token: Access token used (masked)
user_id: User ID returned
client_id: Client ID making the request
success: Whether access was successful
error_code: Error code if failed
error_description: Error description if failed
scopes_claimed: Scopes claimed in the request
Returns:
OIDCAuditLog instance
"""
# Mask the access token for security
masked_token = None
if access_token:
masked_token = access_token[:8] + "..." + access_token[-4:] if len(access_token) > 12 else "***"
metadata = {
"token_prefix": masked_token,
"scopes_claimed": scopes_claimed,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_USERINFO_ACCESS,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_token_revocation_event(
cls,
client_id: str,
user_id: str = None,
token_type: str = "access_token",
reason: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None
) -> OIDCAuditLog:
"""Log a token revocation event.
Args:
client_id: OIDC client ID
user_id: User ID
token_type: Type of token being revoked
reason: Revocation reason
success: Whether revocation was successful
error_code: Error code if failed
error_description: Error description if failed
Returns:
OIDCAuditLog instance
"""
metadata = {
"token_type": token_type,
"reason": reason,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_TOKEN_REVOCATION,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_authentication_failure(
cls,
client_id: str = None,
error_code: str = "authentication_failed",
error_description: str = "Authentication failed",
user_id: str = None
) -> OIDCAuditLog:
"""Log an authentication failure event.
Args:
client_id: OIDC client ID
error_code: Error code
error_description: Error description
user_id: User ID if known
Returns:
OIDCAuditLog instance
"""
return cls.log_event(
event_type=cls.EVENT_AUTHENTICATION_FAILURE,
client_id=client_id,
user_id=user_id,
success=False,
error_code=error_code,
error_description=error_description,
)
@classmethod
def get_events_for_user(
cls,
user_id: str,
limit: int = 100,
include_deleted: bool = False
) -> List[OIDCAuditLog]:
"""Get audit events for a specific user.
Args:
user_id: User ID
limit: Maximum number of events to return
include_deleted: Include soft-deleted events
Returns:
List of OIDCAuditLog instances
"""
return OIDCAuditLog.get_events_for_user(user_id, limit)
@classmethod
def get_events_for_client(
cls,
client_id: str,
limit: int = 100
) -> List[OIDCAuditLog]:
"""Get audit events for a specific client.
Args:
client_id: Client ID
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
return OIDCAuditLog.get_events_for_client(client_id, limit)
@classmethod
def get_failed_events(
cls,
client_id: str = None,
user_id: str = None,
start_date: datetime = None,
end_date: datetime = None,
limit: int = 100
) -> List[OIDCAuditLog]:
"""Get failed audit events for analysis.
Args:
client_id: Optional client ID filter
user_id: Optional user ID filter
start_date: Optional start date filter
end_date: Optional end date filter
limit: Maximum number of events to return
Returns:
List of failed OIDCAuditLog instances
"""
return OIDCAuditLog.get_failed_events(
client_id=client_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
limit=limit,
)
@classmethod
def get_event_summary(
cls,
client_id: str = None,
days: int = 30
) -> Dict:
"""Get a summary of audit events.
Args:
client_id: Optional client ID filter
days: Number of days to look back
Returns:
Summary dictionary with event counts
"""
from datetime import timedelta
start_date = datetime.utcnow() - timedelta(days=days)
query = OIDCAuditLog.query.filter(
OIDCAuditLog.created_at >= start_date
)
if client_id:
query = query.filter_by(client_id=client_id)
events = query.all()
# Count by event type
event_counts = {}
success_count = 0
failure_count = 0
for event in events:
event_type = event.event_type
event_counts[event_type] = event_counts.get(event_type, 0) + 1
if event.success:
success_count += 1
else:
failure_count += 1
return {
"total_events": len(events),
"successful_events": success_count,
"failed_events": failure_count,
"by_event_type": event_counts,
"period_days": days,
}
+300
View File
@@ -0,0 +1,300 @@
"""OIDC JWKS Service for key management and rotation."""
import uuid
import json
import hashlib
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from flask import current_app
from app.extensions import db
class JWKSKey:
"""Represents a JWKS key entry."""
def __init__(self, kid: str, private_key: str, public_key: str,
algorithm: str = "RS256", created_at: datetime = None,
expires_at: datetime = None, is_active: bool = True):
self.kid = kid
self.private_key = private_key
self.public_key = public_key
self.algorithm = algorithm
self.created_at = created_at or datetime.utcnow()
self.expires_at = expires_at or datetime.utcnow() + timedelta(days=365)
self.is_active = is_active
def to_jwk(self) -> Dict:
"""Convert to JWK format for JWKS endpoint."""
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.backends import default_backend
# Import cryptography here to avoid issues if not installed
try:
# Get public key from PEM
public_key = serialization.load_pem_public_key(
self.public_key.encode(), backend=default_backend()
)
# Get RSA parameters
public_numbers = public_key.public_numbers()
return {
"kty": "RSA",
"kid": self.kid,
"use": "sig",
"alg": self.algorithm,
"n": _base64url_encode(public_numbers.n),
"e": _base64url_encode(public_numbers.e),
}
except ImportError:
# Fallback for when cryptography is not installed
return {
"kty": "RSA",
"kid": self.kid,
"use": "sig",
"alg": self.algorithm,
}
def to_dict(self) -> Dict:
"""Convert to dictionary for storage."""
return {
"kid": self.kid,
"private_key": self.private_key,
"public_key": self.public_key,
"algorithm": self.algorithm,
"created_at": self.created_at.isoformat(),
"expires_at": self.expires_at.isoformat(),
"is_active": self.is_active,
}
@classmethod
def from_dict(cls, data: Dict) -> "JWKSKey":
"""Create from dictionary."""
return cls(
kid=data["kid"],
private_key=data["private_key"],
public_key=data["public_key"],
algorithm=data.get("algorithm", "RS256"),
created_at=datetime.fromisoformat(data["created_at"]),
expires_at=datetime.fromisoformat(data["expires_at"]),
is_active=data.get("is_active", True),
)
def _base64url_encode(value: int) -> str:
"""Encode an integer to base64url format."""
import base64
byte_length = (value.bit_length() + 7) // 8 or 1
encoded = value.to_bytes(byte_length, byteorder="big")
return base64.urlsafe_b64encode(encoded).decode().rstrip("=")
class OIDCJWKSService:
"""Service for managing OIDC signing keys (JWKS).
This service handles RSA key pair generation, rotation, and JWKS document
generation for the OIDC implementation.
"""
_instance = None
_keys: Dict[str, JWKSKey] = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._keys = {}
return cls._instance
@classmethod
def reset(cls):
"""Reset the singleton (for testing)."""
cls._instance = None
cls._keys = {}
def _generate_kid(self, private_key: str) -> str:
"""Generate a key ID from the private key fingerprint."""
kid_hash = hashlib.sha256(private_key.encode()).hexdigest()[:32]
return kid_hash
def _generate_rsa_key_pair(self) -> Tuple[str, str]:
"""Generate a new RSA key pair in PEM format.
Returns:
Tuple of (private_key_pem, public_key_pem)
"""
try:
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
# Generate RSA private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
# Get public key
public_key = private_key.public_key()
# Serialize to PEM
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
).decode()
public_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode()
return private_pem, public_pem
except ImportError:
# Fallback for testing without cryptography
import secrets
return f"private_key_{secrets.token_hex(32)}", f"public_key_{secrets.token_hex(32)}"
def get_jwks(self, include_private_keys: bool = False) -> Dict:
"""Get the JWKS document containing public keys.
Args:
include_private_keys: Whether to include private keys (for internal use only)
Returns:
JWKS document dictionary
"""
now = datetime.utcnow()
keys = []
for kid, key in self._keys.items():
# Only include active, non-expired keys
if key.is_active and key.expires_at > now:
if include_private_keys:
keys.append(key.to_dict())
else:
keys.append(key.to_jwk())
return {
"keys": keys
}
def get_signing_key(self) -> Optional[JWKSKey]:
"""Get the current active signing key.
Returns:
JWKSKey instance or None if no active key
"""
now = datetime.utcnow()
for kid, key in self._keys.items():
if key.is_active and key.expires_at > now:
return key
return None
def get_key_by_kid(self, kid: str) -> Optional[JWKSKey]:
"""Get a specific key by its ID.
Args:
kid: Key ID to look up
Returns:
JWKSKey instance or None if not found
"""
return self._keys.get(kid)
def generate_new_key_pair(self, expires_in_days: int = 365) -> JWKSKey:
"""Generate a new RSA key pair for signing.
Args:
expires_in_days: Days until key expiration
Returns:
JWKSKey instance
"""
private_key, public_key = self._generate_rsa_key_pair()
kid = self._generate_kid(private_key)
now = datetime.utcnow()
key = JWKSKey(
kid=kid,
private_key=private_key,
public_key=public_key,
algorithm="RS256",
created_at=now,
expires_at=now + timedelta(days=expires_in_days),
is_active=True,
)
self._keys[kid] = key
# Deactivate old keys (but keep them for grace period)
for old_kid in self._keys:
if old_kid != kid:
self._keys[old_kid].is_active = False
return key
def rotate_keys(self, grace_period_hours: int = 24) -> Tuple[JWKSKey, List[str]]:
"""Rotate signing keys, keeping previous key active for grace period.
Args:
grace_period_hours: Hours to keep old keys active
Returns:
Tuple of (new_key, list_of_deprecated_kids)
"""
now = datetime.utcnow()
grace_end = now + timedelta(hours=grace_period_hours)
# Mark current key as deprecated
current_key = self.get_signing_key()
deprecated_kids = []
if current_key:
deprecated_kids.append(current_key.kid)
# Keep key active but mark as deprecated
current_key.is_active = False
current_key.expires_at = grace_end
# Generate new key
new_key = self.generate_new_key_pair()
# Clean up expired keys
expired_kids = [
kid for kid, key in self._keys.items()
if key.expires_at < now
]
for kid in expired_kids:
del self._keys[kid]
return new_key, deprecated_kids
def verify_key_exists(self, kid: str) -> bool:
"""Check if a key with the given ID exists and is valid.
Args:
kid: Key ID to check
Returns:
True if key exists and is valid
"""
key = self.get_key_by_kid(kid)
if not key:
return False
now = datetime.utcnow()
return key.is_active and key.expires_at > now
def initialize_with_key(self) -> JWKSKey:
"""Initialize the service with a key if none exists.
Returns:
JWKSKey instance
"""
if not self._keys:
return self.generate_new_key_pair()
return self.get_signing_key()
+745
View File
@@ -0,0 +1,745 @@
"""OIDC Service - Main OIDC service layer."""
import secrets
import hashlib
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from flask import current_app, g
from app.extensions import db
from app.models import (
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
OIDCSession, OIDCTokenMetadata
)
from app.exceptions.validation_exceptions import (
ValidationError, NotFoundError, BadRequestError
)
from app.exceptions.auth_exceptions import UnauthorizedError, InvalidTokenError
from app.services.oidc_token_service import OIDCTokenService
from app.services.oidc_session_service import OIDCSessionService
from app.services.oidc_audit_service import OIDCAuditService
from app.services.oidc_jwks_service import OIDCJWKSService
class OIDCError(Exception):
"""Base exception for OIDC errors."""
def __init__(self, error: str, error_description: str = None, status_code: int = 400):
self.error = error
self.error_description = error_description
self.status_code = status_code
class InvalidClientError(OIDCError):
"""Raised when client authentication fails."""
def __init__(self, error_description: str = "Invalid client"):
super().__init__("invalid_client", error_description, 401)
class InvalidGrantError(OIDCError):
"""Raised when grant is invalid."""
def __init__(self, error_description: str = "Invalid grant"):
super().__init__("invalid_grant", error_description, 400)
class InvalidRequestError(OIDCError):
"""Raised when request is malformed."""
def __init__(self, error_description: str = "Invalid request"):
super().__init__("invalid_request", error_description, 400)
class OIDCService:
"""Main OIDC service handling all OpenID Connect operations.
This service provides:
- Authorization code generation and validation
- Token generation (access, refresh, ID tokens)
- Token refresh with rotation
- Token validation and introspection
- Token revocation
"""
@staticmethod
def _generate_code() -> str:
"""Generate a secure authorization code.
Returns:
URL-safe base64 encoded code
"""
return secrets.token_urlsafe(32)
@staticmethod
def _hash_value(value: str) -> str:
"""Hash a value for secure storage.
Args:
value: Value to hash
Returns:
SHA256 hash
"""
return hashlib.sha256(value.encode()).hexdigest()
@classmethod
def generate_authorization_code(
cls,
client_id: str,
user_id: str,
redirect_uri: str,
scope: list,
state: str,
nonce: str,
code_challenge: str = None,
code_challenge_method: str = None,
ip_address: str = None,
user_agent: str = None
) -> str:
"""Generate an authorization code for the auth code flow.
Args:
client_id: OIDC client ID
user_id: User ID
redirect_uri: Redirect URI
scope: Requested scopes
state: State parameter
nonce: Nonce for ID token
code_challenge: PKCE code challenge
code_challenge_method: PKCE method ("S256" or "plain")
ip_address: Client IP address
user_agent: Client user agent
Returns:
Authorization code string
Raises:
ValidationError: If parameters are invalid
NotFoundError: If client not found
"""
# Validate client exists and is active
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
raise NotFoundError("Client not found")
if not client.is_active:
raise ValidationError("Client is not active")
# Validate redirect URI
if not client.is_redirect_uri_allowed(redirect_uri):
raise ValidationError("Invalid redirect_uri")
# Validate scopes
allowed_scopes = client.scopes or []
valid_scopes = [s for s in scope if s in allowed_scopes]
if not valid_scopes:
raise ValidationError("Invalid scopes")
# Generate authorization code
code = cls._generate_code()
code_hash = cls._hash_value(code)
# Create auth code record
auth_code = OIDCAuthCode.create_code(
client_id=client.id,
user_id=user_id,
code_hash=code_hash,
redirect_uri=redirect_uri,
scope=valid_scopes,
nonce=nonce,
code_verifier=code_challenge, # Store for validation
ip_address=ip_address,
user_agent=user_agent,
lifetime_seconds=600, # 10 minutes
)
# Log authorization event
OIDCAuditService.log_authorization_event(
client_id=client_id,
user_id=user_id,
success=True,
redirect_uri=redirect_uri,
scope=valid_scopes,
)
return code
@classmethod
def validate_authorization_code(
cls,
code: str,
client_id: str,
redirect_uri: str,
code_verifier: str = None,
ip_address: str = None,
user_agent: str = None
) -> Tuple[Dict, User]:
"""Validate and consume an authorization code.
Args:
code: Authorization code
client_id: OIDC client ID
redirect_uri: Redirect URI
code_verifier: PKCE code verifier (required if PKCE was used)
ip_address: Client IP address
user_agent: Client user agent
Returns:
Tuple of (claims dict, User instance)
Raises:
InvalidGrantError: If code is invalid
ValidationError: If PKCE validation fails
"""
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
raise InvalidGrantError("Invalid client")
# Hash the provided code and find matching auth code
code_hash = cls._hash_value(code)
auth_code = OIDCAuthCode.query.filter_by(
code_hash=code_hash,
client_id=client.id,
deleted_at=None
).first()
if not auth_code:
OIDCAuditService.log_authorization_event(
client_id=client_id,
success=False,
error_code="invalid_grant",
error_description="Invalid or expired authorization code",
)
raise InvalidGrantError("Invalid or expired authorization code")
# Check if already used
if auth_code.is_used:
OIDCAuditService.log_authorization_event(
client_id=client_id,
user_id=auth_code.user_id,
success=False,
error_code="invalid_grant",
error_description="Authorization code already used",
)
raise InvalidGrantError("Authorization code already used")
# Check expiration
if auth_code.is_expired():
OIDCAuditService.log_authorization_event(
client_id=client_id,
user_id=auth_code.user_id,
success=False,
error_code="invalid_grant",
error_description="Authorization code expired",
)
raise InvalidGrantError("Authorization code expired")
# Validate redirect URI
if auth_code.redirect_uri != redirect_uri:
raise InvalidGrantError("Invalid redirect_uri")
# Validate PKCE if required
if client.require_pkce and auth_code.code_verifier:
if not code_verifier:
raise ValidationError("code_verifier is required")
# Verify code verifier
expected_challenge = cls._compute_code_challenge(code_verifier, "S256")
if expected_challenge != auth_code.code_verifier:
OIDCAuditService.log_authorization_event(
client_id=client_id,
user_id=auth_code.user_id,
success=False,
error_code="invalid_grant",
error_description="Invalid code_verifier",
)
raise InvalidGrantError("Invalid code_verifier")
# Mark code as used
auth_code.mark_as_used()
# Get user
user = User.query.get(auth_code.user_id)
if not user:
raise InvalidGrantError("User not found")
claims = {
"user_id": auth_code.user_id,
"client_id": client_id,
"redirect_uri": redirect_uri,
"scope": auth_code.scope,
"nonce": auth_code.nonce,
}
return claims, user
@classmethod
def _compute_code_challenge(cls, verifier: str, method: str = "S256") -> str:
"""Compute PKCE code challenge from verifier.
Args:
verifier: Code verifier
method: Challenge method
Returns:
Code challenge
"""
import hashlib
import base64
if method == "S256":
digest = hashlib.sha256(verifier.encode()).digest()
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
return verifier
@classmethod
def generate_tokens(
cls,
client_id: str,
user_id: str,
scope: list,
nonce: str = None,
refresh_token: str = None,
ip_address: str = None,
user_agent: str = None,
auth_time: int = None
) -> Dict:
"""Generate access token, ID token, and refresh token.
Args:
client_id: OIDC client ID
user_id: User ID
scope: Granted scopes
nonce: Nonce for ID token
refresh_token: Existing refresh token (for rotation)
ip_address: Client IP address
user_agent: Client user agent
auth_time: Authentication time
Returns:
Dictionary with tokens
"""
import hashlib
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
raise InvalidClientError()
# Generate access token
access_token_jti = OIDCTokenService._generate_jti()
access_token = OIDCTokenService.create_access_token(
client_id=client_id,
user_id=user_id,
scope=scope,
jti=access_token_jti,
)
# Generate ID token
id_token = OIDCTokenService.create_id_token(
client_id=client_id,
user_id=user_id,
nonce=nonce,
scope=scope,
access_token=access_token,
auth_time=auth_time,
)
# Generate or rotate refresh token
if "refresh_token" in (client.grant_types or []):
if refresh_token:
# Rotate existing refresh token
refresh_token_obj = OIDCRefreshToken.query.filter_by(
token_hash=hashlib.sha256(refresh_token.encode()).hexdigest(),
deleted_at=None
).first()
if refresh_token_obj and refresh_token_obj.is_valid():
# Create new refresh token
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
client_id=client_id,
user_id=user_id,
scope=scope,
access_token_id=access_token_jti,
)
# Rotate in database
refresh_token_obj.rotate(new_hash)
final_refresh_token = new_refresh
else:
final_refresh_token = None
else:
# Create new refresh token
final_refresh_token, refresh_hash = OIDCTokenService.create_refresh_token(
client_id=client_id,
user_id=user_id,
scope=scope,
access_token_id=access_token_jti,
)
# Store refresh token
OIDCRefreshToken.create_token(
client_id=client.id,
user_id=user_id,
token_hash=refresh_hash,
scope=scope,
access_token_id=access_token_jti,
ip_address=ip_address,
user_agent=user_agent,
lifetime_seconds=client.refresh_token_lifetime or 2592000,
)
else:
final_refresh_token = None
# Store token metadata
client_db_id = client.id
# Access token metadata
OIDCTokenMetadata.create_metadata(
client_id=client_db_id,
user_id=user_id,
token_type="access_token",
token_jti=access_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
)
# ID token metadata (using access token JTI as reference)
id_token_jti = OIDCTokenService._generate_jti()
OIDCTokenMetadata.create_metadata(
client_id=client_db_id,
user_id=user_id,
token_type="id_token",
token_jti=id_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.id_token_lifetime or 3600),
)
# Log token event
OIDCAuditService.log_token_event(
client_id=client_id,
user_id=user_id,
token_type="access_token",
success=True,
grant_type="authorization_code",
scopes=scope,
)
result = {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": client.access_token_lifetime or 3600,
"id_token": id_token,
}
if final_refresh_token:
result["refresh_token"] = final_refresh_token
return result
@classmethod
def refresh_access_token(
cls,
refresh_token: str,
client_id: str,
scope: list = None,
ip_address: str = None,
user_agent: str = None
) -> Dict:
"""Refresh an access token with token rotation.
Args:
refresh_token: The refresh token
client_id: OIDC client ID
scope: Optional scope override
ip_address: Client IP address
user_agent: Client user agent
Returns:
Dictionary with new tokens
Raises:
InvalidGrantError: If refresh token is invalid
"""
import hashlib
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
raise InvalidClientError()
# Find refresh token
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
refresh_token_obj = OIDCRefreshToken.query.filter_by(
token_hash=token_hash,
deleted_at=None
).first()
if not refresh_token_obj:
OIDCAuditService.log_token_event(
client_id=client_id,
success=False,
error_code="invalid_grant",
error_description="Invalid refresh token",
)
raise InvalidGrantError("Invalid refresh token")
# Check if valid
if not refresh_token_obj.is_valid():
OIDCAuditService.log_token_event(
client_id=client_id,
user_id=refresh_token_obj.user_id,
success=False,
error_code="invalid_grant",
error_description="Refresh token expired or revoked",
)
raise InvalidGrantError("Refresh token expired or revoked")
# Validate client matches
if refresh_token_obj.client_id != client.id:
raise InvalidGrantError("Client mismatch")
# Get original scope or use provided
granted_scope = scope or (refresh_token_obj.scope or [])
# Generate new access token
access_token_jti = OIDCTokenService._generate_jti()
access_token = OIDCTokenService.create_access_token(
client_id=client_id,
user_id=refresh_token_obj.user_id,
scope=granted_scope,
jti=access_token_jti,
)
# Generate new ID token
id_token = OIDCTokenService.create_id_token(
client_id=client_id,
user_id=refresh_token_obj.user_id,
scope=granted_scope,
access_token=access_token,
)
# Rotate refresh token
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
client_id=client_id,
user_id=refresh_token_obj.user_id,
scope=granted_scope,
access_token_id=access_token_jti,
)
refresh_token_obj.rotate(new_hash)
# Store new token metadata
OIDCTokenMetadata.create_metadata(
client_id=client.id,
user_id=refresh_token_obj.user_id,
token_type="access_token",
token_jti=access_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
)
# Log refresh event
OIDCAuditService.log_token_event(
client_id=client_id,
user_id=refresh_token_obj.user_id,
token_type="access_token",
success=True,
grant_type="refresh_token",
scopes=granted_scope,
)
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": client.access_token_lifetime or 3600,
"id_token": id_token,
"refresh_token": new_refresh,
}
@classmethod
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
"""Validate an access token and return its claims.
Args:
token: JWT access token
client_id: Optional client ID to validate audience
Returns:
Token claims
Raises:
InvalidTokenError: If token is invalid
"""
try:
claims = OIDCTokenService.validate_access_token(token, client_id)
return claims
except Exception as e:
OIDCAuditService.log_event(
event_type="token_validation",
client_id=client_id,
success=False,
error_code="invalid_token",
error_description=str(e),
)
raise InvalidTokenError(str(e))
@classmethod
def revoke_token(
cls,
token: str,
client_id: str,
token_type_hint: str = None,
ip_address: str = None,
user_agent: str = None
) -> bool:
"""Revoke a token.
Args:
token: Token to revoke
client_id: OIDC client ID
token_type_hint: Hint about token type
ip_address: Client IP address
user_agent: Client user agent
Returns:
True if token was revoked
"""
import hashlib
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
raise InvalidClientError()
revoked = False
token_hash = hashlib.sha256(token.encode()).hexdigest()
# Try to revoke as refresh token
if token_type_hint in (None, "refresh_token"):
refresh_token = OIDCRefreshToken.query.filter_by(
token_hash=token_hash,
deleted_at=None
).first()
if refresh_token:
refresh_token.revoke(reason="revoked_by_client")
revoked = True
OIDCAuditService.log_token_revocation_event(
client_id=client_id,
user_id=refresh_token.user_id,
token_type="refresh_token",
reason="revoked_by_client",
)
# Try to revoke as access token (JTI lookup)
if not revoked or token_type_hint in (None, "access_token"):
try:
# Decode token to get JTI
claims = OIDCTokenService.decode_token(token)
jti = claims.get("jti")
if jti:
revoked_at = OIDCTokenMetadata.revoke_by_jti(
jti,
reason="revoked_by_client"
)
if revoked_at:
revoked = True
OIDCAuditService.log_token_revocation_event(
client_id=client_id,
user_id=claims.get("sub"),
token_type="access_token",
reason="revoked_by_client",
)
except Exception:
pass
return revoked
@classmethod
def introspect_token(
cls,
token: str,
client_id: str = None,
ip_address: str = None,
user_agent: str = None
) -> Dict:
"""Introspect a token and return its status and claims.
Args:
token: Token to introspect
client_id: Client ID for validation
ip_address: Client IP address
user_agent: Client user agent
Returns:
Introspection response
"""
result = OIDCTokenService.introspect_token(token, client_id)
# Log introspection
OIDCAuditService.log_event(
event_type="token_introspection",
client_id=client_id,
user_id=result.get("sub"),
success=result.get("active", False),
metadata={"active": result.get("active")},
)
return result
@classmethod
def get_jwks(cls) -> Dict:
"""Get the JWKS document.
Returns:
JWKS document
"""
jwks_service = OIDCJWKSService()
return jwks_service.get_jwks()
@classmethod
def get_userinfo(cls, access_token: str) -> Dict:
"""Get user information using access token.
Args:
access_token: Access token
Returns:
User information dictionary
"""
claims = cls.validate_access_token(access_token)
user_id = claims.get("sub")
user = User.query.get(user_id)
if not user:
raise NotFoundError("User not found")
# Get scopes from token
scope_str = claims.get("scope", "")
scopes = scope_str.split() if scope_str else []
userinfo = {"sub": user_id}
# Add claims based on scope
if "profile" in scopes and user.full_name:
userinfo["name"] = user.full_name
if "email" in scopes:
userinfo["email"] = user.email
userinfo["email_verified"] = user.email_verified
# Log userinfo access
OIDCAuditService.log_userinfo_event(
access_token=access_token,
user_id=user_id,
client_id=claims.get("client_id"),
success=True,
scopes_claimed=scopes,
)
return userinfo
+288
View File
@@ -0,0 +1,288 @@
"""OIDC Session Service for session management during OIDC flow."""
import secrets
from datetime import datetime, timedelta
from typing import Dict, Optional, Tuple
from flask import current_app, g
from app.extensions import db
from app.models import OIDCSession, OIDCClient, User
from app.exceptions.validation_exceptions import NotFoundError, ValidationError
class OIDCSessionService:
"""Service for managing OIDC authentication sessions.
This service handles:
- Creating OIDC sessions during authorization flow
- Validating sessions with state and nonce
- Managing PKCE code challenges
- Cleaning up expired sessions
"""
@staticmethod
def _generate_state() -> str:
"""Generate a secure state parameter.
Returns:
URL-safe base64 encoded state
"""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_nonce() -> str:
"""Generate a secure nonce for OIDC.
Returns:
URL-safe base64 encoded nonce
"""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_code_challenge(verifier: str, method: str = "S256") -> str:
"""Generate a PKCE code challenge from verifier.
Args:
verifier: The code verifier
method: Challenge method ("S256" or "plain")
Returns:
Code challenge string
"""
import hashlib
import base64
if method == "S256":
digest = hashlib.sha256(verifier.encode()).digest()
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
elif method == "plain":
return verifier
else:
raise ValueError(f"Unsupported code challenge method: {method}")
@classmethod
def validate_code_verifier(cls, code_verifier: str, code_challenge: str,
method: str = "S256") -> bool:
"""Validate a PKCE code verifier against the stored challenge.
Args:
code_verifier: The code verifier from the token request
code_challenge: The code challenge from the authorization request
method: The challenge method used
Returns:
True if validation succeeds
"""
if not code_verifier or not code_challenge:
return False
# Validate code verifier length (43-128 characters)
if method == "S256" and not (43 <= len(code_verifier) <= 128):
return False
# Calculate expected challenge
expected_challenge = cls._generate_code_challenge(code_verifier, method)
return secrets.compare_digest(expected_challenge, code_challenge)
@classmethod
def create_session(
cls,
user_id: str,
client_id: str,
state: str = None,
nonce: str = None,
redirect_uri: str = None,
scope: list = None,
code_challenge: str = None,
code_challenge_method: str = None,
lifetime_seconds: int = 600
) -> OIDCSession:
"""Create a new OIDC session for the authorization flow.
Args:
user_id: The user ID
client_id: The OIDC client ID
state: State parameter (generated if not provided)
nonce: Nonce for ID token validation (generated if not provided)
redirect_uri: Redirect URI from authorization request
scope: Requested scopes
code_challenge: PKCE code challenge
code_challenge_method: PKCE method ("S256" or "plain")
lifetime_seconds: Session lifetime in seconds
Returns:
OIDCSession instance
"""
# Generate state and nonce if not provided
state = state or cls._generate_state()
nonce = nonce or cls._generate_nonce()
session = OIDCSession.create_session(
user_id=user_id,
client_id=client_id,
state=state,
nonce=nonce,
redirect_uri=redirect_uri,
scope=scope,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
lifetime_seconds=lifetime_seconds,
)
return session
@classmethod
def validate_session(cls, state: str, nonce: str = None) -> Tuple[OIDCSession, User]:
"""Validate an OIDC session by state and optionally nonce.
Args:
state: The state parameter
nonce: The nonce to validate (optional)
Returns:
Tuple of (OIDCSession, User)
Raises:
ValidationError: If session is invalid
NotFoundError: If session not found
"""
session = OIDCSession.get_by_state(state)
if not session:
raise NotFoundError("OIDC session not found or expired")
if session.is_expired():
raise ValidationError("OIDC session has expired")
# Validate nonce if provided
if nonce and not session.validate_nonce(nonce):
raise ValidationError("Invalid nonce")
# Get user
user = User.query.get(session.user_id)
if not user:
raise NotFoundError("User not found")
return session, user
@classmethod
def validate_pkce(cls, session: OIDCSession, code_verifier: str) -> bool:
"""Validate PKCE code verifier against the session's code challenge.
Args:
session: OIDCSession instance
code_verifier: The code verifier from token request
Returns:
True if validation succeeds
Raises:
ValidationError: If PKCE validation fails
"""
if not session.code_challenge:
# No PKCE was used, skip validation
return True
if not code_verifier:
raise ValidationError("code_verifier is required")
is_valid = session.validate_code_challenge(code_verifier)
if not is_valid:
raise ValidationError("Invalid code_verifier")
return True
@classmethod
def mark_session_authenticated(cls, session: OIDCSession) -> OIDCSession:
"""Mark a session as authenticated (user has logged in).
Args:
session: OIDCSession instance
Returns:
Updated OIDCSession instance
"""
session.mark_authenticated()
return session
@classmethod
def cleanup_expired_sessions(cls, older_than_hours: int = 24) -> int:
"""Remove expired OIDC sessions.
Args:
older_than_hours: Only delete sessions expired more than this many hours ago
Returns:
Number of sessions deleted
"""
from datetime import timedelta
cutoff = datetime.utcnow() - timedelta(hours=older_than_hours)
# Get expired sessions
expired_sessions = OIDCSession.query.filter(
OIDCSession.expires_at < datetime.utcnow(),
OIDCSession.deleted_at == None
).all()
count = 0
for session in expired_sessions:
# Only hard delete if past the grace period
if session.expires_at < cutoff:
session.delete()
count += 1
return count
@classmethod
def get_session_by_state(cls, state: str) -> Optional[OIDCSession]:
"""Get an OIDC session by state.
Args:
state: The state parameter
Returns:
OIDCSession instance or None
"""
return OIDCSession.get_by_state(state)
@classmethod
def validate_redirect_uri(cls, client_id: str, redirect_uri: str) -> bool:
"""Validate that a redirect URI is allowed for a client.
Args:
client_id: The OIDC client ID
redirect_uri: The redirect URI to validate
Returns:
True if redirect URI is allowed
"""
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return False
return client.is_redirect_uri_allowed(redirect_uri)
@classmethod
def validate_scopes(cls, client_id: str, requested_scopes: list) -> list:
"""Validate and filter scopes against client's allowed scopes.
Args:
client_id: The OIDC client ID
requested_scopes: List of requested scopes
Returns:
List of allowed scopes
"""
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return []
allowed_scopes = client.scopes or []
# Filter to only allowed scopes
valid_scopes = [s for s in requested_scopes if s in allowed_scopes]
return valid_scopes
+439
View File
@@ -0,0 +1,439 @@
"""OIDC Token Service for JWT token generation and validation."""
import hashlib
import base64
import secrets
from datetime import datetime, timedelta
from typing import Dict, Optional, Any
import jwt
from flask import current_app, g
from app.models import User, OIDCClient
from app.services.oidc_jwks_service import OIDCJWKSService
class OIDCTokenService:
"""Service for generating and validating OIDC tokens.
This service handles:
- Access token creation (JWT)
- ID token creation (JWT)
- Refresh token creation (opaque)
- Token signature verification
- Hash generation for PKCE claims (at_hash, c_hash)
"""
@staticmethod
def _generate_jti() -> str:
"""Generate a unique JWT ID."""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_opaque_token(length: int = 43) -> str:
"""Generate an opaque token (for refresh tokens).
Args:
length: Length of the token
Returns:
URL-safe base64 encoded token
"""
return secrets.token_urlsafe(length)
@staticmethod
def _hash_token(token: str) -> str:
"""Hash a token for secure storage.
Args:
token: Token to hash
Returns:
SHA256 hash of the token
"""
return hashlib.sha256(token.encode()).hexdigest()
@staticmethod
def _base64url_encode(data: bytes) -> str:
"""Encode bytes to base64url format without padding.
Args:
data: Bytes to encode
Returns:
Base64url encoded string
"""
return base64.urlsafe_b64encode(data).decode().rstrip("=")
@staticmethod
def create_at_hash(access_token: str) -> str:
"""Create the at_hash claim for ID token.
Implements OIDC spec for access token hash generation.
Hash is the left-most half of the hash of the ASCII representation
of the access token.
Args:
access_token: The access token string
Returns:
Base64url encoded hash
"""
# Hash the access token using SHA256
hash_digest = hashlib.sha256(access_token.encode()).digest()
# Take left-most half of the hash
half_length = len(hash_digest) // 2
left_half = hash_digest[:half_length]
# Base64url encode
return OIDCTokenService._base64url_encode(left_half)
@staticmethod
def create_c_hash(code: str) -> str:
"""Create the c_hash claim for ID token.
Implements OIDC spec for authorization code hash generation.
Args:
code: The authorization code string
Returns:
Base64url encoded hash
"""
# Hash the code using SHA256
hash_digest = hashlib.sha256(code.encode()).digest()
# Take left-most half of the hash
half_length = len(hash_digest) // 2
left_half = hash_digest[:half_length]
# Base64url encode
return OIDCTokenService._base64url_encode(left_half)
@staticmethod
def _get_issuer() -> str:
"""Get the OIDC issuer URL."""
return current_app.config.get("OIDC_ISSUER_URL", "http://localhost:5000")
@staticmethod
def _get_token_lifetime(client: OIDCClient, token_type: str) -> int:
"""Get the token lifetime in seconds for a client.
Args:
client: OIDCClient instance
token_type: Type of token ("access_token", "refresh_token", "id_token")
Returns:
Lifetime in seconds
"""
lifetimes = {
"access_token": client.access_token_lifetime or 3600,
"refresh_token": client.refresh_token_lifetime or 2592000,
"id_token": client.id_token_lifetime or 3600,
}
return lifetimes.get(token_type, 3600)
@classmethod
def create_access_token(cls, client_id: str, user_id: str, scope: list,
jti: str = None) -> str:
"""Create a JWT access token.
Args:
client_id: The OIDC client ID
user_id: The user ID (subject)
scope: List of granted scopes
jti: Optional JWT ID (generated if not provided)
Returns:
JWT access token string
"""
jti = jti or cls._generate_jti()
now = datetime.utcnow()
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
claims = {
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
"iat": int(now.timestamp()),
"nbf": int(now.timestamp()),
"jti": jti,
"client_id": client_id,
"scope": " ".join(scope) if isinstance(scope, list) else scope,
}
# Get signing key
jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key()
if not signing_key:
raise ValueError("No signing key available")
# Sign with RS256
token = jwt.encode(
claims,
signing_key.private_key,
algorithm="RS256",
headers={"kid": signing_key.kid}
)
return token
@classmethod
def create_id_token(cls, client_id: str, user_id: str, nonce: str = None,
scope: list = None, access_token: str = None,
auth_time: int = None) -> str:
"""Create a JWT ID token.
Args:
client_id: The OIDC client ID
user_id: The user ID (subject)
nonce: Nonce for replay protection
scope: Requested/Granted scopes
access_token: Associated access token (for at_hash)
auth_time: Authentication time (Unix timestamp)
Returns:
JWT ID token string
"""
now = datetime.utcnow()
auth_time = auth_time or int(now.timestamp())
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
# Get user for claims
user = User.query.get(user_id)
claims = {
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
"iat": int(now.timestamp()),
"auth_time": auth_time,
}
# Add nonce if provided
if nonce:
claims["nonce"] = nonce
# Add at_hash if access token provided
if access_token:
claims["at_hash"] = cls.create_at_hash(access_token)
# Add standard claims if user exists
if user:
if user.email:
claims["email"] = user.email
claims["email_verified"] = user.email_verified
if user.full_name:
claims["name"] = user.full_name
# Add scope if provided
if scope:
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
# Get signing key
jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key()
if not signing_key:
raise ValueError("No signing key available")
# Sign with RS256
token = jwt.encode(
claims,
signing_key.private_key,
algorithm="RS256",
headers={"kid": signing_key.kid}
)
return token
@classmethod
def create_refresh_token(cls, client_id: str, user_id: str,
scope: list = None, access_token_id: str = None) -> str:
"""Create an opaque refresh token.
Args:
client_id: The OIDC client ID
user_id: The user ID
scope: List of granted scopes
access_token_id: Associated access token ID
Returns:
Opaque refresh token string
"""
token = cls._generate_opaque_token()
# Hash for storage
token_hash = cls._hash_token(token)
return token, token_hash
@classmethod
def verify_token_signature(cls, token: str) -> Dict:
"""Verify the signature of a JWT token.
Args:
token: JWT token string
Returns:
Decoded token claims
Raises:
jwt.InvalidSignatureError: If signature verification fails
jwt.ExpiredSignatureError: If token is expired
jwt.InvalidTokenError: If token is invalid
"""
# Get the JWKS with public keys
jwks_service = OIDCJWKSService()
jwks = jwks_service.get_jwks()
# Get the key ID from token header
try:
unverified_header = jwt.get_unverified_header(token)
except jwt.DecodeError:
raise jwt.InvalidTokenError("Invalid token header")
kid = unverified_header.get("kid")
# Find the matching public key
public_key = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
try:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
public_key = serialization.load_pem_public_key(
key["public_key"].encode() if isinstance(key["public_key"], str)
else key["public_key"],
backend=default_backend()
)
break
except (ImportError, Exception):
continue
if not public_key:
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
# Verify the signature
claims = jwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=None, # We'll validate audience separately
issuer=cls._get_issuer(),
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False, # Handle audience manually
"verify_iss": False, # Handle issuer manually
}
)
return claims
@classmethod
def decode_token(cls, token: str, verify: bool = False) -> Dict:
"""Decode a JWT token without verification (for debugging).
Args:
token: JWT token string
verify: Whether to verify signature
Returns:
Decoded token claims
"""
if verify:
return cls.verify_token_signature(token)
return jwt.decode(
token,
options={
"verify_signature": False,
"verify_exp": False,
}
)
@classmethod
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
"""Validate an access token and return its claims.
Args:
token: JWT access token
client_id: Optional client ID to validate audience
Returns:
Token claims dictionary
Raises:
jwt.InvalidTokenError: If token is invalid
ValueError: If token is expired or audience mismatch
"""
claims = cls.verify_token_signature(token)
# Check expiration
if claims.get("exp", 0) < datetime.utcnow().timestamp():
raise ValueError("Token has expired")
# Validate audience if client_id provided
if client_id:
if claims.get("aud") != client_id:
raise ValueError("Invalid audience")
return claims
@classmethod
def introspect_token(cls, token: str, client_id: str = None) -> Dict:
"""Introspect a token and return its status and claims.
Args:
token: JWT token to introspect
client_id: Client ID for audience validation
Returns:
Dictionary with active status and claims
"""
result = {
"active": False,
}
try:
claims = cls.validate_access_token(token, client_id)
# Calculate remaining time
now = datetime.utcnow().timestamp()
exp = claims.get("exp", 0)
iat = claims.get("iat", 0)
result["active"] = exp > now
result.update({
"iss": claims.get("iss"),
"sub": claims.get("sub"),
"aud": claims.get("aud"),
"exp": exp,
"iat": iat,
"nbf": claims.get("nbf"),
"jti": claims.get("jti"),
"client_id": claims.get("client_id"),
"scope": claims.get("scope"),
"token_type": "Bearer",
})
# Add expiry in seconds
if exp > now:
result["exp"] = int(exp - now)
except (jwt.InvalidTokenError, ValueError) as e:
result["active"] = False
result["error"] = str(e)
return result
+7 -2
View File
@@ -3,13 +3,11 @@ import os
from config.base import BaseConfig from config.base import BaseConfig
from config.development import DevelopmentConfig from config.development import DevelopmentConfig
from config.testing import TestingConfig from config.testing import TestingConfig
from config.production import ProductionConfig
config_by_name = { config_by_name = {
"development": DevelopmentConfig, "development": DevelopmentConfig,
"testing": TestingConfig, "testing": TestingConfig,
"production": ProductionConfig,
"default": DevelopmentConfig, "default": DevelopmentConfig,
} }
@@ -18,4 +16,11 @@ def get_config(config_name=None):
"""Get configuration object based on environment.""" """Get configuration object based on environment."""
if config_name is None: if config_name is None:
config_name = os.getenv("FLASK_ENV", "development") config_name = os.getenv("FLASK_ENV", "development")
# Lazy import of ProductionConfig to avoid requiring SECRET_KEY in non-production environments
if config_name == "production":
from config.production import ProductionConfig
config_by_name["production"] = ProductionConfig
return ProductionConfig
return config_by_name.get(config_name, DevelopmentConfig) return config_by_name.get(config_name, DevelopmentConfig)
+30 -4
View File
@@ -17,6 +17,7 @@ class BaseConfig:
) )
SQLALCHEMY_TRACK_MODIFICATIONS = False SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_ECHO = os.getenv("SQLALCHEMY_ECHO", "False").lower() == "true" SQLALCHEMY_ECHO = os.getenv("SQLALCHEMY_ECHO", "False").lower() == "true"
SQLALCHEMY_LOG_LEVEL = os.getenv("SQLALCHEMY_LOG_LEVEL", "WARNING")
SQLALCHEMY_ENGINE_OPTIONS = { SQLALCHEMY_ENGINE_OPTIONS = {
"pool_pre_ping": True, "pool_pre_ping": True,
"pool_recycle": 300, "pool_recycle": 300,
@@ -47,9 +48,12 @@ class BaseConfig:
# Redis # Redis
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
# Flask session configuration - deprecated, migrating to Bearer token authentication
# SESSION_TYPE = "redis" # Flask Session configuration
# SESSION_REDIS = None # Will be set at app initialization SESSION_TYPE = os.getenv("SESSION_TYPE", "filesystem")
SESSION_FILE_DIR = os.getenv("SESSION_FILE_DIR", "/tmp/flask_session")
SESSION_FILE_THRESHOLD = int(os.getenv("SESSION_FILE_THRESHOLD", "500"))
SESSION_REDIS = None # Will be set at app initialization
# Rate Limiting # Rate Limiting
RATELIMIT_ENABLED = os.getenv("RATELIMIT_ENABLED", "True").lower() == "true" RATELIMIT_ENABLED = os.getenv("RATELIMIT_ENABLED", "True").lower() == "true"
@@ -60,8 +64,30 @@ class BaseConfig:
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true" LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true"
# OIDC # OIDC Configuration
OIDC_ISSUER_URL = os.getenv("OIDC_ISSUER_URL", "http://localhost:5000") OIDC_ISSUER_URL = os.getenv("OIDC_ISSUER_URL", "http://localhost:5000")
OIDC_BASE_URL = os.getenv("OIDC_BASE_URL", OIDC_ISSUER_URL)
# Token lifetimes
OIDC_ACCESS_TOKEN_LIFETIME = int(os.getenv("OIDC_ACCESS_TOKEN_LIFETIME", "3600"))
OIDC_REFRESH_TOKEN_LIFETIME = int(os.getenv("OIDC_REFRESH_TOKEN_LIFETIME", "2592000"))
OIDC_ID_TOKEN_LIFETIME = int(os.getenv("OIDC_ID_TOKEN_LIFETIME", "3600"))
OIDC_AUTHORIZATION_CODE_LIFETIME = int(os.getenv("OIDC_AUTHORIZATION_CODE_LIFETIME", "600"))
# Security settings
OIDC_REQUIRE_PKCE = os.getenv("OIDC_REQUIRE_PKCE", "True").lower() == "true"
OIDC_ALLOW_IMPLICIT_FLOW = os.getenv("OIDC_ALLOW_IMPLICIT_FLOW", "False").lower() == "true"
OIDC_SUPPORTED_SCOPES = ["openid", "profile", "email"]
OIDC_DEFAULT_SCOPES = ["openid", "profile", "email"]
# Key rotation
OIDC_KEY_ROTATION_DAYS = int(os.getenv("OIDC_KEY_ROTATION_DAYS", "90"))
OIDC_KEY_GRACE_PERIOD_DAYS = int(os.getenv("OIDC_KEY_GRACE_PERIOD_DAYS", "30"))
# Rate limiting
OIDC_RATE_LIMIT_AUTHORIZE = os.getenv("OIDC_RATE_LIMIT_AUTHORIZE", "10/minute")
OIDC_RATE_LIMIT_TOKEN = os.getenv("OIDC_RATE_LIMIT_TOKEN", "20/minute")
OIDC_RATE_LIMIT_USERINFO = os.getenv("OIDC_RATE_LIMIT_USERINFO", "60/minute")
# API Versioning # API Versioning
API_VERSION = "1.0.0" API_VERSION = "1.0.0"
+3 -2
View File
@@ -1,12 +1,13 @@
"""Development environment configuration.""" """Development environment configuration."""
from config.base import BaseConfig from config.base import BaseConfig
import os
class DevelopmentConfig(BaseConfig): class DevelopmentConfig(BaseConfig):
"""Development configuration.""" """Development configuration."""
DEBUG = True DEBUG = True
SQLALCHEMY_ECHO = True # Use environment variable like BaseConfig does
SQLALCHEMY_ECHO = os.getenv("SQLALCHEMY_ECHO", "False").lower() == "true"
SESSION_COOKIE_SECURE = False SESSION_COOKIE_SECURE = False
# More verbose logging in development # More verbose logging in development
+8
View File
@@ -1,5 +1,6 @@
"""Testing environment configuration.""" """Testing environment configuration."""
from config.base import BaseConfig from config.base import BaseConfig
import os
class TestingConfig(BaseConfig): class TestingConfig(BaseConfig):
@@ -8,6 +9,9 @@ class TestingConfig(BaseConfig):
TESTING = True TESTING = True
DEBUG = True DEBUG = True
# Explicitly set SECRET_KEY for testing
SECRET_KEY = os.getenv("SECRET_KEY", "test-secret-key-for-testing")
# Use in-memory SQLite for testing # Use in-memory SQLite for testing
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:" SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
SQLALCHEMY_ECHO = False SQLALCHEMY_ECHO = False
@@ -23,3 +27,7 @@ class TestingConfig(BaseConfig):
# Use different Redis DB for testing # Use different Redis DB for testing
REDIS_URL = "redis://localhost:6379/15" REDIS_URL = "redis://localhost:6379/15"
# Use filesystem for sessions in testing
SESSION_TYPE = "filesystem"
SESSION_FILE_DIR = "/tmp/flask_session_test"
+1516
View File
File diff suppressed because it is too large Load Diff
+572
View File
@@ -0,0 +1,572 @@
# OIDC Testing Guide
This guide provides step-by-step instructions for manually testing the OIDC implementation using curl commands.
## Prerequisites
1. A running instance of the authy2 backend
2. curl installed
3. A test user account
4. A registered OIDC client
## Setup
### Start the Backend
```bash
# Development mode
python -m flask run --host=0.0.0.0 --port=5000
# Or using the manage.py script
python manage.py runserver --host=0.0.0.0 --port=5000
```
### Register a Test User (if needed)
```bash
curl -X POST http://localhost:5000/api/v1/auth/register \
-H "Content-Type: application/json" \
-d '{
"email": "test@example.com",
"password": "TestPassword123!",
"password_confirm": "TestPassword123!",
"full_name": "Test User"
}'
```
### Register an OIDC Client
```bash
# Register a new OIDC client
curl -X POST http://localhost:5000/oidc/register \
-H "Content-Type: application/json" \
-d '{
"client_name": "Test Client",
"redirect_uris": ["http://localhost:8080/callback"],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"scope": "openid profile email",
"token_endpoint_auth_method": "client_secret_basic"
}'
```
**Save the `client_id` and `client_secret` from the response for later use.**
## Testing Endpoints
### 1. Discovery Endpoint
**Purpose:** Verify OIDC discovery configuration is accessible and correct.
```bash
curl -s http://localhost:5000/.well-known/openid-configuration | jq
```
**Expected Response:**
```json
{
"issuer": "http://localhost:5000",
"authorization_endpoint": "http://localhost:5000/oidc/authorize",
"token_endpoint": "http://localhost:5000/oidc/token",
"userinfo_endpoint": "http://localhost:5000/oidc/userinfo",
"jwks_uri": "http://localhost:5000/oidc/jwks",
"registration_endpoint": "http://localhost:5000/oidc/register",
"revocation_endpoint": "http://localhost:5000/oidc/revoke",
"introspection_endpoint": "http://localhost:5000/oidc/introspect",
"scopes_supported": ["openid", "profile", "email"],
"response_types_supported": ["code"],
"response_modes_supported": ["query"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["RS256"],
"claims_supported": ["sub", "name", "email", "email_verified"]
}
```
**Verification:**
- All endpoints are present and use the correct base URL
- Cache-Control header is set: `curl -I http://localhost:5000/.well-known/openid-configuration`
### 2. JWKS Endpoint
**Purpose:** Verify JWKS is accessible and contains valid signing keys.
```bash
curl -s http://localhost:5000/oidc/jwks | jq
```
**Expected Response:**
```json
{
"keys": [
{
"kty": "RSA",
"kid": "...",
"use": "sig",
"alg": "RS256",
"n": "...",
"e": "..."
}
]
}
```
**Verification:**
- At least one key is present
- Key has `kty: "RSA"`, `alg: "RS256"`
- Cache-Control header is set
### 3. Authorization Code Flow with PKCE
This is the complete OAuth2/OIDC authentication flow.
#### Step 1: Generate PKCE Parameters
```bash
# Generate code verifier (43-128 characters)
CODE_VERIFIER=$(openssl rand -base64 32 | tr -d '=' | tr '/+' '_-' | cut -c1-43)
# Generate code challenge from verifier
CODE_CHALLENGE=$(echo -n "$CODE_VERIFIER" | openssl sha256 -binary | base64 | tr -d '=' | tr '/+' '_-')
# Generate state parameter
STATE=$(openssl rand -hex 16)
# Generate nonce for ID token
NONCE=$(openssl rand -hex 16)
echo "Code Verifier: $CODE_VERIFIER"
echo "Code Challenge: $CODE_CHALLENGE"
echo "State: $STATE"
echo "Nonce: $NONCE"
```
#### Step 2: Request Authorization Code
**Option A: Browser-based flow (redirect flow)**
```
# Open this URL in a browser
http://localhost:5000/oidc/authorize?\
client_id=YOUR_CLIENT_ID&\
redirect_uri=http://localhost:8080/callback&\
response_type=code&\
scope=openid%20profile%20email&\
state=YOUR_STATE&\
nonce=YOUR_NONCE&\
code_challenge=YOUR_CODE_CHALLENGE&\
code_challenge_method=S256
```
**Option B: POST-based flow (for testing with curl)**
```bash
curl -v -X POST http://localhost:5000/oidc/authorize \
-d "client_id=YOUR_CLIENT_ID" \
-d "redirect_uri=http://localhost:8080/callback" \
-d "response_type=code" \
-d "scope=openid profile email" \
-d "state=$STATE" \
-d "nonce=$NONCE" \
-d "code_challenge=$CODE_CHALLENGE" \
-d "code_challenge_method=S256" \
-d "email=test@example.com" \
-d "password=TestPassword123!"
```
**Expected Response:** 302 Redirect with `code` parameter
```http
HTTP/1.1 302 Found
Location: http://localhost:8080/callback?code=AUTHORIZATION_CODE&state=YOUR_STATE
```
**Extract the authorization code:**
```bash
# From the Location header
AUTH_CODE=$(curl -v -X POST http://localhost:5000/oidc/authorize \
-d "client_id=YOUR_CLIENT_ID" \
-d "redirect_uri=http://localhost:8080/callback" \
-d "response_type=code" \
-d "scope=openid profile email" \
-d "state=$STATE" \
-d "nonce=$NONCE" \
-d "code_challenge=$CODE_CHALLENGE" \
-d "code_challenge_method=S256" \
-d "email=test@example.com" \
-d "password=TestPassword123!" 2>&1 | grep -i "Location:" | cut -d' ' -f2 | cut -d'?' -f2 | cut -d'=' -f2)
```
#### Step 3: Exchange Authorization Code for Tokens
```bash
# Using client_id and client_secret
curl -X POST http://localhost:5000/oidc/token \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=authorization_code" \
-d "code=$AUTH_CODE" \
-d "redirect_uri=http://localhost:8080/callback" \
-d "client_id=YOUR_CLIENT_ID" \
-d "client_secret=YOUR_CLIENT_SECRET" \
-d "code_verifier=$CODE_VERIFIER"
```
**Expected Response:**
```json
{
"version": "1.0",
"success": true,
"code": 200,
"message": "Tokens issued successfully",
"request_id": "...",
"data": {
"access_token": "eyJ...",
"token_type": "Bearer",
"expires_in": 3600,
"id_token": "eyJ...",
"refresh_token": "..."
}
}
```
**Verification:**
- `access_token` is a JWT (check at jwt.io)
- `token_type` is "Bearer"
- `expires_in` indicates token lifetime
- `id_token` contains expected claims (sub, iss, aud, etc.)
### 4. UserInfo Endpoint
**Purpose:** Retrieve user information using the access token.
```bash
curl -X GET http://localhost:5000/oidc/userinfo \
-H "Authorization: Bearer YOUR_ACCESS_TOKEN"
```
**Expected Response:**
```json
{
"version": "1.0",
"success": true,
"code": 200,
"message": "User info retrieved successfully",
"request_id": "...",
"data": {
"sub": "user-id",
"name": "Test User",
"email": "test@example.com",
"email_verified": true
}
}
```
**Verification:**
- `sub` matches the user ID
- `email` and `email_verified` are present if email scope was requested
- `name` is present if profile scope was requested
### 5. Token Refresh
**Purpose:** Obtain a new access token using a refresh token.
```bash
curl -X POST http://localhost:5000/oidc/token \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=refresh_token" \
-d "refresh_token=YOUR_REFRESH_TOKEN" \
-d "client_id=YOUR_CLIENT_ID" \
-d "client_secret=YOUR_CLIENT_SECRET"
```
**Expected Response:**
```json
{
"version": "1.0",
"success": true,
"code": 200,
"message": "Tokens refreshed successfully",
"request_id": "...",
"data": {
"access_token": "eyJ...",
"token_type": "Bearer",
"expires_in": 3600,
"id_token": "eyJ...",
"refresh_token": "..."
}
}
```
**Verification:**
- New `access_token` is returned
- New `refresh_token` is returned (token rotation)
- Old refresh token is now invalid
### 6. Token Revocation
**Purpose:** Revoke a token to invalidate it.
```bash
# Revoke access token
curl -X POST http://localhost:5000/oidc/revoke \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "token=YOUR_ACCESS_TOKEN" \
-d "token_type_hint=access_token" \
-d "client_id=YOUR_CLIENT_ID" \
-d "client_secret=YOUR_CLIENT_SECRET"
# Revoke refresh token
curl -X POST http://localhost:5000/oidc/revoke \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "token=YOUR_REFRESH_TOKEN" \
-d "token_type_hint=refresh_token" \
-d "client_id=YOUR_CLIENT_ID" \
-d "client_secret=YOUR_CLIENT_SECRET"
```
**Expected Response:**
```json
{
"version": "1.0",
"success": true,
"code": 200,
"message": "Token revoked successfully",
"request_id": "..."
}
```
**Verification:**
- Revoked refresh token cannot be used for refresh
- Revoked access token cannot be used for UserInfo
### 7. Token Introspection
**Purpose:** Check if a token is active and get its claims.
```bash
curl -X POST http://localhost:5000/oidc/introspect \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "token=YOUR_ACCESS_TOKEN" \
-d "client_id=YOUR_CLIENT_ID" \
-d "client_secret=YOUR_CLIENT_SECRET"
```
**Expected Response (active token):**
```json
{
"version": "1.0",
"success": true,
"code": 200,
"message": "Token introspection successful",
"request_id": "...",
"data": {
"active": true,
"iss": "http://localhost:5000",
"sub": "user-id",
"aud": "YOUR_CLIENT_ID",
"exp": 1234567890,
"iat": 1234564290,
"scope": "openid profile email",
"token_type": "Bearer"
}
}
```
**Expected Response (invalid/expired token):**
```json
{
"version": "1.0",
"success": true,
"code": 200,
"message": "Token introspection successful",
"request_id": "...",
"data": {
"active": false
}
}
```
## Complete Flow Test Script
Here's a comprehensive script that tests the complete OIDC flow:
```bash
#!/bin/bash
set -e
BASE_URL="http://localhost:5000"
CLIENT_ID="YOUR_CLIENT_ID"
CLIENT_SECRET="YOUR_CLIENT_SECRET"
EMAIL="test@example.com"
PASSWORD="TestPassword123!"
REDIRECT_URI="http://localhost:8080/callback"
echo "=== OIDC Complete Flow Test ==="
# 1. Discovery
echo -e "\n1. Testing Discovery Endpoint..."
curl -s "$BASE_URL/.well-known/openid-configuration" | jq . > /dev/null
echo " ✓ Discovery endpoint working"
# 2. JWKS
echo -e "\n2. Testing JWKS Endpoint..."
curl -s "$BASE_URL/oidc/jwks" | jq . > /dev/null
echo " ✓ JWKS endpoint working"
# 3. Generate PKCE parameters
echo -e "\n3. Generating PKCE parameters..."
CODE_VERIFIER=$(openssl rand -base64 32 | tr -d '=' | tr '/+' '_-' | cut -c1-43)
CODE_CHALLENGE=$(echo -n "$CODE_VERIFIER" | openssl sha256 -binary | base64 | tr -d '=' | tr '/+' '_-')
STATE=$(openssl rand -hex 16)
echo " ✓ PKCE parameters generated"
# 4. Get Authorization Code
echo -e "\n4. Getting Authorization Code..."
AUTH_RESPONSE=$(curl -s -D - -X POST "$BASE_URL/oidc/authorize" \
-d "client_id=$CLIENT_ID" \
-d "redirect_uri=$REDIRECT_URI" \
-d "response_type=code" \
-d "scope=openid profile email" \
-d "state=$STATE" \
-d "code_challenge=$CODE_CHALLENGE" \
-d "code_challenge_method=S256" \
-d "email=$EMAIL" \
-d "password=$PASSWORD")
AUTH_CODE=$(echo "$AUTH_RESPONSE" | grep -i "Location:" | cut -d'?' -f2 | cut -d'=' -f2 | tr -d '\r')
echo " ✓ Authorization code received: ${AUTH_CODE:0:20}..."
# 5. Exchange Code for Tokens
echo -e "\n5. Exchanging Code for Tokens..."
TOKEN_RESPONSE=$(curl -s -X POST "$BASE_URL/oidc/token" \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=authorization_code" \
-d "code=$AUTH_CODE" \
-d "redirect_uri=$REDIRECT_URI" \
-d "client_id=$CLIENT_ID" \
-d "client_secret=$CLIENT_SECRET" \
-d "code_verifier=$CODE_VERIFIER")
ACCESS_TOKEN=$(echo "$TOKEN_RESPONSE" | jq -r '.data.access_token')
REFRESH_TOKEN=$(echo "$TOKEN_RESPONSE" | jq -r '.data.refresh_token')
echo " ✓ Tokens received"
# 6. UserInfo
echo -e "\n6. Testing UserInfo Endpoint..."
USERINFO=$(curl -s -X GET "$BASE_URL/oidc/userinfo" \
-H "Authorization: Bearer $ACCESS_TOKEN")
echo " ✓ UserInfo response: $(echo "$USERINFO" | jq -r '.data.sub')"
# 7. Token Refresh
echo -e "\n7. Testing Token Refresh..."
REFRESH_RESPONSE=$(curl -s -X POST "$BASE_URL/oidc/token" \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=refresh_token" \
-d "refresh_token=$REFRESH_TOKEN" \
-d "client_id=$CLIENT_ID" \
-d "client_secret=$CLIENT_SECRET")
NEW_ACCESS_TOKEN=$(echo "$REFRESH_RESPONSE" | jq -r '.data.access_token')
NEW_REFRESH_TOKEN=$(echo "$REFRESH_RESPONSE" | jq -r '.data.refresh_token')
echo " ✓ Token refresh successful"
# 8. Token Introspection
echo -e "\n8. Testing Token Introspection..."
INTROSPECT=$(curl -s -X POST "$BASE_URL/oidc/introspect" \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "token=$NEW_ACCESS_TOKEN" \
-d "client_id=$CLIENT_ID" \
-d "client_secret=$CLIENT_SECRET")
IS_ACTIVE=$(echo "$INTROSPECT" | jq -r '.data.active')
echo " ✓ Token introspection: active=$IS_ACTIVE"
# 9. Token Revocation
echo -e "\n9. Testing Token Revocation..."
curl -s -X POST "$BASE_URL/oidc/revoke" \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "token=$NEW_REFRESH_TOKEN" \
-d "token_type_hint=refresh_token" \
-d "client_id=$CLIENT_ID" \
-d "client_secret=$CLIENT_SECRET" > /dev/null
echo " ✓ Token revoked"
# 10. Verify Revoked Token
echo -e "\n10. Verifying Revoked Token..."
REVOKE_VERIFY=$(curl -s -X POST "$BASE_URL/oidc/token" \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=refresh_token" \
-d "refresh_token=$NEW_REFRESH_TOKEN" \
-d "client_id=$CLIENT_ID" \
-d "client_secret=$CLIENT_SECRET")
IS_INVALID=$(echo "$REVOKE_VERIFY" | jq -r '.success')
echo " ✓ Revoked token is invalid: success=$IS_INVALID"
echo -e "\n=== OIDC Flow Test Complete ==="
echo "All endpoints tested successfully!"
```
## Error Handling Tests
### Invalid Client
```bash
curl -X POST http://localhost:5000/oidc/token \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=authorization_code" \
-d "code=invalid" \
-d "client_id=invalid_client" \
-d "client_secret=invalid_secret"
```
### Invalid Authorization Code
```bash
curl -X POST http://localhost:5000/oidc/token \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=authorization_code" \
-d "code=INVALID_CODE" \
-d "redirect_uri=http://localhost:8080/callback" \
-d "client_id=YOUR_CLIENT_ID"
```
### Expired Authorization Code
Authorization codes expire after 10 minutes. Wait 10+ minutes and try to use the code again.
### Invalid PKCE Verifier
Use an incorrect `code_verifier` during token exchange:
```bash
curl -X POST http://localhost:5000/oidc/token \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=authorization_code" \
-d "code=YOUR_AUTH_CODE" \
-d "redirect_uri=http://localhost:8080/callback" \
-d "client_id=YOUR_CLIENT_ID" \
-d "code_verifier=wrong_verifier"
```
## Troubleshooting
### Connection Refused
Ensure the backend is running:
```bash
ps aux | grep flask
lsof -i :5000
```
### Authentication Failures
1. Verify user credentials are correct
2. Check that the user exists in the database
3. Ensure the client is active and has correct redirect URIs
### Token Errors
1. Verify access token hasn't expired
2. Check that the token was signed by the OIDC provider
3. Ensure the audience (client_id) matches
### Redirect URI Mismatch
Ensure the `redirect_uri` used in authorization and token exchange exactly matches a registered redirect URI.
+251
View File
@@ -0,0 +1,251 @@
# OAuth2-Proxy Configuration Example
# ================================
# This configuration file demonstrates how to configure oauth2-proxy
# to use this OIDC provider for authentication.
#
# oauth2-proxy project: https://oauth2-proxy.github.io/oauth2-proxy/
#
# Usage:
# oauth2-proxy -config /path/to/oauth2-proxy-config.yaml
#
# Environment variables can also be used by prefixing with OAUTH2_PROXY_
# e.g., OAUTH2_PROXY_PROVIDER="oidc"
# Server Configuration
# --------------------
# The address and port to bind to
http_address: "0.0.0.0:4180"
https_address: ":4443"
# OIDC Provider Configuration
# ---------------------------
# Provider configuration - OIDC for our authy2 backend
provider: "oidc"
# OIDC issuer URL - points to our OIDC discovery endpoint
# This should be the base URL of your authy2 backend
oidc_issuer_url: "http://localhost:5000"
# Email domains to allow (empty means any email is allowed)
# email_domains:
# - "*"
# Client Configuration
# --------------------
# Client ID and secret obtained from OIDC Client Registration
# Run: curl -X POST http://localhost:5000/oidc/register -H "Content-Type: application/json" -d '{"client_name":"oauth2-proxy","redirect_uris":["http://localhost:4180/oauth2/callback"],"scope":"openid profile email"}'
client_id: "your-client-id-here"
client_secret: "your-client-secret-here"
# Client ID file (alternative to providing secret directly)
# client_id_file: "/etc/oauth2-proxy/client_id"
# client_secret_file: "/etc/oauth2-proxy/client_secret"
# OIDC Scopes
# ------------
# Scopes to request from the OIDC provider
# The "openid" scope is always requested
# Available scopes in our OIDC provider: openid, profile, email
scope: "openid profile email"
# Cookie Configuration
# --------------------
# Secret key for cookie encryption (should be random and kept secret)
# Generate with: openssl rand -base64 32 | head -c 32 | xargs
cookie_secret: "your-random-cookie-secret-min-32-chars"
# Name of the cookie that oauth2-proxy will use
cookie_name: "_oauth2_proxy"
# Cookie options
cookie_expire: "168h" # 7 days
cookie_refresh: "1h" # Refresh cookie every hour
secure_cookies: false # Set to true in production with HTTPS
http_only_cookies: true
# Upstream Configuration
# ---------------------
# The upstream application to proxy requests to
# Multiple upstreams can be configured
upstream: "http://127.0.0.1:8080/"
# Internal upstream (not accessible from internet)
# internal_upstream: "http://127.0.0.1:8081/"
# Response Configuration
# ----------------------
# URL to redirect users to after successful authentication
# Can be overridden per-request with &rd parameter
redirect_url: "http://localhost:4180/oauth2/callback"
# Sign-in URL (shown when not authenticated)
sign_in_url: "http://localhost:4180/sign_in"
# Sign-out URL
sign_out_url: "http://localhost:4180/sign_out"
# Proxy Configuration
# -------------------
# List of paths to protect
# Requests to these paths will require authentication
proxy_root_controller: true
# Skip JWT verification for specific routes (advanced)
# skip_auth_routes:
# - path: /public
# regex: false
# - path: /api/health
# regex: true
# Headers Configuration
# ---------------------
# Headers to set for authenticated requests
# These headers are passed to the upstream application
set_authorization_header: true
set_x_auth_request_header: true
# Pass headers from OIDC provider
# pass_access_token: true
# pass_id_token_header: true
# Custom headers
# headers:
# X-Forwarded-User: "${email}"
# X-Forwarded-Groups: "${groups}"
# Token Validation
# ----------------
# Validate tokens against the OIDC provider
validate_session: true
# Refresh expired tokens
# refresh_token: true
# Logging Configuration
# ---------------------
# Log level: debug, info, warn, error
log_level: "info"
# Log format: apache, json, nginx
log_format: "json"
# Metrics Configuration
# ---------------------
# Enable metrics endpoint
metrics_address: "0.0.0.0:9090"
# Request Logging
# ---------------
# Log requests to stdout
request_logging: true
# Batch request logging
# batch_request_logging: false
# Reverse Proxy Headers
# ---------------------
# Use X-Real-IP header from reverse proxy
real_ip_header: "X-Real-IP"
# Trusted CIDRs (for determining client IP)
# trusted_cirs:
# - "10.0.0.0/8"
# - "172.16.0.0/12"
# - "192.168.0.0/16"
# Rate Limiting
# -------------
# Enable rate limiting
# enable_ratelimit: true
# ratelimit:
# type: "memory"
# requests_per_second: 10
# Advanced Options
# ----------------
# Whitelist emails (users who can authenticate)
# whitelist_emails:
# - "admin@example.com"
# Blacklist emails (users who cannot authenticate)
# blacklist_emails:
# - "banned@example.com"
# Whitelist domains
# whitelist_domains:
# - "@example.com"
# Skip OIDC discovery (use manual endpoints)
# skip_oidc_discovery: false
# login_url: "http://localhost:5000/oidc/authorize"
# redeem_url: "http://localhost:5000/oidc/token"
# profile_url: "http://localhost:5000/oidc/userinfo"
# validate_url: "http://localhost:5000/oidc/jwks"
# TLS Configuration
# -----------------
# Enable TLS (uncomment in production)
# tls: true
# tls_cert_file: "/etc/ssl/certs/oauth2-proxy.crt"
# tls_key_file: "/etc/ssl/private/oauth2-proxy.key"
# Skip TLS verification (for testing only)
# tls_insecure_skip_verify: false
# OIDC Extra Configuration
# ------------------------
# Extra parameters to pass to authorization request
# authorise_params:
# acr_values: "urn:goauthentik.io:authentication:factor"
# max_age: "3600"
# Ping path for health checks
# ping_path: "/ping"
# Example Usage Scenarios
# =======================
# Scenario 1: Basic Setup with Local OIDC Provider
# ------------------------------------------------
# Use this configuration when running oauth2-proxy locally
# pointing to the authy2 backend running on localhost:5000
# Scenario 2: Production Setup with HTTPS
# ---------------------------------------
# For production, use HTTPS for all connections
# Set secure_cookies: true
# Configure TLS certificates
# Point to your production OIDC issuer URL
# Scenario 3: Docker Compose Setup
# --------------------------------
# Example docker-compose.yml for oauth2-proxy:
#
# version: '3'
# services:
# oauth2-proxy:
# image: oauth2-proxy/oauth2-proxy:latest
# ports:
# - "4180:4180"
# volumes:
# - ./oauth2-proxy-config.yaml:/etc/oauth2-proxy/config.yaml
# environment:
# - OAUTH2_PROXY_PROVIDER=oidc
# - OAUTH2_PROXY_OIDC_ISSUER_URL=http://authy2:5000
# - OAUTH2_PROXY_CLIENT_ID=${OIDC_CLIENT_ID}
# - OAUTH2_PROXY_CLIENT_SECRET=${OIDC_CLIENT_SECRET}
# - OAUTH2_PROXY_COOKIE_SECRET=${COOKIE_SECRET}
# depends_on:
# - authy2
# Scenario 4: Kubernetes Ingress with oauth2-proxy
# -------------------------------------------------
# Example annotation for Kubernetes Ingress:
#
# nginx.ingress.kubernetes.io/auth-url: https://$host/oauth2/auth
# nginx.ingress.kubernetes.io/auth-signin: https://$host/oauth2/sign_in
# nginx.ingress.kubernetes.io/configuration-snippet: |
# auth_request_set $user $upstream_http_x_auth_request_user;
# auth_request_set $email $upstream_http_x_auth_request_email;
# proxy_set_header X-User $user;
# proxy_set_header X-Email $email;
+5 -4
View File
@@ -1,11 +1,12 @@
"""Management script for Flask application.""" """Management script for Flask application."""
import os import os
from flask.cli import FlaskGroup
from app import create_app
from dotenv import load_dotenv from dotenv import load_dotenv
# Load environment variables # Load environment variables FIRST, before any app imports
load_dotenv() load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), '.env'))
from flask.cli import FlaskGroup
from app import create_app
# Create application # Create application
app = create_app(os.getenv("FLASK_ENV", "development")) app = create_app(os.getenv("FLASK_ENV", "development"))
+150
View File
@@ -0,0 +1,150 @@
"""Database migration: Create OIDC tables.
Revision ID: 001
Revises:
Create Date: 2024-01-01 00:00:00
This migration creates all OIDC-related tables for the authorization code flow,
refresh token management, OIDC session tracking, token metadata, and audit logging.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# Revision identifiers
revision = '001'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
"""Create OIDC tables."""
# OIDC Authorization Codes table
op.create_table(
'oidc_authorization_codes',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
sa.Column('code_hash', sa.String(255), nullable=False),
sa.Column('redirect_uri', sa.String(512), nullable=False),
sa.Column('scope', postgresql.JSON, nullable=True),
sa.Column('nonce', sa.String(255), nullable=True),
sa.Column('code_verifier', sa.String(255), nullable=True),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('used_at', sa.DateTime, nullable=True),
sa.Column('is_used', sa.Boolean, default=False, nullable=False),
sa.Column('ip_address', sa.String(45), nullable=True),
sa.Column('user_agent', sa.Text, nullable=True),
)
op.create_index('ix_oidc_authorization_codes_client_id', 'oidc_authorization_codes', ['client_id'])
op.create_index('ix_oidc_authorization_codes_user_id', 'oidc_authorization_codes', ['user_id'])
op.create_index('ix_oidc_authorization_codes_expires_at', 'oidc_authorization_codes', ['expires_at'])
# OIDC Refresh Tokens table
op.create_table(
'oidc_refresh_tokens',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
sa.Column('token_hash', sa.String(255), nullable=False),
sa.Column('access_token_id', sa.String(36), sa.ForeignKey('sessions.id'), nullable=True),
sa.Column('scope', postgresql.JSON, nullable=True),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('revoked_at', sa.DateTime, nullable=True),
sa.Column('revoked_reason', sa.String(255), nullable=True),
sa.Column('previous_token_hash', sa.String(255), nullable=True),
sa.Column('rotation_count', sa.Integer, default=0, nullable=False),
sa.Column('ip_address', sa.String(45), nullable=True),
sa.Column('user_agent', sa.Text, nullable=True),
)
op.create_index('ix_oidc_refresh_tokens_client_id', 'oidc_refresh_tokens', ['client_id'])
op.create_index('ix_oidc_refresh_tokens_user_id', 'oidc_refresh_tokens', ['user_id'])
op.create_index('ix_oidc_refresh_tokens_token_hash', 'oidc_refresh_tokens', ['token_hash'], unique=True)
op.create_index('ix_oidc_refresh_tokens_access_token_id', 'oidc_refresh_tokens', ['access_token_id'])
op.create_index('ix_oidc_refresh_tokens_expires_at', 'oidc_refresh_tokens', ['expires_at'])
# OIDC Sessions table
op.create_table(
'oidc_sessions',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
sa.Column('state', sa.String(255), nullable=False),
sa.Column('nonce', sa.String(255), nullable=True),
sa.Column('redirect_uri', sa.String(512), nullable=False),
sa.Column('scope', postgresql.JSON, nullable=True),
sa.Column('code_challenge', sa.String(255), nullable=True),
sa.Column('code_challenge_method', sa.String(10), nullable=True),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('authenticated_at', sa.DateTime, nullable=True),
)
op.create_index('ix_oidc_sessions_user_id', 'oidc_sessions', ['user_id'])
op.create_index('ix_oidc_sessions_client_id', 'oidc_sessions', ['client_id'])
op.create_index('ix_oidc_sessions_state', 'oidc_sessions', ['state'])
op.create_index('ix_oidc_sessions_expires_at', 'oidc_sessions', ['expires_at'])
# OIDC Token Metadata table
op.create_table(
'oidc_token_metadata',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=False),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
sa.Column('token_type', sa.String(50), nullable=False),
sa.Column('token_jti', sa.String(255), nullable=False),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('revoked_at', sa.DateTime, nullable=True),
sa.Column('revoked_reason', sa.String(255), nullable=True),
)
op.create_index('ix_oidc_token_metadata_client_id', 'oidc_token_metadata', ['client_id'])
op.create_index('ix_oidc_token_metadata_user_id', 'oidc_token_metadata', ['user_id'])
op.create_index('ix_oidc_token_metadata_token_jti', 'oidc_token_metadata', ['token_jti'])
op.create_index('ix_oidc_token_metadata_expires_at', 'oidc_token_metadata', ['expires_at'])
# OIDC Audit Logs table
op.create_table(
'oidc_audit_logs',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
sa.Column('event_type', sa.String(100), nullable=False),
sa.Column('client_id', sa.String(255), sa.ForeignKey('oidc_clients.id'), nullable=True),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=True),
sa.Column('success', sa.Boolean, default=True, nullable=False),
sa.Column('error_code', sa.String(100), nullable=True),
sa.Column('error_description', sa.Text, nullable=True),
sa.Column('ip_address', sa.String(45), nullable=True),
sa.Column('user_agent', sa.Text, nullable=True),
sa.Column('request_id', sa.String(36), nullable=True),
sa.Column('event_metadata', postgresql.JSON, nullable=True),
)
op.create_index('ix_oidc_audit_logs_event_type', 'oidc_audit_logs', ['event_type'])
op.create_index('ix_oidc_audit_logs_client_id', 'oidc_audit_logs', ['client_id'])
op.create_index('ix_oidc_audit_logs_user_id', 'oidc_audit_logs', ['user_id'])
op.create_index('ix_oidc_audit_logs_success', 'oidc_audit_logs', ['success'])
op.create_index('ix_oidc_audit_logs_ip_address', 'oidc_audit_logs', ['ip_address'])
op.create_index('ix_oidc_audit_logs_request_id', 'oidc_audit_logs', ['request_id'])
def downgrade():
"""Drop OIDC tables."""
op.drop_table('oidc_audit_logs')
op.drop_table('oidc_token_metadata')
op.drop_table('oidc_sessions')
op.drop_table('oidc_refresh_tokens')
op.drop_table('oidc_authorization_codes')
+74
View File
@@ -0,0 +1,74 @@
"""Flask-Migrate environment configuration."""
import os
import sys
# Add the parent directory to the path so we can import the app
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Load environment variables
from dotenv import load_dotenv
load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env'))
# Import the Flask app and db
from app import create_app
from app.extensions import db
# Get the app
app = create_app(os.getenv("FLASK_ENV", "development"))
# Set the Flask application context
with app.app_context():
from alembic import context
# this is the Alembic Config object, which provides access
# to the values within the .ini file in use.
config = context.config
# Set the SQLAlchemy URL from the app config
config.set_main_option('sqlalchemy.url', app.config.get('SQLALCHEMY_DATABASE_URI'))
# Set the target metadata
target_metadata = db.metadata
def run_migrations_offline():
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here too. By skipping the Engine creation
we don't even need a DBAPI to be available.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connection = db.engine.connect()
context.configure(
connection=connection,
target_metadata=target_metadata
)
try:
with context.begin_transaction():
context.run_migrations()
finally:
connection.close()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
+24
View File
@@ -0,0 +1,24 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade():
${upgrades if upgrades else "pass"}
def downgrade():
${downgrades if downgrades else "pass"}
+4
View File
@@ -17,6 +17,10 @@ marshmallow-sqlalchemy==0.29.0
bcrypt==4.1.2 bcrypt==4.1.2
Flask-Bcrypt==1.0.1 Flask-Bcrypt==1.0.1
# JWT / OIDC
PyJWT==2.8.0
cryptography==41.0.7
# CORS # CORS
Flask-CORS==4.0.0 Flask-CORS==4.0.0
+171 -4
View File
@@ -7,6 +7,12 @@ This script creates:
- Proper organization memberships with different roles - Proper organization memberships with different roles
""" """
import sys import sys
import secrets
import hashlib
from dotenv import load_dotenv
# Load environment variables FIRST before any app imports
load_dotenv()
from app import create_app from app import create_app
from app.extensions import db from app.extensions import db
@@ -14,13 +20,10 @@ from app.models.user import User
from app.models.organization import Organization from app.models.organization import Organization
from app.models.organization_member import OrganizationMember from app.models.organization_member import OrganizationMember
from app.models.authentication_method import AuthenticationMethod from app.models.authentication_method import AuthenticationMethod
from app.models.oidc_client import OIDCClient
from app.services.auth_service import AuthService from app.services.auth_service import AuthService
from app.services.organization_service import OrganizationService from app.services.organization_service import OrganizationService
from app.utils.constants import OrganizationRole, UserStatus, AuthMethodType from app.utils.constants import OrganizationRole, UserStatus, AuthMethodType
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Create application # Create application
app = create_app() app = create_app()
@@ -115,6 +118,38 @@ def add_org_member(org, user_id, role, inviter_id):
raise e raise e
def create_or_get_oidc_client(org_id, name, client_id, client_secret,
redirect_uris, grant_types, response_types, scopes,
**kwargs):
"""Create an OIDC client if it doesn't exist, or return existing client."""
existing = OIDCClient.query.filter_by(client_id=client_id, deleted_at=None).first()
if existing:
print(f" → OIDC Client {name} already exists, skipping")
return existing
try:
# Hash the client secret
client_secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
client = OIDCClient(
organization_id=org_id,
name=name,
client_id=client_id,
client_secret_hash=client_secret_hash,
redirect_uris=redirect_uris,
grant_types=grant_types,
response_types=response_types,
scopes=scopes,
**kwargs
)
client.save()
print(f" → Created OIDC client: {name}")
return client
except Exception as e:
print(f" → Error creating OIDC client {name}: {e}")
raise e
def seed_data(): def seed_data():
"""Seed the database with test data.""" """Seed the database with test data."""
print("=" * 60) print("=" * 60)
@@ -387,6 +422,113 @@ def seed_data():
if sarah and tech_org: if sarah and tech_org:
add_org_member(tech_org, charlie.id, OrganizationRole.MEMBER, sarah.id) add_org_member(tech_org, charlie.id, OrganizationRole.MEMBER, sarah.id)
# =========================================================================
# Step 5: Create OIDC Clients
# =========================================================================
print("\n[Step 5] Creating OIDC Clients...")
oidc_clients = {}
# OIDC Client for Acme Corp - Internal Portal
if acme_org:
print("\n Acme Corporation OIDC Clients:")
acme_portal_client = create_or_get_oidc_client(
org_id=acme_org.id,
name="Acme Internal Portal",
client_id="acme-portal-001",
client_secret="acme_secret_portal_2024",
redirect_uris=[
"https://portal.acme-corp.com/auth/callback",
"http://localhost:3000/auth/callback",
],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scopes=["openid", "profile", "email", "offline_access"],
is_active=True,
is_confidential=True,
require_pkce=True,
access_token_lifetime=3600, # 1 hour
refresh_token_lifetime=2592000, # 30 days
id_token_lifetime=3600, # 1 hour
logo_uri="https://portal.acme-corp.com/logo.png",
client_uri="https://portal.acme-corp.com",
)
oidc_clients["acme-portal"] = acme_portal_client
# OIDC Client for Acme Corp - Mobile App
acme_mobile_client = create_or_get_oidc_client(
org_id=acme_org.id,
name="Acme Mobile App",
client_id="acme-mobile-001",
client_secret="acme_secret_mobile_2024",
redirect_uris=[
"com.acmecorp.app://oauth/callback",
"http://localhost:8080/callback",
],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scopes=["openid", "profile", "email", "offline_access"],
is_active=True,
is_confidential=False, # Public client (mobile)
require_pkce=True,
access_token_lifetime=1800, # 30 minutes
refresh_token_lifetime=604800, # 7 days
id_token_lifetime=1800, # 30 minutes
)
oidc_clients["acme-mobile"] = acme_mobile_client
# OIDC Client for Tech Startup
if tech_org:
print("\n Tech Startup OIDC Clients:")
tech_app_client = create_or_get_oidc_client(
org_id=tech_org.id,
name="Tech Startup Dashboard",
client_id="tech-dashboard-001",
client_secret="tech_secret_dashboard_2024",
redirect_uris=[
"https://dashboard.tech-startup.com/auth/callback",
"http://localhost:4200/auth/callback",
],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scopes=["openid", "profile", "email", "offline_access"],
is_active=True,
is_confidential=True,
require_pkce=True,
access_token_lifetime=3600, # 1 hour
refresh_token_lifetime=2592000, # 30 days
id_token_lifetime=3600, # 1 hour
logo_uri="https://tech-startup.com/logo.png",
client_uri="https://tech-startup.com",
)
oidc_clients["tech-dashboard"] = tech_app_client
# OIDC Client for Data Systems
if data_org:
print("\n Data Systems OIDC Clients:")
data_api_client = create_or_get_oidc_client(
org_id=data_org.id,
name="Data Systems API Client",
client_id="data-api-001",
client_secret="data_secret_api_2024",
redirect_uris=[
"https://api.data-systems.com/oauth/callback",
"http://localhost:5000/oauth/callback",
],
grant_types=["authorization_code", "refresh_token", "client_credentials"],
response_types=["code"],
scopes=["openid", "profile", "email", "api:read", "api:write"],
is_active=True,
is_confidential=True,
require_pkce=False, # Server-to-server client
access_token_lifetime=7200, # 2 hours
refresh_token_lifetime=2592000, # 30 days
id_token_lifetime=3600, # 1 hour
client_uri="https://data-systems.com",
)
oidc_clients["data-api"] = data_api_client
print(f"\n Created {len(oidc_clients)} OIDC clients")
# ========================================================================= # =========================================================================
# Summary # Summary
# ========================================================================= # =========================================================================
@@ -398,6 +540,7 @@ def seed_data():
print(f" Organizations: {len(org_objects)}") print(f" Organizations: {len(org_objects)}")
print(f" Admin Users: {len(admin_objects)}") print(f" Admin Users: {len(admin_objects)}")
print(f" Regular Users: {len(all_users)}") print(f" Regular Users: {len(all_users)}")
print(f" OIDC Clients: {len(oidc_clients)}")
print("\n🔐 Test Credentials:") print("\n🔐 Test Credentials:")
print("\n Admin Accounts:") print("\n Admin Accounts:")
@@ -421,6 +564,30 @@ def seed_data():
print(f" {org.name} (slug: {slug})") print(f" {org.name} (slug: {slug})")
print(f" Members: {member_count}, Owner: {owner_email}") print(f" Members: {member_count}, Owner: {owner_email}")
print("\n🔐 OIDC Clients:")
for key, client in oidc_clients.items():
print(f" {client.name}")
print(f" Client ID: {client.client_id}")
print(f" Organization: {client.organization.name}")
print(f" Grant Types: {', '.join(client.grant_types)}")
print(f" Scopes: {', '.join(client.scopes)}")
print(f" Redirect URIs: {len(client.redirect_uris)} configured")
if oidc_clients:
print("\n 📝 OIDC Client Credentials (for testing):")
print(" Acme Portal:")
print(" client_id: acme-portal-001")
print(" client_secret: acme_secret_portal_2024")
print(" Acme Mobile:")
print(" client_id: acme-mobile-001")
print(" client_secret: acme_secret_mobile_2024")
print(" Tech Dashboard:")
print(" client_id: tech-dashboard-001")
print(" client_secret: tech_secret_dashboard_2024")
print(" Data API:")
print(" client_id: data-api-001")
print(" client_secret: data_secret_api_2024")
print("\n" + "=" * 60) print("\n" + "=" * 60)
File diff suppressed because it is too large Load Diff