major checkpoint
This commit is contained in:
@@ -4,6 +4,11 @@ from app.services.user_service import UserService
|
||||
from app.services.organization_service import OrganizationService
|
||||
from app.services.session_service import SessionService
|
||||
from app.services.audit_service import AuditService
|
||||
from app.services.oidc_service import OIDCService, OIDCError
|
||||
from app.services.oidc_jwks_service import OIDCJWKSService
|
||||
from app.services.oidc_token_service import OIDCTokenService
|
||||
from app.services.oidc_session_service import OIDCSessionService
|
||||
from app.services.oidc_audit_service import OIDCAuditService
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
@@ -11,4 +16,10 @@ __all__ = [
|
||||
"OrganizationService",
|
||||
"SessionService",
|
||||
"AuditService",
|
||||
"OIDCService",
|
||||
"OIDCError",
|
||||
"OIDCJWKSService",
|
||||
"OIDCTokenService",
|
||||
"OIDCSessionService",
|
||||
"OIDCAuditService",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
"""OIDC Audit Service for comprehensive OIDC event logging."""
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from flask import g
|
||||
|
||||
from app.models import OIDCAuditLog, OIDCClient, User
|
||||
from app.exceptions.validation_exceptions import NotFoundError
|
||||
|
||||
|
||||
class OIDCAuditService:
|
||||
"""Service for OIDC-specific audit logging.
|
||||
|
||||
This service provides methods to log all OIDC-related events including:
|
||||
- Authorization requests and responses
|
||||
- Token issuance and refresh
|
||||
- Token revocation
|
||||
- UserInfo endpoint access
|
||||
- Authentication failures
|
||||
"""
|
||||
|
||||
# Event type constants
|
||||
EVENT_AUTHORIZATION_REQUEST = "authorization_request"
|
||||
EVENT_AUTHORIZATION_RESPONSE = "authorization_response"
|
||||
EVENT_TOKEN_ISSUE = "token_issue"
|
||||
EVENT_TOKEN_REFRESH = "token_refresh"
|
||||
EVENT_TOKEN_REVOCATION = "token_revocation"
|
||||
EVENT_TOKEN_INTROSPECTION = "token_introspection"
|
||||
EVENT_USERINFO_ACCESS = "userinfo_access"
|
||||
EVENT_AUTHENTICATION_FAILURE = "authentication_failure"
|
||||
EVENT_AUTHORIZATION_FAILURE = "authorization_failure"
|
||||
EVENT_JWKS_ACCESS = "jwks_access"
|
||||
EVENT_REGISTRATION = "client_registration"
|
||||
|
||||
@classmethod
|
||||
def _get_request_context(cls) -> Dict:
|
||||
"""Extract request context for logging.
|
||||
|
||||
Returns:
|
||||
Dictionary with IP, user_agent, and request_id
|
||||
"""
|
||||
from flask import request
|
||||
|
||||
return {
|
||||
"ip_address": request.remote_addr if request else None,
|
||||
"user_agent": request.headers.get("User-Agent") if request else None,
|
||||
"request_id": g.get("request_id"),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def log_event(
|
||||
cls,
|
||||
event_type: str,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
metadata: Dict = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a generic OIDC event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
success: Whether the event was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
metadata: Additional event metadata
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
context = cls._get_request_context()
|
||||
|
||||
log = OIDCAuditLog.log_event(
|
||||
event_type=event_type,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=context["ip_address"],
|
||||
user_agent=context["user_agent"],
|
||||
request_id=context["request_id"],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
@classmethod
|
||||
def log_authorization_event(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
redirect_uri: str = None,
|
||||
scope: list = None,
|
||||
response_type: str = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log an authorization event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID (if authenticated)
|
||||
success: Whether authorization was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
redirect_uri: Redirect URI from request
|
||||
scope: Requested scopes
|
||||
response_type: Response type (e.g., "code")
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
metadata = {
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": scope,
|
||||
"response_type": response_type,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_AUTHORIZATION_REQUEST,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_event(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str = None,
|
||||
token_type: str = "access_token",
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
grant_type: str = None,
|
||||
scopes: list = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a token issuance or refresh event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
token_type: Type of token issued
|
||||
success: Whether token issuance was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
grant_type: Grant type used (e.g., "authorization_code", "refresh_token")
|
||||
scopes: Scopes included in the token
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
metadata = {
|
||||
"token_type": token_type,
|
||||
"grant_type": grant_type,
|
||||
"scopes": scopes,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_TOKEN_ISSUE if token_type else cls.EVENT_TOKEN_REFRESH,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_userinfo_event(
|
||||
cls,
|
||||
access_token: str = None,
|
||||
user_id: str = None,
|
||||
client_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
scopes_claimed: list = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a UserInfo endpoint access event.
|
||||
|
||||
Args:
|
||||
access_token: Access token used (masked)
|
||||
user_id: User ID returned
|
||||
client_id: Client ID making the request
|
||||
success: Whether access was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
scopes_claimed: Scopes claimed in the request
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
# Mask the access token for security
|
||||
masked_token = None
|
||||
if access_token:
|
||||
masked_token = access_token[:8] + "..." + access_token[-4:] if len(access_token) > 12 else "***"
|
||||
|
||||
metadata = {
|
||||
"token_prefix": masked_token,
|
||||
"scopes_claimed": scopes_claimed,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_USERINFO_ACCESS,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_revocation_event(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str = None,
|
||||
token_type: str = "access_token",
|
||||
reason: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a token revocation event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
token_type: Type of token being revoked
|
||||
reason: Revocation reason
|
||||
success: Whether revocation was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
metadata = {
|
||||
"token_type": token_type,
|
||||
"reason": reason,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_TOKEN_REVOCATION,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_authentication_failure(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
error_code: str = "authentication_failed",
|
||||
error_description: str = "Authentication failed",
|
||||
user_id: str = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log an authentication failure event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
error_code: Error code
|
||||
error_description: Error description
|
||||
user_id: User ID if known
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_AUTHENTICATION_FAILURE,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=False,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_user(
|
||||
cls,
|
||||
user_id: str,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False
|
||||
) -> List[OIDCAuditLog]:
|
||||
"""Get audit events for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
limit: Maximum number of events to return
|
||||
include_deleted: Include soft-deleted events
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return OIDCAuditLog.get_events_for_user(user_id, limit)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_client(
|
||||
cls,
|
||||
client_id: str,
|
||||
limit: int = 100
|
||||
) -> List[OIDCAuditLog]:
|
||||
"""Get audit events for a specific client.
|
||||
|
||||
Args:
|
||||
client_id: Client ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return OIDCAuditLog.get_events_for_client(client_id, limit)
|
||||
|
||||
@classmethod
|
||||
def get_failed_events(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
start_date: datetime = None,
|
||||
end_date: datetime = None,
|
||||
limit: int = 100
|
||||
) -> List[OIDCAuditLog]:
|
||||
"""Get failed audit events for analysis.
|
||||
|
||||
Args:
|
||||
client_id: Optional client ID filter
|
||||
user_id: Optional user ID filter
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of failed OIDCAuditLog instances
|
||||
"""
|
||||
return OIDCAuditLog.get_failed_events(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_event_summary(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
days: int = 30
|
||||
) -> Dict:
|
||||
"""Get a summary of audit events.
|
||||
|
||||
Args:
|
||||
client_id: Optional client ID filter
|
||||
days: Number of days to look back
|
||||
|
||||
Returns:
|
||||
Summary dictionary with event counts
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
start_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
query = OIDCAuditLog.query.filter(
|
||||
OIDCAuditLog.created_at >= start_date
|
||||
)
|
||||
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
|
||||
events = query.all()
|
||||
|
||||
# Count by event type
|
||||
event_counts = {}
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for event in events:
|
||||
event_type = event.event_type
|
||||
event_counts[event_type] = event_counts.get(event_type, 0) + 1
|
||||
|
||||
if event.success:
|
||||
success_count += 1
|
||||
else:
|
||||
failure_count += 1
|
||||
|
||||
return {
|
||||
"total_events": len(events),
|
||||
"successful_events": success_count,
|
||||
"failed_events": failure_count,
|
||||
"by_event_type": event_counts,
|
||||
"period_days": days,
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
"""OIDC JWKS Service for key management and rotation."""
|
||||
import uuid
|
||||
import json
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from app.extensions import db
|
||||
|
||||
|
||||
class JWKSKey:
|
||||
"""Represents a JWKS key entry."""
|
||||
|
||||
def __init__(self, kid: str, private_key: str, public_key: str,
|
||||
algorithm: str = "RS256", created_at: datetime = None,
|
||||
expires_at: datetime = None, is_active: bool = True):
|
||||
self.kid = kid
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
self.algorithm = algorithm
|
||||
self.created_at = created_at or datetime.utcnow()
|
||||
self.expires_at = expires_at or datetime.utcnow() + timedelta(days=365)
|
||||
self.is_active = is_active
|
||||
|
||||
def to_jwk(self) -> Dict:
|
||||
"""Convert to JWK format for JWKS endpoint."""
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
# Import cryptography here to avoid issues if not installed
|
||||
try:
|
||||
# Get public key from PEM
|
||||
public_key = serialization.load_pem_public_key(
|
||||
self.public_key.encode(), backend=default_backend()
|
||||
)
|
||||
|
||||
# Get RSA parameters
|
||||
public_numbers = public_key.public_numbers()
|
||||
|
||||
return {
|
||||
"kty": "RSA",
|
||||
"kid": self.kid,
|
||||
"use": "sig",
|
||||
"alg": self.algorithm,
|
||||
"n": _base64url_encode(public_numbers.n),
|
||||
"e": _base64url_encode(public_numbers.e),
|
||||
}
|
||||
except ImportError:
|
||||
# Fallback for when cryptography is not installed
|
||||
return {
|
||||
"kty": "RSA",
|
||||
"kid": self.kid,
|
||||
"use": "sig",
|
||||
"alg": self.algorithm,
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary for storage."""
|
||||
return {
|
||||
"kid": self.kid,
|
||||
"private_key": self.private_key,
|
||||
"public_key": self.public_key,
|
||||
"algorithm": self.algorithm,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat(),
|
||||
"is_active": self.is_active,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "JWKSKey":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
kid=data["kid"],
|
||||
private_key=data["private_key"],
|
||||
public_key=data["public_key"],
|
||||
algorithm=data.get("algorithm", "RS256"),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]),
|
||||
is_active=data.get("is_active", True),
|
||||
)
|
||||
|
||||
|
||||
def _base64url_encode(value: int) -> str:
|
||||
"""Encode an integer to base64url format."""
|
||||
import base64
|
||||
byte_length = (value.bit_length() + 7) // 8 or 1
|
||||
encoded = value.to_bytes(byte_length, byteorder="big")
|
||||
return base64.urlsafe_b64encode(encoded).decode().rstrip("=")
|
||||
|
||||
|
||||
class OIDCJWKSService:
|
||||
"""Service for managing OIDC signing keys (JWKS).
|
||||
|
||||
This service handles RSA key pair generation, rotation, and JWKS document
|
||||
generation for the OIDC implementation.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_keys: Dict[str, JWKSKey] = {}
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._keys = {}
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset the singleton (for testing)."""
|
||||
cls._instance = None
|
||||
cls._keys = {}
|
||||
|
||||
def _generate_kid(self, private_key: str) -> str:
|
||||
"""Generate a key ID from the private key fingerprint."""
|
||||
kid_hash = hashlib.sha256(private_key.encode()).hexdigest()[:32]
|
||||
return kid_hash
|
||||
|
||||
def _generate_rsa_key_pair(self) -> Tuple[str, str]:
|
||||
"""Generate a new RSA key pair in PEM format.
|
||||
|
||||
Returns:
|
||||
Tuple of (private_key_pem, public_key_pem)
|
||||
"""
|
||||
try:
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
# Generate RSA private key
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
backend=default_backend()
|
||||
)
|
||||
|
||||
# Get public key
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Serialize to PEM
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
).decode()
|
||||
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
).decode()
|
||||
|
||||
return private_pem, public_pem
|
||||
except ImportError:
|
||||
# Fallback for testing without cryptography
|
||||
import secrets
|
||||
return f"private_key_{secrets.token_hex(32)}", f"public_key_{secrets.token_hex(32)}"
|
||||
|
||||
def get_jwks(self, include_private_keys: bool = False) -> Dict:
|
||||
"""Get the JWKS document containing public keys.
|
||||
|
||||
Args:
|
||||
include_private_keys: Whether to include private keys (for internal use only)
|
||||
|
||||
Returns:
|
||||
JWKS document dictionary
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
keys = []
|
||||
for kid, key in self._keys.items():
|
||||
# Only include active, non-expired keys
|
||||
if key.is_active and key.expires_at > now:
|
||||
if include_private_keys:
|
||||
keys.append(key.to_dict())
|
||||
else:
|
||||
keys.append(key.to_jwk())
|
||||
|
||||
return {
|
||||
"keys": keys
|
||||
}
|
||||
|
||||
def get_signing_key(self) -> Optional[JWKSKey]:
|
||||
"""Get the current active signing key.
|
||||
|
||||
Returns:
|
||||
JWKSKey instance or None if no active key
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
for kid, key in self._keys.items():
|
||||
if key.is_active and key.expires_at > now:
|
||||
return key
|
||||
|
||||
return None
|
||||
|
||||
def get_key_by_kid(self, kid: str) -> Optional[JWKSKey]:
|
||||
"""Get a specific key by its ID.
|
||||
|
||||
Args:
|
||||
kid: Key ID to look up
|
||||
|
||||
Returns:
|
||||
JWKSKey instance or None if not found
|
||||
"""
|
||||
return self._keys.get(kid)
|
||||
|
||||
def generate_new_key_pair(self, expires_in_days: int = 365) -> JWKSKey:
|
||||
"""Generate a new RSA key pair for signing.
|
||||
|
||||
Args:
|
||||
expires_in_days: Days until key expiration
|
||||
|
||||
Returns:
|
||||
JWKSKey instance
|
||||
"""
|
||||
private_key, public_key = self._generate_rsa_key_pair()
|
||||
kid = self._generate_kid(private_key)
|
||||
|
||||
now = datetime.utcnow()
|
||||
key = JWKSKey(
|
||||
kid=kid,
|
||||
private_key=private_key,
|
||||
public_key=public_key,
|
||||
algorithm="RS256",
|
||||
created_at=now,
|
||||
expires_at=now + timedelta(days=expires_in_days),
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
self._keys[kid] = key
|
||||
|
||||
# Deactivate old keys (but keep them for grace period)
|
||||
for old_kid in self._keys:
|
||||
if old_kid != kid:
|
||||
self._keys[old_kid].is_active = False
|
||||
|
||||
return key
|
||||
|
||||
def rotate_keys(self, grace_period_hours: int = 24) -> Tuple[JWKSKey, List[str]]:
|
||||
"""Rotate signing keys, keeping previous key active for grace period.
|
||||
|
||||
Args:
|
||||
grace_period_hours: Hours to keep old keys active
|
||||
|
||||
Returns:
|
||||
Tuple of (new_key, list_of_deprecated_kids)
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
grace_end = now + timedelta(hours=grace_period_hours)
|
||||
|
||||
# Mark current key as deprecated
|
||||
current_key = self.get_signing_key()
|
||||
deprecated_kids = []
|
||||
|
||||
if current_key:
|
||||
deprecated_kids.append(current_key.kid)
|
||||
# Keep key active but mark as deprecated
|
||||
current_key.is_active = False
|
||||
current_key.expires_at = grace_end
|
||||
|
||||
# Generate new key
|
||||
new_key = self.generate_new_key_pair()
|
||||
|
||||
# Clean up expired keys
|
||||
expired_kids = [
|
||||
kid for kid, key in self._keys.items()
|
||||
if key.expires_at < now
|
||||
]
|
||||
for kid in expired_kids:
|
||||
del self._keys[kid]
|
||||
|
||||
return new_key, deprecated_kids
|
||||
|
||||
def verify_key_exists(self, kid: str) -> bool:
|
||||
"""Check if a key with the given ID exists and is valid.
|
||||
|
||||
Args:
|
||||
kid: Key ID to check
|
||||
|
||||
Returns:
|
||||
True if key exists and is valid
|
||||
"""
|
||||
key = self.get_key_by_kid(kid)
|
||||
if not key:
|
||||
return False
|
||||
|
||||
now = datetime.utcnow()
|
||||
return key.is_active and key.expires_at > now
|
||||
|
||||
def initialize_with_key(self) -> JWKSKey:
|
||||
"""Initialize the service with a key if none exists.
|
||||
|
||||
Returns:
|
||||
JWKSKey instance
|
||||
"""
|
||||
if not self._keys:
|
||||
return self.generate_new_key_pair()
|
||||
return self.get_signing_key()
|
||||
@@ -0,0 +1,745 @@
|
||||
"""OIDC Service - Main OIDC service layer."""
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from flask import current_app, g
|
||||
|
||||
from app.extensions import db
|
||||
from app.models import (
|
||||
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
|
||||
OIDCSession, OIDCTokenMetadata
|
||||
)
|
||||
from app.exceptions.validation_exceptions import (
|
||||
ValidationError, NotFoundError, BadRequestError
|
||||
)
|
||||
from app.exceptions.auth_exceptions import UnauthorizedError, InvalidTokenError
|
||||
from app.services.oidc_token_service import OIDCTokenService
|
||||
from app.services.oidc_session_service import OIDCSessionService
|
||||
from app.services.oidc_audit_service import OIDCAuditService
|
||||
from app.services.oidc_jwks_service import OIDCJWKSService
|
||||
|
||||
|
||||
class OIDCError(Exception):
|
||||
"""Base exception for OIDC errors."""
|
||||
|
||||
def __init__(self, error: str, error_description: str = None, status_code: int = 400):
|
||||
self.error = error
|
||||
self.error_description = error_description
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class InvalidClientError(OIDCError):
|
||||
"""Raised when client authentication fails."""
|
||||
|
||||
def __init__(self, error_description: str = "Invalid client"):
|
||||
super().__init__("invalid_client", error_description, 401)
|
||||
|
||||
|
||||
class InvalidGrantError(OIDCError):
|
||||
"""Raised when grant is invalid."""
|
||||
|
||||
def __init__(self, error_description: str = "Invalid grant"):
|
||||
super().__init__("invalid_grant", error_description, 400)
|
||||
|
||||
|
||||
class InvalidRequestError(OIDCError):
|
||||
"""Raised when request is malformed."""
|
||||
|
||||
def __init__(self, error_description: str = "Invalid request"):
|
||||
super().__init__("invalid_request", error_description, 400)
|
||||
|
||||
|
||||
class OIDCService:
|
||||
"""Main OIDC service handling all OpenID Connect operations.
|
||||
|
||||
This service provides:
|
||||
- Authorization code generation and validation
|
||||
- Token generation (access, refresh, ID tokens)
|
||||
- Token refresh with rotation
|
||||
- Token validation and introspection
|
||||
- Token revocation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_code() -> str:
|
||||
"""Generate a secure authorization code.
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded code
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _hash_value(value: str) -> str:
|
||||
"""Hash a value for secure storage.
|
||||
|
||||
Args:
|
||||
value: Value to hash
|
||||
|
||||
Returns:
|
||||
SHA256 hash
|
||||
"""
|
||||
return hashlib.sha256(value.encode()).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def generate_authorization_code(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
redirect_uri: str,
|
||||
scope: list,
|
||||
state: str,
|
||||
nonce: str,
|
||||
code_challenge: str = None,
|
||||
code_challenge_method: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None
|
||||
) -> str:
|
||||
"""Generate an authorization code for the auth code flow.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
redirect_uri: Redirect URI
|
||||
scope: Requested scopes
|
||||
state: State parameter
|
||||
nonce: Nonce for ID token
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method ("S256" or "plain")
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
Authorization code string
|
||||
|
||||
Raises:
|
||||
ValidationError: If parameters are invalid
|
||||
NotFoundError: If client not found
|
||||
"""
|
||||
# Validate client exists and is active
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
raise NotFoundError("Client not found")
|
||||
|
||||
if not client.is_active:
|
||||
raise ValidationError("Client is not active")
|
||||
|
||||
# Validate redirect URI
|
||||
if not client.is_redirect_uri_allowed(redirect_uri):
|
||||
raise ValidationError("Invalid redirect_uri")
|
||||
|
||||
# Validate scopes
|
||||
allowed_scopes = client.scopes or []
|
||||
valid_scopes = [s for s in scope if s in allowed_scopes]
|
||||
|
||||
if not valid_scopes:
|
||||
raise ValidationError("Invalid scopes")
|
||||
|
||||
# Generate authorization code
|
||||
code = cls._generate_code()
|
||||
code_hash = cls._hash_value(code)
|
||||
|
||||
# Create auth code record
|
||||
auth_code = OIDCAuthCode.create_code(
|
||||
client_id=client.id,
|
||||
user_id=user_id,
|
||||
code_hash=code_hash,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=valid_scopes,
|
||||
nonce=nonce,
|
||||
code_verifier=code_challenge, # Store for validation
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
lifetime_seconds=600, # 10 minutes
|
||||
)
|
||||
|
||||
# Log authorization event
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=True,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=valid_scopes,
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
@classmethod
|
||||
def validate_authorization_code(
|
||||
cls,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None
|
||||
) -> Tuple[Dict, User]:
|
||||
"""Validate and consume an authorization code.
|
||||
|
||||
Args:
|
||||
code: Authorization code
|
||||
client_id: OIDC client ID
|
||||
redirect_uri: Redirect URI
|
||||
code_verifier: PKCE code verifier (required if PKCE was used)
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
Tuple of (claims dict, User instance)
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If code is invalid
|
||||
ValidationError: If PKCE validation fails
|
||||
"""
|
||||
# Get client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
raise InvalidGrantError("Invalid client")
|
||||
|
||||
# Hash the provided code and find matching auth code
|
||||
code_hash = cls._hash_value(code)
|
||||
auth_code = OIDCAuthCode.query.filter_by(
|
||||
code_hash=code_hash,
|
||||
client_id=client.id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not auth_code:
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Invalid or expired authorization code",
|
||||
)
|
||||
raise InvalidGrantError("Invalid or expired authorization code")
|
||||
|
||||
# Check if already used
|
||||
if auth_code.is_used:
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client_id,
|
||||
user_id=auth_code.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Authorization code already used",
|
||||
)
|
||||
raise InvalidGrantError("Authorization code already used")
|
||||
|
||||
# Check expiration
|
||||
if auth_code.is_expired():
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client_id,
|
||||
user_id=auth_code.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Authorization code expired",
|
||||
)
|
||||
raise InvalidGrantError("Authorization code expired")
|
||||
|
||||
# Validate redirect URI
|
||||
if auth_code.redirect_uri != redirect_uri:
|
||||
raise InvalidGrantError("Invalid redirect_uri")
|
||||
|
||||
# Validate PKCE if required
|
||||
if client.require_pkce and auth_code.code_verifier:
|
||||
if not code_verifier:
|
||||
raise ValidationError("code_verifier is required")
|
||||
|
||||
# Verify code verifier
|
||||
expected_challenge = cls._compute_code_challenge(code_verifier, "S256")
|
||||
if expected_challenge != auth_code.code_verifier:
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client_id,
|
||||
user_id=auth_code.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Invalid code_verifier",
|
||||
)
|
||||
raise InvalidGrantError("Invalid code_verifier")
|
||||
|
||||
# Mark code as used
|
||||
auth_code.mark_as_used()
|
||||
|
||||
# Get user
|
||||
user = User.query.get(auth_code.user_id)
|
||||
if not user:
|
||||
raise InvalidGrantError("User not found")
|
||||
|
||||
claims = {
|
||||
"user_id": auth_code.user_id,
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": auth_code.scope,
|
||||
"nonce": auth_code.nonce,
|
||||
}
|
||||
|
||||
return claims, user
|
||||
|
||||
@classmethod
|
||||
def _compute_code_challenge(cls, verifier: str, method: str = "S256") -> str:
|
||||
"""Compute PKCE code challenge from verifier.
|
||||
|
||||
Args:
|
||||
verifier: Code verifier
|
||||
method: Challenge method
|
||||
|
||||
Returns:
|
||||
Code challenge
|
||||
"""
|
||||
import hashlib
|
||||
import base64
|
||||
|
||||
if method == "S256":
|
||||
digest = hashlib.sha256(verifier.encode()).digest()
|
||||
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return verifier
|
||||
|
||||
@classmethod
|
||||
def generate_tokens(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
scope: list,
|
||||
nonce: str = None,
|
||||
refresh_token: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
auth_time: int = None
|
||||
) -> Dict:
|
||||
"""Generate access token, ID token, and refresh token.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
scope: Granted scopes
|
||||
nonce: Nonce for ID token
|
||||
refresh_token: Existing refresh token (for rotation)
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
auth_time: Authentication time
|
||||
|
||||
Returns:
|
||||
Dictionary with tokens
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# Get client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
raise InvalidClientError()
|
||||
|
||||
# Generate access token
|
||||
access_token_jti = OIDCTokenService._generate_jti()
|
||||
access_token = OIDCTokenService.create_access_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
jti=access_token_jti,
|
||||
)
|
||||
|
||||
# Generate ID token
|
||||
id_token = OIDCTokenService.create_id_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
nonce=nonce,
|
||||
scope=scope,
|
||||
access_token=access_token,
|
||||
auth_time=auth_time,
|
||||
)
|
||||
|
||||
# Generate or rotate refresh token
|
||||
if "refresh_token" in (client.grant_types or []):
|
||||
if refresh_token:
|
||||
# Rotate existing refresh token
|
||||
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
||||
token_hash=hashlib.sha256(refresh_token.encode()).hexdigest(),
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if refresh_token_obj and refresh_token_obj.is_valid():
|
||||
# Create new refresh token
|
||||
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
access_token_id=access_token_jti,
|
||||
)
|
||||
|
||||
# Rotate in database
|
||||
refresh_token_obj.rotate(new_hash)
|
||||
final_refresh_token = new_refresh
|
||||
else:
|
||||
final_refresh_token = None
|
||||
else:
|
||||
# Create new refresh token
|
||||
final_refresh_token, refresh_hash = OIDCTokenService.create_refresh_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
access_token_id=access_token_jti,
|
||||
)
|
||||
|
||||
# Store refresh token
|
||||
OIDCRefreshToken.create_token(
|
||||
client_id=client.id,
|
||||
user_id=user_id,
|
||||
token_hash=refresh_hash,
|
||||
scope=scope,
|
||||
access_token_id=access_token_jti,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
lifetime_seconds=client.refresh_token_lifetime or 2592000,
|
||||
)
|
||||
else:
|
||||
final_refresh_token = None
|
||||
|
||||
# Store token metadata
|
||||
client_db_id = client.id
|
||||
|
||||
# Access token metadata
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client_db_id,
|
||||
user_id=user_id,
|
||||
token_type="access_token",
|
||||
token_jti=access_token_jti,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
|
||||
)
|
||||
|
||||
# ID token metadata (using access token JTI as reference)
|
||||
id_token_jti = OIDCTokenService._generate_jti()
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client_db_id,
|
||||
user_id=user_id,
|
||||
token_type="id_token",
|
||||
token_jti=id_token_jti,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=client.id_token_lifetime or 3600),
|
||||
)
|
||||
|
||||
# Log token event
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
token_type="access_token",
|
||||
success=True,
|
||||
grant_type="authorization_code",
|
||||
scopes=scope,
|
||||
)
|
||||
|
||||
result = {
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": client.access_token_lifetime or 3600,
|
||||
"id_token": id_token,
|
||||
}
|
||||
|
||||
if final_refresh_token:
|
||||
result["refresh_token"] = final_refresh_token
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def refresh_access_token(
|
||||
cls,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
scope: list = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None
|
||||
) -> Dict:
|
||||
"""Refresh an access token with token rotation.
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token
|
||||
client_id: OIDC client ID
|
||||
scope: Optional scope override
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
Dictionary with new tokens
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If refresh token is invalid
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# Get client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
raise InvalidClientError()
|
||||
|
||||
# Find refresh token
|
||||
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
||||
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
||||
token_hash=token_hash,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not refresh_token_obj:
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Invalid refresh token",
|
||||
)
|
||||
raise InvalidGrantError("Invalid refresh token")
|
||||
|
||||
# Check if valid
|
||||
if not refresh_token_obj.is_valid():
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Refresh token expired or revoked",
|
||||
)
|
||||
raise InvalidGrantError("Refresh token expired or revoked")
|
||||
|
||||
# Validate client matches
|
||||
if refresh_token_obj.client_id != client.id:
|
||||
raise InvalidGrantError("Client mismatch")
|
||||
|
||||
# Get original scope or use provided
|
||||
granted_scope = scope or (refresh_token_obj.scope or [])
|
||||
|
||||
# Generate new access token
|
||||
access_token_jti = OIDCTokenService._generate_jti()
|
||||
access_token = OIDCTokenService.create_access_token(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
scope=granted_scope,
|
||||
jti=access_token_jti,
|
||||
)
|
||||
|
||||
# Generate new ID token
|
||||
id_token = OIDCTokenService.create_id_token(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
scope=granted_scope,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# Rotate refresh token
|
||||
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
scope=granted_scope,
|
||||
access_token_id=access_token_jti,
|
||||
)
|
||||
|
||||
refresh_token_obj.rotate(new_hash)
|
||||
|
||||
# Store new token metadata
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client.id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
token_type="access_token",
|
||||
token_jti=access_token_jti,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
|
||||
)
|
||||
|
||||
# Log refresh event
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
token_type="access_token",
|
||||
success=True,
|
||||
grant_type="refresh_token",
|
||||
scopes=granted_scope,
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": client.access_token_lifetime or 3600,
|
||||
"id_token": id_token,
|
||||
"refresh_token": new_refresh,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
"""Validate an access token and return its claims.
|
||||
|
||||
Args:
|
||||
token: JWT access token
|
||||
client_id: Optional client ID to validate audience
|
||||
|
||||
Returns:
|
||||
Token claims
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If token is invalid
|
||||
"""
|
||||
try:
|
||||
claims = OIDCTokenService.validate_access_token(token, client_id)
|
||||
return claims
|
||||
except Exception as e:
|
||||
OIDCAuditService.log_event(
|
||||
event_type="token_validation",
|
||||
client_id=client_id,
|
||||
success=False,
|
||||
error_code="invalid_token",
|
||||
error_description=str(e),
|
||||
)
|
||||
raise InvalidTokenError(str(e))
|
||||
|
||||
@classmethod
|
||||
def revoke_token(
|
||||
cls,
|
||||
token: str,
|
||||
client_id: str,
|
||||
token_type_hint: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None
|
||||
) -> bool:
|
||||
"""Revoke a token.
|
||||
|
||||
Args:
|
||||
token: Token to revoke
|
||||
client_id: OIDC client ID
|
||||
token_type_hint: Hint about token type
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
True if token was revoked
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# Get client
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
raise InvalidClientError()
|
||||
|
||||
revoked = False
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
# Try to revoke as refresh token
|
||||
if token_type_hint in (None, "refresh_token"):
|
||||
refresh_token = OIDCRefreshToken.query.filter_by(
|
||||
token_hash=token_hash,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if refresh_token:
|
||||
refresh_token.revoke(reason="revoked_by_client")
|
||||
revoked = True
|
||||
|
||||
OIDCAuditService.log_token_revocation_event(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token.user_id,
|
||||
token_type="refresh_token",
|
||||
reason="revoked_by_client",
|
||||
)
|
||||
|
||||
# Try to revoke as access token (JTI lookup)
|
||||
if not revoked or token_type_hint in (None, "access_token"):
|
||||
try:
|
||||
# Decode token to get JTI
|
||||
claims = OIDCTokenService.decode_token(token)
|
||||
jti = claims.get("jti")
|
||||
|
||||
if jti:
|
||||
revoked_at = OIDCTokenMetadata.revoke_by_jti(
|
||||
jti,
|
||||
reason="revoked_by_client"
|
||||
)
|
||||
if revoked_at:
|
||||
revoked = True
|
||||
|
||||
OIDCAuditService.log_token_revocation_event(
|
||||
client_id=client_id,
|
||||
user_id=claims.get("sub"),
|
||||
token_type="access_token",
|
||||
reason="revoked_by_client",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return revoked
|
||||
|
||||
@classmethod
|
||||
def introspect_token(
|
||||
cls,
|
||||
token: str,
|
||||
client_id: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None
|
||||
) -> Dict:
|
||||
"""Introspect a token and return its status and claims.
|
||||
|
||||
Args:
|
||||
token: Token to introspect
|
||||
client_id: Client ID for validation
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
Introspection response
|
||||
"""
|
||||
result = OIDCTokenService.introspect_token(token, client_id)
|
||||
|
||||
# Log introspection
|
||||
OIDCAuditService.log_event(
|
||||
event_type="token_introspection",
|
||||
client_id=client_id,
|
||||
user_id=result.get("sub"),
|
||||
success=result.get("active", False),
|
||||
metadata={"active": result.get("active")},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_jwks(cls) -> Dict:
|
||||
"""Get the JWKS document.
|
||||
|
||||
Returns:
|
||||
JWKS document
|
||||
"""
|
||||
jwks_service = OIDCJWKSService()
|
||||
return jwks_service.get_jwks()
|
||||
|
||||
@classmethod
|
||||
def get_userinfo(cls, access_token: str) -> Dict:
|
||||
"""Get user information using access token.
|
||||
|
||||
Args:
|
||||
access_token: Access token
|
||||
|
||||
Returns:
|
||||
User information dictionary
|
||||
"""
|
||||
claims = cls.validate_access_token(access_token)
|
||||
|
||||
user_id = claims.get("sub")
|
||||
user = User.query.get(user_id)
|
||||
|
||||
if not user:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
# Get scopes from token
|
||||
scope_str = claims.get("scope", "")
|
||||
scopes = scope_str.split() if scope_str else []
|
||||
|
||||
userinfo = {"sub": user_id}
|
||||
|
||||
# Add claims based on scope
|
||||
if "profile" in scopes and user.full_name:
|
||||
userinfo["name"] = user.full_name
|
||||
|
||||
if "email" in scopes:
|
||||
userinfo["email"] = user.email
|
||||
userinfo["email_verified"] = user.email_verified
|
||||
|
||||
# Log userinfo access
|
||||
OIDCAuditService.log_userinfo_event(
|
||||
access_token=access_token,
|
||||
user_id=user_id,
|
||||
client_id=claims.get("client_id"),
|
||||
success=True,
|
||||
scopes_claimed=scopes,
|
||||
)
|
||||
|
||||
return userinfo
|
||||
@@ -0,0 +1,288 @@
|
||||
"""OIDC Session Service for session management during OIDC flow."""
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from flask import current_app, g
|
||||
|
||||
from app.extensions import db
|
||||
from app.models import OIDCSession, OIDCClient, User
|
||||
from app.exceptions.validation_exceptions import NotFoundError, ValidationError
|
||||
|
||||
|
||||
class OIDCSessionService:
|
||||
"""Service for managing OIDC authentication sessions.
|
||||
|
||||
This service handles:
|
||||
- Creating OIDC sessions during authorization flow
|
||||
- Validating sessions with state and nonce
|
||||
- Managing PKCE code challenges
|
||||
- Cleaning up expired sessions
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_state() -> str:
|
||||
"""Generate a secure state parameter.
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded state
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _generate_nonce() -> str:
|
||||
"""Generate a secure nonce for OIDC.
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded nonce
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _generate_code_challenge(verifier: str, method: str = "S256") -> str:
|
||||
"""Generate a PKCE code challenge from verifier.
|
||||
|
||||
Args:
|
||||
verifier: The code verifier
|
||||
method: Challenge method ("S256" or "plain")
|
||||
|
||||
Returns:
|
||||
Code challenge string
|
||||
"""
|
||||
import hashlib
|
||||
import base64
|
||||
|
||||
if method == "S256":
|
||||
digest = hashlib.sha256(verifier.encode()).digest()
|
||||
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
elif method == "plain":
|
||||
return verifier
|
||||
else:
|
||||
raise ValueError(f"Unsupported code challenge method: {method}")
|
||||
|
||||
@classmethod
|
||||
def validate_code_verifier(cls, code_verifier: str, code_challenge: str,
|
||||
method: str = "S256") -> bool:
|
||||
"""Validate a PKCE code verifier against the stored challenge.
|
||||
|
||||
Args:
|
||||
code_verifier: The code verifier from the token request
|
||||
code_challenge: The code challenge from the authorization request
|
||||
method: The challenge method used
|
||||
|
||||
Returns:
|
||||
True if validation succeeds
|
||||
"""
|
||||
if not code_verifier or not code_challenge:
|
||||
return False
|
||||
|
||||
# Validate code verifier length (43-128 characters)
|
||||
if method == "S256" and not (43 <= len(code_verifier) <= 128):
|
||||
return False
|
||||
|
||||
# Calculate expected challenge
|
||||
expected_challenge = cls._generate_code_challenge(code_verifier, method)
|
||||
|
||||
return secrets.compare_digest(expected_challenge, code_challenge)
|
||||
|
||||
@classmethod
|
||||
def create_session(
|
||||
cls,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
state: str = None,
|
||||
nonce: str = None,
|
||||
redirect_uri: str = None,
|
||||
scope: list = None,
|
||||
code_challenge: str = None,
|
||||
code_challenge_method: str = None,
|
||||
lifetime_seconds: int = 600
|
||||
) -> OIDCSession:
|
||||
"""Create a new OIDC session for the authorization flow.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: The OIDC client ID
|
||||
state: State parameter (generated if not provided)
|
||||
nonce: Nonce for ID token validation (generated if not provided)
|
||||
redirect_uri: Redirect URI from authorization request
|
||||
scope: Requested scopes
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method ("S256" or "plain")
|
||||
lifetime_seconds: Session lifetime in seconds
|
||||
|
||||
Returns:
|
||||
OIDCSession instance
|
||||
"""
|
||||
# Generate state and nonce if not provided
|
||||
state = state or cls._generate_state()
|
||||
nonce = nonce or cls._generate_nonce()
|
||||
|
||||
session = OIDCSession.create_session(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
lifetime_seconds=lifetime_seconds,
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def validate_session(cls, state: str, nonce: str = None) -> Tuple[OIDCSession, User]:
|
||||
"""Validate an OIDC session by state and optionally nonce.
|
||||
|
||||
Args:
|
||||
state: The state parameter
|
||||
nonce: The nonce to validate (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (OIDCSession, User)
|
||||
|
||||
Raises:
|
||||
ValidationError: If session is invalid
|
||||
NotFoundError: If session not found
|
||||
"""
|
||||
session = OIDCSession.get_by_state(state)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError("OIDC session not found or expired")
|
||||
|
||||
if session.is_expired():
|
||||
raise ValidationError("OIDC session has expired")
|
||||
|
||||
# Validate nonce if provided
|
||||
if nonce and not session.validate_nonce(nonce):
|
||||
raise ValidationError("Invalid nonce")
|
||||
|
||||
# Get user
|
||||
user = User.query.get(session.user_id)
|
||||
if not user:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
return session, user
|
||||
|
||||
@classmethod
|
||||
def validate_pkce(cls, session: OIDCSession, code_verifier: str) -> bool:
|
||||
"""Validate PKCE code verifier against the session's code challenge.
|
||||
|
||||
Args:
|
||||
session: OIDCSession instance
|
||||
code_verifier: The code verifier from token request
|
||||
|
||||
Returns:
|
||||
True if validation succeeds
|
||||
|
||||
Raises:
|
||||
ValidationError: If PKCE validation fails
|
||||
"""
|
||||
if not session.code_challenge:
|
||||
# No PKCE was used, skip validation
|
||||
return True
|
||||
|
||||
if not code_verifier:
|
||||
raise ValidationError("code_verifier is required")
|
||||
|
||||
is_valid = session.validate_code_challenge(code_verifier)
|
||||
|
||||
if not is_valid:
|
||||
raise ValidationError("Invalid code_verifier")
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def mark_session_authenticated(cls, session: OIDCSession) -> OIDCSession:
|
||||
"""Mark a session as authenticated (user has logged in).
|
||||
|
||||
Args:
|
||||
session: OIDCSession instance
|
||||
|
||||
Returns:
|
||||
Updated OIDCSession instance
|
||||
"""
|
||||
session.mark_authenticated()
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired_sessions(cls, older_than_hours: int = 24) -> int:
|
||||
"""Remove expired OIDC sessions.
|
||||
|
||||
Args:
|
||||
older_than_hours: Only delete sessions expired more than this many hours ago
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(hours=older_than_hours)
|
||||
|
||||
# Get expired sessions
|
||||
expired_sessions = OIDCSession.query.filter(
|
||||
OIDCSession.expires_at < datetime.utcnow(),
|
||||
OIDCSession.deleted_at == None
|
||||
).all()
|
||||
|
||||
count = 0
|
||||
for session in expired_sessions:
|
||||
# Only hard delete if past the grace period
|
||||
if session.expires_at < cutoff:
|
||||
session.delete()
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def get_session_by_state(cls, state: str) -> Optional[OIDCSession]:
|
||||
"""Get an OIDC session by state.
|
||||
|
||||
Args:
|
||||
state: The state parameter
|
||||
|
||||
Returns:
|
||||
OIDCSession instance or None
|
||||
"""
|
||||
return OIDCSession.get_by_state(state)
|
||||
|
||||
@classmethod
|
||||
def validate_redirect_uri(cls, client_id: str, redirect_uri: str) -> bool:
|
||||
"""Validate that a redirect URI is allowed for a client.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
redirect_uri: The redirect URI to validate
|
||||
|
||||
Returns:
|
||||
True if redirect URI is allowed
|
||||
"""
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
return False
|
||||
|
||||
return client.is_redirect_uri_allowed(redirect_uri)
|
||||
|
||||
@classmethod
|
||||
def validate_scopes(cls, client_id: str, requested_scopes: list) -> list:
|
||||
"""Validate and filter scopes against client's allowed scopes.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
requested_scopes: List of requested scopes
|
||||
|
||||
Returns:
|
||||
List of allowed scopes
|
||||
"""
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
return []
|
||||
|
||||
allowed_scopes = client.scopes or []
|
||||
|
||||
# Filter to only allowed scopes
|
||||
valid_scopes = [s for s in requested_scopes if s in allowed_scopes]
|
||||
|
||||
return valid_scopes
|
||||
@@ -0,0 +1,439 @@
|
||||
"""OIDC Token Service for JWT token generation and validation."""
|
||||
import hashlib
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
import jwt
|
||||
from flask import current_app, g
|
||||
|
||||
from app.models import User, OIDCClient
|
||||
from app.services.oidc_jwks_service import OIDCJWKSService
|
||||
|
||||
|
||||
class OIDCTokenService:
|
||||
"""Service for generating and validating OIDC tokens.
|
||||
|
||||
This service handles:
|
||||
- Access token creation (JWT)
|
||||
- ID token creation (JWT)
|
||||
- Refresh token creation (opaque)
|
||||
- Token signature verification
|
||||
- Hash generation for PKCE claims (at_hash, c_hash)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_jti() -> str:
|
||||
"""Generate a unique JWT ID."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _generate_opaque_token(length: int = 43) -> str:
|
||||
"""Generate an opaque token (for refresh tokens).
|
||||
|
||||
Args:
|
||||
length: Length of the token
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded token
|
||||
"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
@staticmethod
|
||||
def _hash_token(token: str) -> str:
|
||||
"""Hash a token for secure storage.
|
||||
|
||||
Args:
|
||||
token: Token to hash
|
||||
|
||||
Returns:
|
||||
SHA256 hash of the token
|
||||
"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _base64url_encode(data: bytes) -> str:
|
||||
"""Encode bytes to base64url format without padding.
|
||||
|
||||
Args:
|
||||
data: Bytes to encode
|
||||
|
||||
Returns:
|
||||
Base64url encoded string
|
||||
"""
|
||||
return base64.urlsafe_b64encode(data).decode().rstrip("=")
|
||||
|
||||
@staticmethod
|
||||
def create_at_hash(access_token: str) -> str:
|
||||
"""Create the at_hash claim for ID token.
|
||||
|
||||
Implements OIDC spec for access token hash generation.
|
||||
Hash is the left-most half of the hash of the ASCII representation
|
||||
of the access token.
|
||||
|
||||
Args:
|
||||
access_token: The access token string
|
||||
|
||||
Returns:
|
||||
Base64url encoded hash
|
||||
"""
|
||||
# Hash the access token using SHA256
|
||||
hash_digest = hashlib.sha256(access_token.encode()).digest()
|
||||
|
||||
# Take left-most half of the hash
|
||||
half_length = len(hash_digest) // 2
|
||||
left_half = hash_digest[:half_length]
|
||||
|
||||
# Base64url encode
|
||||
return OIDCTokenService._base64url_encode(left_half)
|
||||
|
||||
@staticmethod
|
||||
def create_c_hash(code: str) -> str:
|
||||
"""Create the c_hash claim for ID token.
|
||||
|
||||
Implements OIDC spec for authorization code hash generation.
|
||||
|
||||
Args:
|
||||
code: The authorization code string
|
||||
|
||||
Returns:
|
||||
Base64url encoded hash
|
||||
"""
|
||||
# Hash the code using SHA256
|
||||
hash_digest = hashlib.sha256(code.encode()).digest()
|
||||
|
||||
# Take left-most half of the hash
|
||||
half_length = len(hash_digest) // 2
|
||||
left_half = hash_digest[:half_length]
|
||||
|
||||
# Base64url encode
|
||||
return OIDCTokenService._base64url_encode(left_half)
|
||||
|
||||
@staticmethod
|
||||
def _get_issuer() -> str:
|
||||
"""Get the OIDC issuer URL."""
|
||||
return current_app.config.get("OIDC_ISSUER_URL", "http://localhost:5000")
|
||||
|
||||
@staticmethod
|
||||
def _get_token_lifetime(client: OIDCClient, token_type: str) -> int:
|
||||
"""Get the token lifetime in seconds for a client.
|
||||
|
||||
Args:
|
||||
client: OIDCClient instance
|
||||
token_type: Type of token ("access_token", "refresh_token", "id_token")
|
||||
|
||||
Returns:
|
||||
Lifetime in seconds
|
||||
"""
|
||||
lifetimes = {
|
||||
"access_token": client.access_token_lifetime or 3600,
|
||||
"refresh_token": client.refresh_token_lifetime or 2592000,
|
||||
"id_token": client.id_token_lifetime or 3600,
|
||||
}
|
||||
return lifetimes.get(token_type, 3600)
|
||||
|
||||
@classmethod
|
||||
def create_access_token(cls, client_id: str, user_id: str, scope: list,
|
||||
jti: str = None) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID (subject)
|
||||
scope: List of granted scopes
|
||||
jti: Optional JWT ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
JWT access token string
|
||||
"""
|
||||
jti = jti or cls._generate_jti()
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
|
||||
|
||||
claims = {
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"nbf": int(now.timestamp()),
|
||||
"jti": jti,
|
||||
"client_id": client_id,
|
||||
"scope": " ".join(scope) if isinstance(scope, list) else scope,
|
||||
}
|
||||
|
||||
# Get signing key
|
||||
jwks_service = OIDCJWKSService()
|
||||
signing_key = jwks_service.get_signing_key()
|
||||
|
||||
if not signing_key:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
algorithm="RS256",
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def create_id_token(cls, client_id: str, user_id: str, nonce: str = None,
|
||||
scope: list = None, access_token: str = None,
|
||||
auth_time: int = None) -> str:
|
||||
"""Create a JWT ID token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID (subject)
|
||||
nonce: Nonce for replay protection
|
||||
scope: Requested/Granted scopes
|
||||
access_token: Associated access token (for at_hash)
|
||||
auth_time: Authentication time (Unix timestamp)
|
||||
|
||||
Returns:
|
||||
JWT ID token string
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
auth_time = auth_time or int(now.timestamp())
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
|
||||
|
||||
# Get user for claims
|
||||
user = User.query.get(user_id)
|
||||
|
||||
claims = {
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"auth_time": auth_time,
|
||||
}
|
||||
|
||||
# Add nonce if provided
|
||||
if nonce:
|
||||
claims["nonce"] = nonce
|
||||
|
||||
# Add at_hash if access token provided
|
||||
if access_token:
|
||||
claims["at_hash"] = cls.create_at_hash(access_token)
|
||||
|
||||
# Add standard claims if user exists
|
||||
if user:
|
||||
if user.email:
|
||||
claims["email"] = user.email
|
||||
claims["email_verified"] = user.email_verified
|
||||
if user.full_name:
|
||||
claims["name"] = user.full_name
|
||||
|
||||
# Add scope if provided
|
||||
if scope:
|
||||
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
|
||||
|
||||
# Get signing key
|
||||
jwks_service = OIDCJWKSService()
|
||||
signing_key = jwks_service.get_signing_key()
|
||||
|
||||
if not signing_key:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
algorithm="RS256",
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def create_refresh_token(cls, client_id: str, user_id: str,
|
||||
scope: list = None, access_token_id: str = None) -> str:
|
||||
"""Create an opaque refresh token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
scope: List of granted scopes
|
||||
access_token_id: Associated access token ID
|
||||
|
||||
Returns:
|
||||
Opaque refresh token string
|
||||
"""
|
||||
token = cls._generate_opaque_token()
|
||||
|
||||
# Hash for storage
|
||||
token_hash = cls._hash_token(token)
|
||||
|
||||
return token, token_hash
|
||||
|
||||
@classmethod
|
||||
def verify_token_signature(cls, token: str) -> Dict:
|
||||
"""Verify the signature of a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
Decoded token claims
|
||||
|
||||
Raises:
|
||||
jwt.InvalidSignatureError: If signature verification fails
|
||||
jwt.ExpiredSignatureError: If token is expired
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
"""
|
||||
# Get the JWKS with public keys
|
||||
jwks_service = OIDCJWKSService()
|
||||
jwks = jwks_service.get_jwks()
|
||||
|
||||
# Get the key ID from token header
|
||||
try:
|
||||
unverified_header = jwt.get_unverified_header(token)
|
||||
except jwt.DecodeError:
|
||||
raise jwt.InvalidTokenError("Invalid token header")
|
||||
|
||||
kid = unverified_header.get("kid")
|
||||
|
||||
# Find the matching public key
|
||||
public_key = None
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
public_key = serialization.load_pem_public_key(
|
||||
key["public_key"].encode() if isinstance(key["public_key"], str)
|
||||
else key["public_key"],
|
||||
backend=default_backend()
|
||||
)
|
||||
break
|
||||
except (ImportError, Exception):
|
||||
continue
|
||||
|
||||
if not public_key:
|
||||
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
|
||||
|
||||
# Verify the signature
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
public_key,
|
||||
algorithms=["RS256"],
|
||||
audience=None, # We'll validate audience separately
|
||||
issuer=cls._get_issuer(),
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_aud": False, # Handle audience manually
|
||||
"verify_iss": False, # Handle issuer manually
|
||||
}
|
||||
)
|
||||
|
||||
return claims
|
||||
|
||||
@classmethod
|
||||
def decode_token(cls, token: str, verify: bool = False) -> Dict:
|
||||
"""Decode a JWT token without verification (for debugging).
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
verify: Whether to verify signature
|
||||
|
||||
Returns:
|
||||
Decoded token claims
|
||||
"""
|
||||
if verify:
|
||||
return cls.verify_token_signature(token)
|
||||
|
||||
return jwt.decode(
|
||||
token,
|
||||
options={
|
||||
"verify_signature": False,
|
||||
"verify_exp": False,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
"""Validate an access token and return its claims.
|
||||
|
||||
Args:
|
||||
token: JWT access token
|
||||
client_id: Optional client ID to validate audience
|
||||
|
||||
Returns:
|
||||
Token claims dictionary
|
||||
|
||||
Raises:
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
ValueError: If token is expired or audience mismatch
|
||||
"""
|
||||
claims = cls.verify_token_signature(token)
|
||||
|
||||
# Check expiration
|
||||
if claims.get("exp", 0) < datetime.utcnow().timestamp():
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
# Validate audience if client_id provided
|
||||
if client_id:
|
||||
if claims.get("aud") != client_id:
|
||||
raise ValueError("Invalid audience")
|
||||
|
||||
return claims
|
||||
|
||||
@classmethod
|
||||
def introspect_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
"""Introspect a token and return its status and claims.
|
||||
|
||||
Args:
|
||||
token: JWT token to introspect
|
||||
client_id: Client ID for audience validation
|
||||
|
||||
Returns:
|
||||
Dictionary with active status and claims
|
||||
"""
|
||||
result = {
|
||||
"active": False,
|
||||
}
|
||||
|
||||
try:
|
||||
claims = cls.validate_access_token(token, client_id)
|
||||
|
||||
# Calculate remaining time
|
||||
now = datetime.utcnow().timestamp()
|
||||
exp = claims.get("exp", 0)
|
||||
iat = claims.get("iat", 0)
|
||||
|
||||
result["active"] = exp > now
|
||||
result.update({
|
||||
"iss": claims.get("iss"),
|
||||
"sub": claims.get("sub"),
|
||||
"aud": claims.get("aud"),
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": claims.get("nbf"),
|
||||
"jti": claims.get("jti"),
|
||||
"client_id": claims.get("client_id"),
|
||||
"scope": claims.get("scope"),
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
|
||||
# Add expiry in seconds
|
||||
if exp > now:
|
||||
result["exp"] = int(exp - now)
|
||||
|
||||
except (jwt.InvalidTokenError, ValueError) as e:
|
||||
result["active"] = False
|
||||
result["error"] = str(e)
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user