functional totp
This commit is contained in:
@@ -2,12 +2,13 @@
|
||||
import uuid
|
||||
import json
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from app.extensions import db
|
||||
from app.models.oidc_jwks_key import OidcJwksKey
|
||||
|
||||
|
||||
class JWKSKey:
|
||||
@@ -20,8 +21,8 @@ class JWKSKey:
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
self.algorithm = algorithm
|
||||
self.created_at = created_at or datetime.utcnow()
|
||||
self.expires_at = expires_at or datetime.utcnow() + timedelta(days=365)
|
||||
self.created_at = created_at or datetime.now(timezone.utc)
|
||||
self.expires_at = expires_at or datetime.now(timezone.utc) + timedelta(days=365)
|
||||
self.is_active = is_active
|
||||
|
||||
def to_jwk(self) -> Dict:
|
||||
@@ -166,7 +167,7 @@ class OIDCJWKSService:
|
||||
Returns:
|
||||
JWKS document dictionary
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
keys = []
|
||||
for kid, key in self._keys.items():
|
||||
@@ -181,14 +182,88 @@ class OIDCJWKSService:
|
||||
"keys": keys
|
||||
}
|
||||
|
||||
def load_keys_from_db(self) -> int:
|
||||
"""Load existing keys from the database.
|
||||
|
||||
Returns:
|
||||
Number of keys loaded
|
||||
"""
|
||||
try:
|
||||
db_keys = OidcJwksKey.get_active_keys()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
for db_key in db_keys:
|
||||
# Create JWKSKey from database model
|
||||
key = JWKSKey(
|
||||
kid=db_key.kid,
|
||||
private_key=db_key.private_key,
|
||||
public_key=db_key.public_key,
|
||||
algorithm=db_key.algorithm,
|
||||
created_at=db_key.created_at,
|
||||
expires_at=db_key.expires_at or now + timedelta(days=365),
|
||||
is_active=db_key.is_active,
|
||||
)
|
||||
self._keys[db_key.kid] = key
|
||||
|
||||
return len(self._keys)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading keys from database: {e}")
|
||||
return 0
|
||||
|
||||
def save_key_to_db(self, key: JWKSKey, is_primary: bool = False) -> OidcJwksKey:
|
||||
"""Save a key to the database.
|
||||
|
||||
Args:
|
||||
key: JWKSKey instance to save
|
||||
is_primary: Whether this is the primary signing key
|
||||
|
||||
Returns:
|
||||
OidcJwksKey database model instance
|
||||
"""
|
||||
db_key = OidcJwksKey(
|
||||
kid=key.kid,
|
||||
key_type="RSA",
|
||||
algorithm=key.algorithm,
|
||||
private_key=key.private_key,
|
||||
public_key=key.public_key,
|
||||
is_active=key.is_active,
|
||||
is_primary=is_primary,
|
||||
)
|
||||
|
||||
db.session.add(db_key)
|
||||
db.session.commit()
|
||||
|
||||
return db_key
|
||||
|
||||
def get_signing_key(self) -> Optional[JWKSKey]:
|
||||
"""Get the current active signing key.
|
||||
|
||||
Returns:
|
||||
JWKSKey instance or None if no active key
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# First try to get the primary key from database
|
||||
try:
|
||||
primary_db_key = OidcJwksKey.get_primary_key()
|
||||
if primary_db_key:
|
||||
# Check if we have it in memory, if not load it
|
||||
if primary_db_key.kid not in self._keys:
|
||||
key = JWKSKey(
|
||||
kid=primary_db_key.kid,
|
||||
private_key=primary_db_key.private_key,
|
||||
public_key=primary_db_key.public_key,
|
||||
algorithm=primary_db_key.algorithm,
|
||||
created_at=primary_db_key.created_at,
|
||||
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
|
||||
is_active=primary_db_key.is_active,
|
||||
)
|
||||
self._keys[primary_db_key.kid] = key
|
||||
return self._keys[primary_db_key.kid]
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error getting primary key from database: {e}")
|
||||
|
||||
# Fall back to in-memory keys
|
||||
for kid, key in self._keys.items():
|
||||
if key.is_active and key.expires_at > now:
|
||||
return key
|
||||
@@ -218,7 +293,7 @@ class OIDCJWKSService:
|
||||
private_key, public_key = self._generate_rsa_key_pair()
|
||||
kid = self._generate_kid(private_key)
|
||||
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(timezone.utc)
|
||||
key = JWKSKey(
|
||||
kid=kid,
|
||||
private_key=private_key,
|
||||
@@ -247,7 +322,7 @@ class OIDCJWKSService:
|
||||
Returns:
|
||||
Tuple of (new_key, list_of_deprecated_kids)
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(timezone.utc)
|
||||
grace_end = now + timedelta(hours=grace_period_hours)
|
||||
|
||||
# Mark current key as deprecated
|
||||
@@ -286,15 +361,58 @@ class OIDCJWKSService:
|
||||
if not key:
|
||||
return False
|
||||
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(timezone.utc)
|
||||
return key.is_active and key.expires_at > now
|
||||
|
||||
def initialize_with_key(self) -> JWKSKey:
|
||||
"""Initialize the service with a key if none exists.
|
||||
"""Initialize the service with a key, loading from database if available.
|
||||
|
||||
This method first attempts to load existing keys from the database.
|
||||
If no active primary key exists, it generates a new key and saves it to the database.
|
||||
|
||||
Returns:
|
||||
JWKSKey instance
|
||||
"""
|
||||
if not self._keys:
|
||||
return self.generate_new_key_pair()
|
||||
return self.get_signing_key()
|
||||
# First, try to load keys from database
|
||||
try:
|
||||
# Check if there's a primary key in the database
|
||||
primary_db_key = OidcJwksKey.get_primary_key()
|
||||
if primary_db_key:
|
||||
# Load the primary key into memory
|
||||
now = datetime.now(timezone.utc)
|
||||
key = JWKSKey(
|
||||
kid=primary_db_key.kid,
|
||||
private_key=primary_db_key.private_key,
|
||||
public_key=primary_db_key.public_key,
|
||||
algorithm=primary_db_key.algorithm,
|
||||
created_at=primary_db_key.created_at,
|
||||
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
|
||||
is_active=primary_db_key.is_active,
|
||||
)
|
||||
self._keys[primary_db_key.kid] = key
|
||||
current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}")
|
||||
return key
|
||||
|
||||
# Try to load all active keys from database
|
||||
loaded_count = self.load_keys_from_db()
|
||||
if loaded_count > 0:
|
||||
# Get the signing key from loaded keys
|
||||
signing_key = self.get_signing_key()
|
||||
if signing_key:
|
||||
current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}")
|
||||
return signing_key
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading keys from database: {e}")
|
||||
|
||||
# No keys in database, generate a new key and save it
|
||||
current_app.logger.info("[OIDC] No existing keys found in database, generating new signing key")
|
||||
new_key = self.generate_new_key_pair()
|
||||
|
||||
# Save the new key to database
|
||||
try:
|
||||
self.save_key_to_db(new_key, is_primary=True)
|
||||
current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}")
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error saving key to database: {e}")
|
||||
|
||||
return new_key
|
||||
|
||||
Reference in New Issue
Block a user