functional totp

This commit is contained in:
2026-01-14 18:06:26 +10:30
parent cfd79190ee
commit 5e4cffcf73
17 changed files with 1052 additions and 56 deletions
+6 -6
View File
@@ -221,6 +221,8 @@ def initialize_oidc_jwks(app):
"""Initialize OIDC JWKS service with a signing key.
This ensures that signing keys are available for token generation.
Keys are loaded from the database if available, otherwise a new key
is generated and persisted to the database.
Args:
app: Flask application instance
@@ -228,11 +230,9 @@ def initialize_oidc_jwks(app):
with app.app_context():
try:
jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key()
if not signing_key:
signing_key = jwks_service.initialize_with_key()
app.logger.info(f"[OIDC] Generated new signing key: kid={signing_key.kid}")
else:
app.logger.info(f"[OIDC] Using existing signing key: kid={signing_key.kid}")
# Use initialize_with_key which handles loading from DB
# or generating a new key if none exists
signing_key = jwks_service.initialize_with_key()
app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}")
except Exception as e:
app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}")
+4 -4
View File
@@ -24,7 +24,7 @@ def setup_cors(app):
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
@@ -32,7 +32,7 @@ def setup_cors(app):
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
@@ -51,13 +51,13 @@ def setup_cors(app):
# When allowing all origins, set header to "*"
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Max-Age"] = "3600"
elif origin and origin in cors_origins:
# When allowing specific origins, echo the request origin
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
+4 -4
View File
@@ -19,10 +19,10 @@ class AuthenticationMethod(BaseModel):
provider_user_id = db.Column(db.String(255), nullable=True)
provider_data = db.Column(db.JSON, nullable=True)
# # For TOTP authentication
# totp_secret = db.Column(db.String(32), nullable=True)
# totp_backup_codes = db.Column(db.JSON, nullable=True)
# totp_verified_at = db.Column(db.DateTime, nullable=True)
# For TOTP authentication
totp_secret = db.Column(db.String(32), nullable=True)
totp_backup_codes = db.Column(db.JSON, nullable=True)
totp_verified_at = db.Column(db.DateTime, nullable=True)
# Metadata
is_primary = db.Column(db.Boolean, default=False, nullable=False)
+77
View File
@@ -0,0 +1,77 @@
"""OIDC JWKS Key model for persisting signing keys."""
from datetime import datetime, timezone
from app.extensions import db
from app.models.base import BaseModel
class OidcJwksKey(BaseModel):
"""
OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
This model stores RSA/ECDSA key pairs used for signing OIDC tokens.
Multiple keys can be stored to support key rotation scenarios.
Attributes:
id: Integer primary key
kid: Unique key ID used in JWT "kid" header
key_type: Type of key (e.g., "RSA", "EC")
private_key: PEM-encoded private key
public_key: PEM-encoded public key
algorithm: Signing algorithm (e.g., "RS256", "ES256")
created_at: When the key was created
is_active: Whether this key is currently active for signing
is_primary: Whether this is the primary signing key
expires_at: ...
"""
__tablename__ = "oidc_jwks_keys"
# Override the default UUID id with integer primary key
id = db.Column(db.Integer, primary_key=True)
expires_at = db.Column(db.DateTime, nullable=True)
# Key identification and type
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
# Key material (PEM-encoded)
private_key = db.Column(db.Text, nullable=False)
public_key = db.Column(db.Text, nullable=False)
# Key status
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_primary = db.Column(db.Boolean, default=False, nullable=False)
def __repr__(self):
"""String representation of OidcJwksKey."""
return f"<OidcJwksKey kid={self.kid} key_type={self.key_type} algorithm={self.algorithm}>"
def to_dict(self, exclude_private_key=True):
"""
Convert model to dictionary.
Args:
exclude_private_key: If True, excludes the private key from output
Returns:
Dictionary representation of the model
"""
exclude = ["private_key"] if exclude_private_key else []
return super().to_dict(exclude=exclude)
@classmethod
def get_active_keys(cls):
"""Get all active keys for signing operations."""
return cls.query.filter(cls.is_active == True).all()
@classmethod
def get_primary_key(cls):
"""Get the primary signing key."""
return cls.query.filter(cls.is_primary == True).first()
@classmethod
def get_key_by_kid(cls, kid):
"""Get a key by its key ID."""
return cls.query.filter(cls.kid == kid, cls.is_active == True).first()
+12 -3
View File
@@ -21,7 +21,7 @@ class Session(BaseModel):
# Timing
expires_at = db.Column(db.DateTime, nullable=False)
last_activity_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow)
last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
@@ -35,15 +35,24 @@ class Session(BaseModel):
def is_active(self):
"""Check if session is currently active."""
now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return (
self.status == SessionStatus.ACTIVE
and self.expires_at > now
and expires_at > now
and self.deleted_at is None
)
def is_expired(self):
"""Check if session has expired."""
return datetime.now(timezone.utc) > self.expires_at
now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
def refresh(self, duration_seconds=86400):
"""
+5 -1
View File
@@ -84,10 +84,14 @@ class User(BaseModel):
Returns:
The AuthenticationMethod instance for TOTP or None if not found.
Note:
Returns the most recently created TOTP method to handle cases where
multiple enrollment attempts may exist.
"""
from app.models.authentication_method import AuthenticationMethod
from app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
).first()
).order_by(AuthenticationMethod.created_at.desc()).first()
+18 -8
View File
@@ -1,7 +1,7 @@
"""Authentication service."""
import logging
import secrets
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from flask import request, g, current_app
from app.extensions import db, bcrypt
from app.models.user import User
@@ -131,9 +131,9 @@ class AuthService:
raise InvalidCredentialsError()
# Update last login
user.last_login_at = datetime.utcnow()
user.last_login_at = datetime.now(timezone.utc)
user.last_login_ip = request.remote_addr
auth_method.last_used_at = datetime.utcnow()
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
return user
@@ -160,8 +160,8 @@ class AuthService:
status=SessionStatus.ACTIVE,
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent"),
expires_at=datetime.utcnow() + timedelta(seconds=duration_seconds),
last_activity_at=datetime.utcnow(),
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
last_activity_at=datetime.now(timezone.utc),
)
session.save()
@@ -260,6 +260,14 @@ class AuthService:
if user.has_totp_enabled():
raise ConflictError("TOTP is already enabled for this account")
# Clean up any existing unverified TOTP enrollment attempts
# Use hard delete for unverified methods since they're incomplete enrollment attempts
existing_totp_method = user.get_totp_method()
if existing_totp_method and not existing_totp_method.verified:
logger.debug(f"Removing existing unverified TOTP method for user {user.id}")
db.session.delete(existing_totp_method) # Hard delete - unverified methods are temporary
db.session.commit() # Commit to ensure deletion before creating new record
# Generate TOTP secret
secret = TOTPService.generate_secret()
@@ -339,7 +347,7 @@ class AuthService:
# Mark TOTP as verified
auth_method.verified = True
auth_method.totp_verified_at = datetime.utcnow()
auth_method.totp_verified_at = datetime.now(timezone.utc)
db.session.commit()
# Log TOTP enrollment completion
@@ -436,8 +444,10 @@ class AuthService:
"secret": auth_method.provider_data.get("secret"),
"backup_codes": remaining_codes,
}
auth_method.last_used_at = datetime.utcnow()
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.add(auth_method)
db.session.commit()
logger.debug(f"[BACKUP CODE] Updated provider_data: {auth_method.provider_data}")
# Log backup code usage
AuditService.log_action(
@@ -470,7 +480,7 @@ class AuthService:
is_valid = TOTPService.verify_code(secret, code)
if is_valid:
auth_method.last_used_at = datetime.utcnow()
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
# Log successful verification
+2 -2
View File
@@ -1,5 +1,5 @@
"""OIDC Audit Service for comprehensive OIDC event logging."""
from datetime import datetime
from datetime import datetime, timezone
from typing import Dict, List, Optional
from flask import g
@@ -374,7 +374,7 @@ class OIDCAuditService:
"""
from datetime import timedelta
start_date = datetime.utcnow() - timedelta(days=days)
start_date = datetime.now(timezone.utc) - timedelta(days=days)
query = OIDCAuditLog.query.filter(
OIDCAuditLog.created_at >= start_date
+130 -12
View File
@@ -2,12 +2,13 @@
import uuid
import json
import hashlib
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Tuple
from flask import current_app
from app.extensions import db
from app.models.oidc_jwks_key import OidcJwksKey
class JWKSKey:
@@ -20,8 +21,8 @@ class JWKSKey:
self.private_key = private_key
self.public_key = public_key
self.algorithm = algorithm
self.created_at = created_at or datetime.utcnow()
self.expires_at = expires_at or datetime.utcnow() + timedelta(days=365)
self.created_at = created_at or datetime.now(timezone.utc)
self.expires_at = expires_at or datetime.now(timezone.utc) + timedelta(days=365)
self.is_active = is_active
def to_jwk(self) -> Dict:
@@ -166,7 +167,7 @@ class OIDCJWKSService:
Returns:
JWKS document dictionary
"""
now = datetime.utcnow()
now = datetime.now(timezone.utc)
keys = []
for kid, key in self._keys.items():
@@ -181,14 +182,88 @@ class OIDCJWKSService:
"keys": keys
}
def load_keys_from_db(self) -> int:
"""Load existing keys from the database.
Returns:
Number of keys loaded
"""
try:
db_keys = OidcJwksKey.get_active_keys()
now = datetime.now(timezone.utc)
for db_key in db_keys:
# Create JWKSKey from database model
key = JWKSKey(
kid=db_key.kid,
private_key=db_key.private_key,
public_key=db_key.public_key,
algorithm=db_key.algorithm,
created_at=db_key.created_at,
expires_at=db_key.expires_at or now + timedelta(days=365),
is_active=db_key.is_active,
)
self._keys[db_key.kid] = key
return len(self._keys)
except Exception as e:
current_app.logger.error(f"Error loading keys from database: {e}")
return 0
def save_key_to_db(self, key: JWKSKey, is_primary: bool = False) -> OidcJwksKey:
"""Save a key to the database.
Args:
key: JWKSKey instance to save
is_primary: Whether this is the primary signing key
Returns:
OidcJwksKey database model instance
"""
db_key = OidcJwksKey(
kid=key.kid,
key_type="RSA",
algorithm=key.algorithm,
private_key=key.private_key,
public_key=key.public_key,
is_active=key.is_active,
is_primary=is_primary,
)
db.session.add(db_key)
db.session.commit()
return db_key
def get_signing_key(self) -> Optional[JWKSKey]:
"""Get the current active signing key.
Returns:
JWKSKey instance or None if no active key
"""
now = datetime.utcnow()
now = datetime.now(timezone.utc)
# First try to get the primary key from database
try:
primary_db_key = OidcJwksKey.get_primary_key()
if primary_db_key:
# Check if we have it in memory, if not load it
if primary_db_key.kid not in self._keys:
key = JWKSKey(
kid=primary_db_key.kid,
private_key=primary_db_key.private_key,
public_key=primary_db_key.public_key,
algorithm=primary_db_key.algorithm,
created_at=primary_db_key.created_at,
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
is_active=primary_db_key.is_active,
)
self._keys[primary_db_key.kid] = key
return self._keys[primary_db_key.kid]
except Exception as e:
current_app.logger.error(f"Error getting primary key from database: {e}")
# Fall back to in-memory keys
for kid, key in self._keys.items():
if key.is_active and key.expires_at > now:
return key
@@ -218,7 +293,7 @@ class OIDCJWKSService:
private_key, public_key = self._generate_rsa_key_pair()
kid = self._generate_kid(private_key)
now = datetime.utcnow()
now = datetime.now(timezone.utc)
key = JWKSKey(
kid=kid,
private_key=private_key,
@@ -247,7 +322,7 @@ class OIDCJWKSService:
Returns:
Tuple of (new_key, list_of_deprecated_kids)
"""
now = datetime.utcnow()
now = datetime.now(timezone.utc)
grace_end = now + timedelta(hours=grace_period_hours)
# Mark current key as deprecated
@@ -286,15 +361,58 @@ class OIDCJWKSService:
if not key:
return False
now = datetime.utcnow()
now = datetime.now(timezone.utc)
return key.is_active and key.expires_at > now
def initialize_with_key(self) -> JWKSKey:
"""Initialize the service with a key if none exists.
"""Initialize the service with a key, loading from database if available.
This method first attempts to load existing keys from the database.
If no active primary key exists, it generates a new key and saves it to the database.
Returns:
JWKSKey instance
"""
if not self._keys:
return self.generate_new_key_pair()
return self.get_signing_key()
# First, try to load keys from database
try:
# Check if there's a primary key in the database
primary_db_key = OidcJwksKey.get_primary_key()
if primary_db_key:
# Load the primary key into memory
now = datetime.now(timezone.utc)
key = JWKSKey(
kid=primary_db_key.kid,
private_key=primary_db_key.private_key,
public_key=primary_db_key.public_key,
algorithm=primary_db_key.algorithm,
created_at=primary_db_key.created_at,
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
is_active=primary_db_key.is_active,
)
self._keys[primary_db_key.kid] = key
current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}")
return key
# Try to load all active keys from database
loaded_count = self.load_keys_from_db()
if loaded_count > 0:
# Get the signing key from loaded keys
signing_key = self.get_signing_key()
if signing_key:
current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}")
return signing_key
except Exception as e:
current_app.logger.error(f"Error loading keys from database: {e}")
# No keys in database, generate a new key and save it
current_app.logger.info("[OIDC] No existing keys found in database, generating new signing key")
new_key = self.generate_new_key_pair()
# Save the new key to database
try:
self.save_key_to_db(new_key, is_primary=True)
current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}")
except Exception as e:
current_app.logger.error(f"Error saving key to database: {e}")
return new_key
+4 -4
View File
@@ -1,6 +1,6 @@
"""Organization service."""
import logging
from datetime import datetime
from datetime import datetime, timezone
from flask import current_app
from app.extensions import db
from app.models.organization import Organization
@@ -53,7 +53,7 @@ class OrganizationService:
user_id=owner_user_id,
organization_id=org.id,
role=OrganizationRole.OWNER,
joined_at=datetime.utcnow(),
joined_at=datetime.now(timezone.utc),
)
member.save()
@@ -208,8 +208,8 @@ class OrganizationService:
organization_id=org.id,
role=role,
invited_by_id=inviter_id,
invited_at=datetime.utcnow(),
joined_at=datetime.utcnow(),
invited_at=datetime.now(timezone.utc),
joined_at=datetime.now(timezone.utc),
)
member.save()
+3 -3
View File
@@ -1,5 +1,5 @@
"""Session service."""
from datetime import datetime
from datetime import datetime, timezone
from app.models.session import Session
from app.utils.constants import SessionStatus
@@ -41,7 +41,7 @@ class SessionService:
if active_only:
query = query.filter_by(status=SessionStatus.ACTIVE).filter(
Session.expires_at > datetime.utcnow()
Session.expires_at > datetime.now(timezone.utc)
)
return query.all()
@@ -65,7 +65,7 @@ class SessionService:
"""Clean up expired sessions."""
expired_sessions = Session.query.filter(
Session.status == SessionStatus.ACTIVE,
Session.expires_at < datetime.utcnow(),
Session.expires_at < datetime.now(timezone.utc),
Session.deleted_at.is_(None),
).all()
+34 -8
View File
@@ -3,6 +3,7 @@ import base64
import io
import logging
import secrets
from datetime import datetime, timezone
from typing import Tuple
import pyotp
@@ -72,10 +73,34 @@ class TOTPService:
The window parameter allows for clock skew between the server
and the authenticator app. A window of 1 allows codes from
the previous, current, and next 30-second intervals.
IMPORTANT: Always uses UTC time for verification to ensure
consistency across all timezones.
"""
totp = pyotp.TOTP(secret)
is_valid = totp.verify(code, valid_window=window)
logger.debug(f"TOTP code verification: valid={is_valid}, window={window}")
# Use timezone-aware UTC datetime for verification
# IMPORTANT: We must pass a datetime object, NOT a Unix timestamp
# pyotp's internal datetime.utcfromtimestamp() is deprecated and can be
# affected by local timezone settings, causing the 10.5 hour skew issue
utc_now = datetime.now(timezone.utc)
# DEBUG: Log detailed timezone information
logger.debug(f"[TOTP DEBUG] UTC now: {utc_now}")
logger.debug(f"[TOTP DEBUG] UTC now isoformat: {utc_now.isoformat()}")
logger.debug(f"[TOTP DEBUG] UTC timestamp: {utc_now.timestamp()}")
logger.debug(f"[TOTP DEBUG] UTC now tzinfo: {utc_now.tzinfo}")
# Generate what the TOTP code should be at this moment using UTC datetime
expected_code = totp.at(utc_now)
logger.debug(f"[TOTP DEBUG] Expected TOTP code at UTC: {expected_code}")
# Verify with the provided code using UTC datetime object
# Passing a datetime object avoids pyotp's utcfromtimestamp() issues
is_valid = totp.verify(code, valid_window=window, for_time=utc_now)
logger.debug(f"[TOTP DEBUG] TOTP code verification: valid={is_valid}, window={window}")
logger.debug(f"[TOTP DEBUG] Provided code: {code}, Expected code: {expected_code}")
return is_valid
@staticmethod
@@ -133,15 +158,16 @@ class TOTPService:
for hashed_code in hashed_codes:
if bcrypt.check_password_hash(hashed_code, code):
# Code found and valid - don't add to remaining codes (consumed)
logger.debug("Backup code verified and consumed")
return True, remaining_codes
# Code found and valid - mark as matched but don't add to remaining codes
matched = True
else:
# Code doesn't match - keep it in remaining codes
remaining_codes.append(hashed_code)
logger.debug("Backup code verification failed")
return False, remaining_codes
if matched:
return True, remaining_codes
else:
return False, remaining_codes
@staticmethod
def generate_qr_code_data_uri(provisioning_uri: str) -> str:
@@ -185,4 +211,4 @@ class TOTPService:
except ImportError:
logger.warning("qrcode library not installed, returning placeholder")
return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]"
return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]"