440 lines
14 KiB
Python
440 lines
14 KiB
Python
"""OIDC Token Service for JWT token generation and validation."""
|
|
import hashlib
|
|
import base64
|
|
import secrets
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, Optional, Any
|
|
|
|
import jwt
|
|
from flask import current_app, g
|
|
|
|
from app.models import User, OIDCClient
|
|
from app.services.oidc_jwks_service import OIDCJWKSService
|
|
|
|
|
|
class OIDCTokenService:
|
|
"""Service for generating and validating OIDC tokens.
|
|
|
|
This service handles:
|
|
- Access token creation (JWT)
|
|
- ID token creation (JWT)
|
|
- Refresh token creation (opaque)
|
|
- Token signature verification
|
|
- Hash generation for PKCE claims (at_hash, c_hash)
|
|
"""
|
|
|
|
@staticmethod
|
|
def _generate_jti() -> str:
|
|
"""Generate a unique JWT ID."""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
@staticmethod
|
|
def _generate_opaque_token(length: int = 43) -> str:
|
|
"""Generate an opaque token (for refresh tokens).
|
|
|
|
Args:
|
|
length: Length of the token
|
|
|
|
Returns:
|
|
URL-safe base64 encoded token
|
|
"""
|
|
return secrets.token_urlsafe(length)
|
|
|
|
@staticmethod
|
|
def _hash_token(token: str) -> str:
|
|
"""Hash a token for secure storage.
|
|
|
|
Args:
|
|
token: Token to hash
|
|
|
|
Returns:
|
|
SHA256 hash of the token
|
|
"""
|
|
return hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
@staticmethod
|
|
def _base64url_encode(data: bytes) -> str:
|
|
"""Encode bytes to base64url format without padding.
|
|
|
|
Args:
|
|
data: Bytes to encode
|
|
|
|
Returns:
|
|
Base64url encoded string
|
|
"""
|
|
return base64.urlsafe_b64encode(data).decode().rstrip("=")
|
|
|
|
@staticmethod
|
|
def create_at_hash(access_token: str) -> str:
|
|
"""Create the at_hash claim for ID token.
|
|
|
|
Implements OIDC spec for access token hash generation.
|
|
Hash is the left-most half of the hash of the ASCII representation
|
|
of the access token.
|
|
|
|
Args:
|
|
access_token: The access token string
|
|
|
|
Returns:
|
|
Base64url encoded hash
|
|
"""
|
|
# Hash the access token using SHA256
|
|
hash_digest = hashlib.sha256(access_token.encode()).digest()
|
|
|
|
# Take left-most half of the hash
|
|
half_length = len(hash_digest) // 2
|
|
left_half = hash_digest[:half_length]
|
|
|
|
# Base64url encode
|
|
return OIDCTokenService._base64url_encode(left_half)
|
|
|
|
@staticmethod
|
|
def create_c_hash(code: str) -> str:
|
|
"""Create the c_hash claim for ID token.
|
|
|
|
Implements OIDC spec for authorization code hash generation.
|
|
|
|
Args:
|
|
code: The authorization code string
|
|
|
|
Returns:
|
|
Base64url encoded hash
|
|
"""
|
|
# Hash the code using SHA256
|
|
hash_digest = hashlib.sha256(code.encode()).digest()
|
|
|
|
# Take left-most half of the hash
|
|
half_length = len(hash_digest) // 2
|
|
left_half = hash_digest[:half_length]
|
|
|
|
# Base64url encode
|
|
return OIDCTokenService._base64url_encode(left_half)
|
|
|
|
@staticmethod
|
|
def _get_issuer() -> str:
|
|
"""Get the OIDC issuer URL."""
|
|
return current_app.config.get("OIDC_ISSUER_URL", "http://localhost:5000")
|
|
|
|
@staticmethod
|
|
def _get_token_lifetime(client: OIDCClient, token_type: str) -> int:
|
|
"""Get the token lifetime in seconds for a client.
|
|
|
|
Args:
|
|
client: OIDCClient instance
|
|
token_type: Type of token ("access_token", "refresh_token", "id_token")
|
|
|
|
Returns:
|
|
Lifetime in seconds
|
|
"""
|
|
lifetimes = {
|
|
"access_token": client.access_token_lifetime or 3600,
|
|
"refresh_token": client.refresh_token_lifetime or 2592000,
|
|
"id_token": client.id_token_lifetime or 3600,
|
|
}
|
|
return lifetimes.get(token_type, 3600)
|
|
|
|
@classmethod
|
|
def create_access_token(cls, client_id: str, user_id: str, scope: list,
|
|
jti: str = None) -> str:
|
|
"""Create a JWT access token.
|
|
|
|
Args:
|
|
client_id: The OIDC client ID
|
|
user_id: The user ID (subject)
|
|
scope: List of granted scopes
|
|
jti: Optional JWT ID (generated if not provided)
|
|
|
|
Returns:
|
|
JWT access token string
|
|
"""
|
|
jti = jti or cls._generate_jti()
|
|
now = datetime.utcnow()
|
|
|
|
# Get client for token lifetime
|
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
|
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
|
|
|
|
claims = {
|
|
"iss": cls._get_issuer(),
|
|
"sub": user_id,
|
|
"aud": client_id,
|
|
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
|
"iat": int(now.timestamp()),
|
|
"nbf": int(now.timestamp()),
|
|
"jti": jti,
|
|
"client_id": client_id,
|
|
"scope": " ".join(scope) if isinstance(scope, list) else scope,
|
|
}
|
|
|
|
# Get signing key
|
|
jwks_service = OIDCJWKSService()
|
|
signing_key = jwks_service.get_signing_key()
|
|
|
|
if not signing_key:
|
|
raise ValueError("No signing key available")
|
|
|
|
# Sign with RS256
|
|
token = jwt.encode(
|
|
claims,
|
|
signing_key.private_key,
|
|
algorithm="RS256",
|
|
headers={"kid": signing_key.kid}
|
|
)
|
|
|
|
return token
|
|
|
|
@classmethod
|
|
def create_id_token(cls, client_id: str, user_id: str, nonce: str = None,
|
|
scope: list = None, access_token: str = None,
|
|
auth_time: int = None) -> str:
|
|
"""Create a JWT ID token.
|
|
|
|
Args:
|
|
client_id: The OIDC client ID
|
|
user_id: The user ID (subject)
|
|
nonce: Nonce for replay protection
|
|
scope: Requested/Granted scopes
|
|
access_token: Associated access token (for at_hash)
|
|
auth_time: Authentication time (Unix timestamp)
|
|
|
|
Returns:
|
|
JWT ID token string
|
|
"""
|
|
now = datetime.utcnow()
|
|
auth_time = auth_time or int(now.timestamp())
|
|
|
|
# Get client for token lifetime
|
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
|
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
|
|
|
|
# Get user for claims
|
|
user = User.query.get(user_id)
|
|
|
|
claims = {
|
|
"iss": cls._get_issuer(),
|
|
"sub": user_id,
|
|
"aud": client_id,
|
|
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
|
"iat": int(now.timestamp()),
|
|
"auth_time": auth_time,
|
|
}
|
|
|
|
# Add nonce if provided
|
|
if nonce:
|
|
claims["nonce"] = nonce
|
|
|
|
# Add at_hash if access token provided
|
|
if access_token:
|
|
claims["at_hash"] = cls.create_at_hash(access_token)
|
|
|
|
# Add standard claims if user exists
|
|
if user:
|
|
if user.email:
|
|
claims["email"] = user.email
|
|
claims["email_verified"] = user.email_verified
|
|
if user.full_name:
|
|
claims["name"] = user.full_name
|
|
|
|
# Add scope if provided
|
|
if scope:
|
|
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
|
|
|
|
# Get signing key
|
|
jwks_service = OIDCJWKSService()
|
|
signing_key = jwks_service.get_signing_key()
|
|
|
|
if not signing_key:
|
|
raise ValueError("No signing key available")
|
|
|
|
# Sign with RS256
|
|
token = jwt.encode(
|
|
claims,
|
|
signing_key.private_key,
|
|
algorithm="RS256",
|
|
headers={"kid": signing_key.kid}
|
|
)
|
|
|
|
return token
|
|
|
|
@classmethod
|
|
def create_refresh_token(cls, client_id: str, user_id: str,
|
|
scope: list = None, access_token_id: str = None) -> str:
|
|
"""Create an opaque refresh token.
|
|
|
|
Args:
|
|
client_id: The OIDC client ID
|
|
user_id: The user ID
|
|
scope: List of granted scopes
|
|
access_token_id: Associated access token ID
|
|
|
|
Returns:
|
|
Opaque refresh token string
|
|
"""
|
|
token = cls._generate_opaque_token()
|
|
|
|
# Hash for storage
|
|
token_hash = cls._hash_token(token)
|
|
|
|
return token, token_hash
|
|
|
|
@classmethod
|
|
def verify_token_signature(cls, token: str) -> Dict:
|
|
"""Verify the signature of a JWT token.
|
|
|
|
Args:
|
|
token: JWT token string
|
|
|
|
Returns:
|
|
Decoded token claims
|
|
|
|
Raises:
|
|
jwt.InvalidSignatureError: If signature verification fails
|
|
jwt.ExpiredSignatureError: If token is expired
|
|
jwt.InvalidTokenError: If token is invalid
|
|
"""
|
|
# Get the JWKS with public keys
|
|
jwks_service = OIDCJWKSService()
|
|
jwks = jwks_service.get_jwks()
|
|
|
|
# Get the key ID from token header
|
|
try:
|
|
unverified_header = jwt.get_unverified_header(token)
|
|
except jwt.DecodeError:
|
|
raise jwt.InvalidTokenError("Invalid token header")
|
|
|
|
kid = unverified_header.get("kid")
|
|
|
|
# Find the matching public key
|
|
public_key = None
|
|
for key in jwks.get("keys", []):
|
|
if key.get("kid") == kid:
|
|
try:
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.backends import default_backend
|
|
|
|
public_key = serialization.load_pem_public_key(
|
|
key["public_key"].encode() if isinstance(key["public_key"], str)
|
|
else key["public_key"],
|
|
backend=default_backend()
|
|
)
|
|
break
|
|
except (ImportError, Exception):
|
|
continue
|
|
|
|
if not public_key:
|
|
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
|
|
|
|
# Verify the signature
|
|
claims = jwt.decode(
|
|
token,
|
|
public_key,
|
|
algorithms=["RS256"],
|
|
audience=None, # We'll validate audience separately
|
|
issuer=cls._get_issuer(),
|
|
options={
|
|
"verify_signature": True,
|
|
"verify_exp": True,
|
|
"verify_aud": False, # Handle audience manually
|
|
"verify_iss": False, # Handle issuer manually
|
|
}
|
|
)
|
|
|
|
return claims
|
|
|
|
@classmethod
|
|
def decode_token(cls, token: str, verify: bool = False) -> Dict:
|
|
"""Decode a JWT token without verification (for debugging).
|
|
|
|
Args:
|
|
token: JWT token string
|
|
verify: Whether to verify signature
|
|
|
|
Returns:
|
|
Decoded token claims
|
|
"""
|
|
if verify:
|
|
return cls.verify_token_signature(token)
|
|
|
|
return jwt.decode(
|
|
token,
|
|
options={
|
|
"verify_signature": False,
|
|
"verify_exp": False,
|
|
}
|
|
)
|
|
|
|
@classmethod
|
|
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
|
"""Validate an access token and return its claims.
|
|
|
|
Args:
|
|
token: JWT access token
|
|
client_id: Optional client ID to validate audience
|
|
|
|
Returns:
|
|
Token claims dictionary
|
|
|
|
Raises:
|
|
jwt.InvalidTokenError: If token is invalid
|
|
ValueError: If token is expired or audience mismatch
|
|
"""
|
|
claims = cls.verify_token_signature(token)
|
|
|
|
# Check expiration
|
|
if claims.get("exp", 0) < datetime.utcnow().timestamp():
|
|
raise ValueError("Token has expired")
|
|
|
|
# Validate audience if client_id provided
|
|
if client_id:
|
|
if claims.get("aud") != client_id:
|
|
raise ValueError("Invalid audience")
|
|
|
|
return claims
|
|
|
|
@classmethod
|
|
def introspect_token(cls, token: str, client_id: str = None) -> Dict:
|
|
"""Introspect a token and return its status and claims.
|
|
|
|
Args:
|
|
token: JWT token to introspect
|
|
client_id: Client ID for audience validation
|
|
|
|
Returns:
|
|
Dictionary with active status and claims
|
|
"""
|
|
result = {
|
|
"active": False,
|
|
}
|
|
|
|
try:
|
|
claims = cls.validate_access_token(token, client_id)
|
|
|
|
# Calculate remaining time
|
|
now = datetime.utcnow().timestamp()
|
|
exp = claims.get("exp", 0)
|
|
iat = claims.get("iat", 0)
|
|
|
|
result["active"] = exp > now
|
|
result.update({
|
|
"iss": claims.get("iss"),
|
|
"sub": claims.get("sub"),
|
|
"aud": claims.get("aud"),
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": claims.get("nbf"),
|
|
"jti": claims.get("jti"),
|
|
"client_id": claims.get("client_id"),
|
|
"scope": claims.get("scope"),
|
|
"token_type": "Bearer",
|
|
})
|
|
|
|
# Add expiry in seconds
|
|
if exp > now:
|
|
result["exp"] = int(exp - now)
|
|
|
|
except (jwt.InvalidTokenError, ValueError) as e:
|
|
result["active"] = False
|
|
result["error"] = str(e)
|
|
|
|
return result
|