major checkpoint
This commit is contained in:
@@ -0,0 +1,439 @@
|
||||
"""OIDC Token Service for JWT token generation and validation."""
|
||||
import hashlib
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
import jwt
|
||||
from flask import current_app, g
|
||||
|
||||
from app.models import User, OIDCClient
|
||||
from app.services.oidc_jwks_service import OIDCJWKSService
|
||||
|
||||
|
||||
class OIDCTokenService:
|
||||
"""Service for generating and validating OIDC tokens.
|
||||
|
||||
This service handles:
|
||||
- Access token creation (JWT)
|
||||
- ID token creation (JWT)
|
||||
- Refresh token creation (opaque)
|
||||
- Token signature verification
|
||||
- Hash generation for PKCE claims (at_hash, c_hash)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_jti() -> str:
|
||||
"""Generate a unique JWT ID."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _generate_opaque_token(length: int = 43) -> str:
|
||||
"""Generate an opaque token (for refresh tokens).
|
||||
|
||||
Args:
|
||||
length: Length of the token
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded token
|
||||
"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
@staticmethod
|
||||
def _hash_token(token: str) -> str:
|
||||
"""Hash a token for secure storage.
|
||||
|
||||
Args:
|
||||
token: Token to hash
|
||||
|
||||
Returns:
|
||||
SHA256 hash of the token
|
||||
"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _base64url_encode(data: bytes) -> str:
|
||||
"""Encode bytes to base64url format without padding.
|
||||
|
||||
Args:
|
||||
data: Bytes to encode
|
||||
|
||||
Returns:
|
||||
Base64url encoded string
|
||||
"""
|
||||
return base64.urlsafe_b64encode(data).decode().rstrip("=")
|
||||
|
||||
@staticmethod
|
||||
def create_at_hash(access_token: str) -> str:
|
||||
"""Create the at_hash claim for ID token.
|
||||
|
||||
Implements OIDC spec for access token hash generation.
|
||||
Hash is the left-most half of the hash of the ASCII representation
|
||||
of the access token.
|
||||
|
||||
Args:
|
||||
access_token: The access token string
|
||||
|
||||
Returns:
|
||||
Base64url encoded hash
|
||||
"""
|
||||
# Hash the access token using SHA256
|
||||
hash_digest = hashlib.sha256(access_token.encode()).digest()
|
||||
|
||||
# Take left-most half of the hash
|
||||
half_length = len(hash_digest) // 2
|
||||
left_half = hash_digest[:half_length]
|
||||
|
||||
# Base64url encode
|
||||
return OIDCTokenService._base64url_encode(left_half)
|
||||
|
||||
@staticmethod
|
||||
def create_c_hash(code: str) -> str:
|
||||
"""Create the c_hash claim for ID token.
|
||||
|
||||
Implements OIDC spec for authorization code hash generation.
|
||||
|
||||
Args:
|
||||
code: The authorization code string
|
||||
|
||||
Returns:
|
||||
Base64url encoded hash
|
||||
"""
|
||||
# Hash the code using SHA256
|
||||
hash_digest = hashlib.sha256(code.encode()).digest()
|
||||
|
||||
# Take left-most half of the hash
|
||||
half_length = len(hash_digest) // 2
|
||||
left_half = hash_digest[:half_length]
|
||||
|
||||
# Base64url encode
|
||||
return OIDCTokenService._base64url_encode(left_half)
|
||||
|
||||
@staticmethod
|
||||
def _get_issuer() -> str:
|
||||
"""Get the OIDC issuer URL."""
|
||||
return current_app.config.get("OIDC_ISSUER_URL", "http://localhost:5000")
|
||||
|
||||
@staticmethod
|
||||
def _get_token_lifetime(client: OIDCClient, token_type: str) -> int:
|
||||
"""Get the token lifetime in seconds for a client.
|
||||
|
||||
Args:
|
||||
client: OIDCClient instance
|
||||
token_type: Type of token ("access_token", "refresh_token", "id_token")
|
||||
|
||||
Returns:
|
||||
Lifetime in seconds
|
||||
"""
|
||||
lifetimes = {
|
||||
"access_token": client.access_token_lifetime or 3600,
|
||||
"refresh_token": client.refresh_token_lifetime or 2592000,
|
||||
"id_token": client.id_token_lifetime or 3600,
|
||||
}
|
||||
return lifetimes.get(token_type, 3600)
|
||||
|
||||
@classmethod
|
||||
def create_access_token(cls, client_id: str, user_id: str, scope: list,
|
||||
jti: str = None) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID (subject)
|
||||
scope: List of granted scopes
|
||||
jti: Optional JWT ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
JWT access token string
|
||||
"""
|
||||
jti = jti or cls._generate_jti()
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
|
||||
|
||||
claims = {
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"nbf": int(now.timestamp()),
|
||||
"jti": jti,
|
||||
"client_id": client_id,
|
||||
"scope": " ".join(scope) if isinstance(scope, list) else scope,
|
||||
}
|
||||
|
||||
# Get signing key
|
||||
jwks_service = OIDCJWKSService()
|
||||
signing_key = jwks_service.get_signing_key()
|
||||
|
||||
if not signing_key:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
algorithm="RS256",
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def create_id_token(cls, client_id: str, user_id: str, nonce: str = None,
|
||||
scope: list = None, access_token: str = None,
|
||||
auth_time: int = None) -> str:
|
||||
"""Create a JWT ID token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID (subject)
|
||||
nonce: Nonce for replay protection
|
||||
scope: Requested/Granted scopes
|
||||
access_token: Associated access token (for at_hash)
|
||||
auth_time: Authentication time (Unix timestamp)
|
||||
|
||||
Returns:
|
||||
JWT ID token string
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
auth_time = auth_time or int(now.timestamp())
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
|
||||
|
||||
# Get user for claims
|
||||
user = User.query.get(user_id)
|
||||
|
||||
claims = {
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": int((now + timedelta(seconds=lifetime)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"auth_time": auth_time,
|
||||
}
|
||||
|
||||
# Add nonce if provided
|
||||
if nonce:
|
||||
claims["nonce"] = nonce
|
||||
|
||||
# Add at_hash if access token provided
|
||||
if access_token:
|
||||
claims["at_hash"] = cls.create_at_hash(access_token)
|
||||
|
||||
# Add standard claims if user exists
|
||||
if user:
|
||||
if user.email:
|
||||
claims["email"] = user.email
|
||||
claims["email_verified"] = user.email_verified
|
||||
if user.full_name:
|
||||
claims["name"] = user.full_name
|
||||
|
||||
# Add scope if provided
|
||||
if scope:
|
||||
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
|
||||
|
||||
# Get signing key
|
||||
jwks_service = OIDCJWKSService()
|
||||
signing_key = jwks_service.get_signing_key()
|
||||
|
||||
if not signing_key:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
algorithm="RS256",
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def create_refresh_token(cls, client_id: str, user_id: str,
|
||||
scope: list = None, access_token_id: str = None) -> str:
|
||||
"""Create an opaque refresh token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
scope: List of granted scopes
|
||||
access_token_id: Associated access token ID
|
||||
|
||||
Returns:
|
||||
Opaque refresh token string
|
||||
"""
|
||||
token = cls._generate_opaque_token()
|
||||
|
||||
# Hash for storage
|
||||
token_hash = cls._hash_token(token)
|
||||
|
||||
return token, token_hash
|
||||
|
||||
@classmethod
|
||||
def verify_token_signature(cls, token: str) -> Dict:
|
||||
"""Verify the signature of a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
Decoded token claims
|
||||
|
||||
Raises:
|
||||
jwt.InvalidSignatureError: If signature verification fails
|
||||
jwt.ExpiredSignatureError: If token is expired
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
"""
|
||||
# Get the JWKS with public keys
|
||||
jwks_service = OIDCJWKSService()
|
||||
jwks = jwks_service.get_jwks()
|
||||
|
||||
# Get the key ID from token header
|
||||
try:
|
||||
unverified_header = jwt.get_unverified_header(token)
|
||||
except jwt.DecodeError:
|
||||
raise jwt.InvalidTokenError("Invalid token header")
|
||||
|
||||
kid = unverified_header.get("kid")
|
||||
|
||||
# Find the matching public key
|
||||
public_key = None
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
public_key = serialization.load_pem_public_key(
|
||||
key["public_key"].encode() if isinstance(key["public_key"], str)
|
||||
else key["public_key"],
|
||||
backend=default_backend()
|
||||
)
|
||||
break
|
||||
except (ImportError, Exception):
|
||||
continue
|
||||
|
||||
if not public_key:
|
||||
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
|
||||
|
||||
# Verify the signature
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
public_key,
|
||||
algorithms=["RS256"],
|
||||
audience=None, # We'll validate audience separately
|
||||
issuer=cls._get_issuer(),
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_aud": False, # Handle audience manually
|
||||
"verify_iss": False, # Handle issuer manually
|
||||
}
|
||||
)
|
||||
|
||||
return claims
|
||||
|
||||
@classmethod
|
||||
def decode_token(cls, token: str, verify: bool = False) -> Dict:
|
||||
"""Decode a JWT token without verification (for debugging).
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
verify: Whether to verify signature
|
||||
|
||||
Returns:
|
||||
Decoded token claims
|
||||
"""
|
||||
if verify:
|
||||
return cls.verify_token_signature(token)
|
||||
|
||||
return jwt.decode(
|
||||
token,
|
||||
options={
|
||||
"verify_signature": False,
|
||||
"verify_exp": False,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
"""Validate an access token and return its claims.
|
||||
|
||||
Args:
|
||||
token: JWT access token
|
||||
client_id: Optional client ID to validate audience
|
||||
|
||||
Returns:
|
||||
Token claims dictionary
|
||||
|
||||
Raises:
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
ValueError: If token is expired or audience mismatch
|
||||
"""
|
||||
claims = cls.verify_token_signature(token)
|
||||
|
||||
# Check expiration
|
||||
if claims.get("exp", 0) < datetime.utcnow().timestamp():
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
# Validate audience if client_id provided
|
||||
if client_id:
|
||||
if claims.get("aud") != client_id:
|
||||
raise ValueError("Invalid audience")
|
||||
|
||||
return claims
|
||||
|
||||
@classmethod
|
||||
def introspect_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
"""Introspect a token and return its status and claims.
|
||||
|
||||
Args:
|
||||
token: JWT token to introspect
|
||||
client_id: Client ID for audience validation
|
||||
|
||||
Returns:
|
||||
Dictionary with active status and claims
|
||||
"""
|
||||
result = {
|
||||
"active": False,
|
||||
}
|
||||
|
||||
try:
|
||||
claims = cls.validate_access_token(token, client_id)
|
||||
|
||||
# Calculate remaining time
|
||||
now = datetime.utcnow().timestamp()
|
||||
exp = claims.get("exp", 0)
|
||||
iat = claims.get("iat", 0)
|
||||
|
||||
result["active"] = exp > now
|
||||
result.update({
|
||||
"iss": claims.get("iss"),
|
||||
"sub": claims.get("sub"),
|
||||
"aud": claims.get("aud"),
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": claims.get("nbf"),
|
||||
"jti": claims.get("jti"),
|
||||
"client_id": claims.get("client_id"),
|
||||
"scope": claims.get("scope"),
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
|
||||
# Add expiry in seconds
|
||||
if exp > now:
|
||||
result["exp"] = int(exp - now)
|
||||
|
||||
except (jwt.InvalidTokenError, ValueError) as e:
|
||||
result["active"] = False
|
||||
result["error"] = str(e)
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user