functional totp

This commit is contained in:
2026-01-14 18:06:26 +10:30
parent cfd79190ee
commit 5e4cffcf73
17 changed files with 1052 additions and 56 deletions
+18 -8
View File
@@ -1,7 +1,7 @@
"""Authentication service."""
import logging
import secrets
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from flask import request, g, current_app
from app.extensions import db, bcrypt
from app.models.user import User
@@ -131,9 +131,9 @@ class AuthService:
raise InvalidCredentialsError()
# Update last login
user.last_login_at = datetime.utcnow()
user.last_login_at = datetime.now(timezone.utc)
user.last_login_ip = request.remote_addr
auth_method.last_used_at = datetime.utcnow()
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
return user
@@ -160,8 +160,8 @@ class AuthService:
status=SessionStatus.ACTIVE,
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent"),
expires_at=datetime.utcnow() + timedelta(seconds=duration_seconds),
last_activity_at=datetime.utcnow(),
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
last_activity_at=datetime.now(timezone.utc),
)
session.save()
@@ -260,6 +260,14 @@ class AuthService:
if user.has_totp_enabled():
raise ConflictError("TOTP is already enabled for this account")
# Clean up any existing unverified TOTP enrollment attempts
# Use hard delete for unverified methods since they're incomplete enrollment attempts
existing_totp_method = user.get_totp_method()
if existing_totp_method and not existing_totp_method.verified:
logger.debug(f"Removing existing unverified TOTP method for user {user.id}")
db.session.delete(existing_totp_method) # Hard delete - unverified methods are temporary
db.session.commit() # Commit to ensure deletion before creating new record
# Generate TOTP secret
secret = TOTPService.generate_secret()
@@ -339,7 +347,7 @@ class AuthService:
# Mark TOTP as verified
auth_method.verified = True
auth_method.totp_verified_at = datetime.utcnow()
auth_method.totp_verified_at = datetime.now(timezone.utc)
db.session.commit()
# Log TOTP enrollment completion
@@ -436,8 +444,10 @@ class AuthService:
"secret": auth_method.provider_data.get("secret"),
"backup_codes": remaining_codes,
}
auth_method.last_used_at = datetime.utcnow()
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.add(auth_method)
db.session.commit()
logger.debug(f"[BACKUP CODE] Updated provider_data: {auth_method.provider_data}")
# Log backup code usage
AuditService.log_action(
@@ -470,7 +480,7 @@ class AuthService:
is_valid = TOTPService.verify_code(secret, code)
if is_valid:
auth_method.last_used_at = datetime.utcnow()
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
# Log successful verification
+2 -2
View File
@@ -1,5 +1,5 @@
"""OIDC Audit Service for comprehensive OIDC event logging."""
from datetime import datetime
from datetime import datetime, timezone
from typing import Dict, List, Optional
from flask import g
@@ -374,7 +374,7 @@ class OIDCAuditService:
"""
from datetime import timedelta
start_date = datetime.utcnow() - timedelta(days=days)
start_date = datetime.now(timezone.utc) - timedelta(days=days)
query = OIDCAuditLog.query.filter(
OIDCAuditLog.created_at >= start_date
+130 -12
View File
@@ -2,12 +2,13 @@
import uuid
import json
import hashlib
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Tuple
from flask import current_app
from app.extensions import db
from app.models.oidc_jwks_key import OidcJwksKey
class JWKSKey:
@@ -20,8 +21,8 @@ class JWKSKey:
self.private_key = private_key
self.public_key = public_key
self.algorithm = algorithm
self.created_at = created_at or datetime.utcnow()
self.expires_at = expires_at or datetime.utcnow() + timedelta(days=365)
self.created_at = created_at or datetime.now(timezone.utc)
self.expires_at = expires_at or datetime.now(timezone.utc) + timedelta(days=365)
self.is_active = is_active
def to_jwk(self) -> Dict:
@@ -166,7 +167,7 @@ class OIDCJWKSService:
Returns:
JWKS document dictionary
"""
now = datetime.utcnow()
now = datetime.now(timezone.utc)
keys = []
for kid, key in self._keys.items():
@@ -181,14 +182,88 @@ class OIDCJWKSService:
"keys": keys
}
def load_keys_from_db(self) -> int:
"""Load existing keys from the database.
Returns:
Number of keys loaded
"""
try:
db_keys = OidcJwksKey.get_active_keys()
now = datetime.now(timezone.utc)
for db_key in db_keys:
# Create JWKSKey from database model
key = JWKSKey(
kid=db_key.kid,
private_key=db_key.private_key,
public_key=db_key.public_key,
algorithm=db_key.algorithm,
created_at=db_key.created_at,
expires_at=db_key.expires_at or now + timedelta(days=365),
is_active=db_key.is_active,
)
self._keys[db_key.kid] = key
return len(self._keys)
except Exception as e:
current_app.logger.error(f"Error loading keys from database: {e}")
return 0
def save_key_to_db(self, key: JWKSKey, is_primary: bool = False) -> OidcJwksKey:
"""Save a key to the database.
Args:
key: JWKSKey instance to save
is_primary: Whether this is the primary signing key
Returns:
OidcJwksKey database model instance
"""
db_key = OidcJwksKey(
kid=key.kid,
key_type="RSA",
algorithm=key.algorithm,
private_key=key.private_key,
public_key=key.public_key,
is_active=key.is_active,
is_primary=is_primary,
)
db.session.add(db_key)
db.session.commit()
return db_key
def get_signing_key(self) -> Optional[JWKSKey]:
"""Get the current active signing key.
Returns:
JWKSKey instance or None if no active key
"""
now = datetime.utcnow()
now = datetime.now(timezone.utc)
# First try to get the primary key from database
try:
primary_db_key = OidcJwksKey.get_primary_key()
if primary_db_key:
# Check if we have it in memory, if not load it
if primary_db_key.kid not in self._keys:
key = JWKSKey(
kid=primary_db_key.kid,
private_key=primary_db_key.private_key,
public_key=primary_db_key.public_key,
algorithm=primary_db_key.algorithm,
created_at=primary_db_key.created_at,
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
is_active=primary_db_key.is_active,
)
self._keys[primary_db_key.kid] = key
return self._keys[primary_db_key.kid]
except Exception as e:
current_app.logger.error(f"Error getting primary key from database: {e}")
# Fall back to in-memory keys
for kid, key in self._keys.items():
if key.is_active and key.expires_at > now:
return key
@@ -218,7 +293,7 @@ class OIDCJWKSService:
private_key, public_key = self._generate_rsa_key_pair()
kid = self._generate_kid(private_key)
now = datetime.utcnow()
now = datetime.now(timezone.utc)
key = JWKSKey(
kid=kid,
private_key=private_key,
@@ -247,7 +322,7 @@ class OIDCJWKSService:
Returns:
Tuple of (new_key, list_of_deprecated_kids)
"""
now = datetime.utcnow()
now = datetime.now(timezone.utc)
grace_end = now + timedelta(hours=grace_period_hours)
# Mark current key as deprecated
@@ -286,15 +361,58 @@ class OIDCJWKSService:
if not key:
return False
now = datetime.utcnow()
now = datetime.now(timezone.utc)
return key.is_active and key.expires_at > now
def initialize_with_key(self) -> JWKSKey:
"""Initialize the service with a key if none exists.
"""Initialize the service with a key, loading from database if available.
This method first attempts to load existing keys from the database.
If no active primary key exists, it generates a new key and saves it to the database.
Returns:
JWKSKey instance
"""
if not self._keys:
return self.generate_new_key_pair()
return self.get_signing_key()
# First, try to load keys from database
try:
# Check if there's a primary key in the database
primary_db_key = OidcJwksKey.get_primary_key()
if primary_db_key:
# Load the primary key into memory
now = datetime.now(timezone.utc)
key = JWKSKey(
kid=primary_db_key.kid,
private_key=primary_db_key.private_key,
public_key=primary_db_key.public_key,
algorithm=primary_db_key.algorithm,
created_at=primary_db_key.created_at,
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
is_active=primary_db_key.is_active,
)
self._keys[primary_db_key.kid] = key
current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}")
return key
# Try to load all active keys from database
loaded_count = self.load_keys_from_db()
if loaded_count > 0:
# Get the signing key from loaded keys
signing_key = self.get_signing_key()
if signing_key:
current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}")
return signing_key
except Exception as e:
current_app.logger.error(f"Error loading keys from database: {e}")
# No keys in database, generate a new key and save it
current_app.logger.info("[OIDC] No existing keys found in database, generating new signing key")
new_key = self.generate_new_key_pair()
# Save the new key to database
try:
self.save_key_to_db(new_key, is_primary=True)
current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}")
except Exception as e:
current_app.logger.error(f"Error saving key to database: {e}")
return new_key
+4 -4
View File
@@ -1,6 +1,6 @@
"""Organization service."""
import logging
from datetime import datetime
from datetime import datetime, timezone
from flask import current_app
from app.extensions import db
from app.models.organization import Organization
@@ -53,7 +53,7 @@ class OrganizationService:
user_id=owner_user_id,
organization_id=org.id,
role=OrganizationRole.OWNER,
joined_at=datetime.utcnow(),
joined_at=datetime.now(timezone.utc),
)
member.save()
@@ -208,8 +208,8 @@ class OrganizationService:
organization_id=org.id,
role=role,
invited_by_id=inviter_id,
invited_at=datetime.utcnow(),
joined_at=datetime.utcnow(),
invited_at=datetime.now(timezone.utc),
joined_at=datetime.now(timezone.utc),
)
member.save()
+3 -3
View File
@@ -1,5 +1,5 @@
"""Session service."""
from datetime import datetime
from datetime import datetime, timezone
from app.models.session import Session
from app.utils.constants import SessionStatus
@@ -41,7 +41,7 @@ class SessionService:
if active_only:
query = query.filter_by(status=SessionStatus.ACTIVE).filter(
Session.expires_at > datetime.utcnow()
Session.expires_at > datetime.now(timezone.utc)
)
return query.all()
@@ -65,7 +65,7 @@ class SessionService:
"""Clean up expired sessions."""
expired_sessions = Session.query.filter(
Session.status == SessionStatus.ACTIVE,
Session.expires_at < datetime.utcnow(),
Session.expires_at < datetime.now(timezone.utc),
Session.deleted_at.is_(None),
).all()
+34 -8
View File
@@ -3,6 +3,7 @@ import base64
import io
import logging
import secrets
from datetime import datetime, timezone
from typing import Tuple
import pyotp
@@ -72,10 +73,34 @@ class TOTPService:
The window parameter allows for clock skew between the server
and the authenticator app. A window of 1 allows codes from
the previous, current, and next 30-second intervals.
IMPORTANT: Always uses UTC time for verification to ensure
consistency across all timezones.
"""
totp = pyotp.TOTP(secret)
is_valid = totp.verify(code, valid_window=window)
logger.debug(f"TOTP code verification: valid={is_valid}, window={window}")
# Use timezone-aware UTC datetime for verification
# IMPORTANT: We must pass a datetime object, NOT a Unix timestamp
# pyotp's internal datetime.utcfromtimestamp() is deprecated and can be
# affected by local timezone settings, causing the 10.5 hour skew issue
utc_now = datetime.now(timezone.utc)
# DEBUG: Log detailed timezone information
logger.debug(f"[TOTP DEBUG] UTC now: {utc_now}")
logger.debug(f"[TOTP DEBUG] UTC now isoformat: {utc_now.isoformat()}")
logger.debug(f"[TOTP DEBUG] UTC timestamp: {utc_now.timestamp()}")
logger.debug(f"[TOTP DEBUG] UTC now tzinfo: {utc_now.tzinfo}")
# Generate what the TOTP code should be at this moment using UTC datetime
expected_code = totp.at(utc_now)
logger.debug(f"[TOTP DEBUG] Expected TOTP code at UTC: {expected_code}")
# Verify with the provided code using UTC datetime object
# Passing a datetime object avoids pyotp's utcfromtimestamp() issues
is_valid = totp.verify(code, valid_window=window, for_time=utc_now)
logger.debug(f"[TOTP DEBUG] TOTP code verification: valid={is_valid}, window={window}")
logger.debug(f"[TOTP DEBUG] Provided code: {code}, Expected code: {expected_code}")
return is_valid
@staticmethod
@@ -133,15 +158,16 @@ class TOTPService:
for hashed_code in hashed_codes:
if bcrypt.check_password_hash(hashed_code, code):
# Code found and valid - don't add to remaining codes (consumed)
logger.debug("Backup code verified and consumed")
return True, remaining_codes
# Code found and valid - mark as matched but don't add to remaining codes
matched = True
else:
# Code doesn't match - keep it in remaining codes
remaining_codes.append(hashed_code)
logger.debug("Backup code verification failed")
return False, remaining_codes
if matched:
return True, remaining_codes
else:
return False, remaining_codes
@staticmethod
def generate_qr_code_data_uri(provisioning_uri: str) -> str:
@@ -185,4 +211,4 @@ class TOTPService:
except ImportError:
logger.warning("qrcode library not installed, returning placeholder")
return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]"
return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]"