move app to gatehouse-app

This commit is contained in:
2026-01-15 03:40:29 +10:30
parent 5e4cffcf73
commit 2c0aaf484b
69 changed files with 1569 additions and 294 deletions
+25
View File
@@ -0,0 +1,25 @@
"""Services package."""
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.session_service import SessionService
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.services.oidc_service import OIDCService, OIDCError
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
from gatehouse_app.services.oidc_token_service import OIDCTokenService
from gatehouse_app.services.oidc_session_service import OIDCSessionService
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
__all__ = [
"AuthService",
"UserService",
"OrganizationService",
"SessionService",
"AuditService",
"OIDCService",
"OIDCError",
"OIDCJWKSService",
"OIDCTokenService",
"OIDCSessionService",
"OIDCAuditService",
]
+107
View File
@@ -0,0 +1,107 @@
"""Audit service."""
from flask import request, g
from gatehouse_app.models.audit_log import AuditLog
from gatehouse_app.utils.constants import AuditAction
class AuditService:
"""Service for audit logging."""
@staticmethod
def log_action(
action,
user_id=None,
organization_id=None,
resource_type=None,
resource_id=None,
metadata=None,
description=None,
success=True,
error_message=None,
):
"""
Create an audit log entry.
Args:
action: AuditAction enum value
user_id: ID of user performing the action
organization_id: ID of related organization
resource_type: Type of resource being acted upon
resource_id: ID of resource being acted upon
metadata: Additional metadata dictionary
description: Human-readable description
success: Whether the action succeeded
error_message: Error message if action failed
Returns:
AuditLog instance
"""
# Get request details if available
ip_address = None
user_agent = None
request_id = None
try:
if request:
ip_address = request.remote_addr
user_agent = request.headers.get("User-Agent")
request_id = g.get("request_id")
except RuntimeError:
# No request context
pass
log_entry = AuditLog(
action=action,
user_id=user_id,
organization_id=organization_id,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
metadata=metadata,
description=description,
success=success,
error_message=error_message,
)
log_entry.save()
return log_entry
@staticmethod
def get_user_activity(user_id, limit=50):
"""
Get recent activity for a user.
Args:
user_id: User ID
limit: Maximum number of records to return
Returns:
List of AuditLog instances
"""
return (
AuditLog.query.filter_by(user_id=user_id)
.order_by(AuditLog.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def get_organization_activity(organization_id, limit=50):
"""
Get recent activity for an organization.
Args:
organization_id: Organization ID
limit: Maximum number of records to return
Returns:
List of AuditLog instances
"""
return (
AuditLog.query.filter_by(organization_id=organization_id)
.order_by(AuditLog.created_at.desc())
.limit(limit)
.all()
)
+559
View File
@@ -0,0 +1,559 @@
"""Authentication service."""
import logging
import secrets
from datetime import datetime, timedelta, timezone
from flask import request, g, current_app
from gatehouse_app.extensions import db, bcrypt
from gatehouse_app.models.user import User
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.models.session import Session
from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, UserStatus, AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.services.totp_service import TOTPService
logger = logging.getLogger(__name__)
class AuthService:
"""Service for authentication operations."""
@staticmethod
def register_user(email, password, full_name=None):
"""
Register a new user with email/password.
Args:
email: User email address
password: Plain text password
full_name: Optional full name
Returns:
User instance
Raises:
EmailAlreadyExistsError: If email is already registered
"""
# Check if email already exists
existing_user = User.query.filter_by(email=email.lower()).first()
if existing_user and existing_user.deleted_at is None:
raise EmailAlreadyExistsError()
# Create user
user = User(
email=email.lower(),
full_name=full_name,
status=UserStatus.ACTIVE,
)
user.save()
# Create password authentication method
password_hash = bcrypt.generate_password_hash(password).decode("utf-8")
auth_method = AuthenticationMethod(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
password_hash=password_hash,
is_primary=True,
verified=True,
)
auth_method.save()
# Log the registration
AuditService.log_action(
action=AuditAction.USER_REGISTER,
user_id=user.id,
resource_type="user",
resource_id=user.id,
description=f"User registered with email: {email}",
)
return user
@staticmethod
def authenticate(email, password):
"""
Authenticate user with email/password.
Args:
email: User email
password: Plain text password
Returns:
User instance if authentication succeeds
Raises:
InvalidCredentialsError: If credentials are invalid
AccountSuspendedError: If account is suspended
AccountInactiveError: If account is inactive
"""
# Find user
user = User.query.filter_by(email=email.lower(), deleted_at=None).first()
# Development-only debug logging for user existence check
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] User lookup: email={email}, exists={user is not None}")
if not user:
raise InvalidCredentialsError()
# Check account status
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}")
if user.status == UserStatus.SUSPENDED:
raise AccountSuspendedError()
if user.status == UserStatus.INACTIVE:
raise AccountInactiveError()
# Find password auth method
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
# Development-only debug logging for auth method lookup
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] Auth method lookup: user_id={user.id}, has_password_auth={auth_method is not None and auth_method.password_hash is not None}")
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError()
# Verify password
password_valid = bcrypt.check_password_hash(auth_method.password_hash, password)
# Development-only debug logging for password validation (without logging actual password)
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] Password validation: user_id={user.id}, valid={password_valid}")
if not password_valid:
raise InvalidCredentialsError()
# Update last login
user.last_login_at = datetime.now(timezone.utc)
user.last_login_ip = request.remote_addr
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
return user
@staticmethod
def create_session(user, duration_seconds=86400):
"""
Create a new session for the user.
Args:
user: User instance
duration_seconds: Session duration in seconds
Returns:
Session instance
"""
# Generate session token
token = secrets.token_urlsafe(32)
# Create session
session = Session(
user_id=user.id,
token=token,
status=SessionStatus.ACTIVE,
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent"),
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
last_activity_at=datetime.now(timezone.utc),
)
session.save()
# Log session creation
AuditService.log_action(
action=AuditAction.SESSION_CREATE,
user_id=user.id,
resource_type="session",
resource_id=session.id,
description="User session created",
)
return session
@staticmethod
def change_password(user, current_password, new_password):
"""
Change user password.
Args:
user: User instance
current_password: Current password
new_password: New password
Raises:
InvalidCredentialsError: If current password is incorrect
"""
# Find password auth method
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError("No password authentication method found")
# Verify current password
if not bcrypt.check_password_hash(auth_method.password_hash, current_password):
raise InvalidCredentialsError("Current password is incorrect")
# Update password
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
db.session.commit()
# Log password change
AuditService.log_action(
action=AuditAction.PASSWORD_CHANGE,
user_id=user.id,
description="User changed password",
)
@staticmethod
def revoke_session(session_id, reason=None):
"""
Revoke a session.
Args:
session_id: Session ID to revoke
reason: Optional revocation reason
"""
session = Session.query.get(session_id)
if session:
session.revoke(reason=reason)
# Log session revocation
AuditService.log_action(
action=AuditAction.SESSION_REVOKE,
user_id=session.user_id,
resource_type="session",
resource_id=session.id,
description=f"Session revoked: {reason or 'User logout'}",
)
@staticmethod
def enroll_totp(user: User) -> dict:
"""
Initiate TOTP enrollment for a user.
Args:
user: User instance
Returns:
Dictionary containing:
- secret: TOTP secret (base32 encoded)
- provisioning_uri: otpauth:// URI for QR code
- qr_code: Base64 encoded QR code as data URI
- backup_codes: List of plain text backup codes
Raises:
ConflictError: If user already has TOTP enabled
"""
from gatehouse_app.exceptions.validation_exceptions import ConflictError
# Check if user already has TOTP enabled
if user.has_totp_enabled():
raise ConflictError("TOTP is already enabled for this account")
# Clean up any existing unverified TOTP enrollment attempts
# Use hard delete for unverified methods since they're incomplete enrollment attempts
existing_totp_method = user.get_totp_method()
if existing_totp_method and not existing_totp_method.verified:
logger.debug(f"Removing existing unverified TOTP method for user {user.id}")
db.session.delete(existing_totp_method) # Hard delete - unverified methods are temporary
db.session.commit() # Commit to ensure deletion before creating new record
# Generate TOTP secret
secret = TOTPService.generate_secret()
# Generate provisioning URI
provisioning_uri = TOTPService.generate_provisioning_uri(
user_email=user.email,
secret=secret,
issuer="Gatehouse",
)
# Generate QR code data URI
qr_code = TOTPService.generate_qr_code_data_uri(provisioning_uri)
# Generate backup codes
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
# Create unverified TOTP authentication method
auth_method = AuthenticationMethod(
user_id=user.id,
method_type=AuthMethodType.TOTP,
verified=False,
is_primary=False,
)
auth_method.save()
# Store TOTP data in provider_data (since totp_secret field is commented out)
auth_method.provider_data = {
"secret": secret,
"backup_codes": hashed_backup_codes,
}
db.session.commit()
# Log TOTP enrollment initiation
AuditService.log_action(
action=AuditAction.TOTP_ENROLL_INITIATED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP enrollment initiated",
)
return {
"secret": secret,
"provisioning_uri": provisioning_uri,
"qr_code": qr_code,
"backup_codes": backup_codes,
}
@staticmethod
def verify_totp_enrollment(user: User, code: str) -> bool:
"""
Complete TOTP enrollment by verifying the first TOTP code.
Args:
user: User instance
code: 6-digit TOTP code from authenticator app
Returns:
True if verification successful
Raises:
InvalidCredentialsError: If code is invalid or TOTP method not found
"""
# Get user's TOTP authentication method
auth_method = user.get_totp_method()
if not auth_method:
raise InvalidCredentialsError("TOTP enrollment not found")
# Get secret from provider_data
secret = auth_method.provider_data.get("secret") if auth_method.provider_data else None
if not secret:
raise InvalidCredentialsError("TOTP secret not found")
# Verify the code
if not TOTPService.verify_code(secret, code):
raise InvalidCredentialsError("Invalid TOTP code")
# Mark TOTP as verified
auth_method.verified = True
auth_method.totp_verified_at = datetime.now(timezone.utc)
db.session.commit()
# Log TOTP enrollment completion
AuditService.log_action(
action=AuditAction.TOTP_ENROLL_COMPLETED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP enrollment completed",
)
return True
@staticmethod
def disable_totp(user: User, password: str) -> bool:
"""
Disable TOTP for a user.
Args:
user: User instance
password: User's current password for verification
Returns:
True if TOTP disabled successfully
Raises:
InvalidCredentialsError: If password is invalid or TOTP method not found
"""
# Verify user's password
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError("No password authentication method found")
if not bcrypt.check_password_hash(auth_method.password_hash, password):
raise InvalidCredentialsError("Invalid password")
# Get user's TOTP authentication method
totp_method = user.get_totp_method()
if not totp_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
# Soft-delete the TOTP authentication method
totp_method.delete(soft=True)
# Log TOTP disabled
AuditService.log_action(
action=AuditAction.TOTP_DISABLED,
user_id=user.id,
resource_type="authentication_method",
resource_id=totp_method.id,
description="TOTP disabled",
)
return True
@staticmethod
def authenticate_with_totp(user: User, code: str, is_backup_code: bool = False) -> bool:
"""
Verify TOTP code during login.
Args:
user: User instance
code: 6-digit TOTP code or backup code
is_backup_code: True if code is a backup code, False if TOTP code
Returns:
True if code is valid
Raises:
InvalidCredentialsError: If code is invalid or TOTP method not found
"""
# Get user's TOTP authentication method
auth_method = user.get_totp_method()
if not auth_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
if is_backup_code:
# Verify backup code
backup_codes = (
auth_method.provider_data.get("backup_codes")
if auth_method.provider_data
else []
)
is_valid, remaining_codes = TOTPService.verify_backup_code(backup_codes, code)
if is_valid:
# Update remaining backup codes
auth_method.provider_data = {
"secret": auth_method.provider_data.get("secret"),
"backup_codes": remaining_codes,
}
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.add(auth_method)
db.session.commit()
logger.debug(f"[BACKUP CODE] Updated provider_data: {auth_method.provider_data}")
# Log backup code usage
AuditService.log_action(
action=AuditAction.TOTP_BACKUP_CODE_USED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Backup code used for authentication",
)
else:
# Log failed verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_FAILED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Invalid backup code provided",
)
raise InvalidCredentialsError("Invalid backup code")
else:
# Verify TOTP code
secret = (
auth_method.provider_data.get("secret")
if auth_method.provider_data
else None
)
if not secret:
raise InvalidCredentialsError("TOTP secret not found")
is_valid = TOTPService.verify_code(secret, code)
if is_valid:
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
# Log successful verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_SUCCESS,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP code verified successfully",
)
else:
# Log failed verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_FAILED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Invalid TOTP code provided",
)
raise InvalidCredentialsError("Invalid TOTP code")
return True
@staticmethod
def regenerate_totp_backup_codes(user: User, password: str) -> list[str]:
"""
Generate new backup codes for TOTP.
Args:
user: User instance
password: User's current password for verification
Returns:
List of new plain text backup codes
Raises:
InvalidCredentialsError: If password is invalid or TOTP method not found
"""
# Verify user's password
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError("No password authentication method found")
if not bcrypt.check_password_hash(auth_method.password_hash, password):
raise InvalidCredentialsError("Invalid password")
# Get user's TOTP authentication method
totp_method = user.get_totp_method()
if not totp_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
# Generate new backup codes
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
# Update the authentication method with new backup codes
totp_method.provider_data = {
"secret": totp_method.provider_data.get("secret"),
"backup_codes": hashed_backup_codes,
}
db.session.commit()
# Log backup codes regeneration
AuditService.log_action(
action=AuditAction.TOTP_BACKUP_CODES_REGENERATED,
user_id=user.id,
resource_type="authentication_method",
resource_id=totp_method.id,
description="TOTP backup codes regenerated",
)
return backup_codes
@@ -0,0 +1,408 @@
"""OIDC Audit Service for comprehensive OIDC event logging."""
from datetime import datetime, timezone
from typing import Dict, List, Optional
from flask import g
from gatehouse_app.models import OIDCAuditLog, OIDCClient, User
from gatehouse_app.exceptions.validation_exceptions import NotFoundError
class OIDCAuditService:
"""Service for OIDC-specific audit logging.
This service provides methods to log all OIDC-related events including:
- Authorization requests and responses
- Token issuance and refresh
- Token revocation
- UserInfo endpoint access
- Authentication failures
"""
# Event type constants
EVENT_AUTHORIZATION_REQUEST = "authorization_request"
EVENT_AUTHORIZATION_RESPONSE = "authorization_response"
EVENT_TOKEN_ISSUE = "token_issue"
EVENT_TOKEN_REFRESH = "token_refresh"
EVENT_TOKEN_REVOCATION = "token_revocation"
EVENT_TOKEN_INTROSPECTION = "token_introspection"
EVENT_USERINFO_ACCESS = "userinfo_access"
EVENT_AUTHENTICATION_FAILURE = "authentication_failure"
EVENT_AUTHORIZATION_FAILURE = "authorization_failure"
EVENT_JWKS_ACCESS = "jwks_access"
EVENT_REGISTRATION = "client_registration"
@classmethod
def _get_request_context(cls) -> Dict:
"""Extract request context for logging.
Returns:
Dictionary with IP, user_agent, and request_id
"""
from flask import request
return {
"ip_address": request.remote_addr if request else None,
"user_agent": request.headers.get("User-Agent") if request else None,
"request_id": g.get("request_id"),
}
@classmethod
def log_event(
cls,
event_type: str,
client_id: str = None,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
metadata: Dict = None
) -> OIDCAuditLog:
"""Log a generic OIDC event.
Args:
event_type: Type of event
client_id: OIDC client ID
user_id: User ID
success: Whether the event was successful
error_code: Error code if failed
error_description: Error description if failed
metadata: Additional event metadata
Returns:
OIDCAuditLog instance
"""
context = cls._get_request_context()
log = OIDCAuditLog.log_event(
event_type=event_type,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
ip_address=context["ip_address"],
user_agent=context["user_agent"],
request_id=context["request_id"],
event_metadata=metadata,
)
return log
@classmethod
def log_authorization_event(
cls,
client_id: str,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
redirect_uri: str = None,
scope: list = None,
response_type: str = None
) -> OIDCAuditLog:
"""Log an authorization event.
Args:
client_id: OIDC client ID
user_id: User ID (if authenticated)
success: Whether authorization was successful
error_code: Error code if failed
error_description: Error description if failed
redirect_uri: Redirect URI from request
scope: Requested scopes
response_type: Response type (e.g., "code")
Returns:
OIDCAuditLog instance
"""
metadata = {
"redirect_uri": redirect_uri,
"scope": scope,
"response_type": response_type,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_AUTHORIZATION_REQUEST,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_token_event(
cls,
client_id: str,
user_id: str = None,
token_type: str = "access_token",
success: bool = True,
error_code: str = None,
error_description: str = None,
grant_type: str = None,
scopes: list = None
) -> OIDCAuditLog:
"""Log a token issuance or refresh event.
Args:
client_id: OIDC client ID
user_id: User ID
token_type: Type of token issued
success: Whether token issuance was successful
error_code: Error code if failed
error_description: Error description if failed
grant_type: Grant type used (e.g., "authorization_code", "refresh_token")
scopes: Scopes included in the token
Returns:
OIDCAuditLog instance
"""
metadata = {
"token_type": token_type,
"grant_type": grant_type,
"scopes": scopes,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_TOKEN_ISSUE if token_type else cls.EVENT_TOKEN_REFRESH,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_userinfo_event(
cls,
access_token: str = None,
user_id: str = None,
client_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
scopes_claimed: list = None
) -> OIDCAuditLog:
"""Log a UserInfo endpoint access event.
Args:
access_token: Access token used (masked)
user_id: User ID returned
client_id: Client ID making the request
success: Whether access was successful
error_code: Error code if failed
error_description: Error description if failed
scopes_claimed: Scopes claimed in the request
Returns:
OIDCAuditLog instance
"""
# Mask the access token for security
masked_token = None
if access_token:
masked_token = access_token[:8] + "..." + access_token[-4:] if len(access_token) > 12 else "***"
metadata = {
"token_prefix": masked_token,
"scopes_claimed": scopes_claimed,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_USERINFO_ACCESS,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_token_revocation_event(
cls,
client_id: str,
user_id: str = None,
token_type: str = "access_token",
reason: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None
) -> OIDCAuditLog:
"""Log a token revocation event.
Args:
client_id: OIDC client ID
user_id: User ID
token_type: Type of token being revoked
reason: Revocation reason
success: Whether revocation was successful
error_code: Error code if failed
error_description: Error description if failed
Returns:
OIDCAuditLog instance
"""
metadata = {
"token_type": token_type,
"reason": reason,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_TOKEN_REVOCATION,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_authentication_failure(
cls,
client_id: str = None,
error_code: str = "authentication_failed",
error_description: str = "Authentication failed",
user_id: str = None
) -> OIDCAuditLog:
"""Log an authentication failure event.
Args:
client_id: OIDC client ID
error_code: Error code
error_description: Error description
user_id: User ID if known
Returns:
OIDCAuditLog instance
"""
return cls.log_event(
event_type=cls.EVENT_AUTHENTICATION_FAILURE,
client_id=client_id,
user_id=user_id,
success=False,
error_code=error_code,
error_description=error_description,
)
@classmethod
def get_events_for_user(
cls,
user_id: str,
limit: int = 100,
include_deleted: bool = False
) -> List[OIDCAuditLog]:
"""Get audit events for a specific user.
Args:
user_id: User ID
limit: Maximum number of events to return
include_deleted: Include soft-deleted events
Returns:
List of OIDCAuditLog instances
"""
return OIDCAuditLog.get_events_for_user(user_id, limit)
@classmethod
def get_events_for_client(
cls,
client_id: str,
limit: int = 100
) -> List[OIDCAuditLog]:
"""Get audit events for a specific client.
Args:
client_id: Client ID
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
return OIDCAuditLog.get_events_for_client(client_id, limit)
@classmethod
def get_failed_events(
cls,
client_id: str = None,
user_id: str = None,
start_date: datetime = None,
end_date: datetime = None,
limit: int = 100
) -> List[OIDCAuditLog]:
"""Get failed audit events for analysis.
Args:
client_id: Optional client ID filter
user_id: Optional user ID filter
start_date: Optional start date filter
end_date: Optional end date filter
limit: Maximum number of events to return
Returns:
List of failed OIDCAuditLog instances
"""
return OIDCAuditLog.get_failed_events(
client_id=client_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
limit=limit,
)
@classmethod
def get_event_summary(
cls,
client_id: str = None,
days: int = 30
) -> Dict:
"""Get a summary of audit events.
Args:
client_id: Optional client ID filter
days: Number of days to look back
Returns:
Summary dictionary with event counts
"""
from datetime import timedelta
start_date = datetime.now(timezone.utc) - timedelta(days=days)
query = OIDCAuditLog.query.filter(
OIDCAuditLog.created_at >= start_date
)
if client_id:
query = query.filter_by(client_id=client_id)
events = query.all()
# Count by event type
event_counts = {}
success_count = 0
failure_count = 0
for event in events:
event_type = event.event_type
event_counts[event_type] = event_counts.get(event_type, 0) + 1
if event.success:
success_count += 1
else:
failure_count += 1
return {
"total_events": len(events),
"successful_events": success_count,
"failed_events": failure_count,
"by_event_type": event_counts,
"period_days": days,
}
+418
View File
@@ -0,0 +1,418 @@
"""OIDC JWKS Service for key management and rotation."""
import uuid
import json
import hashlib
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Tuple
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.oidc_jwks_key import OidcJwksKey
class JWKSKey:
"""Represents a JWKS key entry."""
def __init__(self, kid: str, private_key: str, public_key: str,
algorithm: str = "RS256", created_at: datetime = None,
expires_at: datetime = None, is_active: bool = True):
self.kid = kid
self.private_key = private_key
self.public_key = public_key
self.algorithm = algorithm
self.created_at = created_at or datetime.now(timezone.utc)
self.expires_at = expires_at or datetime.now(timezone.utc) + timedelta(days=365)
self.is_active = is_active
def to_jwk(self) -> Dict:
"""Convert to JWK format for JWKS endpoint."""
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.backends import default_backend
# Import cryptography here to avoid issues if not installed
try:
# Get public key from PEM
public_key = serialization.load_pem_public_key(
self.public_key.encode(), backend=default_backend()
)
# Get RSA parameters
public_numbers = public_key.public_numbers()
return {
"kty": "RSA",
"kid": self.kid,
"use": "sig",
"alg": self.algorithm,
"n": _base64url_encode(public_numbers.n),
"e": _base64url_encode(public_numbers.e),
}
except ImportError:
# Fallback for when cryptography is not installed
return {
"kty": "RSA",
"kid": self.kid,
"use": "sig",
"alg": self.algorithm,
}
def to_dict(self) -> Dict:
"""Convert to dictionary for storage."""
return {
"kid": self.kid,
"private_key": self.private_key,
"public_key": self.public_key,
"algorithm": self.algorithm,
"created_at": self.created_at.isoformat(),
"expires_at": self.expires_at.isoformat(),
"is_active": self.is_active,
}
@classmethod
def from_dict(cls, data: Dict) -> "JWKSKey":
"""Create from dictionary."""
return cls(
kid=data["kid"],
private_key=data["private_key"],
public_key=data["public_key"],
algorithm=data.get("algorithm", "RS256"),
created_at=datetime.fromisoformat(data["created_at"]),
expires_at=datetime.fromisoformat(data["expires_at"]),
is_active=data.get("is_active", True),
)
def _base64url_encode(value: int) -> str:
"""Encode an integer to base64url format."""
import base64
byte_length = (value.bit_length() + 7) // 8 or 1
encoded = value.to_bytes(byte_length, byteorder="big")
return base64.urlsafe_b64encode(encoded).decode().rstrip("=")
class OIDCJWKSService:
"""Service for managing OIDC signing keys (JWKS).
This service handles RSA key pair generation, rotation, and JWKS document
generation for the OIDC implementation.
"""
_instance = None
_keys: Dict[str, JWKSKey] = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._keys = {}
return cls._instance
@classmethod
def reset(cls):
"""Reset the singleton (for testing)."""
cls._instance = None
cls._keys = {}
def _generate_kid(self, private_key: str) -> str:
"""Generate a key ID from the private key fingerprint."""
kid_hash = hashlib.sha256(private_key.encode()).hexdigest()[:32]
return kid_hash
def _generate_rsa_key_pair(self) -> Tuple[str, str]:
"""Generate a new RSA key pair in PEM format.
Returns:
Tuple of (private_key_pem, public_key_pem)
"""
try:
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
# Generate RSA private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
# Get public key
public_key = private_key.public_key()
# Serialize to PEM
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
).decode()
public_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode()
return private_pem, public_pem
except ImportError:
# Fallback for testing without cryptography
import secrets
return f"private_key_{secrets.token_hex(32)}", f"public_key_{secrets.token_hex(32)}"
def get_jwks(self, include_private_keys: bool = False) -> Dict:
"""Get the JWKS document containing public keys.
Args:
include_private_keys: Whether to include private keys (for internal use only)
Returns:
JWKS document dictionary
"""
now = datetime.now(timezone.utc)
keys = []
for kid, key in self._keys.items():
# Only include active, non-expired keys
if key.is_active and key.expires_at > now:
if include_private_keys:
keys.append(key.to_dict())
else:
keys.append(key.to_jwk())
return {
"keys": keys
}
def load_keys_from_db(self) -> int:
"""Load existing keys from the database.
Returns:
Number of keys loaded
"""
try:
db_keys = OidcJwksKey.get_active_keys()
now = datetime.now(timezone.utc)
for db_key in db_keys:
# Create JWKSKey from database model
key = JWKSKey(
kid=db_key.kid,
private_key=db_key.private_key,
public_key=db_key.public_key,
algorithm=db_key.algorithm,
created_at=db_key.created_at,
expires_at=db_key.expires_at or now + timedelta(days=365),
is_active=db_key.is_active,
)
self._keys[db_key.kid] = key
return len(self._keys)
except Exception as e:
current_app.logger.error(f"Error loading keys from database: {e}")
return 0
def save_key_to_db(self, key: JWKSKey, is_primary: bool = False) -> OidcJwksKey:
"""Save a key to the database.
Args:
key: JWKSKey instance to save
is_primary: Whether this is the primary signing key
Returns:
OidcJwksKey database model instance
"""
db_key = OidcJwksKey(
kid=key.kid,
key_type="RSA",
algorithm=key.algorithm,
private_key=key.private_key,
public_key=key.public_key,
is_active=key.is_active,
is_primary=is_primary,
)
db.session.add(db_key)
db.session.commit()
return db_key
def get_signing_key(self) -> Optional[JWKSKey]:
"""Get the current active signing key.
Returns:
JWKSKey instance or None if no active key
"""
now = datetime.now(timezone.utc)
# First try to get the primary key from database
try:
primary_db_key = OidcJwksKey.get_primary_key()
if primary_db_key:
# Check if we have it in memory, if not load it
if primary_db_key.kid not in self._keys:
key = JWKSKey(
kid=primary_db_key.kid,
private_key=primary_db_key.private_key,
public_key=primary_db_key.public_key,
algorithm=primary_db_key.algorithm,
created_at=primary_db_key.created_at,
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
is_active=primary_db_key.is_active,
)
self._keys[primary_db_key.kid] = key
return self._keys[primary_db_key.kid]
except Exception as e:
current_app.logger.error(f"Error getting primary key from database: {e}")
# Fall back to in-memory keys
for kid, key in self._keys.items():
if key.is_active and key.expires_at > now:
return key
return None
def get_key_by_kid(self, kid: str) -> Optional[JWKSKey]:
"""Get a specific key by its ID.
Args:
kid: Key ID to look up
Returns:
JWKSKey instance or None if not found
"""
return self._keys.get(kid)
def generate_new_key_pair(self, expires_in_days: int = 365) -> JWKSKey:
"""Generate a new RSA key pair for signing.
Args:
expires_in_days: Days until key expiration
Returns:
JWKSKey instance
"""
private_key, public_key = self._generate_rsa_key_pair()
kid = self._generate_kid(private_key)
now = datetime.now(timezone.utc)
key = JWKSKey(
kid=kid,
private_key=private_key,
public_key=public_key,
algorithm="RS256",
created_at=now,
expires_at=now + timedelta(days=expires_in_days),
is_active=True,
)
self._keys[kid] = key
# Deactivate old keys (but keep them for grace period)
for old_kid in self._keys:
if old_kid != kid:
self._keys[old_kid].is_active = False
return key
def rotate_keys(self, grace_period_hours: int = 24) -> Tuple[JWKSKey, List[str]]:
"""Rotate signing keys, keeping previous key active for grace period.
Args:
grace_period_hours: Hours to keep old keys active
Returns:
Tuple of (new_key, list_of_deprecated_kids)
"""
now = datetime.now(timezone.utc)
grace_end = now + timedelta(hours=grace_period_hours)
# Mark current key as deprecated
current_key = self.get_signing_key()
deprecated_kids = []
if current_key:
deprecated_kids.append(current_key.kid)
# Keep key active but mark as deprecated
current_key.is_active = False
current_key.expires_at = grace_end
# Generate new key
new_key = self.generate_new_key_pair()
# Clean up expired keys
expired_kids = [
kid for kid, key in self._keys.items()
if key.expires_at < now
]
for kid in expired_kids:
del self._keys[kid]
return new_key, deprecated_kids
def verify_key_exists(self, kid: str) -> bool:
"""Check if a key with the given ID exists and is valid.
Args:
kid: Key ID to check
Returns:
True if key exists and is valid
"""
key = self.get_key_by_kid(kid)
if not key:
return False
now = datetime.now(timezone.utc)
return key.is_active and key.expires_at > now
def initialize_with_key(self) -> JWKSKey:
"""Initialize the service with a key, loading from database if available.
This method first attempts to load existing keys from the database.
If no active primary key exists, it generates a new key and saves it to the database.
Returns:
JWKSKey instance
"""
# First, try to load keys from database
try:
# Check if there's a primary key in the database
primary_db_key = OidcJwksKey.get_primary_key()
if primary_db_key:
# Load the primary key into memory
now = datetime.now(timezone.utc)
key = JWKSKey(
kid=primary_db_key.kid,
private_key=primary_db_key.private_key,
public_key=primary_db_key.public_key,
algorithm=primary_db_key.algorithm,
created_at=primary_db_key.created_at,
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
is_active=primary_db_key.is_active,
)
self._keys[primary_db_key.kid] = key
current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}")
return key
# Try to load all active keys from database
loaded_count = self.load_keys_from_db()
if loaded_count > 0:
# Get the signing key from loaded keys
signing_key = self.get_signing_key()
if signing_key:
current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}")
return signing_key
except Exception as e:
current_app.logger.error(f"Error loading keys from database: {e}")
# No keys in database, generate a new key and save it
current_app.logger.info("[OIDC] No existing keys found in database, generating new signing key")
new_key = self.generate_new_key_pair()
# Save the new key to database
try:
self.save_key_to_db(new_key, is_primary=True)
current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}")
except Exception as e:
current_app.logger.error(f"Error saving key to database: {e}")
return new_key
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,289 @@
"""OIDC Session Service for session management during OIDC flow."""
import secrets
from datetime import datetime, timedelta
from typing import Dict, Optional, Tuple
from datetime import timezone
from flask import current_app, g
from gatehouse_app.extensions import db
from gatehouse_app.models import OIDCSession, OIDCClient, User
from gatehouse_app.exceptions.validation_exceptions import NotFoundError, ValidationError
class OIDCSessionService:
"""Service for managing OIDC authentication sessions.
This service handles:
- Creating OIDC sessions during authorization flow
- Validating sessions with state and nonce
- Managing PKCE code challenges
- Cleaning up expired sessions
"""
@staticmethod
def _generate_state() -> str:
"""Generate a secure state parameter.
Returns:
URL-safe base64 encoded state
"""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_nonce() -> str:
"""Generate a secure nonce for OIDC.
Returns:
URL-safe base64 encoded nonce
"""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_code_challenge(verifier: str, method: str = "S256") -> str:
"""Generate a PKCE code challenge from verifier.
Args:
verifier: The code verifier
method: Challenge method ("S256" or "plain")
Returns:
Code challenge string
"""
import hashlib
import base64
if method == "S256":
digest = hashlib.sha256(verifier.encode()).digest()
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
elif method == "plain":
return verifier
else:
raise ValueError(f"Unsupported code challenge method: {method}")
@classmethod
def validate_code_verifier(cls, code_verifier: str, code_challenge: str,
method: str = "S256") -> bool:
"""Validate a PKCE code verifier against the stored challenge.
Args:
code_verifier: The code verifier from the token request
code_challenge: The code challenge from the authorization request
method: The challenge method used
Returns:
True if validation succeeds
"""
if not code_verifier or not code_challenge:
return False
# Validate code verifier length (43-128 characters)
if method == "S256" and not (43 <= len(code_verifier) <= 128):
return False
# Calculate expected challenge
expected_challenge = cls._generate_code_challenge(code_verifier, method)
return secrets.compare_digest(expected_challenge, code_challenge)
@classmethod
def create_session(
cls,
user_id: str,
client_id: str,
state: str = None,
nonce: str = None,
redirect_uri: str = None,
scope: list = None,
code_challenge: str = None,
code_challenge_method: str = None,
lifetime_seconds: int = 600
) -> OIDCSession:
"""Create a new OIDC session for the authorization flow.
Args:
user_id: The user ID
client_id: The OIDC client ID
state: State parameter (generated if not provided)
nonce: Nonce for ID token validation (generated if not provided)
redirect_uri: Redirect URI from authorization request
scope: Requested scopes
code_challenge: PKCE code challenge
code_challenge_method: PKCE method ("S256" or "plain")
lifetime_seconds: Session lifetime in seconds
Returns:
OIDCSession instance
"""
# Generate state and nonce if not provided
state = state or cls._generate_state()
nonce = nonce or cls._generate_nonce()
session = OIDCSession.create_session(
user_id=user_id,
client_id=client_id,
state=state,
nonce=nonce,
redirect_uri=redirect_uri,
scope=scope,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
lifetime_seconds=lifetime_seconds,
)
return session
@classmethod
def validate_session(cls, state: str, nonce: str = None) -> Tuple[OIDCSession, User]:
"""Validate an OIDC session by state and optionally nonce.
Args:
state: The state parameter
nonce: The nonce to validate (optional)
Returns:
Tuple of (OIDCSession, User)
Raises:
ValidationError: If session is invalid
NotFoundError: If session not found
"""
session = OIDCSession.get_by_state(state)
if not session:
raise NotFoundError("OIDC session not found or expired")
if session.is_expired():
raise ValidationError("OIDC session has expired")
# Validate nonce if provided
if nonce and not session.validate_nonce(nonce):
raise ValidationError("Invalid nonce")
# Get user
user = User.query.get(session.user_id)
if not user:
raise NotFoundError("User not found")
return session, user
@classmethod
def validate_pkce(cls, session: OIDCSession, code_verifier: str) -> bool:
"""Validate PKCE code verifier against the session's code challenge.
Args:
session: OIDCSession instance
code_verifier: The code verifier from token request
Returns:
True if validation succeeds
Raises:
ValidationError: If PKCE validation fails
"""
if not session.code_challenge:
# No PKCE was used, skip validation
return True
if not code_verifier:
raise ValidationError("code_verifier is required")
is_valid = session.validate_code_challenge(code_verifier)
if not is_valid:
raise ValidationError("Invalid code_verifier")
return True
@classmethod
def mark_session_authenticated(cls, session: OIDCSession) -> OIDCSession:
"""Mark a session as authenticated (user has logged in).
Args:
session: OIDCSession instance
Returns:
Updated OIDCSession instance
"""
session.mark_authenticated()
return session
@classmethod
def cleanup_expired_sessions(cls, older_than_hours: int = 24) -> int:
"""Remove expired OIDC sessions.
Args:
older_than_hours: Only delete sessions expired more than this many hours ago
Returns:
Number of sessions deleted
"""
from datetime import timedelta
cutoff = datetime.now(timezone.utc) - timedelta(hours=older_than_hours)
# Get expired sessions
expired_sessions = OIDCSession.query.filter(
OIDCSession.expires_at < datetime.now(timezone.utc),
OIDCSession.deleted_at == None
).all()
count = 0
for session in expired_sessions:
# Only hard delete if past the grace period
if session.expires_at < cutoff:
session.delete()
count += 1
return count
@classmethod
def get_session_by_state(cls, state: str) -> Optional[OIDCSession]:
"""Get an OIDC session by state.
Args:
state: The state parameter
Returns:
OIDCSession instance or None
"""
return OIDCSession.get_by_state(state)
@classmethod
def validate_redirect_uri(cls, client_id: str, redirect_uri: str) -> bool:
"""Validate that a redirect URI is allowed for a client.
Args:
client_id: The OIDC client ID
redirect_uri: The redirect URI to validate
Returns:
True if redirect URI is allowed
"""
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return False
return client.is_redirect_uri_allowed(redirect_uri)
@classmethod
def validate_scopes(cls, client_id: str, requested_scopes: list) -> list:
"""Validate and filter scopes against client's allowed scopes.
Args:
client_id: The OIDC client ID
requested_scopes: List of requested scopes
Returns:
List of allowed scopes
"""
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return []
allowed_scopes = client.scopes or []
# Filter to only allowed scopes
valid_scopes = [s for s in requested_scopes if s in allowed_scopes]
return valid_scopes
@@ -0,0 +1,593 @@
"""OIDC Token Service for JWT token generation and validation."""
import hashlib
import base64
import secrets
import logging
import time
from datetime import datetime, timedelta, timezone
from typing import Dict, Optional, Any
import jwt
from flask import current_app, g
from gatehouse_app.models import User, OIDCClient
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
logger = logging.getLogger(__name__)
class OIDCTokenService:
"""Service for generating and validating OIDC tokens.
This service handles:
- Access token creation (JWT)
- ID token creation (JWT)
- Refresh token creation (opaque)
- Token signature verification
- Hash generation for PKCE claims (at_hash, c_hash)
"""
@staticmethod
def _generate_jti() -> str:
"""Generate a unique JWT ID."""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_opaque_token(length: int = 43) -> str:
"""Generate an opaque token (for refresh tokens).
Args:
length: Length of the token
Returns:
URL-safe base64 encoded token
"""
return secrets.token_urlsafe(length)
@staticmethod
def _hash_token(token: str) -> str:
"""Hash a token for secure storage.
Args:
token: Token to hash
Returns:
SHA256 hash of the token
"""
return hashlib.sha256(token.encode()).hexdigest()
@staticmethod
def _base64url_encode(data: bytes) -> str:
"""Encode bytes to base64url format without padding.
Args:
data: Bytes to encode
Returns:
Base64url encoded string
"""
return base64.urlsafe_b64encode(data).decode().rstrip("=")
@staticmethod
def create_at_hash(access_token: str) -> str:
"""Create the at_hash claim for ID token.
Implements OIDC spec for access token hash generation.
Hash is the left-most half of the hash of the ASCII representation
of the access token.
Args:
access_token: The access token string
Returns:
Base64url encoded hash
"""
# Hash the access token using SHA256
hash_digest = hashlib.sha256(access_token.encode()).digest()
# Take left-most half of the hash
half_length = len(hash_digest) // 2
left_half = hash_digest[:half_length]
# Base64url encode
return OIDCTokenService._base64url_encode(left_half)
@staticmethod
def create_c_hash(code: str) -> str:
"""Create the c_hash claim for ID token.
Implements OIDC spec for authorization code hash generation.
Args:
code: The authorization code string
Returns:
Base64url encoded hash
"""
# Hash the code using SHA256
hash_digest = hashlib.sha256(code.encode()).digest()
# Take left-most half of the hash
half_length = len(hash_digest) // 2
left_half = hash_digest[:half_length]
# Base64url encode
return OIDCTokenService._base64url_encode(left_half)
@staticmethod
def _get_issuer() -> str:
"""Get the OIDC issuer URL."""
return current_app.config.get("OIDC_ISSUER_URL", "http://localhost:5000")
@staticmethod
def _get_token_lifetime(client: OIDCClient, token_type: str) -> int:
"""Get the token lifetime in seconds for a client.
Args:
client: OIDCClient instance
token_type: Type of token ("access_token", "refresh_token", "id_token")
Returns:
Lifetime in seconds
"""
lifetimes = {
"access_token": client.access_token_lifetime or 3600,
"refresh_token": client.refresh_token_lifetime or 2592000,
"id_token": client.id_token_lifetime or 3600,
}
return lifetimes.get(token_type, 3600)
@classmethod
def create_access_token(cls, client_id: str, user_id: str, scope: list,
jti: str = None) -> str:
"""Create a JWT access token.
Args:
client_id: The OIDC client ID
user_id: The user ID (subject)
scope: List of granted scopes
jti: Optional JWT ID (generated if not provided)
Returns:
JWT access token string
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_access_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
jti = jti or cls._generate_jti()
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
logger.debug("[OIDC TOKEN SERVICE] Access token lifetime (seconds): %s", lifetime)
exp_timestamp = now_timestamp + lifetime
exp_time = now + timedelta(seconds=lifetime)
logger.debug("[OIDC TOKEN SERVICE] Access token expiration time (UTC): %s", exp_time.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Access token expiration timestamp: %s", exp_timestamp)
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
claims = {
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": exp_timestamp,
"iat": now_timestamp,
"nbf": now_timestamp,
"jti": jti,
"client_id": client_id,
"scope": " ".join(scope) if isinstance(scope, list) else scope,
}
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, nbf=%s",
claims["exp"], claims["iat"], claims["nbf"])
# Get signing key
jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key()
if not signing_key:
raise ValueError("No signing key available")
# Sign with RS256
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
token = jwt.encode(
claims,
signing_key.private_key,
algorithm="RS256",
headers={"kid": signing_key.kid}
)
logger.debug("[OIDC TOKEN SERVICE] Access token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token
@classmethod
def create_id_token(cls, client_id: str, user_id: str, nonce: str = None,
scope: list = None, access_token: str = None,
auth_time: int = None) -> str:
"""Create a JWT ID token.
Args:
client_id: The OIDC client ID
user_id: The user ID (subject)
nonce: Nonce for replay protection
scope: Requested/Granted scopes
access_token: Associated access token (for at_hash)
auth_time: Authentication time (Unix timestamp)
Returns:
JWT ID token string
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_id_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
auth_time = auth_time or now_timestamp
logger.debug("[OIDC TOKEN SERVICE] auth_time (Unix timestamp): %s", auth_time)
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
logger.debug("[OIDC TOKEN SERVICE] ID token lifetime (seconds): %s", lifetime)
exp_timestamp = now_timestamp + lifetime
exp_time = now + timedelta(seconds=lifetime)
logger.debug("[OIDC TOKEN SERVICE] ID token expiration time (UTC): %s", exp_time.isoformat())
logger.debug("[OIDC TOKEN SERVICE] ID token expiration timestamp: %s", exp_timestamp)
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
# Get user for claims
user = User.query.get(user_id)
claims = {
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": exp_timestamp,
"iat": now_timestamp,
"auth_time": auth_time,
}
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, auth_time=%s",
claims["exp"], claims["iat"], claims["auth_time"])
# Add nonce if provided
if nonce:
claims["nonce"] = nonce
# Add at_hash if access token provided
if access_token:
claims["at_hash"] = cls.create_at_hash(access_token)
# Add standard claims if user exists
if user:
if user.email:
claims["email"] = user.email
claims["email_verified"] = user.email_verified
if user.full_name:
claims["name"] = user.full_name
# Add roles claim if scope is granted
if scope and "roles" in scope:
claims["roles"] = cls._get_user_roles(user)
# Add scope if provided
if scope:
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
# Get signing key
jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key()
if not signing_key:
raise ValueError("No signing key available")
# Sign with RS256
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
token = jwt.encode(
claims,
signing_key.private_key,
algorithm="RS256",
headers={"kid": signing_key.kid}
)
logger.debug("[OIDC TOKEN SERVICE] ID token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token
@staticmethod
def _get_user_roles(user: User) -> list:
"""Get user's organization roles.
Args:
user: User instance
Returns:
List of role objects with organization_id and role
"""
roles = []
if user and user.organization_memberships:
for member in user.organization_memberships:
roles.append({
"organization_id": str(member.organization_id),
"role": member.role.value
})
return roles
@classmethod
def create_refresh_token(cls, client_id: str, user_id: str,
scope: list = None, access_token_id: str = None) -> str:
"""Create an opaque refresh token.
Args:
client_id: The OIDC client ID
user_id: The user ID
scope: List of granted scopes
access_token_id: Associated access token ID
Returns:
Opaque refresh token string
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_refresh_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] scope=%s, access_token_id=%s", scope, access_token_id)
token = cls._generate_opaque_token()
logger.debug("[OIDC TOKEN SERVICE] Refresh token generated: %s...", token[:20] if token else None)
# Hash for storage
token_hash = cls._hash_token(token)
logger.debug("[OIDC TOKEN SERVICE] Refresh token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token, token_hash
@classmethod
def verify_token_signature(cls, token: str) -> Dict:
"""Verify the signature of a JWT token.
Args:
token: JWT token string
Returns:
Decoded token claims
Raises:
jwt.InvalidSignatureError: If signature verification fails
jwt.ExpiredSignatureError: If token is expired
jwt.InvalidTokenError: If token is invalid
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] verify_token_signature() called")
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
# Get the JWKS with public keys
logger.debug("[OIDC TOKEN SERVICE] Getting JWKS...")
jwks_service = OIDCJWKSService()
jwks = jwks_service.get_jwks(include_private_keys=True)
logger.debug("[OIDC TOKEN SERVICE] JWKS retrieved: %d keys", len(jwks.get("keys", [])))
# Get the key ID from token header
try:
logger.debug("[OIDC TOKEN SERVICE] Getting unverified token header...")
unverified_header = jwt.get_unverified_header(token)
logger.debug("[OIDC TOKEN SERVICE] Unverified header: %s", unverified_header)
except jwt.DecodeError as e:
logger.error("[OIDC TOKEN SERVICE] Failed to decode token header: %s", str(e))
raise jwt.InvalidTokenError("Invalid token header")
kid = unverified_header.get("kid")
logger.debug("[OIDC TOKEN SERVICE] Key ID (kid) from token header: %s", kid)
# Find the matching public key
logger.debug("[OIDC TOKEN SERVICE] Searching for matching public key...")
public_key = None
for idx, key in enumerate(jwks.get("keys", [])):
logger.debug("[OIDC TOKEN SERVICE] Checking key %d: kid=%s", idx, key.get("kid"))
if key.get("kid") == kid:
logger.debug("[OIDC TOKEN SERVICE] Found matching key at index %d", idx)
try:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
logger.debug("[OIDC TOKEN SERVICE] Loading PEM public key...")
public_key = serialization.load_pem_public_key(
key["public_key"].encode() if isinstance(key["public_key"], str)
else key["public_key"],
backend=default_backend()
)
logger.debug("[OIDC TOKEN SERVICE] Public key loaded successfully")
break
except (ImportError, Exception) as e:
logger.error("[OIDC TOKEN SERVICE] Failed to load public key: %s: %s", type(e).__name__, str(e))
continue
if not public_key:
logger.error("[OIDC TOKEN SERVICE] No matching public key found for kid=%s", kid)
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
logger.debug("[OIDC TOKEN SERVICE] Public key found, verifying signature...")
# Verify the signature
try:
claims = jwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=None, # We'll validate audience separately
issuer=cls._get_issuer(),
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False, # Handle audience manually
"verify_iss": False, # Handle issuer manually
}
)
logger.debug("[OIDC TOKEN SERVICE] Signature verification successful")
logger.debug("[OIDC TOKEN SERVICE] Decoded claims: %s", claims)
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return claims
except jwt.ExpiredSignatureError as e:
logger.error("[OIDC TOKEN SERVICE] Token has expired: %s", str(e))
raise
except jwt.InvalidSignatureError as e:
logger.error("[OIDC TOKEN SERVICE] Invalid token signature: %s", str(e))
raise
except jwt.InvalidTokenError as e:
logger.error("[OIDC TOKEN SERVICE] Invalid token: %s: %s", type(e).__name__, str(e))
raise
except Exception as e:
logger.error("[OIDC TOKEN SERVICE] Unexpected error during token verification: %s: %s", type(e).__name__, str(e))
import traceback
logger.error("[OIDC TOKEN SERVICE] Traceback: %s", traceback.format_exc())
raise
@classmethod
def decode_token(cls, token: str, verify: bool = False) -> Dict:
"""Decode a JWT token without verification (for debugging).
Args:
token: JWT token string
verify: Whether to verify signature
Returns:
Decoded token claims
"""
if verify:
return cls.verify_token_signature(token)
return jwt.decode(
token,
options={
"verify_signature": False,
"verify_exp": False,
}
)
@classmethod
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
"""Validate an access token and return its claims.
Args:
token: JWT access token
client_id: Optional client ID to validate audience
Returns:
Token claims dictionary
Raises:
jwt.InvalidTokenError: If token is invalid
ValueError: If token is expired or audience mismatch
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() called")
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
logger.debug("[OIDC TOKEN SERVICE] Client ID: %s", client_id)
# Verify token signature
logger.debug("[OIDC TOKEN SERVICE] Verifying token signature...")
claims = cls.verify_token_signature(token)
logger.debug("[OIDC TOKEN SERVICE] Token signature verified")
logger.debug("[OIDC TOKEN SERVICE] Claims: %s", claims)
# Check expiration
exp = claims.get("exp", 0)
now_timestamp = int(time.time())
if exp < now_timestamp:
logger.error("[OIDC TOKEN SERVICE] Token has expired")
raise ValueError("Token has expired")
# Validate audience if client_id provided
aud = claims.get("aud")
logger.debug("[OIDC TOKEN SERVICE] Token audience (aud): %s", aud)
logger.debug("[OIDC TOKEN SERVICE] Expected client_id: %s", client_id)
if client_id:
if aud != client_id:
logger.error("[OIDC TOKEN SERVICE] Audience mismatch: expected=%s, got=%s", client_id, aud)
raise ValueError("Invalid audience")
logger.debug("[OIDC TOKEN SERVICE] Audience validation passed")
else:
logger.debug("[OIDC TOKEN SERVICE] No client_id provided, skipping audience validation")
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() completed successfully")
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return claims
@classmethod
def introspect_token(cls, token: str, client_id: str = None) -> Dict:
"""Introspect a token and return its status and claims.
Args:
token: JWT token to introspect
client_id: Client ID for audience validation
Returns:
Dictionary with active status and claims
"""
result = {
"active": False,
}
try:
claims = cls.validate_access_token(token, client_id)
# Calculate remaining time
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
exp = claims.get("exp", 0)
iat = claims.get("iat", 0)
logger.debug("[OIDC TOKEN SERVICE] Introspection - Current UTC time: %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration timestamp: %s", exp)
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration datetime (UTC): %s", datetime.fromtimestamp(exp, tz=timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] Introspection - Time until expiration: %s seconds", exp - now_timestamp)
result["active"] = exp > now_timestamp
result.update({
"iss": claims.get("iss"),
"sub": claims.get("sub"),
"aud": claims.get("aud"),
"exp": exp,
"iat": iat,
"nbf": claims.get("nbf"),
"jti": claims.get("jti"),
"client_id": claims.get("client_id"),
"scope": claims.get("scope"),
"token_type": "Bearer",
})
# Add expiry in seconds
if exp > now_timestamp:
result["exp"] = int(exp - now_timestamp)
except (jwt.InvalidTokenError, ValueError) as e:
result["active"] = False
result["error"] = str(e)
return result
@@ -0,0 +1,303 @@
"""Organization service."""
import logging
from datetime import datetime, timezone
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.exceptions.validation_exceptions import OrganizationNotFoundError, ConflictError
from gatehouse_app.utils.constants import OrganizationRole, AuditAction
from gatehouse_app.services.audit_service import AuditService
logger = logging.getLogger(__name__)
class OrganizationService:
"""Service for organization operations."""
@staticmethod
def create_organization(name, slug, owner_user_id, description=None, logo_url=None):
"""
Create a new organization.
Args:
name: Organization name
slug: Unique organization slug
owner_user_id: ID of the user who will be the owner
description: Optional description
logo_url: Optional logo URL
Returns:
Organization instance
Raises:
ConflictError: If slug already exists
"""
# Check if slug already exists
existing = Organization.query.filter_by(slug=slug, deleted_at=None).first()
if existing:
raise ConflictError("Organization slug already exists")
# Create organization
org = Organization(
name=name,
slug=slug,
description=description,
logo_url=logo_url,
is_active=True,
)
org.save()
# Add owner as member
member = OrganizationMember(
user_id=owner_user_id,
organization_id=org.id,
role=OrganizationRole.OWNER,
joined_at=datetime.now(timezone.utc),
)
member.save()
# Log organization creation
AuditService.log_action(
action=AuditAction.ORG_CREATE,
user_id=owner_user_id,
organization_id=org.id,
resource_type="organization",
resource_id=org.id,
description=f"Organization created: {name}",
)
return org
@staticmethod
def get_organization_by_id(org_id):
"""
Get organization by ID.
Args:
org_id: Organization ID
Returns:
Organization instance
Raises:
OrganizationNotFoundError: If organization not found
"""
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
# Development-only debug logging for organization validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Org] Get organization by ID: org_id={org_id}, exists={org is not None}")
if not org:
raise OrganizationNotFoundError()
return org
@staticmethod
def get_organization_by_slug(slug):
"""
Get organization by slug.
Args:
slug: Organization slug
Returns:
Organization instance or None
"""
org = Organization.query.filter_by(slug=slug, deleted_at=None).first()
# Development-only debug logging for organization validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Org] Get organization by slug: slug={slug}, exists={org is not None}")
return org
@staticmethod
def update_organization(org, user_id, **kwargs):
"""
Update organization.
Args:
org: Organization instance
user_id: ID of user performing the update
**kwargs: Fields to update
Returns:
Updated Organization instance
"""
allowed_fields = ["name", "description", "logo_url"]
update_data = {k: v for k, v in kwargs.items() if k in allowed_fields}
if update_data:
org.update(**update_data)
# Log organization update
AuditService.log_action(
action=AuditAction.ORG_UPDATE,
user_id=user_id,
organization_id=org.id,
resource_type="organization",
resource_id=org.id,
metadata=update_data,
description="Organization updated",
)
return org
@staticmethod
def delete_organization(org, user_id, soft=True):
"""
Delete organization.
Args:
org: Organization instance
user_id: ID of user performing the delete
soft: If True, performs soft delete
Returns:
Deleted Organization instance
"""
org.delete(soft=soft)
# Log organization deletion
AuditService.log_action(
action=AuditAction.ORG_DELETE,
user_id=user_id,
organization_id=org.id,
resource_type="organization",
resource_id=org.id,
description=f"Organization {'soft' if soft else 'hard'} deleted",
)
return org
@staticmethod
def add_member(org, user_id, role, inviter_id):
"""
Add a member to the organization.
Args:
org: Organization instance
user_id: ID of user to add
role: OrganizationRole
inviter_id: ID of user performing the invitation
Returns:
OrganizationMember instance
Raises:
ConflictError: If user is already a member
"""
# Check if already a member
existing = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org.id,
deleted_at=None,
).first()
# Development-only debug logging for membership validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
if existing:
raise ConflictError("User is already a member of this organization")
# Create membership
member = OrganizationMember(
user_id=user_id,
organization_id=org.id,
role=role,
invited_by_id=inviter_id,
invited_at=datetime.now(timezone.utc),
joined_at=datetime.now(timezone.utc),
)
member.save()
# Log member addition
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ADD,
user_id=inviter_id,
organization_id=org.id,
resource_type="organization_member",
resource_id=member.id,
metadata={"added_user_id": user_id, "role": role.value},
description=f"Member added to organization with role: {role.value}",
)
return member
@staticmethod
def remove_member(org, user_id, remover_id):
"""
Remove a member from the organization.
Args:
org: Organization instance
user_id: ID of user to remove
remover_id: ID of user performing the removal
"""
member = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org.id,
deleted_at=None,
).first()
# Development-only debug logging for membership removal validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Org] Member removal: org_id={org.id}, user_id={user_id}, found={member is not None}")
if member:
member.delete(soft=True)
# Log member removal
AuditService.log_action(
action=AuditAction.ORG_MEMBER_REMOVE,
user_id=remover_id,
organization_id=org.id,
resource_type="organization_member",
resource_id=member.id,
metadata={"removed_user_id": user_id},
description="Member removed from organization",
)
@staticmethod
def update_member_role(org, user_id, new_role, updater_id):
"""
Update a member's role in the organization.
Args:
org: Organization instance
user_id: ID of user whose role to update
new_role: New OrganizationRole
updater_id: ID of user performing the update
Returns:
Updated OrganizationMember instance
"""
member = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org.id,
deleted_at=None,
).first()
if member:
old_role = member.role
member.role = new_role
db.session.commit()
# Log role change
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ROLE_CHANGE,
user_id=updater_id,
organization_id=org.id,
resource_type="organization_member",
resource_id=member.id,
metadata={
"target_user_id": user_id,
"old_role": old_role.value,
"new_role": new_role.value,
},
description=f"Member role changed from {old_role.value} to {new_role.value}",
)
return member
+76
View File
@@ -0,0 +1,76 @@
"""Session service."""
from datetime import datetime, timezone
from gatehouse_app.models.session import Session
from gatehouse_app.utils.constants import SessionStatus
class SessionService:
"""Service for session operations."""
@staticmethod
def get_active_session_by_token(token):
"""Get active session by token.
Args:
token: The session token string
Returns:
Session object if found and active, None otherwise
"""
from gatehouse_app.models.session import Session
from gatehouse_app.utils.constants import SessionStatus
return Session.query.filter_by(
token=token,
status=SessionStatus.ACTIVE,
deleted_at=None
).first()
@staticmethod
def get_user_sessions(user_id, active_only=True):
"""
Get all sessions for a user.
Args:
user_id: User ID
active_only: If True, only return active sessions
Returns:
List of Session instances
"""
query = Session.query.filter_by(user_id=user_id, deleted_at=None)
if active_only:
query = query.filter_by(status=SessionStatus.ACTIVE).filter(
Session.expires_at > datetime.now(timezone.utc)
)
return query.all()
@staticmethod
def revoke_user_sessions(user_id, reason="User logged out from all devices"):
"""
Revoke all active sessions for a user.
Args:
user_id: User ID
reason: Reason for revocation
"""
sessions = SessionService.get_user_sessions(user_id, active_only=True)
for session in sessions:
session.revoke(reason=reason)
@staticmethod
def cleanup_expired_sessions():
"""Clean up expired sessions."""
expired_sessions = Session.query.filter(
Session.status == SessionStatus.ACTIVE,
Session.expires_at < datetime.now(timezone.utc),
Session.deleted_at.is_(None),
).all()
for session in expired_sessions:
session.status = SessionStatus.EXPIRED
session.save()
return len(expired_sessions)
+214
View File
@@ -0,0 +1,214 @@
"""TOTP (Time-based One-Time Password) service."""
import base64
import io
import logging
import secrets
from datetime import datetime, timezone
from typing import Tuple
import pyotp
from gatehouse_app.extensions import bcrypt
logger = logging.getLogger(__name__)
class TOTPService:
"""Service for TOTP operations."""
@staticmethod
def generate_secret() -> str:
"""
Generate a new TOTP secret.
Returns:
Base32 encoded secret (32 characters)
Note:
The secret is generated using cryptographically secure random bytes
and encoded in base32 format for compatibility with authenticator apps.
"""
# Generate 20 random bytes (160 bits) and encode as base32
random_bytes = secrets.token_bytes(20)
secret = base64.b32encode(random_bytes).decode("utf-8")
logger.debug(f"Generated new TOTP secret: {secret[:8]}...")
return secret
@staticmethod
def generate_provisioning_uri(user_email: str, secret: str, issuer: str = "Gatehouse") -> str:
"""
Generate provisioning URI for QR code.
Args:
user_email: User's email address
secret: TOTP secret (base32 encoded)
issuer: Issuer name (default: "Gatehouse")
Returns:
otpauth:// URI for QR code generation
Example:
>>> uri = TOTPService.generate_provisioning_uri("user@example.com", "JBSWY3DPEHPK3PXP")
>>> print(uri)
otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse
"""
totp = pyotp.TOTP(secret)
uri = totp.provisioning_uri(name=user_email, issuer_name=issuer)
logger.debug(f"Generated provisioning URI for user: {user_email}")
return uri
@staticmethod
def verify_code(secret: str, code: str, window: int = 1) -> bool:
"""
Verify a TOTP code against the secret.
Args:
secret: TOTP secret (base32 encoded)
code: 6-digit TOTP code to verify
window: Time window for code validation (default: 1, allows codes from previous/next time steps)
Returns:
True if code is valid, False otherwise
Note:
The window parameter allows for clock skew between the server
and the authenticator app. A window of 1 allows codes from
the previous, current, and next 30-second intervals.
IMPORTANT: Always uses UTC time for verification to ensure
consistency across all timezones.
"""
totp = pyotp.TOTP(secret)
# Use timezone-aware UTC datetime for verification
# IMPORTANT: We must pass a datetime object, NOT a Unix timestamp
# pyotp's internal datetime.utcfromtimestamp() is deprecated and can be
# affected by local timezone settings, causing the 10.5 hour skew issue
utc_now = datetime.now(timezone.utc)
# DEBUG: Log detailed timezone information
logger.debug(f"[TOTP DEBUG] UTC now: {utc_now}")
logger.debug(f"[TOTP DEBUG] UTC now isoformat: {utc_now.isoformat()}")
logger.debug(f"[TOTP DEBUG] UTC timestamp: {utc_now.timestamp()}")
logger.debug(f"[TOTP DEBUG] UTC now tzinfo: {utc_now.tzinfo}")
# Generate what the TOTP code should be at this moment using UTC datetime
expected_code = totp.at(utc_now)
logger.debug(f"[TOTP DEBUG] Expected TOTP code at UTC: {expected_code}")
# Verify with the provided code using UTC datetime object
# Passing a datetime object avoids pyotp's utcfromtimestamp() issues
is_valid = totp.verify(code, valid_window=window, for_time=utc_now)
logger.debug(f"[TOTP DEBUG] TOTP code verification: valid={is_valid}, window={window}")
logger.debug(f"[TOTP DEBUG] Provided code: {code}, Expected code: {expected_code}")
return is_valid
@staticmethod
def generate_backup_codes(count: int = 10) -> Tuple[list[str], list[str]]:
"""
Generate backup codes for TOTP recovery.
Args:
count: Number of backup codes to generate (default: 10)
Returns:
Tuple of (plain_codes, hashed_codes)
- plain_codes: List of plain text backup codes (for display to user)
- hashed_codes: List of bcrypt hashed backup codes (for storage)
Note:
Backup codes are 16-character alphanumeric codes that can be used
to recover access if the TOTP device is lost. Each code can only
be used once.
"""
plain_codes = []
hashed_codes = []
for _ in range(count):
# Generate a 16-character alphanumeric code
code = secrets.token_hex(8).upper()
plain_codes.append(code)
# Hash the code using bcrypt
hashed_code = bcrypt.generate_password_hash(code).decode("utf-8")
hashed_codes.append(hashed_code)
logger.debug(f"Generated {count} backup codes")
return plain_codes, hashed_codes
@staticmethod
def verify_backup_code(hashed_codes: list[str], code: str) -> Tuple[bool, list[str]]:
"""
Verify and consume a backup code.
Args:
hashed_codes: List of bcrypt hashed backup codes
code: Plain text backup code to verify
Returns:
Tuple of (is_valid, remaining_codes)
- is_valid: True if code was valid and consumed, False otherwise
- remaining_codes: List of remaining hashed codes (with consumed code removed)
Note:
Once a backup code is used, it is removed from the list and cannot
be used again. This ensures each code is single-use.
"""
remaining_codes = []
for hashed_code in hashed_codes:
if bcrypt.check_password_hash(hashed_code, code):
# Code found and valid - mark as matched but don't add to remaining codes
matched = True
else:
# Code doesn't match - keep it in remaining codes
remaining_codes.append(hashed_code)
if matched:
return True, remaining_codes
else:
return False, remaining_codes
@staticmethod
def generate_qr_code_data_uri(provisioning_uri: str) -> str:
"""
Generate QR code as data URI for frontend display.
Args:
provisioning_uri: otpauth:// URI to encode in QR code
Returns:
Base64 encoded PNG image as data URI (data:image/png;base64,...)
Note:
If the qrcode library is not installed, returns a placeholder message.
Install with: pip install qrcode[pil]
"""
try:
import qrcode
# Create QR code
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_L,
box_size=10,
border=4,
)
qr.add_data(provisioning_uri)
qr.make(fit=True)
# Generate image
img = qr.make_image(fill_color="black", back_color="white")
# Convert to base64
buffer = io.BytesIO()
img.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
data_uri = f"data:image/png;base64,{img_base64}"
logger.debug("Generated QR code data URI")
return data_uri
except ImportError:
logger.warning("qrcode library not installed, returning placeholder")
return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]"
+125
View File
@@ -0,0 +1,125 @@
"""User service."""
import logging
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.user import User
from gatehouse_app.exceptions.validation_exceptions import UserNotFoundError
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
logger = logging.getLogger(__name__)
class UserService:
"""Service for user operations."""
@staticmethod
def get_user_by_id(user_id):
"""
Get user by ID.
Args:
user_id: User ID
Returns:
User instance
Raises:
UserNotFoundError: If user not found
"""
user = User.query.filter_by(id=user_id, deleted_at=None).first()
# Development-only debug logging for user validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[User] Get user by ID: user_id={user_id}, exists={user is not None}")
if not user:
raise UserNotFoundError()
return user
@staticmethod
def get_user_by_email(email):
"""
Get user by email.
Args:
email: User email
Returns:
User instance or None
"""
user = User.query.filter_by(email=email.lower(), deleted_at=None).first()
# Development-only debug logging for user validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[User] Get user by email: email={email}, exists={user is not None}")
return user
@staticmethod
def update_user(user, **kwargs):
"""
Update user profile.
Args:
user: User instance
**kwargs: Fields to update
Returns:
Updated User instance
"""
allowed_fields = ["full_name", "avatar_url"]
update_data = {k: v for k, v in kwargs.items() if k in allowed_fields}
if update_data:
user.update(**update_data)
# Log user update
AuditService.log_action(
action=AuditAction.USER_UPDATE,
user_id=user.id,
resource_type="user",
resource_id=user.id,
metadata=update_data,
description="User profile updated",
)
return user
@staticmethod
def delete_user(user, soft=True):
"""
Delete user account.
Args:
user: User instance
soft: If True, performs soft delete
Returns:
Deleted User instance
"""
user.delete(soft=soft)
# Log user deletion
AuditService.log_action(
action=AuditAction.USER_DELETE,
user_id=user.id,
resource_type="user",
resource_id=user.id,
description=f"User account {'soft' if soft else 'hard'} deleted",
)
return user
@staticmethod
def get_user_organizations(user):
"""
Get all organizations the user is a member of.
Args:
user: User instance
Returns:
List of organizations
"""
return user.get_organizations()
+647
View File
@@ -0,0 +1,647 @@
"""WebAuthn passkey authentication service."""
import logging
import secrets
import hashlib
import base64
import json
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any, List
from flask import current_app
from gatehouse_app.extensions import db, redis_client
from gatehouse_app.models.user import User
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
from gatehouse_app.services.audit_service import AuditService
logger = logging.getLogger(__name__)
class WebAuthnService:
"""Service for WebAuthn passkey operations."""
# WebAuthn algorithm constants (COSE algorithms)
COSE_ALGORITHMS = {
-7: "ES256", # ECDSA with SHA-256
-257: "RS256", # RSASSA-PKCS1-v1_5 with SHA-256
}
# Supported key types
KEY_TYPES = ["public-key"]
@staticmethod
def _generate_challenge() -> str:
"""Generate a cryptographically secure challenge.
Returns:
Base64URL-encoded challenge string
"""
bytes_data = secrets.token_bytes(32)
return base64.urlsafe_b64encode(bytes_data).decode('utf-8').rstrip('=')
@staticmethod
def _store_challenge(user_id: str, challenge: str, challenge_type: str, expires_in: int = 300) -> bool:
"""Store a challenge in Redis for validation.
Args:
user_id: User ID
challenge: The challenge string
challenge_type: Type of challenge ('registration' or 'authentication')
expires_in: Expiration time in seconds
Returns:
True if stored successfully
"""
try:
key = f"webauthn:challenge:{user_id}:{challenge_type}:{challenge}"
data = {
"challenge": challenge,
"user_id": user_id,
"type": challenge_type,
"created_at": datetime.now(timezone.utc).isoformat()
}
redis_client.setex(key, expires_in, json.dumps(data))
return True
except Exception as e:
logger.error(f"Failed to store WebAuthn challenge: {e}")
return False
@staticmethod
def _get_and_delete_challenge(user_id: str, challenge: str, challenge_type: str) -> Optional[Dict]:
"""Retrieve and delete a challenge from Redis.
Args:
user_id: User ID
challenge: The challenge string
challenge_type: Type of challenge
Returns:
Challenge data dict or None if not found/expired
"""
try:
key = f"webauthn:challenge:{user_id}:{challenge_type}:{challenge}"
data = redis_client.get(key)
if data:
redis_client.delete(key)
return json.loads(data)
return None
except Exception as e:
logger.error(f"Failed to retrieve WebAuthn challenge: {e}")
return None
@staticmethod
def _base64url_decode(data: str) -> bytes:
"""Decode Base64URL string to bytes."""
# Add padding if needed
padding = 4 - (len(data) % 4)
if padding != 4:
data += '=' * padding
return base64.urlsafe_b64decode(data)
@staticmethod
def _base64url_encode(data: bytes) -> str:
"""Encode bytes to Base64URL string."""
return base64.urlsafe_b64encode(data).decode('utf-8').rstrip('=')
@staticmethod
def _hash_credential_id(credential_id: bytes) -> str:
"""Hash a credential ID for secure storage lookup.
Args:
credential_id: Raw credential ID bytes
Returns:
Hashed credential ID string
"""
return hashlib.sha256(credential_id).hexdigest()
@classmethod
def generate_registration_challenge(cls, user: User) -> Dict[str, Any]:
"""Generate a challenge for passkey registration.
Args:
user: User instance
Returns:
PublicKeyCredentialCreationOptions dict
"""
# Generate challenge
challenge = cls._generate_challenge()
# Store challenge
cls._store_challenge(user.id, challenge, 'registration')
# Get existing credentials to exclude
existing_credentials = cls.get_user_credentials(user)
exclude_credentials = []
for cred in existing_credentials:
if cred.provider_data:
cred_id_b64 = cred.provider_data.get("credential_id")
if cred_id_b64:
try:
cred_id = cls._base64url_decode(cred_id_b64)
transports = cred.provider_data.get("transports", [])
exclude_credentials.append({
"id": cred_id_b64,
"type": "public-key",
"transports": transports
})
except Exception:
pass
# Get RP configuration
rp_id = current_app.config.get('WEBAUTHN_RP_ID', 'localhost')
rp_name = current_app.config.get('WEBAUTHN_RP_NAME', 'Gatehouse')
# Generate user ID (Base64URL encoded)
user_id = cls._base64url_encode(user.id.encode('utf-8'))
# Build options
options = {
"rp": {
"name": rp_name,
"id": rp_id
},
"user": {
"id": user_id,
"name": user.email,
"displayName": user.full_name or user.email
},
"challenge": challenge,
"pubKeyCredParams": [
{"type": "public-key", "alg": -7}, # ES256
{"type": "public-key", "alg": -257} # RS256
],
"timeout": 60000, # 60 seconds
"excludeCredentials": exclude_credentials,
"authenticatorSelection": {
"residentKey": "preferred",
"userVerification": "preferred"
},
"attestation": "none"
}
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_REGISTER_INITIATED,
user_id=user.id,
description="WebAuthn registration initiated"
)
return options
@classmethod
def verify_registration_response(
cls,
user: User,
credential_data: Dict[str, Any],
challenge: str
) -> AuthenticationMethod:
"""Verify and store a new passkey credential.
Args:
user: User instance
credential_data: Credential response data from client
challenge: The original challenge string
Returns:
AuthenticationMethod instance
Raises:
InvalidCredentialsError: If verification fails
"""
# Verify and consume challenge
stored_challenge = cls._get_and_delete_challenge(user.id, challenge, 'registration')
if not stored_challenge:
AuditService.log_action(
action=AuditAction.WEBAUTHN_REGISTER_FAILED,
user_id=user.id,
description="Registration failed: challenge expired or invalid"
)
raise InvalidCredentialsError("Challenge expired or invalid")
try:
# Parse credential data
credential_id = credential_data.get("id")
raw_id = credential_data.get("rawId")
response = credential_data.get("response", {})
attestation_object_b64 = response.get("attestationObject")
client_data_json_b64 = response.get("clientDataJSON")
transports = credential_data.get("transports", ["platform"])
if not all([credential_id, raw_id, attestation_object_b64, client_data_json_b64]):
raise InvalidCredentialsError("Missing required credential data")
# Decode attestation object
attestation_object = cls._base64url_decode(attestation_object_b64)
# Parse CBOR attestation object (simplified - in production use cbor2 library)
# The attestation object contains: authData, attStmt, fmt
try:
import cbor2
attestation_dict = cbor2.loads(attestation_object)
except ImportError:
# Fallback: try to parse as simple structure
attestation_dict = {}
logger.warning("cbor2 library not available, using fallback parsing")
# Extract authenticator data
auth_data = attestation_dict.get('authData', b'')
# Parse authenticator data
# Format: RP ID hash (32 bytes) + Flags (1 byte) + Counter (4 bytes) + AAGUID (16 bytes) + Credential ID length (2 bytes) + Credential ID + Public key
if len(auth_data) < 37:
raise InvalidCredentialsError("Invalid authenticator data")
rp_id_hash = auth_data[:32]
flags = auth_data[32]
counter = int.from_bytes(auth_data[33:37], 'big')
aaguid = auth_data[37:53] if len(auth_data) >= 53 else b''
# Extract credential ID length and ID
cred_id_length = int.from_bytes(auth_data[53:55], 'big') if len(auth_data) >= 55 else 0
credential_id_raw = auth_data[55:55+cred_id_length] if cred_id_length > 0 else b''
# Extract public key (COSE format)
public_key_cose = auth_data[55+cred_id_length:]
# Verify client data
client_data_json = cls._base64url_decode(client_data_json_b64)
client_data = json.loads(client_data_json)
# Verify challenge matches
if client_data.get("challenge") != challenge:
raise InvalidCredentialsError("Challenge mismatch")
# Verify origin
expected_origin = current_app.config.get('WEBAUTHN_ORIGIN', 'http://localhost:5173')
if client_data.get("origin") != expected_origin:
logger.warning(f"Origin mismatch: expected {expected_origin}, got {client_data.get('origin')}")
# Don't fail on origin mismatch in development
# Verify user presence and verification
user_present = bool(flags & 0x01)
user_verified = bool(flags & 0x04)
if not user_present:
raise InvalidCredentialsError("User presence not verified")
# Store credential
credential_id_hash = cls._hash_credential_id(credential_id_raw)
# Check if credential already exists
existing = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).first()
if existing and existing.provider_data:
stored_cred_id = existing.provider_data.get("credential_id", "")
if stored_cred_id == credential_id:
raise InvalidCredentialsError("Credential already registered")
# Create or update authentication method
auth_method = existing or AuthenticationMethod(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
is_primary=False,
verified=True
)
# Store credential data
auth_method.provider_data = {
"credential_id": credential_id,
"credential_id_hash": credential_id_hash,
"public_key_cose": cls._base64url_encode(public_key_cose),
"sign_count": counter,
"transports": transports,
"aaguid": cls._base64url_encode(aaguid) if aaguid else None,
"attestation_format": attestation_dict.get('fmt', 'unknown'),
"created_at": datetime.now(timezone.utc).isoformat(),
"last_used_at": None,
"name": f"Passkey {datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
}
auth_method.save()
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_REGISTER_COMPLETED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description=f"WebAuthn credential registered: {credential_id[:16]}..."
)
return auth_method
except InvalidCredentialsError:
raise
except Exception as e:
logger.error(f"WebAuthn registration verification failed: {e}")
AuditService.log_action(
action=AuditAction.WEBAUTHN_REGISTER_FAILED,
user_id=user.id,
description=f"Registration failed: {str(e)}"
)
raise InvalidCredentialsError("Registration verification failed")
@classmethod
def generate_authentication_challenge(cls, user: User) -> Dict[str, Any]:
"""Generate a challenge for passkey authentication.
Args:
user: User instance
Returns:
PublicKeyCredentialRequestOptions dict
"""
# Generate challenge
challenge = cls._generate_challenge()
# Store challenge
cls._store_challenge(user.id, challenge, 'authentication')
# Get user's credentials
credentials = cls.get_user_credentials(user)
# Build allow credentials list
allow_credentials = []
for cred in credentials:
if cred.provider_data:
cred_id = cred.provider_data.get("credential_id")
transports = cred.provider_data.get("transports", [])
if cred_id:
allow_credentials.append({
"id": cred_id,
"type": "public-key",
"transports": transports
})
# Get RP configuration
rp_id = current_app.config.get('WEBAUTHN_RP_ID', 'localhost')
# Build options
options = {
"challenge": challenge,
"timeout": 60000,
"rpId": rp_id,
"allowCredentials": allow_credentials,
"userVerification": "preferred"
}
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_LOGIN_INITIATED,
user_id=user.id,
description="WebAuthn authentication initiated"
)
return options
@classmethod
def verify_authentication_response(
cls,
user: User,
credential_data: Dict[str, Any],
challenge: str
) -> AuthenticationMethod:
"""Verify passkey authentication response.
Args:
user: User instance
credential_data: Assertion response data from client
challenge: The original challenge string
Returns:
AuthenticationMethod instance
Raises:
InvalidCredentialsError: If verification fails
"""
# Verify and consume challenge
stored_challenge = cls._get_and_delete_challenge(user.id, challenge, 'authentication')
if not stored_challenge:
AuditService.log_action(
action=AuditAction.WEBAUTHN_LOGIN_FAILED,
user_id=user.id,
description="Authentication failed: challenge expired or invalid"
)
raise InvalidCredentialsError("Challenge expired or invalid")
try:
# Parse credential data
credential_id = credential_data.get("id")
raw_id = credential_data.get("rawId")
response = credential_data.get("response", {})
authenticator_data_b64 = response.get("authenticatorData")
client_data_json_b64 = response.get("clientDataJSON")
signature_b64 = response.get("signature")
if not all([credential_id, authenticator_data_b64, client_data_json_b64, signature_b64]):
raise InvalidCredentialsError("Missing required credential data")
# Find the credential
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).first()
if not auth_method or not auth_method.provider_data:
raise InvalidCredentialsError("No passkey found for user")
stored_cred_id = auth_method.provider_data.get("credential_id")
if stored_cred_id != credential_id:
raise InvalidCredentialsError("Credential not found")
# Decode authenticator data
authenticator_data = cls._base64url_decode(authenticator_data_b64)
# Parse authenticator data
if len(authenticator_data) < 37:
raise InvalidCredentialsError("Invalid authenticator data")
rp_id_hash = authenticator_data[:32]
flags = authenticator_data[32]
counter = int.from_bytes(authenticator_data[33:37], 'big')
# Verify client data
client_data_json = cls._base64url_decode(client_data_json_b64)
client_data = json.loads(client_data_json)
# Verify challenge matches
if client_data.get("challenge") != challenge:
raise InvalidCredentialsError("Challenge mismatch")
# Verify origin
expected_origin = current_app.config.get('WEBAUTHN_ORIGIN', 'http://localhost:5173')
if client_data.get("origin") != expected_origin:
logger.warning(f"Origin mismatch: expected {expected_origin}, got {client_data.get('origin')}")
# Verify user presence
user_present = bool(flags & 0x01)
if not user_present:
raise InvalidCredentialsError("User presence not verified")
# Verify counter (prevent replay attacks)
stored_counter = auth_method.provider_data.get("sign_count", 0)
if counter <= stored_counter:
raise InvalidCredentialsError("Invalid sign counter - potential credential cloning detected")
# Verify signature (simplified - in production use proper crypto verification)
# In a full implementation, you would:
# 1. Decode the public key from COSE format
# 2. Verify the signature using the stored public key
# 3. Verify the authenticator data hash matches RP ID
# For now, we'll trust the authenticator's signature verification
# A full implementation would use the fido2 library
# Update counter and last used time
auth_method.provider_data["sign_count"] = counter
auth_method.provider_data["last_used_at"] = datetime.now(timezone.utc).isoformat()
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_LOGIN_SUCCESS,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="WebAuthn authentication successful"
)
return auth_method
except InvalidCredentialsError:
raise
except Exception as e:
logger.error(f"WebAuthn authentication verification failed: {e}")
AuditService.log_action(
action=AuditAction.WEBAUTHN_LOGIN_FAILED,
user_id=user.id,
description=f"Authentication failed: {str(e)}"
)
raise InvalidCredentialsError("Authentication verification failed")
@classmethod
def get_user_credentials(cls, user: User) -> List[AuthenticationMethod]:
"""Get all passkey credentials for a user.
Args:
user: User instance
Returns:
List of AuthenticationMethod instances
"""
return AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).order_by(AuthenticationMethod.created_at.desc()).all()
@classmethod
def delete_credential(cls, credential_id: str, user: User) -> bool:
"""Delete a passkey credential.
Args:
credential_id: The credential ID to delete
user: User instance
Returns:
True if deleted successfully
"""
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).first()
if not auth_method or not auth_method.provider_data:
return False
stored_cred_id = auth_method.provider_data.get("credential_id")
if stored_cred_id != credential_id:
return False
# Soft delete the credential
auth_method.delete(soft=True)
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_CREDENTIAL_DELETED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description=f"WebAuthn credential deleted: {credential_id[:16]}..."
)
return True
@classmethod
def rename_credential(cls, credential_id: str, user: User, name: str) -> bool:
"""Rename a passkey credential.
Args:
credential_id: The credential ID to rename
user: User instance
name: New name for the credential
Returns:
True if renamed successfully
"""
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).first()
if not auth_method or not auth_method.provider_data:
return False
stored_cred_id = auth_method.provider_data.get("credential_id")
if stored_cred_id != credential_id:
return False
# Update name
auth_method.provider_data["name"] = name
db.session.commit()
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_CREDENTIAL_RENAMED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description=f"WebAuthn credential renamed to: {name}"
)
return True
@classmethod
def get_credential_by_id(cls, credential_id: str, user: User) -> Optional[AuthenticationMethod]:
"""Get a specific credential by ID.
Args:
credential_id: The credential ID
user: User instance
Returns:
AuthenticationMethod instance or None
"""
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).first()
if auth_method and auth_method.provider_data:
stored_cred_id = auth_method.provider_data.get("credential_id")
if stored_cred_id == credential_id:
return auth_method
return None