Fix: Migration

oidc_jwks_keys table doesn't exist
uix_org_provider_type constraint multiple use
transaction abort/never rolled back
This commit is contained in:
2026-03-05 11:35:09 +05:45
parent 7cb522b590
commit cc9dc5064e
4 changed files with 75 additions and 40 deletions
@@ -213,7 +213,7 @@ class OrganizationProviderOverride(BaseModel):
__table_args__ = ( __table_args__ = (
db.UniqueConstraint( db.UniqueConstraint(
"organization_id", "provider_type", name="uix_org_provider_type" "organization_id", "provider_type", name="uix_org_provider_override_type"
), ),
) )
+1
View File
@@ -107,6 +107,7 @@ class CA(BaseModel):
) )
__table_args__ = ( __table_args__ = (
db.UniqueConstraint("organization_id", "name", name="uix_org_ca_name"),
db.Index("idx_ca_org_active", "organization_id", "is_active"), db.Index("idx_ca_org_active", "organization_id", "is_active"),
) )
+56 -35
View File
@@ -188,6 +188,8 @@ class OIDCJWKSService:
Returns: Returns:
Number of keys loaded Number of keys loaded
""" """
if not self._table_exists():
return 0
try: try:
db_keys = OidcJwksKey.get_active_keys() db_keys = OidcJwksKey.get_active_keys()
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
@@ -208,6 +210,7 @@ class OIDCJWKSService:
return len(self._keys) return len(self._keys)
except Exception as e: except Exception as e:
current_app.logger.error(f"Error loading keys from database: {e}") current_app.logger.error(f"Error loading keys from database: {e}")
db.session.rollback()
return 0 return 0
def save_key_to_db(self, key: JWKSKey, is_primary: bool = False) -> OidcJwksKey: def save_key_to_db(self, key: JWKSKey, is_primary: bool = False) -> OidcJwksKey:
@@ -364,6 +367,15 @@ class OIDCJWKSService:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
return key.is_active and key.expires_at > now return key.is_active and key.expires_at > now
def _table_exists(self) -> bool:
"""Check if the oidc_jwks_keys table exists in the database."""
try:
from sqlalchemy import inspect
inspector = inspect(db.engine)
return "oidc_jwks_keys" in inspector.get_table_names()
except Exception:
return False
def initialize_with_key(self) -> JWKSKey: def initialize_with_key(self) -> JWKSKey:
"""Initialize the service with a key, loading from database if available. """Initialize the service with a key, loading from database if available.
@@ -373,46 +385,55 @@ class OIDCJWKSService:
Returns: Returns:
JWKSKey instance JWKSKey instance
""" """
# First, try to load keys from database # Check if the table exists before attempting any DB operations
try: table_exists = self._table_exists()
# Check if there's a primary key in the database
primary_db_key = OidcJwksKey.get_primary_key()
if primary_db_key:
# Load the primary key into memory
now = datetime.now(timezone.utc)
key = JWKSKey(
kid=primary_db_key.kid,
private_key=primary_db_key.private_key,
public_key=primary_db_key.public_key,
algorithm=primary_db_key.algorithm,
created_at=primary_db_key.created_at,
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
is_active=primary_db_key.is_active,
)
self._keys[primary_db_key.kid] = key
current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}")
return key
# Try to load all active keys from database if table_exists:
loaded_count = self.load_keys_from_db() # First, try to load keys from database
if loaded_count > 0: try:
# Get the signing key from loaded keys # Check if there's a primary key in the database
signing_key = self.get_signing_key() primary_db_key = OidcJwksKey.get_primary_key()
if signing_key: if primary_db_key:
current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}") # Load the primary key into memory
return signing_key now = datetime.now(timezone.utc)
except Exception as e: key = JWKSKey(
current_app.logger.error(f"Error loading keys from database: {e}") kid=primary_db_key.kid,
private_key=primary_db_key.private_key,
public_key=primary_db_key.public_key,
algorithm=primary_db_key.algorithm,
created_at=primary_db_key.created_at,
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
is_active=primary_db_key.is_active,
)
self._keys[primary_db_key.kid] = key
current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}")
return key
# Try to load all active keys from database
loaded_count = self.load_keys_from_db()
if loaded_count > 0:
# Get the signing key from loaded keys
signing_key = self.get_signing_key()
if signing_key:
current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}")
return signing_key
except Exception as e:
current_app.logger.error(f"Error loading keys from database: {e}")
db.session.rollback()
else:
current_app.logger.info("[OIDC] Table oidc_jwks_keys does not exist yet, skipping DB load")
# No keys in database, generate a new key and save it # No keys in database, generate a new key and save it
current_app.logger.info("[OIDC] No existing keys found in database, generating new signing key") current_app.logger.info("[OIDC] No existing keys found in database, generating new signing key")
new_key = self.generate_new_key_pair() new_key = self.generate_new_key_pair()
# Save the new key to database # Save the new key to database (only if table exists)
try: if table_exists:
self.save_key_to_db(new_key, is_primary=True) try:
current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}") self.save_key_to_db(new_key, is_primary=True)
except Exception as e: current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}")
current_app.logger.error(f"Error saving key to database: {e}") except Exception as e:
current_app.logger.error(f"Error saving key to database: {e}")
db.session.rollback()
return new_key return new_key
+16 -3
View File
@@ -1,7 +1,10 @@
"""Initialize database script.""" """Initialize database script."""
from gatehouse_app import create_app from gatehouse_app import create_app
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from sqlalchemy import text
from dotenv import load_dotenv from dotenv import load_dotenv
import os
import time
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
@@ -9,10 +12,20 @@ load_dotenv()
# Create application # Create application
app = create_app() app = create_app()
import gatehouse_app.models
with app.app_context(): with app.app_context():
# Drop all tables # Drop all tables, constraints, and indexes cleanly
print("Dropping all tables...") db_url = os.getenv("DATABASE_URL", "")
db.drop_all() db_name = db_url.split("/")[-1] if db_url else "gatehouse_db"
print(f"⚠️ WARNING: About to drop all tables in database '{db_name}'!")
print("Countdown to deletion:")
for i in range(5, 0, -1):
print(f"{i}...")
time.sleep(1)
db.session.execute(text("DROP SCHEMA public CASCADE"))
db.session.execute(text("CREATE SCHEMA public"))
db.session.commit()
# Create all tables # Create all tables
print("Creating all tables...") print("Creating all tables...")