Merge pull request #13 from jamesii-b/gatehouse/secuird-CA-merge-v2.01

Gatehouse/secuird ca merge v2.01
This commit is contained in:
2026-04-03 10:18:10 +10:30
committed by GitHub
54 changed files with 3317 additions and 599 deletions
+91 -32
View File
@@ -1,58 +1,117 @@
# Flask Configuration
FLASK_APP=wsgi.py
FLASK_APP=manage.py
FLASK_ENV=development
SECRET_KEY=your-secret-key-here-change-in-production
FLASK_DEBUG=1
# Database
DATABASE_URL=postgresql://user:password@localhost:5432/authy2_dev
DATABASE_URL=postgresql://user:password@localhost:5432/gatehouse_dev
SQLALCHEMY_ECHO=False
SQLALCHEMY_LOG_LEVEL=WARNING
# Security
# Security / Encryption
SECRET_KEY=change-me-in-production
ENCRYPTION_KEY=change-me-in-production-32-bytes!!
# Used to encrypt SSH CA private keys stored in the database
CA_ENCRYPTION_KEY=change-me-in-production
BCRYPT_LOG_ROUNDS=12
ENCRYPTION_KEY=your-encryption-key-here-change-in-production
# Session cookies
SESSION_COOKIE_SECURE=False
SESSION_COOKIE_HTTPONLY=True
SESSION_COOKIE_SAMESITE=Lax
# Only needed when sharing cookies across subdomains (e.g. api.example.com + ui.example.com)
# SESSION_COOKIE_DOMAIN=example.com
MAX_SESSION_DURATION=86400
# CORS
#CORS_ORIGINS=http://localhost:3000,http://localhost:5173,https://oidc-playpen.lovable.app/,http://localhost:8080/
CORS_ORIGINS=*
# JWT (if using JWT instead of sessions)
JWT_SECRET_KEY=your-jwt-secret-key-here
# ─────────────────────────────────────────────────────────────────────────────
# JWT
# ─────────────────────────────────────────────────────────────────────────────
JWT_SECRET_KEY=change-me-in-production
JWT_ACCESS_TOKEN_EXPIRES=3600
JWT_REFRESH_TOKEN_EXPIRES=2592000
# Redis (for session storage)
# ─────────────────────────────────────────────────────────────────────────────
# Redis (session storage + rate limiting)
# ─────────────────────────────────────────────────────────────────────────────
REDIS_URL=redis://localhost:6379/0
SESSION_REDIS_URL=redis://localhost:6379/0
RATELIMIT_STORAGE_URL=redis://localhost:6379/1
# OIDC
# ─────────────────────────────────────────────────────────────────────────────
# CORS
# ─────────────────────────────────────────────────────────────────────────────
CORS_ORIGINS=http://localhost:8080,http://localhost:5173
# ─────────────────────────────────────────────────────────────────────────────
# Frontend / App URLs
# All three should point at the browser-facing SPA. They are used for:
# FRONTEND_URL → OAuth callback redirects after provider auth
# APP_URL → Password-reset and email-verify links in emails
# OIDC_UI_URL → OIDC /authorize redirects to the React consent/login UI
# ─────────────────────────────────────────────────────────────────────────────
FRONTEND_URL=http://localhost:8080
APP_URL=http://localhost:8080
OIDC_UI_URL=http://localhost:8080
# ─────────────────────────────────────────────────────────────────────────────
# OIDC / OAuth issuer
# ─────────────────────────────────────────────────────────────────────────────
OIDC_ISSUER_URL=http://localhost:5000
OIDC_BASE_URL=http://localhost:5000
# ─────────────────────────────────────────────────────────────────────────────
# WebAuthn
# ─────────────────────────────────────────────────────────────────────────────
WEBAUTHN_RP_ID=localhost
WEBAUTHN_RP_NAME=Gatehouse
WEBAUTHN_ORIGIN=http://localhost:8080
# ─────────────────────────────────────────────────────────────────────────────
# SSH CA (pick one)
# ─────────────────────────────────────────────────────────────────────────────
SSH_CA_KEY_PATH=/path/to/ca-users
# SSH_CA_PRIVATE_KEY= # raw key content; takes priority over SSH_CA_KEY_PATH
# ─────────────────────────────────────────────────────────────────────────────
# Email / SMTP
# ─────────────────────────────────────────────────────────────────────────────
EMAIL_ENABLED=False
SMTP_HOST=smtp.gmail.com
SMTP_PORT=587
SMTP_USE_TLS=True
SMTP_USERNAME=
SMTP_PASSWORD=
FROM_ADDRESS=noreply@gatehouse.local
# ─────────────────────────────────────────────────────────────────────────────
# Logging
# ─────────────────────────────────────────────────────────────────────────────
LOG_LEVEL=INFO
LOG_TO_STDOUT=True
# ─────────────────────────────────────────────────────────────────────────────
# Rate Limiting
# ─────────────────────────────────────────────────────────────────────────────
RATELIMIT_ENABLED=True
RATELIMIT_STORAGE_URL=redis://localhost:6379/1
# SSH CA
# Path to CA private key file (alternative to SSH_CA_PRIVATE_KEY env var)
SSH_CA_KEY_PATH=/path/to/ca-users
# Or set the key content directly (takes priority over SSH_CA_KEY_PATH):
# SSH_CA_PRIVATE_KEY=
EMAIL_ENABLED=
SMTP_HOST=
SMTP_PORT=
SMTP_USERNAME=
SMTP_PASSWORD=
FROM_ADDRESS=
WEBAUTHN_ORIGIN=
# Per-endpoint auth limits (optional — defaults shown)
# RATELIMIT_AUTH_REGISTER=10 per minute; 50 per hour
# RATELIMIT_AUTH_LOGIN=20 per minute; 100 per hour
# RATELIMIT_AUTH_TOTP_VERIFY=20 per minute; 100 per hour
# RATELIMIT_AUTH_FORGOT_PASSWORD=5 per minute; 20 per hour
# RATELIMIT_AUTH_RESET_PASSWORD=10 per minute; 30 per hour
ZEROTIER_API_TOKEN=
ZEROTIER_API_URL=
ZEROTIER_API_URL=
# ─────────────────────────────────────────────────────────────────────────────
# OIDC token lifetimes & security (optional — defaults shown)
# ─────────────────────────────────────────────────────────────────────────────
# OIDC_ACCESS_TOKEN_LIFETIME=3600
# OIDC_REFRESH_TOKEN_LIFETIME=2592000
# OIDC_ID_TOKEN_LIFETIME=3600
# OIDC_AUTHORIZATION_CODE_LIFETIME=600
# OIDC_REQUIRE_PKCE=True
# OIDC_ALLOW_IMPLICIT_FLOW=False
# OIDC_KEY_ROTATION_DAYS=90
# OIDC_KEY_GRACE_PERIOD_DAYS=30
# OIDC_RATE_LIMIT_AUTHORIZE=10/minute
# OIDC_RATE_LIMIT_TOKEN=20/minute
# OIDC_RATE_LIMIT_USERINFO=60/minute
+2 -12
View File
@@ -128,20 +128,10 @@ class BaseConfig:
# Frontend URL (for OAuth callback redirects)
FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:8080")
APP_URL = os.getenv("APP_URL", os.getenv("FRONTEND_URL", "http://localhost:8080"))
OIDC_UI_URL = os.getenv("OIDC_UI_URL", os.getenv("FRONTEND_URL", "http://localhost:8080"))
# ZeroTier Configuration
ZEROTIER_API_TOKEN = os.getenv("ZEROTIER_API_TOKEN", "")
ZEROTIER_API_URL = os.getenv(
"ZEROTIER_API_URL",
"http://localhost:9993",
)
ZEROTIER_API_MODE = os.getenv("ZEROTIER_API_MODE", "controller").lower()
ZEROTIER_DEFAULT_ACTIVATION_LIFETIME_MINUTES = int(
os.getenv("ZEROTIER_DEFAULT_ACTIVATION_LIFETIME_MINUTES", "480")
)
ZEROTIER_RECONCILIATION_INTERVAL_SECONDS = int(
os.getenv("ZEROTIER_RECONCILIATION_INTERVAL_SECONDS", "120")
)
# Email / SMTP
EMAIL_ENABLED = os.getenv("EMAIL_ENABLED", "False").lower() == "true"
+2 -1
View File
@@ -5,6 +5,7 @@ from flask import Blueprint
api_v1_bp = Blueprint("api_v1", __name__)
# Import route modules to register them
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh, zerotier
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh, zerotier, sudo
api_v1_bp.register_blueprint(ssh.ssh_bp)
+1 -1
View File
@@ -39,7 +39,7 @@ def register():
f"{verify_link}\n\n"
f"Gatehouse Security Team"
)
NotificationService._send_email(to_address=user.email, subject=subject, body=body)
NotificationService._send_email_async(to_address=user.email, subject=subject, body=body)
except Exception as exc:
logging.getLogger(__name__).warning(f"Failed to send verification email on register: {exc}")
+3 -3
View File
@@ -27,7 +27,7 @@ def forgot_password():
reset_token = PasswordResetToken.generate(user_id=user.id)
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
reset_link = f"{app_url}/reset-password?token={reset_token.token}"
NotificationService._send_email(
NotificationService._send_email_async(
to_address=user.email,
subject="Reset your Gatehouse password",
body=(
@@ -129,7 +129,7 @@ def resend_verification():
verify_token = EmailVerificationToken.generate(user_id=user.id)
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
verify_link = f"{app_url}/verify-email?token={verify_token.token}"
NotificationService._send_email(
NotificationService._send_email_async(
to_address=user.email,
subject="Verify your Gatehouse email address",
body=(
@@ -200,7 +200,7 @@ def resend_activation():
app_url = current_app.config.get("APP_URL", current_app.config.get("FRONTEND_URL", "http://localhost:8080"))
activate_link = f"{app_url}/activate?code={code}"
NotificationService._send_email(
NotificationService._send_email_async(
to_address=user.email,
subject="Activate your Gatehouse account",
body=(
+4
View File
@@ -16,12 +16,15 @@ class DepartmentCreateSchema(Schema):
"""Schema for creating a department."""
name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
can_sudo = fields.Bool(allow_none=True, load_default=False)
class DepartmentUpdateSchema(Schema):
"""Schema for updating a department."""
name = fields.Str(validate=validate.Length(min=1, max=255))
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
can_sudo = fields.Bool(allow_none=True)
class AddDepartmentMemberSchema(Schema):
@@ -119,6 +122,7 @@ def create_department(org_id):
organization_id=org_id,
name=data["name"],
description=data.get("description"),
can_sudo=data.get("can_sudo", False),
)
db.session.add(dept)
db.session.commit()
@@ -1,4 +1,4 @@
"""Organization routes package."""
from gatehouse_app.api.v1.organizations import core, members, invites, clients, cas, audit, roles
from gatehouse_app.api.v1.organizations import core, members, invites, clients, cas, audit, roles, api_keys
__all__ = ["core", "members", "invites", "clients", "cas", "audit", "roles"]
__all__ = ["core", "members", "invites", "clients", "cas", "audit", "roles", "api_keys"]
@@ -0,0 +1,299 @@
"""Organization API Key management endpoints."""
from flask import g, request
from marshmallow import Schema, fields, validate, ValidationError
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
from gatehouse_app.models.organization import OrganizationApiKey
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.extensions import db
class ApiKeyCreateSchema(Schema):
"""Schema for creating an API key."""
name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
class ApiKeyUpdateSchema(Schema):
"""Schema for updating an API key."""
name = fields.Str(validate=validate.Length(min=1, max=255))
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
@api_v1_bp.route("/organizations/<org_id>/api-keys", methods=["GET"])
@login_required
@require_admin
@full_access_required
def list_api_keys(org_id):
"""
List all API keys for an organization.
Only accessible by organization admins.
Args:
org_id: Organization ID
Returns:
200: List of API keys (without key values)
401: Not authenticated
403: Not an admin
404: Organization not found
"""
org = OrganizationService.get_organization_by_id(org_id)
# Check if user is an admin
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
membership = OrganizationMember.query.filter_by(
user_id=g.current_user.id,
organization_id=org_id,
deleted_at=None
).first()
if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
return api_response(
success=False,
message="You do not have permission to manage API keys",
status=403,
error_type="AUTHORIZATION_ERROR",
)
api_keys = OrganizationApiKey.query.filter_by(
organization_id=org_id,
deleted_at=None
).all()
return api_response(
data={
"api_keys": [k.to_dict() for k in api_keys],
"count": len(api_keys),
},
message="API keys retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/api-keys", methods=["POST"])
@login_required
@require_admin
@full_access_required
def create_api_key(org_id):
"""
Create a new API key for an organization.
Only accessible by organization admins.
The plain text key is returned only on creation and should be stored securely.
Args:
org_id: Organization ID
Request body:
name: API key name (required)
description: Optional description
Returns:
201: API key created successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization not found
"""
try:
org = OrganizationService.get_organization_by_id(org_id)
# Check if user is an admin
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
membership = OrganizationMember.query.filter_by(
user_id=g.current_user.id,
organization_id=org_id,
deleted_at=None
).first()
if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
return api_response(
success=False,
message="You do not have permission to create API keys",
status=403,
error_type="AUTHORIZATION_ERROR",
)
schema = ApiKeyCreateSchema()
data = schema.load(request.json or {})
# Create the API key
api_key, plain_key = OrganizationApiKey.create_key(
organization_id=org_id,
name=data["name"],
description=data.get("description"),
)
# Return the key data with the plain text key (only on creation)
key_dict = api_key.to_dict()
key_dict["key"] = plain_key # Include plain text only on creation
return api_response(
data={"api_key": key_dict},
message="API key created successfully. Store the key value securely - it cannot be retrieved later.",
status=201,
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>/api-keys/<key_id>", methods=["PATCH"])
@login_required
@require_admin
@full_access_required
def update_api_key(org_id, key_id):
"""
Update an API key.
Only accessible by organization admins.
Args:
org_id: Organization ID
key_id: API Key ID
Request body:
name: New name (optional)
description: New description (optional)
Returns:
200: API key updated successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization or API key not found
"""
try:
org = OrganizationService.get_organization_by_id(org_id)
# Check if user is an admin
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
membership = OrganizationMember.query.filter_by(
user_id=g.current_user.id,
organization_id=org_id,
deleted_at=None
).first()
if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
return api_response(
success=False,
message="You do not have permission to update API keys",
status=403,
error_type="AUTHORIZATION_ERROR",
)
api_key = OrganizationApiKey.query.filter_by(
id=key_id,
organization_id=org_id,
deleted_at=None
).first()
if not api_key:
return api_response(
success=False,
message="API key not found",
status=404,
error_type="NOT_FOUND",
)
schema = ApiKeyUpdateSchema()
data = schema.load(request.json or {})
# Update fields
if "name" in data:
api_key.name = data["name"]
if "description" in data:
api_key.description = data["description"]
api_key.save()
return api_response(
data={"api_key": api_key.to_dict()},
message="API key updated successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>/api-keys/<key_id>", methods=["DELETE"])
@login_required
@require_admin
@full_access_required
def delete_api_key(org_id, key_id):
"""
Delete/revoke an API key.
Only accessible by organization admins.
Args:
org_id: Organization ID
key_id: API Key ID
Returns:
200: API key deleted successfully
401: Not authenticated
403: Not an admin
404: Organization or API key not found
"""
org = OrganizationService.get_organization_by_id(org_id)
# Check if user is an admin
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
membership = OrganizationMember.query.filter_by(
user_id=g.current_user.id,
organization_id=org_id,
deleted_at=None
).first()
if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
return api_response(
success=False,
message="You do not have permission to delete API keys",
status=403,
error_type="AUTHORIZATION_ERROR",
)
api_key = OrganizationApiKey.query.filter_by(
id=key_id,
organization_id=org_id,
deleted_at=None
).first()
if not api_key:
return api_response(
success=False,
message="API key not found",
status=404,
error_type="NOT_FOUND",
)
# Soft delete the API key
api_key.delete(soft=True)
return api_response(
message="API key deleted successfully",
)
@@ -173,3 +173,101 @@ def get_my_audit_logs():
},
message="Activity retrieved",
)
@api_v1_bp.route("/organizations/<org_id>/certificates/audit", methods=["GET"])
@login_required
@require_admin
@full_access_required
def get_certificate_audit_logs(org_id):
"""
Get certificate issuance audit logs for an organization.
Only accessible by organization admins.
Returns certificate serial IDs, user IDs, and issuance timestamps for compliance.
Args:
org_id: Organization ID
Returns:
200: List of certificate audit logs
401: Not authenticated
403: Not an admin
404: Organization not found
"""
from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate
from gatehouse_app.models.user import User
org = OrganizationService.get_organization_by_id(org_id)
page = max(1, int(request.args.get("page", 1)))
per_page = min(int(request.args.get("per_page", 50)), 200)
action_filter = request.args.get("action", "signed") # Default to signed certificates
# Get all CAs for this organization
from gatehouse_app.models.ssh_ca import CA
org_cas = CA.query.filter_by(organization_id=org_id, deleted_at=None).all()
org_ca_ids = [ca.id for ca in org_cas]
if not org_ca_ids:
return api_response(
data={
"audit_logs": [],
"count": 0,
"page": page,
"per_page": per_page,
"pages": 0,
},
message="No certificate audit logs found",
)
# Query certificate audit logs for certificates issued by org's CAs
query = CertificateAuditLog.query.join(
SSHCertificate,
CertificateAuditLog.certificate_id == SSHCertificate.id
).filter(
SSHCertificate.ca_id.in_(org_ca_ids),
CertificateAuditLog.deleted_at.is_(None)
)
if action_filter:
query = query.filter(CertificateAuditLog.action == action_filter)
query = query.order_by(CertificateAuditLog.created_at.desc())
total = query.count()
logs = query.offset((page - 1) * per_page).limit(per_page).all()
# Build response data with certificate details
audit_data = []
for log in logs:
cert = log.certificate
user = log.user
audit_data.append({
"id": log.id,
"action": log.action,
"certificate_serial": cert.serial,
"key_id": cert.key_id,
"principals": cert.principals,
"user_id": user.id if user else cert.user_id,
"user_email": user.email if user else None,
"issued_at": cert.created_at.isoformat() if cert.created_at else None,
"valid_after": cert.valid_after.isoformat() if cert.valid_after else None,
"valid_before": cert.valid_before.isoformat() if cert.valid_before else None,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"message": log.message,
"success": log.success,
"created_at": log.created_at.isoformat() if log.created_at else None,
})
return api_response(
data={
"audit_logs": audit_data,
"count": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page,
},
message="Certificate audit logs retrieved successfully",
)
@@ -95,6 +95,53 @@ def create_org_client(org_id):
)
@api_v1_bp.route("/organizations/<org_id>/clients/<client_id>", methods=["PATCH"])
@login_required
@require_admin
def update_org_client(org_id, client_id):
from gatehouse_app.models import OIDCClient
client = OIDCClient.query.filter_by(id=client_id, organization_id=org_id).first()
if not client:
return api_response(success=False, message="Client not found", status=404)
data = request.get_json() or {}
if "name" in data:
name = (data["name"] or "").strip()
if not name:
return api_response(success=False, message="Client name cannot be empty", status=400, error_type="VALIDATION_ERROR")
client.name = name
if "redirect_uris" in data:
redirect_uris_raw = data["redirect_uris"]
if isinstance(redirect_uris_raw, str):
uris = [u.strip() for u in redirect_uris_raw.replace(",", "\n").splitlines() if u.strip()]
else:
uris = [u.strip() for u in redirect_uris_raw if isinstance(u, str) and u.strip()]
if not uris:
return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR")
client.redirect_uris = uris
db.session.commit()
return api_response(
data={
"client": {
"id": client.id,
"name": client.name,
"client_id": client.client_id,
"redirect_uris": client.redirect_uris,
"scopes": client.scopes,
"grant_types": client.grant_types,
"is_active": client.is_active,
"created_at": client.created_at.isoformat() + "Z",
}
},
message="Client updated successfully",
)
@api_v1_bp.route("/organizations/<org_id>/clients/<client_id>", methods=["DELETE"])
@login_required
@require_admin
+3 -11
View File
@@ -37,7 +37,7 @@ def create_org_invite(org_id):
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
invite_link = f"{app_url}/invite?token={invite.token}"
email_sent = NotificationService._send_email(
NotificationService._send_email_async(
to_address=email,
subject=f"You're invited to join {org.name} on Gatehouse",
body=(
@@ -47,16 +47,8 @@ def create_org_invite(org_id):
f"Gatehouse Security Team"
),
)
if not email_sent:
logging.getLogger(__name__).warning(
f"[INVITE LINK] Email not sent (EMAIL_ENABLED=False or SMTP down). "
f"Invite for {email}{invite_link}"
)
else:
logging.getLogger(__name__).info(
f"[INVITE] Email sent successfully to {email}"
)
logging.getLogger(__name__).info(f"[INVITE] Email queued for {email}")
email_sent = True # async — assume queued successfully
response_data = {
"invite": {
@@ -161,7 +161,7 @@ def send_mfa_reminder(org_id, user_id):
if compliance and policy and compliance.deadline_at:
NotificationService.send_mfa_deadline_reminder(user, compliance, policy)
else:
NotificationService._send_email(
NotificationService._send_email_async(
to_address=user.email,
subject="Reminder: Set up multi-factor authentication",
body=(
+7 -3
View File
@@ -130,14 +130,18 @@ def sign_certificate():
dept_policy = _get_merged_dept_cert_policy(user_id)
if dept_policy:
if is_org_admin:
if not dept_policy["allow_user_expiry"]:
expiry_hours = dept_policy["default_expiry_hours"]
elif is_org_admin:
if expiry_hours is not None:
expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"])
elif not dept_policy["allow_user_expiry"]:
expiry_hours = dept_policy["default_expiry_hours"]
else:
expiry_hours = dept_policy["default_expiry_hours"]
else:
if expiry_hours is not None:
expiry_hours = min(int(expiry_hours), dept_policy["max_expiry_hours"])
else:
expiry_hours = dept_policy["default_expiry_hours"]
policy_extensions = dept_policy["extensions"]
else:
policy_extensions = None
+137
View File
@@ -0,0 +1,137 @@
"""Sudoer check and sudo-related endpoints."""
from flask import request
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.models.organization import OrganizationApiKey
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate
from gatehouse_app.models.organization import Department, DepartmentMembership
@api_v1_bp.route("/sudo/check", methods=["POST"])
def check_sudoer():
"""
Check if a user with a given certificate can sudo.
This endpoint validates an API key for an organization, retrieves the certificate
by serial ID, finds the user and their departments, and checks if any of their
departments have sudo capability.
Request body:
api_key: Organization API key (required)
certificate_serial: Certificate serial ID (required)
Returns:
200: Sudoer status returned
400: Invalid request body
401: Invalid API key
403: Certificate not found or user not found
404: Organization or certificate not found
"""
try:
data = request.get_json()
if not data:
return api_response(
success=False,
message="Request body is required",
status=400,
error_type="INVALID_REQUEST",
)
api_key = data.get("api_key")
certificate_serial = data.get("certificate_serial")
if not api_key or certificate_serial is None:
return api_response(
success=False,
message="api_key and certificate_serial are required",
status=400,
error_type="MISSING_REQUIRED_FIELDS",
)
# Find the certificate by serial
certificate = SSHCertificate.query.filter_by(
serial=certificate_serial,
deleted_at=None
).first()
if not certificate:
return api_response(
success=False,
message="Certificate not found",
status=404,
error_type="NOT_FOUND",
)
# Get the CA and organization
ca = certificate.ca
if not ca:
return api_response(
success=False,
message="Certificate CA not found",
status=404,
error_type="NOT_FOUND",
)
org_id = ca.organization_id
# Verify the API key for this organization
org_api_key = OrganizationApiKey.verify_key(org_id, api_key)
if not org_api_key:
return api_response(
success=False,
message="Invalid API key for organization",
status=401,
error_type="UNAUTHORIZED",
)
# Get the user from the certificate
user = certificate.user
if not user:
return api_response(
success=False,
message="Certificate user not found",
status=404,
error_type="NOT_FOUND",
)
# Get all departments the user belongs to
user_departments = DepartmentMembership.query.filter_by(
user_id=user.id,
deleted_at=None
).all()
# Check if any of the user's departments have sudo capability
can_sudo = False
sudoer_departments = []
for dept_membership in user_departments:
dept = dept_membership.department
if dept and dept.can_sudo and dept.deleted_at is None:
can_sudo = True
sudoer_departments.append({
"id": dept.id,
"name": dept.name,
})
return api_response(
data={
"can_sudo": can_sudo,
"user_id": user.id,
"user_email": user.email,
"certificate_serial": certificate.serial,
"sudoer_departments": sudoer_departments,
"all_departments_count": len(user_departments),
},
message="Sudoer status retrieved successfully",
status=200,
)
except Exception as e:
return api_response(
success=False,
message=f"An error occurred: {str(e)}",
status=500,
error_type="INTERNAL_ERROR",
)
+338 -17
View File
@@ -2,8 +2,10 @@
from flask import g, request
from marshmallow import Schema, fields, validate, ValidationError
from sqlalchemy.exc import IntegrityError
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.extensions import db
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
from gatehouse_app.services import portal_network_service
@@ -19,6 +21,8 @@ from gatehouse_app.models import (
ActivationSession,
)
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
from gatehouse_app.exceptions import (
ValidationError as AppValidationError,
ZeroTierAPIError,
@@ -39,6 +43,17 @@ def _org_check(org_id):
return org, None
def _is_org_admin(org_id: str, user_id: str) -> bool:
"""Return True if the user is an admin or owner of the org."""
return OrganizationMember.query.filter(
OrganizationMember.organization_id == org_id,
OrganizationMember.user_id == user_id,
OrganizationMember.role.in_([OrganizationRole.ADMIN, OrganizationRole.OWNER]),
OrganizationMember.deleted_at.is_(None),
).first() is not None
# ── Schemas ───────────────────────────────────────────────────────────────────
@@ -154,6 +169,63 @@ def create_network(org_id):
return api_response(success=False, message=str(e.message), status=400, error_type=e.error_type)
except ZeroTierAPIError as e:
return api_response(success=False, message=str(e), status=502, error_type=e.error_type)
except IntegrityError:
db.session.rollback()
return api_response(
success=False,
message="A portal network with this ZeroTier ID already exists in this organization.",
status=409,
error_type="DUPLICATE_NETWORK",
)
@api_v1_bp.route("/organizations/<org_id>/zerotier/available-networks", methods=["GET"])
@login_required
@require_admin
@full_access_required
def list_zerotier_available_networks(org_id):
"""List all ZeroTier networks from the org's ZT controller/account.
Cross-references against managed portal networks so the UI can show
which ones are already imported and which can be imported.
"""
org, err = _org_check(org_id)
if err:
return err
# Fetch all active portal networks for this org, keyed by ZT network ID
managed = {
pn.zerotier_network_id: pn
for pn in PortalNetwork.query.filter(
PortalNetwork.organization_id == org_id,
PortalNetwork.deleted_at.is_(None),
).all()
}
try:
zt_networks = zt.list_networks(organization_id=org_id)
except ZeroTierAPIError as e:
# Return an empty list with a flag so the UI can show a helpful message
# rather than an error page (e.g. "ZeroTier not configured yet").
return api_response(
data={"networks": [], "count": 0, "zt_error": str(e)},
message="ZeroTier unavailable — no networks returned",
)
result = []
for zt_net in zt_networks:
portal = managed.get(zt_net.id)
result.append({
**zt_net.to_dict(),
"already_managed": portal is not None,
"portal_network_id": portal.id if portal else None,
"portal_network_name": portal.name if portal else None,
})
return api_response(
data={"networks": result, "count": len(result)},
message="Available ZeroTier networks retrieved",
)
@api_v1_bp.route("/organizations/<org_id>/networks/<network_id>", methods=["GET"])
@@ -346,6 +418,9 @@ def update_device(org_id, device_id):
except ValidationError as e:
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
if "nickname" in data:
data["device_nickname"] = data.pop("nickname")
try:
device = device_service.update_device(device_id, g.current_user.id, **data)
return api_response(data={"device": device.to_dict()}, message="Device updated successfully")
@@ -520,6 +595,25 @@ def assign_access(org_id):
return api_response(success=False, message=str(e.message), status=400, error_type=e.error_type)
@api_v1_bp.route("/organizations/<org_id>/admin/approvals", methods=["GET"])
@login_required
@require_admin
@full_access_required
def admin_list_all_approvals(org_id):
"""List ALL approval records across all users in the org (admin only)."""
org, err = _org_check(org_id)
if err:
return err
network_id = request.args.get("network_id")
state = request.args.get("state")
approvals = network_access_service.list_all_org_approvals(org_id, network_id=network_id, state=state)
return api_response(
data={"approvals": [a.to_dict() for a in approvals], "count": len(approvals)},
message="Approvals retrieved successfully",
)
# ── Memberships ───────────────────────────────────────────────────────────────
@@ -548,7 +642,7 @@ def list_memberships(org_id):
@login_required
@full_access_required
def activate_membership(org_id, membership_id):
"""Activate an approved device membership."""
"""Activate an approved device membership. Admins can activate any membership; regular members can only activate their own."""
org, err = _org_check(org_id)
if err:
return err
@@ -559,11 +653,14 @@ def activate_membership(org_id, membership_id):
except ValidationError as e:
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
is_admin = _is_org_admin(org_id, g.current_user.id)
try:
session = network_access_service.activate_device_membership(
membership_id=membership_id,
user_id=g.current_user.id,
lifetime_minutes=data.get("lifetime_minutes"),
admin_override=is_admin,
)
membership = DeviceNetworkMembership.query.get(membership_id)
return api_response(data={"session": session.to_dict(), "membership": membership.to_dict()}, message="Membership activated successfully")
@@ -577,11 +674,21 @@ def activate_membership(org_id, membership_id):
@login_required
@full_access_required
def deactivate_membership(org_id, membership_id):
"""Deactivate an active device membership."""
"""Deactivate an active device membership. Admins can deactivate any; regular members can only deactivate their own."""
org, err = _org_check(org_id)
if err:
return err
# Verify ownership for non-admins
if not _is_org_admin(org_id, g.current_user.id):
membership_check = DeviceNetworkMembership.query.filter(
DeviceNetworkMembership.id == membership_id,
DeviceNetworkMembership.user_id == g.current_user.id,
DeviceNetworkMembership.deleted_at.is_(None),
).first()
if not membership_check:
return api_response(success=False, message="Membership not found", status=404, error_type="NOT_FOUND")
try:
membership = network_access_service.deactivate_membership(
membership_id=membership_id,
@@ -597,7 +704,7 @@ def deactivate_membership(org_id, membership_id):
@login_required
@full_access_required
def activate_all_memberships(org_id):
"""Bulk-activate all approved inactive memberships."""
"""Bulk-activate all of the caller's approved inactive memberships in this org."""
org, err = _org_check(org_id)
if err:
return err
@@ -744,6 +851,7 @@ def trigger_kill_switch(org_id):
event = network_access_service.kill_switch(
target_user_id=data["target_user_id"],
triggered_by_user_id=g.current_user.id,
organization_id=org_id,
scope=data.get("scope", "organization"),
reason=data.get("reason"),
network_ids=data.get("network_ids"),
@@ -794,12 +902,20 @@ def admin_delete_membership(org_id, membership_id):
@api_v1_bp.route("/admin/zerotier/status", methods=["GET"])
@login_required
@require_admin
@full_access_required
def zerotier_status():
"""Check ZeroTier controller connectivity and status (admin only)."""
"""Check ZeroTier controller connectivity and status.
Requires ?org_id=<uuid> credentials are looked up from that org.
Caller must be an admin/owner of that specific org.
"""
org_id = request.args.get("org_id")
if not org_id:
return api_response(success=False, message="org_id query parameter is required", status=400, error_type="VALIDATION_ERROR")
if not _is_org_admin(org_id, g.current_user.id):
return api_response(success=False, message="Admin or owner role required for this organization", status=403, error_type="AUTHORIZATION_ERROR")
try:
status = zt.get_status()
status = zt.get_status(organization_id=org_id)
return api_response(data={"status": status}, message="ZeroTier controller is reachable")
except ZeroTierAPIError as e:
return api_response(success=False, message=str(e), status=502, error_type=e.error_type)
@@ -807,12 +923,20 @@ def zerotier_status():
@api_v1_bp.route("/admin/zerotier/networks", methods=["GET"])
@login_required
@require_admin
@full_access_required
def zerotier_list_networks():
"""List networks from the ZeroTier controller (admin only)."""
"""List networks from the ZeroTier controller.
Requires ?org_id=<uuid> credentials are looked up from that org.
Caller must be an admin/owner of that specific org.
"""
org_id = request.args.get("org_id")
if not org_id:
return api_response(success=False, message="org_id query parameter is required", status=400, error_type="VALIDATION_ERROR")
if not _is_org_admin(org_id, g.current_user.id):
return api_response(success=False, message="Admin or owner role required for this organization", status=403, error_type="AUTHORIZATION_ERROR")
try:
networks = zt.list_networks()
networks = zt.list_networks(organization_id=org_id)
return api_response(
data={"networks": [n.to_dict() if hasattr(n, 'to_dict') else {"id": getattr(n, "id", str(n))} for n in networks], "count": len(networks)},
message="Networks retrieved successfully",
@@ -823,12 +947,20 @@ def zerotier_list_networks():
@api_v1_bp.route("/admin/zerotier/networks/<network_id>", methods=["GET"])
@login_required
@require_admin
@full_access_required
def zerotier_get_network(network_id):
"""Get a ZeroTier network from the controller (admin only)."""
"""Get a ZeroTier network from the controller.
Requires ?org_id=<uuid> credentials are looked up from that org.
Caller must be an admin/owner of that specific org.
"""
org_id = request.args.get("org_id")
if not org_id:
return api_response(success=False, message="org_id query parameter is required", status=400, error_type="VALIDATION_ERROR")
if not _is_org_admin(org_id, g.current_user.id):
return api_response(success=False, message="Admin or owner role required for this organization", status=403, error_type="AUTHORIZATION_ERROR")
try:
network = zt.get_network(network_id)
network = zt.get_network(network_id, organization_id=org_id)
return api_response(data={"network": network.to_dict()}, message="Network retrieved successfully")
except ZeroTierAPIError as e:
return api_response(success=False, message=str(e), status=502, error_type=e.error_type)
@@ -836,12 +968,20 @@ def zerotier_get_network(network_id):
@api_v1_bp.route("/admin/zerotier/networks/<network_id>/members", methods=["GET"])
@login_required
@require_admin
@full_access_required
def zerotier_list_members(network_id):
"""List members on a ZeroTier network from the controller (admin only)."""
"""List members on a ZeroTier network from the controller.
Requires ?org_id=<uuid> credentials are looked up from that org.
Caller must be an admin/owner of that specific org.
"""
org_id = request.args.get("org_id")
if not org_id:
return api_response(success=False, message="org_id query parameter is required", status=400, error_type="VALIDATION_ERROR")
if not _is_org_admin(org_id, g.current_user.id):
return api_response(success=False, message="Admin or owner role required for this organization", status=403, error_type="AUTHORIZATION_ERROR")
try:
members = zt.list_members(network_id)
members = zt.list_members(network_id, organization_id=org_id)
return api_response(
data={"members": [m.to_dict() for m in members], "count": len(members)},
message="Members retrieved successfully",
@@ -852,9 +992,190 @@ def zerotier_list_members(network_id):
@api_v1_bp.route("/admin/zerotier/reconcile", methods=["POST"])
@login_required
@require_admin
@full_access_required
def trigger_reconciliation():
"""Trigger full reconciliation across all networks (admin only)."""
"""Trigger full reconciliation across all networks (requires org admin in at least one org)."""
from gatehouse_app.models.organization.organization_member import OrganizationMember
is_any_admin = OrganizationMember.query.filter(
OrganizationMember.user_id == g.current_user.id,
OrganizationMember.role.in_([OrganizationRole.ADMIN, OrganizationRole.OWNER]),
OrganizationMember.deleted_at.is_(None),
).first() is not None
if not is_any_admin:
return api_response(success=False, message="Admin or owner role required", status=403, error_type="AUTHORIZATION_ERROR")
result = zerotier_reconciliation_service.reconcile_all()
return api_response(data=result, message="Reconciliation complete")
# ── Per-org ZeroTier configuration ───────────────────────────────────────────
class ZeroTierConfigSchema(Schema):
zt_api_token = fields.Str(required=True, validate=validate.Length(min=1, max=512))
zt_api_url = fields.Str(required=True, validate=validate.Length(min=1, max=512))
zt_api_mode = fields.Str(
required=True,
validate=validate.OneOf(["central", "controller"]),
)
@api_v1_bp.route("/organizations/<org_id>/zerotier-config", methods=["GET"])
@login_required
@require_admin
@full_access_required
def get_zerotier_config(org_id):
"""Return the current ZeroTier configuration for an organization (admin only).
The token is masked only its presence is indicated, not the value.
"""
org, err = _org_check(org_id)
if err:
return err
return api_response(
data={
"zerotier_config": {
"zt_api_token_set": bool(org.zt_api_token),
"zt_api_url": org.zt_api_url,
"zt_api_mode": org.zt_api_mode,
}
},
message="ZeroTier configuration retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/zerotier-config", methods=["PUT"])
@login_required
@require_admin
@full_access_required
def set_zerotier_config(org_id):
"""Set (or replace) the ZeroTier credentials for an organization (admin only).
All three fields are required there are no server-level defaults.
Body:
zt_api_token (required) API token for ZeroTier Central or authtoken.secret
zt_api_url (required) full base URL, e.g. http://host:9993 or
https://api.zerotier.com/api/v1
zt_api_mode (required) "central" | "controller"
"""
org, err = _org_check(org_id)
if err:
return err
try:
schema = ZeroTierConfigSchema()
data = schema.load(request.json or {})
except ValidationError as e:
return api_response(
success=False, message="Validation failed",
status=400, error_type="VALIDATION_ERROR", error_details=e.messages,
)
# Test connectivity BEFORE saving — reject bad credentials early
connectivity_ok = False
connectivity_error = None
# Temporarily set the credentials so _get_client() can build a client
old_token, old_url, old_mode = org.zt_api_token, org.zt_api_url, org.zt_api_mode
org.zt_api_token = data["zt_api_token"]
org.zt_api_url = data["zt_api_url"]
org.zt_api_mode = data["zt_api_mode"]
db.session.flush() # make visible to _get_client query without committing
try:
zt.get_status(organization_id=org_id)
connectivity_ok = True
except ZeroTierAPIError as exc:
connectivity_error = str(exc)
except Exception as exc:
connectivity_error = str(exc)
if not connectivity_ok:
# Roll back — don't persist bad credentials
org.zt_api_token = old_token
org.zt_api_url = old_url
org.zt_api_mode = old_mode
db.session.commit()
return api_response(
success=False,
message="Controller Connectivity test failed",
status=400,
error_type="ZEROTIER_CONNECTIVITY_FAILED",
error_details={
"connectivity_test": {
"ok": False,
"error": connectivity_error,
},
},
)
# Connectivity verified — commit the new credentials
org.save()
from gatehouse_app.services.audit_service import AuditService
AuditService.log_action(
action="org.zerotier_config.updated",
user_id=g.current_user.id,
organization_id=org_id,
resource_type="organization",
resource_id=org_id,
metadata={
"zt_api_url": org.zt_api_url,
"zt_api_mode": org.zt_api_mode,
"connectivity_ok": connectivity_ok,
},
description="Organization ZeroTier config updated",
success=True,
)
return api_response(
data={
"zerotier_config": {
"zt_api_token_set": True,
"zt_api_url": org.zt_api_url,
"zt_api_mode": org.zt_api_mode,
},
"connectivity_test": {
"ok": True,
"error": None,
},
},
message="ZeroTier configuration saved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/zerotier-config", methods=["DELETE"])
@login_required
@require_admin
@full_access_required
def delete_zerotier_config(org_id):
"""Remove the org-level ZeroTier credentials (admin only).
After removal, all ZeroTier operations for this organization will fail
until new credentials
are configured via the ZeroTier Config page.
"""
org, err = _org_check(org_id)
if err:
return err
org.zt_api_token = None
org.zt_api_url = None
org.zt_api_mode = None
org.save()
from gatehouse_app.services.audit_service import AuditService
AuditService.log_action(
action="org.zerotier_config.deleted",
user_id=g.current_user.id,
organization_id=org_id,
resource_type="organization",
resource_id=org_id,
metadata={},
description="Organization ZeroTier config removed — ZeroTier operations disabled until reconfigured",
success=True,
)
return api_response(message="ZeroTier configuration removed. Configure new credentials to re-enable ZeroTier features.")
+4 -1
View File
@@ -8,19 +8,22 @@ class BaseAPIException(Exception):
error_type = "INTERNAL_ERROR"
message = "An unexpected error occurred"
def __init__(self, message=None, error_details=None):
def __init__(self, message=None, error_details=None, status_code=None):
"""
Initialize exception.
Args:
message: Custom error message
error_details: Additional error details dictionary
status_code: Override the class-level HTTP status code
"""
super().__init__(self.message)
if message:
self.message = message
super().__init__(message) # update args so str(e) works
self.error_details = error_details or {}
if status_code is not None:
self.status_code = status_code
def to_dict(self):
"""Convert exception to dictionary for API response."""
@@ -12,6 +12,7 @@ from gatehouse_app.models.organization.department_cert_policy import (
)
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
from gatehouse_app.models.organization.organization_api_key import OrganizationApiKey
__all__ = [
"Organization",
@@ -24,4 +25,5 @@ __all__ = [
"Principal",
"PrincipalMembership",
"OrgInviteToken",
"OrganizationApiKey",
]
@@ -27,6 +27,7 @@ class Department(BaseModel):
)
name = db.Column(db.String(255), nullable=False, index=True)
description = db.Column(db.Text, nullable=True)
can_sudo = db.Column(db.Boolean, default=False, nullable=False)
# Relationships
organization = db.relationship("Organization", back_populates="departments")
@@ -4,12 +4,13 @@ from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
# Standard SSH certificate extensions
# Standard SSH certificate extensions — must be in strict lexical order
# (OpenSSH RFC 4251 §5 / golang.org/x/crypto/ssh requires lexical ordering)
STANDARD_EXTENSIONS = [
"permit-X11-forwarding",
"permit-agent-forwarding",
"permit-pty",
"permit-port-forwarding",
"permit-pty",
"permit-user-rc",
]
@@ -17,6 +17,10 @@ class Organization(BaseModel):
# Settings (stored as JSON)
settings = db.Column(db.JSON, nullable=True, default=dict)
zt_api_token = db.Column(db.String(512), nullable=True)
zt_api_url = db.Column(db.String(512), nullable=True)
zt_api_mode = db.Column(db.String(32), nullable=True) # "central" | "controller"
# Relationships
members = db.relationship(
"OrganizationMember", back_populates="organization", cascade="all, delete-orphan"
@@ -43,6 +47,9 @@ class Organization(BaseModel):
cas = db.relationship(
"CA", back_populates="organization", cascade="all, delete-orphan"
)
api_keys = db.relationship(
"OrganizationApiKey", back_populates="organization", cascade="all, delete-orphan"
)
def __repr__(self):
"""String representation of Organization."""
@@ -0,0 +1,158 @@
"""Organization API Key model — API keys for organizations for external integrations."""
import secrets
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OrganizationApiKey(BaseModel):
"""API Key model representing an API key for an organization.
API keys are used to authenticate external integrations or services
that need programmatic access to the organization's resources.
Each key is tied to an organization and can be revoked/deleted as needed.
"""
__tablename__ = "organization_api_keys"
organization_id = db.Column(
db.String(36),
db.ForeignKey("organizations.id"),
nullable=False,
index=True,
)
# Human-readable name for the API key
name = db.Column(db.String(255), nullable=False)
# Hashed key value (never store plain text)
key_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
# Last used timestamp for tracking activity
last_used_at = db.Column(db.DateTime, nullable=True)
# Revocation status
is_revoked = db.Column(db.Boolean, default=False, nullable=False, index=True)
revoked_at = db.Column(db.DateTime, nullable=True)
revoke_reason = db.Column(db.String(255), nullable=True)
# Description/purpose of the key
description = db.Column(db.Text, nullable=True)
# Relationships
organization = db.relationship("Organization", back_populates="api_keys")
__table_args__ = (
db.Index("idx_org_api_key_org_active", "organization_id", "is_revoked"),
db.Index("idx_api_key_last_used", "last_used_at"),
)
def __repr__(self):
"""String representation of OrganizationApiKey."""
return f"<OrganizationApiKey name={self.name} org_id={self.organization_id}>"
@staticmethod
def generate_key() -> str:
"""Generate a random API key.
Returns:
A random 32-byte hex string suitable for use as an API key
"""
return secrets.token_hex(32)
@classmethod
def create_key(
cls,
organization_id: str,
name: str,
description: str = None,
) -> tuple:
"""Create and store a new API key for an organization.
Args:
organization_id: ID of the organization
name: Human-readable name for the key
description: Optional description/purpose of the key
Returns:
Tuple of (OrganizationApiKey instance, plain_text_key_string)
The plain text key is only returned on creation and should be
stored securely by the user. It cannot be retrieved later.
"""
# Generate a plain text key
plain_key = cls.generate_key()
# Hash it using the key_hash method
key_hash = cls.hash_key(plain_key)
# Create the database record
api_key = cls(
organization_id=organization_id,
name=name,
key_hash=key_hash,
description=description,
)
api_key.save()
return api_key, plain_key
@staticmethod
def hash_key(plain_key: str) -> str:
"""Hash an API key for storage.
Args:
plain_key: The plain text API key
Returns:
Hashed version of the key
"""
import hashlib
return hashlib.sha256(plain_key.encode()).hexdigest()
@classmethod
def verify_key(cls, organization_id: str, plain_key: str) -> "OrganizationApiKey":
"""Verify an API key for an organization.
Args:
organization_id: ID of the organization
plain_key: The plain text API key to verify
Returns:
OrganizationApiKey instance if valid and active, None otherwise
"""
key_hash = cls.hash_key(plain_key)
api_key = cls.query.filter_by(
organization_id=organization_id,
key_hash=key_hash,
is_revoked=False,
deleted_at=None,
).first()
if api_key:
# Update last used timestamp
api_key.last_used_at = datetime.now(timezone.utc)
api_key.save()
return api_key
def revoke(self, reason: str = None) -> None:
"""Revoke this API key.
Args:
reason: Optional reason for revocation
"""
self.is_revoked = True
self.revoked_at = datetime.now(timezone.utc)
self.revoke_reason = reason
self.save()
def to_dict(self, exclude=None):
"""Convert API key to dictionary.
The key_hash is excluded by default for security.
"""
exclude = exclude or []
if "key_hash" not in exclude:
exclude.append("key_hash")
return super().to_dict(exclude=exclude)
+8 -1
View File
@@ -1,10 +1,15 @@
"""Certificate Authority (CA) model."""
import time
from enum import Enum
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
def _serial_start() -> int:
return int(time.time() * 1000)
class KeyType(str, Enum):
"""SSH CA key types."""
@@ -91,7 +96,9 @@ class CA(BaseModel):
# Monotonically-increasing serial counter. Every cert this CA issues
# gets the next value so serials are unique, ordered, and auditable.
# Protected by a row-level SELECT … FOR UPDATE in get_next_serial().
next_serial_number = db.Column(db.BigInteger, default=1, nullable=False)
# Initialised to the current Unix timestamp in milliseconds so serials
# are globally unique across CAs from the moment of creation.
next_serial_number = db.Column(db.BigInteger, default=_serial_start, nullable=False)
# Relationships
organization = db.relationship("Organization", back_populates="cas")
@@ -50,7 +50,7 @@ class SSHCertificate(BaseModel):
certificate = db.Column(db.Text, nullable=False)
# Certificate metadata
serial = db.Column(db.String(255), nullable=False, unique=True, index=True)
serial = db.Column(db.String(255), nullable=False)
key_id = db.Column(db.String(255), nullable=False) # Usually user email
cert_type = db.Column(
db.Enum(CertType, values_callable=lambda x: [e.value for e in x]),
@@ -103,6 +103,8 @@ class SSHCertificate(BaseModel):
)
__table_args__ = (
db.UniqueConstraint("ca_id", "serial", name="uq_ssh_certificates_ca_serial"),
db.Index("ix_ssh_certificates_serial", "serial"),
db.Index("idx_cert_user_status", "user_id", "status"),
db.Index("idx_cert_validity", "valid_after", "valid_before"),
db.Index("idx_cert_revoked", "revoked", "revoked_at"),
@@ -45,16 +45,16 @@ class ActivationSession(BaseModel):
index=True,
)
authenticated_at = db.Column(
db.DateTime(timezone=True),
db.DateTime,
nullable=False,
)
expires_at = db.Column(
db.DateTime(timezone=True),
db.DateTime,
nullable=False,
)
ended_at = db.Column(db.DateTime(timezone=True), nullable=True)
ended_at = db.Column(db.DateTime, nullable=True)
end_reason = db.Column(
db.Enum(ActivationEndReason, name="activation_end_reason"),
db.Enum(ActivationEndReason, name="activation_end_reason", values_callable=lambda x: [e.value for e in x]),
nullable=True,
)
created_by = db.Column(
+1 -1
View File
@@ -47,7 +47,7 @@ class Device(BaseModel):
asset_tag = db.Column(db.String(255), nullable=True)
serial_number = db.Column(db.String(255), nullable=True)
status = db.Column(
db.Enum(DeviceStatus, name="device_status"),
db.Enum(DeviceStatus, name="device_status", values_callable=lambda x: [e.value for e in x]),
default=DeviceStatus.ACTIVE,
nullable=False,
)
@@ -58,7 +58,7 @@ class DeviceNetworkMembership(BaseModel):
index=True,
)
state = db.Column(
db.Enum(MembershipState, name="membership_state"),
db.Enum(MembershipState, name="membership_state", values_callable=lambda x: [e.value for e in x]),
default=MembershipState.PENDING_DEVICE_REGISTRATION,
nullable=False,
index=True,
@@ -35,7 +35,7 @@ class KillSwitchEvent(BaseModel):
index=True,
)
scope = db.Column(
db.Enum(KillSwitchScope, name="kill_switch_scope"),
db.Enum(KillSwitchScope, name="kill_switch_scope", values_callable=lambda x: [e.value for e in x]),
default=KillSwitchScope.ORGANIZATION,
nullable=False,
)
@@ -45,12 +45,12 @@ class PortalNetwork(BaseModel):
index=True,
)
environment = db.Column(
db.Enum(NetworkEnvironment, name="network_environment"),
db.Enum(NetworkEnvironment, name="network_environment", values_callable=lambda x: [e.value for e in x]),
default=NetworkEnvironment.DEVELOPMENT,
nullable=False,
)
request_mode = db.Column(
db.Enum(NetworkRequestMode, name="network_request_mode"),
db.Enum(NetworkRequestMode, name="network_request_mode", values_callable=lambda x: [e.value for e in x]),
default=NetworkRequestMode.APPROVAL_REQUIRED,
nullable=False,
)
@@ -48,12 +48,12 @@ class UserNetworkApproval(BaseModel):
nullable=True,
)
grant_type = db.Column(
db.Enum(ApprovalGrantType, name="approval_grant_type"),
db.Enum(ApprovalGrantType, name="approval_grant_type", values_callable=lambda x: [e.value for e in x]),
default=ApprovalGrantType.REQUESTED,
nullable=False,
)
state = db.Column(
db.Enum(ApprovalState, name="approval_state"),
db.Enum(ApprovalState, name="approval_state", values_callable=lambda x: [e.value for e in x]),
default=ApprovalState.PENDING,
nullable=False,
index=True,
@@ -51,8 +51,8 @@ class ZeroTierMembership(BaseModel):
)
member_seen = db.Column(db.Boolean, default=False, nullable=False)
authorized = db.Column(db.Boolean, default=False, nullable=False)
join_seen_at = db.Column(db.DateTime(timezone=True), nullable=True)
last_synced_at = db.Column(db.DateTime(timezone=True), nullable=True)
join_seen_at = db.Column(db.DateTime, nullable=True)
last_synced_at = db.Column(db.DateTime, nullable=True)
raw_controller_payload = db.Column(db.JSON, nullable=True)
# Relationships
@@ -33,6 +33,7 @@ from gatehouse_app.exceptions import (
DeviceNotFoundError,
ApprovalAlreadyExistsError,
ValidationError,
ZeroTierAPIError,
)
logger = logging.getLogger(__name__)
@@ -74,9 +75,30 @@ def request_access(
raise ApprovalAlreadyExistsError(
"An access request or approval already exists for this user and network."
)
existing.state = ApprovalState.PENDING
is_open = network.request_mode.value == "open"
existing.state = ApprovalState.APPROVED if is_open else ApprovalState.PENDING
existing.justification = justification
existing.save()
existing_membership = DeviceNetworkMembership.query.filter(
DeviceNetworkMembership.user_network_approval_id == existing.id,
DeviceNetworkMembership.device_id == device_id,
DeviceNetworkMembership.deleted_at.is_(None),
).first()
if not existing_membership:
membership_state = MembershipState.APPROVED_INACTIVE if is_open else MembershipState.PENDING_DEVICE_REGISTRATION
membership = DeviceNetworkMembership(
organization_id=organization_id,
user_id=user_id,
device_id=device_id,
portal_network_id=portal_network_id,
user_network_approval_id=existing.id,
state=membership_state,
approved_for_activation=is_open,
)
membership.save()
_ensure_zerotier_member(device.node_id, portal_network_id, authorized=False)
return existing
is_open = network.request_mode.value == "open"
@@ -329,6 +351,23 @@ def list_user_approvals(user_id: str, organization_id: str) -> list[UserNetworkA
).all()
def list_all_org_approvals(
organization_id: str,
network_id: str | None = None,
state: str | None = None,
) -> list[UserNetworkApproval]:
"""List all approval records across all users in an org (admin use)."""
q = UserNetworkApproval.query.filter(
UserNetworkApproval.organization_id == organization_id,
UserNetworkApproval.deleted_at.is_(None),
)
if network_id:
q = q.filter(UserNetworkApproval.portal_network_id == network_id)
if state:
q = q.filter(UserNetworkApproval.state == state)
return q.order_by(UserNetworkApproval.created_at.desc()).all()
# ── Membership materialisation ───────────────────────────────────────────────
@@ -428,11 +467,12 @@ def activate_device_membership(
membership_id: str,
user_id: str,
lifetime_minutes: int | None = None,
admin_override: bool = False,
) -> ActivationSession:
"""Activate an approved device on a network. Creates an activation session and authorizes in ZT."""
membership = _get_membership(membership_id)
if membership.user_id != user_id:
if not admin_override and membership.user_id != user_id:
raise MembershipNotFoundError("Membership not found.")
# Check approval is still active
@@ -536,7 +576,8 @@ def deactivate_membership(
# Deauthorize in ZeroTier
device = Device.query.get(membership.device_id)
network = PortalNetwork.query.get(membership.portal_network_id)
_deauthorize_in_zerotier(device.node_id, network.zerotier_network_id)
_deauthorize_in_zerotier(device.node_id, network.zerotier_network_id,
organization_id=membership.organization_id)
membership.state = MembershipState.APPROVED_INACTIVE
membership.currently_authorized = False
@@ -567,6 +608,7 @@ def kill_switch(
target_user_id: str,
triggered_by_user_id: str,
scope: str,
organization_id: str | None = None,
reason: str | None = None,
network_ids: list[str] | None = None,
) -> KillSwitchEvent:
@@ -579,14 +621,18 @@ def kill_switch(
DeviceNetworkMembership.deleted_at.is_(None),
)
org_id = None
org_id = organization_id # Use caller-supplied org_id as the primary source
if scope_enum == KillSwitchScope.ORGANIZATION:
# Use the first membership's org
first = q.first()
org_id = first.organization_id if first else None
if not org_id:
# Fall back to deriving from first active membership
first = q.first()
org_id = first.organization_id if first else None
else:
# Scope query to the specified org
q = q.filter(DeviceNetworkMembership.organization_id == org_id)
elif scope_enum == KillSwitchScope.SELECTED_NETWORKS and network_ids:
q = q.filter(DeviceNetworkMembership.portal_network_id.in_(network_ids))
if network_ids:
if not org_id:
first_network = PortalNetwork.query.filter(
PortalNetwork.id.in_(network_ids),
PortalNetwork.deleted_at.is_(None),
@@ -594,7 +640,7 @@ def kill_switch(
org_id = first_network.organization_id if first_network else None
if not org_id:
org_id = network_ids[0] if network_ids else None
raise ValidationError("Cannot determine organization for kill switch event.")
# Create kill switch event
event = KillSwitchEvent(
@@ -608,14 +654,16 @@ def kill_switch(
event.save()
# Suspend all approvals
ApprovalState._value2member_map_ # just reference
approvals = UserNetworkApproval.query.filter(
UserNetworkApproval.user_id == target_user_id,
UserNetworkApproval.state == ApprovalState.APPROVED,
UserNetworkApproval.deleted_at.is_(None),
).all()
for approval in approvals:
if scope_enum == KillSwitchScope.SELECTED_NETWORKS and network_ids:
if scope_enum == KillSwitchScope.ORGANIZATION and org_id:
if approval.organization_id != org_id:
continue
elif scope_enum == KillSwitchScope.SELECTED_NETWORKS and network_ids:
if approval.portal_network_id not in network_ids:
continue
approval.state = ApprovalState.SUSPENDED
@@ -691,7 +739,8 @@ def _ensure_zerotier_member(
return
try:
zt.add_member(network.zerotier_network_id, node_id, authorized=authorized)
zt.add_member(network.zerotier_network_id, node_id, authorized=authorized,
organization_id=network.organization_id)
except Exception as exc:
logger.warning(
f"[_ensure_zerotier_member] Could not add member {node_id} "
@@ -705,7 +754,8 @@ def _authorize_in_zerotier(
membership: DeviceNetworkMembership,
) -> None:
try:
zt.authorize_member(zerotier_network_id, node_id)
zt.authorize_member(zerotier_network_id, node_id,
organization_id=membership.organization_id)
# Update zerotier_membership cache
zt_membership = ZeroTierMembership.query.filter(
@@ -740,6 +790,11 @@ def _authorize_in_zerotier(
success=True,
)
except ZeroTierAPIError as exc:
logger.warning(
f"[_authorize_in_zerotier] ZeroTier unavailable — skipping authorization "
f"for {node_id} on {zerotier_network_id}: {exc}"
)
except Exception as exc:
logger.error(
f"[_authorize_in_zerotier] Failed to authorize {node_id} "
@@ -748,9 +803,11 @@ def _authorize_in_zerotier(
raise
def _deauthorize_in_zerotier(node_id: str, zerotier_network_id: str) -> None:
def _deauthorize_in_zerotier(node_id: str, zerotier_network_id: str,
organization_id: str | None = None) -> None:
try:
zt.deauthorize_member(zerotier_network_id, node_id)
zt.deauthorize_member(zerotier_network_id, node_id,
organization_id=organization_id)
zt_membership = ZeroTierMembership.query.filter(
ZeroTierMembership.zerotier_network_id == zerotier_network_id,
@@ -940,7 +997,8 @@ def revoke_membership_soft(
if device and network:
try:
zt.deauthorize_member(network.zerotier_network_id, device.node_id)
zt.deauthorize_member(network.zerotier_network_id, device.node_id,
organization_id=membership.organization_id)
except Exception as exc:
logger.warning(f"[revoke_membership_soft] ZT deauthorize failed for {device.node_id}: {exc}")
@@ -984,7 +1042,8 @@ def hard_delete_membership(membership_id: str) -> None:
if device and network:
try:
zt.delete_network_member(network.zerotier_network_id, device.node_id)
zt.delete_network_member(network.zerotier_network_id, device.node_id,
organization_id=membership.organization_id)
logger.info(f"[hard_delete_membership] Deleted {device.node_id} from ZT network {network.zerotier_network_id}")
except Exception as exc:
logger.warning(f"[hard_delete_membership] ZT delete failed for {device.node_id}: {exc}")
+78 -97
View File
@@ -17,6 +17,7 @@ from datetime import datetime, timezone
from typing import Optional, Dict, Any
import logging
import json
import threading
from gatehouse_app.extensions import db
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
@@ -78,29 +79,22 @@ class NotificationService:
)
# Send the notification
success = NotificationService._send_email(
NotificationService._send_email_async(
to_address=user.email,
subject=subject,
body=body,
)
if success:
logger.info(
f"Sent MFA deadline reminder to {user.email} "
f"({days_until_deadline} days remaining)"
)
AuditService.log_action(
action=AuditAction.MFA_POLICY_USER_COMPLIANT,
user_id=user.id,
organization_id=compliance.organization_id,
description=f"MFA deadline reminder sent. Days remaining: {days_until_deadline}",
)
else:
logger.warning(
f"Failed to send MFA deadline reminder to {user.email}"
)
return success
logger.info(
f"Sent MFA deadline reminder to {user.email} "
f"({days_until_deadline} days remaining)"
)
AuditService.log_action(
action=AuditAction.MFA_POLICY_USER_COMPLIANT,
user_id=user.id,
organization_id=compliance.organization_id,
description=f"MFA deadline reminder sent. Days remaining: {days_until_deadline}",
)
return True
except Exception as e:
logger.exception(f"Error sending MFA deadline reminder to {user.email}: {e}")
@@ -136,27 +130,19 @@ class NotificationService:
)
# Send the notification
success = NotificationService._send_email(
NotificationService._send_email_async(
to_address=user.email,
subject=subject,
body=body,
)
if success:
logger.info(f"Sent MFA suspension notification to {user.email}")
# Audit log
AuditService.log_action(
action=AuditAction.MFA_POLICY_USER_SUSPENDED,
user_id=user.id,
organization_id=compliance.organization_id,
description="MFA compliance suspension notification sent",
)
else:
logger.warning(
f"Failed to send MFA suspension notification to {user.email}"
)
return success
logger.info(f"Sent MFA suspension notification to {user.email}")
AuditService.log_action(
action=AuditAction.MFA_POLICY_USER_SUSPENDED,
user_id=user.id,
organization_id=compliance.organization_id,
description="MFA compliance suspension notification sent",
)
return True
except Exception as e:
logger.exception(
@@ -285,89 +271,84 @@ Gatehouse Security Team
return body
@staticmethod
def _send_email(
def _send_email_async(
to_address: str,
subject: str,
body: str,
html_body: Optional[str] = None,
) -> bool:
"""Send an email via SMTP.
) -> None:
"""Send an email on a daemon thread so the calling request returns immediately.
Returns True if the email was sent successfully, False otherwise.
If EMAIL_ENABLED is False, logs the email body instead (simulation mode).
If EMAIL_ENABLED is False, logs instead of sending.
All SMTP exceptions are caught and logged this method never raises.
The Flask app context is pushed inside the thread so current_app works correctly.
"""
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from flask import current_app
email_enabled = current_app.config.get(NotificationService.EMAIL_ENABLED_KEY, False)
app = current_app._get_current_object() # capture real app before leaving request context
if not email_enabled:
logger.info(
f"[EMAIL DISABLED] Would have sent to: {to_address} | Subject: {subject}\n"
f"Body: {body[:500]}"
)
return False
def _send():
with app.app_context():
email_enabled = app.config.get(NotificationService.EMAIL_ENABLED_KEY, False)
if not email_enabled:
logger.info(
f"[EMAIL DISABLED] Would have sent to: {to_address} | Subject: {subject}\n"
f"Body: {body[:500]}"
)
return
smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY, "")
smtp_port_raw = current_app.config.get(NotificationService.SMTP_PORT_KEY, 587)
smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY)
smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY)
from_address = current_app.config.get(
NotificationService.FROM_ADDRESS_KEY, ""
)
smtp_host = app.config.get(NotificationService.SMTP_HOST_KEY, "")
smtp_port_raw = app.config.get(NotificationService.SMTP_PORT_KEY, 587)
smtp_username = app.config.get(NotificationService.SMTP_USERNAME_KEY)
smtp_password = app.config.get(NotificationService.SMTP_PASSWORD_KEY)
from_address = app.config.get(NotificationService.FROM_ADDRESS_KEY, "")
# Guard: refuse to attempt a connection when critical config is missing.
# This surfaces a clear log message instead of a confusing socket error.
missing = [k for k, v in [
("SMTP_HOST", smtp_host),
("FROM_ADDRESS", from_address),
] if not v]
if missing:
logger.error(
f"[EMAIL] Cannot send — missing config: {', '.join(missing)}. "
f"Would have sent to: {to_address} | Subject: {subject}"
)
return False
missing = [k for k, v in [("SMTP_HOST", smtp_host), ("FROM_ADDRESS", from_address)] if not v]
if missing:
logger.error(
f"[EMAIL] Cannot send — missing config: {', '.join(missing)}. "
f"Would have sent to: {to_address} | Subject: {subject}"
)
return
try:
smtp_port = int(smtp_port_raw)
except (TypeError, ValueError):
logger.error(f"[EMAIL] Invalid SMTP_PORT value: {smtp_port_raw!r}")
return False
try:
smtp_port = int(smtp_port_raw)
except (TypeError, ValueError):
logger.error(f"[EMAIL] Invalid SMTP_PORT value: {smtp_port_raw!r}")
return
smtp_use_tls = current_app.config.get(
NotificationService.SMTP_USE_TLS_KEY,
smtp_port not in (25, 1025),
)
smtp_use_tls = app.config.get(
NotificationService.SMTP_USE_TLS_KEY,
smtp_port not in (25, 1025),
)
try:
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["From"] = from_address
msg["To"] = to_address
msg.attach(MIMEText(body, "plain"))
if html_body:
msg.attach(MIMEText(html_body, "html"))
try:
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["From"] = from_address
msg["To"] = to_address
msg.attach(MIMEText(body, "plain"))
if html_body:
msg.attach(MIMEText(html_body, "html"))
with smtplib.SMTP(smtp_host, smtp_port) as server:
server.ehlo()
if smtp_use_tls:
server.starttls()
server.ehlo()
if smtp_username and smtp_password:
server.login(smtp_username, smtp_password)
server.send_message(msg)
with smtplib.SMTP(smtp_host, smtp_port) as server:
server.ehlo()
if smtp_use_tls:
server.starttls()
server.ehlo()
if smtp_username and smtp_password:
server.login(smtp_username, smtp_password)
server.send_message(msg)
logger.info(f"[EMAIL] Sent to {to_address} | Subject: {subject}")
return True
logger.info(f"[EMAIL] Sent to {to_address} | Subject: {subject}")
except Exception as e:
logger.error(f"[EMAIL] Failed to send to {to_address}: {e}")
return False
except Exception as e:
logger.error(f"[EMAIL] Failed to send to {to_address}: {e}")
threading.Thread(target=_send, daemon=True).start()
@staticmethod
def get_notification_stats(user_id: str) -> Dict[str, Any]:
+112 -10
View File
@@ -9,7 +9,7 @@ from gatehouse_app.models.organization import Organization
from gatehouse_app.models.user import User
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.services import zerotier_api_service as zt
from gatehouse_app.utils.constants import NetworkRequestMode
from gatehouse_app.utils.constants import NetworkRequestMode, NetworkEnvironment
from gatehouse_app.exceptions import (
NetworkNotFoundError,
InvalidNetworkIdError,
@@ -57,23 +57,74 @@ def create_network(
default_activation_lifetime_minutes: Default session length
max_activation_lifetime_minutes: Cap on activation lifetime
"""
from gatehouse_app.utils.constants import NetworkEnvironment
zerotier_network_id = _validate_network_id(zerotier_network_id)
existing = PortalNetwork.query.filter(
existing_active = PortalNetwork.query.filter(
PortalNetwork.organization_id == organization_id,
PortalNetwork.zerotier_network_id == zerotier_network_id,
PortalNetwork.deleted_at.is_(None),
).first()
if existing:
if existing_active:
raise ValidationError(
f"A portal network already exists for ZT network {zerotier_network_id} "
f"in this organization."
)
env = NetworkEnvironment(environment) if environment else NetworkEnvironment.DEVELOPMENT
mode = NetworkRequestMode(request_mode)
# Normalize to lowercase so callers may pass "PRODUCTION" or "production" interchangeably
env_str = environment.lower() if environment else None
mode_str = request_mode.lower() if request_mode else "approval_required"
try:
env = NetworkEnvironment(env_str) if env_str else NetworkEnvironment.DEVELOPMENT
except ValueError:
valid = [e.value for e in NetworkEnvironment]
raise ValidationError(f"Invalid environment '{environment}'. Must be one of: {valid}")
try:
mode = NetworkRequestMode(mode_str)
except ValueError:
valid = [e.value for e in NetworkRequestMode]
raise ValidationError(f"Invalid request_mode '{request_mode}'. Must be one of: {valid}")
# If a soft-deleted record for the same (org, zt_network_id) pair exists, restore it
# rather than inserting a new row (which would violate the unique constraint).
deleted = PortalNetwork.query.filter(
PortalNetwork.organization_id == organization_id,
PortalNetwork.zerotier_network_id == zerotier_network_id,
PortalNetwork.deleted_at.isnot(None),
).first()
if deleted:
logger.info(
f"[PortalNetwork] Restoring soft-deleted portal network {deleted.id} "
f"for ZT network {zerotier_network_id}"
)
deleted.deleted_at = None
deleted.name = name
deleted.description = description
deleted.owner_user_id = owner_user_id
deleted.environment = env
deleted.request_mode = mode
deleted.default_activation_lifetime_minutes = default_activation_lifetime_minutes
deleted.max_activation_lifetime_minutes = max_activation_lifetime_minutes
deleted.is_active = True
deleted.save()
AuditService.log_action(
action="zt.network.restored",
user_id=owner_user_id,
organization_id=organization_id,
resource_type="portal_network",
resource_id=deleted.id,
metadata={
"zerotier_network_id": zerotier_network_id,
"name": name,
"environment": env.value,
"request_mode": mode.value,
},
description=f"Portal network '{name}' restored (ZT: {zerotier_network_id})",
success=True,
)
return deleted
network = PortalNetwork(
organization_id=organization_id,
@@ -90,7 +141,7 @@ def create_network(
# Try to verify the network exists in ZeroTier
try:
zt_network = zt.get_network(zerotier_network_id)
zt_network = zt.get_network(zerotier_network_id, organization_id=organization_id)
logger.info(
f"[PortalNetwork] Verified ZT network {zerotier_network_id} "
f"exists in ZeroTier: {zt_network.name}"
@@ -100,7 +151,7 @@ def create_network(
f"[PortalNetwork] ZT network {zerotier_network_id} not found "
"in ZeroTier — will be reconciled later."
)
except ZeroTierAPIError as exc:
except (ZeroTierAPIError, Exception) as exc:
logger.warning(
f"[PortalNetwork] Could not verify ZT network {zerotier_network_id}: {exc}"
)
@@ -175,6 +226,23 @@ def update_network(
if key not in allowed:
raise ValidationError(f"Cannot update field: {key}")
# Normalize environment / request_mode strings to lowercase enum values
if "environment" in kwargs and isinstance(kwargs["environment"], str):
env_str = kwargs["environment"].lower()
try:
kwargs["environment"] = NetworkEnvironment(env_str)
except ValueError:
valid = [e.value for e in NetworkEnvironment]
raise ValidationError(f"Invalid environment '{kwargs['environment']}'. Must be one of: {valid}")
if "request_mode" in kwargs and isinstance(kwargs["request_mode"], str):
mode_str = kwargs["request_mode"].lower()
try:
kwargs["request_mode"] = NetworkRequestMode(mode_str)
except ValueError:
valid = [e.value for e in NetworkRequestMode]
raise ValidationError(f"Invalid request_mode '{kwargs['request_mode']}'. Must be one of: {valid}")
network.update(**kwargs)
AuditService.log_action(
@@ -192,7 +260,11 @@ def update_network(
def delete_network(network_id: str, user_id: str) -> None:
"""Soft-delete a portal network and deactivate all memberships."""
"""Soft-delete a portal network and deactivate/clean up all related records."""
from datetime import datetime, timezone
from gatehouse_app.models import UserNetworkApproval
from gatehouse_app.extensions import db
network = get_network(network_id)
# Deauthorize all active memberships in ZeroTier
@@ -203,6 +275,36 @@ def delete_network(network_id: str, user_id: str) -> None:
network.delete(soft=True)
# Cascade soft-delete all active approvals and memberships for this network.
now = datetime.now(timezone.utc)
db.session.execute(
db.text(
"UPDATE user_network_approvals AS a "
"SET deleted_at = :now + (s.rn * interval '1 microsecond') "
"FROM ("
" SELECT id, row_number() OVER () AS rn "
" FROM user_network_approvals "
" WHERE portal_network_id = :network_id AND deleted_at IS NULL"
") s "
"WHERE a.id = s.id"
),
{"now": now, "network_id": network_id},
)
db.session.execute(
db.text(
"UPDATE device_network_memberships AS m "
"SET deleted_at = :now + (s.rn * interval '1 microsecond') "
"FROM ("
" SELECT id, row_number() OVER () AS rn "
" FROM device_network_memberships "
" WHERE portal_network_id = :network_id AND deleted_at IS NULL"
") s "
"WHERE m.id = s.id"
),
{"now": now, "network_id": network_id},
)
db.session.commit()
AuditService.log_action(
action="zt.network.deleted",
user_id=user_id,
@@ -288,6 +288,12 @@ class SSHCASigningService:
else:
extensions = [] # host certs: no extensions
# OpenSSH (RFC 4251 §5) and golang.org/x/crypto/ssh require
# certificate extensions to be in strict lexical (alphabetical) order.
# Sort unconditionally so any caller-supplied or policy-derived list
# is guaranteed to be compliant.
extensions = sorted(extensions)
certificate.fields.extensions = extensions
certificate.fields.critical_options = signing_request.critical_options or {}
+35 -1
View File
@@ -206,6 +206,40 @@ class SSHKeyService:
return challenge_text
@staticmethod
def _decode_signature(signature: str) -> bytes:
"""Decode a user-supplied signature into raw SSH signature bytes.
Accepts either:
1. The raw SSH armored signature (-----BEGIN SSH SIGNATURE-----)
2. A base64-encoded version of that armored signature
(produced by ``cat file.sig | base64 -w0``)
Returns the raw armored signature bytes suitable for writing to a
``.sig`` file that ``ssh-keygen -Y verify`` can read.
"""
stripped = signature.strip()
# If it already looks like a raw SSH signature armor, use it directly
if stripped.startswith("-----BEGIN SSH SIGNATURE-----"):
return stripped.encode("utf-8")
# Otherwise treat it as base64 — strip any embedded whitespace first
cleaned = stripped.replace("\n", "").replace("\r", "").replace(" ", "")
try:
decoded = base64.b64decode(cleaned)
except Exception as exc:
raise SSHKeyError(f"Could not decode signature: {exc}")
# Sanity-check: the decoded bytes should be a valid SSH signature
text = decoded.decode("utf-8", errors="replace")
if "-----BEGIN SSH SIGNATURE-----" not in text:
raise SSHKeyError(
"Invalid signature format. Please paste the output of: "
"cat /tmp/challenge.txt.sig | base64 -w0"
)
return decoded
def verify_ssh_key_ownership(
self,
key_id: str,
@@ -247,7 +281,7 @@ class SSHKeyService:
# allowed_signers format: "<identity> <keytype> <pubkey>"
# We use the key fingerprint as the identity.
sig_bytes = base64.b64decode(signature)
sig_bytes = self._decode_signature(signature)
challenge_text = key.verify_text + "\n"
with tempfile.TemporaryDirectory() as tmpdir:
+83 -29
View File
@@ -1,7 +1,11 @@
"""ZeroTier API service — thin Flask adapter around the ZeroTierClient SDK.
Reads configuration from app config and translates SDK exceptions to
Secuird typed exceptions.
ZeroTier is managed exclusively at the organization level. Each organization
configures its own ZeroTier credentials (token, URL, mode) via the web UI
(ZeroTier Config page stored in the organizations table).
Every call that interacts with ZeroTier must supply an organization_id so the correct org credentials
can be loaded from the database.
"""
import logging
@@ -19,97 +23,147 @@ from gatehouse_app.utils.zerotier_client import (
logger = logging.getLogger(__name__)
def _get_client(app=None) -> ZeroTierClient:
"""Build a ZeroTierClient from current app config."""
from flask import current_app
def _get_client(organization_id: Optional[str] = None, app=None) -> ZeroTierClient:
"""Build a ZeroTierClient using the organization's stored ZeroTier credentials.
app = app or current_app
Credentials are read exclusively from the organization record
(org.zt_api_token / org.zt_api_url / org.zt_api_mode).
Args:
organization_id: The org whose credentials should be used.
Required for any ZeroTier operation.
app: Flask app instance (defaults to current_app, only needed for
background tasks that run outside a request context).
Raises:
ZeroTierAPIError: If organization_id is missing, the org is not found,
or the org has incomplete ZeroTier credentials.
"""
if not organization_id:
raise ZeroTierAPIError(
"organization_id is required — ZeroTier credentials are managed "
"per-organization. Configure them via the ZeroTier Config page."
)
try:
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.extensions import db
org = db.session.get(Organization, organization_id)
except Exception as exc:
logger.error(f"[ZT] Failed to load org {organization_id} from DB: {exc}")
raise ZeroTierAPIError(
f"Could not load organization {organization_id}: {exc}"
) from exc
if not org:
raise ZeroTierAPIError(f"Organization {organization_id} not found.")
token: Optional[str] = org.zt_api_token or None
if not token:
raise ZeroTierAPIError(
f"Organization '{org.name}' has no ZeroTier credentials configured. "
"Go to Settings → ZeroTier Config to add a token, mode, and controller URL."
)
mode_str = (org.zt_api_mode or "").strip().lower()
if mode_str not in ("central", "controller"):
raise ZeroTierAPIError(
f"Organization '{org.name}' has no ZeroTier mode set. "
"Go to Settings → ZeroTier Config and select 'Central' or 'Controller'."
)
url: str = (org.zt_api_url or "").strip()
if not url:
raise ZeroTierAPIError(
f"Organization '{org.name}' has no ZeroTier controller/API URL set. "
"Go to Settings → ZeroTier Config and enter the URL for your ZeroTier "
"controller (e.g. http://host:9993) or Central API."
)
mode_str = app.config.get("ZEROTIER_API_MODE", "controller")
mode = APIMode.CENTRAL if mode_str == "central" else APIMode.CONTROLLER
return ZeroTierClient(
api_token=app.config.get("ZEROTIER_API_TOKEN", ""),
base_url=app.config.get("ZEROTIER_API_URL", "http://localhost:9993"),
mode=mode,
logger.debug(
f"[ZT] Client for org:{organization_id} mode={mode_str} url={url}"
)
return ZeroTierClient(api_token=token, base_url=url, mode=mode)
def get_status() -> dict:
def get_status(organization_id: Optional[str] = None) -> dict:
"""Verify connectivity to the ZeroTier controller."""
client = _get_client()
client = _get_client(organization_id)
try:
return client.get_status()
except SDKZeroTierAPIError as exc:
raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc
def list_networks():
def list_networks(organization_id: Optional[str] = None):
"""List all networks accessible to the configured token."""
client = _get_client()
client = _get_client(organization_id)
try:
return client.list_networks()
except SDKZeroTierAPIError as exc:
raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc
def get_network(network_id: str):
def get_network(network_id: str, organization_id: Optional[str] = None):
"""Fetch a single network by ID."""
client = _get_client()
client = _get_client(organization_id)
try:
return client.get_network(network_id)
except SDKZeroTierAPIError as exc:
raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc
def list_members(network_id: str):
def list_members(network_id: str, organization_id: Optional[str] = None):
"""List all members on a network."""
client = _get_client()
client = _get_client(organization_id)
try:
return client.list_members(network_id)
except SDKZeroTierAPIError as exc:
raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc
def get_member(network_id: str, node_id: str):
def get_member(network_id: str, node_id: str, organization_id: Optional[str] = None):
"""Fetch a single member on a network."""
client = _get_client()
client = _get_client(organization_id)
try:
return client.get_member(network_id, node_id)
except SDKZeroTierAPIError as exc:
raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc
def authorize_member(network_id: str, node_id: str):
def authorize_member(network_id: str, node_id: str, organization_id: Optional[str] = None):
"""Authorize a member on a network. Returns updated member."""
client = _get_client()
client = _get_client(organization_id)
try:
return client.authorize_member(network_id, node_id)
except SDKZeroTierAPIError as exc:
raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc
def deauthorize_member(network_id: str, node_id: str):
def deauthorize_member(network_id: str, node_id: str, organization_id: Optional[str] = None):
"""De-authorize a member on a network. Returns updated member."""
client = _get_client()
client = _get_client(organization_id)
try:
return client.deauthorize_member(network_id, node_id)
except SDKZeroTierAPIError as exc:
raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc
def add_member(network_id: str, node_id: str, authorized: bool = False):
def add_member(network_id: str, node_id: str, authorized: bool = False, organization_id: Optional[str] = None):
"""Manually add/pre-provision a member on a network."""
client = _get_client()
client = _get_client(organization_id)
try:
return client.add_member(network_id, node_id, authorized=authorized)
except SDKZeroTierAPIError as exc:
raise ZeroTierAPIError(str(exc), status_code=exc.status_code) from exc
def delete_network_member(network_id: str, node_id: str):
def delete_network_member(network_id: str, node_id: str, organization_id: Optional[str] = None):
"""Remove a member entirely from a ZeroTier network."""
client = _get_client()
client = _get_client(organization_id)
try:
return client.delete_member(network_id, node_id)
except SDKZeroTierAPIError as exc:
@@ -1,6 +1,7 @@
"""ZeroTier reconciliation service — polling loop to sync state with the controller."""
import logging
import time
from datetime import datetime, timezone
from gatehouse_app.extensions import db
@@ -34,16 +35,24 @@ def reconcile_expired_activations() -> int:
ActivationSession.deleted_at.is_(None),
).all()
logger.debug(f"[Reconciliation] Expiry check: {len(expired)} overdue session(s) found.")
count = 0
for session in expired:
try:
_expire_session(session)
count += 1
except Exception as exc:
logger.error(f"[Reconciliation] Failed to expire session {session.id}: {exc}")
logger.error(
f"[Reconciliation] Failed to expire session {session.id} "
f"(user={session.user_id} membership={session.device_network_membership_id}): {exc}",
exc_info=True,
)
if count > 0:
logger.info(f"[Reconciliation] Expired {count} activation sessions.")
logger.info(f"[Reconciliation] Expired {count} activation session(s).")
else:
logger.debug("[Reconciliation] No activation sessions to expire.")
return count
@@ -55,9 +64,14 @@ def reconcile_network(portal_network_id: str) -> dict:
"""
network = PortalNetwork.query.get(portal_network_id)
if not network or not network.is_active:
logger.debug(
f"[Reconciliation] Skipping portal_network_id={portal_network_id}: "
f"{'not found' if not network else 'inactive or deleted'}."
)
return {"skipped": True, "reason": "network_inactive_or_deleted"}
zerotier_network_id = network.zerotier_network_id
network_label = f"{network.name} ({zerotier_network_id})"
actions = {
"zt_members_checked": 0,
"zt_members_added": 0,
@@ -67,15 +81,25 @@ def reconcile_network(portal_network_id: str) -> dict:
"unknown_members": [],
}
t_start = time.monotonic()
logger.debug(f"[Reconciliation] Starting network reconciliation for {network_label}.")
# Get current ZT members
try:
zt_members = {m.node_id: m for m in zt.list_members(zerotier_network_id)}
zt_members = {m.node_id: m for m in zt.list_members(zerotier_network_id,
organization_id=network.organization_id)}
except Exception as exc:
logger.error(f"[Reconciliation] Failed to list ZT members for {zerotier_network_id}: {exc}")
logger.error(
f"[Reconciliation] Failed to list ZT members for {network_label}: {exc}",
exc_info=True,
)
actions["error"] = str(exc)
return actions
actions["zt_members_checked"] = len(zt_members)
logger.debug(
f"[Reconciliation] {network_label}: {len(zt_members)} member(s) fetched from ZT controller."
)
# Get our portal memberships for this network
our_memberships = {
@@ -87,13 +111,21 @@ def reconcile_network(portal_network_id: str) -> dict:
if m.device and m.device.deleted_at is None
}
logger.debug(
f"[Reconciliation] {network_label}: {len(our_memberships)} portal membership(s) to reconcile."
)
# Reconcile each portal membership
for node_id, membership in our_memberships.items():
zt_member = zt_members.pop(node_id, None)
device = membership.device
if not zt_member:
# Member not seen in ZT yet
# Member not seen in ZT yet — could be freshly joined or never connected
logger.debug(
f"[Reconciliation] {network_label}: node {node_id} "
f"(device={device.display_name!r}, state={membership.state}) not yet seen in ZT controller."
)
continue
actions["join_seen_updated"] += 1
@@ -104,31 +136,67 @@ def reconcile_network(portal_network_id: str) -> dict:
# Sync authorization state
if membership.state == MembershipState.ACTIVE_AUTHORIZED:
if not zt_member.is_authorized:
# We think it's active but ZT says it's not — re-authorize
# Portal says active but ZT disagrees — drift, re-authorize
logger.warning(
f"[Reconciliation] {network_label}: DRIFT detected — portal=ACTIVE_AUTHORIZED "
f"but ZT says unauthorized for node {node_id} (device={device.display_name!r}). Re-authorizing."
)
try:
zt.authorize_member(zerotier_network_id, node_id)
zt.authorize_member(zerotier_network_id, node_id,
organization_id=network.organization_id)
actions["authorized"] += 1
logger.info(
f"[Reconciliation] {network_label}: Re-authorized node {node_id} (device={device.display_name!r})."
)
except Exception as exc:
logger.warning(f"[Reconciliation] Re-authorize failed for {node_id}: {exc}")
logger.warning(
f"[Reconciliation] {network_label}: Re-authorize failed for node {node_id}: {exc}"
)
else:
logger.debug(
f"[Reconciliation] {network_label}: node {node_id} — portal=ACTIVE_AUTHORIZED, ZT=authorized. OK."
)
else:
if zt_member.is_authorized:
# We think it's not authorized but ZT says it is — deauthorize
# (could be manual override in ZT console)
# ZT says authorized but portal doesn't — could be manual override in ZT console
logger.warning(
f"[Reconciliation] {network_label}: DRIFT detected — portal state={membership.state} "
f"but ZT says authorized for node {node_id} (device={device.display_name!r}). Deauthorizing."
)
try:
zt.deauthorize_member(zerotier_network_id, node_id)
zt.deauthorize_member(zerotier_network_id, node_id,
organization_id=network.organization_id)
actions["deauthorized"] += 1
logger.info(
f"[Reconciliation] {network_label}: Deauthorized node {node_id} (device={device.display_name!r})."
)
except Exception as exc:
logger.warning(f"[Reconciliation] Deauthorize failed for {node_id}: {exc}")
logger.warning(
f"[Reconciliation] {network_label}: Deauthorize failed for node {node_id}: {exc}"
)
else:
logger.debug(
f"[Reconciliation] {network_label}: node {node_id}"
f"portal={membership.state}, ZT=unauthorized. OK."
)
# Unknown ZT members not in our portal
actions["unknown_members"] = list(zt_members.keys())
# Unknown ZT members not in our portal — log only, do not touch
unknown = list(zt_members.keys())
actions["unknown_members"] = unknown
if unknown:
logger.warning(
f"[Reconciliation] {network_label}: {len(unknown)} ZT member(s) not in portal — "
f"node IDs: {', '.join(unknown)}"
)
elapsed_ms = int((time.monotonic() - t_start) * 1000)
logger.info(
f"[Reconciliation] Network {zerotier_network_id}: "
f"[Reconciliation] Network {network_label}: "
f"checked={actions['zt_members_checked']} "
f"authorized={actions['authorized']} "
f"deauthorized={actions['deauthorized']} "
f"unknown={len(actions['unknown_members'])}"
f"unknown={len(actions['unknown_members'])} "
f"elapsed={elapsed_ms}ms"
)
return actions
@@ -144,16 +212,34 @@ def reconcile_all() -> dict:
PortalNetwork.deleted_at.is_(None),
).all()
results = {"networks_processed": 0, "errors": 0}
logger.info(f"[Reconciliation] reconcile_all: {len(networks)} active network(s) to process.")
results = {"networks_processed": 0, "errors": 0, "authorized": 0, "deauthorized": 0, "unknown_members": []}
for network in networks:
try:
result = reconcile_network(network.id)
if "error" in result:
logger.error(
f"[Reconciliation] Network {network.name} ({network.zerotier_network_id}) "
f"failed: {result['error']}"
)
results["errors"] += 1
elif result.get("skipped"):
logger.debug(
f"[Reconciliation] Network {network.name} ({network.zerotier_network_id}) "
f"skipped: {result.get('reason')}"
)
else:
results["networks_processed"] += 1
results["authorized"] += result.get("authorized", 0)
results["deauthorized"] += result.get("deauthorized", 0)
results["unknown_members"].extend(result.get("unknown_members", []))
except Exception as exc:
logger.error(f"[Reconciliation] Failed to reconcile network {network.id}: {exc}")
logger.error(
f"[Reconciliation] Unhandled error reconciling network "
f"{network.name} ({network.id}): {exc}",
exc_info=True,
)
results["errors"] += 1
deleted_result = reconcile_deleted_memberships()
@@ -161,8 +247,11 @@ def reconcile_all() -> dict:
results["delete_errors"] = deleted_result.get("errors", 0)
logger.info(
f"[Reconciliation] Complete: {results['networks_processed']} networks processed, "
f"{results['errors']} errors, {results.get('deleted_memberships', 0)} memberships purged."
f"[Reconciliation] Complete: "
f"networks={results['networks_processed']} "
f"errors={results['errors']} "
f"purged={results.get('deleted_memberships', 0)} "
f"purge_errors={results.get('delete_errors', 0)}"
)
return results
@@ -180,8 +269,11 @@ def reconcile_deleted_memberships() -> dict:
).all()
if not deleted:
logger.debug("[Reconciliation] No soft-deleted memberships to purge.")
return {"deleted": 0, "errors": 0}
logger.info(f"[Reconciliation] Purging {len(deleted)} soft-deleted membership(s) from ZT and DB.")
results = {"deleted": 0, "errors": 0}
for membership in deleted:
try:
@@ -189,30 +281,49 @@ def reconcile_deleted_memberships() -> dict:
network = PortalNetwork.query.get(membership.portal_network_id)
if not device or not network:
logger.warning(
f"[Reconciliation] Membership {membership.id}: missing "
f"{'device' if not device else 'network'} — hard-deleting record only."
)
db.session.delete(membership)
db.session.commit()
results["deleted"] += 1
continue
node_id = device.node_id
zt_network_id = network.zerotier_network_id
network_label = f"{network.name} ({zt_network_id})"
try:
zt.delete_network_member(network.zerotier_network_id, device.node_id)
logger.info(f"[Reconciliation] Deleted {device.node_id} from ZT network {network.zerotier_network_id}")
zt.delete_network_member(zt_network_id, node_id,
organization_id=network.organization_id)
logger.info(
f"[Reconciliation] Removed node {node_id} (device={device.display_name!r}) "
f"from ZT network {network_label}."
)
except Exception as zt_exc:
logger.warning(
f"[Reconciliation] ZT delete failed for {device.node_id} "
f"on {network.zerotier_network_id}: {zt_exc}"
f"[Reconciliation] ZT delete failed for node {node_id} "
f"on {network_label}: {zt_exc} — proceeding with DB hard-delete."
)
db.session.delete(membership)
db.session.commit()
results["deleted"] += 1
logger.debug(
f"[Reconciliation] Hard-deleted membership {membership.id} "
f"(node={node_id}, network={network_label})."
)
except Exception as exc:
logger.error(f"[Reconciliation] Failed to hard-delete membership {membership.id}: {exc}")
logger.error(
f"[Reconciliation] Failed to hard-delete membership {membership.id}: {exc}",
exc_info=True,
)
results["errors"] += 1
if results["deleted"] > 0:
logger.info(f"[Reconciliation] Purged {results['deleted']} memberships.")
logger.info(f"[Reconciliation] Purged {results['deleted']} membership(s).")
return results
@@ -228,7 +339,12 @@ def _sync_zt_membership(membership: DeviceNetworkMembership, zt_member) -> None:
ZeroTierMembership.deleted_at.is_(None),
).first()
if not zt_membership:
is_new = zt_membership is None
if is_new:
logger.debug(
f"[Reconciliation] Creating new ZeroTierMembership cache record for "
f"node {device.node_id} on network {network.zerotier_network_id}."
)
zt_membership = ZeroTierMembership(
organization_id=membership.organization_id,
device_network_membership_id=membership.id,
@@ -236,6 +352,8 @@ def _sync_zt_membership(membership: DeviceNetworkMembership, zt_member) -> None:
node_id=device.node_id,
)
prev_authorized = zt_membership.authorized if not is_new else None
zt_membership.member_seen = True
zt_membership.authorized = zt_member.is_authorized
zt_membership.last_synced_at = datetime.now(timezone.utc)
@@ -248,11 +366,27 @@ def _sync_zt_membership(membership: DeviceNetworkMembership, zt_member) -> None:
zt_membership.save()
if not is_new and prev_authorized != zt_member.is_authorized:
logger.info(
f"[Reconciliation] ZT auth state changed for node {device.node_id} "
f"(device={device.display_name!r}): {prev_authorized}{zt_member.is_authorized}"
)
# Update membership join_seen flag
if not membership.join_seen:
logger.info(
f"[Reconciliation] First join seen for node {device.node_id} "
f"(device={device.display_name!r}, membership={membership.id}). "
f"State: {membership.state}{MembershipState.JOINED_DEAUTHORIZED}"
)
membership.join_seen = True
membership.state = MembershipState.JOINED_DEAUTHORIZED
membership.save()
else:
logger.debug(
f"[Reconciliation] Synced ZT membership for node {device.node_id} "
f"(device={device.display_name!r}, authorized={zt_member.is_authorized})."
)
def _expire_session(session: ActivationSession) -> None:
@@ -261,8 +395,19 @@ def _expire_session(session: ActivationSession) -> None:
session.end_reason = ActivationEndReason.EXPIRED
session.save()
logger.info(
f"[Reconciliation] Expiring activation session {session.id} "
f"(user={session.user_id}, membership={session.device_network_membership_id}, "
f"expired_at={session.expires_at.isoformat()})."
)
membership = DeviceNetworkMembership.query.get(session.device_network_membership_id)
if membership:
if not membership:
logger.warning(
f"[Reconciliation] Session {session.id}: membership "
f"{session.device_network_membership_id} not found — skipping ZT deauth."
)
else:
membership.state = MembershipState.ACTIVATION_EXPIRED
membership.currently_authorized = False
membership.save()
@@ -270,8 +415,14 @@ def _expire_session(session: ActivationSession) -> None:
device = Device.query.get(membership.device_id)
network = PortalNetwork.query.get(membership.portal_network_id)
if device and network:
network_label = f"{network.name} ({network.zerotier_network_id})"
try:
zt.deauthorize_member(network.zerotier_network_id, device.node_id)
zt.deauthorize_member(network.zerotier_network_id, device.node_id,
organization_id=network.organization_id)
logger.info(
f"[Reconciliation] Deauthorized expired node {device.node_id} "
f"(device={device.display_name!r}) on {network_label}."
)
# Update ZT membership cache
zt_membership = ZeroTierMembership.query.filter(
@@ -283,12 +434,24 @@ def _expire_session(session: ActivationSession) -> None:
zt_membership.authorized = False
zt_membership.last_synced_at = datetime.now(timezone.utc)
zt_membership.save()
else:
logger.debug(
f"[Reconciliation] No ZeroTierMembership cache record found for "
f"node {device.node_id} on {network_label} — nothing to update."
)
except Exception as exc:
logger.warning(
f"[_expire_session] Failed to deauthorize {device.node_id} "
f"on {network.zerotier_network_id}: {exc}"
f"[_expire_session] Failed to deauthorize node {device.node_id} "
f"on {network_label}: {exc}",
exc_info=True,
)
else:
logger.warning(
f"[Reconciliation] Session {session.id}: missing "
f"{'device' if not device else 'network'} for membership "
f"{membership.id} — ZT deauth skipped."
)
from gatehouse_app.services.audit_service import AuditService
AuditService.log_action(
-1
View File
@@ -282,7 +282,6 @@ class KillSwitchScope(str, Enum):
"""Scope of a kill switch event."""
ORGANIZATION = "organization"
GLOBAL = "global"
SELECTED_NETWORKS = "selected_networks"
+61 -12
View File
@@ -1,5 +1,6 @@
"""Management script for Flask application."""
import os
import click
from dotenv import load_dotenv
# Load environment variables FIRST, before any app imports
@@ -153,36 +154,75 @@ def mfa_compliance_status():
@cli.command("configure_oauth")
def configure_oauth():
"""Interactively configure an OAuth provider at the application level.
@click.argument("provider", required=False)
@click.option("--client-id", default=None, help="OAuth client ID")
@click.option("--client-secret", default=None, help="OAuth client secret")
@click.option("--redirect-url", default=None, help="Default redirect URL (e.g. https://yourdomain.com/api/v1/auth/external/<provider>/callback)")
def configure_oauth(provider, client_id, client_secret, redirect_url):
"""Configure an OAuth provider at the application level.
Usage:
Usage (interactive):
python manage.py configure_oauth
Usage (non-interactive):
python manage.py configure_oauth google --client-id ID --client-secret SECRET
Supported providers: google, github, microsoft
"""
import getpass
from gatehouse_app.models.authentication_method import ApplicationProviderConfig
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
from gatehouse_app.extensions import db
SUPPORTED = ["google", "github", "microsoft"]
print("=" * 60)
print("OAuth Provider Configuration")
print("=" * 60)
print(f"Supported providers: {', '.join(SUPPORTED)}")
# Well-known endpoints — stored in additional_config so the adapter can
# resolve auth_url / token_url / userinfo_url without extra logic.
PROVIDER_DEFAULTS = {
"google": {
"auth_url": "https://accounts.google.com/o/oauth2/v2/auth",
"token_url": "https://oauth2.googleapis.com/token",
"userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo",
},
"github": {
"auth_url": "https://github.com/login/oauth/authorize",
"token_url": "https://github.com/login/oauth/access_token",
"userinfo_url": "https://api.github.com/user",
},
"microsoft": {
"auth_url": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
"token_url": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
"userinfo_url": "https://graph.microsoft.com/oidc/userinfo",
},
}
provider = input("Provider [google/github/microsoft]: ").strip().lower()
if not provider:
print("=" * 60)
print("OAuth Provider Configuration")
print("=" * 60)
print(f"Supported providers: {', '.join(SUPPORTED)}")
provider = input("Provider [google/github/microsoft]: ").strip().lower()
provider = provider.strip().lower()
if provider not in SUPPORTED:
print(f"❌ Unknown provider: {provider}")
return
client_id = input("Client ID: ").strip()
if not client_id:
client_id = input("Client ID: ").strip()
if not client_id:
print("❌ client_id is required")
return
client_secret = getpass.getpass("Client Secret (leave blank to keep existing): ").strip()
if not client_secret:
client_secret = getpass.getpass("Client Secret (leave blank to keep existing): ").strip()
if not redirect_url:
base_url = os.getenv("API_BASE_URL", "http://localhost:5000/api/v1")
default = f"{base_url}/auth/external/{provider}/callback"
entered = input(f"Default redirect URL [{default}]: ").strip()
redirect_url = entered or default
additional_config = PROVIDER_DEFAULTS[provider].copy()
with app.app_context():
config = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
@@ -191,6 +231,11 @@ def configure_oauth():
if client_secret:
config.set_client_secret(client_secret)
config.is_enabled = True
config.default_redirect_url = redirect_url
config.additional_config = {
**(config.additional_config or {}),
**additional_config,
}
db.session.commit()
print(f"✅ Updated {provider} provider config.")
else:
@@ -198,12 +243,16 @@ def configure_oauth():
provider_type=provider,
client_id=client_id,
is_enabled=True,
default_redirect_url=redirect_url,
additional_config=additional_config,
)
if client_secret:
config.set_client_secret(client_secret)
db.session.add(config)
db.session.commit()
print(f"✅ Created {provider} provider config.")
print(f" redirect_url : {redirect_url}")
print(f" auth_url : {additional_config['auth_url']}")
@cli.command("list_oauth")
@@ -213,7 +262,7 @@ def list_oauth():
Usage:
python manage.py list_oauth
"""
from gatehouse_app.models.authentication_method import ApplicationProviderConfig
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
with app.app_context():
configs = ApplicationProviderConfig.query.all()
+8 -299
View File
@@ -4,314 +4,23 @@ Revision ID: 020_zerotier
Revises: 019_audit_varchar
Create Date: 2026-03-19
Tables created:
- portal_networks manager-created ZeroTier network bindings
- devices user-registered ZeroTier node endpoints
- user_network_approvals durable manager approval records
- device_network_memberships per-device per-network workflow records
- activation_sessions temporary activation windows
- zerotier_memberships observed controller-side member state
- kill_switch_events explicit rapid deactivation records
SUPERSEDED by 023_zerotier_drop_legacy which creates all ZeroTier tables
idempotently (with IF NOT EXISTS / if_not_exists=True). This migration is
kept as a no-op to preserve the Alembic revision chain for databases that
already have '020_zerotier' stamped (e.g. dev environments).
"""
from alembic import op
import sqlalchemy as sa
revision = "020_zerotier"
down_revision = "019_audit_varchar"
branch_labels = None
depends_on = None
def _pg_enum(enum_name: str, values: list[str]) -> sa.Enum:
return sa.Enum(*values, name=enum_name, create_type=False)
def upgrade():
bind = op.get_bind()
dialect = bind.dialect.name
# ── 1. Enum types ─────────────────────────────────────────────────────────
if dialect == "postgresql":
op.execute("CREATE TYPE network_environment AS ENUM (%s)" % ", ".join(
f"'{v}'" for v in ["production", "staging", "development", "lab"]
))
op.execute("CREATE TYPE network_request_mode AS ENUM (%s)" % ", ".join(
f"'{v}'" for v in ["open", "approval_required", "invite_only"]
))
op.execute("CREATE TYPE approval_grant_type AS ENUM (%s)" % ", ".join(
f"'{v}'" for v in ["requested", "assigned"]
))
op.execute("CREATE TYPE approval_state AS ENUM (%s)" % ", ".join(
f"'{v}'" for v in ["pending", "approved", "rejected", "revoked", "suspended"]
))
op.execute("CREATE TYPE membership_state AS ENUM (%s)" % ", ".join(
f"'{v}'" for v in [
"pending_device_registration",
"pending_request",
"pending_manager_approval",
"approved_inactive",
"joined_deauthorized",
"active_authorized",
"activation_expired",
"suspended",
"revoked",
"rejected",
]
))
op.execute("CREATE TYPE activation_end_reason AS ENUM (%s)" % ", ".join(
f"'{v}'" for v in [
"expired", "logout", "kill_switch",
"manual_revoke", "approval_revoked", "admin_action",
]
))
op.execute("CREATE TYPE kill_switch_scope AS ENUM (%s)" % ", ".join(
f"'{v}'" for v in ["organization", "global", "selected_networks"]
))
op.execute("CREATE TYPE device_status AS ENUM (%s)" % ", ".join(
f"'{v}'" for v in ["active", "inactive"]
))
# ── 2. portal_networks ────────────────────────────────────────────────────
op.create_table(
"portal_networks",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("description", sa.Text, nullable=True),
sa.Column("owner_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
sa.Column("zerotier_network_id", sa.String(16), nullable=False, index=True),
sa.Column(
"environment",
_pg_enum("network_environment", ["production", "staging", "development", "lab"]) if dialect == "postgresql"
else sa.String(20),
nullable=False,
),
sa.Column(
"request_mode",
_pg_enum("network_request_mode", ["open", "approval_required", "invite_only"]) if dialect == "postgresql"
else sa.String(20),
nullable=False,
),
sa.Column("default_activation_lifetime_minutes", sa.Integer, nullable=False, default=480),
sa.Column("max_activation_lifetime_minutes", sa.Integer, nullable=True),
sa.Column("is_active", sa.Boolean, nullable=False, default=True),
)
op.create_index(
"ix_portal_networks_org_zt",
"portal_networks",
["organization_id", "zerotier_network_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# ── 3. devices ───────────────────────────────────────────────────────────
op.create_table(
"devices",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False, index=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True),
sa.Column("node_id", sa.String(10), nullable=False, index=True),
sa.Column("device_nickname", sa.String(255), nullable=True),
sa.Column("hostname", sa.String(255), nullable=True),
sa.Column("asset_tag", sa.String(255), nullable=True),
sa.Column("serial_number", sa.String(255), nullable=True),
sa.Column(
"status",
_pg_enum("device_status", ["active", "inactive"]) if dialect == "postgresql"
else sa.String(20),
nullable=False,
default="active",
),
)
if dialect == "postgresql":
op.create_index(
"ix_devices_node_id_active",
"devices",
["node_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
else:
op.create_index("ix_devices_node_id", "devices", ["node_id"], unique=False)
# ── 4. user_network_approvals ─────────────────────────────────────────────
op.create_table(
"user_network_approvals",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True),
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False, index=True),
sa.Column("portal_network_id", sa.String(36), sa.ForeignKey("portal_networks.id"), nullable=False, index=True),
sa.Column("granted_by_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=True),
sa.Column(
"grant_type",
_pg_enum("approval_grant_type", ["requested", "assigned"]) if dialect == "postgresql"
else sa.String(20),
nullable=False,
default="requested",
),
sa.Column(
"state",
_pg_enum("approval_state", ["pending", "approved", "rejected", "revoked", "suspended"]) if dialect == "postgresql"
else sa.String(20),
nullable=False,
default="pending",
index=True,
),
sa.Column("justification", sa.Text, nullable=True),
)
op.create_index(
"ix_user_network_approvals_user_network",
"user_network_approvals",
["user_id", "portal_network_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# ── 5. device_network_memberships ────────────────────────────────────────
op.create_table(
"device_network_memberships",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True),
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False, index=True),
sa.Column("device_id", sa.String(36), sa.ForeignKey("devices.id"), nullable=False, index=True),
sa.Column("portal_network_id", sa.String(36), sa.ForeignKey("portal_networks.id"), nullable=False, index=True),
sa.Column("user_network_approval_id", sa.String(36), sa.ForeignKey("user_network_approvals.id"), nullable=True, index=True),
sa.Column(
"state",
_pg_enum(
"membership_state",
[
"pending_device_registration", "pending_request",
"pending_manager_approval", "approved_inactive",
"joined_deauthorized", "active_authorized",
"activation_expired", "suspended", "revoked", "rejected",
],
) if dialect == "postgresql" else sa.String(30),
nullable=False,
default="pending_device_registration",
index=True,
),
sa.Column("join_seen", sa.Boolean, nullable=False, default=False),
sa.Column("currently_authorized", sa.Boolean, nullable=False, default=False),
sa.Column("approved_for_activation", sa.Boolean, nullable=False, default=True),
)
op.create_index(
"ix_device_network_memberships_device_network",
"device_network_memberships",
["device_id", "portal_network_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# ── 6. activation_sessions ────────────────────────────────────────────────
op.create_table(
"activation_sessions",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True),
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False, index=True),
sa.Column("device_network_membership_id", sa.String(36), sa.ForeignKey("device_network_memberships.id"), nullable=False, index=True),
sa.Column("authenticated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("ended_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"end_reason",
_pg_enum(
"activation_end_reason",
["expired", "logout", "kill_switch", "manual_revoke", "approval_revoked", "admin_action"],
) if dialect == "postgresql" else sa.String(20),
nullable=True,
),
sa.Column("created_by", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
)
# ── 7. zerotier_memberships ───────────────────────────────────────────────
op.create_table(
"zerotier_memberships",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True),
sa.Column("device_network_membership_id", sa.String(36), sa.ForeignKey("device_network_memberships.id"), nullable=True, index=True),
sa.Column("zerotier_network_id", sa.String(16), nullable=False, index=True),
sa.Column("node_id", sa.String(10), nullable=False, index=True),
sa.Column("member_seen", sa.Boolean, nullable=False, default=False),
sa.Column("authorized", sa.Boolean, nullable=False, default=False),
sa.Column("join_seen_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_synced_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("raw_controller_payload", sa.JSON, nullable=True),
)
op.create_index(
"ix_zerotier_memberships_network_node",
"zerotier_memberships",
["zerotier_network_id", "node_id"],
unique=True,
)
# ── 8. kill_switch_events ────────────────────────────────────────────────
op.create_table(
"kill_switch_events",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False, index=True),
sa.Column("target_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False, index=True),
sa.Column(
"scope",
_pg_enum("kill_switch_scope", ["organization", "global", "selected_networks"]) if dialect == "postgresql"
else sa.String(20),
nullable=False,
default="organization",
),
sa.Column("triggered_by_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
sa.Column("reason", sa.Text, nullable=True),
sa.Column("network_ids", sa.JSON, nullable=True),
)
# No-op — 023_zerotier_drop_legacy handles everything idempotently.
pass
def downgrade():
bind = op.get_bind()
dialect = bind.dialect.name
op.drop_table("kill_switch_events")
op.drop_table("zerotier_memberships")
op.drop_table("activation_sessions")
op.drop_table("device_network_memberships")
op.drop_table("user_network_approvals")
op.drop_table("devices")
op.drop_table("portal_networks")
if dialect == "postgresql":
op.execute("DROP TYPE IF EXISTS kill_switch_scope")
op.execute("DROP TYPE IF EXISTS device_status")
op.execute("DROP TYPE IF EXISTS activation_end_reason")
op.execute("DROP TYPE IF EXISTS membership_state")
op.execute("DROP TYPE IF EXISTS approval_state")
op.execute("DROP TYPE IF EXISTS approval_grant_type")
op.execute("DROP TYPE IF EXISTS network_request_mode")
op.execute("DROP TYPE IF EXISTS network_environment")
# No-op — 023_zerotier_drop_legacy handles rollback.
pass
@@ -0,0 +1,76 @@
"""Seed CA serial counters with a timestamp-based starting value.
Revision ID: 020_ca_serial_timestamp_start
Revises: 019_audit_varchar, d34bfb72844e
Create Date: 2026-03-06
WHY
---
``next_serial_number`` was originally seeded at ``1`` for every CA
(``server_default="1"`` in migration 017). Because the
``ix_ssh_certificates_serial`` index enforces a globally-unique constraint on
the serial column, any two CAs issuing their first certificate would both try
to insert serial ``1``, causing a UniqueViolation.
FIX new CAs
-------------
The CA model's Python-side ``default`` is now ``_serial_start()``, which
returns ``int(time.time() * 1000)`` (Unix milliseconds) at row-creation time.
CAs created after this migration will start their serial counter at the
millisecond they were first inserted, so serials are globally unique across
CAs and still monotonically increasing within each CA.
FIX existing CAs
-------------------
This migration performs a data migration: any CA whose ``next_serial_number``
is still ``<= 2`` (i.e. has issued at most one certificate since the original
``1``-based default) is given a new timestamp-based starting value.
CAs that have already issued many certificates keep their current counter
unchanged their serials are already beyond the low collision-prone range.
NOTE: the ``server_default`` on the column is intentionally NOT changed here
because SQLAlchemy uses the Python-side ``default=_serial_start`` callable for
new rows; the ``server_default`` is only a database-level fallback that is
never hit when rows are inserted via the ORM.
"""
import time
from alembic import op
import sqlalchemy as sa
revision = "020_ca_serial_timestamp_start"
down_revision = ("3de11c5dc2d5", "d34bfb72844e")
branch_labels = None
depends_on = None
def _now_ms() -> int:
return int(time.time() * 1000)
def upgrade():
conn = op.get_bind()
# Update ALL CAs to a timestamp-based starting serial — not just those
# stuck at 1. Any CA with a serial below the current ms timestamp is in
# the low collision-prone range (serials 1N where N is tiny). Resetting
# every CA to a fresh ms timestamp is safe: the counter only moves forward
# from here, and no existing certificate serial is changed.
rows = conn.execute(
sa.text("SELECT id FROM cas")
).fetchall()
for (ca_id,) in rows:
new_start = _now_ms()
conn.execute(
sa.text(
"UPDATE cas SET next_serial_number = :val WHERE id = :id"
),
{"val": new_start, "id": ca_id},
)
def downgrade():
# There is no safe downgrade for a data migration that assigns new serial
# starting points — resetting to 1 would recreate the collision risk.
pass
+22
View File
@@ -0,0 +1,22 @@
"""Merge 020_ca_serial_timestamp_start and 002_add_can_sudo_to_departments into a single head.
Revision ID: 021_merge_heads
Revises: 020_ca_serial_timestamp_start, 002_add_can_sudo_to_departments
Create Date: 2026-03-09
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = '021_merge_heads'
down_revision = ('020_ca_serial_timestamp_start', '002_add_can_sudo_to_departments')
branch_labels = None
depends_on = None
def upgrade():
pass
def downgrade():
pass
@@ -0,0 +1,29 @@
"""Merge zerotier + CA/sudo/api-key branches.
Revision ID: 022_add_command_events
Revises: 020_zerotier, 021_merge_heads
Create Date: 2026-03-09
Pure merge-point for 020_zerotier and 021_merge_heads.
Revision ID kept as-is for compatibility with production databases that
already have '022_add_command_events' stamped in alembic_version.
"""
from alembic import op
# ---------------------------------------------------------------------------
# revision identifiers
# ---------------------------------------------------------------------------
revision = "022_add_command_events"
down_revision = ("020_zerotier", "021_merge_heads")
branch_labels = None
depends_on = None
def upgrade():
pass
def downgrade():
pass
@@ -0,0 +1,393 @@
"""Apply ZeroTier tables and drop legacy SSH-session tables.
Revision ID: 023_apply_zerotier_drop_legacy_ssh_tables
Revises: 022_add_command_events
Create Date: 2026-03-22
CONTEXT
-------
Migration 020_zerotier was never applied to the production database the
alembic_version stamp jumped directly from a pre-zerotier revision to
022_add_command_events. This migration catches the DB up by:
1. Creating all ZeroTier / Portal Network tables (idempotent every
create_table uses if_not_exists=True so it is safe to run on a DB
that already has some of these tables).
2. Dropping the legacy SSH-session tables that no longer have
corresponding ORM models:
- command_events (dropped first has FKs to servers + host_sessions)
- sudo_events (dropped first has FK to host_sessions)
- host_sessions (dropped second referenced by the two above)
- servers (dropped last)
All drops use IF EXISTS so the migration is also safe on a fresh DB
that ran 020_zerotier correctly (those tables would already be absent).
PROD SAFETY
-----------
- All create_table calls use if_not_exists=True.
- All drop_table calls use IF EXISTS via op.execute() for tables that may
or may not be present.
- No data migration; no destructive schema change on tables that still
have ORM models.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector
# ---------------------------------------------------------------------------
revision = "023_zerotier_drop_legacy"
down_revision = "022_add_command_events"
branch_labels = None
depends_on = None
# ---------------------------------------------------------------------------
def _table_exists(conn, table: str) -> bool:
return Inspector.from_engine(conn).has_table(table)
def _index_exists(conn, table: str, index: str) -> bool:
insp = Inspector.from_engine(conn)
return any(i["name"] == index for i in insp.get_indexes(table)) if _table_exists(conn, table) else False
def _type_exists(conn, type_name: str) -> bool:
result = conn.execute(
sa.text("SELECT 1 FROM pg_type WHERE typname = :t"),
{"t": type_name},
).scalar()
return bool(result)
def _pg_enum(name: str) -> sa.Text:
"""Return a plain Text column type for use inside create_table.
We rely on the enum type already existing in PostgreSQL (created above via
'CREATE TYPE ... IF NOT EXISTS'). Using sa.String avoids SQLAlchemy's
automatic 'CREATE TYPE' emission inside create_table, which would fail if
the type already exists. A cast via server_default / CHECK constraint is
not required PostgreSQL accepts varchar literals for enum columns when
inserted from SQLAlchemy's ORM layer, which uses the Python Enum type map.
"""
return sa.String(40)
# ---------------------------------------------------------------------------
# upgrade
# ---------------------------------------------------------------------------
def upgrade():
conn = op.get_bind()
dialect = conn.dialect.name
# ── 1. Enum types (PostgreSQL only, idempotent) ───────────────────────────
if dialect == "postgresql":
enum_defs = {
"network_environment": ["production", "staging", "development", "lab"],
"network_request_mode": ["open", "approval_required", "invite_only"],
"approval_grant_type": ["requested", "assigned"],
"approval_state": ["pending", "approved", "rejected", "revoked", "suspended"],
"membership_state": [
"pending_device_registration", "pending_request",
"pending_manager_approval", "approved_inactive",
"joined_deauthorized", "active_authorized",
"activation_expired", "suspended", "revoked", "rejected",
],
"activation_end_reason": [
"expired", "logout", "kill_switch",
"manual_revoke", "approval_revoked", "admin_action",
],
"kill_switch_scope": ["organization", "global", "selected_networks"],
"device_status": ["active", "inactive"],
}
for type_name, values in enum_defs.items():
if not _type_exists(conn, type_name):
quoted = ", ".join(f"'{v}'" for v in values)
conn.execute(sa.text(f"CREATE TYPE {type_name} AS ENUM ({quoted})"))
# ── 2. portal_networks ────────────────────────────────────────────────────
op.create_table(
"portal_networks",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("description", sa.Text, nullable=True),
sa.Column("owner_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
sa.Column("zerotier_network_id", sa.String(16), nullable=False),
sa.Column("environment", sa.String(40), nullable=False),
sa.Column("request_mode", sa.String(40), nullable=False),
sa.Column("default_activation_lifetime_minutes", sa.Integer, nullable=False, server_default="480"),
sa.Column("max_activation_lifetime_minutes", sa.Integer, nullable=True),
sa.Column("is_active", sa.Boolean, nullable=False, server_default="true"),
if_not_exists=True,
)
if not _index_exists(conn, "portal_networks", "ix_portal_networks_organization_id"):
op.create_index("ix_portal_networks_organization_id", "portal_networks", ["organization_id"])
if not _index_exists(conn, "portal_networks", "ix_portal_networks_zerotier_network_id"):
op.create_index("ix_portal_networks_zerotier_network_id", "portal_networks", ["zerotier_network_id"])
if not _index_exists(conn, "portal_networks", "ix_portal_networks_org_zt"):
op.create_index(
"ix_portal_networks_org_zt", "portal_networks",
["organization_id", "zerotier_network_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# ── 3. devices ────────────────────────────────────────────────────────────
op.create_table(
"devices",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False),
sa.Column("node_id", sa.String(10), nullable=False),
sa.Column("device_nickname", sa.String(255), nullable=True),
sa.Column("hostname", sa.String(255), nullable=True),
sa.Column("asset_tag", sa.String(255), nullable=True),
sa.Column("serial_number", sa.String(255), nullable=True),
sa.Column("status", sa.String(40), nullable=False, server_default="active"),
if_not_exists=True,
)
if not _index_exists(conn, "devices", "ix_devices_user_id"):
op.create_index("ix_devices_user_id", "devices", ["user_id"])
if not _index_exists(conn, "devices", "ix_devices_organization_id"):
op.create_index("ix_devices_organization_id", "devices", ["organization_id"])
if not _index_exists(conn, "devices", "ix_devices_node_id_active") and dialect == "postgresql":
op.create_index(
"ix_devices_node_id_active", "devices", ["node_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
elif not _index_exists(conn, "devices", "ix_devices_node_id") and dialect != "postgresql":
op.create_index("ix_devices_node_id", "devices", ["node_id"])
# ── 4. user_network_approvals ─────────────────────────────────────────────
op.create_table(
"user_network_approvals",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False),
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
sa.Column("portal_network_id", sa.String(36), sa.ForeignKey("portal_networks.id"), nullable=False),
sa.Column("granted_by_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=True),
sa.Column("grant_type", sa.String(40), nullable=False, server_default="requested"),
sa.Column("state", sa.String(40), nullable=False, server_default="pending"),
sa.Column("justification", sa.Text, nullable=True),
if_not_exists=True,
)
if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_organization_id"):
op.create_index("ix_user_network_approvals_organization_id", "user_network_approvals", ["organization_id"])
if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_user_id"):
op.create_index("ix_user_network_approvals_user_id", "user_network_approvals", ["user_id"])
if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_portal_network_id"):
op.create_index("ix_user_network_approvals_portal_network_id", "user_network_approvals", ["portal_network_id"])
if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_state"):
op.create_index("ix_user_network_approvals_state", "user_network_approvals", ["state"])
if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_user_network"):
op.create_index(
"ix_user_network_approvals_user_network", "user_network_approvals",
["user_id", "portal_network_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# ── 5. device_network_memberships ─────────────────────────────────────────
op.create_table(
"device_network_memberships",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False),
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
sa.Column("device_id", sa.String(36), sa.ForeignKey("devices.id"), nullable=False),
sa.Column("portal_network_id", sa.String(36), sa.ForeignKey("portal_networks.id"), nullable=False),
sa.Column("user_network_approval_id", sa.String(36), sa.ForeignKey("user_network_approvals.id"), nullable=True),
sa.Column("state", sa.String(40), nullable=False, server_default="pending_device_registration"),
sa.Column("join_seen", sa.Boolean, nullable=False, server_default="false"),
sa.Column("currently_authorized", sa.Boolean, nullable=False, server_default="false"),
sa.Column("approved_for_activation", sa.Boolean, nullable=False, server_default="true"),
if_not_exists=True,
)
if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_organization_id"):
op.create_index("ix_device_network_memberships_organization_id", "device_network_memberships", ["organization_id"])
if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_user_id"):
op.create_index("ix_device_network_memberships_user_id", "device_network_memberships", ["user_id"])
if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_device_id"):
op.create_index("ix_device_network_memberships_device_id", "device_network_memberships", ["device_id"])
if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_portal_network_id"):
op.create_index("ix_device_network_memberships_portal_network_id", "device_network_memberships", ["portal_network_id"])
if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_state"):
op.create_index("ix_device_network_memberships_state", "device_network_memberships", ["state"])
if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_user_network_approval_id"):
op.create_index("ix_device_network_memberships_user_network_approval_id", "device_network_memberships", ["user_network_approval_id"])
if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_device_network"):
op.create_index(
"ix_device_network_memberships_device_network", "device_network_memberships",
["device_id", "portal_network_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# ── 6. activation_sessions ────────────────────────────────────────────────
op.create_table(
"activation_sessions",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False),
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
sa.Column("device_network_membership_id", sa.String(36), sa.ForeignKey("device_network_memberships.id"), nullable=False),
sa.Column("authenticated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("ended_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("end_reason", sa.String(40), nullable=True),
sa.Column("created_by", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
if_not_exists=True,
)
if not _index_exists(conn, "activation_sessions", "ix_activation_sessions_organization_id"):
op.create_index("ix_activation_sessions_organization_id", "activation_sessions", ["organization_id"])
if not _index_exists(conn, "activation_sessions", "ix_activation_sessions_user_id"):
op.create_index("ix_activation_sessions_user_id", "activation_sessions", ["user_id"])
if not _index_exists(conn, "activation_sessions", "ix_activation_sessions_device_network_membership_id"):
op.create_index("ix_activation_sessions_device_network_membership_id", "activation_sessions", ["device_network_membership_id"])
# ── 7. zerotier_memberships ───────────────────────────────────────────────
op.create_table(
"zerotier_memberships",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False),
sa.Column("device_network_membership_id", sa.String(36), sa.ForeignKey("device_network_memberships.id"), nullable=True),
sa.Column("zerotier_network_id", sa.String(16), nullable=False),
sa.Column("node_id", sa.String(10), nullable=False),
sa.Column("member_seen", sa.Boolean, nullable=False, server_default="false"),
sa.Column("authorized", sa.Boolean, nullable=False, server_default="false"),
sa.Column("join_seen_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_synced_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("raw_controller_payload", sa.JSON, nullable=True),
if_not_exists=True,
)
if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_organization_id"):
op.create_index("ix_zerotier_memberships_organization_id", "zerotier_memberships", ["organization_id"])
if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_device_network_membership_id"):
op.create_index("ix_zerotier_memberships_device_network_membership_id", "zerotier_memberships", ["device_network_membership_id"])
if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_zerotier_network_id"):
op.create_index("ix_zerotier_memberships_zerotier_network_id", "zerotier_memberships", ["zerotier_network_id"])
if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_node_id"):
op.create_index("ix_zerotier_memberships_node_id", "zerotier_memberships", ["node_id"])
if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_network_node"):
op.create_index(
"ix_zerotier_memberships_network_node", "zerotier_memberships",
["zerotier_network_id", "node_id"],
unique=True,
)
# ── 8. kill_switch_events ────────────────────────────────────────────────
op.create_table(
"kill_switch_events",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False),
sa.Column("target_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
sa.Column("scope", sa.String(40), nullable=False, server_default="organization"),
sa.Column("triggered_by_user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
sa.Column("reason", sa.Text, nullable=True),
sa.Column("network_ids", sa.JSON, nullable=True),
if_not_exists=True,
)
if not _index_exists(conn, "kill_switch_events", "ix_kill_switch_events_organization_id"):
op.create_index("ix_kill_switch_events_organization_id", "kill_switch_events", ["organization_id"])
if not _index_exists(conn, "kill_switch_events", "ix_kill_switch_events_target_user_id"):
op.create_index("ix_kill_switch_events_target_user_id", "kill_switch_events", ["target_user_id"])
# ── 9. Drop legacy SSH-session tables (IF EXISTS — safe on fresh DBs) ─────
#
# Order matters due to FK constraints:
# command_events → servers, host_sessions
# sudo_events → host_sessions
# host_sessions → (nothing that still exists)
# servers → (nothing that still exists)
conn.execute(sa.text("DROP TABLE IF EXISTS command_events CASCADE"))
conn.execute(sa.text("DROP TABLE IF EXISTS sudo_events CASCADE"))
conn.execute(sa.text("DROP TABLE IF EXISTS host_sessions CASCADE"))
conn.execute(sa.text("DROP TABLE IF EXISTS servers CASCADE"))
# ---------------------------------------------------------------------------
# downgrade
# ---------------------------------------------------------------------------
def downgrade():
conn = op.get_bind()
dialect = conn.dialect.name
# Re-create the legacy tables (minimal — enough for FK integrity)
op.create_table(
"servers",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.Column("deleted_at", sa.DateTime(), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False),
sa.Column("hostname", sa.String(255), nullable=False),
sa.Column("display_name", sa.String(255), nullable=True),
sa.Column("ip_address", sa.String(64), nullable=True),
sa.Column("is_active", sa.Boolean, nullable=False, server_default="true"),
if_not_exists=True,
)
op.create_table(
"host_sessions",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.Column("deleted_at", sa.DateTime(), nullable=True),
sa.Column("organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=False),
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
sa.Column("server_id", sa.String(36), sa.ForeignKey("servers.id"), nullable=False),
if_not_exists=True,
)
# Drop ZeroTier tables
op.drop_table("kill_switch_events", if_exists=True)
op.drop_table("zerotier_memberships", if_exists=True)
op.drop_table("activation_sessions", if_exists=True)
op.drop_table("device_network_memberships", if_exists=True)
op.drop_table("user_network_approvals", if_exists=True)
op.drop_table("devices", if_exists=True)
op.drop_table("portal_networks", if_exists=True)
# Drop ZeroTier enum types
if dialect == "postgresql":
for t in [
"kill_switch_scope", "device_status", "activation_end_reason",
"membership_state", "approval_state", "approval_grant_type",
"network_request_mode", "network_environment",
]:
conn.execute(sa.text(f"DROP TYPE IF EXISTS {t}"))
@@ -0,0 +1,291 @@
"""Fix ZeroTier table schema: enum types, unique constraints, indexes, drop cert_token.
Revision ID: 024_fix_zerotier_schema
Revises: 023_zerotier_drop_legacy
Create Date: 2026-03-22
Addresses all `db check` differences after 023:
- Cast VARCHAR(40) enum columns to their proper PostgreSQL enum types
(guarded skipped if columns are already native enum, e.g. on a fresh DB
where 020_zerotier created them correctly)
- Replace partial unique indexes with named UniqueConstraints
- Fix devices.node_id partial index -> plain index
- Add UniqueConstraint on `id` for all new ZeroTier tables (BaseModel.unique=True)
- Drop orphan cert_token column and its index from ssh_certificates
"""
from alembic import op
import sqlalchemy as sa
revision = "024_fix_zerotier_schema"
down_revision = "023_zerotier_drop_legacy"
branch_labels = None
depends_on = None
# ---------------------------------------------------------------------------
# helpers
# ---------------------------------------------------------------------------
def _col_data_type(conn, table: str, column: str) -> str | None:
"""Return the PostgreSQL data_type string for a column, or None."""
row = conn.execute(sa.text(
"SELECT data_type FROM information_schema.columns "
"WHERE table_name = :t AND column_name = :c"
), {"t": table, "c": column}).first()
return row[0] if row else None
def _column_exists(conn, table: str, column: str) -> bool:
return _col_data_type(conn, table, column) is not None
def _index_exists(conn, table: str, index: str) -> bool:
from sqlalchemy.engine.reflection import Inspector
insp = Inspector.from_engine(conn)
return any(i["name"] == index for i in insp.get_indexes(table))
def _constraint_exists(conn, constraint: str) -> bool:
row = conn.execute(sa.text(
"SELECT 1 FROM information_schema.table_constraints "
"WHERE constraint_name = :c"
), {"c": constraint}).first()
return row is not None
def upgrade():
conn = op.get_bind()
# -------------------------------------------------------------------------
# 1. Cast VARCHAR(40) enum columns to proper PostgreSQL enum types.
# GUARDED: On a fresh DB, 020_zerotier already created these as native
# enum types. We only cast if the column is currently 'character varying'.
# -------------------------------------------------------------------------
enum_casts = [
("portal_networks", "environment", "network_environment", None),
("portal_networks", "request_mode", "network_request_mode", None),
("devices", "status", "device_status", "'active'::device_status"),
("device_network_memberships", "state", "membership_state", "'pending_device_registration'::membership_state"),
("user_network_approvals", "grant_type", "approval_grant_type", "'requested'::approval_grant_type"),
("user_network_approvals", "state", "approval_state", "'pending'::approval_state"),
("activation_sessions", "end_reason", "activation_end_reason", None),
("kill_switch_events", "scope", "kill_switch_scope", "'organization'::kill_switch_scope"),
]
for table, col, enum_type, new_default in enum_casts:
dtype = _col_data_type(conn, table, col)
if dtype == "character varying":
conn.execute(sa.text(f'ALTER TABLE "{table}" ALTER COLUMN "{col}" DROP DEFAULT'))
conn.execute(sa.text(
f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {enum_type} '
f'USING "{col}"::text::{enum_type}'
))
if new_default:
conn.execute(sa.text(
f'ALTER TABLE "{table}" ALTER COLUMN "{col}" SET DEFAULT {new_default}'
))
elif dtype == "USER-DEFINED" and new_default:
# Already native enum (fresh DB path). Ensure server_default is set
# if 020 used `default=` (Python-side) instead of `server_default=`.
# This is harmless — SET DEFAULT is idempotent.
conn.execute(sa.text(
f'ALTER TABLE "{table}" ALTER COLUMN "{col}" SET DEFAULT {new_default}'
))
# -------------------------------------------------------------------------
# 2. portal_networks: drop partial unique index, add named UniqueConstraint
# -------------------------------------------------------------------------
if _index_exists(conn, "portal_networks", "ix_portal_networks_org_zt"):
op.drop_index("ix_portal_networks_org_zt", table_name="portal_networks")
if not _constraint_exists(conn, "uix_org_zt_network_id"):
op.create_unique_constraint(
"uix_org_zt_network_id",
"portal_networks",
["organization_id", "zerotier_network_id"],
)
# -------------------------------------------------------------------------
# 3. device_network_memberships: drop partial unique index, add named UC
# -------------------------------------------------------------------------
if _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_device_network"):
op.drop_index("ix_device_network_memberships_device_network", table_name="device_network_memberships")
if not _constraint_exists(conn, "uix_device_network"):
op.create_unique_constraint(
"uix_device_network",
"device_network_memberships",
["device_id", "portal_network_id", "deleted_at"],
)
# -------------------------------------------------------------------------
# 4. user_network_approvals: drop partial unique index, add named UC
# -------------------------------------------------------------------------
if _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_user_network"):
op.drop_index("ix_user_network_approvals_user_network", table_name="user_network_approvals")
if not _constraint_exists(conn, "uix_user_network_approval"):
op.create_unique_constraint(
"uix_user_network_approval",
"user_network_approvals",
["user_id", "portal_network_id", "deleted_at"],
)
# -------------------------------------------------------------------------
# 5. zerotier_memberships: drop index, add named UniqueConstraint
# -------------------------------------------------------------------------
if _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_network_node"):
op.drop_index("ix_zerotier_memberships_network_node", table_name="zerotier_memberships")
if not _constraint_exists(conn, "uix_zt_network_node"):
op.create_unique_constraint(
"uix_zt_network_node",
"zerotier_memberships",
["zerotier_network_id", "node_id"],
)
# -------------------------------------------------------------------------
# 6. devices.node_id: drop partial unique index, add plain non-unique index
# -------------------------------------------------------------------------
if _index_exists(conn, "devices", "ix_devices_node_id_active"):
op.drop_index("ix_devices_node_id_active", table_name="devices")
if not _index_exists(conn, "devices", "ix_devices_node_id"):
op.create_index("ix_devices_node_id", "devices", ["node_id"])
# -------------------------------------------------------------------------
# 7. Add UniqueConstraint on `id` for all ZeroTier tables
# BaseModel defines id with unique=True → separate _id_key constraint.
# -------------------------------------------------------------------------
zt_tables = [
"portal_networks",
"devices",
"device_network_memberships",
"user_network_approvals",
"activation_sessions",
"zerotier_memberships",
"kill_switch_events",
]
for tbl in zt_tables:
cname = f"{tbl}_id_key"
if not _constraint_exists(conn, cname):
op.create_unique_constraint(cname, tbl, ["id"])
# -------------------------------------------------------------------------
# 8. Drop orphan cert_token column and its index from ssh_certificates.
# cert_token was created by 3de11c5dc2d5 but the SSHCertificate model
# never uses it. Guarded in case a future revision removes it first.
# -------------------------------------------------------------------------
if _index_exists(conn, "ssh_certificates", "ix_ssh_certificates_cert_token"):
op.drop_index("ix_ssh_certificates_cert_token", table_name="ssh_certificates")
if _column_exists(conn, "ssh_certificates", "cert_token"):
op.drop_column("ssh_certificates", "cert_token")
def downgrade():
conn = op.get_bind()
# Restore cert_token if it was dropped
if not _column_exists(conn, "ssh_certificates", "cert_token"):
op.add_column(
"ssh_certificates",
sa.Column("cert_token", sa.String(64), nullable=True),
)
if not _index_exists(conn, "ssh_certificates", "ix_ssh_certificates_cert_token"):
op.create_index(
"ix_ssh_certificates_cert_token",
"ssh_certificates",
["cert_token"],
unique=True,
)
# Drop id unique constraints on ZeroTier tables
zt_tables = [
"portal_networks",
"devices",
"device_network_memberships",
"user_network_approvals",
"activation_sessions",
"zerotier_memberships",
"kill_switch_events",
]
for tbl in zt_tables:
cname = f"{tbl}_id_key"
if _constraint_exists(conn, cname):
op.drop_constraint(cname, tbl, type_="unique")
# Restore devices node_id index
if _index_exists(conn, "devices", "ix_devices_node_id"):
op.drop_index("ix_devices_node_id", table_name="devices")
if not _index_exists(conn, "devices", "ix_devices_node_id_active"):
op.create_index(
"ix_devices_node_id_active",
"devices",
["node_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Restore zerotier_memberships index
if _constraint_exists(conn, "uix_zt_network_node"):
op.drop_constraint("uix_zt_network_node", "zerotier_memberships", type_="unique")
if not _index_exists(conn, "zerotier_memberships", "ix_zerotier_memberships_network_node"):
op.create_index(
"ix_zerotier_memberships_network_node",
"zerotier_memberships",
["zerotier_network_id", "node_id"],
unique=True,
)
# Restore user_network_approvals partial unique index
if _constraint_exists(conn, "uix_user_network_approval"):
op.drop_constraint("uix_user_network_approval", "user_network_approvals", type_="unique")
if not _index_exists(conn, "user_network_approvals", "ix_user_network_approvals_user_network"):
op.create_index(
"ix_user_network_approvals_user_network",
"user_network_approvals",
["user_id", "portal_network_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Restore device_network_memberships partial unique index
if _constraint_exists(conn, "uix_device_network"):
op.drop_constraint("uix_device_network", "device_network_memberships", type_="unique")
if not _index_exists(conn, "device_network_memberships", "ix_device_network_memberships_device_network"):
op.create_index(
"ix_device_network_memberships_device_network",
"device_network_memberships",
["device_id", "portal_network_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Restore portal_networks partial unique index
if _constraint_exists(conn, "uix_org_zt_network_id"):
op.drop_constraint("uix_org_zt_network_id", "portal_networks", type_="unique")
if not _index_exists(conn, "portal_networks", "ix_portal_networks_org_zt"):
op.create_index(
"ix_portal_networks_org_zt",
"portal_networks",
["organization_id", "zerotier_network_id"],
unique=True,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Cast enum columns back to VARCHAR(40) — only if currently native enum
enum_casts = [
("portal_networks", "environment", "'development'::character varying"),
("portal_networks", "request_mode", "'approval_required'::character varying"),
("devices", "status", "'active'::character varying"),
("device_network_memberships", "state", "'pending_device_registration'::character varying"),
("user_network_approvals", "grant_type", "'requested'::character varying"),
("user_network_approvals", "state", "'pending'::character varying"),
("activation_sessions", "end_reason", None),
("kill_switch_events", "scope", "'organization'::character varying"),
]
for table, col, old_default in enum_casts:
conn.execute(sa.text(f'ALTER TABLE "{table}" ALTER COLUMN "{col}" DROP DEFAULT'))
conn.execute(sa.text(
f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE VARCHAR(40) '
f'USING "{col}"::text'
))
if old_default:
conn.execute(sa.text(
f'ALTER TABLE "{table}" ALTER COLUMN "{col}" SET DEFAULT {old_default}'
))
@@ -0,0 +1,101 @@
"""Convert ZeroTier table timestamp columns from TIMESTAMPTZ to TIMESTAMP.
Revision ID: 025_fix_zt_timestamps
Revises: 024_fix_zerotier_schema
Create Date: 2026-03-22
Migration 020_zerotier (and 023's fallback create_table) defined ZeroTier tables
with sa.DateTime(timezone=True), producing TIMESTAMP WITH TIME ZONE columns.
The rest of the codebase uses plain DateTime (timezone-naive TIMESTAMP WITHOUT
TIME ZONE). This migration aligns all ZeroTier table timestamp columns with the
existing codebase convention.
GUARDED: Each ALTER is only executed if the column is currently
TIMESTAMP WITH TIME ZONE. On a DB that has already been converted (e.g. dev),
the migration is a harmless no-op.
"""
from alembic import op
import sqlalchemy as sa
revision = "025_fix_zt_timestamps"
down_revision = "024_fix_zerotier_schema"
branch_labels = None
depends_on = None
# All ZeroTier tables that inherit BaseModel's created_at/updated_at/deleted_at
_ZT_BASE_TABLES = [
"portal_networks",
"devices",
"device_network_memberships",
"user_network_approvals",
"kill_switch_events",
"activation_sessions",
"zerotier_memberships",
]
# Additional datetime columns specific to individual models
_EXTRA_COLS = {
"activation_sessions": ["authenticated_at", "expires_at", "ended_at"],
"zerotier_memberships": ["join_seen_at", "last_synced_at"],
}
def _col_is_timestamptz(conn, table: str, column: str) -> bool:
"""Return True if the column is TIMESTAMP WITH TIME ZONE."""
row = conn.execute(sa.text(
"SELECT data_type FROM information_schema.columns "
"WHERE table_name = :t AND column_name = :c"
), {"t": table, "c": column}).first()
return row is not None and row[0] == "timestamp with time zone"
def _col_is_timestamp(conn, table: str, column: str) -> bool:
"""Return True if the column is TIMESTAMP WITHOUT TIME ZONE."""
row = conn.execute(sa.text(
"SELECT data_type FROM information_schema.columns "
"WHERE table_name = :t AND column_name = :c"
), {"t": table, "c": column}).first()
return row is not None and row[0] == "timestamp without time zone"
def upgrade():
conn = op.get_bind()
for tbl in _ZT_BASE_TABLES:
for col in ("created_at", "updated_at", "deleted_at"):
if _col_is_timestamptz(conn, tbl, col):
conn.execute(sa.text(
f'ALTER TABLE "{tbl}" ALTER COLUMN "{col}" '
f'TYPE TIMESTAMP WITHOUT TIME ZONE '
f'USING "{col}" AT TIME ZONE \'UTC\''
))
for col in _EXTRA_COLS.get(tbl, []):
if _col_is_timestamptz(conn, tbl, col):
conn.execute(sa.text(
f'ALTER TABLE "{tbl}" ALTER COLUMN "{col}" '
f'TYPE TIMESTAMP WITHOUT TIME ZONE '
f'USING CASE WHEN "{col}" IS NULL THEN NULL '
f'ELSE "{col}" AT TIME ZONE \'UTC\' END'
))
def downgrade():
conn = op.get_bind()
for tbl in _ZT_BASE_TABLES:
for col in ("created_at", "updated_at", "deleted_at"):
if _col_is_timestamp(conn, tbl, col):
conn.execute(sa.text(
f'ALTER TABLE "{tbl}" ALTER COLUMN "{col}" '
f'TYPE TIMESTAMP WITH TIME ZONE '
f'USING "{col}" AT TIME ZONE \'UTC\''
))
for col in _EXTRA_COLS.get(tbl, []):
if _col_is_timestamp(conn, tbl, col):
conn.execute(sa.text(
f'ALTER TABLE "{tbl}" ALTER COLUMN "{col}" '
f'TYPE TIMESTAMP WITH TIME ZONE '
f'USING CASE WHEN "{col}" IS NULL THEN NULL '
f'ELSE "{col}" AT TIME ZONE \'UTC\' END'
))
+216
View File
@@ -0,0 +1,216 @@
"""Schema cleanup: id UniqueConstraints, organization_api_keys index/timestamp fixes.
Revision ID: 026_schema_cleanup
Revises: 025_fix_zt_timestamps
Create Date: 2026-03-23
Addresses all `db check` differences after 025 on a database upgraded from
production (021_merge_heads):
1. Add UniqueConstraint on `id` for all pre-existing tables that inherit
BaseModel (which declares id with unique=True). The ZeroTier tables
already got these in 024_fix_zerotier_schema; this covers the rest.
2. organization_api_keys fix schema drift vs. the current model:
- TIMESTAMPTZ TIMESTAMP WITHOUT TIME ZONE (align with rest of codebase)
- Drop legacy unique constraint 'organization_api_keys_key_hash_key'
and replace with named index 'ix_organization_api_keys_key_hash'
- Drop extra index 'idx_org_api_key_org_id' (superseded by
'ix_organization_api_keys_organization_id')
- Add 'ix_organization_api_keys_organization_id' and
'ix_organization_api_keys_is_revoked' named indexes expected by model
3. Drop 'idx_dept_can_sudo' index from departments created by an old
migration but not declared in the current Department model.
All operations are guarded so the migration is safe to re-run.
"""
from alembic import op
import sqlalchemy as sa
revision = "026_schema_cleanup"
down_revision = "025_fix_zt_timestamps"
branch_labels = None
depends_on = None
# ---------------------------------------------------------------------------
# helpers
# ---------------------------------------------------------------------------
def _constraint_exists(conn, name: str) -> bool:
row = conn.execute(sa.text(
"SELECT 1 FROM information_schema.table_constraints "
"WHERE constraint_name = :n"
), {"n": name}).first()
return row is not None
def _index_exists(conn, table: str, index: str) -> bool:
row = conn.execute(sa.text(
"SELECT 1 FROM pg_indexes "
"WHERE tablename = :t AND indexname = :i"
), {"t": table, "i": index}).first()
return row is not None
def _col_is_timestamptz(conn, table: str, column: str) -> bool:
row = conn.execute(sa.text(
"SELECT data_type FROM information_schema.columns "
"WHERE table_name = :t AND column_name = :c"
), {"t": table, "c": column}).first()
return row is not None and row[0] == "timestamp with time zone"
# ---------------------------------------------------------------------------
# Tables that inherit BaseModel and need an id UniqueConstraint.
# ZeroTier tables were handled in 024; all others are listed here.
# ---------------------------------------------------------------------------
_LEGACY_TABLES = [
"application_provider_configs",
"audit_logs",
"authentication_methods",
"ca_permissions",
"cas",
"certificate_audit_logs",
"department_cert_policies",
"department_memberships",
"department_principals",
"departments",
"email_verification_tokens",
"external_provider_configs",
"mfa_policy_compliance",
"oauth_states",
"oidc_audit_logs",
"oidc_authorization_codes",
"oidc_clients",
"oidc_refresh_tokens",
"oidc_sessions",
# oidc_token_metadata intentionally excluded: its id column overrides
# BaseModel without unique=True (JTI is the PK but not separately unique)
"org_invite_tokens",
"organization_api_keys",
"organization_members",
"organization_provider_overrides",
"organization_security_policies",
"organizations",
"password_reset_tokens",
"principal_memberships",
"principals",
"sessions",
"ssh_certificates",
"ssh_keys",
"user_security_policies",
"users",
]
def upgrade():
conn = op.get_bind()
# ── 1. Add id UniqueConstraint to all legacy BaseModel tables ─────────
for tbl in _LEGACY_TABLES:
cname = f"{tbl}_id_key"
if not _constraint_exists(conn, cname):
op.create_unique_constraint(cname, tbl, ["id"])
# Drop the wrongly-added constraint on oidc_token_metadata if present
# (its id column overrides BaseModel without unique=True)
if _constraint_exists(conn, "oidc_token_metadata_id_key"):
op.drop_constraint("oidc_token_metadata_id_key", "oidc_token_metadata", type_="unique")
# ── 2. organization_api_keys: timestamp columns TIMESTAMPTZ → TIMESTAMP
for col in ("created_at", "updated_at", "deleted_at", "last_used_at", "revoked_at"):
if _col_is_timestamptz(conn, "organization_api_keys", col):
conn.execute(sa.text(
f'ALTER TABLE organization_api_keys ALTER COLUMN "{col}" '
f'TYPE TIMESTAMP WITHOUT TIME ZONE '
f'USING CASE WHEN "{col}" IS NULL THEN NULL '
f'ELSE "{col}" AT TIME ZONE \'UTC\' END'
))
# ── 3. organization_api_keys: replace legacy unique constraint + indexes
# Drop the anonymous unique constraint on key_hash (created by
# sa.UniqueConstraint('key_hash') in the original migration)
if _constraint_exists(conn, "organization_api_keys_key_hash_key"):
op.drop_constraint(
"organization_api_keys_key_hash_key",
"organization_api_keys",
type_="unique",
)
# Add named unique index for key_hash expected by the model
if not _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_key_hash"):
op.create_index(
"ix_organization_api_keys_key_hash",
"organization_api_keys",
["key_hash"],
unique=True,
)
# Drop the legacy plain org-id index (superseded by the named one below)
if _index_exists(conn, "organization_api_keys", "idx_org_api_key_org_id"):
op.drop_index("idx_org_api_key_org_id", table_name="organization_api_keys")
# Add named org-id index expected by the model
if not _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_organization_id"):
op.create_index(
"ix_organization_api_keys_organization_id",
"organization_api_keys",
["organization_id"],
)
# Add named is_revoked index expected by the model
if not _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_is_revoked"):
op.create_index(
"ix_organization_api_keys_is_revoked",
"organization_api_keys",
["is_revoked"],
)
# ── 4. Drop orphan idx_dept_can_sudo from departments ─────────────────
if _index_exists(conn, "departments", "idx_dept_can_sudo"):
op.drop_index("idx_dept_can_sudo", table_name="departments")
# NOTE: ix_ssh_certificates_serial uniqueness is handled in
# 027_fix_cert_serial_uniqueness (composite unique per CA).
def downgrade():
conn = op.get_bind()
# Restore idx_dept_can_sudo
if not _index_exists(conn, "departments", "idx_dept_can_sudo"):
op.create_index("idx_dept_can_sudo", "departments", ["organization_id", "can_sudo"])
# Restore organization_api_keys indexes
if _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_is_revoked"):
op.drop_index("ix_organization_api_keys_is_revoked", table_name="organization_api_keys")
if _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_organization_id"):
op.drop_index("ix_organization_api_keys_organization_id", table_name="organization_api_keys")
if _index_exists(conn, "organization_api_keys", "ix_organization_api_keys_key_hash"):
op.drop_index("ix_organization_api_keys_key_hash", table_name="organization_api_keys")
if not _constraint_exists(conn, "organization_api_keys_key_hash_key"):
op.create_unique_constraint(
"organization_api_keys_key_hash_key",
"organization_api_keys",
["key_hash"],
)
if not _index_exists(conn, "organization_api_keys", "idx_org_api_key_org_id"):
op.create_index("idx_org_api_key_org_id", "organization_api_keys", ["organization_id"])
# Restore TIMESTAMPTZ on organization_api_keys
for col in ("created_at", "updated_at", "deleted_at", "last_used_at", "revoked_at"):
conn.execute(sa.text(
f'ALTER TABLE organization_api_keys ALTER COLUMN "{col}" '
f'TYPE TIMESTAMP WITH TIME ZONE '
f'USING CASE WHEN "{col}" IS NULL THEN NULL '
f'ELSE "{col}" AT TIME ZONE \'UTC\' END'
))
# Drop id UniqueConstraints from legacy tables
for tbl in reversed(_LEGACY_TABLES):
cname = f"{tbl}_id_key"
if _constraint_exists(conn, cname):
op.drop_constraint(cname, tbl, type_="unique")
@@ -0,0 +1,105 @@
"""Fix ssh_certificates serial uniqueness: per-CA not global.
Revision ID: 027_fix_cert_serial_uniqueness
Revises: 026_schema_cleanup
Create Date: 2026-03-23
The SSHCertificate model uses a per-CA monotonic serial counter, meaning
serial numbers are only unique within a single CA not across the whole
table. The original migration created a global unique index on `serial`
alone, which is incorrect and was blocking enforcement (duplicate serial=1
rows exist in production where two different CAs both issued their first
certificate).
This migration:
1. Drops the old non-unique index ix_ssh_certificates_serial (which was
never enforcing uniqueness just an index).
2. Drops the stale unique constraint ssh_certificates_serial_key if it
somehow exists.
3. Creates a proper composite unique constraint uq_ssh_certificates_ca_serial
on (ca_id, serial), reflecting the real invariant: a serial is unique
within one CA.
All operations are guarded (IF EXISTS / try/except) so this is safe to
re-run on any DB state.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector
# ---------------------------------------------------------------------------
# revision identifiers
# ---------------------------------------------------------------------------
revision = "027_fix_cert_serial_uniqueness"
down_revision = "026_schema_cleanup"
branch_labels = None
depends_on = None
def _index_exists(conn, table: str, index: str) -> bool:
insp = Inspector.from_engine(conn)
return any(i["name"] == index for i in insp.get_indexes(table))
def _constraint_exists(conn, table: str, constraint: str) -> bool:
insp = Inspector.from_engine(conn)
for uc in insp.get_unique_constraints(table):
if uc["name"] == constraint:
return True
return False
def upgrade():
conn = op.get_bind()
# 1. Drop the old global non-unique index on serial (if present)
if _index_exists(conn, "ssh_certificates", "ix_ssh_certificates_serial"):
op.drop_index("ix_ssh_certificates_serial", table_name="ssh_certificates")
# 2. Drop any stale global unique constraint on serial alone (defensive)
if _constraint_exists(conn, "ssh_certificates", "ssh_certificates_serial_key"):
op.drop_constraint(
"ssh_certificates_serial_key",
"ssh_certificates",
type_="unique",
)
# 3. Add composite unique constraint: serial is unique per CA
if not _constraint_exists(conn, "ssh_certificates", "uq_ssh_certificates_ca_serial"):
op.create_unique_constraint(
"uq_ssh_certificates_ca_serial",
"ssh_certificates",
["ca_id", "serial"],
)
# 4. Re-create a plain non-unique index on serial for fast lookups
if not _index_exists(conn, "ssh_certificates", "ix_ssh_certificates_serial"):
op.create_index(
"ix_ssh_certificates_serial",
"ssh_certificates",
["serial"],
unique=False,
)
def downgrade():
conn = op.get_bind()
# Remove the composite constraint
if _constraint_exists(conn, "ssh_certificates", "uq_ssh_certificates_ca_serial"):
op.drop_constraint(
"uq_ssh_certificates_ca_serial",
"ssh_certificates",
type_="unique",
)
# Restore the old non-unique index (best effort — data may have duplicates)
if not _index_exists(conn, "ssh_certificates", "ix_ssh_certificates_serial"):
op.create_index(
"ix_ssh_certificates_serial",
"ssh_certificates",
["serial"],
unique=False,
)
@@ -0,0 +1,69 @@
"""Add per-org ZeroTier credentials to organizations table.
Revision ID: 028_org_zerotier_config
Revises: 026_schema_cleanup
Create Date: 2026-03-25
Adds three nullable columns to `organizations`:
- zt_api_token VARCHAR(512) API token (Central) or authtoken.secret (controller)
- zt_api_url VARCHAR(512) base URL of the controller / Central API
- zt_api_mode VARCHAR(32) "central" | "controller"
When these are NULL the server-level ZEROTIER_API_* env vars are used instead,
so existing deployments are fully backwards-compatible with no data migration needed.
"""
from alembic import op
import sqlalchemy as sa
revision = "028_org_zerotier_config"
down_revision = "027_fix_cert_serial_uniqueness"
branch_labels = None
depends_on = None
def _col_exists(conn, table: str, column: str) -> bool:
row = conn.execute(
sa.text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :t AND column_name = :c"
),
{"t": table, "c": column},
).first()
return row is not None
def upgrade():
conn = op.get_bind()
if not _col_exists(conn, "organizations", "zt_api_token"):
op.add_column(
"organizations",
sa.Column("zt_api_token", sa.String(512), nullable=True),
)
if not _col_exists(conn, "organizations", "zt_api_url"):
op.add_column(
"organizations",
sa.Column("zt_api_url", sa.String(512), nullable=True),
)
if not _col_exists(conn, "organizations", "zt_api_mode"):
op.add_column(
"organizations",
sa.Column("zt_api_mode", sa.String(32), nullable=True),
)
def downgrade():
conn = op.get_bind()
if _col_exists(conn, "organizations", "zt_api_mode"):
op.drop_column("organizations", "zt_api_mode")
if _col_exists(conn, "organizations", "zt_api_url"):
op.drop_column("organizations", "zt_api_url")
if _col_exists(conn, "organizations", "zt_api_token"):
op.drop_column("organizations", "zt_api_token")
@@ -0,0 +1,30 @@
"""add_cert_token_to_ssh_certificates
Revision ID: 3de11c5dc2d5
Revises: 019_audit_varchar
Create Date: 2026-03-06 16:04:33.561099
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '3de11c5dc2d5'
down_revision = '019_audit_varchar'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('ssh_certificates', sa.Column('cert_token', sa.String(length=64), nullable=True))
op.create_index(op.f('ix_ssh_certificates_cert_token'), 'ssh_certificates', ['cert_token'], unique=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_ssh_certificates_cert_token'), table_name='ssh_certificates')
op.drop_column('ssh_certificates', 'cert_token')
# ### end Alembic commands ###
@@ -0,0 +1,34 @@
"""Add can_sudo column to departments table.
Revision ID: 002_add_can_sudo_to_departments
Revises: 001_add_org_api_keys
Create Date: 2026-03-07 23:40:30.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '002_add_can_sudo_to_departments'
down_revision = '001_add_org_api_keys'
branch_labels = None
depends_on = None
def upgrade():
# Add can_sudo column to departments table
op.add_column('departments',
sa.Column('can_sudo', sa.Boolean(), nullable=False, server_default='false'))
# Create index for performance
op.create_index('idx_dept_can_sudo', 'departments',
['organization_id', 'can_sudo'])
def downgrade():
# Drop index
op.drop_index('idx_dept_can_sudo', table_name='departments')
# Drop column
op.drop_column('departments', 'can_sudo')
@@ -0,0 +1,56 @@
"""Add organization_api_keys table for API key management.
Revision ID: 001_add_org_api_keys
Revises: 3de11c5dc2d5
Create Date: 2026-03-07 23:40:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '001_add_org_api_keys'
down_revision = '3de11c5dc2d5'
branch_labels = None
depends_on = None
def upgrade():
# Create organization_api_keys table
op.create_table(
'organization_api_keys',
sa.Column('id', sa.String(36), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('organization_id', sa.String(36), nullable=False),
sa.Column('name', sa.String(255), nullable=False),
sa.Column('key_hash', sa.String(255), nullable=False),
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('is_revoked', sa.Boolean(), nullable=False, server_default='false'),
sa.Column('revoked_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('revoke_reason', sa.String(255), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('key_hash'),
)
# Create indexes for performance
op.create_index('idx_org_api_key_org_active', 'organization_api_keys',
['organization_id', 'is_revoked'])
op.create_index('idx_api_key_last_used', 'organization_api_keys',
['last_used_at'])
op.create_index('idx_org_api_key_org_id', 'organization_api_keys',
['organization_id'])
def downgrade():
# Drop indexes
op.drop_index('idx_org_api_key_org_id', table_name='organization_api_keys')
op.drop_index('idx_api_key_last_used', table_name='organization_api_keys')
op.drop_index('idx_org_api_key_org_active', table_name='organization_api_keys')
# Drop table
op.drop_table('organization_api_keys')
+1 -1
View File
@@ -58,7 +58,7 @@ if os.path.exists(env_file):
# Import after path setup
from gatehouse_app import create_app
from gatehouse_app.services.external_auth_service import ExternalAuthService, ExternalAuthError
from gatehouse_app.services.external_auth import ExternalAuthService, ExternalAuthError
def _microsoft_defaults() -> dict:
+7
View File
@@ -1,4 +1,11 @@
"""Initialize database script."""
import sys
import os
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from gatehouse_app import create_app
from gatehouse_app.extensions import db
from sqlalchemy import text