746 lines
24 KiB
Python
746 lines
24 KiB
Python
"""OIDC Service - Main OIDC service layer."""
|
|
import secrets
|
|
import hashlib
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from flask import current_app, g
|
|
|
|
from app.extensions import db
|
|
from app.models import (
|
|
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
|
|
OIDCSession, OIDCTokenMetadata
|
|
)
|
|
from app.exceptions.validation_exceptions import (
|
|
ValidationError, NotFoundError, BadRequestError
|
|
)
|
|
from app.exceptions.auth_exceptions import UnauthorizedError, InvalidTokenError
|
|
from app.services.oidc_token_service import OIDCTokenService
|
|
from app.services.oidc_session_service import OIDCSessionService
|
|
from app.services.oidc_audit_service import OIDCAuditService
|
|
from app.services.oidc_jwks_service import OIDCJWKSService
|
|
|
|
|
|
class OIDCError(Exception):
|
|
"""Base exception for OIDC errors."""
|
|
|
|
def __init__(self, error: str, error_description: str = None, status_code: int = 400):
|
|
self.error = error
|
|
self.error_description = error_description
|
|
self.status_code = status_code
|
|
|
|
|
|
class InvalidClientError(OIDCError):
|
|
"""Raised when client authentication fails."""
|
|
|
|
def __init__(self, error_description: str = "Invalid client"):
|
|
super().__init__("invalid_client", error_description, 401)
|
|
|
|
|
|
class InvalidGrantError(OIDCError):
|
|
"""Raised when grant is invalid."""
|
|
|
|
def __init__(self, error_description: str = "Invalid grant"):
|
|
super().__init__("invalid_grant", error_description, 400)
|
|
|
|
|
|
class InvalidRequestError(OIDCError):
|
|
"""Raised when request is malformed."""
|
|
|
|
def __init__(self, error_description: str = "Invalid request"):
|
|
super().__init__("invalid_request", error_description, 400)
|
|
|
|
|
|
class OIDCService:
|
|
"""Main OIDC service handling all OpenID Connect operations.
|
|
|
|
This service provides:
|
|
- Authorization code generation and validation
|
|
- Token generation (access, refresh, ID tokens)
|
|
- Token refresh with rotation
|
|
- Token validation and introspection
|
|
- Token revocation
|
|
"""
|
|
|
|
@staticmethod
|
|
def _generate_code() -> str:
|
|
"""Generate a secure authorization code.
|
|
|
|
Returns:
|
|
URL-safe base64 encoded code
|
|
"""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
@staticmethod
|
|
def _hash_value(value: str) -> str:
|
|
"""Hash a value for secure storage.
|
|
|
|
Args:
|
|
value: Value to hash
|
|
|
|
Returns:
|
|
SHA256 hash
|
|
"""
|
|
return hashlib.sha256(value.encode()).hexdigest()
|
|
|
|
@classmethod
|
|
def generate_authorization_code(
|
|
cls,
|
|
client_id: str,
|
|
user_id: str,
|
|
redirect_uri: str,
|
|
scope: list,
|
|
state: str,
|
|
nonce: str,
|
|
code_challenge: str = None,
|
|
code_challenge_method: str = None,
|
|
ip_address: str = None,
|
|
user_agent: str = None
|
|
) -> str:
|
|
"""Generate an authorization code for the auth code flow.
|
|
|
|
Args:
|
|
client_id: OIDC client ID
|
|
user_id: User ID
|
|
redirect_uri: Redirect URI
|
|
scope: Requested scopes
|
|
state: State parameter
|
|
nonce: Nonce for ID token
|
|
code_challenge: PKCE code challenge
|
|
code_challenge_method: PKCE method ("S256" or "plain")
|
|
ip_address: Client IP address
|
|
user_agent: Client user agent
|
|
|
|
Returns:
|
|
Authorization code string
|
|
|
|
Raises:
|
|
ValidationError: If parameters are invalid
|
|
NotFoundError: If client not found
|
|
"""
|
|
# Validate client exists and is active
|
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
|
if not client:
|
|
raise NotFoundError("Client not found")
|
|
|
|
if not client.is_active:
|
|
raise ValidationError("Client is not active")
|
|
|
|
# Validate redirect URI
|
|
if not client.is_redirect_uri_allowed(redirect_uri):
|
|
raise ValidationError("Invalid redirect_uri")
|
|
|
|
# Validate scopes
|
|
allowed_scopes = client.scopes or []
|
|
valid_scopes = [s for s in scope if s in allowed_scopes]
|
|
|
|
if not valid_scopes:
|
|
raise ValidationError("Invalid scopes")
|
|
|
|
# Generate authorization code
|
|
code = cls._generate_code()
|
|
code_hash = cls._hash_value(code)
|
|
|
|
# Create auth code record
|
|
auth_code = OIDCAuthCode.create_code(
|
|
client_id=client.id,
|
|
user_id=user_id,
|
|
code_hash=code_hash,
|
|
redirect_uri=redirect_uri,
|
|
scope=valid_scopes,
|
|
nonce=nonce,
|
|
code_verifier=code_challenge, # Store for validation
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
lifetime_seconds=600, # 10 minutes
|
|
)
|
|
|
|
# Log authorization event
|
|
OIDCAuditService.log_authorization_event(
|
|
client_id=client_id,
|
|
user_id=user_id,
|
|
success=True,
|
|
redirect_uri=redirect_uri,
|
|
scope=valid_scopes,
|
|
)
|
|
|
|
return code
|
|
|
|
@classmethod
|
|
def validate_authorization_code(
|
|
cls,
|
|
code: str,
|
|
client_id: str,
|
|
redirect_uri: str,
|
|
code_verifier: str = None,
|
|
ip_address: str = None,
|
|
user_agent: str = None
|
|
) -> Tuple[Dict, User]:
|
|
"""Validate and consume an authorization code.
|
|
|
|
Args:
|
|
code: Authorization code
|
|
client_id: OIDC client ID
|
|
redirect_uri: Redirect URI
|
|
code_verifier: PKCE code verifier (required if PKCE was used)
|
|
ip_address: Client IP address
|
|
user_agent: Client user agent
|
|
|
|
Returns:
|
|
Tuple of (claims dict, User instance)
|
|
|
|
Raises:
|
|
InvalidGrantError: If code is invalid
|
|
ValidationError: If PKCE validation fails
|
|
"""
|
|
# Get client
|
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
|
if not client:
|
|
raise InvalidGrantError("Invalid client")
|
|
|
|
# Hash the provided code and find matching auth code
|
|
code_hash = cls._hash_value(code)
|
|
auth_code = OIDCAuthCode.query.filter_by(
|
|
code_hash=code_hash,
|
|
client_id=client.id,
|
|
deleted_at=None
|
|
).first()
|
|
|
|
if not auth_code:
|
|
OIDCAuditService.log_authorization_event(
|
|
client_id=client_id,
|
|
success=False,
|
|
error_code="invalid_grant",
|
|
error_description="Invalid or expired authorization code",
|
|
)
|
|
raise InvalidGrantError("Invalid or expired authorization code")
|
|
|
|
# Check if already used
|
|
if auth_code.is_used:
|
|
OIDCAuditService.log_authorization_event(
|
|
client_id=client_id,
|
|
user_id=auth_code.user_id,
|
|
success=False,
|
|
error_code="invalid_grant",
|
|
error_description="Authorization code already used",
|
|
)
|
|
raise InvalidGrantError("Authorization code already used")
|
|
|
|
# Check expiration
|
|
if auth_code.is_expired():
|
|
OIDCAuditService.log_authorization_event(
|
|
client_id=client_id,
|
|
user_id=auth_code.user_id,
|
|
success=False,
|
|
error_code="invalid_grant",
|
|
error_description="Authorization code expired",
|
|
)
|
|
raise InvalidGrantError("Authorization code expired")
|
|
|
|
# Validate redirect URI
|
|
if auth_code.redirect_uri != redirect_uri:
|
|
raise InvalidGrantError("Invalid redirect_uri")
|
|
|
|
# Validate PKCE if required
|
|
if client.require_pkce and auth_code.code_verifier:
|
|
if not code_verifier:
|
|
raise ValidationError("code_verifier is required")
|
|
|
|
# Verify code verifier
|
|
expected_challenge = cls._compute_code_challenge(code_verifier, "S256")
|
|
if expected_challenge != auth_code.code_verifier:
|
|
OIDCAuditService.log_authorization_event(
|
|
client_id=client_id,
|
|
user_id=auth_code.user_id,
|
|
success=False,
|
|
error_code="invalid_grant",
|
|
error_description="Invalid code_verifier",
|
|
)
|
|
raise InvalidGrantError("Invalid code_verifier")
|
|
|
|
# Mark code as used
|
|
auth_code.mark_as_used()
|
|
|
|
# Get user
|
|
user = User.query.get(auth_code.user_id)
|
|
if not user:
|
|
raise InvalidGrantError("User not found")
|
|
|
|
claims = {
|
|
"user_id": auth_code.user_id,
|
|
"client_id": client_id,
|
|
"redirect_uri": redirect_uri,
|
|
"scope": auth_code.scope,
|
|
"nonce": auth_code.nonce,
|
|
}
|
|
|
|
return claims, user
|
|
|
|
@classmethod
|
|
def _compute_code_challenge(cls, verifier: str, method: str = "S256") -> str:
|
|
"""Compute PKCE code challenge from verifier.
|
|
|
|
Args:
|
|
verifier: Code verifier
|
|
method: Challenge method
|
|
|
|
Returns:
|
|
Code challenge
|
|
"""
|
|
import hashlib
|
|
import base64
|
|
|
|
if method == "S256":
|
|
digest = hashlib.sha256(verifier.encode()).digest()
|
|
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
|
return verifier
|
|
|
|
@classmethod
|
|
def generate_tokens(
|
|
cls,
|
|
client_id: str,
|
|
user_id: str,
|
|
scope: list,
|
|
nonce: str = None,
|
|
refresh_token: str = None,
|
|
ip_address: str = None,
|
|
user_agent: str = None,
|
|
auth_time: int = None
|
|
) -> Dict:
|
|
"""Generate access token, ID token, and refresh token.
|
|
|
|
Args:
|
|
client_id: OIDC client ID
|
|
user_id: User ID
|
|
scope: Granted scopes
|
|
nonce: Nonce for ID token
|
|
refresh_token: Existing refresh token (for rotation)
|
|
ip_address: Client IP address
|
|
user_agent: Client user agent
|
|
auth_time: Authentication time
|
|
|
|
Returns:
|
|
Dictionary with tokens
|
|
"""
|
|
import hashlib
|
|
|
|
# Get client
|
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
|
if not client:
|
|
raise InvalidClientError()
|
|
|
|
# Generate access token
|
|
access_token_jti = OIDCTokenService._generate_jti()
|
|
access_token = OIDCTokenService.create_access_token(
|
|
client_id=client_id,
|
|
user_id=user_id,
|
|
scope=scope,
|
|
jti=access_token_jti,
|
|
)
|
|
|
|
# Generate ID token
|
|
id_token = OIDCTokenService.create_id_token(
|
|
client_id=client_id,
|
|
user_id=user_id,
|
|
nonce=nonce,
|
|
scope=scope,
|
|
access_token=access_token,
|
|
auth_time=auth_time,
|
|
)
|
|
|
|
# Generate or rotate refresh token
|
|
if "refresh_token" in (client.grant_types or []):
|
|
if refresh_token:
|
|
# Rotate existing refresh token
|
|
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
|
token_hash=hashlib.sha256(refresh_token.encode()).hexdigest(),
|
|
deleted_at=None
|
|
).first()
|
|
|
|
if refresh_token_obj and refresh_token_obj.is_valid():
|
|
# Create new refresh token
|
|
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
|
|
client_id=client_id,
|
|
user_id=user_id,
|
|
scope=scope,
|
|
access_token_id=access_token_jti,
|
|
)
|
|
|
|
# Rotate in database
|
|
refresh_token_obj.rotate(new_hash)
|
|
final_refresh_token = new_refresh
|
|
else:
|
|
final_refresh_token = None
|
|
else:
|
|
# Create new refresh token
|
|
final_refresh_token, refresh_hash = OIDCTokenService.create_refresh_token(
|
|
client_id=client_id,
|
|
user_id=user_id,
|
|
scope=scope,
|
|
access_token_id=access_token_jti,
|
|
)
|
|
|
|
# Store refresh token
|
|
OIDCRefreshToken.create_token(
|
|
client_id=client.id,
|
|
user_id=user_id,
|
|
token_hash=refresh_hash,
|
|
scope=scope,
|
|
access_token_id=access_token_jti,
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
lifetime_seconds=client.refresh_token_lifetime or 2592000,
|
|
)
|
|
else:
|
|
final_refresh_token = None
|
|
|
|
# Store token metadata
|
|
client_db_id = client.id
|
|
|
|
# Access token metadata
|
|
OIDCTokenMetadata.create_metadata(
|
|
client_id=client_db_id,
|
|
user_id=user_id,
|
|
token_type="access_token",
|
|
token_jti=access_token_jti,
|
|
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
|
|
)
|
|
|
|
# ID token metadata (using access token JTI as reference)
|
|
id_token_jti = OIDCTokenService._generate_jti()
|
|
OIDCTokenMetadata.create_metadata(
|
|
client_id=client_db_id,
|
|
user_id=user_id,
|
|
token_type="id_token",
|
|
token_jti=id_token_jti,
|
|
expires_at=datetime.utcnow() + timedelta(seconds=client.id_token_lifetime or 3600),
|
|
)
|
|
|
|
# Log token event
|
|
OIDCAuditService.log_token_event(
|
|
client_id=client_id,
|
|
user_id=user_id,
|
|
token_type="access_token",
|
|
success=True,
|
|
grant_type="authorization_code",
|
|
scopes=scope,
|
|
)
|
|
|
|
result = {
|
|
"access_token": access_token,
|
|
"token_type": "Bearer",
|
|
"expires_in": client.access_token_lifetime or 3600,
|
|
"id_token": id_token,
|
|
}
|
|
|
|
if final_refresh_token:
|
|
result["refresh_token"] = final_refresh_token
|
|
|
|
return result
|
|
|
|
@classmethod
|
|
def refresh_access_token(
|
|
cls,
|
|
refresh_token: str,
|
|
client_id: str,
|
|
scope: list = None,
|
|
ip_address: str = None,
|
|
user_agent: str = None
|
|
) -> Dict:
|
|
"""Refresh an access token with token rotation.
|
|
|
|
Args:
|
|
refresh_token: The refresh token
|
|
client_id: OIDC client ID
|
|
scope: Optional scope override
|
|
ip_address: Client IP address
|
|
user_agent: Client user agent
|
|
|
|
Returns:
|
|
Dictionary with new tokens
|
|
|
|
Raises:
|
|
InvalidGrantError: If refresh token is invalid
|
|
"""
|
|
import hashlib
|
|
|
|
# Get client
|
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
|
if not client:
|
|
raise InvalidClientError()
|
|
|
|
# Find refresh token
|
|
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
|
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
|
token_hash=token_hash,
|
|
deleted_at=None
|
|
).first()
|
|
|
|
if not refresh_token_obj:
|
|
OIDCAuditService.log_token_event(
|
|
client_id=client_id,
|
|
success=False,
|
|
error_code="invalid_grant",
|
|
error_description="Invalid refresh token",
|
|
)
|
|
raise InvalidGrantError("Invalid refresh token")
|
|
|
|
# Check if valid
|
|
if not refresh_token_obj.is_valid():
|
|
OIDCAuditService.log_token_event(
|
|
client_id=client_id,
|
|
user_id=refresh_token_obj.user_id,
|
|
success=False,
|
|
error_code="invalid_grant",
|
|
error_description="Refresh token expired or revoked",
|
|
)
|
|
raise InvalidGrantError("Refresh token expired or revoked")
|
|
|
|
# Validate client matches
|
|
if refresh_token_obj.client_id != client.id:
|
|
raise InvalidGrantError("Client mismatch")
|
|
|
|
# Get original scope or use provided
|
|
granted_scope = scope or (refresh_token_obj.scope or [])
|
|
|
|
# Generate new access token
|
|
access_token_jti = OIDCTokenService._generate_jti()
|
|
access_token = OIDCTokenService.create_access_token(
|
|
client_id=client_id,
|
|
user_id=refresh_token_obj.user_id,
|
|
scope=granted_scope,
|
|
jti=access_token_jti,
|
|
)
|
|
|
|
# Generate new ID token
|
|
id_token = OIDCTokenService.create_id_token(
|
|
client_id=client_id,
|
|
user_id=refresh_token_obj.user_id,
|
|
scope=granted_scope,
|
|
access_token=access_token,
|
|
)
|
|
|
|
# Rotate refresh token
|
|
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
|
|
client_id=client_id,
|
|
user_id=refresh_token_obj.user_id,
|
|
scope=granted_scope,
|
|
access_token_id=access_token_jti,
|
|
)
|
|
|
|
refresh_token_obj.rotate(new_hash)
|
|
|
|
# Store new token metadata
|
|
OIDCTokenMetadata.create_metadata(
|
|
client_id=client.id,
|
|
user_id=refresh_token_obj.user_id,
|
|
token_type="access_token",
|
|
token_jti=access_token_jti,
|
|
expires_at=datetime.utcnow() + timedelta(seconds=client.access_token_lifetime or 3600),
|
|
)
|
|
|
|
# Log refresh event
|
|
OIDCAuditService.log_token_event(
|
|
client_id=client_id,
|
|
user_id=refresh_token_obj.user_id,
|
|
token_type="access_token",
|
|
success=True,
|
|
grant_type="refresh_token",
|
|
scopes=granted_scope,
|
|
)
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"token_type": "Bearer",
|
|
"expires_in": client.access_token_lifetime or 3600,
|
|
"id_token": id_token,
|
|
"refresh_token": new_refresh,
|
|
}
|
|
|
|
@classmethod
|
|
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
|
"""Validate an access token and return its claims.
|
|
|
|
Args:
|
|
token: JWT access token
|
|
client_id: Optional client ID to validate audience
|
|
|
|
Returns:
|
|
Token claims
|
|
|
|
Raises:
|
|
InvalidTokenError: If token is invalid
|
|
"""
|
|
try:
|
|
claims = OIDCTokenService.validate_access_token(token, client_id)
|
|
return claims
|
|
except Exception as e:
|
|
OIDCAuditService.log_event(
|
|
event_type="token_validation",
|
|
client_id=client_id,
|
|
success=False,
|
|
error_code="invalid_token",
|
|
error_description=str(e),
|
|
)
|
|
raise InvalidTokenError(str(e))
|
|
|
|
@classmethod
|
|
def revoke_token(
|
|
cls,
|
|
token: str,
|
|
client_id: str,
|
|
token_type_hint: str = None,
|
|
ip_address: str = None,
|
|
user_agent: str = None
|
|
) -> bool:
|
|
"""Revoke a token.
|
|
|
|
Args:
|
|
token: Token to revoke
|
|
client_id: OIDC client ID
|
|
token_type_hint: Hint about token type
|
|
ip_address: Client IP address
|
|
user_agent: Client user agent
|
|
|
|
Returns:
|
|
True if token was revoked
|
|
"""
|
|
import hashlib
|
|
|
|
# Get client
|
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
|
if not client:
|
|
raise InvalidClientError()
|
|
|
|
revoked = False
|
|
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
# Try to revoke as refresh token
|
|
if token_type_hint in (None, "refresh_token"):
|
|
refresh_token = OIDCRefreshToken.query.filter_by(
|
|
token_hash=token_hash,
|
|
deleted_at=None
|
|
).first()
|
|
|
|
if refresh_token:
|
|
refresh_token.revoke(reason="revoked_by_client")
|
|
revoked = True
|
|
|
|
OIDCAuditService.log_token_revocation_event(
|
|
client_id=client_id,
|
|
user_id=refresh_token.user_id,
|
|
token_type="refresh_token",
|
|
reason="revoked_by_client",
|
|
)
|
|
|
|
# Try to revoke as access token (JTI lookup)
|
|
if not revoked or token_type_hint in (None, "access_token"):
|
|
try:
|
|
# Decode token to get JTI
|
|
claims = OIDCTokenService.decode_token(token)
|
|
jti = claims.get("jti")
|
|
|
|
if jti:
|
|
revoked_at = OIDCTokenMetadata.revoke_by_jti(
|
|
jti,
|
|
reason="revoked_by_client"
|
|
)
|
|
if revoked_at:
|
|
revoked = True
|
|
|
|
OIDCAuditService.log_token_revocation_event(
|
|
client_id=client_id,
|
|
user_id=claims.get("sub"),
|
|
token_type="access_token",
|
|
reason="revoked_by_client",
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
return revoked
|
|
|
|
@classmethod
|
|
def introspect_token(
|
|
cls,
|
|
token: str,
|
|
client_id: str = None,
|
|
ip_address: str = None,
|
|
user_agent: str = None
|
|
) -> Dict:
|
|
"""Introspect a token and return its status and claims.
|
|
|
|
Args:
|
|
token: Token to introspect
|
|
client_id: Client ID for validation
|
|
ip_address: Client IP address
|
|
user_agent: Client user agent
|
|
|
|
Returns:
|
|
Introspection response
|
|
"""
|
|
result = OIDCTokenService.introspect_token(token, client_id)
|
|
|
|
# Log introspection
|
|
OIDCAuditService.log_event(
|
|
event_type="token_introspection",
|
|
client_id=client_id,
|
|
user_id=result.get("sub"),
|
|
success=result.get("active", False),
|
|
metadata={"active": result.get("active")},
|
|
)
|
|
|
|
return result
|
|
|
|
@classmethod
|
|
def get_jwks(cls) -> Dict:
|
|
"""Get the JWKS document.
|
|
|
|
Returns:
|
|
JWKS document
|
|
"""
|
|
jwks_service = OIDCJWKSService()
|
|
return jwks_service.get_jwks()
|
|
|
|
@classmethod
|
|
def get_userinfo(cls, access_token: str) -> Dict:
|
|
"""Get user information using access token.
|
|
|
|
Args:
|
|
access_token: Access token
|
|
|
|
Returns:
|
|
User information dictionary
|
|
"""
|
|
claims = cls.validate_access_token(access_token)
|
|
|
|
user_id = claims.get("sub")
|
|
user = User.query.get(user_id)
|
|
|
|
if not user:
|
|
raise NotFoundError("User not found")
|
|
|
|
# Get scopes from token
|
|
scope_str = claims.get("scope", "")
|
|
scopes = scope_str.split() if scope_str else []
|
|
|
|
userinfo = {"sub": user_id}
|
|
|
|
# Add claims based on scope
|
|
if "profile" in scopes and user.full_name:
|
|
userinfo["name"] = user.full_name
|
|
|
|
if "email" in scopes:
|
|
userinfo["email"] = user.email
|
|
userinfo["email_verified"] = user.email_verified
|
|
|
|
# Log userinfo access
|
|
OIDCAuditService.log_userinfo_event(
|
|
access_token=access_token,
|
|
user_id=user_id,
|
|
client_id=claims.get("client_id"),
|
|
success=True,
|
|
scopes_claimed=scopes,
|
|
)
|
|
|
|
return userinfo
|