419 lines
15 KiB
Python
419 lines
15 KiB
Python
"""OIDC JWKS Service for key management and rotation."""
|
|
import uuid
|
|
import json
|
|
import hashlib
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from flask import current_app
|
|
|
|
from gatehouse_app.extensions import db
|
|
from gatehouse_app.models.oidc_jwks_key import OidcJwksKey
|
|
|
|
|
|
class JWKSKey:
|
|
"""Represents a JWKS key entry."""
|
|
|
|
def __init__(self, kid: str, private_key: str, public_key: str,
|
|
algorithm: str = "RS256", created_at: datetime = None,
|
|
expires_at: datetime = None, is_active: bool = True):
|
|
self.kid = kid
|
|
self.private_key = private_key
|
|
self.public_key = public_key
|
|
self.algorithm = algorithm
|
|
self.created_at = created_at or datetime.now(timezone.utc)
|
|
self.expires_at = expires_at or datetime.now(timezone.utc) + timedelta(days=365)
|
|
self.is_active = is_active
|
|
|
|
def to_jwk(self) -> Dict:
|
|
"""Convert to JWK format for JWKS endpoint."""
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
|
from cryptography.hazmat.backends import default_backend
|
|
|
|
# Import cryptography here to avoid issues if not installed
|
|
try:
|
|
# Get public key from PEM
|
|
public_key = serialization.load_pem_public_key(
|
|
self.public_key.encode(), backend=default_backend()
|
|
)
|
|
|
|
# Get RSA parameters
|
|
public_numbers = public_key.public_numbers()
|
|
|
|
return {
|
|
"kty": "RSA",
|
|
"kid": self.kid,
|
|
"use": "sig",
|
|
"alg": self.algorithm,
|
|
"n": _base64url_encode(public_numbers.n),
|
|
"e": _base64url_encode(public_numbers.e),
|
|
}
|
|
except ImportError:
|
|
# Fallback for when cryptography is not installed
|
|
return {
|
|
"kty": "RSA",
|
|
"kid": self.kid,
|
|
"use": "sig",
|
|
"alg": self.algorithm,
|
|
}
|
|
|
|
def to_dict(self) -> Dict:
|
|
"""Convert to dictionary for storage."""
|
|
return {
|
|
"kid": self.kid,
|
|
"private_key": self.private_key,
|
|
"public_key": self.public_key,
|
|
"algorithm": self.algorithm,
|
|
"created_at": self.created_at.isoformat(),
|
|
"expires_at": self.expires_at.isoformat(),
|
|
"is_active": self.is_active,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict) -> "JWKSKey":
|
|
"""Create from dictionary."""
|
|
return cls(
|
|
kid=data["kid"],
|
|
private_key=data["private_key"],
|
|
public_key=data["public_key"],
|
|
algorithm=data.get("algorithm", "RS256"),
|
|
created_at=datetime.fromisoformat(data["created_at"]),
|
|
expires_at=datetime.fromisoformat(data["expires_at"]),
|
|
is_active=data.get("is_active", True),
|
|
)
|
|
|
|
|
|
def _base64url_encode(value: int) -> str:
|
|
"""Encode an integer to base64url format."""
|
|
import base64
|
|
byte_length = (value.bit_length() + 7) // 8 or 1
|
|
encoded = value.to_bytes(byte_length, byteorder="big")
|
|
return base64.urlsafe_b64encode(encoded).decode().rstrip("=")
|
|
|
|
|
|
class OIDCJWKSService:
|
|
"""Service for managing OIDC signing keys (JWKS).
|
|
|
|
This service handles RSA key pair generation, rotation, and JWKS document
|
|
generation for the OIDC implementation.
|
|
"""
|
|
|
|
_instance = None
|
|
_keys: Dict[str, JWKSKey] = {}
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._keys = {}
|
|
return cls._instance
|
|
|
|
@classmethod
|
|
def reset(cls):
|
|
"""Reset the singleton (for testing)."""
|
|
cls._instance = None
|
|
cls._keys = {}
|
|
|
|
def _generate_kid(self, private_key: str) -> str:
|
|
"""Generate a key ID from the private key fingerprint."""
|
|
kid_hash = hashlib.sha256(private_key.encode()).hexdigest()[:32]
|
|
return kid_hash
|
|
|
|
def _generate_rsa_key_pair(self) -> Tuple[str, str]:
|
|
"""Generate a new RSA key pair in PEM format.
|
|
|
|
Returns:
|
|
Tuple of (private_key_pem, public_key_pem)
|
|
"""
|
|
try:
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.backends import default_backend
|
|
|
|
# Generate RSA private key
|
|
private_key = rsa.generate_private_key(
|
|
public_exponent=65537,
|
|
key_size=2048,
|
|
backend=default_backend()
|
|
)
|
|
|
|
# Get public key
|
|
public_key = private_key.public_key()
|
|
|
|
# Serialize to PEM
|
|
private_pem = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption()
|
|
).decode()
|
|
|
|
public_pem = public_key.public_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
).decode()
|
|
|
|
return private_pem, public_pem
|
|
except ImportError:
|
|
# Fallback for testing without cryptography
|
|
import secrets
|
|
return f"private_key_{secrets.token_hex(32)}", f"public_key_{secrets.token_hex(32)}"
|
|
|
|
def get_jwks(self, include_private_keys: bool = False) -> Dict:
|
|
"""Get the JWKS document containing public keys.
|
|
|
|
Args:
|
|
include_private_keys: Whether to include private keys (for internal use only)
|
|
|
|
Returns:
|
|
JWKS document dictionary
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
|
|
keys = []
|
|
for kid, key in self._keys.items():
|
|
# Only include active, non-expired keys
|
|
if key.is_active and key.expires_at > now:
|
|
if include_private_keys:
|
|
keys.append(key.to_dict())
|
|
else:
|
|
keys.append(key.to_jwk())
|
|
|
|
return {
|
|
"keys": keys
|
|
}
|
|
|
|
def load_keys_from_db(self) -> int:
|
|
"""Load existing keys from the database.
|
|
|
|
Returns:
|
|
Number of keys loaded
|
|
"""
|
|
try:
|
|
db_keys = OidcJwksKey.get_active_keys()
|
|
now = datetime.now(timezone.utc)
|
|
|
|
for db_key in db_keys:
|
|
# Create JWKSKey from database model
|
|
key = JWKSKey(
|
|
kid=db_key.kid,
|
|
private_key=db_key.private_key,
|
|
public_key=db_key.public_key,
|
|
algorithm=db_key.algorithm,
|
|
created_at=db_key.created_at,
|
|
expires_at=db_key.expires_at or now + timedelta(days=365),
|
|
is_active=db_key.is_active,
|
|
)
|
|
self._keys[db_key.kid] = key
|
|
|
|
return len(self._keys)
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error loading keys from database: {e}")
|
|
return 0
|
|
|
|
def save_key_to_db(self, key: JWKSKey, is_primary: bool = False) -> OidcJwksKey:
|
|
"""Save a key to the database.
|
|
|
|
Args:
|
|
key: JWKSKey instance to save
|
|
is_primary: Whether this is the primary signing key
|
|
|
|
Returns:
|
|
OidcJwksKey database model instance
|
|
"""
|
|
db_key = OidcJwksKey(
|
|
kid=key.kid,
|
|
key_type="RSA",
|
|
algorithm=key.algorithm,
|
|
private_key=key.private_key,
|
|
public_key=key.public_key,
|
|
is_active=key.is_active,
|
|
is_primary=is_primary,
|
|
)
|
|
|
|
db.session.add(db_key)
|
|
db.session.commit()
|
|
|
|
return db_key
|
|
|
|
def get_signing_key(self) -> Optional[JWKSKey]:
|
|
"""Get the current active signing key.
|
|
|
|
Returns:
|
|
JWKSKey instance or None if no active key
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
|
|
# First try to get the primary key from database
|
|
try:
|
|
primary_db_key = OidcJwksKey.get_primary_key()
|
|
if primary_db_key:
|
|
# Check if we have it in memory, if not load it
|
|
if primary_db_key.kid not in self._keys:
|
|
key = JWKSKey(
|
|
kid=primary_db_key.kid,
|
|
private_key=primary_db_key.private_key,
|
|
public_key=primary_db_key.public_key,
|
|
algorithm=primary_db_key.algorithm,
|
|
created_at=primary_db_key.created_at,
|
|
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
|
|
is_active=primary_db_key.is_active,
|
|
)
|
|
self._keys[primary_db_key.kid] = key
|
|
return self._keys[primary_db_key.kid]
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error getting primary key from database: {e}")
|
|
|
|
# Fall back to in-memory keys
|
|
for kid, key in self._keys.items():
|
|
if key.is_active and key.expires_at > now:
|
|
return key
|
|
|
|
return None
|
|
|
|
def get_key_by_kid(self, kid: str) -> Optional[JWKSKey]:
|
|
"""Get a specific key by its ID.
|
|
|
|
Args:
|
|
kid: Key ID to look up
|
|
|
|
Returns:
|
|
JWKSKey instance or None if not found
|
|
"""
|
|
return self._keys.get(kid)
|
|
|
|
def generate_new_key_pair(self, expires_in_days: int = 365) -> JWKSKey:
|
|
"""Generate a new RSA key pair for signing.
|
|
|
|
Args:
|
|
expires_in_days: Days until key expiration
|
|
|
|
Returns:
|
|
JWKSKey instance
|
|
"""
|
|
private_key, public_key = self._generate_rsa_key_pair()
|
|
kid = self._generate_kid(private_key)
|
|
|
|
now = datetime.now(timezone.utc)
|
|
key = JWKSKey(
|
|
kid=kid,
|
|
private_key=private_key,
|
|
public_key=public_key,
|
|
algorithm="RS256",
|
|
created_at=now,
|
|
expires_at=now + timedelta(days=expires_in_days),
|
|
is_active=True,
|
|
)
|
|
|
|
self._keys[kid] = key
|
|
|
|
# Deactivate old keys (but keep them for grace period)
|
|
for old_kid in self._keys:
|
|
if old_kid != kid:
|
|
self._keys[old_kid].is_active = False
|
|
|
|
return key
|
|
|
|
def rotate_keys(self, grace_period_hours: int = 24) -> Tuple[JWKSKey, List[str]]:
|
|
"""Rotate signing keys, keeping previous key active for grace period.
|
|
|
|
Args:
|
|
grace_period_hours: Hours to keep old keys active
|
|
|
|
Returns:
|
|
Tuple of (new_key, list_of_deprecated_kids)
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
grace_end = now + timedelta(hours=grace_period_hours)
|
|
|
|
# Mark current key as deprecated
|
|
current_key = self.get_signing_key()
|
|
deprecated_kids = []
|
|
|
|
if current_key:
|
|
deprecated_kids.append(current_key.kid)
|
|
# Keep key active but mark as deprecated
|
|
current_key.is_active = False
|
|
current_key.expires_at = grace_end
|
|
|
|
# Generate new key
|
|
new_key = self.generate_new_key_pair()
|
|
|
|
# Clean up expired keys
|
|
expired_kids = [
|
|
kid for kid, key in self._keys.items()
|
|
if key.expires_at < now
|
|
]
|
|
for kid in expired_kids:
|
|
del self._keys[kid]
|
|
|
|
return new_key, deprecated_kids
|
|
|
|
def verify_key_exists(self, kid: str) -> bool:
|
|
"""Check if a key with the given ID exists and is valid.
|
|
|
|
Args:
|
|
kid: Key ID to check
|
|
|
|
Returns:
|
|
True if key exists and is valid
|
|
"""
|
|
key = self.get_key_by_kid(kid)
|
|
if not key:
|
|
return False
|
|
|
|
now = datetime.now(timezone.utc)
|
|
return key.is_active and key.expires_at > now
|
|
|
|
def initialize_with_key(self) -> JWKSKey:
|
|
"""Initialize the service with a key, loading from database if available.
|
|
|
|
This method first attempts to load existing keys from the database.
|
|
If no active primary key exists, it generates a new key and saves it to the database.
|
|
|
|
Returns:
|
|
JWKSKey instance
|
|
"""
|
|
# First, try to load keys from database
|
|
try:
|
|
# Check if there's a primary key in the database
|
|
primary_db_key = OidcJwksKey.get_primary_key()
|
|
if primary_db_key:
|
|
# Load the primary key into memory
|
|
now = datetime.now(timezone.utc)
|
|
key = JWKSKey(
|
|
kid=primary_db_key.kid,
|
|
private_key=primary_db_key.private_key,
|
|
public_key=primary_db_key.public_key,
|
|
algorithm=primary_db_key.algorithm,
|
|
created_at=primary_db_key.created_at,
|
|
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
|
|
is_active=primary_db_key.is_active,
|
|
)
|
|
self._keys[primary_db_key.kid] = key
|
|
current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}")
|
|
return key
|
|
|
|
# Try to load all active keys from database
|
|
loaded_count = self.load_keys_from_db()
|
|
if loaded_count > 0:
|
|
# Get the signing key from loaded keys
|
|
signing_key = self.get_signing_key()
|
|
if signing_key:
|
|
current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}")
|
|
return signing_key
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error loading keys from database: {e}")
|
|
|
|
# No keys in database, generate a new key and save it
|
|
current_app.logger.info("[OIDC] No existing keys found in database, generating new signing key")
|
|
new_key = self.generate_new_key_pair()
|
|
|
|
# Save the new key to database
|
|
try:
|
|
self.save_key_to_db(new_key, is_primary=True)
|
|
current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}")
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error saving key to database: {e}")
|
|
|
|
return new_key
|