diff --git a/gatehouse_app/models/auth/authentication_method.py b/gatehouse_app/models/auth/authentication_method.py index 3fbd7c0..c4c4352 100644 --- a/gatehouse_app/models/auth/authentication_method.py +++ b/gatehouse_app/models/auth/authentication_method.py @@ -213,7 +213,7 @@ class OrganizationProviderOverride(BaseModel): __table_args__ = ( db.UniqueConstraint( - "organization_id", "provider_type", name="uix_org_provider_type" + "organization_id", "provider_type", name="uix_org_provider_override_type" ), ) diff --git a/gatehouse_app/models/ssh_ca/ca.py b/gatehouse_app/models/ssh_ca/ca.py index 182d842..eee9909 100644 --- a/gatehouse_app/models/ssh_ca/ca.py +++ b/gatehouse_app/models/ssh_ca/ca.py @@ -107,6 +107,7 @@ class CA(BaseModel): ) __table_args__ = ( + db.UniqueConstraint("organization_id", "name", name="uix_org_ca_name"), db.Index("idx_ca_org_active", "organization_id", "is_active"), ) diff --git a/gatehouse_app/services/oidc_jwks_service.py b/gatehouse_app/services/oidc_jwks_service.py index 269dc08..ef8f993 100644 --- a/gatehouse_app/services/oidc_jwks_service.py +++ b/gatehouse_app/services/oidc_jwks_service.py @@ -188,6 +188,8 @@ class OIDCJWKSService: Returns: Number of keys loaded """ + if not self._table_exists(): + return 0 try: db_keys = OidcJwksKey.get_active_keys() now = datetime.now(timezone.utc) @@ -208,6 +210,7 @@ class OIDCJWKSService: return len(self._keys) except Exception as e: current_app.logger.error(f"Error loading keys from database: {e}") + db.session.rollback() return 0 def save_key_to_db(self, key: JWKSKey, is_primary: bool = False) -> OidcJwksKey: @@ -364,6 +367,15 @@ class OIDCJWKSService: now = datetime.now(timezone.utc) return key.is_active and key.expires_at > now + def _table_exists(self) -> bool: + """Check if the oidc_jwks_keys table exists in the database.""" + try: + from sqlalchemy import inspect + inspector = inspect(db.engine) + return "oidc_jwks_keys" in inspector.get_table_names() + except Exception: + return False + def initialize_with_key(self) -> JWKSKey: """Initialize the service with a key, loading from database if available. @@ -373,46 +385,55 @@ class OIDCJWKSService: Returns: JWKSKey instance """ - # First, try to load keys from database - try: - # Check if there's a primary key in the database - primary_db_key = OidcJwksKey.get_primary_key() - if primary_db_key: - # Load the primary key into memory - now = datetime.now(timezone.utc) - key = JWKSKey( - kid=primary_db_key.kid, - private_key=primary_db_key.private_key, - public_key=primary_db_key.public_key, - algorithm=primary_db_key.algorithm, - created_at=primary_db_key.created_at, - expires_at=primary_db_key.expires_at or now + timedelta(days=365), - is_active=primary_db_key.is_active, - ) - self._keys[primary_db_key.kid] = key - current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}") - return key - - # Try to load all active keys from database - loaded_count = self.load_keys_from_db() - if loaded_count > 0: - # Get the signing key from loaded keys - signing_key = self.get_signing_key() - if signing_key: - current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}") - return signing_key - except Exception as e: - current_app.logger.error(f"Error loading keys from database: {e}") + # Check if the table exists before attempting any DB operations + table_exists = self._table_exists() + + if table_exists: + # First, try to load keys from database + try: + # Check if there's a primary key in the database + primary_db_key = OidcJwksKey.get_primary_key() + if primary_db_key: + # Load the primary key into memory + now = datetime.now(timezone.utc) + key = JWKSKey( + kid=primary_db_key.kid, + private_key=primary_db_key.private_key, + public_key=primary_db_key.public_key, + algorithm=primary_db_key.algorithm, + created_at=primary_db_key.created_at, + expires_at=primary_db_key.expires_at or now + timedelta(days=365), + is_active=primary_db_key.is_active, + ) + self._keys[primary_db_key.kid] = key + current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}") + return key + + # Try to load all active keys from database + loaded_count = self.load_keys_from_db() + if loaded_count > 0: + # Get the signing key from loaded keys + signing_key = self.get_signing_key() + if signing_key: + current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}") + return signing_key + except Exception as e: + current_app.logger.error(f"Error loading keys from database: {e}") + db.session.rollback() + else: + current_app.logger.info("[OIDC] Table oidc_jwks_keys does not exist yet, skipping DB load") # No keys in database, generate a new key and save it current_app.logger.info("[OIDC] No existing keys found in database, generating new signing key") new_key = self.generate_new_key_pair() - # Save the new key to database - try: - self.save_key_to_db(new_key, is_primary=True) - current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}") - except Exception as e: - current_app.logger.error(f"Error saving key to database: {e}") + # Save the new key to database (only if table exists) + if table_exists: + try: + self.save_key_to_db(new_key, is_primary=True) + current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}") + except Exception as e: + current_app.logger.error(f"Error saving key to database: {e}") + db.session.rollback() return new_key diff --git a/scripts/init_db.py b/scripts/init_db.py index 7f69862..c6ecda9 100644 --- a/scripts/init_db.py +++ b/scripts/init_db.py @@ -1,7 +1,10 @@ """Initialize database script.""" from gatehouse_app import create_app from gatehouse_app.extensions import db +from sqlalchemy import text from dotenv import load_dotenv +import os +import time # Load environment variables load_dotenv() @@ -9,10 +12,20 @@ load_dotenv() # Create application app = create_app() +import gatehouse_app.models + with app.app_context(): - # Drop all tables - print("Dropping all tables...") - db.drop_all() + # Drop all tables, constraints, and indexes cleanly + db_url = os.getenv("DATABASE_URL", "") + db_name = db_url.split("/")[-1] if db_url else "gatehouse_db" + print(f"⚠️ WARNING: About to drop all tables in database '{db_name}'!") + print("Countdown to deletion:") + for i in range(5, 0, -1): + print(f"{i}...") + time.sleep(1) + db.session.execute(text("DROP SCHEMA public CASCADE")) + db.session.execute(text("CREATE SCHEMA public")) + db.session.commit() # Create all tables print("Creating all tables...")