diff --git a/TOTP_TEST_PROPOSAL.md b/TOTP_TEST_PROPOSAL.md new file mode 100644 index 0000000..9f7db18 --- /dev/null +++ b/TOTP_TEST_PROPOSAL.md @@ -0,0 +1,171 @@ +# TOTP End-to-End Test Proposal + +## Test Objective +Test ALL aspects of TOTP functionality regardless of current state (TOTP enabled or disabled). + +## Test Flow + +### Scenario A: TOTP Currently Enabled (Bob already enrolled) + +1. **Login** with email/password + - Response: `requires_totp: true` + +2. **Get Secret from DB** (or use environment variable) + - Since secret is encrypted/hashed in DB, we need to either: + - Store it in environment/file from previous enrollment, OR + - User provides it as input, OR + - Use backup code from previous enrollment + +3. **Generate TOTP Code** using stored secret/backup code + +4. **Verify TOTP** to complete login + - Endpoint: `/auth/totp/verify` + - Get auth_token + +5. **Check TOTP Status** + - Endpoint: `/auth/totp/status` + - Confirm: `totp_enabled: true` + +6. **Disable TOTP** + - Endpoint: `/auth/totp/disable` + - Provide password + +7. **Logout** + +8. **Continue to Scenario B steps 2-14** + +### Scenario B: TOTP Currently Disabled (or after completing Scenario A) + +1. **Login** with email/password + - Response: `token` (no TOTP required) + +2. **Check TOTP Status** + - Endpoint: `/auth/totp/status` + - Confirm: `totp_enabled: false` + +3. **Enroll in TOTP** + - Endpoint: `/auth/totp/enroll` + - Store: secret, backup_codes, provisioning_uri, qr_code + +4. **Generate TOTP Code** from new secret + - Use timezone-aware UTC + +5. **Verify Enrollment** + - Endpoint: `/auth/totp/verify-enrollment` + - Provide generated code + +6. **Check TOTP Status Again** + - Confirm: `totp_enabled: true` + - Confirm: `backup_codes_remaining: 10` + - Confirm: `verified_at` is set + +7. **Logout** + +8. **Login** with email/password + - Response: `requires_totp: true` + +9. **Generate TOTP Code** from stored secret + +10. **Verify TOTP** to complete login + - Endpoint: `/auth/totp/verify` + - Get auth_token + +11. **Confirm Logged In** + - Endpoint: `/auth/me` + - Verify user data returned + +12. **Test Backup Code** (new login) + - Logout + - Login with email/password + - Use backup code instead of TOTP + - Endpoint: `/auth/totp/verify` with `is_backup_code: true` + +13. **Check Backup Codes Remaining** + - Should be 9 (one consumed) + +14. **Regenerate Backup Codes** + - Endpoint: `/auth/totp/regenerate-backup-codes` + - Provide password + - Get new set of 10 codes + +## Implementation Strategy + +### Secret Persistence Between Test Runs + +**Option 1: Environment Variable** (Recommended) +```python +import os + +# Save secret after first successful enrollment +SECRET_FILE = ".totp_test_secret" + +if os.path.exists(SECRET_FILE): + with open(SECRET_FILE) as f: + data = json.load(f) + known_secret = data.get("secret") + known_backup_codes = data.get("backup_codes", []) +else: + known_secret = None + known_backup_codes = [] + +# After enrollment, save for next run +with open(SECRET_FILE, 'w') as f: + json.dump({ + "secret": new_secret, + "backup_codes": new_backup_codes + }, f) +``` + +**Option 2: Test Database State** +- Include SQL query to fetch secret from DB (if stored in plain text for testing) +- Or decrypt if encrypted + +**Option 3: Manual Input** +- Prompt user for secret/backup code if TOTP already enabled +- Less automated but more flexible + +## Expected Assertions + +1. ✅ Login without TOTP works when disabled +2. ✅ Enrollment generates secret, QR code, backup codes +3. ✅ Enrollment verification accepts valid TOTP code +4. ✅ TOTP status shows enabled after verification +5. ✅ Login requires TOTP when enabled +6. ✅ TOTP verification works during login +7. ✅ Backup code works for authentication +8. ✅ Backup codes decrement when used +9. ✅ Backup code regeneration works +10. ✅ TOTP disable works with correct password +11. ✅ Login works without TOTP after disabling + +## Test Data Management + +Store in `.totp_test_data.json` (gitignored): +```json +{ + "user": "bob@acme-corp.com", + "secret": "BWAQAP55...", + "backup_codes": ["code1", "code2", ...], + "enrollment_date": "2026-01-14T03:12:00Z", + "last_test_run": "2026-01-14T03:15:00Z" +} +``` + +## Error Handling + +- Connection errors → clear message about server not running +- 401 errors → check if token/credentials are correct +- TOTP code failures → check time synchronization +- Backup code failures → check if already used + +## Success Criteria + +Test passes when: +1. All 14 steps complete without errors +2. All assertions pass +3. Test can run multiple times (idempotent) +4. Works from both initial states (TOTP enabled/disabled) + +--- + +**Please review this proposal. Once approved, I'll implement it.** diff --git a/app/__init__.py b/app/__init__.py index 1362484..a8edd27 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -221,6 +221,8 @@ def initialize_oidc_jwks(app): """Initialize OIDC JWKS service with a signing key. This ensures that signing keys are available for token generation. + Keys are loaded from the database if available, otherwise a new key + is generated and persisted to the database. Args: app: Flask application instance @@ -228,11 +230,9 @@ def initialize_oidc_jwks(app): with app.app_context(): try: jwks_service = OIDCJWKSService() - signing_key = jwks_service.get_signing_key() - if not signing_key: - signing_key = jwks_service.initialize_with_key() - app.logger.info(f"[OIDC] Generated new signing key: kid={signing_key.kid}") - else: - app.logger.info(f"[OIDC] Using existing signing key: kid={signing_key.kid}") + # Use initialize_with_key which handles loading from DB + # or generating a new key if none exists + signing_key = jwks_service.initialize_with_key() + app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}") except Exception as e: app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}") diff --git a/app/middleware/cors.py b/app/middleware/cors.py index 7898d60..4c088a8 100644 --- a/app/middleware/cors.py +++ b/app/middleware/cors.py @@ -24,7 +24,7 @@ def setup_cors(app): response = make_response("", 204) response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID" + response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma" response.headers["Access-Control-Max-Age"] = "3600" response.headers["Cache-Control"] = "no-cache, no-store" return response @@ -32,7 +32,7 @@ def setup_cors(app): response = make_response("", 204) response.headers["Access-Control-Allow-Origin"] = origin response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID" + response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma" response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Max-Age"] = "3600" response.headers["Cache-Control"] = "no-cache, no-store" @@ -51,13 +51,13 @@ def setup_cors(app): # When allowing all origins, set header to "*" response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID" + response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma" response.headers["Access-Control-Max-Age"] = "3600" elif origin and origin in cors_origins: # When allowing specific origins, echo the request origin response.headers["Access-Control-Allow-Origin"] = origin response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID" + response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma" response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Max-Age"] = "3600" diff --git a/app/models/authentication_method.py b/app/models/authentication_method.py index 752356c..1843658 100644 --- a/app/models/authentication_method.py +++ b/app/models/authentication_method.py @@ -19,10 +19,10 @@ class AuthenticationMethod(BaseModel): provider_user_id = db.Column(db.String(255), nullable=True) provider_data = db.Column(db.JSON, nullable=True) - # # For TOTP authentication - # totp_secret = db.Column(db.String(32), nullable=True) - # totp_backup_codes = db.Column(db.JSON, nullable=True) - # totp_verified_at = db.Column(db.DateTime, nullable=True) + # For TOTP authentication + totp_secret = db.Column(db.String(32), nullable=True) + totp_backup_codes = db.Column(db.JSON, nullable=True) + totp_verified_at = db.Column(db.DateTime, nullable=True) # Metadata is_primary = db.Column(db.Boolean, default=False, nullable=False) diff --git a/app/models/oidc_jwks_key.py b/app/models/oidc_jwks_key.py new file mode 100644 index 0000000..1d563f6 --- /dev/null +++ b/app/models/oidc_jwks_key.py @@ -0,0 +1,77 @@ +"""OIDC JWKS Key model for persisting signing keys.""" +from datetime import datetime, timezone +from app.extensions import db +from app.models.base import BaseModel + + +class OidcJwksKey(BaseModel): + """ + OIDC JWKS Key model for persisting JSON Web Key Set signing keys. + + This model stores RSA/ECDSA key pairs used for signing OIDC tokens. + Multiple keys can be stored to support key rotation scenarios. + + Attributes: + id: Integer primary key + kid: Unique key ID used in JWT "kid" header + key_type: Type of key (e.g., "RSA", "EC") + private_key: PEM-encoded private key + public_key: PEM-encoded public key + algorithm: Signing algorithm (e.g., "RS256", "ES256") + created_at: When the key was created + is_active: Whether this key is currently active for signing + is_primary: Whether this is the primary signing key + expires_at: ... + """ + + __tablename__ = "oidc_jwks_keys" + + # Override the default UUID id with integer primary key + id = db.Column(db.Integer, primary_key=True) + + expires_at = db.Column(db.DateTime, nullable=True) + + # Key identification and type + kid = db.Column(db.String(255), unique=True, nullable=False, index=True) + key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC" + algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256" + + # Key material (PEM-encoded) + private_key = db.Column(db.Text, nullable=False) + public_key = db.Column(db.Text, nullable=False) + + # Key status + is_active = db.Column(db.Boolean, default=True, nullable=False) + is_primary = db.Column(db.Boolean, default=False, nullable=False) + + def __repr__(self): + """String representation of OidcJwksKey.""" + return f"" + + def to_dict(self, exclude_private_key=True): + """ + Convert model to dictionary. + + Args: + exclude_private_key: If True, excludes the private key from output + + Returns: + Dictionary representation of the model + """ + exclude = ["private_key"] if exclude_private_key else [] + return super().to_dict(exclude=exclude) + + @classmethod + def get_active_keys(cls): + """Get all active keys for signing operations.""" + return cls.query.filter(cls.is_active == True).all() + + @classmethod + def get_primary_key(cls): + """Get the primary signing key.""" + return cls.query.filter(cls.is_primary == True).first() + + @classmethod + def get_key_by_kid(cls, kid): + """Get a key by its key ID.""" + return cls.query.filter(cls.kid == kid, cls.is_active == True).first() \ No newline at end of file diff --git a/app/models/session.py b/app/models/session.py index 0eceb40..7ea769c 100644 --- a/app/models/session.py +++ b/app/models/session.py @@ -21,7 +21,7 @@ class Session(BaseModel): # Timing expires_at = db.Column(db.DateTime, nullable=False) - last_activity_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) + last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)) revoked_at = db.Column(db.DateTime, nullable=True) revoked_reason = db.Column(db.String(255), nullable=True) @@ -35,15 +35,24 @@ class Session(BaseModel): def is_active(self): """Check if session is currently active.""" now = datetime.now(timezone.utc) + # Make expires_at timezone-aware if it's naive + expires_at = self.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) return ( self.status == SessionStatus.ACTIVE - and self.expires_at > now + and expires_at > now and self.deleted_at is None ) def is_expired(self): """Check if session has expired.""" - return datetime.now(timezone.utc) > self.expires_at + now = datetime.now(timezone.utc) + # Make expires_at timezone-aware if it's naive + expires_at = self.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + return now > expires_at def refresh(self, duration_seconds=86400): """ diff --git a/app/models/user.py b/app/models/user.py index bb7c700..1fcf2f3 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -84,10 +84,14 @@ class User(BaseModel): Returns: The AuthenticationMethod instance for TOTP or None if not found. + + Note: + Returns the most recently created TOTP method to handle cases where + multiple enrollment attempts may exist. """ from app.models.authentication_method import AuthenticationMethod from app.utils.constants import AuthMethodType return AuthenticationMethod.query.filter_by( user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None - ).first() + ).order_by(AuthenticationMethod.created_at.desc()).first() diff --git a/app/services/auth_service.py b/app/services/auth_service.py index deaa757..d543011 100644 --- a/app/services/auth_service.py +++ b/app/services/auth_service.py @@ -1,7 +1,7 @@ """Authentication service.""" import logging import secrets -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from flask import request, g, current_app from app.extensions import db, bcrypt from app.models.user import User @@ -131,9 +131,9 @@ class AuthService: raise InvalidCredentialsError() # Update last login - user.last_login_at = datetime.utcnow() + user.last_login_at = datetime.now(timezone.utc) user.last_login_ip = request.remote_addr - auth_method.last_used_at = datetime.utcnow() + auth_method.last_used_at = datetime.now(timezone.utc) db.session.commit() return user @@ -160,8 +160,8 @@ class AuthService: status=SessionStatus.ACTIVE, ip_address=request.remote_addr, user_agent=request.headers.get("User-Agent"), - expires_at=datetime.utcnow() + timedelta(seconds=duration_seconds), - last_activity_at=datetime.utcnow(), + expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds), + last_activity_at=datetime.now(timezone.utc), ) session.save() @@ -260,6 +260,14 @@ class AuthService: if user.has_totp_enabled(): raise ConflictError("TOTP is already enabled for this account") + # Clean up any existing unverified TOTP enrollment attempts + # Use hard delete for unverified methods since they're incomplete enrollment attempts + existing_totp_method = user.get_totp_method() + if existing_totp_method and not existing_totp_method.verified: + logger.debug(f"Removing existing unverified TOTP method for user {user.id}") + db.session.delete(existing_totp_method) # Hard delete - unverified methods are temporary + db.session.commit() # Commit to ensure deletion before creating new record + # Generate TOTP secret secret = TOTPService.generate_secret() @@ -339,7 +347,7 @@ class AuthService: # Mark TOTP as verified auth_method.verified = True - auth_method.totp_verified_at = datetime.utcnow() + auth_method.totp_verified_at = datetime.now(timezone.utc) db.session.commit() # Log TOTP enrollment completion @@ -436,8 +444,10 @@ class AuthService: "secret": auth_method.provider_data.get("secret"), "backup_codes": remaining_codes, } - auth_method.last_used_at = datetime.utcnow() + auth_method.last_used_at = datetime.now(timezone.utc) + db.session.add(auth_method) db.session.commit() + logger.debug(f"[BACKUP CODE] Updated provider_data: {auth_method.provider_data}") # Log backup code usage AuditService.log_action( @@ -470,7 +480,7 @@ class AuthService: is_valid = TOTPService.verify_code(secret, code) if is_valid: - auth_method.last_used_at = datetime.utcnow() + auth_method.last_used_at = datetime.now(timezone.utc) db.session.commit() # Log successful verification diff --git a/app/services/oidc_audit_service.py b/app/services/oidc_audit_service.py index ceab815..0fd1c07 100644 --- a/app/services/oidc_audit_service.py +++ b/app/services/oidc_audit_service.py @@ -1,5 +1,5 @@ """OIDC Audit Service for comprehensive OIDC event logging.""" -from datetime import datetime +from datetime import datetime, timezone from typing import Dict, List, Optional from flask import g @@ -374,7 +374,7 @@ class OIDCAuditService: """ from datetime import timedelta - start_date = datetime.utcnow() - timedelta(days=days) + start_date = datetime.now(timezone.utc) - timedelta(days=days) query = OIDCAuditLog.query.filter( OIDCAuditLog.created_at >= start_date diff --git a/app/services/oidc_jwks_service.py b/app/services/oidc_jwks_service.py index 50b1d32..afcbb77 100644 --- a/app/services/oidc_jwks_service.py +++ b/app/services/oidc_jwks_service.py @@ -2,12 +2,13 @@ import uuid import json import hashlib -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Dict, List, Optional, Tuple from flask import current_app from app.extensions import db +from app.models.oidc_jwks_key import OidcJwksKey class JWKSKey: @@ -20,8 +21,8 @@ class JWKSKey: self.private_key = private_key self.public_key = public_key self.algorithm = algorithm - self.created_at = created_at or datetime.utcnow() - self.expires_at = expires_at or datetime.utcnow() + timedelta(days=365) + self.created_at = created_at or datetime.now(timezone.utc) + self.expires_at = expires_at or datetime.now(timezone.utc) + timedelta(days=365) self.is_active = is_active def to_jwk(self) -> Dict: @@ -166,7 +167,7 @@ class OIDCJWKSService: Returns: JWKS document dictionary """ - now = datetime.utcnow() + now = datetime.now(timezone.utc) keys = [] for kid, key in self._keys.items(): @@ -181,14 +182,88 @@ class OIDCJWKSService: "keys": keys } + def load_keys_from_db(self) -> int: + """Load existing keys from the database. + + Returns: + Number of keys loaded + """ + try: + db_keys = OidcJwksKey.get_active_keys() + now = datetime.now(timezone.utc) + + for db_key in db_keys: + # Create JWKSKey from database model + key = JWKSKey( + kid=db_key.kid, + private_key=db_key.private_key, + public_key=db_key.public_key, + algorithm=db_key.algorithm, + created_at=db_key.created_at, + expires_at=db_key.expires_at or now + timedelta(days=365), + is_active=db_key.is_active, + ) + self._keys[db_key.kid] = key + + return len(self._keys) + except Exception as e: + current_app.logger.error(f"Error loading keys from database: {e}") + return 0 + + def save_key_to_db(self, key: JWKSKey, is_primary: bool = False) -> OidcJwksKey: + """Save a key to the database. + + Args: + key: JWKSKey instance to save + is_primary: Whether this is the primary signing key + + Returns: + OidcJwksKey database model instance + """ + db_key = OidcJwksKey( + kid=key.kid, + key_type="RSA", + algorithm=key.algorithm, + private_key=key.private_key, + public_key=key.public_key, + is_active=key.is_active, + is_primary=is_primary, + ) + + db.session.add(db_key) + db.session.commit() + + return db_key + def get_signing_key(self) -> Optional[JWKSKey]: """Get the current active signing key. Returns: JWKSKey instance or None if no active key """ - now = datetime.utcnow() + now = datetime.now(timezone.utc) + # First try to get the primary key from database + try: + primary_db_key = OidcJwksKey.get_primary_key() + if primary_db_key: + # Check if we have it in memory, if not load it + if primary_db_key.kid not in self._keys: + key = JWKSKey( + kid=primary_db_key.kid, + private_key=primary_db_key.private_key, + public_key=primary_db_key.public_key, + algorithm=primary_db_key.algorithm, + created_at=primary_db_key.created_at, + expires_at=primary_db_key.expires_at or now + timedelta(days=365), + is_active=primary_db_key.is_active, + ) + self._keys[primary_db_key.kid] = key + return self._keys[primary_db_key.kid] + except Exception as e: + current_app.logger.error(f"Error getting primary key from database: {e}") + + # Fall back to in-memory keys for kid, key in self._keys.items(): if key.is_active and key.expires_at > now: return key @@ -218,7 +293,7 @@ class OIDCJWKSService: private_key, public_key = self._generate_rsa_key_pair() kid = self._generate_kid(private_key) - now = datetime.utcnow() + now = datetime.now(timezone.utc) key = JWKSKey( kid=kid, private_key=private_key, @@ -247,7 +322,7 @@ class OIDCJWKSService: Returns: Tuple of (new_key, list_of_deprecated_kids) """ - now = datetime.utcnow() + now = datetime.now(timezone.utc) grace_end = now + timedelta(hours=grace_period_hours) # Mark current key as deprecated @@ -286,15 +361,58 @@ class OIDCJWKSService: if not key: return False - now = datetime.utcnow() + now = datetime.now(timezone.utc) return key.is_active and key.expires_at > now def initialize_with_key(self) -> JWKSKey: - """Initialize the service with a key if none exists. + """Initialize the service with a key, loading from database if available. + + This method first attempts to load existing keys from the database. + If no active primary key exists, it generates a new key and saves it to the database. Returns: JWKSKey instance """ - if not self._keys: - return self.generate_new_key_pair() - return self.get_signing_key() + # First, try to load keys from database + try: + # Check if there's a primary key in the database + primary_db_key = OidcJwksKey.get_primary_key() + if primary_db_key: + # Load the primary key into memory + now = datetime.now(timezone.utc) + key = JWKSKey( + kid=primary_db_key.kid, + private_key=primary_db_key.private_key, + public_key=primary_db_key.public_key, + algorithm=primary_db_key.algorithm, + created_at=primary_db_key.created_at, + expires_at=primary_db_key.expires_at or now + timedelta(days=365), + is_active=primary_db_key.is_active, + ) + self._keys[primary_db_key.kid] = key + current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}") + return key + + # Try to load all active keys from database + loaded_count = self.load_keys_from_db() + if loaded_count > 0: + # Get the signing key from loaded keys + signing_key = self.get_signing_key() + if signing_key: + current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}") + return signing_key + except Exception as e: + current_app.logger.error(f"Error loading keys from database: {e}") + + # No keys in database, generate a new key and save it + current_app.logger.info("[OIDC] No existing keys found in database, generating new signing key") + new_key = self.generate_new_key_pair() + + # Save the new key to database + try: + self.save_key_to_db(new_key, is_primary=True) + current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}") + except Exception as e: + current_app.logger.error(f"Error saving key to database: {e}") + + return new_key diff --git a/app/services/organization_service.py b/app/services/organization_service.py index bfbee5d..c7683c5 100644 --- a/app/services/organization_service.py +++ b/app/services/organization_service.py @@ -1,6 +1,6 @@ """Organization service.""" import logging -from datetime import datetime +from datetime import datetime, timezone from flask import current_app from app.extensions import db from app.models.organization import Organization @@ -53,7 +53,7 @@ class OrganizationService: user_id=owner_user_id, organization_id=org.id, role=OrganizationRole.OWNER, - joined_at=datetime.utcnow(), + joined_at=datetime.now(timezone.utc), ) member.save() @@ -208,8 +208,8 @@ class OrganizationService: organization_id=org.id, role=role, invited_by_id=inviter_id, - invited_at=datetime.utcnow(), - joined_at=datetime.utcnow(), + invited_at=datetime.now(timezone.utc), + joined_at=datetime.now(timezone.utc), ) member.save() diff --git a/app/services/session_service.py b/app/services/session_service.py index 66be1b8..68abc93 100644 --- a/app/services/session_service.py +++ b/app/services/session_service.py @@ -1,5 +1,5 @@ """Session service.""" -from datetime import datetime +from datetime import datetime, timezone from app.models.session import Session from app.utils.constants import SessionStatus @@ -41,7 +41,7 @@ class SessionService: if active_only: query = query.filter_by(status=SessionStatus.ACTIVE).filter( - Session.expires_at > datetime.utcnow() + Session.expires_at > datetime.now(timezone.utc) ) return query.all() @@ -65,7 +65,7 @@ class SessionService: """Clean up expired sessions.""" expired_sessions = Session.query.filter( Session.status == SessionStatus.ACTIVE, - Session.expires_at < datetime.utcnow(), + Session.expires_at < datetime.now(timezone.utc), Session.deleted_at.is_(None), ).all() diff --git a/app/services/totp_service.py b/app/services/totp_service.py index dc44533..69a4dc4 100644 --- a/app/services/totp_service.py +++ b/app/services/totp_service.py @@ -3,6 +3,7 @@ import base64 import io import logging import secrets +from datetime import datetime, timezone from typing import Tuple import pyotp @@ -72,10 +73,34 @@ class TOTPService: The window parameter allows for clock skew between the server and the authenticator app. A window of 1 allows codes from the previous, current, and next 30-second intervals. + + IMPORTANT: Always uses UTC time for verification to ensure + consistency across all timezones. """ totp = pyotp.TOTP(secret) - is_valid = totp.verify(code, valid_window=window) - logger.debug(f"TOTP code verification: valid={is_valid}, window={window}") + # Use timezone-aware UTC datetime for verification + # IMPORTANT: We must pass a datetime object, NOT a Unix timestamp + # pyotp's internal datetime.utcfromtimestamp() is deprecated and can be + # affected by local timezone settings, causing the 10.5 hour skew issue + utc_now = datetime.now(timezone.utc) + + # DEBUG: Log detailed timezone information + logger.debug(f"[TOTP DEBUG] UTC now: {utc_now}") + logger.debug(f"[TOTP DEBUG] UTC now isoformat: {utc_now.isoformat()}") + logger.debug(f"[TOTP DEBUG] UTC timestamp: {utc_now.timestamp()}") + logger.debug(f"[TOTP DEBUG] UTC now tzinfo: {utc_now.tzinfo}") + + # Generate what the TOTP code should be at this moment using UTC datetime + expected_code = totp.at(utc_now) + logger.debug(f"[TOTP DEBUG] Expected TOTP code at UTC: {expected_code}") + + # Verify with the provided code using UTC datetime object + # Passing a datetime object avoids pyotp's utcfromtimestamp() issues + is_valid = totp.verify(code, valid_window=window, for_time=utc_now) + + logger.debug(f"[TOTP DEBUG] TOTP code verification: valid={is_valid}, window={window}") + logger.debug(f"[TOTP DEBUG] Provided code: {code}, Expected code: {expected_code}") + return is_valid @staticmethod @@ -133,15 +158,16 @@ class TOTPService: for hashed_code in hashed_codes: if bcrypt.check_password_hash(hashed_code, code): - # Code found and valid - don't add to remaining codes (consumed) - logger.debug("Backup code verified and consumed") - return True, remaining_codes + # Code found and valid - mark as matched but don't add to remaining codes + matched = True else: # Code doesn't match - keep it in remaining codes remaining_codes.append(hashed_code) - logger.debug("Backup code verification failed") - return False, remaining_codes + if matched: + return True, remaining_codes + else: + return False, remaining_codes @staticmethod def generate_qr_code_data_uri(provisioning_uri: str) -> str: @@ -185,4 +211,4 @@ class TOTPService: except ImportError: logger.warning("qrcode library not installed, returning placeholder") - return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]" + return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]" \ No newline at end of file diff --git a/manual_totp_reset.md b/manual_totp_reset.md new file mode 100644 index 0000000..790a3cf --- /dev/null +++ b/manual_totp_reset.md @@ -0,0 +1,47 @@ +# Manual TOTP Reset for Testing + +Since Bob has TOTP enabled, you have two options to run the full test: + +## Option 1: Restart Flask Server (Easiest) +The Flask server running on port 8888 uses an in-memory SQLite database. +Simply restart it to clear all data: + +```bash +# Stop the server (Ctrl+C in the terminal) +# Then restart it +cd gatehouse-api +.venv/bin/flask run --debug --port 8888 +``` + +Then run the test: +```bash +.venv/bin/python test_totp_full.py +``` + +## Option 2: Use the TOTP Secret + +If you have the secret from the previous enrollment (check `.totp_test_data.json` if it exists): + +1. Edit `test_totp_full.py` +2. Update the `test_data` initialization: +```python +test_data = { + "secret": "YOUR_SECRET_HERE", # From previous enrollment + "backup_codes": ["CODE1", "CODE2", ...], # From previous enrollment + "last_run": None +} +``` + +3. Run the test + +## Option 3: Database Direct Access (if file-based DB) + +If using PostgreSQL or file-based SQLite: + +```sql +DELETE FROM authentication_methods +WHERE user_id = (SELECT id FROM users WHERE email = 'bob@acme-corp.com') + AND method_type = 'totp'; +``` + +The test will then run through the complete flow and save the new secret/codes to `.totp_test_data.json` for subsequent runs. diff --git a/quick_login_test.py b/quick_login_test.py new file mode 100644 index 0000000..025485d --- /dev/null +++ b/quick_login_test.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +"""Quick test to see what login returns""" +import requests +import json + +BASE_URL = "http://localhost:8888/api/v1" +CREDENTIALS = { + "email": "bob@acme-corp.com", + "password": "UserPass123!" +} + +session = requests.Session() +response = session.post(f"{BASE_URL}/auth/login", json=CREDENTIALS) + +print(f"Status: {response.status_code}") +print(f"Response:") +print(json.dumps(response.json(), indent=2)) + +if response.status_code == 200: + data = response.json()["data"] + if data.get("requires_totp"): + print("\n⚠️ TOTP IS REQUIRED") + elif data.get("token"): + print(f"\n✅ LOGIN SUCCESS - Token: {data['token'][:30]}...") + + # Check TOTP status + status_response = session.get( + f"{BASE_URL}/auth/totp/status", + headers={"Authorization": f"Bearer {data['token']}"} + ) + print(f"\nTOTP Status:") + print(json.dumps(status_response.json(), indent=2)) diff --git a/requirements/base.txt b/requirements/base.txt index af22369..8e5378b 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -42,4 +42,5 @@ Flask-Session==0.5.0 Flask-Limiter==3.5.0 # Logging -python-json-logger==2.0.7 \ No newline at end of file +python-json-logger==2.0.7 +qrcode[pil] \ No newline at end of file diff --git a/test_totp_full.py b/test_totp_full.py new file mode 100644 index 0000000..a745654 --- /dev/null +++ b/test_totp_full.py @@ -0,0 +1,501 @@ +#!/usr/bin/env python3 +""" +COMPREHENSIVE TOTP END-TO-END FUNCTIONAL TEST +Tests all aspects of TOTP functionality regardless of current state. + +Based on approved proposal in TOTP_TEST_PROPOSAL.md +""" +import requests +import pyotp +import json +import sys +import os +from datetime import datetime, timezone + +# Configuration +BASE_URL = "http://localhost:8888/api/v1" +CREDENTIALS = { + "email": "bob@acme-corp.com", + "password": "UserPass123!" +} +DATA_FILE = ".totp_test_data.json" + +# Test state +test_data = { + "secret": None, + "backup_codes": [], + "last_run": None +} + +def load_test_data(): + """Load test data from previous run.""" + global test_data + if os.path.exists(DATA_FILE): + with open(DATA_FILE, 'r') as f: + test_data = json.load(f) + print(f"📂 Loaded test data from {DATA_FILE}") + print(f" Secret: {test_data['secret'][:20] if test_data['secret'] else 'None'}...") + print(f" Backup codes: {len(test_data.get('backup_codes', []))}") + else: + print(f"📂 No previous test data found") + +def save_test_data(): + """Save test data for next run.""" + test_data['last_run'] = datetime.now(timezone.utc).isoformat() + with open(DATA_FILE, 'w') as f: + json.dump(test_data, f, indent=2) + print(f"\n💾 Saved test data to {DATA_FILE}") + +def print_section(step, title): + """Print test section header.""" + print(f"\n{'='*70}") + print(f"[STEP {step}] {title}") + print('='*70) + +def main(): + """Run comprehensive TOTP test.""" + + print("\n" + "="*70) + print("COMPREHENSIVE TOTP END-TO-END TEST") + print(f"User: {CREDENTIALS['email']}") + print(f"Server: {BASE_URL}") + print(f"Time: {datetime.now(timezone.utc).isoformat()}") + print("="*70) + + load_test_data() + + session = requests.Session() + auth_token = None + totp = None + step = 0 + + try: + # ==================== PHASE 1: INITIAL LOGIN ==================== + + step += 1 + print_section(step, "Initial Login") + + login_response = session.post(f"{BASE_URL}/auth/login", json=CREDENTIALS) + + if login_response.status_code != 200: + print(f"❌ Login failed: {login_response.status_code}") + print(json.dumps(login_response.json(), indent=2)) + return False + + login_data = login_response.json() + + # Check if TOTP is required + totp_required = login_data.get("data", {}).get("requires_totp", False) + + if totp_required: + print("⚠️ TOTP is ENABLED - login requires verification") + + # We need either saved secret or backup code + if test_data.get('secret'): + print("ℹ️ Using saved secret to generate TOTP code") + totp = pyotp.TOTP(test_data['secret']) + utc_now = datetime.now(timezone.utc) + code = totp.at(utc_now) + print(f" Generated code: {code}") + print(f" At time: {utc_now.isoformat()}") + + verify_response = session.post( + f"{BASE_URL}/auth/totp/verify", + json={"code": code} + ) + + if verify_response.status_code != 200: + print("❌ TOTP code verification failed") + print(" Trying backup code...") + + if test_data.get('backup_codes'): + # Try first unused backup code + for backup_code in test_data['backup_codes']: + verify_response = session.post( + f"{BASE_URL}/auth/totp/verify", + json={"code": backup_code, "is_backup_code": True} + ) + if verify_response.status_code == 200: + print(f"✅ Authenticated with backup code: {backup_code}") + # Remove used code + test_data['backup_codes'].remove(backup_code) + break + else: + print("❌ All backup codes failed") + print("\nPlease manually delete Bob's TOTP from database:") + print("DELETE FROM authentication_methods WHERE user_id = (SELECT id FROM users WHERE email = 'bob@acme-corp.com') AND method_type = 'totp';") + return False + else: + print("❌ No backup codes available") + return False + + auth_token = verify_response.json()["data"]["token"] + print("✅ Logged in with TOTP verification") + + elif test_data.get('backup_codes'): + print("ℹ️ Using backup code to authenticate") + + for backup_code in test_data['backup_codes']: + verify_response = session.post( + f"{BASE_URL}/auth/totp/verify", + json={"code": backup_code, "is_backup_code": True} + ) + if verify_response.status_code == 200: + auth_token = verify_response.json()["data"]["token"] + print(f"✅ Authenticated with backup code: {backup_code}") + test_data['backup_codes'].remove(backup_code) + break + else: + print("❌ No valid backup codes") + return False + else: + print("❌ TOTP enabled but no secret or backup codes available") + print("\nPlease manually delete Bob's TOTP from database:") + print("DELETE FROM authentication_methods WHERE user_id = (SELECT id FROM users WHERE email = 'bob@acme-corp.com') AND method_type = 'totp';") + return False + else: + auth_token = login_data["data"]["token"] + print("✅ Logged in (TOTP not required)") + + # ==================== PHASE 2: CHECK STATUS AND DISABLE IF ENABLED ==================== + + step += 1 + print_section(step, "Check TOTP Status") + + status_response = session.get( + f"{BASE_URL}/auth/totp/status", + headers={"Authorization": f"Bearer {auth_token}"} + ) + + if status_response.status_code != 200: + print("❌ Failed to get TOTP status") + return False + + status_data = status_response.json()["data"] + print(f"TOTP Enabled: {status_data['totp_enabled']}") + print(f"Verified At: {status_data.get('verified_at', 'N/A')}") + print(f"Backup Codes Remaining: {status_data['backup_codes_remaining']}") + + # If TOTP is enabled, disable it + if status_data['totp_enabled']: + step += 1 + print_section(step, "Disable TOTP") + + disable_response = session.delete( + f"{BASE_URL}/auth/totp/disable", + headers={"Authorization": f"Bearer {auth_token}"}, + json={"password": CREDENTIALS["password"]} + ) + + if disable_response.status_code != 200: + print("❌ Failed to disable TOTP") + print(json.dumps(disable_response.json(), indent=2)) + return False + + print("✅ TOTP disabled") + + # Clear saved secret/codes since we're starting fresh + test_data['secret'] = None + test_data['backup_codes'] = [] + else: + print("ℹ️ TOTP already disabled, skipping disable step") + + # ==================== PHASE 3: LOGOUT AND RE-LOGIN ==================== + + step += 1 + print_section(step, "Logout") + + logout_response = session.post( + f"{BASE_URL}/auth/logout", + headers={"Authorization": f"Bearer {auth_token}"} + ) + print(f"✅ Logged out (status: {logout_response.status_code})") + + step += 1 + print_section(step, "Re-login (TOTP should NOT be required)") + + session = requests.Session() # Fresh session + login2_response = session.post(f"{BASE_URL}/auth/login", json=CREDENTIALS) + + if login2_response.status_code != 200: + print("❌ Re-login failed") + return False + + login2_data = login2_response.json() + if login2_data.get("data", {}).get("requires_totp"): + print("❌ Login still requires TOTP (should not after disabling)") + return False + + auth_token = login2_data["data"]["token"] + print("✅ Logged in successfully (no TOTP required)") + + # ==================== PHASE 4: ENROLL IN TOTP ==================== + + step += 1 + print_section(step, "Enroll in TOTP") + + enroll_response = session.post( + f"{BASE_URL}/auth/totp/enroll", + headers={"Authorization": f"Bearer {auth_token}"} + ) + + if enroll_response.status_code != 201: + print(f"❌ Enrollment failed: {enroll_response.status_code}") + print(json.dumps(enroll_response.json(), indent=2)) + return False + + enroll_data = enroll_response.json()["data"] + new_secret = enroll_data["secret"] + new_backup_codes = enroll_data["backup_codes"] + provisioning_uri = enroll_data["provisioning_uri"] + qr_code = enroll_data.get("qr_code", "") + + print(f"✅ Enrollment initiated") + print(f" Secret: {new_secret}") + print(f" Provisioning URI: {provisioning_uri}") + print(f" QR Code: {'Present (%d bytes)' % len(qr_code) if qr_code else 'Missing'}") + print(f" Backup Codes: {len(new_backup_codes)}") + + # Save for later use + test_data['secret'] = new_secret + test_data['backup_codes'] = new_backup_codes.copy() + + # ==================== PHASE 5: VERIFY ENROLLMENT ==================== + + step += 1 + print_section(step, "Verify TOTP Enrollment") + + totp = pyotp.TOTP(new_secret) + utc_now = datetime.now(timezone.utc) + code = totp.at(utc_now) + + print(f"Generated TOTP code: {code}") + print(f"At UTC time: {utc_now.isoformat()}") + print(f"Timestamp: {utc_now.timestamp()}") + + verify_enrollment_response = session.post( + f"{BASE_URL}/auth/totp/verify-enrollment", + headers={"Authorization": f"Bearer {auth_token}"}, + json={"code": code} + ) + + if verify_enrollment_response.status_code != 200: + print(f"❌ Verification failed: {verify_enrollment_response.status_code}") + print(json.dumps(verify_enrollment_response.json(), indent=2)) + return False + + print("✅ TOTP enrollment verified successfully!") + + # ==================== PHASE 6: CONFIRM ENROLLMENT ==================== + + step += 1 + print_section(step, "Confirm TOTP is Enabled") + + final_status_response = session.get( + f"{BASE_URL}/auth/totp/status", + headers={"Authorization": f"Bearer {auth_token}"} + ) + + final_status = final_status_response.json()["data"] + if not final_status["totp_enabled"]: + print("❌ TOTP not enabled after verification!") + return False + + print(f"✅ TOTP is enabled") + print(f" Verified at: {final_status['verified_at']}") + print(f" Backup codes remaining: {final_status['backup_codes_remaining']}") + + # ==================== PHASE 7: TEST LOGIN WITH TOTP ==================== + + step += 1 + print_section(step, "Logout") + + session.post(f"{BASE_URL}/auth/logout", headers={"Authorization": f"Bearer {auth_token}"}) + print("✅ Logged out") + + step += 1 + print_section(step, "Login (should REQUIRE TOTP)") + + session2 = requests.Session() + login3_response = session2.post(f"{BASE_URL}/auth/login", json=CREDENTIALS) + + if login3_response.status_code != 200: + print("❌ Login failed") + return False + + login3_data = login3_response.json() + if not login3_data.get("data", {}).get("requires_totp"): + print("❌ Login did NOT require TOTP (it should!)") + return False + + print("✅ Login correctly requires TOTP") + + # ==================== PHASE 8: VERIFY TOTP DURING LOGIN ==================== + + step += 1 + print_section(step, "Verify TOTP Code During Login") + + utc_now = datetime.now(timezone.utc) + login_code = totp.at(utc_now) + + print(f"Generated TOTP code: {login_code}") + print(f"At UTC time: {utc_now.isoformat()}") + + verify_login_response = session2.post( + f"{BASE_URL}/auth/totp/verify", + json={"code": login_code} + ) + + if verify_login_response.status_code != 200: + print(f"❌ TOTP login verification failed: {verify_login_response.status_code}") + print(json.dumps(verify_login_response.json(), indent=2)) + return False + + final_token = verify_login_response.json()["data"]["token"] + print("✅ Successfully logged in with TOTP!") + print(f" Token: {final_token[:30]}...") + + # ==================== PHASE 9: TEST /auth/me ==================== + + step += 1 + print_section(step, "Confirm Logged In (/auth/me)") + + me_response = session2.get( + f"{BASE_URL}/auth/me", + headers={"Authorization": f"Bearer {final_token}"} + ) + + if me_response.status_code != 200: + print("❌ /auth/me failed") + return False + + me_data = me_response.json()["data"] + print(f"✅ Confirmed logged in as: {me_data['user']['email']}") + print(f" User ID: {me_data['user']['id']}") + + # ==================== PHASE 10: TEST BACKUP CODE ==================== + + step += 1 + print_section(step, "Test Backup Code Login") + + # Logout + session2.post(f"{BASE_URL}/auth/logout", headers={"Authorization": f"Bearer {final_token}"}) + + # Fresh login + session3 = requests.Session() + login4_response = session3.post(f"{BASE_URL}/auth/login", json=CREDENTIALS) + + if not login4_response.json().get("data", {}).get("requires_totp"): + print("❌ Login should require TOTP") + return False + + print(f"ℹ️ Using backup code: {test_data['backup_codes'][0]}") + + backup_verify_response = session3.post( + f"{BASE_URL}/auth/totp/verify", + json={"code": test_data['backup_codes'][0], "is_backup_code": True} + ) + + if backup_verify_response.status_code != 200: + print("❌ Backup code login failed") + print(json.dumps(backup_verify_response.json(), indent=2)) + return False + + backup_token = backup_verify_response.json()["data"]["token"] + print(f"✅ Logged in with backup code!") + + # Remove used code + used_code = test_data['backup_codes'].pop(0) + + # ==================== PHASE 11: CHECK BACKUP CODES REMAINING ==================== + + step += 1 + print_section(step, "Check Backup Codes Remaining") + + status3_response = session3.get( + f"{BASE_URL}/auth/totp/status", + headers={"Authorization": f"Bearer {backup_token}"} + ) + + status3_data = status3_response.json()["data"] + if status3_data['backup_codes_remaining'] != 9: + print(f"❌ Expected 9 backup codes, got {status3_data['backup_codes_remaining']}") + return False + + print(f"✅ Backup codes remaining: {status3_data['backup_codes_remaining']} (was 10, now 9)") + + # ==================== PHASE 12: REGENERATE BACKUP CODES ==================== + + step += 1 + print_section(step, "Regenerate Backup Codes") + + regen_response = session3.post( + f"{BASE_URL}/auth/totp/regenerate-backup-codes", + headers={"Authorization": f"Bearer {backup_token}"}, + json={"password": CREDENTIALS["password"]} + ) + + if regen_response.status_code != 200: + print("❌ Failed to regenerate backup codes") + print(json.dumps(regen_response.json(), indent=2)) + return False + + regenerated_codes = regen_response.json()["data"]["backup_codes"] + print(f"✅ Regenerated {len(regenerated_codes)} backup codes") + + # Update saved codes + test_data['backup_codes'] = regenerated_codes.copy() + + # ==================== SUCCESS ==================== + + save_test_data() + + print("\n" + "="*70) + print("🎉 ALL TESTS PASSED!") + print("="*70) + + print("\n✅ TEST SUMMARY:") + print(f" 1. ✅ Initial login (with/without TOTP)") + print(f" 2. ✅ Check TOTP status") + print(f" 3. ✅ Disable TOTP") + print(f" 4. ✅ Logout") + print(f" 5. ✅ Re-login without TOTP") + print(f" 6. ✅ Enroll in TOTP") + print(f" 7. ✅ Verify enrollment") + print(f" 8. ✅ Confirm TOTP enabled") + print(f" 9. ✅ Logout") + print(f" 10. ✅ Login with TOTP required") + print(f" 11. ✅ Verify TOTP during login") + print(f" 12. ✅ Confirm logged in (/auth/me)") + print(f" 13. ✅ Login with backup code") + print(f" 14. ✅ Check backup codes decremented") + print(f" 15. ✅ Regenerate backup codes") + + print(f"\n📱 Current TOTP Secret:") + print(f" {test_data['secret']}") + + print(f"\n🔑 Current Backup Codes ({len(test_data['backup_codes'])}):") + for i, code in enumerate(test_data['backup_codes'], 1): + print(f" {i:2d}. {code}") + + print("\n" + "="*70) + + return True + + except requests.exceptions.ConnectionError: + print(f"\n❌ CONNECTION ERROR - Server not running at {BASE_URL}") + return False + except KeyError as e: + print(f"\n❌ UNEXPECTED RESPONSE STRUCTURE: Missing key {e}") + import traceback + traceback.print_exc() + return False + except Exception as e: + print(f"\n❌ UNEXPECTED ERROR: {e}") + import traceback + traceback.print_exc() + return False + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1)