feat(auth): implement TOTP two-factor authentication with enrollment and verification

Adds TOTP (Time-based One-Time Password) two-factor authentication support including:
- New TOTP service with secret generation, QR code provisioning, and code verification
- New auth endpoints for enrollment, verification, status, and backup code management
- New TOTP authentication method type and user methods for TOTP management
- Backup codes generation and verification for account recovery
- Updated OIDC endpoints with timezone-aware datetime handling and RFC-compliant responses
- Added "roles" scope support for OIDC userinfo and ID tokens
- New pyotp dependency for TOTP operations
- Comprehensive unit tests for TOTP service
This commit is contained in:
2026-01-14 18:06:17 +10:30
parent 977abf66df
commit cfd79190ee
26 changed files with 2176 additions and 263 deletions
+12 -1
View File
@@ -19,6 +19,11 @@ class AuthenticationMethod(BaseModel):
provider_user_id = db.Column(db.String(255), nullable=True)
provider_data = db.Column(db.JSON, nullable=True)
# # For TOTP authentication
# totp_secret = db.Column(db.String(32), nullable=True)
# totp_backup_codes = db.Column(db.JSON, nullable=True)
# totp_verified_at = db.Column(db.DateTime, nullable=True)
# Metadata
is_primary = db.Column(db.Boolean, default=False, nullable=False)
verified = db.Column(db.Boolean, default=False, nullable=False)
@@ -51,9 +56,15 @@ class AuthenticationMethod(BaseModel):
AuthMethodType.MICROSOFT,
]
def is_totp(self):
"""Check if this is a TOTP authentication method."""
return self.method_type == AuthMethodType.TOTP
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude password hash
# Always exclude password hash and TOTP secrets
exclude.append("password_hash")
exclude.append("totp_secret")
exclude.append("totp_backup_codes")
return super().to_dict(exclude=exclude)
+5 -5
View File
@@ -1,6 +1,6 @@
"""Base model with common fields and functionality."""
import uuid
from datetime import datetime
from datetime import datetime, timezone
from app.extensions import db
@@ -16,9 +16,9 @@ class BaseModel(db.Model):
unique=True,
nullable=False,
)
created_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow)
created_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
updated_at = db.Column(
db.DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow
db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)
)
deleted_at = db.Column(db.DateTime, nullable=True)
@@ -36,7 +36,7 @@ class BaseModel(db.Model):
soft: If True, performs soft delete. If False, hard delete.
"""
if soft:
self.deleted_at = datetime.utcnow()
self.deleted_at = datetime.now(timezone.utc)
db.session.commit()
else:
db.session.delete(self)
@@ -47,7 +47,7 @@ class BaseModel(db.Model):
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
self.updated_at = datetime.utcnow()
self.updated_at = datetime.now(timezone.utc)
db.session.commit()
return self
+9 -4
View File
@@ -1,5 +1,5 @@
"""OIDC Authorization Code model for auth code flow."""
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from app.extensions import db
from app.models.base import BaseModel
@@ -49,7 +49,12 @@ class OIDCAuthCode(BaseModel):
def is_expired(self):
"""Check if the authorization code has expired."""
return datetime.utcnow() > self.expires_at
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
# Make naive datetime timezone-aware (UTC)
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_valid(self):
"""Check if the authorization code is valid for use."""
@@ -58,7 +63,7 @@ class OIDCAuthCode(BaseModel):
def mark_as_used(self):
"""Mark the authorization code as used."""
self.is_used = True
self.used_at = datetime.utcnow()
self.used_at = datetime.now(timezone.utc)
db.session.commit()
@classmethod
@@ -90,7 +95,7 @@ class OIDCAuthCode(BaseModel):
scope=scope,
nonce=nonce,
code_verifier=code_verifier,
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
ip_address=ip_address,
user_agent=user_agent,
)
+9 -5
View File
@@ -1,5 +1,5 @@
"""OIDC Refresh Token model for token rotation."""
from datetime import datetime
from datetime import datetime, timezone
from app.extensions import db
from app.models.base import BaseModel
@@ -58,7 +58,11 @@ class OIDCRefreshToken(BaseModel):
def is_expired(self):
"""Check if the refresh token has expired."""
return datetime.utcnow() > self.expires_at
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self):
"""Check if the refresh token has been revoked."""
@@ -74,7 +78,7 @@ class OIDCRefreshToken(BaseModel):
Args:
reason: Optional reason for revocation
"""
self.revoked_at = datetime.utcnow()
self.revoked_at = datetime.now(timezone.utc)
self.revoked_reason = reason
db.session.commit()
@@ -93,7 +97,7 @@ class OIDCRefreshToken(BaseModel):
self.rotation_count += 1
# Extend expiration on rotation
from datetime import timedelta
self.expires_at = datetime.utcnow() + timedelta(days=30)
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
db.session.commit()
return self
@@ -123,7 +127,7 @@ class OIDCRefreshToken(BaseModel):
token_hash=token_hash,
scope=scope,
access_token_id=access_token_id,
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
ip_address=ip_address,
user_agent=user_agent,
)
+4 -4
View File
@@ -1,5 +1,5 @@
"""OIDC Session model for OIDC session tracking."""
from datetime import datetime
from datetime import datetime, timezone
from app.extensions import db
from app.models.base import BaseModel
@@ -49,7 +49,7 @@ class OIDCSession(BaseModel):
def is_expired(self):
"""Check if the OIDC session has expired."""
return datetime.utcnow() > self.expires_at
return datetime.now(timezone.utc) > self.expires_at
def is_authenticated(self):
"""Check if the user has been authenticated in this session."""
@@ -57,7 +57,7 @@ class OIDCSession(BaseModel):
def mark_authenticated(self):
"""Mark the session as authenticated."""
self.authenticated_at = datetime.utcnow()
self.authenticated_at = datetime.now(timezone.utc)
db.session.commit()
def validate_nonce(self, expected_nonce):
@@ -126,7 +126,7 @@ class OIDCSession(BaseModel):
nonce=nonce,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
expires_at=datetime.utcnow() + timedelta(seconds=lifetime_seconds),
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
)
db.session.add(session)
db.session.commit()
+7 -3
View File
@@ -1,6 +1,6 @@
"""OIDC Token Metadata model for token revocation tracking."""
import uuid
from datetime import datetime
from datetime import datetime, timezone
from app.extensions import db
from app.models.base import BaseModel
@@ -50,7 +50,11 @@ class OIDCTokenMetadata(BaseModel):
def is_expired(self):
"""Check if the token has expired."""
return datetime.utcnow() > self.expires_at
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self):
"""Check if the token has been revoked."""
@@ -66,7 +70,7 @@ class OIDCTokenMetadata(BaseModel):
Args:
reason: Optional reason for revocation
"""
self.revoked_at = datetime.utcnow()
self.revoked_at = datetime.now(timezone.utc)
self.revoked_reason = reason
db.session.commit()
+6 -6
View File
@@ -1,5 +1,5 @@
"""Session model."""
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from app.extensions import db
from app.models.base import BaseModel
from app.utils.constants import SessionStatus
@@ -34,7 +34,7 @@ class Session(BaseModel):
def is_active(self):
"""Check if session is currently active."""
now = datetime.utcnow()
now = datetime.now(timezone.utc)
return (
self.status == SessionStatus.ACTIVE
and self.expires_at > now
@@ -43,7 +43,7 @@ class Session(BaseModel):
def is_expired(self):
"""Check if session has expired."""
return datetime.utcnow() > self.expires_at
return datetime.now(timezone.utc) > self.expires_at
def refresh(self, duration_seconds=86400):
"""
@@ -52,8 +52,8 @@ class Session(BaseModel):
Args:
duration_seconds: New session duration in seconds
"""
self.expires_at = datetime.utcnow() + timedelta(seconds=duration_seconds)
self.last_activity_at = datetime.utcnow()
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
self.last_activity_at = datetime.now(timezone.utc)
db.session.commit()
def revoke(self, reason=None):
@@ -64,7 +64,7 @@ class Session(BaseModel):
reason: Optional reason for revocation
"""
self.status = SessionStatus.REVOKED
self.revoked_at = datetime.utcnow()
self.revoked_at = datetime.now(timezone.utc)
if reason:
self.revoked_reason = reason
db.session.commit()
+32
View File
@@ -59,3 +59,35 @@ class User(BaseModel):
def get_organizations(self):
"""Get all organizations the user is a member of."""
return [membership.organization for membership in self.organization_memberships]
def has_totp_enabled(self) -> bool:
"""Check if user has TOTP enabled and verified.
Returns:
True if user has a verified TOTP authentication method, False otherwise.
"""
from app.models.authentication_method import AuthenticationMethod
from app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.TOTP,
verified=True,
deleted_at=None,
).first()
is not None
)
def get_totp_method(self):
"""Get user's TOTP authentication method.
Returns:
The AuthenticationMethod instance for TOTP or None if not found.
"""
from app.models.authentication_method import AuthenticationMethod
from app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
).first()