From 16d04bd5f757ec7bb54ce1e1c43d439cbdf4c610 Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Fri, 6 Mar 2026 01:36:23 +0545 Subject: [PATCH 1/9] Chore: Setup and Env --- .env.example | 123 ++++++++++++++++++++-------- config/base.py | 2 + manage.py | 73 ++++++++++++++--- scripts/configure_oauth_provider.py | 2 +- scripts/init_db.py | 7 ++ 5 files changed, 162 insertions(+), 45 deletions(-) diff --git a/.env.example b/.env.example index 204b2bf..537face 100644 --- a/.env.example +++ b/.env.example @@ -1,58 +1,117 @@ -# Flask Configuration -FLASK_APP=wsgi.py +FLASK_APP=manage.py FLASK_ENV=development -SECRET_KEY=your-secret-key-here-change-in-production +FLASK_DEBUG=1 # Database -DATABASE_URL=postgresql://user:password@localhost:5432/authy2_dev +DATABASE_URL=postgresql://user:password@localhost:5432/gatehouse_dev SQLALCHEMY_ECHO=False SQLALCHEMY_LOG_LEVEL=WARNING -# Security +# Security / Encryption +SECRET_KEY=change-me-in-production +ENCRYPTION_KEY=change-me-in-production-32-bytes!! +# Used to encrypt SSH CA private keys stored in the database +CA_ENCRYPTION_KEY=change-me-in-production BCRYPT_LOG_ROUNDS=12 -ENCRYPTION_KEY=your-encryption-key-here-change-in-production + +# Session cookies SESSION_COOKIE_SECURE=False -SESSION_COOKIE_HTTPONLY=True SESSION_COOKIE_SAMESITE=Lax +# Only needed when sharing cookies across subdomains (e.g. api.example.com + ui.example.com) +# SESSION_COOKIE_DOMAIN=example.com MAX_SESSION_DURATION=86400 -# CORS -#CORS_ORIGINS=http://localhost:3000,http://localhost:5173,https://oidc-playpen.lovable.app/,http://localhost:8080/ -CORS_ORIGINS=* - - -# JWT (if using JWT instead of sessions) -JWT_SECRET_KEY=your-jwt-secret-key-here +# ───────────────────────────────────────────────────────────────────────────── +# JWT +# ───────────────────────────────────────────────────────────────────────────── +JWT_SECRET_KEY=change-me-in-production JWT_ACCESS_TOKEN_EXPIRES=3600 JWT_REFRESH_TOKEN_EXPIRES=2592000 -# Redis (for session storage) +# ───────────────────────────────────────────────────────────────────────────── +# Redis (session storage + rate limiting) +# ───────────────────────────────────────────────────────────────────────────── REDIS_URL=redis://localhost:6379/0 +SESSION_REDIS_URL=redis://localhost:6379/0 +RATELIMIT_STORAGE_URL=redis://localhost:6379/1 -# OIDC +# ───────────────────────────────────────────────────────────────────────────── +# CORS +# ───────────────────────────────────────────────────────────────────────────── +CORS_ORIGINS=http://localhost:8080,http://localhost:5173 + +# ───────────────────────────────────────────────────────────────────────────── +# Frontend / App URLs +# All three should point at the browser-facing SPA. They are used for: +# FRONTEND_URL → OAuth callback redirects after provider auth +# APP_URL → Password-reset and email-verify links in emails +# OIDC_UI_URL → OIDC /authorize redirects to the React consent/login UI +# ───────────────────────────────────────────────────────────────────────────── +FRONTEND_URL=http://localhost:8080 +APP_URL=http://localhost:8080 +OIDC_UI_URL=http://localhost:8080 + +# ───────────────────────────────────────────────────────────────────────────── +# OIDC / OAuth issuer +# ───────────────────────────────────────────────────────────────────────────── OIDC_ISSUER_URL=http://localhost:5000 +OIDC_BASE_URL=http://localhost:5000 +# ───────────────────────────────────────────────────────────────────────────── +# WebAuthn +# ───────────────────────────────────────────────────────────────────────────── +WEBAUTHN_RP_ID=localhost +WEBAUTHN_RP_NAME=Gatehouse +WEBAUTHN_ORIGIN=http://localhost:8080 + +# ───────────────────────────────────────────────────────────────────────────── +# SSH CA (pick one) +# ───────────────────────────────────────────────────────────────────────────── +SSH_CA_KEY_PATH=/path/to/ca-users +# SSH_CA_PRIVATE_KEY= # raw key content; takes priority over SSH_CA_KEY_PATH + +# ───────────────────────────────────────────────────────────────────────────── +# Email / SMTP +# ───────────────────────────────────────────────────────────────────────────── +EMAIL_ENABLED=False +SMTP_HOST=smtp.gmail.com +SMTP_PORT=587 +SMTP_USE_TLS=True +SMTP_USERNAME= +SMTP_PASSWORD= +FROM_ADDRESS=noreply@gatehouse.local + +# ───────────────────────────────────────────────────────────────────────────── # Logging +# ───────────────────────────────────────────────────────────────────────────── LOG_LEVEL=INFO LOG_TO_STDOUT=True +# ───────────────────────────────────────────────────────────────────────────── # Rate Limiting +# ───────────────────────────────────────────────────────────────────────────── RATELIMIT_ENABLED=True -RATELIMIT_STORAGE_URL=redis://localhost:6379/1 - -# SSH CA -# Path to CA private key file (alternative to SSH_CA_PRIVATE_KEY env var) -SSH_CA_KEY_PATH=/path/to/ca-users -# Or set the key content directly (takes priority over SSH_CA_KEY_PATH): -# SSH_CA_PRIVATE_KEY= - -EMAIL_ENABLED= -SMTP_HOST= -SMTP_PORT= -SMTP_USERNAME= -SMTP_PASSWORD= -FROM_ADDRESS= -WEBAUTHN_ORIGIN= +# Per-endpoint auth limits (optional — defaults shown) +# RATELIMIT_AUTH_REGISTER=10 per minute; 50 per hour +# RATELIMIT_AUTH_LOGIN=20 per minute; 100 per hour +# RATELIMIT_AUTH_TOTP_VERIFY=20 per minute; 100 per hour +# RATELIMIT_AUTH_FORGOT_PASSWORD=5 per minute; 20 per hour +# RATELIMIT_AUTH_RESET_PASSWORD=10 per minute; 30 per hour ZEROTIER_API_TOKEN= -ZEROTIER_API_URL= \ No newline at end of file +ZEROTIER_API_URL= + +# ───────────────────────────────────────────────────────────────────────────── +# OIDC token lifetimes & security (optional — defaults shown) +# ───────────────────────────────────────────────────────────────────────────── +# OIDC_ACCESS_TOKEN_LIFETIME=3600 +# OIDC_REFRESH_TOKEN_LIFETIME=2592000 +# OIDC_ID_TOKEN_LIFETIME=3600 +# OIDC_AUTHORIZATION_CODE_LIFETIME=600 +# OIDC_REQUIRE_PKCE=True +# OIDC_ALLOW_IMPLICIT_FLOW=False +# OIDC_KEY_ROTATION_DAYS=90 +# OIDC_KEY_GRACE_PERIOD_DAYS=30 +# OIDC_RATE_LIMIT_AUTHORIZE=10/minute +# OIDC_RATE_LIMIT_TOKEN=20/minute +# OIDC_RATE_LIMIT_USERINFO=60/minute diff --git a/config/base.py b/config/base.py index b940767..e4d9329 100644 --- a/config/base.py +++ b/config/base.py @@ -128,6 +128,8 @@ class BaseConfig: # Frontend URL (for OAuth callback redirects) FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:8080") + APP_URL = os.getenv("APP_URL", os.getenv("FRONTEND_URL", "http://localhost:8080")) + OIDC_UI_URL = os.getenv("OIDC_UI_URL", os.getenv("FRONTEND_URL", "http://localhost:8080")) # ZeroTier Configuration ZEROTIER_API_TOKEN = os.getenv("ZEROTIER_API_TOKEN", "") diff --git a/manage.py b/manage.py index 3f694fc..06d3fc7 100644 --- a/manage.py +++ b/manage.py @@ -1,5 +1,6 @@ """Management script for Flask application.""" import os +import click from dotenv import load_dotenv # Load environment variables FIRST, before any app imports @@ -153,36 +154,75 @@ def mfa_compliance_status(): @cli.command("configure_oauth") -def configure_oauth(): - """Interactively configure an OAuth provider at the application level. +@click.argument("provider", required=False) +@click.option("--client-id", default=None, help="OAuth client ID") +@click.option("--client-secret", default=None, help="OAuth client secret") +@click.option("--redirect-url", default=None, help="Default redirect URL (e.g. https://yourdomain.com/api/v1/auth/external//callback)") +def configure_oauth(provider, client_id, client_secret, redirect_url): + """Configure an OAuth provider at the application level. - Usage: + Usage (interactive): python manage.py configure_oauth + Usage (non-interactive): + python manage.py configure_oauth google --client-id ID --client-secret SECRET + Supported providers: google, github, microsoft """ import getpass - from gatehouse_app.models.authentication_method import ApplicationProviderConfig + from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig from gatehouse_app.extensions import db SUPPORTED = ["google", "github", "microsoft"] - print("=" * 60) - print("OAuth Provider Configuration") - print("=" * 60) - print(f"Supported providers: {', '.join(SUPPORTED)}") + # Well-known endpoints — stored in additional_config so the adapter can + # resolve auth_url / token_url / userinfo_url without extra logic. + PROVIDER_DEFAULTS = { + "google": { + "auth_url": "https://accounts.google.com/o/oauth2/v2/auth", + "token_url": "https://oauth2.googleapis.com/token", + "userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo", + }, + "github": { + "auth_url": "https://github.com/login/oauth/authorize", + "token_url": "https://github.com/login/oauth/access_token", + "userinfo_url": "https://api.github.com/user", + }, + "microsoft": { + "auth_url": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", + "token_url": "https://login.microsoftonline.com/common/oauth2/v2.0/token", + "userinfo_url": "https://graph.microsoft.com/oidc/userinfo", + }, + } - provider = input("Provider [google/github/microsoft]: ").strip().lower() + if not provider: + print("=" * 60) + print("OAuth Provider Configuration") + print("=" * 60) + print(f"Supported providers: {', '.join(SUPPORTED)}") + provider = input("Provider [google/github/microsoft]: ").strip().lower() + + provider = provider.strip().lower() if provider not in SUPPORTED: print(f"❌ Unknown provider: {provider}") return - client_id = input("Client ID: ").strip() + if not client_id: + client_id = input("Client ID: ").strip() if not client_id: print("❌ client_id is required") return - client_secret = getpass.getpass("Client Secret (leave blank to keep existing): ").strip() + if not client_secret: + client_secret = getpass.getpass("Client Secret (leave blank to keep existing): ").strip() + + if not redirect_url: + base_url = os.getenv("API_BASE_URL", "http://localhost:5000/api/v1") + default = f"{base_url}/auth/external/{provider}/callback" + entered = input(f"Default redirect URL [{default}]: ").strip() + redirect_url = entered or default + + additional_config = PROVIDER_DEFAULTS[provider].copy() with app.app_context(): config = ApplicationProviderConfig.query.filter_by(provider_type=provider).first() @@ -191,6 +231,11 @@ def configure_oauth(): if client_secret: config.set_client_secret(client_secret) config.is_enabled = True + config.default_redirect_url = redirect_url + config.additional_config = { + **(config.additional_config or {}), + **additional_config, + } db.session.commit() print(f"✅ Updated {provider} provider config.") else: @@ -198,12 +243,16 @@ def configure_oauth(): provider_type=provider, client_id=client_id, is_enabled=True, + default_redirect_url=redirect_url, + additional_config=additional_config, ) if client_secret: config.set_client_secret(client_secret) db.session.add(config) db.session.commit() print(f"✅ Created {provider} provider config.") + print(f" redirect_url : {redirect_url}") + print(f" auth_url : {additional_config['auth_url']}") @cli.command("list_oauth") @@ -213,7 +262,7 @@ def list_oauth(): Usage: python manage.py list_oauth """ - from gatehouse_app.models.authentication_method import ApplicationProviderConfig + from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig with app.app_context(): configs = ApplicationProviderConfig.query.all() diff --git a/scripts/configure_oauth_provider.py b/scripts/configure_oauth_provider.py index fe8bcc0..5372bd7 100755 --- a/scripts/configure_oauth_provider.py +++ b/scripts/configure_oauth_provider.py @@ -58,7 +58,7 @@ if os.path.exists(env_file): # Import after path setup from gatehouse_app import create_app -from gatehouse_app.services.external_auth_service import ExternalAuthService, ExternalAuthError +from gatehouse_app.services.external_auth import ExternalAuthService, ExternalAuthError def _microsoft_defaults() -> dict: diff --git a/scripts/init_db.py b/scripts/init_db.py index 6f620bd..9398243 100644 --- a/scripts/init_db.py +++ b/scripts/init_db.py @@ -1,4 +1,11 @@ """Initialize database script.""" +import sys +import os +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + from gatehouse_app import create_app from gatehouse_app.extensions import db from sqlalchemy import text From 7492c406688a929c10be5a4979807a038188f4a4 Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Fri, 6 Mar 2026 18:20:09 +0545 Subject: [PATCH 2/9] Fix: Admin Expiry Hours --- gatehouse_app/api/v1/ssh/certs.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/gatehouse_app/api/v1/ssh/certs.py b/gatehouse_app/api/v1/ssh/certs.py index 429f6f0..d7537fc 100644 --- a/gatehouse_app/api/v1/ssh/certs.py +++ b/gatehouse_app/api/v1/ssh/certs.py @@ -130,14 +130,18 @@ def sign_certificate(): dept_policy = _get_merged_dept_cert_policy(user_id) if dept_policy: - if is_org_admin: + if not dept_policy["allow_user_expiry"]: + expiry_hours = dept_policy["default_expiry_hours"] + elif is_org_admin: if expiry_hours is not None: expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"]) - elif not dept_policy["allow_user_expiry"]: - expiry_hours = dept_policy["default_expiry_hours"] + else: + expiry_hours = dept_policy["default_expiry_hours"] else: if expiry_hours is not None: expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"]) + else: + expiry_hours = dept_policy["default_expiry_hours"] policy_extensions = dept_policy["extensions"] else: policy_extensions = None From ff976ee1cca3047e4e74f6a66f7c0dd2821f4640 Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Fri, 6 Mar 2026 18:41:46 +0545 Subject: [PATCH 3/9] Fix: Serial uniqueness --- gatehouse_app/models/ssh_ca/ca.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/gatehouse_app/models/ssh_ca/ca.py b/gatehouse_app/models/ssh_ca/ca.py index eee9909..91548fc 100644 --- a/gatehouse_app/models/ssh_ca/ca.py +++ b/gatehouse_app/models/ssh_ca/ca.py @@ -1,10 +1,15 @@ """Certificate Authority (CA) model.""" +import time from enum import Enum from datetime import datetime, timezone from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel +def _serial_start() -> int: + return int(time.time() * 1000) + + class KeyType(str, Enum): """SSH CA key types.""" @@ -91,7 +96,9 @@ class CA(BaseModel): # Monotonically-increasing serial counter. Every cert this CA issues # gets the next value so serials are unique, ordered, and auditable. # Protected by a row-level SELECT … FOR UPDATE in get_next_serial(). - next_serial_number = db.Column(db.BigInteger, default=1, nullable=False) + # Initialised to the current Unix timestamp in milliseconds so serials + # are globally unique across CAs from the moment of creation. + next_serial_number = db.Column(db.BigInteger, default=_serial_start, nullable=False) # Relationships organization = db.relationship("Organization", back_populates="cas") From f334000da343c1c11275eddb55bbd460c49c107b Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Sun, 8 Mar 2026 18:10:26 +0545 Subject: [PATCH 4/9] Feat: Implemented SUDO Department & API Key, CA Serial --- gatehouse_app/api/v1/__init__.py | 3 +- gatehouse_app/api/v1/departments.py | 4 + .../api/v1/organizations/__init__.py | 4 +- .../api/v1/organizations/api_keys.py | 299 ++++++++++++++++++ gatehouse_app/api/v1/organizations/audit.py | 98 ++++++ gatehouse_app/api/v1/sudo.py | 137 ++++++++ gatehouse_app/models/organization/__init__.py | 2 + .../models/organization/department.py | 1 + .../organization/department_cert_policy.py | 5 +- .../models/organization/organization.py | 3 + .../organization/organization_api_key.py | 158 +++++++++ .../services/ssh_ca_signing_service.py | 6 + .../versions/020_ca_serial_timestamp_start.py | 76 +++++ ...c2d5_add_cert_token_to_ssh_certificates.py | 30 ++ .../versions/add_can_sudo_to_departments.py | 34 ++ .../add_organization_api_keys_table.py | 56 ++++ 16 files changed, 911 insertions(+), 5 deletions(-) create mode 100644 gatehouse_app/api/v1/organizations/api_keys.py create mode 100644 gatehouse_app/api/v1/sudo.py create mode 100644 gatehouse_app/models/organization/organization_api_key.py create mode 100644 migrations/versions/020_ca_serial_timestamp_start.py create mode 100644 migrations/versions/3de11c5dc2d5_add_cert_token_to_ssh_certificates.py create mode 100644 migrations/versions/add_can_sudo_to_departments.py create mode 100644 migrations/versions/add_organization_api_keys_table.py diff --git a/gatehouse_app/api/v1/__init__.py b/gatehouse_app/api/v1/__init__.py index a5f676c..836fa79 100644 --- a/gatehouse_app/api/v1/__init__.py +++ b/gatehouse_app/api/v1/__init__.py @@ -5,6 +5,7 @@ from flask import Blueprint api_v1_bp = Blueprint("api_v1", __name__) # Import route modules to register them -from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh, zerotier +from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh, zerotier, sudo api_v1_bp.register_blueprint(ssh.ssh_bp) + diff --git a/gatehouse_app/api/v1/departments.py b/gatehouse_app/api/v1/departments.py index d305c66..4ced781 100644 --- a/gatehouse_app/api/v1/departments.py +++ b/gatehouse_app/api/v1/departments.py @@ -16,12 +16,15 @@ class DepartmentCreateSchema(Schema): """Schema for creating a department.""" name = fields.Str(required=True, validate=validate.Length(min=1, max=255)) description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) + can_sudo = fields.Bool(allow_none=True, load_default=False) + class DepartmentUpdateSchema(Schema): """Schema for updating a department.""" name = fields.Str(validate=validate.Length(min=1, max=255)) description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) + can_sudo = fields.Bool(allow_none=True) class AddDepartmentMemberSchema(Schema): @@ -119,6 +122,7 @@ def create_department(org_id): organization_id=org_id, name=data["name"], description=data.get("description"), + can_sudo=data.get("can_sudo", False), ) db.session.add(dept) db.session.commit() diff --git a/gatehouse_app/api/v1/organizations/__init__.py b/gatehouse_app/api/v1/organizations/__init__.py index 76f6fdd..fba555b 100644 --- a/gatehouse_app/api/v1/organizations/__init__.py +++ b/gatehouse_app/api/v1/organizations/__init__.py @@ -1,4 +1,4 @@ """Organization routes package.""" -from gatehouse_app.api.v1.organizations import core, members, invites, clients, cas, audit, roles +from gatehouse_app.api.v1.organizations import core, members, invites, clients, cas, audit, roles, api_keys -__all__ = ["core", "members", "invites", "clients", "cas", "audit", "roles"] +__all__ = ["core", "members", "invites", "clients", "cas", "audit", "roles", "api_keys"] diff --git a/gatehouse_app/api/v1/organizations/api_keys.py b/gatehouse_app/api/v1/organizations/api_keys.py new file mode 100644 index 0000000..90d83ee --- /dev/null +++ b/gatehouse_app/api/v1/organizations/api_keys.py @@ -0,0 +1,299 @@ +"""Organization API Key management endpoints.""" +from flask import g, request +from marshmallow import Schema, fields, validate, ValidationError + +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required +from gatehouse_app.models.organization import OrganizationApiKey +from gatehouse_app.services.organization_service import OrganizationService +from gatehouse_app.extensions import db + + +class ApiKeyCreateSchema(Schema): + """Schema for creating an API key.""" + name = fields.Str(required=True, validate=validate.Length(min=1, max=255)) + description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) + + +class ApiKeyUpdateSchema(Schema): + """Schema for updating an API key.""" + name = fields.Str(validate=validate.Length(min=1, max=255)) + description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) + + +@api_v1_bp.route("/organizations//api-keys", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def list_api_keys(org_id): + """ + List all API keys for an organization. + + Only accessible by organization admins. + + Args: + org_id: Organization ID + + Returns: + 200: List of API keys (without key values) + 401: Not authenticated + 403: Not an admin + 404: Organization not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + # Check if user is an admin + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + + membership = OrganizationMember.query.filter_by( + user_id=g.current_user.id, + organization_id=org_id, + deleted_at=None + ).first() + + if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: + return api_response( + success=False, + message="You do not have permission to manage API keys", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + api_keys = OrganizationApiKey.query.filter_by( + organization_id=org_id, + deleted_at=None + ).all() + + return api_response( + data={ + "api_keys": [k.to_dict() for k in api_keys], + "count": len(api_keys), + }, + message="API keys retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//api-keys", methods=["POST"]) +@login_required +@require_admin +@full_access_required +def create_api_key(org_id): + """ + Create a new API key for an organization. + + Only accessible by organization admins. + The plain text key is returned only on creation and should be stored securely. + + Args: + org_id: Organization ID + + Request body: + name: API key name (required) + description: Optional description + + Returns: + 201: API key created successfully + 400: Validation error + 401: Not authenticated + 403: Not an admin + 404: Organization not found + """ + try: + org = OrganizationService.get_organization_by_id(org_id) + + # Check if user is an admin + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + + membership = OrganizationMember.query.filter_by( + user_id=g.current_user.id, + organization_id=org_id, + deleted_at=None + ).first() + + if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: + return api_response( + success=False, + message="You do not have permission to create API keys", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + schema = ApiKeyCreateSchema() + data = schema.load(request.json or {}) + + # Create the API key + api_key, plain_key = OrganizationApiKey.create_key( + organization_id=org_id, + name=data["name"], + description=data.get("description"), + ) + + # Return the key data with the plain text key (only on creation) + key_dict = api_key.to_dict() + key_dict["key"] = plain_key # Include plain text only on creation + + return api_response( + data={"api_key": key_dict}, + message="API key created successfully. Store the key value securely - it cannot be retrieved later.", + status=201, + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + +@api_v1_bp.route("/organizations//api-keys/", methods=["PATCH"]) +@login_required +@require_admin +@full_access_required +def update_api_key(org_id, key_id): + """ + Update an API key. + + Only accessible by organization admins. + + Args: + org_id: Organization ID + key_id: API Key ID + + Request body: + name: New name (optional) + description: New description (optional) + + Returns: + 200: API key updated successfully + 400: Validation error + 401: Not authenticated + 403: Not an admin + 404: Organization or API key not found + """ + try: + org = OrganizationService.get_organization_by_id(org_id) + + # Check if user is an admin + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + + membership = OrganizationMember.query.filter_by( + user_id=g.current_user.id, + organization_id=org_id, + deleted_at=None + ).first() + + if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: + return api_response( + success=False, + message="You do not have permission to update API keys", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + api_key = OrganizationApiKey.query.filter_by( + id=key_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not api_key: + return api_response( + success=False, + message="API key not found", + status=404, + error_type="NOT_FOUND", + ) + + schema = ApiKeyUpdateSchema() + data = schema.load(request.json or {}) + + # Update fields + if "name" in data: + api_key.name = data["name"] + if "description" in data: + api_key.description = data["description"] + + api_key.save() + + return api_response( + data={"api_key": api_key.to_dict()}, + message="API key updated successfully", + ) + + except ValidationError as e: + return api_response( + success=False, + message="Validation failed", + status=400, + error_type="VALIDATION_ERROR", + error_details=e.messages, + ) + + +@api_v1_bp.route("/organizations//api-keys/", methods=["DELETE"]) +@login_required +@require_admin +@full_access_required +def delete_api_key(org_id, key_id): + """ + Delete/revoke an API key. + + Only accessible by organization admins. + + Args: + org_id: Organization ID + key_id: API Key ID + + Returns: + 200: API key deleted successfully + 401: Not authenticated + 403: Not an admin + 404: Organization or API key not found + """ + org = OrganizationService.get_organization_by_id(org_id) + + # Check if user is an admin + from gatehouse_app.models.organization.organization_member import OrganizationMember + from gatehouse_app.utils.constants import OrganizationRole + + membership = OrganizationMember.query.filter_by( + user_id=g.current_user.id, + organization_id=org_id, + deleted_at=None + ).first() + + if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: + return api_response( + success=False, + message="You do not have permission to delete API keys", + status=403, + error_type="AUTHORIZATION_ERROR", + ) + + api_key = OrganizationApiKey.query.filter_by( + id=key_id, + organization_id=org_id, + deleted_at=None + ).first() + + if not api_key: + return api_response( + success=False, + message="API key not found", + status=404, + error_type="NOT_FOUND", + ) + + # Soft delete the API key + api_key.delete(soft=True) + + return api_response( + message="API key deleted successfully", + ) diff --git a/gatehouse_app/api/v1/organizations/audit.py b/gatehouse_app/api/v1/organizations/audit.py index 0ddd315..eeb4cad 100644 --- a/gatehouse_app/api/v1/organizations/audit.py +++ b/gatehouse_app/api/v1/organizations/audit.py @@ -173,3 +173,101 @@ def get_my_audit_logs(): }, message="Activity retrieved", ) + + +@api_v1_bp.route("/organizations//certificates/audit", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def get_certificate_audit_logs(org_id): + """ + Get certificate issuance audit logs for an organization. + + Only accessible by organization admins. + Returns certificate serial IDs, user IDs, and issuance timestamps for compliance. + + Args: + org_id: Organization ID + + Returns: + 200: List of certificate audit logs + 401: Not authenticated + 403: Not an admin + 404: Organization not found + """ + from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate + from gatehouse_app.models.user import User + + org = OrganizationService.get_organization_by_id(org_id) + + page = max(1, int(request.args.get("page", 1))) + per_page = min(int(request.args.get("per_page", 50)), 200) + action_filter = request.args.get("action", "signed") # Default to signed certificates + + # Get all CAs for this organization + from gatehouse_app.models.ssh_ca import CA + org_cas = CA.query.filter_by(organization_id=org_id, deleted_at=None).all() + org_ca_ids = [ca.id for ca in org_cas] + + if not org_ca_ids: + return api_response( + data={ + "audit_logs": [], + "count": 0, + "page": page, + "per_page": per_page, + "pages": 0, + }, + message="No certificate audit logs found", + ) + + # Query certificate audit logs for certificates issued by org's CAs + query = CertificateAuditLog.query.join( + SSHCertificate, + CertificateAuditLog.certificate_id == SSHCertificate.id + ).filter( + SSHCertificate.ca_id.in_(org_ca_ids), + CertificateAuditLog.deleted_at.is_(None) + ) + + if action_filter: + query = query.filter(CertificateAuditLog.action == action_filter) + + query = query.order_by(CertificateAuditLog.created_at.desc()) + total = query.count() + logs = query.offset((page - 1) * per_page).limit(per_page).all() + + # Build response data with certificate details + audit_data = [] + for log in logs: + cert = log.certificate + user = log.user + audit_data.append({ + "id": log.id, + "action": log.action, + "certificate_serial": cert.serial, + "key_id": cert.key_id, + "principals": cert.principals, + "user_id": user.id if user else cert.user_id, + "user_email": user.email if user else None, + "issued_at": cert.created_at.isoformat() if cert.created_at else None, + "valid_after": cert.valid_after.isoformat() if cert.valid_after else None, + "valid_before": cert.valid_before.isoformat() if cert.valid_before else None, + "ip_address": log.ip_address, + "user_agent": log.user_agent, + "message": log.message, + "success": log.success, + "created_at": log.created_at.isoformat() if log.created_at else None, + }) + + return api_response( + data={ + "audit_logs": audit_data, + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + }, + message="Certificate audit logs retrieved successfully", + ) diff --git a/gatehouse_app/api/v1/sudo.py b/gatehouse_app/api/v1/sudo.py new file mode 100644 index 0000000..f828587 --- /dev/null +++ b/gatehouse_app/api/v1/sudo.py @@ -0,0 +1,137 @@ +"""Sudoer check and sudo-related endpoints.""" +from flask import request +from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.utils.response import api_response +from gatehouse_app.models.organization import OrganizationApiKey +from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate +from gatehouse_app.models.organization import Department, DepartmentMembership + + +@api_v1_bp.route("/sudo/check", methods=["POST"]) +def check_sudoer(): + """ + Check if a user with a given certificate can sudo. + + This endpoint validates an API key for an organization, retrieves the certificate + by serial ID, finds the user and their departments, and checks if any of their + departments have sudo capability. + + Request body: + api_key: Organization API key (required) + certificate_serial: Certificate serial ID (required) + + Returns: + 200: Sudoer status returned + 400: Invalid request body + 401: Invalid API key + 403: Certificate not found or user not found + 404: Organization or certificate not found + """ + try: + data = request.get_json() + + if not data: + return api_response( + success=False, + message="Request body is required", + status=400, + error_type="INVALID_REQUEST", + ) + + api_key = data.get("api_key") + certificate_serial = data.get("certificate_serial") + + if not api_key or certificate_serial is None: + return api_response( + success=False, + message="api_key and certificate_serial are required", + status=400, + error_type="MISSING_REQUIRED_FIELDS", + ) + + # Find the certificate by serial + certificate = SSHCertificate.query.filter_by( + serial=certificate_serial, + deleted_at=None + ).first() + + if not certificate: + return api_response( + success=False, + message="Certificate not found", + status=404, + error_type="NOT_FOUND", + ) + + # Get the CA and organization + ca = certificate.ca + if not ca: + return api_response( + success=False, + message="Certificate CA not found", + status=404, + error_type="NOT_FOUND", + ) + + org_id = ca.organization_id + + # Verify the API key for this organization + org_api_key = OrganizationApiKey.verify_key(org_id, api_key) + + if not org_api_key: + return api_response( + success=False, + message="Invalid API key for organization", + status=401, + error_type="UNAUTHORIZED", + ) + + # Get the user from the certificate + user = certificate.user + if not user: + return api_response( + success=False, + message="Certificate user not found", + status=404, + error_type="NOT_FOUND", + ) + + # Get all departments the user belongs to + user_departments = DepartmentMembership.query.filter_by( + user_id=user.id, + deleted_at=None + ).all() + + # Check if any of the user's departments have sudo capability + can_sudo = False + sudoer_departments = [] + + for dept_membership in user_departments: + dept = dept_membership.department + if dept and dept.can_sudo and dept.deleted_at is None: + can_sudo = True + sudoer_departments.append({ + "id": dept.id, + "name": dept.name, + }) + + return api_response( + data={ + "can_sudo": can_sudo, + "user_id": user.id, + "user_email": user.email, + "certificate_serial": certificate.serial, + "sudoer_departments": sudoer_departments, + "all_departments_count": len(user_departments), + }, + message="Sudoer status retrieved successfully", + status=200, + ) + + except Exception as e: + return api_response( + success=False, + message=f"An error occurred: {str(e)}", + status=500, + error_type="INTERNAL_ERROR", + ) diff --git a/gatehouse_app/models/organization/__init__.py b/gatehouse_app/models/organization/__init__.py index aa33f8e..52e29cf 100644 --- a/gatehouse_app/models/organization/__init__.py +++ b/gatehouse_app/models/organization/__init__.py @@ -12,6 +12,7 @@ from gatehouse_app.models.organization.department_cert_policy import ( ) from gatehouse_app.models.organization.principal import Principal, PrincipalMembership from gatehouse_app.models.organization.org_invite_token import OrgInviteToken +from gatehouse_app.models.organization.organization_api_key import OrganizationApiKey __all__ = [ "Organization", @@ -24,4 +25,5 @@ __all__ = [ "Principal", "PrincipalMembership", "OrgInviteToken", + "OrganizationApiKey", ] diff --git a/gatehouse_app/models/organization/department.py b/gatehouse_app/models/organization/department.py index 800780b..f46385a 100644 --- a/gatehouse_app/models/organization/department.py +++ b/gatehouse_app/models/organization/department.py @@ -27,6 +27,7 @@ class Department(BaseModel): ) name = db.Column(db.String(255), nullable=False, index=True) description = db.Column(db.Text, nullable=True) + can_sudo = db.Column(db.Boolean, default=False, nullable=False) # Relationships organization = db.relationship("Organization", back_populates="departments") diff --git a/gatehouse_app/models/organization/department_cert_policy.py b/gatehouse_app/models/organization/department_cert_policy.py index 357329f..3f8a93f 100644 --- a/gatehouse_app/models/organization/department_cert_policy.py +++ b/gatehouse_app/models/organization/department_cert_policy.py @@ -4,12 +4,13 @@ from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel -# Standard SSH certificate extensions +# Standard SSH certificate extensions — must be in strict lexical order +# (OpenSSH RFC 4251 §5 / golang.org/x/crypto/ssh requires lexical ordering) STANDARD_EXTENSIONS = [ "permit-X11-forwarding", "permit-agent-forwarding", - "permit-pty", "permit-port-forwarding", + "permit-pty", "permit-user-rc", ] diff --git a/gatehouse_app/models/organization/organization.py b/gatehouse_app/models/organization/organization.py index 9be5c65..349edb0 100644 --- a/gatehouse_app/models/organization/organization.py +++ b/gatehouse_app/models/organization/organization.py @@ -43,6 +43,9 @@ class Organization(BaseModel): cas = db.relationship( "CA", back_populates="organization", cascade="all, delete-orphan" ) + api_keys = db.relationship( + "OrganizationApiKey", back_populates="organization", cascade="all, delete-orphan" + ) def __repr__(self): """String representation of Organization.""" diff --git a/gatehouse_app/models/organization/organization_api_key.py b/gatehouse_app/models/organization/organization_api_key.py new file mode 100644 index 0000000..64feefa --- /dev/null +++ b/gatehouse_app/models/organization/organization_api_key.py @@ -0,0 +1,158 @@ +"""Organization API Key model — API keys for organizations for external integrations.""" +import secrets +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class OrganizationApiKey(BaseModel): + """API Key model representing an API key for an organization. + + API keys are used to authenticate external integrations or services + that need programmatic access to the organization's resources. + Each key is tied to an organization and can be revoked/deleted as needed. + """ + + __tablename__ = "organization_api_keys" + + organization_id = db.Column( + db.String(36), + db.ForeignKey("organizations.id"), + nullable=False, + index=True, + ) + + # Human-readable name for the API key + name = db.Column(db.String(255), nullable=False) + + # Hashed key value (never store plain text) + key_hash = db.Column(db.String(255), nullable=False, unique=True, index=True) + + # Last used timestamp for tracking activity + last_used_at = db.Column(db.DateTime, nullable=True) + + # Revocation status + is_revoked = db.Column(db.Boolean, default=False, nullable=False, index=True) + revoked_at = db.Column(db.DateTime, nullable=True) + revoke_reason = db.Column(db.String(255), nullable=True) + + # Description/purpose of the key + description = db.Column(db.Text, nullable=True) + + # Relationships + organization = db.relationship("Organization", back_populates="api_keys") + + __table_args__ = ( + db.Index("idx_org_api_key_org_active", "organization_id", "is_revoked"), + db.Index("idx_api_key_last_used", "last_used_at"), + ) + + def __repr__(self): + """String representation of OrganizationApiKey.""" + return f"" + + @staticmethod + def generate_key() -> str: + """Generate a random API key. + + Returns: + A random 32-byte hex string suitable for use as an API key + """ + return secrets.token_hex(32) + + @classmethod + def create_key( + cls, + organization_id: str, + name: str, + description: str = None, + ) -> tuple: + """Create and store a new API key for an organization. + + Args: + organization_id: ID of the organization + name: Human-readable name for the key + description: Optional description/purpose of the key + + Returns: + Tuple of (OrganizationApiKey instance, plain_text_key_string) + The plain text key is only returned on creation and should be + stored securely by the user. It cannot be retrieved later. + """ + # Generate a plain text key + plain_key = cls.generate_key() + + # Hash it using the key_hash method + key_hash = cls.hash_key(plain_key) + + # Create the database record + api_key = cls( + organization_id=organization_id, + name=name, + key_hash=key_hash, + description=description, + ) + api_key.save() + + return api_key, plain_key + + @staticmethod + def hash_key(plain_key: str) -> str: + """Hash an API key for storage. + + Args: + plain_key: The plain text API key + + Returns: + Hashed version of the key + """ + import hashlib + return hashlib.sha256(plain_key.encode()).hexdigest() + + @classmethod + def verify_key(cls, organization_id: str, plain_key: str) -> "OrganizationApiKey": + """Verify an API key for an organization. + + Args: + organization_id: ID of the organization + plain_key: The plain text API key to verify + + Returns: + OrganizationApiKey instance if valid and active, None otherwise + """ + key_hash = cls.hash_key(plain_key) + + api_key = cls.query.filter_by( + organization_id=organization_id, + key_hash=key_hash, + is_revoked=False, + deleted_at=None, + ).first() + + if api_key: + # Update last used timestamp + api_key.last_used_at = datetime.now(timezone.utc) + api_key.save() + + return api_key + + def revoke(self, reason: str = None) -> None: + """Revoke this API key. + + Args: + reason: Optional reason for revocation + """ + self.is_revoked = True + self.revoked_at = datetime.now(timezone.utc) + self.revoke_reason = reason + self.save() + + def to_dict(self, exclude=None): + """Convert API key to dictionary. + + The key_hash is excluded by default for security. + """ + exclude = exclude or [] + if "key_hash" not in exclude: + exclude.append("key_hash") + return super().to_dict(exclude=exclude) diff --git a/gatehouse_app/services/ssh_ca_signing_service.py b/gatehouse_app/services/ssh_ca_signing_service.py index 2f760cc..e092e57 100644 --- a/gatehouse_app/services/ssh_ca_signing_service.py +++ b/gatehouse_app/services/ssh_ca_signing_service.py @@ -288,6 +288,12 @@ class SSHCASigningService: else: extensions = [] # host certs: no extensions + # OpenSSH (RFC 4251 §5) and golang.org/x/crypto/ssh require + # certificate extensions to be in strict lexical (alphabetical) order. + # Sort unconditionally so any caller-supplied or policy-derived list + # is guaranteed to be compliant. + extensions = sorted(extensions) + certificate.fields.extensions = extensions certificate.fields.critical_options = signing_request.critical_options or {} diff --git a/migrations/versions/020_ca_serial_timestamp_start.py b/migrations/versions/020_ca_serial_timestamp_start.py new file mode 100644 index 0000000..2556607 --- /dev/null +++ b/migrations/versions/020_ca_serial_timestamp_start.py @@ -0,0 +1,76 @@ +"""Seed CA serial counters with a timestamp-based starting value. + +Revision ID: 020_ca_serial_timestamp_start +Revises: 019_audit_varchar, d34bfb72844e +Create Date: 2026-03-06 + +WHY +--- +``next_serial_number`` was originally seeded at ``1`` for every CA +(``server_default="1"`` in migration 017). Because the +``ix_ssh_certificates_serial`` index enforces a globally-unique constraint on +the serial column, any two CAs issuing their first certificate would both try +to insert serial ``1``, causing a UniqueViolation. + +FIX — new CAs +------------- +The CA model's Python-side ``default`` is now ``_serial_start()``, which +returns ``int(time.time() * 1000)`` (Unix milliseconds) at row-creation time. +CAs created after this migration will start their serial counter at the +millisecond they were first inserted, so serials are globally unique across +CAs and still monotonically increasing within each CA. + +FIX — existing CAs +------------------- +This migration performs a data migration: any CA whose ``next_serial_number`` +is still ``<= 2`` (i.e. has issued at most one certificate since the original +``1``-based default) is given a new timestamp-based starting value. + +CAs that have already issued many certificates keep their current counter +unchanged — their serials are already beyond the low collision-prone range. + +NOTE: the ``server_default`` on the column is intentionally NOT changed here +because SQLAlchemy uses the Python-side ``default=_serial_start`` callable for +new rows; the ``server_default`` is only a database-level fallback that is +never hit when rows are inserted via the ORM. +""" +import time +from alembic import op +import sqlalchemy as sa + +revision = "020_ca_serial_timestamp_start" +down_revision = ("3de11c5dc2d5", "d34bfb72844e") +branch_labels = None +depends_on = None + + +def _now_ms() -> int: + return int(time.time() * 1000) + + +def upgrade(): + conn = op.get_bind() + + # Update ALL CAs to a timestamp-based starting serial — not just those + # stuck at 1. Any CA with a serial below the current ms timestamp is in + # the low collision-prone range (serials 1–N where N is tiny). Resetting + # every CA to a fresh ms timestamp is safe: the counter only moves forward + # from here, and no existing certificate serial is changed. + rows = conn.execute( + sa.text("SELECT id FROM cas") + ).fetchall() + + for (ca_id,) in rows: + new_start = _now_ms() + conn.execute( + sa.text( + "UPDATE cas SET next_serial_number = :val WHERE id = :id" + ), + {"val": new_start, "id": ca_id}, + ) + + +def downgrade(): + # There is no safe downgrade for a data migration that assigns new serial + # starting points — resetting to 1 would recreate the collision risk. + pass diff --git a/migrations/versions/3de11c5dc2d5_add_cert_token_to_ssh_certificates.py b/migrations/versions/3de11c5dc2d5_add_cert_token_to_ssh_certificates.py new file mode 100644 index 0000000..0b066ad --- /dev/null +++ b/migrations/versions/3de11c5dc2d5_add_cert_token_to_ssh_certificates.py @@ -0,0 +1,30 @@ +"""add_cert_token_to_ssh_certificates + +Revision ID: 3de11c5dc2d5 +Revises: 019_audit_varchar +Create Date: 2026-03-06 16:04:33.561099 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '3de11c5dc2d5' +down_revision = '019_audit_varchar' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('ssh_certificates', sa.Column('cert_token', sa.String(length=64), nullable=True)) + op.create_index(op.f('ix_ssh_certificates_cert_token'), 'ssh_certificates', ['cert_token'], unique=True) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_ssh_certificates_cert_token'), table_name='ssh_certificates') + op.drop_column('ssh_certificates', 'cert_token') + # ### end Alembic commands ### diff --git a/migrations/versions/add_can_sudo_to_departments.py b/migrations/versions/add_can_sudo_to_departments.py new file mode 100644 index 0000000..ccc72e0 --- /dev/null +++ b/migrations/versions/add_can_sudo_to_departments.py @@ -0,0 +1,34 @@ +"""Add can_sudo column to departments table. + +Revision ID: 002_add_can_sudo_to_departments +Revises: 001_add_org_api_keys +Create Date: 2026-03-07 23:40:30.000000 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '002_add_can_sudo_to_departments' +down_revision = '001_add_org_api_keys' +branch_labels = None +depends_on = None + + +def upgrade(): + # Add can_sudo column to departments table + op.add_column('departments', + sa.Column('can_sudo', sa.Boolean(), nullable=False, server_default='false')) + + # Create index for performance + op.create_index('idx_dept_can_sudo', 'departments', + ['organization_id', 'can_sudo']) + + +def downgrade(): + # Drop index + op.drop_index('idx_dept_can_sudo', table_name='departments') + + # Drop column + op.drop_column('departments', 'can_sudo') diff --git a/migrations/versions/add_organization_api_keys_table.py b/migrations/versions/add_organization_api_keys_table.py new file mode 100644 index 0000000..1c62994 --- /dev/null +++ b/migrations/versions/add_organization_api_keys_table.py @@ -0,0 +1,56 @@ +"""Add organization_api_keys table for API key management. + +Revision ID: 001_add_org_api_keys +Revises: 3de11c5dc2d5 +Create Date: 2026-03-07 23:40:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '001_add_org_api_keys' +down_revision = '3de11c5dc2d5' +branch_labels = None +depends_on = None + + +def upgrade(): + # Create organization_api_keys table + op.create_table( + 'organization_api_keys', + sa.Column('id', sa.String(36), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('organization_id', sa.String(36), nullable=False), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('key_hash', sa.String(255), nullable=False), + sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('is_revoked', sa.Boolean(), nullable=False, server_default='false'), + sa.Column('revoked_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('revoke_reason', sa.String(255), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('key_hash'), + ) + + # Create indexes for performance + op.create_index('idx_org_api_key_org_active', 'organization_api_keys', + ['organization_id', 'is_revoked']) + op.create_index('idx_api_key_last_used', 'organization_api_keys', + ['last_used_at']) + op.create_index('idx_org_api_key_org_id', 'organization_api_keys', + ['organization_id']) + + +def downgrade(): + # Drop indexes + op.drop_index('idx_org_api_key_org_id', table_name='organization_api_keys') + op.drop_index('idx_api_key_last_used', table_name='organization_api_keys') + op.drop_index('idx_org_api_key_org_active', table_name='organization_api_keys') + + # Drop table + op.drop_table('organization_api_keys') From 42ff4f2f4f12478b935d786f79ea40ac774056a6 Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Mon, 9 Mar 2026 17:35:19 +0545 Subject: [PATCH 5/9] Fix: Migration Heads --- migrations/versions/021_merge_heads.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 migrations/versions/021_merge_heads.py diff --git a/migrations/versions/021_merge_heads.py b/migrations/versions/021_merge_heads.py new file mode 100644 index 0000000..9b1e662 --- /dev/null +++ b/migrations/versions/021_merge_heads.py @@ -0,0 +1,22 @@ +"""Merge 020_ca_serial_timestamp_start and 002_add_can_sudo_to_departments into a single head. + +Revision ID: 021_merge_heads +Revises: 020_ca_serial_timestamp_start, 002_add_can_sudo_to_departments +Create Date: 2026-03-09 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = '021_merge_heads' +down_revision = ('020_ca_serial_timestamp_start', '002_add_can_sudo_to_departments') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass From a7915c93283a982ddbb7eea7c10d298b5c3abd5e Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Fri, 13 Mar 2026 11:43:36 +0545 Subject: [PATCH 6/9] =?UTF-8?q?Fix:=20SSH=20key=20verification=20=E2=80=94?= =?UTF-8?q?=20accept=20raw=20armor=20+=20base64,=20clearer=20error=20messa?= =?UTF-8?q?ges?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gatehouse_app/services/ssh_key_service.py | 36 ++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/gatehouse_app/services/ssh_key_service.py b/gatehouse_app/services/ssh_key_service.py index db07b0b..4da133b 100644 --- a/gatehouse_app/services/ssh_key_service.py +++ b/gatehouse_app/services/ssh_key_service.py @@ -206,6 +206,40 @@ class SSHKeyService: return challenge_text + @staticmethod + def _decode_signature(signature: str) -> bytes: + """Decode a user-supplied signature into raw SSH signature bytes. + + Accepts either: + 1. The raw SSH armored signature (-----BEGIN SSH SIGNATURE-----) + 2. A base64-encoded version of that armored signature + (produced by ``cat file.sig | base64 -w0``) + + Returns the raw armored signature bytes suitable for writing to a + ``.sig`` file that ``ssh-keygen -Y verify`` can read. + """ + stripped = signature.strip() + + # If it already looks like a raw SSH signature armor, use it directly + if stripped.startswith("-----BEGIN SSH SIGNATURE-----"): + return stripped.encode("utf-8") + + # Otherwise treat it as base64 — strip any embedded whitespace first + cleaned = stripped.replace("\n", "").replace("\r", "").replace(" ", "") + try: + decoded = base64.b64decode(cleaned) + except Exception as exc: + raise SSHKeyError(f"Could not decode signature: {exc}") + + # Sanity-check: the decoded bytes should be a valid SSH signature + text = decoded.decode("utf-8", errors="replace") + if "-----BEGIN SSH SIGNATURE-----" not in text: + raise SSHKeyError( + "Invalid signature format. Please paste the output of: " + "cat /tmp/challenge.txt.sig | base64 -w0" + ) + return decoded + def verify_ssh_key_ownership( self, key_id: str, @@ -247,7 +281,7 @@ class SSHKeyService: # allowed_signers format: " " # We use the key fingerprint as the identity. - sig_bytes = base64.b64decode(signature) + sig_bytes = self._decode_signature(signature) challenge_text = key.verify_text + "\n" with tempfile.TemporaryDirectory() as tmpdir: From 05eb0922282119a1b1bae88f889cd983ad119799 Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Mon, 23 Mar 2026 17:51:55 +0545 Subject: [PATCH 7/9] Fix: DB Migration --- .../models/ssh_ca/ssh_certificate.py | 4 +- .../models/zerotier/activation_session.py | 6 +- .../models/zerotier/zerotier_membership.py | 4 +- .../versions/020_add_zerotier_models.py | 307 +------------- migrations/versions/022_add_command_events.py | 29 ++ .../versions/023_zerotier_drop_legacy.py | 393 ++++++++++++++++++ .../versions/024_fix_zerotier_schema.py | 291 +++++++++++++ migrations/versions/025_fix_zt_timestamps.py | 101 +++++ migrations/versions/026_schema_cleanup.py | 216 ++++++++++ .../027_fix_cert_serial_uniqueness.py | 105 +++++ 10 files changed, 1151 insertions(+), 305 deletions(-) create mode 100644 migrations/versions/022_add_command_events.py create mode 100644 migrations/versions/023_zerotier_drop_legacy.py create mode 100644 migrations/versions/024_fix_zerotier_schema.py create mode 100644 migrations/versions/025_fix_zt_timestamps.py create mode 100644 migrations/versions/026_schema_cleanup.py create mode 100644 migrations/versions/027_fix_cert_serial_uniqueness.py diff --git a/gatehouse_app/models/ssh_ca/ssh_certificate.py b/gatehouse_app/models/ssh_ca/ssh_certificate.py index a76fbd4..3affb58 100644 --- a/gatehouse_app/models/ssh_ca/ssh_certificate.py +++ b/gatehouse_app/models/ssh_ca/ssh_certificate.py @@ -50,7 +50,7 @@ class SSHCertificate(BaseModel): certificate = db.Column(db.Text, nullable=False) # Certificate metadata - serial = db.Column(db.String(255), nullable=False, unique=True, index=True) + serial = db.Column(db.String(255), nullable=False) key_id = db.Column(db.String(255), nullable=False) # Usually user email cert_type = db.Column( db.Enum(CertType, values_callable=lambda x: [e.value for e in x]), @@ -103,6 +103,8 @@ class SSHCertificate(BaseModel): ) __table_args__ = ( + db.UniqueConstraint("ca_id", "serial", name="uq_ssh_certificates_ca_serial"), + db.Index("ix_ssh_certificates_serial", "serial"), db.Index("idx_cert_user_status", "user_id", "status"), db.Index("idx_cert_validity", "valid_after", "valid_before"), db.Index("idx_cert_revoked", "revoked", "revoked_at"), diff --git a/gatehouse_app/models/zerotier/activation_session.py b/gatehouse_app/models/zerotier/activation_session.py index 8ee06cb..f26a9ef 100644 --- a/gatehouse_app/models/zerotier/activation_session.py +++ b/gatehouse_app/models/zerotier/activation_session.py @@ -45,14 +45,14 @@ class ActivationSession(BaseModel): index=True, ) authenticated_at = db.Column( - db.DateTime(timezone=True), + db.DateTime, nullable=False, ) expires_at = db.Column( - db.DateTime(timezone=True), + db.DateTime, nullable=False, ) - ended_at = db.Column(db.DateTime(timezone=True), nullable=True) + ended_at = db.Column(db.DateTime, nullable=True) end_reason = db.Column( db.Enum(ActivationEndReason, name="activation_end_reason"), nullable=True, diff --git a/gatehouse_app/models/zerotier/zerotier_membership.py b/gatehouse_app/models/zerotier/zerotier_membership.py index 4897056..ad2b7e9 100644 --- a/gatehouse_app/models/zerotier/zerotier_membership.py +++ b/gatehouse_app/models/zerotier/zerotier_membership.py @@ -51,8 +51,8 @@ class ZeroTierMembership(BaseModel): ) member_seen = db.Column(db.Boolean, default=False, nullable=False) authorized = db.Column(db.Boolean, default=False, nullable=False) - join_seen_at = db.Column(db.DateTime(timezone=True), nullable=True) - last_synced_at = db.Column(db.DateTime(timezone=True), nullable=True) + join_seen_at = db.Column(db.DateTime, nullable=True) + last_synced_at = db.Column(db.DateTime, nullable=True) raw_controller_payload = db.Column(db.JSON, nullable=True) # Relationships diff --git a/migrations/versions/020_add_zerotier_models.py b/migrations/versions/020_add_zerotier_models.py index 5b15439..bba02a4 100644 --- a/migrations/versions/020_add_zerotier_models.py +++ b/migrations/versions/020_add_zerotier_models.py @@ -4,314 +4,23 @@ Revision ID: 020_zerotier Revises: 019_audit_varchar Create Date: 2026-03-19 -Tables created: - - portal_networks — manager-created ZeroTier network bindings - - devices — user-registered ZeroTier node endpoints - - user_network_approvals — durable manager approval records - - device_network_memberships — per-device per-network workflow records - - activation_sessions — temporary activation windows - - zerotier_memberships — observed controller-side member state - - kill_switch_events — explicit rapid deactivation records +SUPERSEDED by 023_zerotier_drop_legacy which creates all ZeroTier tables +idempotently (with IF NOT EXISTS / if_not_exists=True). This migration is +kept as a no-op to preserve the Alembic revision chain for databases that +already have '020_zerotier' stamped (e.g. dev environments). """ -from alembic import op -import sqlalchemy as sa - revision = "020_zerotier" down_revision = "019_audit_varchar" branch_labels = None depends_on = None -def _pg_enum(enum_name: str, values: list[str]) -> sa.Enum: - return sa.Enum(*values, name=enum_name, create_type=False) - - def upgrade(): - bind = op.get_bind() - dialect = bind.dialect.name - - # ── 1. Enum types ───────────────────────────────────────────────────────── - - if dialect == "postgresql": - op.execute("CREATE TYPE network_environment AS ENUM (%s)" % ", ".join( - f"'{v}'" for v in ["production", "staging", "development", "lab"] - )) - op.execute("CREATE TYPE network_request_mode AS ENUM (%s)" % ", ".join( - f"'{v}'" for v in ["open", "approval_required", "invite_only"] - )) - op.execute("CREATE TYPE approval_grant_type AS ENUM (%s)" % ", ".join( - f"'{v}'" for v in ["requested", "assigned"] - )) - op.execute("CREATE TYPE approval_state AS ENUM (%s)" % ", ".join( - f"'{v}'" for v in ["pending", "approved", "rejected", "revoked", "suspended"] - )) - op.execute("CREATE TYPE membership_state AS ENUM (%s)" % ", ".join( - f"'{v}'" for v in [ - "pending_device_registration", - "pending_request", - "pending_manager_approval", - "approved_inactive", - "joined_deauthorized", - "active_authorized", - "activation_expired", - "suspended", - "revoked", - "rejected", - ] - )) - op.execute("CREATE TYPE activation_end_reason AS ENUM (%s)" % ", ".join( - f"'{v}'" for v in [ - "expired", "logout", "kill_switch", - "manual_revoke", "approval_revoked", "admin_action", - ] - )) - op.execute("CREATE TYPE kill_switch_scope AS ENUM (%s)" % ", ".join( - f"'{v}'" for v in ["organization", "global", "selected_networks"] - )) - op.execute("CREATE TYPE device_status AS ENUM (%s)" % ", ".join( - f"'{v}'" for v in ["active", "inactive"] - )) - - # ── 2. portal_networks ──────────────────────────────────────────────────── - - op.create_table( - "portal_networks", - sa.Column("id", sa.String(36), primary_key=True), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True), - sa.Column("name", sa.String(255), nullable=False), - sa.Column("description", sa.Text, nullable=True), - sa.Column("owner_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False), - sa.Column("zerotier_network_id", sa.String(16), nullable=False, index=True), - sa.Column( - "environment", - _pg_enum("network_environment", ["production", "staging", "development", "lab"]) if dialect == "postgresql" - else sa.String(20), - nullable=False, - ), - sa.Column( - "request_mode", - _pg_enum("network_request_mode", ["open", "approval_required", "invite_only"]) if dialect == "postgresql" - else sa.String(20), - nullable=False, - ), - sa.Column("default_activation_lifetime_minutes", sa.Integer, nullable=False, default=480), - sa.Column("max_activation_lifetime_minutes", sa.Integer, nullable=True), - sa.Column("is_active", sa.Boolean, nullable=False, default=True), - ) - op.create_index( - "ix_portal_networks_org_zt", - "portal_networks", - ["organization_id", "zerotier_network_id"], - unique=True, - postgresql_where=sa.text("deleted_at IS NULL"), - ) - - # ── 3. devices ─────────────────────────────────────────────────────────── - - op.create_table( - "devices", - sa.Column("id", sa.String(36), primary_key=True), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False, index=True), - sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True), - sa.Column("node_id", sa.String(10), nullable=False, index=True), - sa.Column("device_nickname", sa.String(255), nullable=True), - sa.Column("hostname", sa.String(255), nullable=True), - sa.Column("asset_tag", sa.String(255), nullable=True), - sa.Column("serial_number", sa.String(255), nullable=True), - sa.Column( - "status", - _pg_enum("device_status", ["active", "inactive"]) if dialect == "postgresql" - else sa.String(20), - nullable=False, - default="active", - ), - ) - if dialect == "postgresql": - op.create_index( - "ix_devices_node_id_active", - "devices", - ["node_id"], - unique=True, - postgresql_where=sa.text("deleted_at IS NULL"), - ) - else: - op.create_index("ix_devices_node_id", "devices", ["node_id"], unique=False) - - # ── 4. user_network_approvals ───────────────────────────────────────────── - - op.create_table( - "user_network_approvals", - sa.Column("id", sa.String(36), primary_key=True), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True), - sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False, index=True), - sa.Column("portal_network_id", sa.String(36), sa.ForeignKey("portal_networks.id"), nullable=False, index=True), - sa.Column("granted_by_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=True), - sa.Column( - "grant_type", - _pg_enum("approval_grant_type", ["requested", "assigned"]) if dialect == "postgresql" - else sa.String(20), - nullable=False, - default="requested", - ), - sa.Column( - "state", - _pg_enum("approval_state", ["pending", "approved", "rejected", "revoked", "suspended"]) if dialect == "postgresql" - else sa.String(20), - nullable=False, - default="pending", - index=True, - ), - sa.Column("justification", sa.Text, nullable=True), - ) - op.create_index( - "ix_user_network_approvals_user_network", - "user_network_approvals", - ["user_id", "portal_network_id"], - unique=True, - postgresql_where=sa.text("deleted_at IS NULL"), - ) - - # ── 5. device_network_memberships ──────────────────────────────────────── - - op.create_table( - "device_network_memberships", - sa.Column("id", sa.String(36), primary_key=True), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True), - sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False, index=True), - sa.Column("device_id", sa.String(36), sa.ForeignKey("devices.id"), nullable=False, index=True), - sa.Column("portal_network_id", sa.String(36), sa.ForeignKey("portal_networks.id"), nullable=False, index=True), - sa.Column("user_network_approval_id", sa.String(36), sa.ForeignKey("user_network_approvals.id"), nullable=True, index=True), - sa.Column( - "state", - _pg_enum( - "membership_state", - [ - "pending_device_registration", "pending_request", - "pending_manager_approval", "approved_inactive", - "joined_deauthorized", "active_authorized", - "activation_expired", "suspended", "revoked", "rejected", - ], - ) if dialect == "postgresql" else sa.String(30), - nullable=False, - default="pending_device_registration", - index=True, - ), - sa.Column("join_seen", sa.Boolean, nullable=False, default=False), - sa.Column("currently_authorized", sa.Boolean, nullable=False, default=False), - sa.Column("approved_for_activation", sa.Boolean, nullable=False, default=True), - ) - op.create_index( - "ix_device_network_memberships_device_network", - "device_network_memberships", - ["device_id", "portal_network_id"], - unique=True, - postgresql_where=sa.text("deleted_at IS NULL"), - ) - - # ── 6. activation_sessions ──────────────────────────────────────────────── - - op.create_table( - "activation_sessions", - sa.Column("id", sa.String(36), primary_key=True), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True), - sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False, index=True), - sa.Column("device_network_membership_id", sa.String(36), sa.ForeignKey("device_network_memberships.id"), nullable=False, index=True), - sa.Column("authenticated_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("ended_at", sa.DateTime(timezone=True), nullable=True), - sa.Column( - "end_reason", - _pg_enum( - "activation_end_reason", - ["expired", "logout", "kill_switch", "manual_revoke", "approval_revoked", "admin_action"], - ) if dialect == "postgresql" else sa.String(20), - nullable=True, - ), - sa.Column("created_by", sa.String(36), sa.ForeignKey("users.id"), nullable=False), - ) - - # ── 7. zerotier_memberships ─────────────────────────────────────────────── - - op.create_table( - "zerotier_memberships", - sa.Column("id", sa.String(36), primary_key=True), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True), - sa.Column("device_network_membership_id", sa.String(36), sa.ForeignKey("device_network_memberships.id"), nullable=True, index=True), - sa.Column("zerotier_network_id", sa.String(16), nullable=False, index=True), - sa.Column("node_id", sa.String(10), nullable=False, index=True), - sa.Column("member_seen", sa.Boolean, nullable=False, default=False), - sa.Column("authorized", sa.Boolean, nullable=False, default=False), - sa.Column("join_seen_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("last_synced_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("raw_controller_payload", sa.JSON, nullable=True), - ) - op.create_index( - "ix_zerotier_memberships_network_node", - "zerotier_memberships", - ["zerotier_network_id", "node_id"], - unique=True, - ) - - # ── 8. kill_switch_events ──────────────────────────────────────────────── - - op.create_table( - "kill_switch_events", - sa.Column("id", sa.String(36), primary_key=True), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True), - sa.Column("target_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False, index=True), - sa.Column( - "scope", - _pg_enum("kill_switch_scope", ["organization", "global", "selected_networks"]) if dialect == "postgresql" - else sa.String(20), - nullable=False, - default="organization", - ), - sa.Column("triggered_by_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False), - sa.Column("reason", sa.Text, nullable=True), - sa.Column("network_ids", sa.JSON, nullable=True), - ) + # No-op — 023_zerotier_drop_legacy handles everything idempotently. + pass def downgrade(): - bind = op.get_bind() - dialect = bind.dialect.name - - op.drop_table("kill_switch_events") - op.drop_table("zerotier_memberships") - op.drop_table("activation_sessions") - op.drop_table("device_network_memberships") - op.drop_table("user_network_approvals") - op.drop_table("devices") - op.drop_table("portal_networks") - - if dialect == "postgresql": - op.execute("DROP TYPE IF EXISTS kill_switch_scope") - op.execute("DROP TYPE IF EXISTS device_status") - op.execute("DROP TYPE IF EXISTS activation_end_reason") - op.execute("DROP TYPE IF EXISTS membership_state") - op.execute("DROP TYPE IF EXISTS approval_state") - op.execute("DROP TYPE IF EXISTS approval_grant_type") - op.execute("DROP TYPE IF EXISTS network_request_mode") - op.execute("DROP TYPE IF EXISTS network_environment") + # No-op — 023_zerotier_drop_legacy handles rollback. + pass diff --git a/migrations/versions/022_add_command_events.py b/migrations/versions/022_add_command_events.py new file mode 100644 index 0000000..a410af7 --- /dev/null +++ b/migrations/versions/022_add_command_events.py @@ -0,0 +1,29 @@ +"""Merge zerotier + CA/sudo/api-key branches. + +Revision ID: 022_add_command_events +Revises: 020_zerotier, 021_merge_heads +Create Date: 2026-03-09 + +Pure merge-point for 020_zerotier and 021_merge_heads. +Revision ID kept as-is for compatibility with production databases that +already have '022_add_command_events' stamped in alembic_version. +""" + +from alembic import op + + +# --------------------------------------------------------------------------- +# revision identifiers +# --------------------------------------------------------------------------- +revision = "022_add_command_events" +down_revision = ("020_zerotier", "021_merge_heads") +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/migrations/versions/023_zerotier_drop_legacy.py b/migrations/versions/023_zerotier_drop_legacy.py new file mode 100644 index 0000000..c1a1608 --- /dev/null +++ b/migrations/versions/023_zerotier_drop_legacy.py @@ -0,0 +1,393 @@ +"""Apply ZeroTier tables and drop legacy SSH-session tables. + +Revision ID: 023_apply_zerotier_drop_legacy_ssh_tables +Revises: 022_add_command_events +Create Date: 2026-03-22 + +CONTEXT +------- +Migration 020_zerotier was never applied to the production database — the +alembic_version stamp jumped directly from a pre-zerotier revision to +022_add_command_events. This migration catches the DB up by: + + 1. Creating all ZeroTier / Portal Network tables (idempotent — every + create_table uses if_not_exists=True so it is safe to run on a DB + that already has some of these tables). + + 2. Dropping the legacy SSH-session tables that no longer have + corresponding ORM models: + - command_events (dropped first — has FKs to servers + host_sessions) + - sudo_events (dropped first — has FK to host_sessions) + - host_sessions (dropped second — referenced by the two above) + - servers (dropped last) + + All drops use IF EXISTS so the migration is also safe on a fresh DB + that ran 020_zerotier correctly (those tables would already be absent). + +PROD SAFETY +----------- +- All create_table calls use if_not_exists=True. +- All drop_table calls use IF EXISTS via op.execute() for tables that may + or may not be present. +- No data migration; no destructive schema change on tables that still + have ORM models. +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.engine.reflection import Inspector + + +# --------------------------------------------------------------------------- +revision = "023_zerotier_drop_legacy" +down_revision = "022_add_command_events" +branch_labels = None +depends_on = None +# --------------------------------------------------------------------------- + + +def _table_exists(conn, table: str) -> bool: + return Inspector.from_engine(conn).has_table(table) + + +def _index_exists(conn, table: str, index: str) -> bool: + insp = Inspector.from_engine(conn) + return any(i["name"] == index for i in insp.get_indexes(table)) if _table_exists(conn, table) else False + + +def _type_exists(conn, type_name: str) -> bool: + result = conn.execute( + sa.text("SELECT 1 FROM pg_type WHERE typname = :t"), + {"t": type_name}, + ).scalar() + return bool(result) + + +def _pg_enum(name: str) -> sa.Text: + """Return a plain Text column type for use inside create_table. + + We rely on the enum type already existing in PostgreSQL (created above via + 'CREATE TYPE ... IF NOT EXISTS'). Using sa.String avoids SQLAlchemy's + automatic 'CREATE TYPE' emission inside create_table, which would fail if + the type already exists. A cast via server_default / CHECK constraint is + not required — PostgreSQL accepts varchar literals for enum columns when + inserted from SQLAlchemy's ORM layer, which uses the Python Enum type map. + """ + return sa.String(40) + + +# --------------------------------------------------------------------------- +# upgrade +# --------------------------------------------------------------------------- + +def upgrade(): + conn = op.get_bind() + dialect = conn.dialect.name + + # ── 1. Enum types (PostgreSQL only, idempotent) ─────────────────────────── + + if dialect == "postgresql": + enum_defs = { + "network_environment": ["production", "staging", "development", "lab"], + "network_request_mode": ["open", "approval_required", "invite_only"], + "approval_grant_type": ["requested", "assigned"], + "approval_state": ["pending", "approved", "rejected", "revoked", "suspended"], + "membership_state": [ + "pending_device_registration", "pending_request", + "pending_manager_approval", "approved_inactive", + "joined_deauthorized", "active_authorized", + "activation_expired", "suspended", "revoked", "rejected", + ], + "activation_end_reason": [ + "expired", "logout", "kill_switch", + "manual_revoke", "approval_revoked", "admin_action", + ], + "kill_switch_scope": ["organization", "global", "selected_networks"], + "device_status": ["active", "inactive"], + } + for type_name, values in enum_defs.items(): + if not _type_exists(conn, type_name): + quoted = ", ".join(f"'{v}'" for v in values) + conn.execute(sa.text(f"CREATE TYPE {type_name} AS ENUM ({quoted})")) + + # ── 2. portal_networks ──────────────────────────────────────────────────── + + op.create_table( + "portal_networks", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("description", sa.Text, nullable=True), + sa.Column("owner_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False), + sa.Column("zerotier_network_id", sa.String(16), nullable=False), + sa.Column("environment", sa.String(40), nullable=False), + sa.Column("request_mode", sa.String(40), nullable=False), + sa.Column("default_activation_lifetime_minutes", sa.Integer, nullable=False, server_default="480"), + sa.Column("max_activation_lifetime_minutes", sa.Integer, nullable=True), + sa.Column("is_active", sa.Boolean, nullable=False, server_default="true"), + if_not_exists=True, + ) + if not _index_exists(conn, "portal_networks", "ix_portal_networks_organization_id"): + op.create_index("ix_portal_networks_organization_id", "portal_networks", ["organization_id"]) + if not _index_exists(conn, "portal_networks", "ix_portal_networks_zerotier_network_id"): + op.create_index("ix_portal_networks_zerotier_network_id", "portal_networks", ["zerotier_network_id"]) + if not _index_exists(conn, "portal_networks", "ix_portal_networks_org_zt"): + op.create_index( + "ix_portal_networks_org_zt", "portal_networks", + ["organization_id", "zerotier_network_id"], + unique=True, + postgresql_where=sa.text("deleted_at IS NULL"), + ) + + # ── 3. devices ──────────────────────────────────────────────────────────── + + op.create_table( + "devices", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False), + sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False), + sa.Column("node_id", sa.String(10), nullable=False), + sa.Column("device_nickname", sa.String(255), nullable=True), + sa.Column("hostname", sa.String(255), nullable=True), + sa.Column("asset_tag", sa.String(255), nullable=True), + sa.Column("serial_number", sa.String(255), nullable=True), + sa.Column("status", sa.String(40), nullable=False, server_default="active"), + if_not_exists=True, + ) + if not _index_exists(conn, "devices", "ix_devices_user_id"): + op.create_index("ix_devices_user_id", "devices", ["user_id"]) + if not _index_exists(conn, "devices", "ix_devices_organization_id"): + op.create_index("ix_devices_organization_id", "devices", ["organization_id"]) + if not _index_exists(conn, "devices", "ix_devices_node_id_active") and dialect == "postgresql": + op.create_index( + "ix_devices_node_id_active", "devices", ["node_id"], + unique=True, + postgresql_where=sa.text("deleted_at IS NULL"), + ) + elif not _index_exists(conn, "devices", "ix_devices_node_id") and dialect != "postgresql": + op.create_index("ix_devices_node_id", "devices", ["node_id"]) + + # ── 4. user_network_approvals ───────────────────────────────────────────── + + op.create_table( + "user_network_approvals", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False), + sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False), + sa.Column("portal_network_id", sa.String(36), sa.ForeignKey("portal_networks.id"), nullable=False), + sa.Column("granted_by_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=True), + sa.Column("grant_type", sa.String(40), nullable=False, server_default="requested"), + sa.Column("state", sa.String(40), nullable=False, server_default="pending"), + sa.Column("justification", sa.Text, nullable=True), + if_not_exists=True, + ) + if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_organization_id"): + op.create_index("ix_user_network_approvals_organization_id", "user_network_approvals", ["organization_id"]) + if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_user_id"): + op.create_index("ix_user_network_approvals_user_id", "user_network_approvals", ["user_id"]) + if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_portal_network_id"): + op.create_index("ix_user_network_approvals_portal_network_id", "user_network_approvals", ["portal_network_id"]) + if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_state"): + op.create_index("ix_user_network_approvals_state", "user_network_approvals", ["state"]) + if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_user_network"): + op.create_index( + "ix_user_network_approvals_user_network", "user_network_approvals", + ["user_id", "portal_network_id"], + unique=True, + postgresql_where=sa.text("deleted_at IS NULL"), + ) + + # ── 5. device_network_memberships ───────────────────────────────────────── + + op.create_table( + "device_network_memberships", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False), + sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False), + sa.Column("device_id", sa.String(36), sa.ForeignKey("devices.id"), nullable=False), + sa.Column("portal_network_id", sa.String(36), sa.ForeignKey("portal_networks.id"), nullable=False), + sa.Column("user_network_approval_id", sa.String(36), sa.ForeignKey("user_network_approvals.id"), nullable=True), + sa.Column("state", sa.String(40), nullable=False, server_default="pending_device_registration"), + sa.Column("join_seen", sa.Boolean, nullable=False, server_default="false"), + sa.Column("currently_authorized", sa.Boolean, nullable=False, server_default="false"), + sa.Column("approved_for_activation", sa.Boolean, nullable=False, server_default="true"), + if_not_exists=True, + ) + if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_organization_id"): + op.create_index("ix_device_network_memberships_organization_id", "device_network_memberships", ["organization_id"]) + if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_user_id"): + op.create_index("ix_device_network_memberships_user_id", "device_network_memberships", ["user_id"]) + if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_device_id"): + op.create_index("ix_device_network_memberships_device_id", "device_network_memberships", ["device_id"]) + if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_portal_network_id"): + op.create_index("ix_device_network_memberships_portal_network_id", "device_network_memberships", ["portal_network_id"]) + if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_state"): + op.create_index("ix_device_network_memberships_state", "device_network_memberships", ["state"]) + if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_user_network_approval_id"): + op.create_index("ix_device_network_memberships_user_network_approval_id", "device_network_memberships", ["user_network_approval_id"]) + if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_device_network"): + op.create_index( + "ix_device_network_memberships_device_network", "device_network_memberships", + ["device_id", "portal_network_id"], + unique=True, + postgresql_where=sa.text("deleted_at IS NULL"), + ) + + # ── 6. activation_sessions ──────────────────────────────────────────────── + + op.create_table( + "activation_sessions", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False), + sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False), + sa.Column("device_network_membership_id", sa.String(36), sa.ForeignKey("device_network_memberships.id"), nullable=False), + sa.Column("authenticated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("ended_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("end_reason", sa.String(40), nullable=True), + sa.Column("created_by", sa.String(36), sa.ForeignKey("users.id"), nullable=False), + if_not_exists=True, + ) + if not _index_exists(conn, "activation_sessions", "ix_activation_sessions_organization_id"): + op.create_index("ix_activation_sessions_organization_id", "activation_sessions", ["organization_id"]) + if not _index_exists(conn, "activation_sessions", "ix_activation_sessions_user_id"): + op.create_index("ix_activation_sessions_user_id", "activation_sessions", ["user_id"]) + if not _index_exists(conn, "activation_sessions", "ix_activation_sessions_device_network_membership_id"): + op.create_index("ix_activation_sessions_device_network_membership_id", "activation_sessions", ["device_network_membership_id"]) + + # ── 7. zerotier_memberships ─────────────────────────────────────────────── + + op.create_table( + "zerotier_memberships", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False), + sa.Column("device_network_membership_id", sa.String(36), sa.ForeignKey("device_network_memberships.id"), nullable=True), + sa.Column("zerotier_network_id", sa.String(16), nullable=False), + sa.Column("node_id", sa.String(10), nullable=False), + sa.Column("member_seen", sa.Boolean, nullable=False, server_default="false"), + sa.Column("authorized", sa.Boolean, nullable=False, server_default="false"), + sa.Column("join_seen_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_synced_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("raw_controller_payload", sa.JSON, nullable=True), + if_not_exists=True, + ) + if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_organization_id"): + op.create_index("ix_zerotier_memberships_organization_id", "zerotier_memberships", ["organization_id"]) + if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_device_network_membership_id"): + op.create_index("ix_zerotier_memberships_device_network_membership_id", "zerotier_memberships", ["device_network_membership_id"]) + if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_zerotier_network_id"): + op.create_index("ix_zerotier_memberships_zerotier_network_id", "zerotier_memberships", ["zerotier_network_id"]) + if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_node_id"): + op.create_index("ix_zerotier_memberships_node_id", "zerotier_memberships", ["node_id"]) + if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_network_node"): + op.create_index( + "ix_zerotier_memberships_network_node", "zerotier_memberships", + ["zerotier_network_id", "node_id"], + unique=True, + ) + + # ── 8. kill_switch_events ──────────────────────────────────────────────── + + op.create_table( + "kill_switch_events", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False), + sa.Column("target_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False), + sa.Column("scope", sa.String(40), nullable=False, server_default="organization"), + sa.Column("triggered_by_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False), + sa.Column("reason", sa.Text, nullable=True), + sa.Column("network_ids", sa.JSON, nullable=True), + if_not_exists=True, + ) + if not _index_exists(conn, "kill_switch_events", "ix_kill_switch_events_organization_id"): + op.create_index("ix_kill_switch_events_organization_id", "kill_switch_events", ["organization_id"]) + if not _index_exists(conn, "kill_switch_events", "ix_kill_switch_events_target_user_id"): + op.create_index("ix_kill_switch_events_target_user_id", "kill_switch_events", ["target_user_id"]) + + # ── 9. Drop legacy SSH-session tables (IF EXISTS — safe on fresh DBs) ───── + # + # Order matters due to FK constraints: + # command_events → servers, host_sessions + # sudo_events → host_sessions + # host_sessions → (nothing that still exists) + # servers → (nothing that still exists) + + conn.execute(sa.text("DROP TABLE IF EXISTS command_events CASCADE")) + conn.execute(sa.text("DROP TABLE IF EXISTS sudo_events CASCADE")) + conn.execute(sa.text("DROP TABLE IF EXISTS host_sessions CASCADE")) + conn.execute(sa.text("DROP TABLE IF EXISTS servers CASCADE")) + + +# --------------------------------------------------------------------------- +# downgrade +# --------------------------------------------------------------------------- + +def downgrade(): + conn = op.get_bind() + dialect = conn.dialect.name + + # Re-create the legacy tables (minimal — enough for FK integrity) + op.create_table( + "servers", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False), + sa.Column("hostname", sa.String(255), nullable=False), + sa.Column("display_name", sa.String(255), nullable=True), + sa.Column("ip_address", sa.String(64), nullable=True), + sa.Column("is_active", sa.Boolean, nullable=False, server_default="true"), + if_not_exists=True, + ) + + op.create_table( + "host_sessions", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False), + sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False), + sa.Column("server_id", sa.String(36), sa.ForeignKey("servers.id"), nullable=False), + if_not_exists=True, + ) + + # Drop ZeroTier tables + op.drop_table("kill_switch_events", if_exists=True) + op.drop_table("zerotier_memberships", if_exists=True) + op.drop_table("activation_sessions", if_exists=True) + op.drop_table("device_network_memberships", if_exists=True) + op.drop_table("user_network_approvals", if_exists=True) + op.drop_table("devices", if_exists=True) + op.drop_table("portal_networks", if_exists=True) + + # Drop ZeroTier enum types + if dialect == "postgresql": + for t in [ + "kill_switch_scope", "device_status", "activation_end_reason", + "membership_state", "approval_state", "approval_grant_type", + "network_request_mode", "network_environment", + ]: + conn.execute(sa.text(f"DROP TYPE IF EXISTS {t}")) diff --git a/migrations/versions/024_fix_zerotier_schema.py b/migrations/versions/024_fix_zerotier_schema.py new file mode 100644 index 0000000..11c9aa2 --- /dev/null +++ b/migrations/versions/024_fix_zerotier_schema.py @@ -0,0 +1,291 @@ +"""Fix ZeroTier table schema: enum types, unique constraints, indexes, drop cert_token. + +Revision ID: 024_fix_zerotier_schema +Revises: 023_zerotier_drop_legacy +Create Date: 2026-03-22 + +Addresses all `db check` differences after 023: + - Cast VARCHAR(40) enum columns to their proper PostgreSQL enum types + (guarded — skipped if columns are already native enum, e.g. on a fresh DB + where 020_zerotier created them correctly) + - Replace partial unique indexes with named UniqueConstraints + - Fix devices.node_id partial index -> plain index + - Add UniqueConstraint on `id` for all new ZeroTier tables (BaseModel.unique=True) + - Drop orphan cert_token column and its index from ssh_certificates +""" + +from alembic import op +import sqlalchemy as sa + +revision = "024_fix_zerotier_schema" +down_revision = "023_zerotier_drop_legacy" +branch_labels = None +depends_on = None + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +def _col_data_type(conn, table: str, column: str) -> str | None: + """Return the PostgreSQL data_type string for a column, or None.""" + row = conn.execute(sa.text( + "SELECT data_type FROM information_schema.columns " + "WHERE table_name = :t AND column_name = :c" + ), {"t": table, "c": column}).first() + return row[0] if row else None + + +def _column_exists(conn, table: str, column: str) -> bool: + return _col_data_type(conn, table, column) is not None + + +def _index_exists(conn, table: str, index: str) -> bool: + from sqlalchemy.engine.reflection import Inspector + insp = Inspector.from_engine(conn) + return any(i["name"] == index for i in insp.get_indexes(table)) + + +def _constraint_exists(conn, constraint: str) -> bool: + row = conn.execute(sa.text( + "SELECT 1 FROM information_schema.table_constraints " + "WHERE constraint_name = :c" + ), {"c": constraint}).first() + return row is not None + + +def upgrade(): + conn = op.get_bind() + + # ------------------------------------------------------------------------- + # 1. Cast VARCHAR(40) enum columns to proper PostgreSQL enum types. + # GUARDED: On a fresh DB, 020_zerotier already created these as native + # enum types. We only cast if the column is currently 'character varying'. + # ------------------------------------------------------------------------- + enum_casts = [ + ("portal_networks", "environment", "network_environment", None), + ("portal_networks", "request_mode", "network_request_mode", None), + ("devices", "status", "device_status", "'active'::device_status"), + ("device_network_memberships", "state", "membership_state", "'pending_device_registration'::membership_state"), + ("user_network_approvals", "grant_type", "approval_grant_type", "'requested'::approval_grant_type"), + ("user_network_approvals", "state", "approval_state", "'pending'::approval_state"), + ("activation_sessions", "end_reason", "activation_end_reason", None), + ("kill_switch_events", "scope", "kill_switch_scope", "'organization'::kill_switch_scope"), + ] + for table, col, enum_type, new_default in enum_casts: + dtype = _col_data_type(conn, table, col) + if dtype == "character varying": + conn.execute(sa.text(f'ALTER TABLE "{table}" ALTER COLUMN "{col}" DROP DEFAULT')) + conn.execute(sa.text( + f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {enum_type} ' + f'USING "{col}"::text::{enum_type}' + )) + if new_default: + conn.execute(sa.text( + f'ALTER TABLE "{table}" ALTER COLUMN "{col}" SET DEFAULT {new_default}' + )) + elif dtype == "USER-DEFINED" and new_default: + # Already native enum (fresh DB path). Ensure server_default is set + # if 020 used `default=` (Python-side) instead of `server_default=`. + # This is harmless — SET DEFAULT is idempotent. + conn.execute(sa.text( + f'ALTER TABLE "{table}" ALTER COLUMN "{col}" SET DEFAULT {new_default}' + )) + + # ------------------------------------------------------------------------- + # 2. portal_networks: drop partial unique index, add named UniqueConstraint + # ------------------------------------------------------------------------- + if _index_exists(conn, "portal_networks", "ix_portal_networks_org_zt"): + op.drop_index("ix_portal_networks_org_zt", table_name="portal_networks") + if not _constraint_exists(conn, "uix_org_zt_network_id"): + op.create_unique_constraint( + "uix_org_zt_network_id", + "portal_networks", + ["organization_id", "zerotier_network_id"], + ) + + # ------------------------------------------------------------------------- + # 3. device_network_memberships: drop partial unique index, add named UC + # ------------------------------------------------------------------------- + if _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_device_network"): + op.drop_index("ix_device_network_memberships_device_network", table_name="device_network_memberships") + if not _constraint_exists(conn, "uix_device_network"): + op.create_unique_constraint( + "uix_device_network", + "device_network_memberships", + ["device_id", "portal_network_id", "deleted_at"], + ) + + # ------------------------------------------------------------------------- + # 4. user_network_approvals: drop partial unique index, add named UC + # ------------------------------------------------------------------------- + if _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_user_network"): + op.drop_index("ix_user_network_approvals_user_network", table_name="user_network_approvals") + if not _constraint_exists(conn, "uix_user_network_approval"): + op.create_unique_constraint( + "uix_user_network_approval", + "user_network_approvals", + ["user_id", "portal_network_id", "deleted_at"], + ) + + # ------------------------------------------------------------------------- + # 5. zerotier_memberships: drop index, add named UniqueConstraint + # ------------------------------------------------------------------------- + if _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_network_node"): + op.drop_index("ix_zerotier_memberships_network_node", table_name="zerotier_memberships") + if not _constraint_exists(conn, "uix_zt_network_node"): + op.create_unique_constraint( + "uix_zt_network_node", + "zerotier_memberships", + ["zerotier_network_id", "node_id"], + ) + + # ------------------------------------------------------------------------- + # 6. devices.node_id: drop partial unique index, add plain non-unique index + # ------------------------------------------------------------------------- + if _index_exists(conn, "devices", "ix_devices_node_id_active"): + op.drop_index("ix_devices_node_id_active", table_name="devices") + if not _index_exists(conn, "devices", "ix_devices_node_id"): + op.create_index("ix_devices_node_id", "devices", ["node_id"]) + + # ------------------------------------------------------------------------- + # 7. Add UniqueConstraint on `id` for all ZeroTier tables + # BaseModel defines id with unique=True → separate _id_key constraint. + # ------------------------------------------------------------------------- + zt_tables = [ + "portal_networks", + "devices", + "device_network_memberships", + "user_network_approvals", + "activation_sessions", + "zerotier_memberships", + "kill_switch_events", + ] + for tbl in zt_tables: + cname = f"{tbl}_id_key" + if not _constraint_exists(conn, cname): + op.create_unique_constraint(cname, tbl, ["id"]) + + # ------------------------------------------------------------------------- + # 8. Drop orphan cert_token column and its index from ssh_certificates. + # cert_token was created by 3de11c5dc2d5 but the SSHCertificate model + # never uses it. Guarded in case a future revision removes it first. + # ------------------------------------------------------------------------- + if _index_exists(conn, "ssh_certificates", "ix_ssh_certificates_cert_token"): + op.drop_index("ix_ssh_certificates_cert_token", table_name="ssh_certificates") + if _column_exists(conn, "ssh_certificates", "cert_token"): + op.drop_column("ssh_certificates", "cert_token") + + +def downgrade(): + conn = op.get_bind() + + # Restore cert_token if it was dropped + if not _column_exists(conn, "ssh_certificates", "cert_token"): + op.add_column( + "ssh_certificates", + sa.Column("cert_token", sa.String(64), nullable=True), + ) + if not _index_exists(conn, "ssh_certificates", "ix_ssh_certificates_cert_token"): + op.create_index( + "ix_ssh_certificates_cert_token", + "ssh_certificates", + ["cert_token"], + unique=True, + ) + + # Drop id unique constraints on ZeroTier tables + zt_tables = [ + "portal_networks", + "devices", + "device_network_memberships", + "user_network_approvals", + "activation_sessions", + "zerotier_memberships", + "kill_switch_events", + ] + for tbl in zt_tables: + cname = f"{tbl}_id_key" + if _constraint_exists(conn, cname): + op.drop_constraint(cname, tbl, type_="unique") + + # Restore devices node_id index + if _index_exists(conn, "devices", "ix_devices_node_id"): + op.drop_index("ix_devices_node_id", table_name="devices") + if not _index_exists(conn, "devices", "ix_devices_node_id_active"): + op.create_index( + "ix_devices_node_id_active", + "devices", + ["node_id"], + unique=True, + postgresql_where=sa.text("deleted_at IS NULL"), + ) + + # Restore zerotier_memberships index + if _constraint_exists(conn, "uix_zt_network_node"): + op.drop_constraint("uix_zt_network_node", "zerotier_memberships", type_="unique") + if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_network_node"): + op.create_index( + "ix_zerotier_memberships_network_node", + "zerotier_memberships", + ["zerotier_network_id", "node_id"], + unique=True, + ) + + # Restore user_network_approvals partial unique index + if _constraint_exists(conn, "uix_user_network_approval"): + op.drop_constraint("uix_user_network_approval", "user_network_approvals", type_="unique") + if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_user_network"): + op.create_index( + "ix_user_network_approvals_user_network", + "user_network_approvals", + ["user_id", "portal_network_id"], + unique=True, + postgresql_where=sa.text("deleted_at IS NULL"), + ) + + # Restore device_network_memberships partial unique index + if _constraint_exists(conn, "uix_device_network"): + op.drop_constraint("uix_device_network", "device_network_memberships", type_="unique") + if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_device_network"): + op.create_index( + "ix_device_network_memberships_device_network", + "device_network_memberships", + ["device_id", "portal_network_id"], + unique=True, + postgresql_where=sa.text("deleted_at IS NULL"), + ) + + # Restore portal_networks partial unique index + if _constraint_exists(conn, "uix_org_zt_network_id"): + op.drop_constraint("uix_org_zt_network_id", "portal_networks", type_="unique") + if not _index_exists(conn, "portal_networks", "ix_portal_networks_org_zt"): + op.create_index( + "ix_portal_networks_org_zt", + "portal_networks", + ["organization_id", "zerotier_network_id"], + unique=True, + postgresql_where=sa.text("deleted_at IS NULL"), + ) + + # Cast enum columns back to VARCHAR(40) — only if currently native enum + enum_casts = [ + ("portal_networks", "environment", "'development'::character varying"), + ("portal_networks", "request_mode", "'approval_required'::character varying"), + ("devices", "status", "'active'::character varying"), + ("device_network_memberships", "state", "'pending_device_registration'::character varying"), + ("user_network_approvals", "grant_type", "'requested'::character varying"), + ("user_network_approvals", "state", "'pending'::character varying"), + ("activation_sessions", "end_reason", None), + ("kill_switch_events", "scope", "'organization'::character varying"), + ] + for table, col, old_default in enum_casts: + conn.execute(sa.text(f'ALTER TABLE "{table}" ALTER COLUMN "{col}" DROP DEFAULT')) + conn.execute(sa.text( + f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE VARCHAR(40) ' + f'USING "{col}"::text' + )) + if old_default: + conn.execute(sa.text( + f'ALTER TABLE "{table}" ALTER COLUMN "{col}" SET DEFAULT {old_default}' + )) diff --git a/migrations/versions/025_fix_zt_timestamps.py b/migrations/versions/025_fix_zt_timestamps.py new file mode 100644 index 0000000..2b0029f --- /dev/null +++ b/migrations/versions/025_fix_zt_timestamps.py @@ -0,0 +1,101 @@ +"""Convert ZeroTier table timestamp columns from TIMESTAMPTZ to TIMESTAMP. + +Revision ID: 025_fix_zt_timestamps +Revises: 024_fix_zerotier_schema +Create Date: 2026-03-22 + +Migration 020_zerotier (and 023's fallback create_table) defined ZeroTier tables +with sa.DateTime(timezone=True), producing TIMESTAMP WITH TIME ZONE columns. +The rest of the codebase uses plain DateTime (timezone-naive TIMESTAMP WITHOUT +TIME ZONE). This migration aligns all ZeroTier table timestamp columns with the +existing codebase convention. + +GUARDED: Each ALTER is only executed if the column is currently +TIMESTAMP WITH TIME ZONE. On a DB that has already been converted (e.g. dev), +the migration is a harmless no-op. +""" + +from alembic import op +import sqlalchemy as sa + +revision = "025_fix_zt_timestamps" +down_revision = "024_fix_zerotier_schema" +branch_labels = None +depends_on = None + +# All ZeroTier tables that inherit BaseModel's created_at/updated_at/deleted_at +_ZT_BASE_TABLES = [ + "portal_networks", + "devices", + "device_network_memberships", + "user_network_approvals", + "kill_switch_events", + "activation_sessions", + "zerotier_memberships", +] + +# Additional datetime columns specific to individual models +_EXTRA_COLS = { + "activation_sessions": ["authenticated_at", "expires_at", "ended_at"], + "zerotier_memberships": ["join_seen_at", "last_synced_at"], +} + + +def _col_is_timestamptz(conn, table: str, column: str) -> bool: + """Return True if the column is TIMESTAMP WITH TIME ZONE.""" + row = conn.execute(sa.text( + "SELECT data_type FROM information_schema.columns " + "WHERE table_name = :t AND column_name = :c" + ), {"t": table, "c": column}).first() + return row is not None and row[0] == "timestamp with time zone" + + +def _col_is_timestamp(conn, table: str, column: str) -> bool: + """Return True if the column is TIMESTAMP WITHOUT TIME ZONE.""" + row = conn.execute(sa.text( + "SELECT data_type FROM information_schema.columns " + "WHERE table_name = :t AND column_name = :c" + ), {"t": table, "c": column}).first() + return row is not None and row[0] == "timestamp without time zone" + + +def upgrade(): + conn = op.get_bind() + + for tbl in _ZT_BASE_TABLES: + for col in ("created_at", "updated_at", "deleted_at"): + if _col_is_timestamptz(conn, tbl, col): + conn.execute(sa.text( + f'ALTER TABLE "{tbl}" ALTER COLUMN "{col}" ' + f'TYPE TIMESTAMP WITHOUT TIME ZONE ' + f'USING "{col}" AT TIME ZONE \'UTC\'' + )) + for col in _EXTRA_COLS.get(tbl, []): + if _col_is_timestamptz(conn, tbl, col): + conn.execute(sa.text( + f'ALTER TABLE "{tbl}" ALTER COLUMN "{col}" ' + f'TYPE TIMESTAMP WITHOUT TIME ZONE ' + f'USING CASE WHEN "{col}" IS NULL THEN NULL ' + f'ELSE "{col}" AT TIME ZONE \'UTC\' END' + )) + + +def downgrade(): + conn = op.get_bind() + + for tbl in _ZT_BASE_TABLES: + for col in ("created_at", "updated_at", "deleted_at"): + if _col_is_timestamp(conn, tbl, col): + conn.execute(sa.text( + f'ALTER TABLE "{tbl}" ALTER COLUMN "{col}" ' + f'TYPE TIMESTAMP WITH TIME ZONE ' + f'USING "{col}" AT TIME ZONE \'UTC\'' + )) + for col in _EXTRA_COLS.get(tbl, []): + if _col_is_timestamp(conn, tbl, col): + conn.execute(sa.text( + f'ALTER TABLE "{tbl}" ALTER COLUMN "{col}" ' + f'TYPE TIMESTAMP WITH TIME ZONE ' + f'USING CASE WHEN "{col}" IS NULL THEN NULL ' + f'ELSE "{col}" AT TIME ZONE \'UTC\' END' + )) diff --git a/migrations/versions/026_schema_cleanup.py b/migrations/versions/026_schema_cleanup.py new file mode 100644 index 0000000..15e5ef0 --- /dev/null +++ b/migrations/versions/026_schema_cleanup.py @@ -0,0 +1,216 @@ +"""Schema cleanup: id UniqueConstraints, organization_api_keys index/timestamp fixes. + +Revision ID: 026_schema_cleanup +Revises: 025_fix_zt_timestamps +Create Date: 2026-03-23 + +Addresses all `db check` differences after 025 on a database upgraded from +production (021_merge_heads): + +1. Add UniqueConstraint on `id` for all pre-existing tables that inherit + BaseModel (which declares id with unique=True). The ZeroTier tables + already got these in 024_fix_zerotier_schema; this covers the rest. + +2. organization_api_keys — fix schema drift vs. the current model: + - TIMESTAMPTZ → TIMESTAMP WITHOUT TIME ZONE (align with rest of codebase) + - Drop legacy unique constraint 'organization_api_keys_key_hash_key' + and replace with named index 'ix_organization_api_keys_key_hash' + - Drop extra index 'idx_org_api_key_org_id' (superseded by + 'ix_organization_api_keys_organization_id') + - Add 'ix_organization_api_keys_organization_id' and + 'ix_organization_api_keys_is_revoked' named indexes expected by model + +3. Drop 'idx_dept_can_sudo' index from departments — created by an old + migration but not declared in the current Department model. + +All operations are guarded so the migration is safe to re-run. +""" + +from alembic import op +import sqlalchemy as sa + + +revision = "026_schema_cleanup" +down_revision = "025_fix_zt_timestamps" +branch_labels = None +depends_on = None + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +def _constraint_exists(conn, name: str) -> bool: + row = conn.execute(sa.text( + "SELECT 1 FROM information_schema.table_constraints " + "WHERE constraint_name = :n" + ), {"n": name}).first() + return row is not None + + +def _index_exists(conn, table: str, index: str) -> bool: + row = conn.execute(sa.text( + "SELECT 1 FROM pg_indexes " + "WHERE tablename = :t AND indexname = :i" + ), {"t": table, "i": index}).first() + return row is not None + + +def _col_is_timestamptz(conn, table: str, column: str) -> bool: + row = conn.execute(sa.text( + "SELECT data_type FROM information_schema.columns " + "WHERE table_name = :t AND column_name = :c" + ), {"t": table, "c": column}).first() + return row is not None and row[0] == "timestamp with time zone" + + +# --------------------------------------------------------------------------- +# Tables that inherit BaseModel and need an id UniqueConstraint. +# ZeroTier tables were handled in 024; all others are listed here. +# --------------------------------------------------------------------------- +_LEGACY_TABLES = [ + "application_provider_configs", + "audit_logs", + "authentication_methods", + "ca_permissions", + "cas", + "certificate_audit_logs", + "department_cert_policies", + "department_memberships", + "department_principals", + "departments", + "email_verification_tokens", + "external_provider_configs", + "mfa_policy_compliance", + "oauth_states", + "oidc_audit_logs", + "oidc_authorization_codes", + "oidc_clients", + "oidc_refresh_tokens", + "oidc_sessions", + # oidc_token_metadata intentionally excluded: its id column overrides + # BaseModel without unique=True (JTI is the PK but not separately unique) + "org_invite_tokens", + "organization_api_keys", + "organization_members", + "organization_provider_overrides", + "organization_security_policies", + "organizations", + "password_reset_tokens", + "principal_memberships", + "principals", + "sessions", + "ssh_certificates", + "ssh_keys", + "user_security_policies", + "users", +] + + +def upgrade(): + conn = op.get_bind() + + # ── 1. Add id UniqueConstraint to all legacy BaseModel tables ───────── + for tbl in _LEGACY_TABLES: + cname = f"{tbl}_id_key" + if not _constraint_exists(conn, cname): + op.create_unique_constraint(cname, tbl, ["id"]) + + # Drop the wrongly-added constraint on oidc_token_metadata if present + # (its id column overrides BaseModel without unique=True) + if _constraint_exists(conn, "oidc_token_metadata_id_key"): + op.drop_constraint("oidc_token_metadata_id_key", "oidc_token_metadata", type_="unique") + + # ── 2. organization_api_keys: timestamp columns TIMESTAMPTZ → TIMESTAMP + for col in ("created_at", "updated_at", "deleted_at", "last_used_at", "revoked_at"): + if _col_is_timestamptz(conn, "organization_api_keys", col): + conn.execute(sa.text( + f'ALTER TABLE organization_api_keys ALTER COLUMN "{col}" ' + f'TYPE TIMESTAMP WITHOUT TIME ZONE ' + f'USING CASE WHEN "{col}" IS NULL THEN NULL ' + f'ELSE "{col}" AT TIME ZONE \'UTC\' END' + )) + + # ── 3. organization_api_keys: replace legacy unique constraint + indexes + # Drop the anonymous unique constraint on key_hash (created by + # sa.UniqueConstraint('key_hash') in the original migration) + if _constraint_exists(conn, "organization_api_keys_key_hash_key"): + op.drop_constraint( + "organization_api_keys_key_hash_key", + "organization_api_keys", + type_="unique", + ) + # Add named unique index for key_hash expected by the model + if not _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_key_hash"): + op.create_index( + "ix_organization_api_keys_key_hash", + "organization_api_keys", + ["key_hash"], + unique=True, + ) + + # Drop the legacy plain org-id index (superseded by the named one below) + if _index_exists(conn, "organization_api_keys", "idx_org_api_key_org_id"): + op.drop_index("idx_org_api_key_org_id", table_name="organization_api_keys") + + # Add named org-id index expected by the model + if not _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_organization_id"): + op.create_index( + "ix_organization_api_keys_organization_id", + "organization_api_keys", + ["organization_id"], + ) + + # Add named is_revoked index expected by the model + if not _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_is_revoked"): + op.create_index( + "ix_organization_api_keys_is_revoked", + "organization_api_keys", + ["is_revoked"], + ) + + # ── 4. Drop orphan idx_dept_can_sudo from departments ───────────────── + if _index_exists(conn, "departments", "idx_dept_can_sudo"): + op.drop_index("idx_dept_can_sudo", table_name="departments") + + # NOTE: ix_ssh_certificates_serial uniqueness is handled in + # 027_fix_cert_serial_uniqueness (composite unique per CA). + + +def downgrade(): + conn = op.get_bind() + + # Restore idx_dept_can_sudo + if not _index_exists(conn, "departments", "idx_dept_can_sudo"): + op.create_index("idx_dept_can_sudo", "departments", ["organization_id", "can_sudo"]) + + # Restore organization_api_keys indexes + if _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_is_revoked"): + op.drop_index("ix_organization_api_keys_is_revoked", table_name="organization_api_keys") + if _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_organization_id"): + op.drop_index("ix_organization_api_keys_organization_id", table_name="organization_api_keys") + if _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_key_hash"): + op.drop_index("ix_organization_api_keys_key_hash", table_name="organization_api_keys") + if not _constraint_exists(conn, "organization_api_keys_key_hash_key"): + op.create_unique_constraint( + "organization_api_keys_key_hash_key", + "organization_api_keys", + ["key_hash"], + ) + if not _index_exists(conn, "organization_api_keys", "idx_org_api_key_org_id"): + op.create_index("idx_org_api_key_org_id", "organization_api_keys", ["organization_id"]) + + # Restore TIMESTAMPTZ on organization_api_keys + for col in ("created_at", "updated_at", "deleted_at", "last_used_at", "revoked_at"): + conn.execute(sa.text( + f'ALTER TABLE organization_api_keys ALTER COLUMN "{col}" ' + f'TYPE TIMESTAMP WITH TIME ZONE ' + f'USING CASE WHEN "{col}" IS NULL THEN NULL ' + f'ELSE "{col}" AT TIME ZONE \'UTC\' END' + )) + + # Drop id UniqueConstraints from legacy tables + for tbl in reversed(_LEGACY_TABLES): + cname = f"{tbl}_id_key" + if _constraint_exists(conn, cname): + op.drop_constraint(cname, tbl, type_="unique") diff --git a/migrations/versions/027_fix_cert_serial_uniqueness.py b/migrations/versions/027_fix_cert_serial_uniqueness.py new file mode 100644 index 0000000..3441b91 --- /dev/null +++ b/migrations/versions/027_fix_cert_serial_uniqueness.py @@ -0,0 +1,105 @@ +"""Fix ssh_certificates serial uniqueness: per-CA not global. + +Revision ID: 027_fix_cert_serial_uniqueness +Revises: 026_schema_cleanup +Create Date: 2026-03-23 + +The SSHCertificate model uses a per-CA monotonic serial counter, meaning +serial numbers are only unique within a single CA — not across the whole +table. The original migration created a global unique index on `serial` +alone, which is incorrect and was blocking enforcement (duplicate serial=1 +rows exist in production where two different CAs both issued their first +certificate). + +This migration: + 1. Drops the old non-unique index ix_ssh_certificates_serial (which was + never enforcing uniqueness — just an index). + 2. Drops the stale unique constraint ssh_certificates_serial_key if it + somehow exists. + 3. Creates a proper composite unique constraint uq_ssh_certificates_ca_serial + on (ca_id, serial), reflecting the real invariant: a serial is unique + within one CA. + +All operations are guarded (IF EXISTS / try/except) so this is safe to +re-run on any DB state. +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.engine.reflection import Inspector + + +# --------------------------------------------------------------------------- +# revision identifiers +# --------------------------------------------------------------------------- +revision = "027_fix_cert_serial_uniqueness" +down_revision = "026_schema_cleanup" +branch_labels = None +depends_on = None + + +def _index_exists(conn, table: str, index: str) -> bool: + insp = Inspector.from_engine(conn) + return any(i["name"] == index for i in insp.get_indexes(table)) + + +def _constraint_exists(conn, table: str, constraint: str) -> bool: + insp = Inspector.from_engine(conn) + for uc in insp.get_unique_constraints(table): + if uc["name"] == constraint: + return True + return False + + +def upgrade(): + conn = op.get_bind() + + # 1. Drop the old global non-unique index on serial (if present) + if _index_exists(conn, "ssh_certificates", "ix_ssh_certificates_serial"): + op.drop_index("ix_ssh_certificates_serial", table_name="ssh_certificates") + + # 2. Drop any stale global unique constraint on serial alone (defensive) + if _constraint_exists(conn, "ssh_certificates", "ssh_certificates_serial_key"): + op.drop_constraint( + "ssh_certificates_serial_key", + "ssh_certificates", + type_="unique", + ) + + # 3. Add composite unique constraint: serial is unique per CA + if not _constraint_exists(conn, "ssh_certificates", "uq_ssh_certificates_ca_serial"): + op.create_unique_constraint( + "uq_ssh_certificates_ca_serial", + "ssh_certificates", + ["ca_id", "serial"], + ) + + # 4. Re-create a plain non-unique index on serial for fast lookups + if not _index_exists(conn, "ssh_certificates", "ix_ssh_certificates_serial"): + op.create_index( + "ix_ssh_certificates_serial", + "ssh_certificates", + ["serial"], + unique=False, + ) + + +def downgrade(): + conn = op.get_bind() + + # Remove the composite constraint + if _constraint_exists(conn, "ssh_certificates", "uq_ssh_certificates_ca_serial"): + op.drop_constraint( + "uq_ssh_certificates_ca_serial", + "ssh_certificates", + type_="unique", + ) + + # Restore the old non-unique index (best effort — data may have duplicates) + if not _index_exists(conn, "ssh_certificates", "ix_ssh_certificates_serial"): + op.create_index( + "ix_ssh_certificates_serial", + "ssh_certificates", + ["serial"], + unique=False, + ) From 2b6f7e15af40282899020ff956513c044274eaf6 Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Sun, 29 Mar 2026 23:14:20 +0545 Subject: [PATCH 8/9] Feat(Fix): Multi-Tenant Zerotier Org Setups Imports Network From Zerotier Async Emails Migration guardrails Admin to see all approvals states --- config/base.py | 12 - gatehouse_app/api/v1/auth/core.py | 2 +- gatehouse_app/api/v1/auth/password.py | 6 +- gatehouse_app/api/v1/organizations/invites.py | 14 +- gatehouse_app/api/v1/organizations/members.py | 2 +- gatehouse_app/api/v1/zerotier.py | 355 +++++++++++++++++- gatehouse_app/exceptions/base.py | 5 +- .../models/organization/organization.py | 4 + .../models/zerotier/activation_session.py | 2 +- gatehouse_app/models/zerotier/device.py | 2 +- .../zerotier/device_network_membership.py | 2 +- .../models/zerotier/kill_switch_event.py | 2 +- .../models/zerotier/portal_network.py | 4 +- .../models/zerotier/user_network_approval.py | 4 +- .../services/network_access_service.py | 93 ++++- .../services/notification_service.py | 175 ++++----- .../services/portal_network_service.py | 122 +++++- .../services/zerotier_api_service.py | 112 ++++-- .../zerotier_reconciliation_service.py | 225 +++++++++-- gatehouse_app/utils/constants.py | 1 - .../versions/028_org_zerotier_config.py | 69 ++++ 21 files changed, 974 insertions(+), 239 deletions(-) create mode 100644 migrations/versions/028_org_zerotier_config.py diff --git a/config/base.py b/config/base.py index e4d9329..2ac6adb 100644 --- a/config/base.py +++ b/config/base.py @@ -132,18 +132,6 @@ class BaseConfig: OIDC_UI_URL = os.getenv("OIDC_UI_URL", os.getenv("FRONTEND_URL", "http://localhost:8080")) # ZeroTier Configuration - ZEROTIER_API_TOKEN = os.getenv("ZEROTIER_API_TOKEN", "") - ZEROTIER_API_URL = os.getenv( - "ZEROTIER_API_URL", - "http://localhost:9993", - ) - ZEROTIER_API_MODE = os.getenv("ZEROTIER_API_MODE", "controller").lower() - ZEROTIER_DEFAULT_ACTIVATION_LIFETIME_MINUTES = int( - os.getenv("ZEROTIER_DEFAULT_ACTIVATION_LIFETIME_MINUTES", "480") - ) - ZEROTIER_RECONCILIATION_INTERVAL_SECONDS = int( - os.getenv("ZEROTIER_RECONCILIATION_INTERVAL_SECONDS", "120") - ) # Email / SMTP EMAIL_ENABLED = os.getenv("EMAIL_ENABLED", "False").lower() == "true" diff --git a/gatehouse_app/api/v1/auth/core.py b/gatehouse_app/api/v1/auth/core.py index f401834..308633b 100644 --- a/gatehouse_app/api/v1/auth/core.py +++ b/gatehouse_app/api/v1/auth/core.py @@ -39,7 +39,7 @@ def register(): f"{verify_link}\n\n" f"Gatehouse Security Team" ) - NotificationService._send_email(to_address=user.email, subject=subject, body=body) + NotificationService._send_email_async(to_address=user.email, subject=subject, body=body) except Exception as exc: logging.getLogger(__name__).warning(f"Failed to send verification email on register: {exc}") diff --git a/gatehouse_app/api/v1/auth/password.py b/gatehouse_app/api/v1/auth/password.py index b972cbd..fac32fa 100644 --- a/gatehouse_app/api/v1/auth/password.py +++ b/gatehouse_app/api/v1/auth/password.py @@ -27,7 +27,7 @@ def forgot_password(): reset_token = PasswordResetToken.generate(user_id=user.id) app_url = current_app.config.get("APP_URL", "http://localhost:8080") reset_link = f"{app_url}/reset-password?token={reset_token.token}" - NotificationService._send_email( + NotificationService._send_email_async( to_address=user.email, subject="Reset your Gatehouse password", body=( @@ -129,7 +129,7 @@ def resend_verification(): verify_token = EmailVerificationToken.generate(user_id=user.id) app_url = current_app.config.get("APP_URL", "http://localhost:8080") verify_link = f"{app_url}/verify-email?token={verify_token.token}" - NotificationService._send_email( + NotificationService._send_email_async( to_address=user.email, subject="Verify your Gatehouse email address", body=( @@ -200,7 +200,7 @@ def resend_activation(): app_url = current_app.config.get("APP_URL", current_app.config.get("FRONTEND_URL", "http://localhost:8080")) activate_link = f"{app_url}/activate?code={code}" - NotificationService._send_email( + NotificationService._send_email_async( to_address=user.email, subject="Activate your Gatehouse account", body=( diff --git a/gatehouse_app/api/v1/organizations/invites.py b/gatehouse_app/api/v1/organizations/invites.py index 2920bc4..4e3971e 100644 --- a/gatehouse_app/api/v1/organizations/invites.py +++ b/gatehouse_app/api/v1/organizations/invites.py @@ -37,7 +37,7 @@ def create_org_invite(org_id): app_url = current_app.config.get("APP_URL", "http://localhost:8080") invite_link = f"{app_url}/invite?token={invite.token}" - email_sent = NotificationService._send_email( + NotificationService._send_email_async( to_address=email, subject=f"You're invited to join {org.name} on Gatehouse", body=( @@ -47,16 +47,8 @@ def create_org_invite(org_id): f"Gatehouse Security Team" ), ) - - if not email_sent: - logging.getLogger(__name__).warning( - f"[INVITE LINK] Email not sent (EMAIL_ENABLED=False or SMTP down). " - f"Invite for {email} → {invite_link}" - ) - else: - logging.getLogger(__name__).info( - f"[INVITE] Email sent successfully to {email}" - ) + logging.getLogger(__name__).info(f"[INVITE] Email queued for {email}") + email_sent = True # async — assume queued successfully response_data = { "invite": { diff --git a/gatehouse_app/api/v1/organizations/members.py b/gatehouse_app/api/v1/organizations/members.py index df594f5..c605104 100644 --- a/gatehouse_app/api/v1/organizations/members.py +++ b/gatehouse_app/api/v1/organizations/members.py @@ -161,7 +161,7 @@ def send_mfa_reminder(org_id, user_id): if compliance and policy and compliance.deadline_at: NotificationService.send_mfa_deadline_reminder(user, compliance, policy) else: - NotificationService._send_email( + NotificationService._send_email_async( to_address=user.email, subject="Reminder: Set up multi-factor authentication", body=( diff --git a/gatehouse_app/api/v1/zerotier.py b/gatehouse_app/api/v1/zerotier.py index 1468ce1..1ba82c9 100644 --- a/gatehouse_app/api/v1/zerotier.py +++ b/gatehouse_app/api/v1/zerotier.py @@ -2,8 +2,10 @@ from flask import g, request from marshmallow import Schema, fields, validate, ValidationError +from sqlalchemy.exc import IntegrityError from gatehouse_app.api.v1 import api_v1_bp +from gatehouse_app.extensions import db from gatehouse_app.utils.response import api_response from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required from gatehouse_app.services import portal_network_service @@ -19,6 +21,8 @@ from gatehouse_app.models import ( ActivationSession, ) from gatehouse_app.models.organization import Organization +from gatehouse_app.models.organization.organization_member import OrganizationMember +from gatehouse_app.utils.constants import OrganizationRole from gatehouse_app.exceptions import ( ValidationError as AppValidationError, ZeroTierAPIError, @@ -39,6 +43,17 @@ def _org_check(org_id): return org, None +def _is_org_admin(org_id: str, user_id: str) -> bool: + """Return True if the user is an admin or owner of the org.""" + return OrganizationMember.query.filter( + OrganizationMember.organization_id == org_id, + OrganizationMember.user_id == user_id, + OrganizationMember.role.in_([OrganizationRole.ADMIN, OrganizationRole.OWNER]), + OrganizationMember.deleted_at.is_(None), + ).first() is not None + + + # ── Schemas ─────────────────────────────────────────────────────────────────── @@ -154,6 +169,63 @@ def create_network(org_id): return api_response(success=False, message=str(e.message), status=400, error_type=e.error_type) except ZeroTierAPIError as e: return api_response(success=False, message=str(e), status=502, error_type=e.error_type) + except IntegrityError: + db.session.rollback() + return api_response( + success=False, + message="A portal network with this ZeroTier ID already exists in this organization.", + status=409, + error_type="DUPLICATE_NETWORK", + ) + + +@api_v1_bp.route("/organizations//zerotier/available-networks", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def list_zerotier_available_networks(org_id): + """List all ZeroTier networks from the org's ZT controller/account. + + Cross-references against managed portal networks so the UI can show + which ones are already imported and which can be imported. + """ + org, err = _org_check(org_id) + if err: + return err + + # Fetch all active portal networks for this org, keyed by ZT network ID + managed = { + pn.zerotier_network_id: pn + for pn in PortalNetwork.query.filter( + PortalNetwork.organization_id == org_id, + PortalNetwork.deleted_at.is_(None), + ).all() + } + + try: + zt_networks = zt.list_networks(organization_id=org_id) + except ZeroTierAPIError as e: + # Return an empty list with a flag so the UI can show a helpful message + # rather than an error page (e.g. "ZeroTier not configured yet"). + return api_response( + data={"networks": [], "count": 0, "zt_error": str(e)}, + message="ZeroTier unavailable — no networks returned", + ) + + result = [] + for zt_net in zt_networks: + portal = managed.get(zt_net.id) + result.append({ + **zt_net.to_dict(), + "already_managed": portal is not None, + "portal_network_id": portal.id if portal else None, + "portal_network_name": portal.name if portal else None, + }) + + return api_response( + data={"networks": result, "count": len(result)}, + message="Available ZeroTier networks retrieved", + ) @api_v1_bp.route("/organizations//networks/", methods=["GET"]) @@ -346,6 +418,9 @@ def update_device(org_id, device_id): except ValidationError as e: return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + if "nickname" in data: + data["device_nickname"] = data.pop("nickname") + try: device = device_service.update_device(device_id, g.current_user.id, **data) return api_response(data={"device": device.to_dict()}, message="Device updated successfully") @@ -520,6 +595,25 @@ def assign_access(org_id): return api_response(success=False, message=str(e.message), status=400, error_type=e.error_type) +@api_v1_bp.route("/organizations//admin/approvals", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def admin_list_all_approvals(org_id): + """List ALL approval records across all users in the org (admin only).""" + org, err = _org_check(org_id) + if err: + return err + + network_id = request.args.get("network_id") + state = request.args.get("state") + approvals = network_access_service.list_all_org_approvals(org_id, network_id=network_id, state=state) + return api_response( + data={"approvals": [a.to_dict() for a in approvals], "count": len(approvals)}, + message="Approvals retrieved successfully", + ) + + # ── Memberships ─────────────────────────────────────────────────────────────── @@ -548,7 +642,7 @@ def list_memberships(org_id): @login_required @full_access_required def activate_membership(org_id, membership_id): - """Activate an approved device membership.""" + """Activate an approved device membership. Admins can activate any membership; regular members can only activate their own.""" org, err = _org_check(org_id) if err: return err @@ -559,11 +653,14 @@ def activate_membership(org_id, membership_id): except ValidationError as e: return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) + is_admin = _is_org_admin(org_id, g.current_user.id) + try: session = network_access_service.activate_device_membership( membership_id=membership_id, user_id=g.current_user.id, lifetime_minutes=data.get("lifetime_minutes"), + admin_override=is_admin, ) membership = DeviceNetworkMembership.query.get(membership_id) return api_response(data={"session": session.to_dict(), "membership": membership.to_dict()}, message="Membership activated successfully") @@ -577,11 +674,21 @@ def activate_membership(org_id, membership_id): @login_required @full_access_required def deactivate_membership(org_id, membership_id): - """Deactivate an active device membership.""" + """Deactivate an active device membership. Admins can deactivate any; regular members can only deactivate their own.""" org, err = _org_check(org_id) if err: return err + # Verify ownership for non-admins + if not _is_org_admin(org_id, g.current_user.id): + membership_check = DeviceNetworkMembership.query.filter( + DeviceNetworkMembership.id == membership_id, + DeviceNetworkMembership.user_id == g.current_user.id, + DeviceNetworkMembership.deleted_at.is_(None), + ).first() + if not membership_check: + return api_response(success=False, message="Membership not found", status=404, error_type="NOT_FOUND") + try: membership = network_access_service.deactivate_membership( membership_id=membership_id, @@ -597,7 +704,7 @@ def deactivate_membership(org_id, membership_id): @login_required @full_access_required def activate_all_memberships(org_id): - """Bulk-activate all approved inactive memberships.""" + """Bulk-activate all of the caller's approved inactive memberships in this org.""" org, err = _org_check(org_id) if err: return err @@ -744,6 +851,7 @@ def trigger_kill_switch(org_id): event = network_access_service.kill_switch( target_user_id=data["target_user_id"], triggered_by_user_id=g.current_user.id, + organization_id=org_id, scope=data.get("scope", "organization"), reason=data.get("reason"), network_ids=data.get("network_ids"), @@ -794,12 +902,20 @@ def admin_delete_membership(org_id, membership_id): @api_v1_bp.route("/admin/zerotier/status", methods=["GET"]) @login_required -@require_admin @full_access_required def zerotier_status(): - """Check ZeroTier controller connectivity and status (admin only).""" + """Check ZeroTier controller connectivity and status. + + Requires ?org_id= — credentials are looked up from that org. + Caller must be an admin/owner of that specific org. + """ + org_id = request.args.get("org_id") + if not org_id: + return api_response(success=False, message="org_id query parameter is required", status=400, error_type="VALIDATION_ERROR") + if not _is_org_admin(org_id, g.current_user.id): + return api_response(success=False, message="Admin or owner role required for this organization", status=403, error_type="AUTHORIZATION_ERROR") try: - status = zt.get_status() + status = zt.get_status(organization_id=org_id) return api_response(data={"status": status}, message="ZeroTier controller is reachable") except ZeroTierAPIError as e: return api_response(success=False, message=str(e), status=502, error_type=e.error_type) @@ -807,12 +923,20 @@ def zerotier_status(): @api_v1_bp.route("/admin/zerotier/networks", methods=["GET"]) @login_required -@require_admin @full_access_required def zerotier_list_networks(): - """List networks from the ZeroTier controller (admin only).""" + """List networks from the ZeroTier controller. + + Requires ?org_id= — credentials are looked up from that org. + Caller must be an admin/owner of that specific org. + """ + org_id = request.args.get("org_id") + if not org_id: + return api_response(success=False, message="org_id query parameter is required", status=400, error_type="VALIDATION_ERROR") + if not _is_org_admin(org_id, g.current_user.id): + return api_response(success=False, message="Admin or owner role required for this organization", status=403, error_type="AUTHORIZATION_ERROR") try: - networks = zt.list_networks() + networks = zt.list_networks(organization_id=org_id) return api_response( data={"networks": [n.to_dict() if hasattr(n, 'to_dict') else {"id": getattr(n, "id", str(n))} for n in networks], "count": len(networks)}, message="Networks retrieved successfully", @@ -823,12 +947,20 @@ def zerotier_list_networks(): @api_v1_bp.route("/admin/zerotier/networks/", methods=["GET"]) @login_required -@require_admin @full_access_required def zerotier_get_network(network_id): - """Get a ZeroTier network from the controller (admin only).""" + """Get a ZeroTier network from the controller. + + Requires ?org_id= — credentials are looked up from that org. + Caller must be an admin/owner of that specific org. + """ + org_id = request.args.get("org_id") + if not org_id: + return api_response(success=False, message="org_id query parameter is required", status=400, error_type="VALIDATION_ERROR") + if not _is_org_admin(org_id, g.current_user.id): + return api_response(success=False, message="Admin or owner role required for this organization", status=403, error_type="AUTHORIZATION_ERROR") try: - network = zt.get_network(network_id) + network = zt.get_network(network_id, organization_id=org_id) return api_response(data={"network": network.to_dict()}, message="Network retrieved successfully") except ZeroTierAPIError as e: return api_response(success=False, message=str(e), status=502, error_type=e.error_type) @@ -836,12 +968,20 @@ def zerotier_get_network(network_id): @api_v1_bp.route("/admin/zerotier/networks//members", methods=["GET"]) @login_required -@require_admin @full_access_required def zerotier_list_members(network_id): - """List members on a ZeroTier network from the controller (admin only).""" + """List members on a ZeroTier network from the controller. + + Requires ?org_id= — credentials are looked up from that org. + Caller must be an admin/owner of that specific org. + """ + org_id = request.args.get("org_id") + if not org_id: + return api_response(success=False, message="org_id query parameter is required", status=400, error_type="VALIDATION_ERROR") + if not _is_org_admin(org_id, g.current_user.id): + return api_response(success=False, message="Admin or owner role required for this organization", status=403, error_type="AUTHORIZATION_ERROR") try: - members = zt.list_members(network_id) + members = zt.list_members(network_id, organization_id=org_id) return api_response( data={"members": [m.to_dict() for m in members], "count": len(members)}, message="Members retrieved successfully", @@ -852,9 +992,190 @@ def zerotier_list_members(network_id): @api_v1_bp.route("/admin/zerotier/reconcile", methods=["POST"]) @login_required -@require_admin @full_access_required def trigger_reconciliation(): - """Trigger full reconciliation across all networks (admin only).""" + """Trigger full reconciliation across all networks (requires org admin in at least one org).""" + from gatehouse_app.models.organization.organization_member import OrganizationMember + is_any_admin = OrganizationMember.query.filter( + OrganizationMember.user_id == g.current_user.id, + OrganizationMember.role.in_([OrganizationRole.ADMIN, OrganizationRole.OWNER]), + OrganizationMember.deleted_at.is_(None), + ).first() is not None + if not is_any_admin: + return api_response(success=False, message="Admin or owner role required", status=403, error_type="AUTHORIZATION_ERROR") result = zerotier_reconciliation_service.reconcile_all() return api_response(data=result, message="Reconciliation complete") + + +# ── Per-org ZeroTier configuration ─────────────────────────────────────────── + + +class ZeroTierConfigSchema(Schema): + zt_api_token = fields.Str(required=True, validate=validate.Length(min=1, max=512)) + zt_api_url = fields.Str(required=True, validate=validate.Length(min=1, max=512)) + zt_api_mode = fields.Str( + required=True, + validate=validate.OneOf(["central", "controller"]), + ) + + +@api_v1_bp.route("/organizations//zerotier-config", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def get_zerotier_config(org_id): + """Return the current ZeroTier configuration for an organization (admin only). + + The token is masked — only its presence is indicated, not the value. + """ + org, err = _org_check(org_id) + if err: + return err + + return api_response( + data={ + "zerotier_config": { + "zt_api_token_set": bool(org.zt_api_token), + "zt_api_url": org.zt_api_url, + "zt_api_mode": org.zt_api_mode, + } + }, + message="ZeroTier configuration retrieved successfully", + ) + + +@api_v1_bp.route("/organizations//zerotier-config", methods=["PUT"]) +@login_required +@require_admin +@full_access_required +def set_zerotier_config(org_id): + """Set (or replace) the ZeroTier credentials for an organization (admin only). + + All three fields are required — there are no server-level defaults. + + Body: + zt_api_token (required) – API token for ZeroTier Central or authtoken.secret + zt_api_url (required) – full base URL, e.g. http://host:9993 or + https://api.zerotier.com/api/v1 + zt_api_mode (required) – "central" | "controller" + """ + org, err = _org_check(org_id) + if err: + return err + + try: + schema = ZeroTierConfigSchema() + data = schema.load(request.json or {}) + except ValidationError as e: + return api_response( + success=False, message="Validation failed", + status=400, error_type="VALIDATION_ERROR", error_details=e.messages, + ) + + # Test connectivity BEFORE saving — reject bad credentials early + connectivity_ok = False + connectivity_error = None + + # Temporarily set the credentials so _get_client() can build a client + old_token, old_url, old_mode = org.zt_api_token, org.zt_api_url, org.zt_api_mode + org.zt_api_token = data["zt_api_token"] + org.zt_api_url = data["zt_api_url"] + org.zt_api_mode = data["zt_api_mode"] + db.session.flush() # make visible to _get_client query without committing + + try: + zt.get_status(organization_id=org_id) + connectivity_ok = True + except ZeroTierAPIError as exc: + connectivity_error = str(exc) + except Exception as exc: + connectivity_error = str(exc) + + if not connectivity_ok: + # Roll back — don't persist bad credentials + org.zt_api_token = old_token + org.zt_api_url = old_url + org.zt_api_mode = old_mode + db.session.commit() + return api_response( + success=False, + message="Controller Connectivity test failed", + status=400, + error_type="ZEROTIER_CONNECTIVITY_FAILED", + error_details={ + "connectivity_test": { + "ok": False, + "error": connectivity_error, + }, + }, + ) + + # Connectivity verified — commit the new credentials + org.save() + + from gatehouse_app.services.audit_service import AuditService + AuditService.log_action( + action="org.zerotier_config.updated", + user_id=g.current_user.id, + organization_id=org_id, + resource_type="organization", + resource_id=org_id, + metadata={ + "zt_api_url": org.zt_api_url, + "zt_api_mode": org.zt_api_mode, + "connectivity_ok": connectivity_ok, + }, + description="Organization ZeroTier config updated", + success=True, + ) + + return api_response( + data={ + "zerotier_config": { + "zt_api_token_set": True, + "zt_api_url": org.zt_api_url, + "zt_api_mode": org.zt_api_mode, + }, + "connectivity_test": { + "ok": True, + "error": None, + }, + }, + message="ZeroTier configuration saved successfully", + ) + + +@api_v1_bp.route("/organizations//zerotier-config", methods=["DELETE"]) +@login_required +@require_admin +@full_access_required +def delete_zerotier_config(org_id): + """Remove the org-level ZeroTier credentials (admin only). + + After removal, all ZeroTier operations for this organization will fail + until new credentials + are configured via the ZeroTier Config page. + """ + org, err = _org_check(org_id) + if err: + return err + + org.zt_api_token = None + org.zt_api_url = None + org.zt_api_mode = None + org.save() + + from gatehouse_app.services.audit_service import AuditService + AuditService.log_action( + action="org.zerotier_config.deleted", + user_id=g.current_user.id, + organization_id=org_id, + resource_type="organization", + resource_id=org_id, + metadata={}, + description="Organization ZeroTier config removed — ZeroTier operations disabled until reconfigured", + success=True, + ) + + return api_response(message="ZeroTier configuration removed. Configure new credentials to re-enable ZeroTier features.") + diff --git a/gatehouse_app/exceptions/base.py b/gatehouse_app/exceptions/base.py index f7cfb0e..575ef47 100644 --- a/gatehouse_app/exceptions/base.py +++ b/gatehouse_app/exceptions/base.py @@ -8,19 +8,22 @@ class BaseAPIException(Exception): error_type = "INTERNAL_ERROR" message = "An unexpected error occurred" - def __init__(self, message=None, error_details=None): + def __init__(self, message=None, error_details=None, status_code=None): """ Initialize exception. Args: message: Custom error message error_details: Additional error details dictionary + status_code: Override the class-level HTTP status code """ super().__init__(self.message) if message: self.message = message super().__init__(message) # update args so str(e) works self.error_details = error_details or {} + if status_code is not None: + self.status_code = status_code def to_dict(self): """Convert exception to dictionary for API response.""" diff --git a/gatehouse_app/models/organization/organization.py b/gatehouse_app/models/organization/organization.py index 349edb0..64d3c7c 100644 --- a/gatehouse_app/models/organization/organization.py +++ b/gatehouse_app/models/organization/organization.py @@ -17,6 +17,10 @@ class Organization(BaseModel): # Settings (stored as JSON) settings = db.Column(db.JSON, nullable=True, default=dict) + zt_api_token = db.Column(db.String(512), nullable=True) + zt_api_url = db.Column(db.String(512), nullable=True) + zt_api_mode = db.Column(db.String(32), nullable=True) # "central" | "controller" + # Relationships members = db.relationship( "OrganizationMember", back_populates="organization", cascade="all, delete-orphan" diff --git a/gatehouse_app/models/zerotier/activation_session.py b/gatehouse_app/models/zerotier/activation_session.py index f26a9ef..00e4086 100644 --- a/gatehouse_app/models/zerotier/activation_session.py +++ b/gatehouse_app/models/zerotier/activation_session.py @@ -54,7 +54,7 @@ class ActivationSession(BaseModel): ) ended_at = db.Column(db.DateTime, nullable=True) end_reason = db.Column( - db.Enum(ActivationEndReason, name="activation_end_reason"), + db.Enum(ActivationEndReason, name="activation_end_reason", values_callable=lambda x: [e.value for e in x]), nullable=True, ) created_by = db.Column( diff --git a/gatehouse_app/models/zerotier/device.py b/gatehouse_app/models/zerotier/device.py index 695b5f3..5955890 100644 --- a/gatehouse_app/models/zerotier/device.py +++ b/gatehouse_app/models/zerotier/device.py @@ -47,7 +47,7 @@ class Device(BaseModel): asset_tag = db.Column(db.String(255), nullable=True) serial_number = db.Column(db.String(255), nullable=True) status = db.Column( - db.Enum(DeviceStatus, name="device_status"), + db.Enum(DeviceStatus, name="device_status", values_callable=lambda x: [e.value for e in x]), default=DeviceStatus.ACTIVE, nullable=False, ) diff --git a/gatehouse_app/models/zerotier/device_network_membership.py b/gatehouse_app/models/zerotier/device_network_membership.py index 8e1118c..cc6b85d 100644 --- a/gatehouse_app/models/zerotier/device_network_membership.py +++ b/gatehouse_app/models/zerotier/device_network_membership.py @@ -58,7 +58,7 @@ class DeviceNetworkMembership(BaseModel): index=True, ) state = db.Column( - db.Enum(MembershipState, name="membership_state"), + db.Enum(MembershipState, name="membership_state", values_callable=lambda x: [e.value for e in x]), default=MembershipState.PENDING_DEVICE_REGISTRATION, nullable=False, index=True, diff --git a/gatehouse_app/models/zerotier/kill_switch_event.py b/gatehouse_app/models/zerotier/kill_switch_event.py index 6571582..f7b67de 100644 --- a/gatehouse_app/models/zerotier/kill_switch_event.py +++ b/gatehouse_app/models/zerotier/kill_switch_event.py @@ -35,7 +35,7 @@ class KillSwitchEvent(BaseModel): index=True, ) scope = db.Column( - db.Enum(KillSwitchScope, name="kill_switch_scope"), + db.Enum(KillSwitchScope, name="kill_switch_scope", values_callable=lambda x: [e.value for e in x]), default=KillSwitchScope.ORGANIZATION, nullable=False, ) diff --git a/gatehouse_app/models/zerotier/portal_network.py b/gatehouse_app/models/zerotier/portal_network.py index b321832..bea0971 100644 --- a/gatehouse_app/models/zerotier/portal_network.py +++ b/gatehouse_app/models/zerotier/portal_network.py @@ -45,12 +45,12 @@ class PortalNetwork(BaseModel): index=True, ) environment = db.Column( - db.Enum(NetworkEnvironment, name="network_environment"), + db.Enum(NetworkEnvironment, name="network_environment", values_callable=lambda x: [e.value for e in x]), default=NetworkEnvironment.DEVELOPMENT, nullable=False, ) request_mode = db.Column( - db.Enum(NetworkRequestMode, name="network_request_mode"), + db.Enum(NetworkRequestMode, name="network_request_mode", values_callable=lambda x: [e.value for e in x]), default=NetworkRequestMode.APPROVAL_REQUIRED, nullable=False, ) diff --git a/gatehouse_app/models/zerotier/user_network_approval.py b/gatehouse_app/models/zerotier/user_network_approval.py index 44e57ae..dfe559e 100644 --- a/gatehouse_app/models/zerotier/user_network_approval.py +++ b/gatehouse_app/models/zerotier/user_network_approval.py @@ -48,12 +48,12 @@ class UserNetworkApproval(BaseModel): nullable=True, ) grant_type = db.Column( - db.Enum(ApprovalGrantType, name="approval_grant_type"), + db.Enum(ApprovalGrantType, name="approval_grant_type", values_callable=lambda x: [e.value for e in x]), default=ApprovalGrantType.REQUESTED, nullable=False, ) state = db.Column( - db.Enum(ApprovalState, name="approval_state"), + db.Enum(ApprovalState, name="approval_state", values_callable=lambda x: [e.value for e in x]), default=ApprovalState.PENDING, nullable=False, index=True, diff --git a/gatehouse_app/services/network_access_service.py b/gatehouse_app/services/network_access_service.py index 6d43ebd..a801a86 100644 --- a/gatehouse_app/services/network_access_service.py +++ b/gatehouse_app/services/network_access_service.py @@ -33,6 +33,7 @@ from gatehouse_app.exceptions import ( DeviceNotFoundError, ApprovalAlreadyExistsError, ValidationError, + ZeroTierAPIError, ) logger = logging.getLogger(__name__) @@ -74,9 +75,30 @@ def request_access( raise ApprovalAlreadyExistsError( "An access request or approval already exists for this user and network." ) - existing.state = ApprovalState.PENDING + is_open = network.request_mode.value == "open" + existing.state = ApprovalState.APPROVED if is_open else ApprovalState.PENDING existing.justification = justification existing.save() + + existing_membership = DeviceNetworkMembership.query.filter( + DeviceNetworkMembership.user_network_approval_id == existing.id, + DeviceNetworkMembership.device_id == device_id, + DeviceNetworkMembership.deleted_at.is_(None), + ).first() + if not existing_membership: + membership_state = MembershipState.APPROVED_INACTIVE if is_open else MembershipState.PENDING_DEVICE_REGISTRATION + membership = DeviceNetworkMembership( + organization_id=organization_id, + user_id=user_id, + device_id=device_id, + portal_network_id=portal_network_id, + user_network_approval_id=existing.id, + state=membership_state, + approved_for_activation=is_open, + ) + membership.save() + _ensure_zerotier_member(device.node_id, portal_network_id, authorized=False) + return existing is_open = network.request_mode.value == "open" @@ -329,6 +351,23 @@ def list_user_approvals(user_id: str, organization_id: str) -> list[UserNetworkA ).all() +def list_all_org_approvals( + organization_id: str, + network_id: str | None = None, + state: str | None = None, +) -> list[UserNetworkApproval]: + """List all approval records across all users in an org (admin use).""" + q = UserNetworkApproval.query.filter( + UserNetworkApproval.organization_id == organization_id, + UserNetworkApproval.deleted_at.is_(None), + ) + if network_id: + q = q.filter(UserNetworkApproval.portal_network_id == network_id) + if state: + q = q.filter(UserNetworkApproval.state == state) + return q.order_by(UserNetworkApproval.created_at.desc()).all() + + # ── Membership materialisation ─────────────────────────────────────────────── @@ -428,11 +467,12 @@ def activate_device_membership( membership_id: str, user_id: str, lifetime_minutes: int | None = None, + admin_override: bool = False, ) -> ActivationSession: """Activate an approved device on a network. Creates an activation session and authorizes in ZT.""" membership = _get_membership(membership_id) - if membership.user_id != user_id: + if not admin_override and membership.user_id != user_id: raise MembershipNotFoundError("Membership not found.") # Check approval is still active @@ -536,7 +576,8 @@ def deactivate_membership( # Deauthorize in ZeroTier device = Device.query.get(membership.device_id) network = PortalNetwork.query.get(membership.portal_network_id) - _deauthorize_in_zerotier(device.node_id, network.zerotier_network_id) + _deauthorize_in_zerotier(device.node_id, network.zerotier_network_id, + organization_id=membership.organization_id) membership.state = MembershipState.APPROVED_INACTIVE membership.currently_authorized = False @@ -567,6 +608,7 @@ def kill_switch( target_user_id: str, triggered_by_user_id: str, scope: str, + organization_id: str | None = None, reason: str | None = None, network_ids: list[str] | None = None, ) -> KillSwitchEvent: @@ -579,14 +621,18 @@ def kill_switch( DeviceNetworkMembership.deleted_at.is_(None), ) - org_id = None + org_id = organization_id # Use caller-supplied org_id as the primary source if scope_enum == KillSwitchScope.ORGANIZATION: - # Use the first membership's org - first = q.first() - org_id = first.organization_id if first else None + if not org_id: + # Fall back to deriving from first active membership + first = q.first() + org_id = first.organization_id if first else None + else: + # Scope query to the specified org + q = q.filter(DeviceNetworkMembership.organization_id == org_id) elif scope_enum == KillSwitchScope.SELECTED_NETWORKS and network_ids: q = q.filter(DeviceNetworkMembership.portal_network_id.in_(network_ids)) - if network_ids: + if not org_id: first_network = PortalNetwork.query.filter( PortalNetwork.id.in_(network_ids), PortalNetwork.deleted_at.is_(None), @@ -594,7 +640,7 @@ def kill_switch( org_id = first_network.organization_id if first_network else None if not org_id: - org_id = network_ids[0] if network_ids else None + raise ValidationError("Cannot determine organization for kill switch event.") # Create kill switch event event = KillSwitchEvent( @@ -608,14 +654,16 @@ def kill_switch( event.save() # Suspend all approvals - ApprovalState._value2member_map_ # just reference approvals = UserNetworkApproval.query.filter( UserNetworkApproval.user_id == target_user_id, UserNetworkApproval.state == ApprovalState.APPROVED, UserNetworkApproval.deleted_at.is_(None), ).all() for approval in approvals: - if scope_enum == KillSwitchScope.SELECTED_NETWORKS and network_ids: + if scope_enum == KillSwitchScope.ORGANIZATION and org_id: + if approval.organization_id != org_id: + continue + elif scope_enum == KillSwitchScope.SELECTED_NETWORKS and network_ids: if approval.portal_network_id not in network_ids: continue approval.state = ApprovalState.SUSPENDED @@ -691,7 +739,8 @@ def _ensure_zerotier_member( return try: - zt.add_member(network.zerotier_network_id, node_id, authorized=authorized) + zt.add_member(network.zerotier_network_id, node_id, authorized=authorized, + organization_id=network.organization_id) except Exception as exc: logger.warning( f"[_ensure_zerotier_member] Could not add member {node_id} " @@ -705,7 +754,8 @@ def _authorize_in_zerotier( membership: DeviceNetworkMembership, ) -> None: try: - zt.authorize_member(zerotier_network_id, node_id) + zt.authorize_member(zerotier_network_id, node_id, + organization_id=membership.organization_id) # Update zerotier_membership cache zt_membership = ZeroTierMembership.query.filter( @@ -740,6 +790,11 @@ def _authorize_in_zerotier( success=True, ) + except ZeroTierAPIError as exc: + logger.warning( + f"[_authorize_in_zerotier] ZeroTier unavailable — skipping authorization " + f"for {node_id} on {zerotier_network_id}: {exc}" + ) except Exception as exc: logger.error( f"[_authorize_in_zerotier] Failed to authorize {node_id} " @@ -748,9 +803,11 @@ def _authorize_in_zerotier( raise -def _deauthorize_in_zerotier(node_id: str, zerotier_network_id: str) -> None: +def _deauthorize_in_zerotier(node_id: str, zerotier_network_id: str, + organization_id: str | None = None) -> None: try: - zt.deauthorize_member(zerotier_network_id, node_id) + zt.deauthorize_member(zerotier_network_id, node_id, + organization_id=organization_id) zt_membership = ZeroTierMembership.query.filter( ZeroTierMembership.zerotier_network_id == zerotier_network_id, @@ -940,7 +997,8 @@ def revoke_membership_soft( if device and network: try: - zt.deauthorize_member(network.zerotier_network_id, device.node_id) + zt.deauthorize_member(network.zerotier_network_id, device.node_id, + organization_id=membership.organization_id) except Exception as exc: logger.warning(f"[revoke_membership_soft] ZT deauthorize failed for {device.node_id}: {exc}") @@ -984,7 +1042,8 @@ def hard_delete_membership(membership_id: str) -> None: if device and network: try: - zt.delete_network_member(network.zerotier_network_id, device.node_id) + zt.delete_network_member(network.zerotier_network_id, device.node_id, + organization_id=membership.organization_id) logger.info(f"[hard_delete_membership] Deleted {device.node_id} from ZT network {network.zerotier_network_id}") except Exception as exc: logger.warning(f"[hard_delete_membership] ZT delete failed for {device.node_id}: {exc}") diff --git a/gatehouse_app/services/notification_service.py b/gatehouse_app/services/notification_service.py index fa942c0..068fe8f 100644 --- a/gatehouse_app/services/notification_service.py +++ b/gatehouse_app/services/notification_service.py @@ -17,6 +17,7 @@ from datetime import datetime, timezone from typing import Optional, Dict, Any import logging import json +import threading from gatehouse_app.extensions import db from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance @@ -78,29 +79,22 @@ class NotificationService: ) # Send the notification - success = NotificationService._send_email( + NotificationService._send_email_async( to_address=user.email, subject=subject, body=body, ) - - if success: - logger.info( - f"Sent MFA deadline reminder to {user.email} " - f"({days_until_deadline} days remaining)" - ) - AuditService.log_action( - action=AuditAction.MFA_POLICY_USER_COMPLIANT, - user_id=user.id, - organization_id=compliance.organization_id, - description=f"MFA deadline reminder sent. Days remaining: {days_until_deadline}", - ) - else: - logger.warning( - f"Failed to send MFA deadline reminder to {user.email}" - ) - - return success + logger.info( + f"Sent MFA deadline reminder to {user.email} " + f"({days_until_deadline} days remaining)" + ) + AuditService.log_action( + action=AuditAction.MFA_POLICY_USER_COMPLIANT, + user_id=user.id, + organization_id=compliance.organization_id, + description=f"MFA deadline reminder sent. Days remaining: {days_until_deadline}", + ) + return True except Exception as e: logger.exception(f"Error sending MFA deadline reminder to {user.email}: {e}") @@ -136,27 +130,19 @@ class NotificationService: ) # Send the notification - success = NotificationService._send_email( + NotificationService._send_email_async( to_address=user.email, subject=subject, body=body, ) - - if success: - logger.info(f"Sent MFA suspension notification to {user.email}") - # Audit log - AuditService.log_action( - action=AuditAction.MFA_POLICY_USER_SUSPENDED, - user_id=user.id, - organization_id=compliance.organization_id, - description="MFA compliance suspension notification sent", - ) - else: - logger.warning( - f"Failed to send MFA suspension notification to {user.email}" - ) - - return success + logger.info(f"Sent MFA suspension notification to {user.email}") + AuditService.log_action( + action=AuditAction.MFA_POLICY_USER_SUSPENDED, + user_id=user.id, + organization_id=compliance.organization_id, + description="MFA compliance suspension notification sent", + ) + return True except Exception as e: logger.exception( @@ -285,89 +271,84 @@ Gatehouse Security Team return body @staticmethod - def _send_email( + def _send_email_async( to_address: str, subject: str, body: str, html_body: Optional[str] = None, - ) -> bool: - """Send an email via SMTP. + ) -> None: + """Send an email on a daemon thread so the calling request returns immediately. - Returns True if the email was sent successfully, False otherwise. - If EMAIL_ENABLED is False, logs the email body instead (simulation mode). + If EMAIL_ENABLED is False, logs instead of sending. All SMTP exceptions are caught and logged — this method never raises. + The Flask app context is pushed inside the thread so current_app works correctly. """ import smtplib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from flask import current_app - email_enabled = current_app.config.get(NotificationService.EMAIL_ENABLED_KEY, False) + app = current_app._get_current_object() # capture real app before leaving request context - if not email_enabled: - logger.info( - f"[EMAIL DISABLED] Would have sent to: {to_address} | Subject: {subject}\n" - f"Body: {body[:500]}" - ) - return False + def _send(): + with app.app_context(): + email_enabled = app.config.get(NotificationService.EMAIL_ENABLED_KEY, False) + if not email_enabled: + logger.info( + f"[EMAIL DISABLED] Would have sent to: {to_address} | Subject: {subject}\n" + f"Body: {body[:500]}" + ) + return - smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY, "") - smtp_port_raw = current_app.config.get(NotificationService.SMTP_PORT_KEY, 587) - smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY) - smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY) - from_address = current_app.config.get( - NotificationService.FROM_ADDRESS_KEY, "" - ) + smtp_host = app.config.get(NotificationService.SMTP_HOST_KEY, "") + smtp_port_raw = app.config.get(NotificationService.SMTP_PORT_KEY, 587) + smtp_username = app.config.get(NotificationService.SMTP_USERNAME_KEY) + smtp_password = app.config.get(NotificationService.SMTP_PASSWORD_KEY) + from_address = app.config.get(NotificationService.FROM_ADDRESS_KEY, "") - # Guard: refuse to attempt a connection when critical config is missing. - # This surfaces a clear log message instead of a confusing socket error. - missing = [k for k, v in [ - ("SMTP_HOST", smtp_host), - ("FROM_ADDRESS", from_address), - ] if not v] - if missing: - logger.error( - f"[EMAIL] Cannot send — missing config: {', '.join(missing)}. " - f"Would have sent to: {to_address} | Subject: {subject}" - ) - return False + missing = [k for k, v in [("SMTP_HOST", smtp_host), ("FROM_ADDRESS", from_address)] if not v] + if missing: + logger.error( + f"[EMAIL] Cannot send — missing config: {', '.join(missing)}. " + f"Would have sent to: {to_address} | Subject: {subject}" + ) + return - try: - smtp_port = int(smtp_port_raw) - except (TypeError, ValueError): - logger.error(f"[EMAIL] Invalid SMTP_PORT value: {smtp_port_raw!r}") - return False + try: + smtp_port = int(smtp_port_raw) + except (TypeError, ValueError): + logger.error(f"[EMAIL] Invalid SMTP_PORT value: {smtp_port_raw!r}") + return - smtp_use_tls = current_app.config.get( - NotificationService.SMTP_USE_TLS_KEY, - smtp_port not in (25, 1025), - ) + smtp_use_tls = app.config.get( + NotificationService.SMTP_USE_TLS_KEY, + smtp_port not in (25, 1025), + ) - try: - msg = MIMEMultipart("alternative") - msg["Subject"] = subject - msg["From"] = from_address - msg["To"] = to_address - msg.attach(MIMEText(body, "plain")) - if html_body: - msg.attach(MIMEText(html_body, "html")) + try: + msg = MIMEMultipart("alternative") + msg["Subject"] = subject + msg["From"] = from_address + msg["To"] = to_address + msg.attach(MIMEText(body, "plain")) + if html_body: + msg.attach(MIMEText(html_body, "html")) - with smtplib.SMTP(smtp_host, smtp_port) as server: - server.ehlo() - if smtp_use_tls: - server.starttls() - server.ehlo() - if smtp_username and smtp_password: - server.login(smtp_username, smtp_password) - server.send_message(msg) + with smtplib.SMTP(smtp_host, smtp_port) as server: + server.ehlo() + if smtp_use_tls: + server.starttls() + server.ehlo() + if smtp_username and smtp_password: + server.login(smtp_username, smtp_password) + server.send_message(msg) - logger.info(f"[EMAIL] Sent to {to_address} | Subject: {subject}") - return True + logger.info(f"[EMAIL] Sent to {to_address} | Subject: {subject}") - except Exception as e: - logger.error(f"[EMAIL] Failed to send to {to_address}: {e}") - return False + except Exception as e: + logger.error(f"[EMAIL] Failed to send to {to_address}: {e}") + threading.Thread(target=_send, daemon=True).start() @staticmethod def get_notification_stats(user_id: str) -> Dict[str, Any]: diff --git a/gatehouse_app/services/portal_network_service.py b/gatehouse_app/services/portal_network_service.py index fcff9ea..0e2a140 100644 --- a/gatehouse_app/services/portal_network_service.py +++ b/gatehouse_app/services/portal_network_service.py @@ -9,7 +9,7 @@ from gatehouse_app.models.organization import Organization from gatehouse_app.models.user import User from gatehouse_app.services.audit_service import AuditService from gatehouse_app.services import zerotier_api_service as zt -from gatehouse_app.utils.constants import NetworkRequestMode +from gatehouse_app.utils.constants import NetworkRequestMode, NetworkEnvironment from gatehouse_app.exceptions import ( NetworkNotFoundError, InvalidNetworkIdError, @@ -57,23 +57,74 @@ def create_network( default_activation_lifetime_minutes: Default session length max_activation_lifetime_minutes: Cap on activation lifetime """ - from gatehouse_app.utils.constants import NetworkEnvironment - zerotier_network_id = _validate_network_id(zerotier_network_id) - existing = PortalNetwork.query.filter( + existing_active = PortalNetwork.query.filter( PortalNetwork.organization_id == organization_id, PortalNetwork.zerotier_network_id == zerotier_network_id, PortalNetwork.deleted_at.is_(None), ).first() - if existing: + if existing_active: raise ValidationError( f"A portal network already exists for ZT network {zerotier_network_id} " f"in this organization." ) - env = NetworkEnvironment(environment) if environment else NetworkEnvironment.DEVELOPMENT - mode = NetworkRequestMode(request_mode) + # Normalize to lowercase so callers may pass "PRODUCTION" or "production" interchangeably + env_str = environment.lower() if environment else None + mode_str = request_mode.lower() if request_mode else "approval_required" + + try: + env = NetworkEnvironment(env_str) if env_str else NetworkEnvironment.DEVELOPMENT + except ValueError: + valid = [e.value for e in NetworkEnvironment] + raise ValidationError(f"Invalid environment '{environment}'. Must be one of: {valid}") + + try: + mode = NetworkRequestMode(mode_str) + except ValueError: + valid = [e.value for e in NetworkRequestMode] + raise ValidationError(f"Invalid request_mode '{request_mode}'. Must be one of: {valid}") + + # If a soft-deleted record for the same (org, zt_network_id) pair exists, restore it + # rather than inserting a new row (which would violate the unique constraint). + deleted = PortalNetwork.query.filter( + PortalNetwork.organization_id == organization_id, + PortalNetwork.zerotier_network_id == zerotier_network_id, + PortalNetwork.deleted_at.isnot(None), + ).first() + if deleted: + logger.info( + f"[PortalNetwork] Restoring soft-deleted portal network {deleted.id} " + f"for ZT network {zerotier_network_id}" + ) + deleted.deleted_at = None + deleted.name = name + deleted.description = description + deleted.owner_user_id = owner_user_id + deleted.environment = env + deleted.request_mode = mode + deleted.default_activation_lifetime_minutes = default_activation_lifetime_minutes + deleted.max_activation_lifetime_minutes = max_activation_lifetime_minutes + deleted.is_active = True + deleted.save() + + AuditService.log_action( + action="zt.network.restored", + user_id=owner_user_id, + organization_id=organization_id, + resource_type="portal_network", + resource_id=deleted.id, + metadata={ + "zerotier_network_id": zerotier_network_id, + "name": name, + "environment": env.value, + "request_mode": mode.value, + }, + description=f"Portal network '{name}' restored (ZT: {zerotier_network_id})", + success=True, + ) + return deleted network = PortalNetwork( organization_id=organization_id, @@ -90,7 +141,7 @@ def create_network( # Try to verify the network exists in ZeroTier try: - zt_network = zt.get_network(zerotier_network_id) + zt_network = zt.get_network(zerotier_network_id, organization_id=organization_id) logger.info( f"[PortalNetwork] Verified ZT network {zerotier_network_id} " f"exists in ZeroTier: {zt_network.name}" @@ -100,7 +151,7 @@ def create_network( f"[PortalNetwork] ZT network {zerotier_network_id} not found " "in ZeroTier — will be reconciled later." ) - except ZeroTierAPIError as exc: + except (ZeroTierAPIError, Exception) as exc: logger.warning( f"[PortalNetwork] Could not verify ZT network {zerotier_network_id}: {exc}" ) @@ -175,6 +226,23 @@ def update_network( if key not in allowed: raise ValidationError(f"Cannot update field: {key}") + # Normalize environment / request_mode strings to lowercase enum values + if "environment" in kwargs and isinstance(kwargs["environment"], str): + env_str = kwargs["environment"].lower() + try: + kwargs["environment"] = NetworkEnvironment(env_str) + except ValueError: + valid = [e.value for e in NetworkEnvironment] + raise ValidationError(f"Invalid environment '{kwargs['environment']}'. Must be one of: {valid}") + + if "request_mode" in kwargs and isinstance(kwargs["request_mode"], str): + mode_str = kwargs["request_mode"].lower() + try: + kwargs["request_mode"] = NetworkRequestMode(mode_str) + except ValueError: + valid = [e.value for e in NetworkRequestMode] + raise ValidationError(f"Invalid request_mode '{kwargs['request_mode']}'. Must be one of: {valid}") + network.update(**kwargs) AuditService.log_action( @@ -192,7 +260,11 @@ def update_network( def delete_network(network_id: str, user_id: str) -> None: - """Soft-delete a portal network and deactivate all memberships.""" + """Soft-delete a portal network and deactivate/clean up all related records.""" + from datetime import datetime, timezone + from gatehouse_app.models import UserNetworkApproval + from gatehouse_app.extensions import db + network = get_network(network_id) # Deauthorize all active memberships in ZeroTier @@ -203,6 +275,36 @@ def delete_network(network_id: str, user_id: str) -> None: network.delete(soft=True) + # Cascade soft-delete all active approvals and memberships for this network. + now = datetime.now(timezone.utc) + db.session.execute( + db.text( + "UPDATE user_network_approvals AS a " + "SET deleted_at = :now + (s.rn * interval '1 microsecond') " + "FROM (" + " SELECT id, row_number() OVER () AS rn " + " FROM user_network_approvals " + " WHERE portal_network_id = :network_id AND deleted_at IS NULL" + ") s " + "WHERE a.id = s.id" + ), + {"now": now, "network_id": network_id}, + ) + db.session.execute( + db.text( + "UPDATE device_network_memberships AS m " + "SET deleted_at = :now + (s.rn * interval '1 microsecond') " + "FROM (" + " SELECT id, row_number() OVER () AS rn " + " FROM device_network_memberships " + " WHERE portal_network_id = :network_id AND deleted_at IS NULL" + ") s " + "WHERE m.id = s.id" + ), + {"now": now, "network_id": network_id}, + ) + db.session.commit() + AuditService.log_action( action="zt.network.deleted", user_id=user_id, diff --git a/gatehouse_app/services/zerotier_api_service.py b/gatehouse_app/services/zerotier_api_service.py index a501755..7364386 100644 --- a/gatehouse_app/services/zerotier_api_service.py +++ b/gatehouse_app/services/zerotier_api_service.py @@ -1,7 +1,11 @@ """ZeroTier API service — thin Flask adapter around the ZeroTierClient SDK. -Reads configuration from app config and translates SDK exceptions to -Secuird typed exceptions. +ZeroTier is managed exclusively at the organization level. Each organization +configures its own ZeroTier credentials (token, URL, mode) via the web UI +(ZeroTier Config page → stored in the organizations table). + +Every call that interacts with ZeroTier must supply an organization_id so the correct org credentials +can be loaded from the database. """ import logging @@ -19,97 +23,147 @@ from gatehouse_app.utils.zerotier_client import ( logger = logging.getLogger(__name__) -def _get_client(app=None) -> ZeroTierClient: - """Build a ZeroTierClient from current app config.""" - from flask import current_app +def _get_client(organization_id: Optional[str] = None, app=None) -> ZeroTierClient: + """Build a ZeroTierClient using the organization's stored ZeroTier credentials. - app = app or current_app + Credentials are read exclusively from the organization record + (org.zt_api_token / org.zt_api_url / org.zt_api_mode). + + Args: + organization_id: The org whose credentials should be used. + Required for any ZeroTier operation. + app: Flask app instance (defaults to current_app, only needed for + background tasks that run outside a request context). + + Raises: + ZeroTierAPIError: If organization_id is missing, the org is not found, + or the org has incomplete ZeroTier credentials. + """ + if not organization_id: + raise ZeroTierAPIError( + "organization_id is required — ZeroTier credentials are managed " + "per-organization. Configure them via the ZeroTier Config page." + ) + + try: + from gatehouse_app.models.organization.organization import Organization + from gatehouse_app.extensions import db + org = db.session.get(Organization, organization_id) + except Exception as exc: + logger.error(f"[ZT] Failed to load org {organization_id} from DB: {exc}") + raise ZeroTierAPIError( + f"Could not load organization {organization_id}: {exc}" + ) from exc + + if not org: + raise ZeroTierAPIError(f"Organization {organization_id} not found.") + + token: Optional[str] = org.zt_api_token or None + if not token: + raise ZeroTierAPIError( + f"Organization '{org.name}' has no ZeroTier credentials configured. " + "Go to Settings → ZeroTier Config to add a token, mode, and controller URL." + ) + + mode_str = (org.zt_api_mode or "").strip().lower() + if mode_str not in ("central", "controller"): + raise ZeroTierAPIError( + f"Organization '{org.name}' has no ZeroTier mode set. " + "Go to Settings → ZeroTier Config and select 'Central' or 'Controller'." + ) + + url: str = (org.zt_api_url or "").strip() + if not url: + raise ZeroTierAPIError( + f"Organization '{org.name}' has no ZeroTier controller/API URL set. " + "Go to Settings → ZeroTier Config and enter the URL for your ZeroTier " + "controller (e.g. http://host:9993) or Central API." + ) - mode_str = app.config.get("ZEROTIER_API_MODE", "controller") mode = APIMode.CENTRAL if mode_str == "central" else APIMode.CONTROLLER - return ZeroTierClient( - api_token=app.config.get("ZEROTIER_API_TOKEN", ""), - base_url=app.config.get("ZEROTIER_API_URL", "http://localhost:9993"), - mode=mode, + logger.debug( + f"[ZT] Client for org:{organization_id} mode={mode_str} url={url}" ) + return ZeroTierClient(api_token=token, base_url=url, mode=mode) -def get_status() -> dict: + +def get_status(organization_id: Optional[str] = None) -> dict: """Verify connectivity to the ZeroTier controller.""" - client = _get_client() + client = _get_client(organization_id) try: return client.get_status() except SDKZeroTierAPIError as exc: raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc -def list_networks(): +def list_networks(organization_id: Optional[str] = None): """List all networks accessible to the configured token.""" - client = _get_client() + client = _get_client(organization_id) try: return client.list_networks() except SDKZeroTierAPIError as exc: raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc -def get_network(network_id: str): +def get_network(network_id: str, organization_id: Optional[str] = None): """Fetch a single network by ID.""" - client = _get_client() + client = _get_client(organization_id) try: return client.get_network(network_id) except SDKZeroTierAPIError as exc: raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc -def list_members(network_id: str): +def list_members(network_id: str, organization_id: Optional[str] = None): """List all members on a network.""" - client = _get_client() + client = _get_client(organization_id) try: return client.list_members(network_id) except SDKZeroTierAPIError as exc: raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc -def get_member(network_id: str, node_id: str): +def get_member(network_id: str, node_id: str, organization_id: Optional[str] = None): """Fetch a single member on a network.""" - client = _get_client() + client = _get_client(organization_id) try: return client.get_member(network_id, node_id) except SDKZeroTierAPIError as exc: raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc -def authorize_member(network_id: str, node_id: str): +def authorize_member(network_id: str, node_id: str, organization_id: Optional[str] = None): """Authorize a member on a network. Returns updated member.""" - client = _get_client() + client = _get_client(organization_id) try: return client.authorize_member(network_id, node_id) except SDKZeroTierAPIError as exc: raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc -def deauthorize_member(network_id: str, node_id: str): +def deauthorize_member(network_id: str, node_id: str, organization_id: Optional[str] = None): """De-authorize a member on a network. Returns updated member.""" - client = _get_client() + client = _get_client(organization_id) try: return client.deauthorize_member(network_id, node_id) except SDKZeroTierAPIError as exc: raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc -def add_member(network_id: str, node_id: str, authorized: bool = False): +def add_member(network_id: str, node_id: str, authorized: bool = False, organization_id: Optional[str] = None): """Manually add/pre-provision a member on a network.""" - client = _get_client() + client = _get_client(organization_id) try: return client.add_member(network_id, node_id, authorized=authorized) except SDKZeroTierAPIError as exc: raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc -def delete_network_member(network_id: str, node_id: str): +def delete_network_member(network_id: str, node_id: str, organization_id: Optional[str] = None): """Remove a member entirely from a ZeroTier network.""" - client = _get_client() + client = _get_client(organization_id) try: return client.delete_member(network_id, node_id) except SDKZeroTierAPIError as exc: diff --git a/gatehouse_app/services/zerotier_reconciliation_service.py b/gatehouse_app/services/zerotier_reconciliation_service.py index c43e941..c78119b 100644 --- a/gatehouse_app/services/zerotier_reconciliation_service.py +++ b/gatehouse_app/services/zerotier_reconciliation_service.py @@ -1,6 +1,7 @@ """ZeroTier reconciliation service — polling loop to sync state with the controller.""" import logging +import time from datetime import datetime, timezone from gatehouse_app.extensions import db @@ -34,16 +35,24 @@ def reconcile_expired_activations() -> int: ActivationSession.deleted_at.is_(None), ).all() + logger.debug(f"[Reconciliation] Expiry check: {len(expired)} overdue session(s) found.") + count = 0 for session in expired: try: _expire_session(session) count += 1 except Exception as exc: - logger.error(f"[Reconciliation] Failed to expire session {session.id}: {exc}") + logger.error( + f"[Reconciliation] Failed to expire session {session.id} " + f"(user={session.user_id} membership={session.device_network_membership_id}): {exc}", + exc_info=True, + ) if count > 0: - logger.info(f"[Reconciliation] Expired {count} activation sessions.") + logger.info(f"[Reconciliation] Expired {count} activation session(s).") + else: + logger.debug("[Reconciliation] No activation sessions to expire.") return count @@ -55,9 +64,14 @@ def reconcile_network(portal_network_id: str) -> dict: """ network = PortalNetwork.query.get(portal_network_id) if not network or not network.is_active: + logger.debug( + f"[Reconciliation] Skipping portal_network_id={portal_network_id}: " + f"{'not found' if not network else 'inactive or deleted'}." + ) return {"skipped": True, "reason": "network_inactive_or_deleted"} zerotier_network_id = network.zerotier_network_id + network_label = f"{network.name} ({zerotier_network_id})" actions = { "zt_members_checked": 0, "zt_members_added": 0, @@ -67,15 +81,25 @@ def reconcile_network(portal_network_id: str) -> dict: "unknown_members": [], } + t_start = time.monotonic() + logger.debug(f"[Reconciliation] Starting network reconciliation for {network_label}.") + # Get current ZT members try: - zt_members = {m.node_id: m for m in zt.list_members(zerotier_network_id)} + zt_members = {m.node_id: m for m in zt.list_members(zerotier_network_id, + organization_id=network.organization_id)} except Exception as exc: - logger.error(f"[Reconciliation] Failed to list ZT members for {zerotier_network_id}: {exc}") + logger.error( + f"[Reconciliation] Failed to list ZT members for {network_label}: {exc}", + exc_info=True, + ) actions["error"] = str(exc) return actions actions["zt_members_checked"] = len(zt_members) + logger.debug( + f"[Reconciliation] {network_label}: {len(zt_members)} member(s) fetched from ZT controller." + ) # Get our portal memberships for this network our_memberships = { @@ -87,13 +111,21 @@ def reconcile_network(portal_network_id: str) -> dict: if m.device and m.device.deleted_at is None } + logger.debug( + f"[Reconciliation] {network_label}: {len(our_memberships)} portal membership(s) to reconcile." + ) + # Reconcile each portal membership for node_id, membership in our_memberships.items(): zt_member = zt_members.pop(node_id, None) device = membership.device if not zt_member: - # Member not seen in ZT yet + # Member not seen in ZT yet — could be freshly joined or never connected + logger.debug( + f"[Reconciliation] {network_label}: node {node_id} " + f"(device={device.display_name!r}, state={membership.state}) not yet seen in ZT controller." + ) continue actions["join_seen_updated"] += 1 @@ -104,31 +136,67 @@ def reconcile_network(portal_network_id: str) -> dict: # Sync authorization state if membership.state == MembershipState.ACTIVE_AUTHORIZED: if not zt_member.is_authorized: - # We think it's active but ZT says it's not — re-authorize + # Portal says active but ZT disagrees — drift, re-authorize + logger.warning( + f"[Reconciliation] {network_label}: DRIFT detected — portal=ACTIVE_AUTHORIZED " + f"but ZT says unauthorized for node {node_id} (device={device.display_name!r}). Re-authorizing." + ) try: - zt.authorize_member(zerotier_network_id, node_id) + zt.authorize_member(zerotier_network_id, node_id, + organization_id=network.organization_id) actions["authorized"] += 1 + logger.info( + f"[Reconciliation] {network_label}: Re-authorized node {node_id} (device={device.display_name!r})." + ) except Exception as exc: - logger.warning(f"[Reconciliation] Re-authorize failed for {node_id}: {exc}") + logger.warning( + f"[Reconciliation] {network_label}: Re-authorize failed for node {node_id}: {exc}" + ) + else: + logger.debug( + f"[Reconciliation] {network_label}: node {node_id} — portal=ACTIVE_AUTHORIZED, ZT=authorized. OK." + ) else: if zt_member.is_authorized: - # We think it's not authorized but ZT says it is — deauthorize - # (could be manual override in ZT console) + # ZT says authorized but portal doesn't — could be manual override in ZT console + logger.warning( + f"[Reconciliation] {network_label}: DRIFT detected — portal state={membership.state} " + f"but ZT says authorized for node {node_id} (device={device.display_name!r}). Deauthorizing." + ) try: - zt.deauthorize_member(zerotier_network_id, node_id) + zt.deauthorize_member(zerotier_network_id, node_id, + organization_id=network.organization_id) actions["deauthorized"] += 1 + logger.info( + f"[Reconciliation] {network_label}: Deauthorized node {node_id} (device={device.display_name!r})." + ) except Exception as exc: - logger.warning(f"[Reconciliation] Deauthorize failed for {node_id}: {exc}") + logger.warning( + f"[Reconciliation] {network_label}: Deauthorize failed for node {node_id}: {exc}" + ) + else: + logger.debug( + f"[Reconciliation] {network_label}: node {node_id} — " + f"portal={membership.state}, ZT=unauthorized. OK." + ) - # Unknown ZT members not in our portal - actions["unknown_members"] = list(zt_members.keys()) + # Unknown ZT members not in our portal — log only, do not touch + unknown = list(zt_members.keys()) + actions["unknown_members"] = unknown + if unknown: + logger.warning( + f"[Reconciliation] {network_label}: {len(unknown)} ZT member(s) not in portal — " + f"node IDs: {', '.join(unknown)}" + ) + elapsed_ms = int((time.monotonic() - t_start) * 1000) logger.info( - f"[Reconciliation] Network {zerotier_network_id}: " + f"[Reconciliation] Network {network_label}: " f"checked={actions['zt_members_checked']} " f"authorized={actions['authorized']} " f"deauthorized={actions['deauthorized']} " - f"unknown={len(actions['unknown_members'])}" + f"unknown={len(actions['unknown_members'])} " + f"elapsed={elapsed_ms}ms" ) return actions @@ -144,16 +212,34 @@ def reconcile_all() -> dict: PortalNetwork.deleted_at.is_(None), ).all() - results = {"networks_processed": 0, "errors": 0} + logger.info(f"[Reconciliation] reconcile_all: {len(networks)} active network(s) to process.") + + results = {"networks_processed": 0, "errors": 0, "authorized": 0, "deauthorized": 0, "unknown_members": []} for network in networks: try: result = reconcile_network(network.id) if "error" in result: + logger.error( + f"[Reconciliation] Network {network.name} ({network.zerotier_network_id}) " + f"failed: {result['error']}" + ) results["errors"] += 1 + elif result.get("skipped"): + logger.debug( + f"[Reconciliation] Network {network.name} ({network.zerotier_network_id}) " + f"skipped: {result.get('reason')}" + ) else: results["networks_processed"] += 1 + results["authorized"] += result.get("authorized", 0) + results["deauthorized"] += result.get("deauthorized", 0) + results["unknown_members"].extend(result.get("unknown_members", [])) except Exception as exc: - logger.error(f"[Reconciliation] Failed to reconcile network {network.id}: {exc}") + logger.error( + f"[Reconciliation] Unhandled error reconciling network " + f"{network.name} ({network.id}): {exc}", + exc_info=True, + ) results["errors"] += 1 deleted_result = reconcile_deleted_memberships() @@ -161,8 +247,11 @@ def reconcile_all() -> dict: results["delete_errors"] = deleted_result.get("errors", 0) logger.info( - f"[Reconciliation] Complete: {results['networks_processed']} networks processed, " - f"{results['errors']} errors, {results.get('deleted_memberships', 0)} memberships purged." + f"[Reconciliation] Complete: " + f"networks={results['networks_processed']} " + f"errors={results['errors']} " + f"purged={results.get('deleted_memberships', 0)} " + f"purge_errors={results.get('delete_errors', 0)}" ) return results @@ -180,8 +269,11 @@ def reconcile_deleted_memberships() -> dict: ).all() if not deleted: + logger.debug("[Reconciliation] No soft-deleted memberships to purge.") return {"deleted": 0, "errors": 0} + logger.info(f"[Reconciliation] Purging {len(deleted)} soft-deleted membership(s) from ZT and DB.") + results = {"deleted": 0, "errors": 0} for membership in deleted: try: @@ -189,30 +281,49 @@ def reconcile_deleted_memberships() -> dict: network = PortalNetwork.query.get(membership.portal_network_id) if not device or not network: + logger.warning( + f"[Reconciliation] Membership {membership.id}: missing " + f"{'device' if not device else 'network'} — hard-deleting record only." + ) db.session.delete(membership) db.session.commit() results["deleted"] += 1 continue + node_id = device.node_id + zt_network_id = network.zerotier_network_id + network_label = f"{network.name} ({zt_network_id})" + try: - zt.delete_network_member(network.zerotier_network_id, device.node_id) - logger.info(f"[Reconciliation] Deleted {device.node_id} from ZT network {network.zerotier_network_id}") + zt.delete_network_member(zt_network_id, node_id, + organization_id=network.organization_id) + logger.info( + f"[Reconciliation] Removed node {node_id} (device={device.display_name!r}) " + f"from ZT network {network_label}." + ) except Exception as zt_exc: logger.warning( - f"[Reconciliation] ZT delete failed for {device.node_id} " - f"on {network.zerotier_network_id}: {zt_exc}" + f"[Reconciliation] ZT delete failed for node {node_id} " + f"on {network_label}: {zt_exc} — proceeding with DB hard-delete." ) db.session.delete(membership) db.session.commit() results["deleted"] += 1 + logger.debug( + f"[Reconciliation] Hard-deleted membership {membership.id} " + f"(node={node_id}, network={network_label})." + ) except Exception as exc: - logger.error(f"[Reconciliation] Failed to hard-delete membership {membership.id}: {exc}") + logger.error( + f"[Reconciliation] Failed to hard-delete membership {membership.id}: {exc}", + exc_info=True, + ) results["errors"] += 1 if results["deleted"] > 0: - logger.info(f"[Reconciliation] Purged {results['deleted']} memberships.") + logger.info(f"[Reconciliation] Purged {results['deleted']} membership(s).") return results @@ -228,7 +339,12 @@ def _sync_zt_membership(membership: DeviceNetworkMembership, zt_member) -> None: ZeroTierMembership.deleted_at.is_(None), ).first() - if not zt_membership: + is_new = zt_membership is None + if is_new: + logger.debug( + f"[Reconciliation] Creating new ZeroTierMembership cache record for " + f"node {device.node_id} on network {network.zerotier_network_id}." + ) zt_membership = ZeroTierMembership( organization_id=membership.organization_id, device_network_membership_id=membership.id, @@ -236,6 +352,8 @@ def _sync_zt_membership(membership: DeviceNetworkMembership, zt_member) -> None: node_id=device.node_id, ) + prev_authorized = zt_membership.authorized if not is_new else None + zt_membership.member_seen = True zt_membership.authorized = zt_member.is_authorized zt_membership.last_synced_at = datetime.now(timezone.utc) @@ -248,11 +366,27 @@ def _sync_zt_membership(membership: DeviceNetworkMembership, zt_member) -> None: zt_membership.save() + if not is_new and prev_authorized != zt_member.is_authorized: + logger.info( + f"[Reconciliation] ZT auth state changed for node {device.node_id} " + f"(device={device.display_name!r}): {prev_authorized} → {zt_member.is_authorized}" + ) + # Update membership join_seen flag if not membership.join_seen: + logger.info( + f"[Reconciliation] First join seen for node {device.node_id} " + f"(device={device.display_name!r}, membership={membership.id}). " + f"State: {membership.state} → {MembershipState.JOINED_DEAUTHORIZED}" + ) membership.join_seen = True membership.state = MembershipState.JOINED_DEAUTHORIZED membership.save() + else: + logger.debug( + f"[Reconciliation] Synced ZT membership for node {device.node_id} " + f"(device={device.display_name!r}, authorized={zt_member.is_authorized})." + ) def _expire_session(session: ActivationSession) -> None: @@ -261,8 +395,19 @@ def _expire_session(session: ActivationSession) -> None: session.end_reason = ActivationEndReason.EXPIRED session.save() + logger.info( + f"[Reconciliation] Expiring activation session {session.id} " + f"(user={session.user_id}, membership={session.device_network_membership_id}, " + f"expired_at={session.expires_at.isoformat()})." + ) + membership = DeviceNetworkMembership.query.get(session.device_network_membership_id) - if membership: + if not membership: + logger.warning( + f"[Reconciliation] Session {session.id}: membership " + f"{session.device_network_membership_id} not found — skipping ZT deauth." + ) + else: membership.state = MembershipState.ACTIVATION_EXPIRED membership.currently_authorized = False membership.save() @@ -270,8 +415,14 @@ def _expire_session(session: ActivationSession) -> None: device = Device.query.get(membership.device_id) network = PortalNetwork.query.get(membership.portal_network_id) if device and network: + network_label = f"{network.name} ({network.zerotier_network_id})" try: - zt.deauthorize_member(network.zerotier_network_id, device.node_id) + zt.deauthorize_member(network.zerotier_network_id, device.node_id, + organization_id=network.organization_id) + logger.info( + f"[Reconciliation] Deauthorized expired node {device.node_id} " + f"(device={device.display_name!r}) on {network_label}." + ) # Update ZT membership cache zt_membership = ZeroTierMembership.query.filter( @@ -283,12 +434,24 @@ def _expire_session(session: ActivationSession) -> None: zt_membership.authorized = False zt_membership.last_synced_at = datetime.now(timezone.utc) zt_membership.save() + else: + logger.debug( + f"[Reconciliation] No ZeroTierMembership cache record found for " + f"node {device.node_id} on {network_label} — nothing to update." + ) except Exception as exc: logger.warning( - f"[_expire_session] Failed to deauthorize {device.node_id} " - f"on {network.zerotier_network_id}: {exc}" + f"[_expire_session] Failed to deauthorize node {device.node_id} " + f"on {network_label}: {exc}", + exc_info=True, ) + else: + logger.warning( + f"[Reconciliation] Session {session.id}: missing " + f"{'device' if not device else 'network'} for membership " + f"{membership.id} — ZT deauth skipped." + ) from gatehouse_app.services.audit_service import AuditService AuditService.log_action( diff --git a/gatehouse_app/utils/constants.py b/gatehouse_app/utils/constants.py index 1fd9285..9968a01 100644 --- a/gatehouse_app/utils/constants.py +++ b/gatehouse_app/utils/constants.py @@ -282,7 +282,6 @@ class KillSwitchScope(str, Enum): """Scope of a kill switch event.""" ORGANIZATION = "organization" - GLOBAL = "global" SELECTED_NETWORKS = "selected_networks" diff --git a/migrations/versions/028_org_zerotier_config.py b/migrations/versions/028_org_zerotier_config.py new file mode 100644 index 0000000..c5aad28 --- /dev/null +++ b/migrations/versions/028_org_zerotier_config.py @@ -0,0 +1,69 @@ +"""Add per-org ZeroTier credentials to organizations table. + +Revision ID: 028_org_zerotier_config +Revises: 026_schema_cleanup +Create Date: 2026-03-25 + +Adds three nullable columns to `organizations`: + - zt_api_token VARCHAR(512) — API token (Central) or authtoken.secret (controller) + - zt_api_url VARCHAR(512) — base URL of the controller / Central API + - zt_api_mode VARCHAR(32) — "central" | "controller" + +When these are NULL the server-level ZEROTIER_API_* env vars are used instead, +so existing deployments are fully backwards-compatible with no data migration needed. +""" + +from alembic import op +import sqlalchemy as sa + + +revision = "028_org_zerotier_config" +down_revision = "027_fix_cert_serial_uniqueness" +branch_labels = None +depends_on = None + + +def _col_exists(conn, table: str, column: str) -> bool: + row = conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :t AND column_name = :c" + ), + {"t": table, "c": column}, + ).first() + return row is not None + + +def upgrade(): + conn = op.get_bind() + + if not _col_exists(conn, "organizations", "zt_api_token"): + op.add_column( + "organizations", + sa.Column("zt_api_token", sa.String(512), nullable=True), + ) + + if not _col_exists(conn, "organizations", "zt_api_url"): + op.add_column( + "organizations", + sa.Column("zt_api_url", sa.String(512), nullable=True), + ) + + if not _col_exists(conn, "organizations", "zt_api_mode"): + op.add_column( + "organizations", + sa.Column("zt_api_mode", sa.String(32), nullable=True), + ) + + +def downgrade(): + conn = op.get_bind() + + if _col_exists(conn, "organizations", "zt_api_mode"): + op.drop_column("organizations", "zt_api_mode") + + if _col_exists(conn, "organizations", "zt_api_url"): + op.drop_column("organizations", "zt_api_url") + + if _col_exists(conn, "organizations", "zt_api_token"): + op.drop_column("organizations", "zt_api_token") From 78c2ee5c5a126916c515eae3c8333b39a43a00af Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Tue, 31 Mar 2026 13:45:07 +0545 Subject: [PATCH 9/9] Feat:Added Update Client --- gatehouse_app/api/v1/organizations/clients.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/gatehouse_app/api/v1/organizations/clients.py b/gatehouse_app/api/v1/organizations/clients.py index 553837a..3817023 100644 --- a/gatehouse_app/api/v1/organizations/clients.py +++ b/gatehouse_app/api/v1/organizations/clients.py @@ -95,6 +95,53 @@ def create_org_client(org_id): ) +@api_v1_bp.route("/organizations//clients/", methods=["PATCH"]) +@login_required +@require_admin +def update_org_client(org_id, client_id): + from gatehouse_app.models import OIDCClient + + client = OIDCClient.query.filter_by(id=client_id, organization_id=org_id).first() + if not client: + return api_response(success=False, message="Client not found", status=404) + + data = request.get_json() or {} + + if "name" in data: + name = (data["name"] or "").strip() + if not name: + return api_response(success=False, message="Client name cannot be empty", status=400, error_type="VALIDATION_ERROR") + client.name = name + + if "redirect_uris" in data: + redirect_uris_raw = data["redirect_uris"] + if isinstance(redirect_uris_raw, str): + uris = [u.strip() for u in redirect_uris_raw.replace(",", "\n").splitlines() if u.strip()] + else: + uris = [u.strip() for u in redirect_uris_raw if isinstance(u, str) and u.strip()] + if not uris: + return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR") + client.redirect_uris = uris + + db.session.commit() + + return api_response( + data={ + "client": { + "id": client.id, + "name": client.name, + "client_id": client.client_id, + "redirect_uris": client.redirect_uris, + "scopes": client.scopes, + "grant_types": client.grant_types, + "is_active": client.is_active, + "created_at": client.created_at.isoformat() + "Z", + } + }, + message="Client updated successfully", + ) + + @api_v1_bp.route("/organizations//clients/", methods=["DELETE"]) @login_required @require_admin