major checkpoint

This commit is contained in:
2026-01-08 15:59:53 +10:30
parent 211854ca0a
commit 5e060f267d
33 changed files with 8088 additions and 43 deletions
+11
View File
@@ -4,6 +4,11 @@ from app.services.user_service import UserService
from app.services.organization_service import OrganizationService
from app.services.session_service import SessionService
from app.services.audit_service import AuditService
from app.services.oidc_service import OIDCService, OIDCError
from app.services.oidc_jwks_service import OIDCJWKSService
from app.services.oidc_token_service import OIDCTokenService
from app.services.oidc_session_service import OIDCSessionService
from app.services.oidc_audit_service import OIDCAuditService
__all__ = [
"AuthService",
@@ -11,4 +16,10 @@ __all__ = [
"OrganizationService",
"SessionService",
"AuditService",
"OIDCService",
"OIDCError",
"OIDCJWKSService",
"OIDCTokenService",
"OIDCSessionService",
"OIDCAuditService",
]
+408
View File
@@ -0,0 +1,408 @@
"""OIDC Audit Service for comprehensive OIDC event logging."""
from datetime import datetime
from typing import Dict, List, Optional
from flask import g
from app.models import OIDCAuditLog, OIDCClient, User
from app.exceptions.validation_exceptions import NotFoundError
class OIDCAuditService:
"""Service for OIDC-specific audit logging.
This service provides methods to log all OIDC-related events including:
- Authorization requests and responses
- Token issuance and refresh
- Token revocation
- UserInfo endpoint access
- Authentication failures
"""
# Event type constants
EVENT_AUTHORIZATION_REQUEST = "authorization_request"
EVENT_AUTHORIZATION_RESPONSE = "authorization_response"
EVENT_TOKEN_ISSUE = "token_issue"
EVENT_TOKEN_REFRESH = "token_refresh"
EVENT_TOKEN_REVOCATION = "token_revocation"
EVENT_TOKEN_INTROSPECTION = "token_introspection"
EVENT_USERINFO_ACCESS = "userinfo_access"
EVENT_AUTHENTICATION_FAILURE = "authentication_failure"
EVENT_AUTHORIZATION_FAILURE = "authorization_failure"
EVENT_JWKS_ACCESS = "jwks_access"
EVENT_REGISTRATION = "client_registration"
@classmethod
def _get_request_context(cls) -> Dict:
"""Extract request context for logging.
Returns:
Dictionary with IP, user_agent, and request_id
"""
from flask import request
return {
"ip_address": request.remote_addr if request else None,
"user_agent": request.headers.get("User-Agent") if request else None,
"request_id": g.get("request_id"),
}
@classmethod
def log_event(
cls,
event_type: str,
client_id: str = None,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
metadata: Dict = None
) -> OIDCAuditLog:
"""Log a generic OIDC event.
Args:
event_type: Type of event
client_id: OIDC client ID
user_id: User ID
success: Whether the event was successful
error_code: Error code if failed
error_description: Error description if failed
metadata: Additional event metadata
Returns:
OIDCAuditLog instance
"""
context = cls._get_request_context()
log = OIDCAuditLog.log_event(
event_type=event_type,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
ip_address=context["ip_address"],
user_agent=context["user_agent"],
request_id=context["request_id"],
metadata=metadata,
)
return log
@classmethod
def log_authorization_event(
cls,
client_id: str,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
redirect_uri: str = None,
scope: list = None,
response_type: str = None
) -> OIDCAuditLog:
"""Log an authorization event.
Args:
client_id: OIDC client ID
user_id: User ID (if authenticated)
success: Whether authorization was successful
error_code: Error code if failed
error_description: Error description if failed
redirect_uri: Redirect URI from request
scope: Requested scopes
response_type: Response type (e.g., "code")
Returns:
OIDCAuditLog instance
"""
metadata = {
"redirect_uri": redirect_uri,
"scope": scope,
"response_type": response_type,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_AUTHORIZATION_REQUEST,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_token_event(
cls,
client_id: str,
user_id: str = None,
token_type: str = "access_token",
success: bool = True,
error_code: str = None,
error_description: str = None,
grant_type: str = None,
scopes: list = None
) -> OIDCAuditLog:
"""Log a token issuance or refresh event.
Args:
client_id: OIDC client ID
user_id: User ID
token_type: Type of token issued
success: Whether token issuance was successful
error_code: Error code if failed
error_description: Error description if failed
grant_type: Grant type used (e.g., "authorization_code", "refresh_token")
scopes: Scopes included in the token
Returns:
OIDCAuditLog instance
"""
metadata = {
"token_type": token_type,
"grant_type": grant_type,
"scopes": scopes,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_TOKEN_ISSUE if token_type else cls.EVENT_TOKEN_REFRESH,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_userinfo_event(
cls,
access_token: str = None,
user_id: str = None,
client_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
scopes_claimed: list = None
) -> OIDCAuditLog:
"""Log a UserInfo endpoint access event.
Args:
access_token: Access token used (masked)
user_id: User ID returned
client_id: Client ID making the request
success: Whether access was successful
error_code: Error code if failed
error_description: Error description if failed
scopes_claimed: Scopes claimed in the request
Returns:
OIDCAuditLog instance
"""
# Mask the access token for security
masked_token = None
if access_token:
masked_token = access_token[:8] + "..." + access_token[-4:] if len(access_token) > 12 else "***"
metadata = {
"token_prefix": masked_token,
"scopes_claimed": scopes_claimed,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_USERINFO_ACCESS,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_token_revocation_event(
cls,
client_id: str,
user_id: str = None,
token_type: str = "access_token",
reason: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None
) -> OIDCAuditLog:
"""Log a token revocation event.
Args:
client_id: OIDC client ID
user_id: User ID
token_type: Type of token being revoked
reason: Revocation reason
success: Whether revocation was successful
error_code: Error code if failed
error_description: Error description if failed
Returns:
OIDCAuditLog instance
"""
metadata = {
"token_type": token_type,
"reason": reason,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_TOKEN_REVOCATION,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_authentication_failure(
cls,
client_id: str = None,
error_code: str = "authentication_failed",
error_description: str = "Authentication failed",
user_id: str = None
) -> OIDCAuditLog:
"""Log an authentication failure event.
Args:
client_id: OIDC client ID
error_code: Error code
error_description: Error description
user_id: User ID if known
Returns:
OIDCAuditLog instance
"""
return cls.log_event(
event_type=cls.EVENT_AUTHENTICATION_FAILURE,
client_id=client_id,
user_id=user_id,
success=False,
error_code=error_code,
error_description=error_description,
)
@classmethod
def get_events_for_user(
cls,
user_id: str,
limit: int = 100,
include_deleted: bool = False
) -> List[OIDCAuditLog]:
"""Get audit events for a specific user.
Args:
user_id: User ID
limit: Maximum number of events to return
include_deleted: Include soft-deleted events
Returns:
List of OIDCAuditLog instances
"""
return OIDCAuditLog.get_events_for_user(user_id, limit)
@classmethod
def get_events_for_client(
cls,
client_id: str,
limit: int = 100
) -> List[OIDCAuditLog]:
"""Get audit events for a specific client.
Args:
client_id: Client ID
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
return OIDCAuditLog.get_events_for_client(client_id, limit)
@classmethod
def get_failed_events(
cls,
client_id: str = None,
user_id: str = None,
start_date: datetime = None,
end_date: datetime = None,
limit: int = 100
) -> List[OIDCAuditLog]:
"""Get failed audit events for analysis.
Args:
client_id: Optional client ID filter
user_id: Optional user ID filter
start_date: Optional start date filter
end_date: Optional end date filter
limit: Maximum number of events to return
Returns:
List of failed OIDCAuditLog instances
"""
return OIDCAuditLog.get_failed_events(
client_id=client_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
limit=limit,
)
@classmethod
def get_event_summary(
cls,
client_id: str = None,
days: int = 30
) -> Dict:
"""Get a summary of audit events.
Args:
client_id: Optional client ID filter
days: Number of days to look back
Returns:
Summary dictionary with event counts
"""
from datetime import timedelta
start_date = datetime.utcnow() - timedelta(days=days)
query = OIDCAuditLog.query.filter(
OIDCAuditLog.created_at >= start_date
)
if client_id:
query = query.filter_by(client_id=client_id)
events = query.all()
# Count by event type
event_counts = {}
success_count = 0
failure_count = 0
for event in events:
event_type = event.event_type
event_counts[event_type] = event_counts.get(event_type, 0) + 1
if event.success:
success_count += 1
else:
failure_count += 1
return {
"total_events": len(events),
"successful_events": success_count,
"failed_events": failure_count,
"by_event_type": event_counts,
"period_days": days,
}
+300
View File
@@ -0,0 +1,300 @@
"""OIDC JWKS Service for key management and rotation."""
import uuid
import json
import hashlib
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from flask import current_app
from app.extensions import db
class JWKSKey:
"""Represents a JWKS key entry."""
def __init__(self, kid: str, private_key: str, public_key: str,
algorithm: str = "RS256", created_at: datetime = None,
expires_at: datetime = None, is_active: bool = True):
self.kid = kid
self.private_key = private_key
self.public_key = public_key
self.algorithm = algorithm
self.created_at = created_at or datetime.utcnow()
self.expires_at = expires_at or datetime.utcnow() + timedelta(days=365)
self.is_active = is_active
def to_jwk(self) -> Dict:
"""Convert to JWK format for JWKS endpoint."""
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.backends import default_backend
# Import cryptography here to avoid issues if not installed
try:
# Get public key from PEM
public_key = serialization.load_pem_public_key(
self.public_key.encode(), backend=default_backend()
)
# Get RSA parameters
public_numbers = public_key.public_numbers()
return {
"kty": "RSA",
"kid": self.kid,
"use": "sig",
"alg": self.algorithm,
"n": _base64url_encode(public_numbers.n),
"e": _base64url_encode(public_numbers.e),
}
except ImportError:
# Fallback for when cryptography is not installed
return {
"kty": "RSA",
"kid": self.kid,
"use": "sig",
"alg": self.algorithm,
}
def to_dict(self) -> Dict:
"""Convert to dictionary for storage."""
return {
"kid": self.kid,
"private_key": self.private_key,
"public_key": self.public_key,
"algorithm": self.algorithm,
"created_at": self.created_at.isoformat(),
"expires_at": self.expires_at.isoformat(),
"is_active": self.is_active,
}
@classmethod
def from_dict(cls, data: Dict) -> "JWKSKey":
"""Create from dictionary."""
return cls(
kid=data["kid"],
private_key=data["private_key"],
public_key=data["public_key"],
algorithm=data.get("algorithm", "RS256"),
created_at=datetime.fromisoformat(data["created_at"]),
expires_at=datetime.fromisoformat(data["expires_at"]),
is_active=data.get("is_active", True),
)
def _base64url_encode(value: int) -> str:
"""Encode an integer to base64url format."""
import base64
byte_length = (value.bit_length() + 7) // 8 or 1
encoded = value.to_bytes(byte_length, byteorder="big")
return base64.urlsafe_b64encode(encoded).decode().rstrip("=")
class OIDCJWKSService:
"""Service for managing OIDC signing keys (JWKS).
This service handles RSA key pair generation, rotation, and JWKS document
generation for the OIDC implementation.
"""
_instance = None
_keys: Dict[str, JWKSKey] = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._keys = {}
return cls._instance
@classmethod
def reset(cls):
"""Reset the singleton (for testing)."""
cls._instance = None
cls._keys = {}
def _generate_kid(self, private_key: str) -> str:
"""Generate a key ID from the private key fingerprint."""
kid_hash = hashlib.sha256(private_key.encode()).hexdigest()[:32]
return kid_hash
def _generate_rsa_key_pair(self) -> Tuple[str, str]:
"""Generate a new RSA key pair in PEM format.
Returns:
Tuple of (private_key_pem, public_key_pem)
"""
try:
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
# Generate RSA private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
# Get public key
public_key = private_key.public_key()
# Serialize to PEM
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
).decode()
public_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode()
return private_pem, public_pem
except ImportError:
# Fallback for testing without cryptography
import secrets
return f"private_key_{secrets.token_hex(32)}", f"public_key_{secrets.token_hex(32)}"
def get_jwks(self, include_private_keys: bool = False) -> Dict:
"""Get the JWKS document containing public keys.
Args:
include_private_keys: Whether to include private keys (for internal use only)
Returns:
JWKS document dictionary
"""
now = datetime.utcnow()
keys = []
for kid, key in self._keys.items():
# Only include active, non-expired keys
if key.is_active and key.expires_at > now:
if include_private_keys:
keys.append(key.to_dict())
else:
keys.append(key.to_jwk())
return {
"keys": keys
}
def get_signing_key(self) -> Optional[JWKSKey]:
"""Get the current active signing key.
Returns:
JWKSKey instance or None if no active key
"""
now = datetime.utcnow()
for kid, key in self._keys.items():
if key.is_active and key.expires_at > now:
return key
return None
def get_key_by_kid(self, kid: str) -> Optional[JWKSKey]:
"""Get a specific key by its ID.
Args:
kid: Key ID to look up
Returns:
JWKSKey instance or None if not found
"""
return self._keys.get(kid)
def generate_new_key_pair(self, expires_in_days: int = 365) -> JWKSKey:
"""Generate a new RSA key pair for signing.
Args:
expires_in_days: Days until key expiration
Returns:
JWKSKey instance
"""
private_key, public_key = self._generate_rsa_key_pair()
kid = self._generate_kid(private_key)
now = datetime.utcnow()
key = JWKSKey(
kid=kid,
private_key=private_key,
public_key=public_key,
algorithm="RS256",
created_at=now,
expires_at=now + timedelta(days=expires_in_days),
is_active=True,
)
self._keys[kid] = key
# Deactivate old keys (but keep them for grace period)
for old_kid in self._keys:
if old_kid != kid:
self._keys[old_kid].is_active = False
return key
def rotate_keys(self, grace_period_hours: int = 24) -> Tuple[JWKSKey, List[str]]:
"""Rotate signing keys, keeping previous key active for grace period.
Args:
grace_period_hours: Hours to keep old keys active
Returns:
Tuple of (new_key, list_of_deprecated_kids)
"""
now = datetime.utcnow()
grace_end = now + timedelta(hours=grace_period_hours)
# Mark current key as deprecated
current_key = self.get_signing_key()
deprecated_kids = []
if current_key:
deprecated_kids.append(current_key.kid)
# Keep key active but mark as deprecated
current_key.is_active = False
current_key.expires_at = grace_end
# Generate new key
new_key = self.generate_new_key_pair()
# Clean up expired keys
expired_kids = [
kid for kid, key in self._keys.items()
if key.expires_at < now
]
for kid in expired_kids:
del self._keys[kid]
return new_key, deprecated_kids
def verify_key_exists(self, kid: str) -> bool:
"""Check if a key with the given ID exists and is valid.
Args:
kid: Key ID to check
Returns:
True if key exists and is valid
"""
key = self.get_key_by_kid(kid)
if not key:
return False
now = datetime.utcnow()
return key.is_active and key.expires_at > now
def initialize_with_key(self) -> JWKSKey:
"""Initialize the service with a key if none exists.
Returns:
JWKSKey instance
"""
if not self._keys:
return self.generate_new_key_pair()
return self.get_signing_key()
+745
View File
@@ -0,0 +1,745 @@
"""OIDC Service - Main OIDC service layer."""
import secrets
import hashlib
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from flask import current_app, g
from app.extensions import db
from app.models import (
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
OIDCSession, OIDCTokenMetadata
)
from app.exceptions.validation_exceptions import (
ValidationError, NotFoundError, BadRequestError
)
from app.exceptions.auth_exceptions import UnauthorizedError, InvalidTokenError
from app.services.oidc_token_service import OIDCTokenService
from app.services.oidc_session_service import OIDCSessionService
from app.services.oidc_audit_service import OIDCAuditService
from app.services.oidc_jwks_service import OIDCJWKSService
class OIDCError(Exception):
"""Base exception for OIDC errors."""
def __init__(self, error: str, error_description: str = None, status_code: int = 400):
self.error = error
self.error_description = error_description
self.status_code = status_code
class InvalidClientError(OIDCError):
"""Raised when client authentication fails."""
def __init__(self, error_description: str = "Invalid client"):
super().__init__("invalid_client", error_description, 401)
class InvalidGrantError(OIDCError):
"""Raised when grant is invalid."""
def __init__(self, error_description: str = "Invalid grant"):
super().__init__("invalid_grant", error_description, 400)
class InvalidRequestError(OIDCError):
"""Raised when request is malformed."""
def __init__(self, error_description: str = "Invalid request"):
super().__init__("invalid_request", error_description, 400)
class OIDCService:
"""Main OIDC service handling all OpenID Connect operations.
This service provides:
- Authorization code generation and validation
- Token generation (access, refresh, ID tokens)
- Token refresh with rotation
- Token validation and introspection
- Token revocation
"""
@staticmethod
def _generate_code() -> str:
"""Generate a secure authorization code.
Returns:
URL-safe base64 encoded code
"""
return secrets.token_urlsafe(32)
@staticmethod
def _hash_value(value: str) -> str:
"""Hash a value for secure storage.
Args:
value: Value to hash
Returns:
SHA256 hash
"""
return hashlib.sha256(value.encode()).hexdigest()
@classmethod
def generate_authorization_code(
cls,
client_id: str,
user_id: str,
redirect_uri: str,
scope: list,
state: str,
nonce: str,
code_challenge: str = None,
code_challenge_method: str = None,
ip_address: str = None,
user_agent: str = None
) -> str:
"""Generate an authorization code for the auth code flow.
Args:
client_id: OIDC client ID
user_id: User ID
redirect_uri: Redirect URI
scope: Requested scopes
state: State parameter
nonce: Nonce for ID token
code_challenge: PKCE code challenge
code_challenge_method: PKCE method ("S256" or "plain")
ip_address: Client IP address
user_agent: Client user agent
Returns:
Authorization code string
Raises:
ValidationError: If parameters are invalid
NotFoundError: If client not found
"""
# Validate client exists and is active
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
raise NotFoundError("Client not found")
if not client.is_active:
raise ValidationError("Client is not active")
# Validate redirect URI
if not client.is_redirect_uri_allowed(redirect_uri):
raise ValidationError("Invalid redirect_uri")
# Validate scopes
allowed_scopes = client.scopes or []
valid_scopes = [s for s in scope if s in allowed_scopes]
if not valid_scopes:
raise ValidationError("Invalid scopes")
# Generate authorization code
code = cls._generate_code()
code_hash = cls._hash_value(code)
# Create auth code record
auth_code = OIDCAuthCode.create_code(
client_id=client.id,
user_id=user_id,
code_hash=code_hash,
redirect_uri=redirect_uri,
scope=valid_scopes,
nonce=nonce,
code_verifier=code_challenge, # Store for validation
ip_address=ip_address,
user_agent=user_agent,
lifetime_seconds=600, # 10 minutes
)
# Log authorization event
OIDCAuditService.log_authorization_event(
client_id=client_id,
user_id=user_id,
success=True,
redirect_uri=redirect_uri,
scope=valid_scopes,
)
return code
@classmethod
def validate_authorization_code(
cls,
code: str,
client_id: str,
redirect_uri: str,
code_verifier: str = None,
ip_address: str = None,
user_agent: str = None
) -> Tuple[Dict, User]:
"""Validate and consume an authorization code.
Args:
code: Authorization code
client_id: OIDC client ID
redirect_uri: Redirect URI
code_verifier: PKCE code verifier (required if PKCE was used)
ip_address: Client IP address
user_agent: Client user agent
Returns:
Tuple of (claims dict, User instance)
Raises:
InvalidGrantError: If code is invalid
ValidationError: If PKCE validation fails
"""
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
raise InvalidGrantError("Invalid client")
# Hash the provided code and find matching auth code
code_hash = cls._hash_value(code)
auth_code = OIDCAuthCode.query.filter_by(
code_hash=code_hash,
client_id=client.id,
deleted_at=None
).first()
if not auth_code:
OIDCAuditService.log_authorization_event(
client_id=client_id,
success=False,
error_code="invalid_grant",
error_description="Invalid or expired authorization code",
)
raise InvalidGrantError("Invalid or expired authorization code")
# Check if already used
if auth_code.is_used:
OIDCAuditService.log_authorization_event(
client_id=client_id,
user_id=auth_code.user_id,
success=False,
error_code="invalid_grant",
error_description="Authorization code already used",
)
raise InvalidGrantError("Authorization code already used")
# Check expiration
if auth_code.is_expired():
OIDCAuditService.log_authorization_event(
client_id=client_id,
user_id=auth_code.user_id,
success=False,
error_code="invalid_grant",
error_description="Authorization code expired",
)
raise InvalidGrantError("Authorization code expired")
# Validate redirect URI
if auth_code.redirect_uri != redirect_uri:
raise InvalidGrantError("Invalid redirect_uri")
# Validate PKCE if required
if client.require_pkce and auth_code.code_verifier:
if not code_verifier:
raise ValidationError("code_verifier is required")
# Verify code verifier
expected_challenge = cls._compute_code_challenge(code_verifier, "S256")
if expected_challenge != auth_code.code_verifier:
OIDCAuditService.log_authorization_event(
client_id=client_id,
user_id=auth_code.user_id,
success=False,
error_code="invalid_grant",
error_description="Invalid code_verifier",
)
raise InvalidGrantError("Invalid code_verifier")
# Mark code as used
auth_code.mark_as_used()
# Get user
user = User.query.get(auth_code.user_id)
if not user:
raise InvalidGrantError("User not found")
claims = {
"user_id": auth_code.user_id,
"client_id": client_id,
"redirect_uri": redirect_uri,
"scope": auth_code.scope,
"nonce": auth_code.nonce,
}
return claims, user
@classmethod
def _compute_code_challenge(cls, verifier: str, method: str = "S256") -> str:
"""Compute PKCE code challenge from verifier.
Args:
verifier: Code verifier
method: Challenge method
Returns:
Code challenge
"""
import hashlib
import base64
if method == "S256":
digest = hashlib.sha256(verifier.encode()).digest()
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
return verifier
@classmethod
def generate_tokens(
cls,
client_id: str,
user_id: str,
scope: list,
nonce: str = None,
refresh_token: str = None,
ip_address: str = None,
user_agent: str = None,
auth_time: int = None
) -> Dict:
"""Generate access token, ID token, and refresh token.
Args:
client_id: OIDC client ID
user_id: User ID
scope: Granted scopes
nonce: Nonce for ID token
refresh_token: Existing refresh token (for rotation)
ip_address: Client IP address
user_agent: Client user agent
auth_time: Authentication time
Returns:
Dictionary with tokens
"""
import hashlib
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
raise InvalidClientError()
# Generate access token
access_token_jti = OIDCTokenService._generate_jti()
access_token = OIDCTokenService.create_access_token(
client_id=client_id,
user_id=user_id,
scope=scope,
jti=access_token_jti,
)
# Generate ID token
id_token = OIDCTokenService.create_id_token(
client_id=client_id,
user_id=user_id,
nonce=nonce,
scope=scope,
access_token=access_token,
auth_time=auth_time,
)
# Generate or rotate refresh token
if "refresh_token" in (client.grant_types or []):
if refresh_token:
# Rotate existing refresh token
refresh_token_obj = OIDCRefreshToken.query.filter_by(
token_hash=hashlib.sha256(refresh_token.encode()).hexdigest(),
deleted_at=None
).first()
if refresh_token_obj and refresh_token_obj.is_valid():
# Create new refresh token
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
client_id=client_id,
user_id=user_id,
scope=scope,
access_token_id=access_token_jti,
)
# Rotate in database
refresh_token_obj.rotate(new_hash)
final_refresh_token = new_refresh
else:
final_refresh_token = None
else:
# Create new refresh token
final_refresh_token, refresh_hash = OIDCTokenService.create_refresh_token(
client_id=client_id,
user_id=user_id,
scope=scope,
access_token_id=access_token_jti,
)
# Store refresh token
OIDCRefreshToken.create_token(
client_id=client.id,
user_id=user_id,
token_hash=refresh_hash,
scope=scope,
access_token_id=access_token_jti,
ip_address=ip_address,
user_agent=user_agent,
lifetime_seconds=client.refresh_token_lifetime or 2592000,
)
else:
final_refresh_token = None
# Store token metadata
client_db_id = client.id
# Access token metadata
OIDCTokenMetadata.create_metadata(
client_id=client_db_id,
user_id=user_id,
token_type="access_token",
token_jti=access_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
)
# ID token metadata (using access token JTI as reference)
id_token_jti = OIDCTokenService._generate_jti()
OIDCTokenMetadata.create_metadata(
client_id=client_db_id,
user_id=user_id,
token_type="id_token",
token_jti=id_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.id_token_lifetime or 3600),
)
# Log token event
OIDCAuditService.log_token_event(
client_id=client_id,
user_id=user_id,
token_type="access_token",
success=True,
grant_type="authorization_code",
scopes=scope,
)
result = {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": client.access_token_lifetime or 3600,
"id_token": id_token,
}
if final_refresh_token:
result["refresh_token"] = final_refresh_token
return result
@classmethod
def refresh_access_token(
cls,
refresh_token: str,
client_id: str,
scope: list = None,
ip_address: str = None,
user_agent: str = None
) -> Dict:
"""Refresh an access token with token rotation.
Args:
refresh_token: The refresh token
client_id: OIDC client ID
scope: Optional scope override
ip_address: Client IP address
user_agent: Client user agent
Returns:
Dictionary with new tokens
Raises:
InvalidGrantError: If refresh token is invalid
"""
import hashlib
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
raise InvalidClientError()
# Find refresh token
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
refresh_token_obj = OIDCRefreshToken.query.filter_by(
token_hash=token_hash,
deleted_at=None
).first()
if not refresh_token_obj:
OIDCAuditService.log_token_event(
client_id=client_id,
success=False,
error_code="invalid_grant",
error_description="Invalid refresh token",
)
raise InvalidGrantError("Invalid refresh token")
# Check if valid
if not refresh_token_obj.is_valid():
OIDCAuditService.log_token_event(
client_id=client_id,
user_id=refresh_token_obj.user_id,
success=False,
error_code="invalid_grant",
error_description="Refresh token expired or revoked",
)
raise InvalidGrantError("Refresh token expired or revoked")
# Validate client matches
if refresh_token_obj.client_id != client.id:
raise InvalidGrantError("Client mismatch")
# Get original scope or use provided
granted_scope = scope or (refresh_token_obj.scope or [])
# Generate new access token
access_token_jti = OIDCTokenService._generate_jti()
access_token = OIDCTokenService.create_access_token(
client_id=client_id,
user_id=refresh_token_obj.user_id,
scope=granted_scope,
jti=access_token_jti,
)
# Generate new ID token
id_token = OIDCTokenService.create_id_token(
client_id=client_id,
user_id=refresh_token_obj.user_id,
scope=granted_scope,
access_token=access_token,
)
# Rotate refresh token
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
client_id=client_id,
user_id=refresh_token_obj.user_id,
scope=granted_scope,
access_token_id=access_token_jti,
)
refresh_token_obj.rotate(new_hash)
# Store new token metadata
OIDCTokenMetadata.create_metadata(
client_id=client.id,
user_id=refresh_token_obj.user_id,
token_type="access_token",
token_jti=access_token_jti,
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
)
# Log refresh event
OIDCAuditService.log_token_event(
client_id=client_id,
user_id=refresh_token_obj.user_id,
token_type="access_token",
success=True,
grant_type="refresh_token",
scopes=granted_scope,
)
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": client.access_token_lifetime or 3600,
"id_token": id_token,
"refresh_token": new_refresh,
}
@classmethod
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
"""Validate an access token and return its claims.
Args:
token: JWT access token
client_id: Optional client ID to validate audience
Returns:
Token claims
Raises:
InvalidTokenError: If token is invalid
"""
try:
claims = OIDCTokenService.validate_access_token(token, client_id)
return claims
except Exception as e:
OIDCAuditService.log_event(
event_type="token_validation",
client_id=client_id,
success=False,
error_code="invalid_token",
error_description=str(e),
)
raise InvalidTokenError(str(e))
@classmethod
def revoke_token(
cls,
token: str,
client_id: str,
token_type_hint: str = None,
ip_address: str = None,
user_agent: str = None
) -> bool:
"""Revoke a token.
Args:
token: Token to revoke
client_id: OIDC client ID
token_type_hint: Hint about token type
ip_address: Client IP address
user_agent: Client user agent
Returns:
True if token was revoked
"""
import hashlib
# Get client
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
raise InvalidClientError()
revoked = False
token_hash = hashlib.sha256(token.encode()).hexdigest()
# Try to revoke as refresh token
if token_type_hint in (None, "refresh_token"):
refresh_token = OIDCRefreshToken.query.filter_by(
token_hash=token_hash,
deleted_at=None
).first()
if refresh_token:
refresh_token.revoke(reason="revoked_by_client")
revoked = True
OIDCAuditService.log_token_revocation_event(
client_id=client_id,
user_id=refresh_token.user_id,
token_type="refresh_token",
reason="revoked_by_client",
)
# Try to revoke as access token (JTI lookup)
if not revoked or token_type_hint in (None, "access_token"):
try:
# Decode token to get JTI
claims = OIDCTokenService.decode_token(token)
jti = claims.get("jti")
if jti:
revoked_at = OIDCTokenMetadata.revoke_by_jti(
jti,
reason="revoked_by_client"
)
if revoked_at:
revoked = True
OIDCAuditService.log_token_revocation_event(
client_id=client_id,
user_id=claims.get("sub"),
token_type="access_token",
reason="revoked_by_client",
)
except Exception:
pass
return revoked
@classmethod
def introspect_token(
cls,
token: str,
client_id: str = None,
ip_address: str = None,
user_agent: str = None
) -> Dict:
"""Introspect a token and return its status and claims.
Args:
token: Token to introspect
client_id: Client ID for validation
ip_address: Client IP address
user_agent: Client user agent
Returns:
Introspection response
"""
result = OIDCTokenService.introspect_token(token, client_id)
# Log introspection
OIDCAuditService.log_event(
event_type="token_introspection",
client_id=client_id,
user_id=result.get("sub"),
success=result.get("active", False),
metadata={"active": result.get("active")},
)
return result
@classmethod
def get_jwks(cls) -> Dict:
"""Get the JWKS document.
Returns:
JWKS document
"""
jwks_service = OIDCJWKSService()
return jwks_service.get_jwks()
@classmethod
def get_userinfo(cls, access_token: str) -> Dict:
"""Get user information using access token.
Args:
access_token: Access token
Returns:
User information dictionary
"""
claims = cls.validate_access_token(access_token)
user_id = claims.get("sub")
user = User.query.get(user_id)
if not user:
raise NotFoundError("User not found")
# Get scopes from token
scope_str = claims.get("scope", "")
scopes = scope_str.split() if scope_str else []
userinfo = {"sub": user_id}
# Add claims based on scope
if "profile" in scopes and user.full_name:
userinfo["name"] = user.full_name
if "email" in scopes:
userinfo["email"] = user.email
userinfo["email_verified"] = user.email_verified
# Log userinfo access
OIDCAuditService.log_userinfo_event(
access_token=access_token,
user_id=user_id,
client_id=claims.get("client_id"),
success=True,
scopes_claimed=scopes,
)
return userinfo
+288
View File
@@ -0,0 +1,288 @@
"""OIDC Session Service for session management during OIDC flow."""
import secrets
from datetime import datetime, timedelta
from typing import Dict, Optional, Tuple
from flask import current_app, g
from app.extensions import db
from app.models import OIDCSession, OIDCClient, User
from app.exceptions.validation_exceptions import NotFoundError, ValidationError
class OIDCSessionService:
"""Service for managing OIDC authentication sessions.
This service handles:
- Creating OIDC sessions during authorization flow
- Validating sessions with state and nonce
- Managing PKCE code challenges
- Cleaning up expired sessions
"""
@staticmethod
def _generate_state() -> str:
"""Generate a secure state parameter.
Returns:
URL-safe base64 encoded state
"""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_nonce() -> str:
"""Generate a secure nonce for OIDC.
Returns:
URL-safe base64 encoded nonce
"""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_code_challenge(verifier: str, method: str = "S256") -> str:
"""Generate a PKCE code challenge from verifier.
Args:
verifier: The code verifier
method: Challenge method ("S256" or "plain")
Returns:
Code challenge string
"""
import hashlib
import base64
if method == "S256":
digest = hashlib.sha256(verifier.encode()).digest()
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
elif method == "plain":
return verifier
else:
raise ValueError(f"Unsupported code challenge method: {method}")
@classmethod
def validate_code_verifier(cls, code_verifier: str, code_challenge: str,
method: str = "S256") -> bool:
"""Validate a PKCE code verifier against the stored challenge.
Args:
code_verifier: The code verifier from the token request
code_challenge: The code challenge from the authorization request
method: The challenge method used
Returns:
True if validation succeeds
"""
if not code_verifier or not code_challenge:
return False
# Validate code verifier length (43-128 characters)
if method == "S256" and not (43 <= len(code_verifier) <= 128):
return False
# Calculate expected challenge
expected_challenge = cls._generate_code_challenge(code_verifier, method)
return secrets.compare_digest(expected_challenge, code_challenge)
@classmethod
def create_session(
cls,
user_id: str,
client_id: str,
state: str = None,
nonce: str = None,
redirect_uri: str = None,
scope: list = None,
code_challenge: str = None,
code_challenge_method: str = None,
lifetime_seconds: int = 600
) -> OIDCSession:
"""Create a new OIDC session for the authorization flow.
Args:
user_id: The user ID
client_id: The OIDC client ID
state: State parameter (generated if not provided)
nonce: Nonce for ID token validation (generated if not provided)
redirect_uri: Redirect URI from authorization request
scope: Requested scopes
code_challenge: PKCE code challenge
code_challenge_method: PKCE method ("S256" or "plain")
lifetime_seconds: Session lifetime in seconds
Returns:
OIDCSession instance
"""
# Generate state and nonce if not provided
state = state or cls._generate_state()
nonce = nonce or cls._generate_nonce()
session = OIDCSession.create_session(
user_id=user_id,
client_id=client_id,
state=state,
nonce=nonce,
redirect_uri=redirect_uri,
scope=scope,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
lifetime_seconds=lifetime_seconds,
)
return session
@classmethod
def validate_session(cls, state: str, nonce: str = None) -> Tuple[OIDCSession, User]:
"""Validate an OIDC session by state and optionally nonce.
Args:
state: The state parameter
nonce: The nonce to validate (optional)
Returns:
Tuple of (OIDCSession, User)
Raises:
ValidationError: If session is invalid
NotFoundError: If session not found
"""
session = OIDCSession.get_by_state(state)
if not session:
raise NotFoundError("OIDC session not found or expired")
if session.is_expired():
raise ValidationError("OIDC session has expired")
# Validate nonce if provided
if nonce and not session.validate_nonce(nonce):
raise ValidationError("Invalid nonce")
# Get user
user = User.query.get(session.user_id)
if not user:
raise NotFoundError("User not found")
return session, user
@classmethod
def validate_pkce(cls, session: OIDCSession, code_verifier: str) -> bool:
"""Validate PKCE code verifier against the session's code challenge.
Args:
session: OIDCSession instance
code_verifier: The code verifier from token request
Returns:
True if validation succeeds
Raises:
ValidationError: If PKCE validation fails
"""
if not session.code_challenge:
# No PKCE was used, skip validation
return True
if not code_verifier:
raise ValidationError("code_verifier is required")
is_valid = session.validate_code_challenge(code_verifier)
if not is_valid:
raise ValidationError("Invalid code_verifier")
return True
@classmethod
def mark_session_authenticated(cls, session: OIDCSession) -> OIDCSession:
"""Mark a session as authenticated (user has logged in).
Args:
session: OIDCSession instance
Returns:
Updated OIDCSession instance
"""
session.mark_authenticated()
return session
@classmethod
def cleanup_expired_sessions(cls, older_than_hours: int = 24) -> int:
"""Remove expired OIDC sessions.
Args:
older_than_hours: Only delete sessions expired more than this many hours ago
Returns:
Number of sessions deleted
"""
from datetime import timedelta
cutoff = datetime.utcnow() - timedelta(hours=older_than_hours)
# Get expired sessions
expired_sessions = OIDCSession.query.filter(
OIDCSession.expires_at < datetime.utcnow(),
OIDCSession.deleted_at == None
).all()
count = 0
for session in expired_sessions:
# Only hard delete if past the grace period
if session.expires_at < cutoff:
session.delete()
count += 1
return count
@classmethod
def get_session_by_state(cls, state: str) -> Optional[OIDCSession]:
"""Get an OIDC session by state.
Args:
state: The state parameter
Returns:
OIDCSession instance or None
"""
return OIDCSession.get_by_state(state)
@classmethod
def validate_redirect_uri(cls, client_id: str, redirect_uri: str) -> bool:
"""Validate that a redirect URI is allowed for a client.
Args:
client_id: The OIDC client ID
redirect_uri: The redirect URI to validate
Returns:
True if redirect URI is allowed
"""
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return False
return client.is_redirect_uri_allowed(redirect_uri)
@classmethod
def validate_scopes(cls, client_id: str, requested_scopes: list) -> list:
"""Validate and filter scopes against client's allowed scopes.
Args:
client_id: The OIDC client ID
requested_scopes: List of requested scopes
Returns:
List of allowed scopes
"""
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return []
allowed_scopes = client.scopes or []
# Filter to only allowed scopes
valid_scopes = [s for s in requested_scopes if s in allowed_scopes]
return valid_scopes
+439
View File
@@ -0,0 +1,439 @@
"""OIDC Token Service for JWT token generation and validation."""
import hashlib
import base64
import secrets
from datetime import datetime, timedelta
from typing import Dict, Optional, Any
import jwt
from flask import current_app, g
from app.models import User, OIDCClient
from app.services.oidc_jwks_service import OIDCJWKSService
class OIDCTokenService:
"""Service for generating and validating OIDC tokens.
This service handles:
- Access token creation (JWT)
- ID token creation (JWT)
- Refresh token creation (opaque)
- Token signature verification
- Hash generation for PKCE claims (at_hash, c_hash)
"""
@staticmethod
def _generate_jti() -> str:
"""Generate a unique JWT ID."""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_opaque_token(length: int = 43) -> str:
"""Generate an opaque token (for refresh tokens).
Args:
length: Length of the token
Returns:
URL-safe base64 encoded token
"""
return secrets.token_urlsafe(length)
@staticmethod
def _hash_token(token: str) -> str:
"""Hash a token for secure storage.
Args:
token: Token to hash
Returns:
SHA256 hash of the token
"""
return hashlib.sha256(token.encode()).hexdigest()
@staticmethod
def _base64url_encode(data: bytes) -> str:
"""Encode bytes to base64url format without padding.
Args:
data: Bytes to encode
Returns:
Base64url encoded string
"""
return base64.urlsafe_b64encode(data).decode().rstrip("=")
@staticmethod
def create_at_hash(access_token: str) -> str:
"""Create the at_hash claim for ID token.
Implements OIDC spec for access token hash generation.
Hash is the left-most half of the hash of the ASCII representation
of the access token.
Args:
access_token: The access token string
Returns:
Base64url encoded hash
"""
# Hash the access token using SHA256
hash_digest = hashlib.sha256(access_token.encode()).digest()
# Take left-most half of the hash
half_length = len(hash_digest) // 2
left_half = hash_digest[:half_length]
# Base64url encode
return OIDCTokenService._base64url_encode(left_half)
@staticmethod
def create_c_hash(code: str) -> str:
"""Create the c_hash claim for ID token.
Implements OIDC spec for authorization code hash generation.
Args:
code: The authorization code string
Returns:
Base64url encoded hash
"""
# Hash the code using SHA256
hash_digest = hashlib.sha256(code.encode()).digest()
# Take left-most half of the hash
half_length = len(hash_digest) // 2
left_half = hash_digest[:half_length]
# Base64url encode
return OIDCTokenService._base64url_encode(left_half)
@staticmethod
def _get_issuer() -> str:
"""Get the OIDC issuer URL."""
return current_app.config.get("OIDC_ISSUER_URL", "http://localhost:5000")
@staticmethod
def _get_token_lifetime(client: OIDCClient, token_type: str) -> int:
"""Get the token lifetime in seconds for a client.
Args:
client: OIDCClient instance
token_type: Type of token ("access_token", "refresh_token", "id_token")
Returns:
Lifetime in seconds
"""
lifetimes = {
"access_token": client.access_token_lifetime or 3600,
"refresh_token": client.refresh_token_lifetime or 2592000,
"id_token": client.id_token_lifetime or 3600,
}
return lifetimes.get(token_type, 3600)
@classmethod
def create_access_token(cls, client_id: str, user_id: str, scope: list,
jti: str = None) -> str:
"""Create a JWT access token.
Args:
client_id: The OIDC client ID
user_id: The user ID (subject)
scope: List of granted scopes
jti: Optional JWT ID (generated if not provided)
Returns:
JWT access token string
"""
jti = jti or cls._generate_jti()
now = datetime.utcnow()
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
claims = {
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
"iat": int(now.timestamp()),
"nbf": int(now.timestamp()),
"jti": jti,
"client_id": client_id,
"scope": " ".join(scope) if isinstance(scope, list) else scope,
}
# Get signing key
jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key()
if not signing_key:
raise ValueError("No signing key available")
# Sign with RS256
token = jwt.encode(
claims,
signing_key.private_key,
algorithm="RS256",
headers={"kid": signing_key.kid}
)
return token
@classmethod
def create_id_token(cls, client_id: str, user_id: str, nonce: str = None,
scope: list = None, access_token: str = None,
auth_time: int = None) -> str:
"""Create a JWT ID token.
Args:
client_id: The OIDC client ID
user_id: The user ID (subject)
nonce: Nonce for replay protection
scope: Requested/Granted scopes
access_token: Associated access token (for at_hash)
auth_time: Authentication time (Unix timestamp)
Returns:
JWT ID token string
"""
now = datetime.utcnow()
auth_time = auth_time or int(now.timestamp())
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
# Get user for claims
user = User.query.get(user_id)
claims = {
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
"iat": int(now.timestamp()),
"auth_time": auth_time,
}
# Add nonce if provided
if nonce:
claims["nonce"] = nonce
# Add at_hash if access token provided
if access_token:
claims["at_hash"] = cls.create_at_hash(access_token)
# Add standard claims if user exists
if user:
if user.email:
claims["email"] = user.email
claims["email_verified"] = user.email_verified
if user.full_name:
claims["name"] = user.full_name
# Add scope if provided
if scope:
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
# Get signing key
jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key()
if not signing_key:
raise ValueError("No signing key available")
# Sign with RS256
token = jwt.encode(
claims,
signing_key.private_key,
algorithm="RS256",
headers={"kid": signing_key.kid}
)
return token
@classmethod
def create_refresh_token(cls, client_id: str, user_id: str,
scope: list = None, access_token_id: str = None) -> str:
"""Create an opaque refresh token.
Args:
client_id: The OIDC client ID
user_id: The user ID
scope: List of granted scopes
access_token_id: Associated access token ID
Returns:
Opaque refresh token string
"""
token = cls._generate_opaque_token()
# Hash for storage
token_hash = cls._hash_token(token)
return token, token_hash
@classmethod
def verify_token_signature(cls, token: str) -> Dict:
"""Verify the signature of a JWT token.
Args:
token: JWT token string
Returns:
Decoded token claims
Raises:
jwt.InvalidSignatureError: If signature verification fails
jwt.ExpiredSignatureError: If token is expired
jwt.InvalidTokenError: If token is invalid
"""
# Get the JWKS with public keys
jwks_service = OIDCJWKSService()
jwks = jwks_service.get_jwks()
# Get the key ID from token header
try:
unverified_header = jwt.get_unverified_header(token)
except jwt.DecodeError:
raise jwt.InvalidTokenError("Invalid token header")
kid = unverified_header.get("kid")
# Find the matching public key
public_key = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
try:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
public_key = serialization.load_pem_public_key(
key["public_key"].encode() if isinstance(key["public_key"], str)
else key["public_key"],
backend=default_backend()
)
break
except (ImportError, Exception):
continue
if not public_key:
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
# Verify the signature
claims = jwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=None, # We'll validate audience separately
issuer=cls._get_issuer(),
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False, # Handle audience manually
"verify_iss": False, # Handle issuer manually
}
)
return claims
@classmethod
def decode_token(cls, token: str, verify: bool = False) -> Dict:
"""Decode a JWT token without verification (for debugging).
Args:
token: JWT token string
verify: Whether to verify signature
Returns:
Decoded token claims
"""
if verify:
return cls.verify_token_signature(token)
return jwt.decode(
token,
options={
"verify_signature": False,
"verify_exp": False,
}
)
@classmethod
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
"""Validate an access token and return its claims.
Args:
token: JWT access token
client_id: Optional client ID to validate audience
Returns:
Token claims dictionary
Raises:
jwt.InvalidTokenError: If token is invalid
ValueError: If token is expired or audience mismatch
"""
claims = cls.verify_token_signature(token)
# Check expiration
if claims.get("exp", 0) < datetime.utcnow().timestamp():
raise ValueError("Token has expired")
# Validate audience if client_id provided
if client_id:
if claims.get("aud") != client_id:
raise ValueError("Invalid audience")
return claims
@classmethod
def introspect_token(cls, token: str, client_id: str = None) -> Dict:
"""Introspect a token and return its status and claims.
Args:
token: JWT token to introspect
client_id: Client ID for audience validation
Returns:
Dictionary with active status and claims
"""
result = {
"active": False,
}
try:
claims = cls.validate_access_token(token, client_id)
# Calculate remaining time
now = datetime.utcnow().timestamp()
exp = claims.get("exp", 0)
iat = claims.get("iat", 0)
result["active"] = exp > now
result.update({
"iss": claims.get("iss"),
"sub": claims.get("sub"),
"aud": claims.get("aud"),
"exp": exp,
"iat": iat,
"nbf": claims.get("nbf"),
"jti": claims.get("jti"),
"client_id": claims.get("client_id"),
"scope": claims.get("scope"),
"token_type": "Bearer",
})
# Add expiry in seconds
if exp > now:
result["exp"] = int(exp - now)
except (jwt.InvalidTokenError, ValueError) as e:
result["active"] = False
result["error"] = str(e)
return result