Merge branch 'main' into v1.01/stable

This commit is contained in:
2026-04-26 22:54:54 +08:00
committed by GitHub
135 changed files with 13567 additions and 307 deletions
+144
View File
@@ -0,0 +1,144 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
Pipfile.lock
# PEP 582
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# IDEs
.vscode/
.idea/
*.swp
*.swo
*~
# OS
.DS_Store
Thumbs.db
# Project specific
*.db
flask_session/
# Opencode files and folders
.opencode/
.swarm/
SWARM_PLAN.*
+40
View File
@@ -0,0 +1,40 @@
FROM python:3.11-slim as builder
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
libpq-dev \
&& rm -rf /var/lib/apt/lists/*
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
WORKDIR /app
COPY requirements/base.txt requirements/base.txt
COPY requirements/production.txt requirements/production.txt
RUN pip install --no-cache-dir --upgrade pip wheel && \
pip install --no-cache-dir -r requirements/production.txt
FROM python:3.11-slim
RUN apt-get update && apt-get install -y --no-install-recommends \
libpq5 \
&& rm -rf /var/lib/apt/lists/*
RUN groupadd --gid 1000 appgroup && \
useradd --uid 1000 --gid appgroup --shell /bin/bash --create-home appuser
COPY --from=builder /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
WORKDIR /app
COPY --chown=appuser:appgroup . .
RUN mkdir -p /app/logs && chown -R appuser:appgroup /app/logs
USER appuser
HEALTHCHECK --interval=60s --timeout=10s --start-period=10s --retries=3 \
CMD pgrep -f "job_runner" || exit 1
CMD ["python", "scripts/job_runner.py"]
+50
View File
@@ -154,6 +154,7 @@ Copy `.env.example` to `.env` and configure:
- `POST /api/v1/auth/logout` - Logout
- `GET /api/v1/auth/me` - Get current user
- `GET /api/v1/auth/sessions` - Get user sessions
- `POST /api/v1/auth/sessions/refresh` - Extend session idle window
- `DELETE /api/v1/auth/sessions/:id` - Revoke session
### Users
@@ -174,6 +175,9 @@ Copy `.env.example` to `.env` and configure:
- `PATCH /api/v1/organizations/:id/members/:userId/role` - Update role
### Contact (Public — No Auth Required)
- `POST /api/v1/contact` - Submit a contact enquiry (demo request, sales enquiry, general, or support). Rate limited to 5 requests per IP per hour. Sends an email to info@secuird.tech.
### Health
- `GET /api/health` - Health check
@@ -261,6 +265,52 @@ gunicorn -w 4 -b 0.0.0.0:8000 wsgi:app
- Request ID tracking for audit trails
## Session Management
Sessions are database-backed bearer tokens stored in PostgreSQL. Each session is created at login and validated on every authenticated request via the `login_required` decorator.
### Sliding Timeout
Sessions use a **sliding window** model with two independent limits:
| Timeout | Default | Env Var | Behaviour |
|---------|---------|---------|-----------|
| **Idle** | 15 min | `SESSION_IDLE_TIMEOUT` | Extends automatically on every request. If no request is made within this window the session expires. |
| **Absolute** | 8 h | `SESSION_ABSOLUTE_TIMEOUT` | Hard cap measured from session creation. Activity cannot extend a session beyond this point. |
Every authenticated request resets the idle clock by calling `Session.refresh()`, which sets `expires_at = now + idle_timeout` — but never past `created_at + absolute_timeout`. This means:
- An active user stays logged in indefinitely **up to** the absolute cap.
- An idle user is logged out after the idle timeout.
- No session can survive longer than the absolute timeout regardless of activity.
### Configuration
Override defaults via environment variables:
```bash
SESSION_IDLE_TIMEOUT=900 # seconds (15 min)
SESSION_ABSOLUTE_TIMEOUT=28800 # seconds (8 h)
```
### Cleanup
Expired sessions are soft-marked as `EXPIRED` by the `cleanup_sessions` job. Run it periodically via the job runner:
```bash
python manage.py cleanup_sessions
# Or via the job runner (Docker):
JOB_NAME=cleanup_sessions JOB_INTERVAL_SECONDS=300
```
### Session Endpoints
- `GET /api/v1/auth/sessions` — List active sessions for the current user
- `POST /api/v1/auth/sessions/refresh` — Extend the current session's idle window (returns new `expires_at`)
- `DELETE /api/v1/auth/sessions/:id` — Revoke a specific session
# Boostrap db
python manage.py db upgrade
+163 -7
View File
@@ -49,10 +49,96 @@ class MyServer(BaseHTTPRequestHandler):
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(bytes("<html><head><title>OIDC Workflow Tool</title></head>", "utf-8"))
self.wfile.write(bytes("<body><p>The token has been received</p>", "utf-8"))
self.wfile.write(bytes("<p>You may now close this window.</p>", "utf-8"))
self.wfile.write(bytes("</body></html>", "utf-8"))
html_content = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authentication Successful - Gatehouse</title>
<!-- Best-effort CSS load from primary site -->
<link rel="stylesheet" href="{SIGN_URL}/static/css/main.css">
<style>
* {{
margin: 0;
padding: 0;
box-sizing: border-box;
}}
body {{
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
background-color: #f0f4f8;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
}}
.card {{
background: white;
border-radius: 12px;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.08);
padding: 48px 40px;
text-align: center;
max-width: 400px;
width: 90%;
}}
.checkmark {{
width: 64px;
height: 64px;
background: #10b981;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
margin: 0 auto 24px;
}}
.checkmark svg {{
width: 32px;
height: 32px;
stroke: white;
stroke-width: 3;
fill: none;
}}
h1 {{
color: #1f2937;
font-size: 24px;
font-weight: 600;
margin-bottom: 12px;
}}
p {{
color: #6b7280;
font-size: 16px;
line-height: 1.5;
}}
.fallback {{
margin-top: 24px;
padding-top: 24px;
border-top: 1px solid #e5e7eb;
color: #9ca3af;
font-size: 14px;
}}
</style>
</head>
<body>
<div class="card">
<div class="checkmark">
<svg viewBox="0 0 24 24">
<polyline points="20 6 9 17 4 12"></polyline>
</svg>
</div>
<h1>Authentication Complete</h1>
<p>You can now return to the terminal.</p>
<p class="fallback">If this window doesn't close automatically, you can close it manually.</p>
</div>
<script>
setTimeout(function() {{
window.close();
if (window.innerHeight > 0) {{
document.querySelector('.fallback').textContent = 'Window refused to close. You may close this tab manually.';
}}
}}, 2000);
</script>
</body>
</html>""".format(SIGN_URL=SIGN_URL)
self.wfile.write(bytes(html_content, "utf-8"))
parsed_url = urlparse(self.path)
query_data = dict(parse_qsl(parsed_url.query))
@@ -253,7 +339,7 @@ def fetch_my_principals():
return principal_names
def request_certificate():
def request_certificate(org_id=None):
CERT_ID = os.getenv("CERT_ID") or get_activated_ssh_key()
principals = fetch_my_principals()
@@ -272,9 +358,13 @@ def request_certificate():
'principals': principals,
}
# Add organization_id if specified
if org_id:
payload['organization_id'] = org_id
try:
response = requests.post(f"{SIGN_URL}/api/v1/ssh/sign", json=payload, headers=headers)
if response.status_code == 201:
json_result = response.json().get('data', response.json())
with open(CERT_FILE_PATH, 'w') as f:
@@ -287,14 +377,41 @@ def request_certificate():
logger.info(f"Certificate signed successfully, located at {CERT_FILE_PATH}")
logger.info(f"Valid for principals: {', '.join(json_result.get('principals', principals))}")
# Show which org issued the cert
org_name = json_result.get('organization_name', 'Unknown')
logger.info(f"Issued by organization: {org_name}")
logger.info("You can login to your destination server with the following command")
logger.info(f"\tssh user@server -o CertificateFile={CERT_FILE_PATH}")
elif response.status_code == 400:
error_data = response.json()
if error_data.get('error', {}).get('type') == 'MULTIPLE_ORGS_AMBIGUOUS':
logger.error("You are a member of multiple organizations. Please specify one with --org-id")
logger.error("\nYour organizations:")
for org in error_data.get('error', {}).get('details', {}).get('organizations', []):
logger.error(f" - {org['name']} (ID: {org['id']}, Role: {org['role']})")
logger.error("\nRun: secuird --list-orgs to see all your organizations")
logger.error("Then run: secuird -r --org-id <organization_id>")
else:
logger.error(f"Error: {error_data.get('message', 'Unknown error')}")
exit(1)
elif response.status_code == 403:
error_data = response.json()
logger.error(f"Permission denied: {error_data.get('message', 'Unknown error')}")
exit(1)
else:
logger.error("Error in response from server")
logger.error(f"Status code: {response.status_code}")
logger.error(f"Response text: {response.text}")
exit(1)
except Exception as e:
logger.error(f"Error during certificate signing: {e}")
exit(1)
def generate_and_sign_challenge(ssh_key_file, key_id):
"""Fetch a challenge from the server, sign it with the SSH key, and submit the signature."""
@@ -321,11 +438,13 @@ def generate_and_sign_challenge(ssh_key_file, key_id):
with open(CHALLENGE_FILE_PATH, 'w') as f:
f.write(challenge_text)
os.chmod(CHALLENGE_FILE_PATH, 0o600)
subprocess.run(
["ssh-keygen", "-Y", "sign", "-f", ssh_key_file, "-n", "file", CHALLENGE_FILE_PATH],
check=True,
)
os.chmod(CHALLENGE_SIG_FILE_PATH, 0o600)
with open(CHALLENGE_SIG_FILE_PATH, 'rb') as f:
signature = base64.b64encode(f.read()).decode('utf-8')
@@ -399,6 +518,38 @@ def remove_ssh_key(key_id=None):
else:
logger.error(f"Failed to remove key {k['id']}: {del_response.status_code} - {del_response.text}")
def list_organizations():
"""List all organizations the user is a member of."""
response = requests.get(
f"{SIGN_URL}/api/v1/users/me/organizations/simple",
headers=auth_headers()
)
if response.status_code != 200:
logger.error(f"Failed to list organizations: {response.status_code} - {response.text}")
exit(1)
data = response.json().get('data', {})
orgs = data.get('organizations', [])
if not orgs:
print("You are not a member of any organizations.")
return
print("\nYour Organizations:")
print("-" * 80)
for org in orgs:
ca_status = []
if org.get('has_user_ca'):
ca_status.append("User CA ✓")
if org.get('has_host_ca'):
ca_status.append("Host CA ✓")
ca_str = f" ({', '.join(ca_status)})" if ca_status else " (No CAs configured)"
print(f" ID: {org['id']}")
print(f" Name: {org['name']}{ca_str}")
print(f" Role: {org['role']}")
print("-" * 80)
def add_ssh_key(ssh_key_file):
"""Add an SSH key to the server and auto-verify it."""
@@ -540,6 +691,11 @@ if __name__ == "__main__":
remove_ssh_key(args.remove_key if args.remove_key else None)
exit(0)
if args.list_orgs:
request_token()
list_organizations()
exit(0)
if args.list_keys:
request_token()
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers())
@@ -580,5 +736,5 @@ if __name__ == "__main__":
if args.force:
logger.info("Forcing renewal of certificate")
if args.force or checkCert() == 1:
request_certificate()
request_certificate(org_id=args.org_id)
exit(0)
+5
View File
@@ -48,6 +48,10 @@ class BaseConfig:
seconds=int(os.getenv("MAX_SESSION_DURATION", "86400"))
)
# Session timeout policy (seconds)
SESSION_IDLE_TIMEOUT = int(os.getenv("SESSION_IDLE_TIMEOUT", "900"))
SESSION_ABSOLUTE_TIMEOUT = int(os.getenv("SESSION_ABSOLUTE_TIMEOUT", "28800"))
# CORS
CORS_ORIGINS = os.getenv(
"CORS_ORIGINS",
@@ -83,6 +87,7 @@ class BaseConfig:
RATELIMIT_AUTH_TOTP_VERIFY = os.getenv("RATELIMIT_AUTH_TOTP_VERIFY", "20 per minute; 100 per hour")
RATELIMIT_AUTH_FORGOT_PASSWORD = os.getenv("RATELIMIT_AUTH_FORGOT_PASSWORD", "5 per minute; 20 per hour")
RATELIMIT_AUTH_RESET_PASSWORD = os.getenv("RATELIMIT_AUTH_RESET_PASSWORD", "10 per minute; 30 per hour")
RATELIMIT_AUTH_RESEND_VERIFICATION = os.getenv("RATELIMIT_AUTH_RESEND_VERIFICATION", "5 per minute; 20 per hour")
# Logging
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
+7 -1
View File
@@ -30,7 +30,13 @@ class TestingConfig(BaseConfig):
# Use different Redis DB for testing
REDIS_URL = "redis://localhost:6379/15"
# Use filesystem for sessions in testing
SESSION_TYPE = "filesystem"
SESSION_FILE_DIR = "/tmp/flask_session_test"
# Override cookie domain so test_client on localhost can send cookies
SESSION_COOKIE_DOMAIN = None
WEBAUTHN_RP_ID = "localhost"
WEBAUTHN_ORIGIN = "http://localhost:8080"
FRONTEND_URL = "http://localhost:8080"
+36
View File
@@ -78,6 +78,42 @@ services:
timeout: 10s
retries: 3
zerotier-reconciler:
build:
context: .
dockerfile: Dockerfile.job
env_file:
- .env
environment:
- JOB_NAME=zerotier_reconciliation
- JOB_INTERVAL_SECONDS=${ZEROTIER_RECONCILE_INTERVAL:-120}
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
networks:
- authy2-network
restart: unless-stopped
mfa-compliance:
build:
context: .
dockerfile: Dockerfile.job
env_file:
- .env
environment:
- JOB_NAME=mfa_compliance
- JOB_INTERVAL_SECONDS=${MFA_COMPLIANCE_INTERVAL:-3600}
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
networks:
- authy2-network
restart: unless-stopped
networks:
authy2-network:
driver: bridge
+3 -1
View File
@@ -5,7 +5,9 @@ from flask import Blueprint
api_v1_bp = Blueprint("api_v1", __name__)
# Import route modules to register them
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh, zerotier, sudo, oidc
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh, zerotier, sudo, oidc, contact
from gatehouse_app.api.v1 import superadmin
api_v1_bp.register_blueprint(ssh.ssh_bp)
api_v1_bp.register_blueprint(superadmin.superadmin_bp)
+29 -2
View File
@@ -116,7 +116,7 @@ def login():
remember_me = data.get("remember_me", False)
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me)
duration = 2592000 if remember_me else 86400
duration = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800) if remember_me else None
is_compliance_only = policy_result.create_compliance_only_session
user_session = AuthService.create_session(user, duration_seconds=duration, is_compliance_only=is_compliance_only)
@@ -227,6 +227,32 @@ def revoke_session(session_id):
return api_response(message="Session revoked successfully")
@api_v1_bp.route("/auth/sessions/refresh", methods=["POST"])
@login_required
def refresh_session():
"""Extend the current session's idle window.
The server already refreshes the session on every authenticated
request, but this endpoint exists so the frontend can proactively
keep a session alive (e.g. a heartbeat while the user is reading
a long page with no API calls).
Returns the new ``expires_at`` so the frontend can display a
countdown or warning before the absolute cap.
"""
session = g.current_session
session.refresh()
return api_response(
data={
"expires_at": session.expires_at.isoformat() + "Z"
if session.expires_at.isoformat()[-1] != "Z"
else session.expires_at.isoformat(),
},
message="Session refreshed",
)
@api_v1_bp.route("/auth/token", methods=["GET"])
@login_required
def get_token():
@@ -246,7 +272,8 @@ def get_token():
parsed_redirect = urlparse(redirect_url)
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
if redirect_origin not in allowed_origins:
wildcard = "*" in allowed_origins
if not wildcard and redirect_origin not in allowed_origins:
return api_response(success=False, message="Redirect URL is not allowed.", status=400, error_type="INVALID_REDIRECT")
sep = "&" if "?" in redirect_url else "?"
+32 -1
View File
@@ -1,8 +1,9 @@
"""Password reset, email verification, and account activation endpoints."""
import logging
from flask import request, current_app
from flask import request, current_app, g
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.extensions import limiter
from gatehouse_app.utils.decorators import login_required
from gatehouse_app.utils.response import api_response
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.notification_service import NotificationService
@@ -151,6 +152,36 @@ def resend_verification():
return api_response(data={}, message="If an account exists for this email and is not yet verified, you will receive a verification link shortly.")
@api_v1_bp.route("/auth/me/resend-verification", methods=["POST"])
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESEND_VERIFICATION"])
@login_required
def resend_verification_authenticated():
from gatehouse_app.models import EmailVerificationToken
user = g.current_user
if not user.email_verified:
try:
verify_token = EmailVerificationToken.generate(user_id=user.id)
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
verify_link = f"{app_url}/verify-email?token={verify_token.token}"
email_body = build_email_verification_html(
user_name=user.full_name or user.email,
verify_link=verify_link,
expiry_hours=24,
)
NotificationService._send_email_async(
to_address=user.email,
subject="Verify your Secuird email address",
body=f"Verify your Secuird email: {verify_link}",
html_body=email_body,
)
_logger.info(f"Verification email sent for authenticated user {user.id}")
except Exception as exc:
_logger.exception(f"Error sending verification email: {exc}")
return api_response(data={}, message="If your email is not yet verified, you will receive a verification link shortly.")
@api_v1_bp.route("/auth/activate", methods=["POST"])
def activate_account():
import secrets
+68
View File
@@ -0,0 +1,68 @@
"""Contact form endpoint for website enquiries."""
import logging
from flask import request, current_app
from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.extensions import limiter
from gatehouse_app.utils.response import api_response
from gatehouse_app.schemas.contact_schema import ContactSchema
from gatehouse_app.services.notification_service import NotificationService
from gatehouse_app.services.email_templates import build_contact_enquiry_html
logger = logging.getLogger(__name__)
# Hardcoded destination for all contact submissions
CONTACT_DESTINATION = "info@secuird.tech"
@api_v1_bp.route("/contact", methods=["POST"])
@limiter.limit("5 per hour")
def contact():
"""Handle contact form submissions from the marketing website.
Accepts: email, name, company, enquiry_type, message, interest_area, _hp.
Sends an email to info@secuird.tech with the enquiry details.
Silently discards submissions where the honeypot field (_hp) is filled.
"""
try:
schema = ContactSchema()
data = schema.load(request.get_json() or {})
except ValidationError as err:
return api_response(
success=False,
message="Invalid request data",
status=400,
error_type="VALIDATION_ERROR",
error_details=err.messages,
)
# Honeypot check — silently succeed without sending
if data.get("_hp"):
logger.info(f"[Contact] Honeypot triggered, ip={request.remote_addr}")
return api_response(message="Thank you for your message!")
enquiry_type = data.get("enquiry_type") or "general"
email = data.get("email") or ""
# Build and send email
html_body = build_contact_enquiry_html(
enquiry_type=enquiry_type,
submitter_email=email,
name=data.get("name"),
company=data.get("company"),
interest_area=data.get("interest_area"),
message=data.get("message"),
)
NotificationService._send_email_async(
to_address=CONTACT_DESTINATION,
subject=f"Secuird Website: {enquiry_type.replace('_', ' ').title()} from {email}",
body=f"New contact enquiry ({enquiry_type}) from {email}",
html_body=html_body,
)
logger.info(f"[Contact] enquiry_type={enquiry_type} ip={request.remote_addr}")
return api_response(message="Thank you for your message!")
+4 -6
View File
@@ -21,7 +21,7 @@ def admin_list_app_providers():
return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
PROVIDERS = [{"id": "google", "name": "Google"}, {"id": "github", "name": "GitHub"}, {"id": "microsoft", "name": "Microsoft"}]
db_configs = {c.provider_type: c for c in ApplicationProviderConfig.query.all()}
db_configs = {c.provider_type: c for c in ApplicationProviderConfig.query.filter_by(deleted_at=None).all()}
result = []
for p in PROVIDERS:
@@ -64,7 +64,7 @@ def admin_configure_app_provider(provider: str):
if not client_id:
return api_response(success=False, message="client_id is required", status=400, error_type="VALIDATION_ERROR")
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider, deleted_at=None).first()
if cfg:
cfg.client_id = client_id
if client_secret:
@@ -90,7 +90,6 @@ def admin_delete_app_provider(provider: str):
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
from gatehouse_app.models import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
from gatehouse_app.extensions import db
admin_memberships = OrganizationMember.query.filter(
OrganizationMember.user_id == g.current_user.id,
@@ -100,10 +99,9 @@ def admin_delete_app_provider(provider: str):
if not admin_memberships:
return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider, deleted_at=None).first()
if not cfg:
return api_response(success=False, message=f"Provider '{provider}' is not configured", status=404, error_type="NOT_FOUND")
db.session.delete(cfg)
db.session.commit()
cfg.delete()
return api_response(message=f"{provider.capitalize()} OAuth provider configuration removed")
+5 -4
View File
@@ -174,6 +174,7 @@ def select_organization():
auth_method = AuthenticationMethod.query.filter_by(
method_type=state_record.provider_type,
deleted_at=None,
).order_by(AuthenticationMethod.created_at.desc()).first()
if not auth_method:
@@ -181,16 +182,16 @@ def select_organization():
user = auth_method.user
org = Organization.query.get(organization_id)
org = Organization.query.filter_by(id=organization_id, deleted_at=None).first()
if not org:
return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND")
member = OrganizationMember.query.filter_by(user_id=user.id, organization_id=organization_id).first()
member = OrganizationMember.query.filter_by(user_id=user.id, organization_id=organization_id, deleted_at=None).first()
if not member:
return api_response(success=False, message="You are not a member of this organization", status=403, error_type="FORBIDDEN")
from gatehouse_app.services.session_service import SessionService
session = SessionService.create_session(user=user, organization_id=organization_id)
from gatehouse_app.services.auth_service import AuthService
session = AuthService.create_session(user=user)
state_record.mark_used()
provider_type_val = state_record.provider_type.value if isinstance(state_record.provider_type, _AuthMethodType) else state_record.provider_type
@@ -14,13 +14,13 @@ from gatehouse_app.api.v1.external_auth._helpers import get_provider_type, _get_
def list_providers():
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
app_configs = {c.provider_type.lower(): c for c in ApplicationProviderConfig.query.filter_by(is_enabled=True).all()}
app_configs = {c.provider_type.lower(): c for c in ApplicationProviderConfig.query.filter_by(is_enabled=True, deleted_at=None).all()}
user_orgs = g.current_user.get_organizations()
org_configs = {}
if user_orgs:
organization_id = user_orgs[0].id
org_level = ExternalProviderConfig.query.filter_by(organization_id=organization_id).all()
org_level = ExternalProviderConfig.query.filter_by(organization_id=organization_id, deleted_at=None).all()
org_configs = {c.provider_type.lower(): c for c in org_level}
def provider_info(provider_id, name):
@@ -50,11 +50,11 @@ def get_provider_config(provider: str):
return api_response(success=False, message="No organizations found for user", status=400, error_type="BAD_REQUEST")
organization_id = user_orgs[0].id
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id).first()
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id, deleted_at=None).first()
if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value).first()
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value, deleted_at=None).first()
if not config:
return api_response(success=False, message=f"{provider.title()} OAuth is not configured", status=404, error_type="NOT_FOUND")
@@ -74,7 +74,7 @@ def create_or_update_provider_config(provider: str):
return api_response(success=False, message="No organizations found for user", status=400, error_type="BAD_REQUEST")
organization_id = user_orgs[0].id
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id).first()
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id, deleted_at=None).first()
if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
@@ -85,7 +85,7 @@ def create_or_update_provider_config(provider: str):
if not client_id:
return api_response(success=False, message="client_id is required", status=400, error_type="VALIDATION_ERROR")
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value).first()
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value, deleted_at=None).first()
is_new = config is None
if config:
@@ -137,11 +137,11 @@ def delete_provider_config(provider: str):
return api_response(success=False, message="No organizations found for user", status=400, error_type="BAD_REQUEST")
organization_id = user_orgs[0].id
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id).first()
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id, deleted_at=None).first()
if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value).first()
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value, deleted_at=None).first()
if not config:
return api_response(success=False, message=f"{provider.title()} OAuth is not configured", status=404, error_type="NOT_FOUND")
+2 -2
View File
@@ -819,9 +819,9 @@ def oidc_register():
org_id = data.get("organization_id")
if org_id:
organization = Organization.query.get(org_id)
organization = Organization.query.filter_by(id=org_id, deleted_at=None).first()
else:
organization = Organization.query.filter_by(is_active=True).first()
organization = Organization.query.filter_by(is_active=True, deleted_at=None).first()
if not organization:
organization = Organization(
@@ -1,4 +1,5 @@
"""Organization core CRUD endpoints."""
import logging
from flask import g, request
from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp
@@ -12,6 +13,15 @@ from gatehouse_app.services.organization_service import OrganizationService
@login_required
@full_access_required
def create_organization():
try:
org_count = OrganizationService.get_user_org_count(g.current_user.id)
if org_count is not None and org_count >= 10:
return api_response(success=False, message="You cannot belong to more than 10 organizations", status=400, error_type="ORG_LIMIT_REACHED")
except Exception as e:
logger = logging.getLogger(__name__)
logger.warning(f"[Org] Failed to check org count for user {g.current_user.id}: {e}")
# Fail open to avoid blocking legitimate users when the count service is unavailable
try:
schema = OrganizationCreateSchema()
data = schema.load(request.json)
+16 -2
View File
@@ -8,7 +8,8 @@ from gatehouse_app.services.notification_service import NotificationService
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.email_templates import build_org_invite_html
from gatehouse_app.utils.constants import OrganizationRole
from gatehouse_app.utils.constants import AuditAction, OrganizationRole
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/organizations/<org_id>/invites", methods=["POST"])
@@ -56,6 +57,19 @@ def create_org_invite(org_id):
logging.getLogger(__name__).info(f"[INVITE] Email queued for {email}")
email_sent = True # async — assume queued successfully
AuditService.log_action(
action=AuditAction.ORG_INVITE_SENT,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="org_invite",
resource_id=invite.id,
metadata={
"invited_email": email,
"role": role,
},
description=f"Invitation sent to {email} with role {role}",
)
response_data = {
"invite": {
"id": invite.id,
@@ -173,7 +187,7 @@ def accept_invite(token):
if session_user.email.lower() != invite.email.lower():
return api_response(
success=False,
message="This invite was sent to a different email address.",
message="You are already logged in and this invite was sent to a different email address.",
status=403,
error_type="EMAIL_MISMATCH",
)
@@ -56,7 +56,10 @@ def add_organization_member(org_id):
@full_access_required
def remove_organization_member(org_id, user_id):
org = OrganizationService.get_organization_by_id(org_id)
OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id)
try:
OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id)
except ValueError as e:
return api_response(success=False, message=str(e), status=403, error_type="OWNER_PROTECTION")
return api_response(message="Member removed successfully")
@@ -155,7 +158,7 @@ def send_mfa_reminder(org_id, user_id):
if not user:
return api_response(success=False, message="User not found", status=404)
compliance = MfaPolicyCompliance.query.filter_by(user_id=user_id, organization_id=org_id).first()
compliance = MfaPolicyCompliance.query.filter_by(user_id=user_id, organization_id=org_id, deleted_at=None).first()
policy = OrganizationSecurityPolicy.query.filter_by(organization_id=org_id).first()
if compliance and policy and compliance.deadline_at:
+12 -5
View File
@@ -11,16 +11,23 @@ ssh_ca_service = SSHCASigningService()
_logger = logging.getLogger(__name__)
def _get_org_ca_for_user(user, ca_type: str = "user"):
def _get_org_ca_for_user(user, ca_type: str = "user", organization_id=None):
try:
from gatehouse_app.models.ssh_ca.ca import CA, CaType
org_ids = [m.organization_id for m in user.organization_memberships]
if organization_id:
org_ids = [organization_id]
else:
org_ids = [m.organization_id for m in user.get_active_memberships()]
if not org_ids:
return None
return CA.query.filter(
CA.organization_id.in_(org_ids),
CA.ca_type == CaType(ca_type),
CA.is_active == True, # noqa: E712
CA.is_active == True,
CA.deleted_at.is_(None),
).first()
except Exception:
return None
@@ -34,7 +41,7 @@ def _get_or_create_system_ca():
import os
try:
existing = CA.query.filter_by(name="system-config-ca").first()
existing = CA.query.filter_by(name="system-config-ca", deleted_at=None).first()
if existing:
return existing
@@ -60,7 +67,7 @@ def _get_or_create_system_ca():
fingerprint = compute_ssh_fingerprint(pub_key)
existing_by_fp = CA.query.filter_by(fingerprint=fingerprint).first()
existing_by_fp = CA.query.filter_by(fingerprint=fingerprint, deleted_at=None).first()
if existing_by_fp:
return existing_by_fp
+73 -22
View File
@@ -14,6 +14,12 @@ from gatehouse_app.utils.decorators import login_required
from gatehouse_app.utils.response import api_response
def _validate_uuid(uuid_str: str) -> bool:
"""Validate UUID format."""
import re
return bool(re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', uuid_str, re.I))
@ssh_bp.route('/dept-cert-policy', methods=['GET'])
@login_required
def get_my_dept_cert_policy():
@@ -60,6 +66,7 @@ def sign_certificate():
cert_type = data.get('cert_type', 'user')
key_id = data.get('key_id') or data.get('cert_id')
expiry_hours = data.get('expiry_hours')
requested_org_id = data.get('organization_id')
AuditLog.log(
action=AuditAction.SSH_CERT_REQUESTED,
@@ -67,22 +74,63 @@ def sign_certificate():
description=(f'{user.email} requested a certificate' + (f' for principals: {", ".join(requested_principals)}' if requested_principals else '')),
)
# Validate organization_id if provided
if requested_org_id and not _validate_uuid(requested_org_id):
return api_response(success=False, message="Invalid organization_id format. Must be a valid UUID.", status=400, error_type="INVALID_ORG_ID")
# Get user's active organization memberships
memberships = OrganizationMember.query.filter_by(user_id=user_id, deleted_at=None).all()
active_memberships = [om for om in memberships if om.organization and om.organization.deleted_at is None]
if not active_memberships:
return api_response(success=False, message="You are not a member of any active organizations.", status=400, error_type="NO_ORG_MEMBERSHIPS")
# Select target organization
target_org = None
if requested_org_id:
# Check if user is member of the requested organization
target_membership = next((om for om in active_memberships if str(om.organization_id).lower() == requested_org_id.lower()), None)
if not target_membership:
return api_response(success=False, message="You are not a member of the specified organization.", status=403, error_type="NOT_ORG_MEMBER")
target_org = target_membership.organization
if not target_org or target_org.deleted_at is not None:
return api_response(success=False, message="The specified organization was not found or has been deleted.", status=404, error_type="ORG_NOT_FOUND")
else:
# No organization specified - use default logic for backward compatibility
if len(active_memberships) > 1:
org_names = [om.organization.name for om in active_memberships]
orgs_data = [
{
"id": m.organization_id,
"name": m.organization.name,
"role": m.role.value if hasattr(m.role, "value") else str(m.role)
}
for m in active_memberships
]
return api_response(
success=False,
message="You are a member of multiple organizations. Please specify organization_id.",
status=400,
error_type="MULTIPLE_ORGS_AMBIGUOUS",
error_details={"organizations": orgs_data}
)
target_org = active_memberships[0].organization
# Get allowed principals for the selected organization
allowed_principal_names = set()
memberships = OrganizationMember.query.filter_by(user_id=user_id).all()
for om in memberships:
org = om.organization
if not org or org.deleted_at is not None:
continue
role = om.role
target_membership = next((om for om in active_memberships if str(om.organization_id).lower() == str(target_org.id).lower()), None)
if target_membership:
role = target_membership.role
if role in (OrganizationRole.ADMIN, OrganizationRole.OWNER):
for p in Principal.query.filter_by(organization_id=org.id, deleted_at=None).all():
for p in Principal.query.filter_by(organization_id=target_org.id, deleted_at=None).all():
allowed_principal_names.add(p.name)
else:
for pm in PrincipalMembership.query.filter_by(user_id=user_id, deleted_at=None).all():
if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None:
if pm.principal and pm.principal.organization_id == target_org.id and pm.principal.deleted_at is None:
allowed_principal_names.add(pm.principal.name)
for dm in DepartmentMembership.query.filter_by(user_id=user_id, deleted_at=None).all():
if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None:
if dm.department and dm.department.organization_id == target_org.id and dm.department.deleted_at is None:
for dp in DepartmentPrincipal.query.filter_by(department_id=dm.department_id, deleted_at=None).all():
if dp.principal and dp.principal.deleted_at is None:
allowed_principal_names.add(dp.principal.name)
@@ -114,7 +162,8 @@ def sign_certificate():
if not ssh_key.verified:
return api_response(success=False, message="SSH key is not verified. Verify it before requesting a certificate.", status=400, error_type="KEY_NOT_VERIFIED")
db_ca = _get_org_ca_for_user(user, ca_type=cert_type)
# Use the selected organization's ID for CA selection
db_ca = _get_org_ca_for_user(user, ca_type=cert_type, organization_id=target_org.id)
if db_ca is None:
return api_response(
success=False,
@@ -122,11 +171,7 @@ def sign_certificate():
status=503, error_type="CA_NOT_CONFIGURED",
)
is_org_admin = any(
om.role in (OrganizationRole.ADMIN, OrganizationRole.OWNER)
for om in memberships
if om.organization and om.organization.deleted_at is None
)
is_org_admin = target_membership.role in (OrganizationRole.ADMIN, OrganizationRole.OWNER) if target_membership else False
dept_policy = _get_merged_dept_cert_policy(user_id)
if dept_policy:
@@ -146,11 +191,7 @@ def sign_certificate():
else:
policy_extensions = None
org_slugs = sorted({
om.organization.slug for om in memberships
if om.organization and om.organization.deleted_at is None and getattr(om.organization, 'slug', None)
})
org_slug = org_slugs[0] if org_slugs else "unknown"
org_slug = getattr(target_org, 'slug', 'unknown')
full_name = getattr(user, 'full_name', None) or getattr(user, 'name', None) or "unknown"
cert_identity = f"{user.email} ({full_name}) [org:{org_slug}]"
@@ -185,12 +226,13 @@ def sign_certificate():
resource_type='SSHCertificate', resource_id=cert_record.id if cert_record else key_id,
ip_address=request.remote_addr,
description=f'Certificate serial={response.serial} issued for {user.email}; principals: {", ".join(principals)}',
extra_data={'serial': response.serial, 'key_id': cert_identity, 'principals': principals, 'ca_id': str(db_ca.id), 'ssh_key_id': str(key_id)},
extra_data={'serial': response.serial, 'key_id': cert_identity, 'principals': principals, 'ca_id': str(db_ca.id), 'ssh_key_id': str(key_id), 'organization_id': str(target_org.id), 'organization_name': target_org.name},
)
if cert_record:
CertificateAuditLog.log(
certificate_id=cert_record.id, action='issued', user_id=user_id,
organization_id=str(target_org.id),
ip_address=request.remote_addr, user_agent=request.headers.get('User-Agent'),
message=f'Certificate serial={response.serial} issued for {user.email}; principals: {", ".join(principals)}',
extra_data={
@@ -198,6 +240,7 @@ def sign_certificate():
'ca_id': str(db_ca.id), 'ssh_key_id': str(key_id),
'valid_after': response.valid_after.isoformat() if response.valid_after else None,
'valid_before': response.valid_before.isoformat() if response.valid_before else None,
'organization_id': str(target_org.id),
},
success=True,
)
@@ -207,6 +250,8 @@ def sign_certificate():
'principals': response.principals,
'valid_after': response.valid_after.isoformat() if response.valid_after else None,
'valid_before': response.valid_before.isoformat() if response.valid_before else None,
'organization_id': str(target_org.id),
'organization_name': target_org.name,
}
if cert_record:
result['cert_id'] = str(cert_record.id)
@@ -371,7 +416,13 @@ def revoke_certificate(cert_id):
cert.revoke(reason=reason)
AuditLog.log(action=AuditAction.SSH_CERT_REVOKED, user_id=user_id, resource_type='SSHCertificate', resource_id=cert_id, ip_address=request.remote_addr, description=f'Revoked: {reason}')
CertificateAuditLog.log(certificate_id=cert_id, action='revoked', user_id=user_id, ip_address=request.remote_addr, user_agent=request.headers.get('User-Agent'), message=f'Certificate revoked: {reason}', success=True)
# Get organization from certificate's CA for audit logging
from gatehouse_app.models.ssh_ca.ca import CA
ca = CA.query.get(cert.ca_id)
org_id = ca.organization_id if ca else None
CertificateAuditLog.log(certificate_id=cert_id, action='revoked', user_id=user_id, organization_id=org_id, ip_address=request.remote_addr, user_agent=request.headers.get('User-Agent'), message=f'Certificate revoked: {reason}', success=True)
return api_response(success=True, message='Certificate revoked successfully', data={'status': 'revoked', 'cert_id': cert_id, 'reason': reason}, status=200)
except Exception as e:
+6 -5
View File
@@ -32,11 +32,12 @@ def add_ssh_key():
return api_response(success=False, message='public_key is required', status=400, error_type='BAD_REQUEST')
try:
ssh_key = ssh_key_service.add_ssh_key(user_id=user_id, public_key=public_key, description=description)
AuditLog.log(action=AuditAction.SSH_KEY_ADDED, user_id=user_id, resource_type='SSHKey', resource_id=ssh_key.id, ip_address=request.remote_addr)
return api_response(success=True, message='SSH key added', data=ssh_key.to_dict(), status=201)
except SSHKeyAlreadyExistsError as e:
return api_response(success=False, message=e.message, status=409, error_type='SSH_KEY_ALREADY_EXISTS')
ssh_key, is_new = ssh_key_service.add_ssh_key(user_id=user_id, public_key=public_key, description=description)
if is_new:
AuditLog.log(action=AuditAction.SSH_KEY_ADDED, user_id=user_id, resource_type='SSHKey', resource_id=ssh_key.id, ip_address=request.remote_addr)
return api_response(success=True, message='SSH key added', data=ssh_key.to_dict(), status=201)
else:
return api_response(success=True, message='SSH key already exists', data=ssh_key.to_dict(), status=200)
except IntegrityError:
return api_response(success=False, message='SSH key already exists', status=409, error_type='SSH_KEY_ALREADY_EXISTS')
except SSHKeyError as e:
@@ -0,0 +1,14 @@
"""Superadmin API blueprint."""
import logging
from flask import Blueprint
from gatehouse_app.extensions import limiter
logger = logging.getLogger(__name__)
# Create superadmin blueprint
superadmin_bp = Blueprint("superadmin", __name__, url_prefix="/superadmin")
# Import route modules to register them
from gatehouse_app.api.v1.superadmin import auth, organizations, organization_members, usage_analytics, users, billing, cas # noqa: F401
+286
View File
@@ -0,0 +1,286 @@
"""Superadmin authentication endpoints."""
import logging
from flask import request, g, current_app
from marshmallow import ValidationError
from gatehouse_app.api.v1.superadmin import superadmin_bp
from gatehouse_app.extensions import limiter
from gatehouse_app.utils.response import api_response
from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
logger = logging.getLogger(__name__)
class LoginSchema:
"""Schema for superadmin login."""
@staticmethod
def load(data):
"""Validate login data."""
errors = {}
if not data.get('email'):
errors['email'] = ['Email is required']
elif '@' not in data['email']:
errors['email'] = ['Invalid email format']
if not data.get('password'):
errors['password'] = ['Password is required']
if errors:
raise ValidationError(errors)
return {
'email': data['email'].lower().strip(),
'password': data['password'],
}
@superadmin_bp.route("/auth/login", methods=["POST"])
@limiter.limit(lambda: current_app.config.get("RATELIMIT_AUTH_LOGIN", "100 per minute"))
def login():
"""Superadmin login endpoint.
Authenticates with email/password and returns a session token.
"""
try:
schema = LoginSchema()
data = schema.load(request.json)
# Authenticate
superadmin = SuperadminAuthService.authenticate(
email=data['email'],
credentials=data['password']
)
# Create session (default 8 hours)
session = SuperadminAuthService.create_session(
superadmin_id=superadmin.id,
duration_seconds=28800 # 8 hours
)
expires_str = session.expires_at.isoformat()
if not expires_str.endswith('Z'):
expires_str += 'Z'
logger.info(f"[SuperadminAuth] Login successful for: {superadmin.email}")
return api_response(
data={
"superadmin": superadmin.to_dict(),
"token": session.token,
"expires_at": expires_str,
},
message="Login successful",
status=200
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages
)
except InvalidCredentialsError:
return api_response(
success=False,
message="Invalid email or password",
status=401,
error_type="INVALID_CREDENTIALS"
)
except Exception as e:
logger.error(f"[SuperadminAuth] Login error: {e}")
return api_response(
success=False,
message="An error occurred during login",
status=500,
error_type="INTERNAL_ERROR"
)
@superadmin_bp.route("/auth/logout", methods=["POST"])
@superadmin_required
def logout():
"""Superadmin logout endpoint.
Invalidates the current session.
"""
try:
session = g.superadmin_session
if session:
SuperadminAuthService.revoke_session(session.id, reason="Superadmin logout")
return api_response(
message="Logout successful"
)
except Exception as e:
logger.error(f"[SuperadminAuth] Logout error: {e}")
return api_response(
success=False,
message="An error occurred during logout",
status=500,
error_type="INTERNAL_ERROR"
)
@superadmin_bp.route("/auth/me", methods=["GET"])
@superadmin_required
def get_current_superadmin():
"""Get current superadmin profile.
Returns the profile of the currently authenticated superadmin.
"""
try:
superadmin = g.current_superadmin
return api_response(
data={
"superadmin": superadmin.to_dict(),
},
message="Superadmin retrieved successfully"
)
except Exception as e:
logger.error(f"[SuperadminAuth] Get me error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR"
)
@superadmin_bp.route("/auth/impersonate/<user_id>", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="impersonate", resource_type="user")
def impersonate_user(user_id):
"""Create emergency access session by impersonating a user.
Creates a temporary session for the target user that allows
the superadmin to access the platform as that user.
This action is fully audited.
"""
try:
superadmin = g.current_superadmin
data = request.json or {}
reason = data.get('reason', 'Not specified')
duration_minutes = data.get('duration_minutes', 15)
# Limit duration to max 60 minutes
duration_minutes = min(duration_minutes, 60)
# Create emergency access
result = SuperadminAuthService.create_emergency_access(
superadmin_id=superadmin.id,
target_user_id=user_id,
reason=reason,
duration_minutes=duration_minutes
)
expires_str = result['expires_at'].isoformat()
if not expires_str.endswith('Z'):
expires_str += 'Z'
logger.warning(
f"[SuperadminAuth] IMPERSONATION: superadmin={superadmin.email} "
f"impersonated user_id={user_id} reason={reason}"
)
return api_response(
data={
"session_token": result['session'].token,
"expires_at": expires_str,
"target_user_id": user_id,
"reason": reason,
"duration_minutes": duration_minutes,
},
message="Emergency access session created",
status=201
)
except ValueError as e:
return api_response(
success=False,
message=str(e),
status=404,
error_type="NOT_FOUND"
)
except Exception as e:
logger.error(f"[SuperadminAuth] Impersonate error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR"
)
@superadmin_bp.route("/auth/emergency/<user_id>", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="emergency_access", resource_type="user")
def grant_emergency_access(user_id):
"""Grant temporary elevated access to a user.
Similar to impersonate but grants elevated permissions
rather than creating a session as the user.
This action is fully audited.
"""
try:
superadmin = g.current_superadmin
data = request.json or {}
reason = data.get('reason', 'Not specified')
duration_minutes = data.get('duration_minutes', 15)
# Limit duration to max 60 minutes
duration_minutes = min(duration_minutes, 60)
# Create emergency access
result = SuperadminAuthService.create_emergency_access(
superadmin_id=superadmin.id,
target_user_id=user_id,
reason=reason,
duration_minutes=duration_minutes
)
expires_str = result['expires_at'].isoformat()
if not expires_str.endswith('Z'):
expires_str += 'Z'
logger.warning(
f"[SuperadminAuth] EMERGENCY ACCESS: superadmin={superadmin.email} "
f"granted access to user_id={user_id} reason={reason}"
)
return api_response(
data={
"session_token": result['session'].token,
"expires_at": expires_str,
"target_user_id": user_id,
"reason": reason,
"duration_minutes": duration_minutes,
},
message="Emergency access granted",
status=201
)
except ValueError as e:
return api_response(
success=False,
message=str(e),
status=404,
error_type="NOT_FOUND"
)
except Exception as e:
logger.error(f"[SuperadminAuth] Emergency access error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR"
)
+568
View File
@@ -0,0 +1,568 @@
"""Superadmin billing endpoints for plans and subscriptions."""
import logging
from datetime import datetime, timezone
from flask import request
from gatehouse_app.api.v1.superadmin import superadmin_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.extensions import db
logger = logging.getLogger(__name__)
# ============ Plans Endpoints ============
@superadmin_bp.route("/billing/plans", methods=["GET"])
@superadmin_required
def list_plans():
"""Get all available plans."""
try:
from gatehouse_app.models.billing.plan import Plan
plans = Plan.query.filter(Plan.is_active == True).order_by(Plan.price_monthly.asc()).all()
items = [{
"id": p.id,
"name": p.name,
"slug": p.slug,
"description": p.description,
"price_monthly": p.price_monthly,
"price_yearly": p.price_yearly,
"included_users": p.included_users,
"overage_rate_per_user": p.overage_rate_per_user,
"features": p.features,
"stripe_price_id_monthly": p.stripe_price_id_monthly,
"stripe_price_id_yearly": p.stripe_price_id_yearly,
"is_active": p.is_active,
"created_at": p.created_at.isoformat() + "Z" if p.created_at else None,
} for p in plans]
return api_response(data={"items": items}, message="Plans retrieved successfully")
except Exception as e:
logger.error(f"[Billing] List plans error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/billing/plans", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="plan.create", resource_type="plan")
def create_plan():
"""Create a new plan."""
try:
from gatehouse_app.models.billing.plan import Plan
from marshmallow import Schema, fields, validate
class CreatePlanSchema(Schema):
name = fields.Str(required=True, validate=validate.Length(min=1, max=100))
slug = fields.Str(required=True, validate=validate.Length(min=1, max=50))
description = fields.Str(allow_none=True)
price_monthly = fields.Int(required=True, validate=validate.Range(min=0))
price_yearly = fields.Int(required=True, validate=validate.Range(min=0))
included_users = fields.Int(required=True, validate=validate.Range(min=0))
overage_rate_per_user = fields.Int(required=True, validate=validate.Range(min=0))
features = fields.Dict(allow_none=True)
stripe_price_id_monthly = fields.Str(allow_none=True)
stripe_price_id_yearly = fields.Str(allow_none=True)
data = request.json or {}
schema = CreatePlanSchema()
errors = schema.validate(data)
if errors:
return api_response(
success=False,
message="Validation error",
status=400,
error_type="VALIDATION_ERROR",
error_details=errors,
)
# Check if slug already exists
existing = Plan.query.filter_by(slug=data["slug"]).first()
if existing:
return api_response(
success=False,
message="Plan with this slug already exists",
status=400,
error_type="VALIDATION_ERROR",
)
plan = Plan(
name=data["name"],
slug=data["slug"],
description=data.get("description"),
price_monthly=data["price_monthly"],
price_yearly=data["price_yearly"],
included_users=data["included_users"],
overage_rate_per_user=data["overage_rate_per_user"],
features=data.get("features"),
stripe_price_id_monthly=data.get("stripe_price_id_monthly"),
stripe_price_id_yearly=data.get("stripe_price_id_yearly"),
)
db.session.add(plan)
db.session.commit()
return api_response(data={
"id": plan.id,
"name": plan.name,
"slug": plan.slug,
}, message="Plan created successfully", status=201)
except Exception as e:
db.session.rollback()
logger.error(f"[Billing] Create plan error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/billing/plans/<plan_id>", methods=["GET"])
@superadmin_required
def get_plan(plan_id):
"""Get a single plan by ID."""
try:
from gatehouse_app.models.billing.plan import Plan
plan = Plan.query.get(plan_id)
if not plan:
return api_response(
success=False,
message="Plan not found",
status=404,
error_type="NOT_FOUND",
)
return api_response(data={
"id": plan.id,
"name": plan.name,
"slug": plan.slug,
"description": plan.description,
"price_monthly": plan.price_monthly,
"price_yearly": plan.price_yearly,
"included_users": plan.included_users,
"overage_rate_per_user": plan.overage_rate_per_user,
"features": plan.features,
"stripe_price_id_monthly": plan.stripe_price_id_monthly,
"stripe_price_id_yearly": plan.stripe_price_id_yearly,
"is_active": plan.is_active,
"created_at": plan.created_at.isoformat() + "Z" if plan.created_at else None,
"updated_at": plan.updated_at.isoformat() + "Z" if plan.updated_at else None,
}, message="Plan retrieved successfully")
except Exception as e:
logger.error(f"[Billing] Get plan error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/billing/plans/<plan_id>", methods=["PATCH"])
@superadmin_required
@superadmin_audit_log(action="plan.update", resource_type="plan")
def update_plan(plan_id):
"""Update a plan."""
try:
from gatehouse_app.models.billing.plan import Plan
from marshmallow import Schema, fields, validate
class UpdatePlanSchema(Schema):
name = fields.Str(validate=validate.Length(min=1, max=100))
description = fields.Str(allow_none=True)
price_monthly = fields.Int(validate=validate.Range(min=0))
price_yearly = fields.Int(validate=validate.Range(min=0))
included_users = fields.Int(validate=validate.Range(min=0))
overage_rate_per_user = fields.Int(validate=validate.Range(min=0))
features = fields.Dict(allow_none=True)
stripe_price_id_monthly = fields.Str(allow_none=True)
stripe_price_id_yearly = fields.Str(allow_none=True)
is_active = fields.Bool()
plan = Plan.query.get(plan_id)
if not plan:
return api_response(
success=False,
message="Plan not found",
status=404,
error_type="NOT_FOUND",
)
data = request.json or {}
schema = UpdatePlanSchema()
errors = schema.validate(data)
if errors:
return api_response(
success=False,
message="Validation error",
status=400,
error_type="VALIDATION_ERROR",
error_details=errors,
)
# Update fields
for key, value in data.items():
if hasattr(plan, key):
setattr(plan, key, value)
db.session.commit()
return api_response(data={
"id": plan.id,
"name": plan.name,
"slug": plan.slug,
}, message="Plan updated successfully")
except Exception as e:
db.session.rollback()
logger.error(f"[Billing] Update plan error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/billing/plans/<plan_id>", methods=["DELETE"])
@superadmin_required
@superadmin_audit_log(action="plan.delete", resource_type="plan")
def delete_plan(plan_id):
"""Soft-delete a plan by setting is_active=False."""
try:
from gatehouse_app.models.billing.plan import Plan
plan = Plan.query.get(plan_id)
if not plan:
return api_response(
success=False,
message="Plan not found",
status=404,
error_type="NOT_FOUND",
)
plan.is_active = False
db.session.commit()
return api_response(data={"id": plan.id}, message="Plan deleted successfully")
except Exception as e:
db.session.rollback()
logger.error(f"[Billing] Delete plan error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
# ============ Subscriptions Endpoints ============
@superadmin_bp.route("/billing/subscriptions", methods=["GET"])
@superadmin_required
def list_subscriptions():
"""Get all subscriptions with optional filters."""
try:
from gatehouse_app.models.billing.subscription import Subscription
from gatehouse_app.models.billing.plan import Plan
page = max(1, int(request.args.get("page", 1)))
per_page = min(100, max(1, int(request.args.get("per_page", 20))))
plan_id = request.args.get("plan_id")
status = request.args.get("status")
query = Subscription.query
if plan_id:
query = query.filter(Subscription.plan_id == plan_id)
if status:
query = query.filter(Subscription.status == status)
total = query.count()
subs = query.order_by(Subscription.created_at.desc()).offset((page - 1) * per_page).limit(per_page).all()
items = []
for sub in subs:
org = Organization.query.get(sub.organization_id)
plan = Plan.query.get(sub.plan_id) if sub.plan_id else None
# Calculate MRR
if plan and sub.status == "active":
mrr = plan.price_monthly if sub.billing_cycle == "monthly" else plan.price_yearly // 12
else:
mrr = 0
items.append({
"id": sub.id,
"organization_id": sub.organization_id,
"org_name": org.name if org else "Unknown",
"plan_id": sub.plan_id,
"plan_name": plan.name if plan else "Unknown",
"status": sub.status,
"billing_cycle": sub.billing_cycle,
"mrr": mrr,
"current_period_start": sub.current_period_start.isoformat() + "Z" if sub.current_period_start else None,
"current_period_end": sub.current_period_end.isoformat() + "Z" if sub.current_period_end else None,
"trial_ends_at": sub.trial_ends_at.isoformat() + "Z" if sub.trial_ends_at else None,
"cancel_at_period_end": sub.cancel_at_period_end,
"created_at": sub.created_at.isoformat() + "Z" if sub.created_at else None,
})
return api_response(data={
"items": items,
"total": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page if per_page > 0 else 0,
}, message="Subscriptions retrieved successfully")
except Exception as e:
logger.error(f"[Billing] List subscriptions error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/billing/subscriptions/<sub_id>", methods=["GET"])
@superadmin_required
def get_subscription(sub_id):
"""Get a single subscription."""
try:
from gatehouse_app.models.billing.subscription import Subscription
from gatehouse_app.models.billing.plan import Plan
sub = Subscription.query.get(sub_id)
if not sub:
return api_response(
success=False,
message="Subscription not found",
status=404,
error_type="NOT_FOUND",
)
org = Organization.query.get(sub.organization_id)
plan = Plan.query.get(sub.plan_id) if sub.plan_id else None
return api_response(data={
"id": sub.id,
"organization_id": sub.organization_id,
"org_name": org.name if org else "Unknown",
"plan_id": sub.plan_id,
"plan_name": plan.name if plan else None,
"status": sub.status,
"billing_cycle": sub.billing_cycle,
"current_period_start": sub.current_period_start.isoformat() + "Z" if sub.current_period_start else None,
"current_period_end": sub.current_period_end.isoformat() + "Z" if sub.current_period_end else None,
"trial_ends_at": sub.trial_ends_at.isoformat() + "Z" if sub.trial_ends_at else None,
"stripe_subscription_id": sub.stripe_subscription_id,
"overage_enabled": sub.overage_enabled,
"cancelled_at": sub.cancelled_at.isoformat() + "Z" if sub.cancelled_at else None,
"cancel_at_period_end": sub.cancel_at_period_end,
}, message="Subscription retrieved successfully")
except Exception as e:
logger.error(f"[Billing] Get subscription error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/billing/subscriptions/<org_id>", methods=["PATCH"])
@superadmin_required
@superadmin_audit_log(action="subscription.update", resource_type="subscription")
def update_subscription(org_id):
"""Update subscription plan or billing cycle for an organization."""
try:
from gatehouse_app.models.billing.subscription import Subscription
from gatehouse_app.models.billing.plan import Plan
org = Organization.query.get(org_id)
if not org:
return api_response(
success=False,
message="Organization not found",
status=404,
error_type="NOT_FOUND",
)
sub = Subscription.query.filter_by(organization_id=org_id).first()
if not sub:
return api_response(
success=False,
message="No subscription found for this organization",
status=404,
error_type="NOT_FOUND",
)
data = request.json or {}
if "plan_id" in data:
plan = Plan.query.get(data["plan_id"])
if not plan:
return api_response(
success=False,
message="Plan not found",
status=404,
error_type="NOT_FOUND",
)
sub.plan_id = data["plan_id"]
if "billing_cycle" in data:
if data["billing_cycle"] not in ["monthly", "yearly"]:
return api_response(
success=False,
message="Invalid billing cycle. Must be 'monthly' or 'yearly'",
status=400,
error_type="VALIDATION_ERROR",
)
sub.billing_cycle = data["billing_cycle"]
db.session.commit()
return api_response(data={
"id": sub.id,
"organization_id": sub.organization_id,
"plan_id": sub.plan_id,
"billing_cycle": sub.billing_cycle,
}, message="Subscription updated successfully")
except Exception as e:
db.session.rollback()
logger.error(f"[Billing] Update subscription error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/billing/subscriptions/<org_id>/cancel", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="subscription.cancel", resource_type="subscription")
def cancel_subscription(org_id):
"""Cancel subscription at period end."""
try:
from gatehouse_app.models.billing.subscription import Subscription
org = Organization.query.get(org_id)
if not org:
return api_response(
success=False,
message="Organization not found",
status=404,
error_type="NOT_FOUND",
)
sub = Subscription.query.filter_by(organization_id=org_id).first()
if not sub:
return api_response(
success=False,
message="No subscription found for this organization",
status=404,
error_type="NOT_FOUND",
)
sub.cancel_at_period_end = True
sub.status = "cancelled"
db.session.commit()
return api_response(data={
"id": sub.id,
"cancel_at_period_end": True,
"status": sub.status,
}, message="Subscription will be cancelled at period end")
except Exception as e:
db.session.rollback()
logger.error(f"[Billing] Cancel subscription error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/billing/subscriptions/<org_id>/trial", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="subscription.extend_trial", resource_type="subscription")
def extend_trial(org_id):
"""Extend trial period for an organization."""
try:
from gatehouse_app.models.billing.subscription import Subscription
from datetime import timedelta
data = request.json or {}
days = data.get("days", 30)
if not isinstance(days, int) or days < 1:
return api_response(
success=False,
message="Days must be a positive integer",
status=400,
error_type="VALIDATION_ERROR",
)
org = Organization.query.get(org_id)
if not org:
return api_response(
success=False,
message="Organization not found",
status=404,
error_type="NOT_FOUND",
)
sub = Subscription.query.filter_by(organization_id=org_id).first()
if not sub:
return api_response(
success=False,
message="No subscription found for this organization",
status=404,
error_type="NOT_FOUND",
)
# Extend trial
if sub.trial_ends_at:
sub.trial_ends_at = sub.trial_ends_at + timedelta(days=days)
else:
sub.trial_ends_at = datetime.now(timezone.utc) + timedelta(days=days)
sub.status = "trial"
db.session.commit()
return api_response(data={
"id": sub.id,
"trial_ends_at": sub.trial_ends_at.isoformat() + "Z",
"days_added": days,
}, message=f"Trial extended by {days} days")
except Exception as e:
db.session.rollback()
logger.error(f"[Billing] Extend trial error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
+56
View File
@@ -0,0 +1,56 @@
"""Superadmin SSH CA management endpoints."""
import logging
from flask import request
from gatehouse_app.api.v1.superadmin import superadmin_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
from gatehouse_app.extensions import db
logger = logging.getLogger(__name__)
@superadmin_bp.route("/organizations/<org_id>/cas/<ca_id>", methods=["DELETE"])
@superadmin_required
@superadmin_audit_log(action="ca.delete", resource_type="CA")
def delete_org_ca(org_id, ca_id):
"""Soft-delete an SSH CA for an organization.
Sets is_active=False and deleted_at=now().
"""
from gatehouse_app.models.ssh_ca.ca import CA
from gatehouse_app.models.organization.organization import Organization
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
if not org:
return api_response(
success=False,
message="Organization not found",
status=404,
error_type="NOT_FOUND"
)
ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first()
if not ca:
return api_response(
success=False,
message="CA not found",
status=404,
error_type="NOT_FOUND"
)
try:
ca.is_active = False
ca.delete(soft=True)
db.session.commit()
return api_response(data={"ca_id": ca_id}, message="CA deleted successfully")
except Exception:
db.session.rollback()
logger.exception(f"Failed to delete CA {ca_id}")
return api_response(
success=False,
message="Failed to delete CA",
status=500,
error_type="SERVER_ERROR"
)
@@ -0,0 +1,456 @@
"""Superadmin organization member management endpoints."""
import logging
from flask import request, g
from marshmallow import ValidationError
from gatehouse_app.api.v1.superadmin import superadmin_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.user.user import User
from gatehouse_app.extensions import db
from gatehouse_app.utils.constants import OrganizationRole
logger = logging.getLogger(__name__)
class ListMembersSchema:
"""Schema for list members query params."""
@staticmethod
def load(args):
"""Parse and validate query parameters."""
try:
page = max(1, int(args.get("page", 1)))
per_page = min(100, max(1, int(args.get("per_page", 20))))
except (ValueError, TypeError):
page = 1
per_page = 20
search = args.get("search")
role = args.get("role")
return {
"page": page,
"per_page": per_page,
"search": search,
"role": role,
}
class AddMemberSchema:
"""Schema for adding a member."""
@staticmethod
def load(data):
"""Parse and validate add member data."""
errors = {}
user_id = data.get("user_id")
if not user_id:
errors["user_id"] = ["User ID is required"]
role_str = data.get("role", "member")
try:
role = OrganizationRole(role_str)
except ValueError:
errors["role"] = [f"Invalid role. Must be one of: {', '.join(r.value for r in OrganizationRole)}"]
if errors:
raise ValidationError(errors)
return {"user_id": user_id, "role": role}
class UpdateMemberSchema:
"""Schema for updating a member role."""
@staticmethod
def load(data):
"""Parse and validate update data."""
errors = {}
role_str = data.get("role")
if not role_str:
errors["role"] = ["Role is required"]
try:
role = OrganizationRole(role_str)
except ValueError:
errors["role"] = [f"Invalid role. Must be one of: {', '.join(r.value for r in OrganizationRole)}"]
if errors:
raise ValidationError(errors)
return {"role": role}
@superadmin_bp.route("/organizations/<org_id>/members", methods=["GET"])
@superadmin_required
def list_organization_members(org_id):
"""List all members of an organization.
Query params:
page: Page number (default 1)
per_page: Items per page (default 20)
search: Search by user email or name
role: Filter by role
"""
try:
# Verify org exists
org = Organization.query.get(org_id)
if not org:
return api_response(
success=False,
message="Organization not found",
status=404,
error_type="NOT_FOUND",
)
schema = ListMembersSchema()
params = schema.load(request.args)
query = OrganizationMember.query.filter_by(organization_id=org_id, deleted_at=None)
# Search by user email or name
if params["search"]:
search_term = f"%{params['search']}%"
query = query.join(User).filter(
db.or_(
User.email.ilike(search_term),
User.full_name.ilike(search_term),
)
)
# Filter by role
if params["role"]:
try:
role = OrganizationRole(params["role"])
query = query.filter(OrganizationMember.role == role)
except ValueError:
pass # Ignore invalid role filter
# Order by joined_at desc
query = query.order_by(OrganizationMember.joined_at.desc())
# Paginate
pagination = query.paginate(page=params["page"], per_page=params["per_page"], error_out=False)
# Build response
items = []
for member in pagination.items:
user = User.query.get(member.user_id)
item = {
"user_id": member.user_id,
"organization_id": member.organization_id,
"role": member.role.value,
"joined_at": member.joined_at.isoformat() + "Z" if member.joined_at else None,
"user": {
"id": user.id,
"email": user.email,
"full_name": user.full_name,
"is_active": user.is_active,
} if user else {"id": member.user_id, "email": "[deleted]", "full_name": None, "is_active": False},
}
items.append(item)
return api_response(
data={
"items": items,
"total": pagination.total,
"page": params["page"],
"per_page": params["per_page"],
"pages": pagination.pages,
},
message="Members retrieved successfully",
)
except Exception as e:
logger.error(f"[SuperadminOrg] List members error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/organizations/<org_id>/members", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="add_member", resource_type="organization_member")
def add_organization_member(org_id):
"""Add a user to an organization.
Body:
user_id: User UUID to add
role: Role (owner, admin, member, guest)
"""
try:
schema = AddMemberSchema()
data = schema.load(request.json or {})
# Verify org exists
org = Organization.query.get(org_id)
if not org:
return api_response(
success=False,
message="Organization not found",
status=404,
error_type="NOT_FOUND",
)
# Verify user exists
user = User.query.get(data["user_id"])
if not user:
return api_response(
success=False,
message="User not found",
status=404,
error_type="NOT_FOUND",
)
# Check if already a member
existing = OrganizationMember.query.filter_by(
user_id=data["user_id"],
organization_id=org_id,
deleted_at=None,
).first()
if existing:
return api_response(
success=False,
message="User is already a member of this organization",
status=400,
error_type="ALREADY_EXISTS",
)
# Create membership
member = OrganizationMember(
user_id=data["user_id"],
organization_id=org_id,
role=data["role"],
invited_by_id=g.current_superadmin.id,
invited_at=db.func.now(),
joined_at=db.func.now(),
)
db.session.add(member)
db.session.commit()
logger.info(f"[SuperadminOrg] Added user {data['user_id']} to org {org_id} as {data['role'].value}")
return api_response(
data={
"member": {
"user_id": member.user_id,
"organization_id": member.organization_id,
"role": member.role.value,
"joined_at": member.joined_at.isoformat() + "Z" if member.joined_at else None,
}
},
message="Member added successfully",
status=201,
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except Exception as e:
logger.error(f"[SuperadminOrg] Add member error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/organizations/<org_id>/members/<user_id>", methods=["PATCH"])
@superadmin_required
@superadmin_audit_log(action="update_member_role", resource_type="organization_member")
def update_organization_member(org_id, user_id):
"""Update a member's role.
Body:
role: New role (owner, admin, member, guest)
"""
try:
schema = UpdateMemberSchema()
data = schema.load(request.json or {})
# Find member
member = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org_id,
deleted_at=None,
).first()
if not member:
return api_response(
success=False,
message="Member not found",
status=404,
error_type="NOT_FOUND",
)
# Update role
old_role = member.role
member.role = data["role"]
db.session.commit()
logger.info(f"[SuperadminOrg] Updated member {user_id} role from {old_role.value} to {data['role'].value}")
return api_response(
data={
"member": {
"user_id": member.user_id,
"organization_id": member.organization_id,
"role": member.role.value,
}
},
message="Member role updated successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except Exception as e:
logger.error(f"[SuperadminOrg] Update member error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/organizations/<org_id>/members/<user_id>", methods=["DELETE"])
@superadmin_required
@superadmin_audit_log(action="remove_member", resource_type="organization_member")
def remove_organization_member(org_id, user_id):
"""Remove a user from an organization."""
try:
# Find member
member = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org_id,
deleted_at=None,
).first()
if not member:
return api_response(
success=False,
message="Member not found",
status=404,
error_type="NOT_FOUND",
)
# Prevent removing owner without transferring ownership
if member.role == OrganizationRole.OWNER:
return api_response(
success=False,
message="Cannot remove organization owner. Transfer ownership first.",
status=400,
error_type="AUTHORIZATION_ERROR",
)
# Soft delete
from datetime import datetime, timezone
member.deleted_at = datetime.now(timezone.utc)
db.session.commit()
logger.info(f"[SuperadminOrg] Removed user {user_id} from org {org_id}")
return api_response(message="Member removed successfully")
except Exception as e:
logger.error(f"[SuperadminOrg] Remove member error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/organizations/<org_id>/transfer-ownership/<user_id>", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="transfer_ownership", resource_type="organization_member")
def transfer_organization_ownership(org_id, user_id):
"""Transfer organization ownership to another member.
The target user must already be a member of the organization.
The current owner will be changed to admin role.
"""
try:
# Verify org exists
org = Organization.query.get(org_id)
if not org:
return api_response(
success=False,
message="Organization not found",
status=404,
error_type="NOT_FOUND",
)
# Find current owner
current_owner = OrganizationMember.query.filter_by(
organization_id=org_id,
role=OrganizationRole.OWNER,
deleted_at=None,
).first()
if not current_owner:
return api_response(
success=False,
message="Current owner not found",
status=404,
error_type="NOT_FOUND",
)
# Find target user as member
target_member = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org_id,
deleted_at=None,
).first()
if not target_member:
return api_response(
success=False,
message="Target user is not a member of this organization",
status=400,
error_type="AUTHORIZATION_ERROR",
)
# Transfer ownership
current_owner.role = OrganizationRole.ADMIN
target_member.role = OrganizationRole.OWNER
db.session.commit()
logger.warning(
f"[SuperadminOrg] TRANSFERRED OWNERSHIP: org={org_id} from user={current_owner.user_id} to user={user_id}"
)
return api_response(
data={
"owner": {
"user_id": target_member.user_id,
"role": target_member.role.value,
}
},
message="Ownership transferred successfully",
)
except Exception as e:
logger.error(f"[SuperadminOrg] Transfer ownership error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@@ -0,0 +1,254 @@
"""Superadmin organization management endpoints."""
import logging
from flask import request
from gatehouse_app.api.v1.superadmin import superadmin_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.services.superadmin_organization_service import SuperadminOrganizationService
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
logger = logging.getLogger(__name__)
class ListOrganizationsSchema:
"""Schema for list organizations query params."""
@staticmethod
def load(args):
"""Parse and validate query parameters."""
try:
page = max(1, int(args.get("page", 1)))
per_page = min(100, max(1, int(args.get("per_page", 20))))
except (ValueError, TypeError):
page = 1
per_page = 20
search = args.get("search")
status = args.get("status")
plan_slug = args.get("plan_slug")
return {
"page": page,
"per_page": per_page,
"search": search,
"status": status,
"plan_slug": plan_slug,
}
class UpdateOrganizationSchema:
"""Schema for updating an organization."""
@staticmethod
def load(data):
"""Parse and validate update data."""
result = {}
if "name" in data:
result["name"] = data["name"]
if "description" in data:
result["description"] = data["description"]
if "is_active" in data:
result["is_active"] = bool(data["is_active"])
return result
@superadmin_bp.route("/organizations", methods=["GET"])
@superadmin_required
def list_organizations():
"""List all organizations with pagination and filtering.
Query params:
page: Page number (default 1)
per_page: Items per page (default 20, max 100)
search: Search by name or slug
status: Filter by status (active, suspended)
plan_slug: Filter by plan slug
"""
try:
schema = ListOrganizationsSchema()
params = schema.load(request.args)
result = SuperadminOrganizationService.list_organizations(
page=params["page"],
per_page=params["per_page"],
search=params["search"],
status=params["status"],
plan_slug=params["plan_slug"],
)
return api_response(data=result, message="Organizations retrieved successfully")
except Exception as e:
logger.error(f"[SuperadminOrg] List organizations error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/organizations/<org_id>", methods=["GET"])
@superadmin_required
def get_organization(org_id):
"""Get detailed organization information.
Returns org details including member count, owner info, and active sessions.
"""
try:
result = SuperadminOrganizationService.get_organization_detail(org_id)
return api_response(data={"organization": result}, message="Organization retrieved successfully")
except ValueError as e:
return api_response(
success=False,
message=str(e),
status=404,
error_type="NOT_FOUND",
)
except Exception as e:
logger.error(f"[SuperadminOrg] Get organization error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/organizations/<org_id>", methods=["PATCH"])
@superadmin_required
@superadmin_audit_log(action="update_organization", resource_type="organization")
def update_organization(org_id):
"""Update organization details.
Body:
name: New name (optional)
description: New description (optional)
is_active: New active status (optional)
"""
try:
schema = UpdateOrganizationSchema()
data = schema.load(request.json or {})
if not data:
return api_response(
success=False,
message="No update data provided",
status=400,
error_type="VALIDATION_ERROR",
)
org = SuperadminOrganizationService.update_organization(org_id, **data)
return api_response(
data={"organization": org.to_dict()},
message="Organization updated successfully",
)
except ValueError as e:
return api_response(
success=False,
message=str(e),
status=404,
error_type="NOT_FOUND",
)
except Exception as e:
logger.error(f"[SuperadminOrg] Update organization error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/organizations/<org_id>/suspend", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="suspend_organization", resource_type="organization")
def suspend_organization(org_id):
"""Suspend an organization.
Sets is_active=False and invalidates all member sessions.
"""
try:
org = SuperadminOrganizationService.suspend_organization(org_id)
return api_response(
data={"organization": org.to_dict()},
message="Organization suspended successfully",
)
except ValueError as e:
return api_response(
success=False,
message=str(e),
status=404,
error_type="NOT_FOUND",
)
except Exception as e:
logger.error(f"[SuperadminOrg] Suspend organization error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/organizations/<org_id>/unsuspend", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="unsuspend_organization", resource_type="organization")
def unsuspend_organization(org_id):
"""Restore a suspended organization."""
try:
org = SuperadminOrganizationService.restore_organization(org_id)
return api_response(
data={"organization": org.to_dict()},
message="Organization restored successfully",
)
except ValueError as e:
return api_response(
success=False,
message=str(e),
status=404,
error_type="NOT_FOUND",
)
except Exception as e:
logger.error(f"[SuperadminOrg] Unsuspend organization error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/organizations/<org_id>", methods=["DELETE"])
@superadmin_required
@superadmin_audit_log(action="delete_organization", resource_type="organization")
def delete_organization(org_id):
"""Soft-delete an organization."""
try:
SuperadminOrganizationService.soft_delete_organization(org_id)
return api_response(message="Organization deleted successfully")
except ValueError as e:
return api_response(
success=False,
message=str(e),
status=404,
error_type="NOT_FOUND",
)
except Exception as e:
logger.error(f"[SuperadminOrg] Delete organization error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@@ -0,0 +1,330 @@
"""Superadmin usage and analytics endpoints."""
import logging
from datetime import datetime, timezone
from flask import request
from gatehouse_app.api.v1.superadmin import superadmin_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.services.superadmin_usage_service import SuperadminUsageService
from gatehouse_app.services.superadmin_analytics_service import SuperadminAnalyticsService
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
logger = logging.getLogger(__name__)
# ============ Analytics Endpoints ============
@superadmin_bp.route("/analytics/dashboard", methods=["GET"])
@superadmin_required
def get_dashboard_stats():
"""Get dashboard statistics for the overview page.
Returns aggregated stats: org counts, user counts, sessions, recent signups.
"""
try:
stats = SuperadminAnalyticsService.get_dashboard_stats()
return api_response(data=stats, message="Dashboard stats retrieved successfully")
except Exception as e:
logger.error(f"[SuperadminAnalytics] Dashboard stats error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/analytics/signup-trends", methods=["GET"])
@superadmin_required
def get_signup_trends():
"""Get signup trends over time.
Query params:
days: Number of days to analyze (default 30, max 365)
"""
try:
days = min(365, max(1, int(request.args.get("days", 30))))
trends = SuperadminAnalyticsService.get_signup_trends(days)
return api_response(data=trends, message="Signup trends retrieved successfully")
except Exception as e:
logger.error(f"[SuperadminAnalytics] Signup trends error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/analytics/org-distribution", methods=["GET"])
@superadmin_required
def get_org_distribution():
"""Get organization distribution by size."""
try:
distribution = SuperadminAnalyticsService.get_org_distribution()
return api_response(data=distribution, message="Organization distribution retrieved successfully")
except Exception as e:
logger.error(f"[SuperadminAnalytics] Org distribution error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/analytics/recent-activity", methods=["GET"])
@superadmin_required
def get_recent_activity():
"""Get recent superadmin actions.
Query params:
limit: Maximum number of entries (default 20, max 100)
"""
try:
limit = min(100, max(1, int(request.args.get("limit", 20))))
activity = SuperadminAnalyticsService.get_recent_activity(limit)
return api_response(data={"items": activity}, message="Recent activity retrieved successfully")
except Exception as e:
logger.error(f"[SuperadminAnalytics] Recent activity error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
# ============ Usage Endpoints ============
@superadmin_bp.route("/usage/<org_id>", methods=["GET"])
@superadmin_required
def get_organization_usage(org_id):
"""Get current usage for an organization.
Returns current period usage metrics: user count, active sessions.
"""
try:
usage = SuperadminUsageService.get_current_usage(org_id)
return api_response(data=usage, message="Usage retrieved successfully")
except ValueError as e:
return api_response(
success=False,
message=str(e),
status=404,
error_type="NOT_FOUND",
)
except Exception as e:
logger.error(f"[SuperadminUsage] Get usage error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/usage/<org_id>/history", methods=["GET"])
@superadmin_required
def get_usage_history(org_id):
"""Get usage history for an organization.
Query params:
metric: Metric type (users, sessions) - default users
days: Number of days of history (default 30, max 365)
"""
try:
metric = request.args.get("metric", "users")
days = min(365, max(1, int(request.args.get("days", 30))))
history = SuperadminUsageService.get_usage_history(org_id, metric, days)
return api_response(data=history, message="Usage history retrieved successfully")
except ValueError as e:
return api_response(
success=False,
message=str(e),
status=404,
error_type="NOT_FOUND",
)
except Exception as e:
logger.error(f"[SuperadminUsage] Get usage history error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/usage/<org_id>/seats", methods=["GET"])
@superadmin_required
def get_seat_count(org_id):
"""Get maximum seat count for billing period.
Query params:
year: Year (default current year)
month: Month (default current month)
"""
try:
year = int(request.args.get("year", datetime.now().year))
month = int(request.args.get("month", datetime.now().month))
seats = SuperadminUsageService.get_seat_count_for_period(org_id, year, month)
return api_response(data=seats, message="Seat count retrieved successfully")
except ValueError as e:
return api_response(
success=False,
message=str(e),
status=404,
error_type="NOT_FOUND",
)
except Exception as e:
logger.error(f"[SuperadminUsage] Get seat count error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/usage/<org_id>/adjust", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="usage_adjustment", resource_type="usage")
def adjust_usage(org_id):
"""Apply a manual usage adjustment.
Body:
metric: Metric to adjust
adjustment: Positive (credit) or negative (charge)
reason: Reason for adjustment
"""
try:
data = request.json or {}
metric = data.get("metric", "users")
adjustment = data.get("adjustment", 0)
reason = data.get("reason", "")
if not reason:
return api_response(
success=False,
message="Reason is required for usage adjustment",
status=400,
error_type="VALIDATION_ERROR",
)
result = SuperadminUsageService.adjust_usage(
org_id=org_id,
metric=metric,
adjustment=adjustment,
reason=reason,
superadmin_id="", # Will be filled from decorator
)
return api_response(data=result, message="Usage adjustment applied successfully")
except ValueError as e:
return api_response(
success=False,
message=str(e),
status=404,
error_type="NOT_FOUND",
)
except Exception as e:
logger.error(f"[SuperadminUsage] Adjust usage error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
# ============ Invoice Data Endpoint ============
@superadmin_bp.route("/invoice-data/<org_id>", methods=["GET"])
@superadmin_required
def get_invoice_data(org_id):
"""Get all data needed to generate an invoice for an organization.
Returns organization info, plan details, usage, and subscription status.
"""
try:
from datetime import datetime
from gatehouse_app.models.organization.organization import Organization
org = Organization.query.get(org_id)
if not org:
return api_response(
success=False,
message="Organization not found",
status=404,
error_type="NOT_FOUND",
)
# Get seat count for current month
now = datetime.now(timezone.utc)
seats = SuperadminUsageService.get_seat_count_for_period(org_id, now.year, now.month)
# Get owner
owner = org.get_owner()
invoice_data = {
"organization": {
"id": org.id,
"name": org.name,
"slug": org.slug,
"owner_email": owner.email if owner else None,
"created_at": org.created_at.isoformat() + "Z" if org.created_at else None,
},
"billing_period": {
"year": now.year,
"month": now.month,
"start": seats["period_start"],
"end": seats["period_end"],
},
"usage": {
"max_seats": seats["max_seats"],
"current_seats": seats["current_seats"],
"included_seats": 0, # Would come from plan
"overage": max(0, seats["current_seats"]), # Simplified
},
"subscription": {
"status": "active" if org.is_active else "suspended",
"is_active": org.is_active,
},
"line_items": [
{
"description": "Base subscription",
"quantity": 1,
"unit_price": 0, # Would come from plan
"total": 0,
},
{
"description": f"User seats ({seats['current_seats']})",
"quantity": seats["current_seats"],
"unit_price": 0, # Would come from plan per-seat price
"total": 0,
},
],
"total": 0, # Would be calculated
"generated_at": now.isoformat() + "Z",
}
return api_response(data=invoice_data, message="Invoice data retrieved successfully")
except Exception as e:
logger.error(f"[SuperadminAnalytics] Invoice data error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
+516
View File
@@ -0,0 +1,516 @@
"""Superadmin user management endpoints."""
import logging
from flask import request, g
from gatehouse_app.api.v1.superadmin import superadmin_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
from gatehouse_app.models.user.user import User
from gatehouse_app.models.user.session import Session
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.extensions import db
logger = logging.getLogger(__name__)
@superadmin_bp.route("/users", methods=["GET"])
@superadmin_required
def list_users():
"""Get paginated list of users with optional filters.
Query params:
page: Page number (default 1)
per_page: Items per page (default 20, max 100)
organization_id: Filter by organization
status: Filter by status (active/suspended)
search: Search by email or name
"""
try:
page = max(1, int(request.args.get("page", 1)))
per_page = min(100, max(1, int(request.args.get("per_page", 20))))
org_id = request.args.get("organization_id")
status = request.args.get("status")
search = request.args.get("search", "").strip()
# Base query
query = User.query.filter(User.deleted_at.is_(None))
# Filter by organization
if org_id:
member_user_ids = db.session.query(OrganizationMember.user_id).filter(
OrganizationMember.organization_id == org_id,
OrganizationMember.deleted_at.is_(None),
).all()
user_ids = [m.user_id for m in member_user_ids]
query = query.filter(User.id.in_(user_ids))
# Filter by status
if status == "suspended":
query = query.filter(User.status == "GLOBAL_SUSPENDED")
elif status == "active":
query = query.filter(User.status != "GLOBAL_SUSPENDED")
# Search by email or name
if search:
search_filter = f"%{search}%"
query = query.filter(
db.or_(
User.email.ilike(search_filter),
User.full_name.ilike(search_filter),
)
)
# Order by created_at desc
query = query.order_by(User.created_at.desc())
# Paginate
total = query.count()
users = query.offset((page - 1) * per_page).limit(per_page).all()
# Get org memberships for each user
items = []
for user in users:
# Get organization memberships
memberships = db.session.query(OrganizationMember).filter(
OrganizationMember.user_id == user.id,
OrganizationMember.deleted_at.is_(None),
).all()
orgs = []
for m in memberships:
org = Organization.query.get(m.organization_id)
if org:
orgs.append({
"org_id": org.id,
"org_name": org.name,
"role": m.role,
"joined_at": m.created_at.isoformat() + "Z" if m.created_at else None,
})
# Get active session count
active_sessions = Session.query.filter(
Session.user_id == user.id,
Session.deleted_at.is_(None),
Session.status == "active",
).count()
items.append({
"id": user.id,
"email": user.email,
"full_name": user.full_name,
"status": user.status,
"mfa_enabled": user.mfa_enabled if hasattr(user, 'mfa_enabled') else False,
"org_count": len(orgs),
"orgs": orgs,
"active_sessions": active_sessions,
"last_login_at": user.last_login_at.isoformat() + "Z" if user.last_login_at else None,
"created_at": user.created_at.isoformat() + "Z" if user.created_at else None,
})
return api_response(data={
"items": items,
"total": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page if per_page > 0 else 0,
}, message="Users retrieved successfully")
except Exception as e:
logger.error(f"[SuperadminUsers] List users error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/users/<user_id>", methods=["GET"])
@superadmin_required
def get_user(user_id):
"""Get detailed user information.
Returns user + all org memberships + active sessions + security methods.
"""
try:
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
return api_response(
success=False,
message="User not found",
status=404,
error_type="NOT_FOUND",
)
# Get organization memberships
memberships = db.session.query(OrganizationMember).filter(
OrganizationMember.user_id == user_id,
OrganizationMember.deleted_at.is_(None),
).all()
orgs = []
for m in memberships:
org = Organization.query.get(m.organization_id)
if org:
orgs.append({
"org_id": org.id,
"org_name": org.name,
"org_slug": org.slug,
"role": m.role,
"joined_at": m.created_at.isoformat() + "Z" if m.created_at else None,
})
# Get active sessions
sessions = Session.query.filter(
Session.user_id == user_id,
Session.deleted_at.is_(None),
Session.status == "active",
).all()
active_sessions = [{
"id": s.id,
"ip_address": s.ip_address,
"user_agent": s.user_agent,
"created_at": s.created_at.isoformat() + "Z" if s.created_at else None,
"last_active_at": s.last_active_at.isoformat() + "Z" if hasattr(s, 'last_active_at') and s.last_active_at else None,
} for s in sessions]
# Get security methods (simplified - would need UserSecurityMethod model)
security_methods = []
if hasattr(user, 'totp_enabled') and user.totp_enabled:
security_methods.append({"type": "totp", "enabled": True})
if hasattr(user, 'webauthn_enabled') and user.webauthn_enabled:
security_methods.append({"type": "webauthn", "enabled": True})
return api_response(data={
"user": {
"id": user.id,
"email": user.email,
"full_name": user.full_name,
"status": user.status,
"mfa_enabled": user.mfa_enabled if hasattr(user, 'mfa_enabled') else False,
"last_login_at": user.last_login_at.isoformat() + "Z" if user.last_login_at else None,
"created_at": user.created_at.isoformat() + "Z" if user.created_at else None,
},
"organizations": orgs,
"active_sessions": active_sessions,
"security_methods": security_methods,
}, message="User retrieved successfully")
except Exception as e:
logger.error(f"[SuperadminUsers] Get user error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/users/<user_id>/suspend", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="user.suspend", resource_type="user")
def suspend_user(user_id):
"""Globally suspend a user (sets status=GLOBAL_SUSPENDED)."""
try:
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
return api_response(
success=False,
message="User not found",
status=404,
error_type="NOT_FOUND",
)
if user.status == "GLOBAL_SUSPENDED":
return api_response(
success=False,
message="User is already suspended",
status=400,
error_type="VALIDATION_ERROR",
)
user.status = "GLOBAL_SUSPENDED"
db.session.commit()
# Revoke all sessions
revoked_count = Session.query.filter(
Session.user_id == user_id,
Session.deleted_at.is_(None),
).update({"status": "revoked", "deleted_at": db.func.now()})
db.session.commit()
logger.warning(f"[SuperadminUsers] User {user_id} globally suspended by {getattr(g, 'current_superadmin', {}).get('id', 'unknown')}")
return api_response(data={
"user": {
"id": user.id,
"email": user.email,
"status": user.status,
},
"sessions_revoked": revoked_count,
}, message="User suspended successfully")
except Exception as e:
db.session.rollback()
logger.error(f"[SuperadminUsers] Suspend user error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/users/<user_id>/unsuspend", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="user.unsuspend", resource_type="user")
def unsuspend_user(user_id):
"""Remove global suspension from a user."""
try:
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
return api_response(
success=False,
message="User not found",
status=404,
error_type="NOT_FOUND",
)
if user.status != "GLOBAL_SUSPENDED":
return api_response(
success=False,
message="User is not suspended",
status=400,
error_type="VALIDATION_ERROR",
)
user.status = "active"
db.session.commit()
return api_response(data={
"user": {
"id": user.id,
"email": user.email,
"status": user.status,
},
}, message="User unsuspended successfully")
except Exception as e:
db.session.rollback()
logger.error(f"[SuperadminUsers] Unsuspend user error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/users/<user_id>/reset-password", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="user.reset_password", resource_type="user")
def reset_user_password(user_id):
"""Trigger password reset email flow for user."""
try:
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
return api_response(
success=False,
message="User not found",
status=404,
error_type="NOT_FOUND",
)
# In production, this would call AuthService.send_password_reset_email(user.email)
# For now, just log and return success
logger.info(f"[SuperadminUsers] Password reset requested for {user.email} by superadmin")
return api_response(data={
"email": user.email,
}, message="Password reset email sent successfully")
except Exception as e:
logger.error(f"[SuperadminUsers] Reset password error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/users/<user_id>/sessions", methods=["DELETE"])
@superadmin_required
@superadmin_audit_log(action="user.revoke_sessions", resource_type="user")
def revoke_user_sessions(user_id):
"""Revoke all sessions for a user."""
try:
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
return api_response(
success=False,
message="User not found",
status=404,
error_type="NOT_FOUND",
)
# Revoke all sessions
result = Session.query.filter(
Session.user_id == user_id,
Session.deleted_at.is_(None),
).update({"status": "revoked", "deleted_at": db.func.now()})
db.session.commit()
return api_response(data={
"user_id": user_id,
"count": result,
}, message=f"All sessions revoked ({result} sessions)")
except Exception as e:
db.session.rollback()
logger.error(f"[SuperadminUsers] Revoke sessions error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/users/<user_id>/add-to-org/<org_id>", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action="user.add_to_org", resource_type="user")
def add_user_to_org(user_id, org_id):
"""Add a user to an organization with specified role."""
try:
data = request.json or {}
role = data.get("role", "member")
valid_roles = ["member", "admin", "owner"]
if role not in valid_roles:
return api_response(
success=False,
message=f"Invalid role. Must be one of: {', '.join(valid_roles)}",
status=400,
error_type="VALIDATION_ERROR",
)
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
return api_response(
success=False,
message="User not found",
status=404,
error_type="NOT_FOUND",
)
org = Organization.query.get(org_id)
if not org or org.deleted_at is not None:
return api_response(
success=False,
message="Organization not found",
status=404,
error_type="NOT_FOUND",
)
# Check if already a member
existing = OrganizationMember.query.filter(
OrganizationMember.user_id == user_id,
OrganizationMember.organization_id == org_id,
OrganizationMember.deleted_at.is_(None),
).first()
if existing:
return api_response(
success=False,
message="User is already a member of this organization",
status=400,
error_type="VALIDATION_ERROR",
)
# Create membership
membership = OrganizationMember(
user_id=user_id,
organization_id=org_id,
role=role,
)
db.session.add(membership)
db.session.commit()
logger.info(f"[SuperadminUsers] User {user_id} added to org {org_id} as {role} by superadmin")
return api_response(data={
"user_id": user_id,
"organization_id": org_id,
"role": role,
"joined_at": membership.created_at.isoformat() + "Z" if membership.created_at else None,
}, message="User added to organization successfully")
except Exception as e:
db.session.rollback()
logger.error(f"[SuperadminUsers] Add to org error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
@superadmin_bp.route("/users/<user_id>/orgs/<org_id>", methods=["DELETE"])
@superadmin_required
@superadmin_audit_log(action="user.remove_from_org", resource_type="user")
def remove_user_from_org(user_id, org_id):
"""Remove a user from an organization."""
try:
membership = OrganizationMember.query.filter(
OrganizationMember.user_id == user_id,
OrganizationMember.organization_id == org_id,
OrganizationMember.deleted_at.is_(None),
).first()
if not membership:
return api_response(
success=False,
message="User is not a member of this organization",
status=404,
error_type="NOT_FOUND",
)
# Check if user is the only owner
if membership.role == "owner":
owner_count = OrganizationMember.query.filter(
OrganizationMember.organization_id == org_id,
OrganizationMember.role == "owner",
OrganizationMember.deleted_at.is_(None),
).count()
if owner_count <= 1:
return api_response(
success=False,
message="Cannot remove the only owner from an organization. Transfer ownership first.",
status=400,
error_type="VALIDATION_ERROR",
)
# Soft delete membership
membership.deleted_at = db.func.now()
db.session.commit()
logger.info(f"[SuperadminUsers] User {user_id} removed from org {org_id} by superadmin")
return api_response(data={
"user_id": user_id,
"organization_id": org_id,
}, message="User removed from organization successfully")
except Exception as e:
db.session.rollback()
logger.error(f"[SuperadminUsers] Remove from org error: {e}")
return api_response(
success=False,
message="An error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
+24 -31
View File
@@ -283,7 +283,8 @@ def admin_verify_user_email(user_id):
if was_inactive:
target.status = UserStatus.ACTIVE
EmailVerificationToken.query.filter_by(user_id=target.id, used_at=None).delete()
now = datetime.now(timezone.utc)
EmailVerificationToken.query.filter_by(user_id=target.id, used_at=None).filter(EmailVerificationToken.deleted_at == None).update({"deleted_at": now}, synchronize_session=False)
_db.session.commit()
AuditService.log_action(
@@ -300,14 +301,12 @@ def admin_verify_user_email(user_id):
@api_v1_bp.route("/admin/users/<user_id>/delete", methods=["POST"])
@login_required
@full_access_required
def admin_hard_delete_user(user_id):
def admin_delete_user(user_id):
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.user.user import User as _User
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate
from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
from gatehouse_app.models.auth.authentication_method import OAuthState
from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.extensions import db as _db
from gatehouse_app.utils.constants import AuditAction, OrganizationRole
from gatehouse_app.services.audit_service import AuditService
@@ -329,7 +328,7 @@ def admin_hard_delete_user(user_id):
if target.id == caller.id:
return api_response(success=False, message="Cannot delete your own account via this endpoint.", status=400, error_type="BAD_REQUEST")
target_org_ids = {m.organization_id for m in target.organization_memberships}
target_org_ids = {m.organization_id for m in target.get_active_memberships()}
admin_in_shared_org = OrganizationMember.query.filter(
OrganizationMember.user_id == caller.id,
OrganizationMember.organization_id.in_(target_org_ids),
@@ -373,44 +372,29 @@ def admin_hard_delete_user(user_id):
target_email = target.email
target_id_str = str(target.id)
now = datetime.now(timezone.utc)
try:
# NULL out FK references that don't cascade on delete so the
# session.delete() below doesn't hit FK constraint violations.
# Soft delete the user — set deleted_at timestamp.
target.deleted_at = now
# org_invite_tokens.invited_by_id — SET NULL is already on the FK column,
# but OrganizationMember.invited_by_id has no ondelete clause.
_db.session.execute(
_db.text("UPDATE organization_members SET invited_by_id = NULL WHERE invited_by_id = :uid"),
{"uid": target_id_str},
# Soft delete associated OAuthState records.
OAuthState.query.filter_by(user_id=target_id_str).filter(OAuthState.deleted_at == None).update(
{"deleted_at": now}, synchronize_session=False
)
# certificate_audit_logs.user_id — nullable, no ondelete clause.
CertificateAuditLog.query.filter_by(user_id=target_id_str).update(
{"user_id": None}, synchronize_session=False
)
# organization_security_policies.updated_by_user_id — nullable, no ondelete.
OrganizationSecurityPolicy.query.filter_by(updated_by_user_id=target_id_str).update(
{"updated_by_user_id": None}, synchronize_session=False
)
# oauth_states.user_id — nullable, no ondelete.
OAuthState.query.filter_by(user_id=target_id_str).delete(synchronize_session=False)
_db.session.delete(target)
_db.session.flush()
except Exception as exc:
_db.session.rollback()
_logger.error(f"Hard delete failed for {target_id_str}: {exc}")
_logger.error(f"Soft delete failed for {target_id_str}: {exc}")
return api_response(success=False, message="Failed to delete user account. Please try again.", status=500, error_type="SERVER_ERROR")
AuditService.log_action(
action=AuditAction.USER_HARD_DELETE,
action=AuditAction.USER_DELETE,
user_id=caller.id,
organization_id=admin_in_shared_org.organization_id,
resource_type="user", resource_id=target_id_str,
description=f"Admin permanently deleted user account: {target_email}",
description=f"Admin deleted user account: {target_email}",
metadata={
"deleted_user_id": target_id_str, "deleted_user_email": target_email,
"ssh_keys_deleted": ssh_key_count, "certs_revoked": active_cert_count,
@@ -419,7 +403,7 @@ def admin_hard_delete_user(user_id):
_db.session.commit()
return api_response(
message=f"User account {target_email} has been permanently deleted.",
message=f"User account {target_email} has been deleted.",
data={"deleted_user_id": target_id_str, "deleted_user_email": target_email,
"ssh_keys_deleted": ssh_key_count, "certs_revoked": active_cert_count},
)
@@ -512,7 +496,14 @@ def admin_get_user_mfa(user_id):
user_id=user_id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None,
).first()
if webauthn_method and webauthn_method.provider_data:
for cred in webauthn_method.provider_data.get("credentials", []):
# Handle both single credential (direct in provider_data) and multiple credentials (in credentials array)
credentials = webauthn_method.provider_data.get("credentials", [])
# If no credentials array, check if provider_data itself is a single credential
if not credentials and "credential_id" in webauthn_method.provider_data:
credentials = [webauthn_method.provider_data]
for cred in credentials:
if not cred.get("deleted_at"):
mfa_methods.append({
"id": cred.get("id") or cred.get("credential_id"),
@@ -588,6 +579,8 @@ def admin_remove_user_mfa(user_id, method_type):
credential_id = request.args.get("credential_id")
if credential_id:
credentials = (webauthn_method.provider_data or {}).get("credentials", [])
if not credentials and "credential_id" in (webauthn_method.provider_data or {}):
credentials = [webauthn_method.provider_data]
found = False
new_credentials = []
for cred in credentials:
+67 -17
View File
@@ -142,6 +142,55 @@ def get_my_organizations():
return api_response(data={"organizations": orgs, "count": len(orgs)}, message="Organizations retrieved successfully")
@api_v1_bp.route("/users/me/organizations/simple", methods=["GET"])
@login_required
def get_my_organizations_simple():
"""Lightweight organization list for CLI tool.
Returns organizations with CA status indicators for CLI users.
"""
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.ssh_ca.ca import CA, CaType
user = g.current_user
memberships = OrganizationMember.query.filter_by(user_id=user.id, deleted_at=None).all()
orgs = []
for membership in memberships:
org = membership.organization
if not org or org.deleted_at is not None:
continue
# Check for active CAs
user_ca = CA.query.filter_by(
organization_id=org.id,
ca_type=CaType.USER,
is_active=True,
deleted_at=None,
).first()
host_ca = CA.query.filter_by(
organization_id=org.id,
ca_type=CaType.HOST,
is_active=True,
deleted_at=None,
).first()
orgs.append({
"id": str(org.id),
"name": org.name,
"slug": getattr(org, 'slug', None),
"role": membership.role.value if hasattr(membership.role, "value") else str(membership.role),
"has_user_ca": user_ca is not None,
"has_host_ca": host_ca is not None,
})
return api_response(
data={"organizations": orgs, "count": len(orgs)},
message="Organizations retrieved successfully",
)
@api_v1_bp.route("/users/me/principals", methods=["GET"])
@login_required
@full_access_required
@@ -182,12 +231,11 @@ def get_my_principals():
my_principals = []
if effective_principal_ids:
for p in Principal.query.filter(
Principal.id.in_(list(effective_principal_ids)),
Principal.deleted_at == None,
).all():
for p in Principal.query.filter(Principal.id.in_(list(effective_principal_ids)), Principal.deleted_at == None).all():
my_principals.append({
"id": p.id, "name": p.name, "description": p.description,
"id": p.id,
"name": p.name,
"description": p.description,
"direct": p.id in direct_principal_ids,
})
@@ -197,7 +245,8 @@ def get_my_principals():
all_principals.append({"id": p.id, "name": p.name, "description": p.description})
orgs_result.append({
"org_id": org.id, "org_name": org.name,
"org_id": org.id,
"org_name": org.name,
"role": role.value if hasattr(role, "value") else role,
"is_admin": is_admin,
"my_principals": my_principals,
@@ -241,6 +290,7 @@ def get_my_pending_invites():
@api_v1_bp.route("/users/me/memberships", methods=["GET"])
@login_required
@full_access_required
def get_my_memberships():
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal, Department
@@ -258,15 +308,15 @@ def get_my_memberships():
dept_memberships = DepartmentMembership.query.filter_by(user_id=user.id, deleted_at=None).all()
user_depts = [
dm.department for dm in dept_memberships
if dm.department
and dm.department.organization_id == org.id
and dm.department.deleted_at is None
dm.department
for dm in dept_memberships
if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None
]
direct_pm = PrincipalMembership.query.filter_by(user_id=user.id, deleted_at=None).all()
direct_principal_ids = {
pm.principal_id for pm in direct_pm
pm.principal_id
for pm in direct_pm
if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None
}
@@ -279,18 +329,18 @@ def get_my_memberships():
all_principal_ids = direct_principal_ids | via_dept_principal_ids
principals_list = []
if all_principal_ids:
for p in Principal.query.filter(
Principal.id.in_(list(all_principal_ids)),
Principal.deleted_at == None,
).all():
for p in Principal.query.filter(Principal.id.in_(list(all_principal_ids)), Principal.deleted_at == None).all():
principals_list.append({
"id": str(p.id), "name": p.name, "description": p.description,
"id": str(p.id),
"name": p.name,
"description": p.description,
"via_department": p.id not in direct_principal_ids,
})
role = membership.role
orgs_result.append({
"org_id": str(org.id), "org_name": org.name,
"org_id": str(org.id),
"org_name": org.name,
"role": role.value if hasattr(role, "value") else role,
"departments": [{"id": str(d.id), "name": d.name, "description": d.description} for d in user_depts],
"principals": principals_list,
+203
View File
@@ -0,0 +1,203 @@
"""Superadmin authentication and audit decorators."""
import logging
from functools import wraps
from datetime import datetime, timezone
from flask import request, g
from gatehouse_app.utils.response import api_response
logger = logging.getLogger(__name__)
def superadmin_required(f):
"""Decorator to require superadmin Bearer token authentication.
Extracts token from Authorization: Bearer {token} header,
validates the session against SuperadminSession table,
and sets g.current_superadmin and g.superadmin_session.
Returns 401 if no valid session, 403 if not a superadmin.
"""
@wraps(f)
def decorated_function(*args, **kwargs):
# Extract token from Authorization header
auth_header = request.headers.get('Authorization')
if not auth_header:
return api_response(
success=False,
message="Authorization header is required",
status=401,
error_type="AUTH_REQUIRED"
)
# Expect format: "Bearer {token}"
parts = auth_header.split()
if len(parts) != 2 or parts[0].lower() != 'bearer':
return api_response(
success=False,
message="Invalid authorization format. Use: Bearer {token}",
status=401,
error_type="INVALID_AUTH_FORMAT"
)
token = parts[1]
# Import here to avoid circular imports
from gatehouse_app.models.superadmin import SuperadminSession, Superadmin
# Get active session by token
session = SuperadminSession.query.filter_by(token=token).first()
if not session:
return api_response(
success=False,
message="Invalid or expired session",
status=401,
error_type="INVALID_TOKEN"
)
# Check if session is active
if not session.is_active():
return api_response(
success=False,
message="Session is no longer active",
status=401,
error_type="SESSION_INACTIVE"
)
# Get the superadmin
superadmin = session.superadmin
if not superadmin:
return api_response(
success=False,
message="Superadmin not found",
status=401,
error_type="INVALID_TOKEN"
)
# Check if superadmin is active
if not superadmin.is_active:
return api_response(
success=False,
message="Superadmin account is disabled",
status=403,
error_type="ACCOUNT_DISABLED"
)
# Update last_activity_at timestamp
session.last_activity_at = datetime.now(timezone.utc)
from gatehouse_app import db
db.session.commit()
# Set context variables
g.current_superadmin = superadmin
g.superadmin_session = session
return f(*args, **kwargs)
return decorated_function
def superadmin_audit_log(action, resource_type):
"""Decorator to log superadmin actions to SuperadminAuditLog.
Must be used AFTER @superadmin_required to have access to g.current_superadmin.
Args:
action: The action being performed (e.g., 'update', 'delete', 'create')
resource_type: The type of resource being acted on (e.g., 'organization', 'user')
Captures: superadmin_id, action, resource_type, resource_id, org_id, user_id,
ip_address, user_agent, request_id, extra_data
"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
# Get superadmin from context (set by @superadmin_required)
superadmin = getattr(g, 'current_superadmin', None)
session = getattr(g, 'superadmin_session', None)
if not superadmin:
logger.warning(f"superadmin_audit_log used without @superadmin_required on {f.__name__}")
return f(*args, **kwargs)
# Extract resource_id from kwargs if present
resource_id = kwargs.get('resource_id') or kwargs.get(f'{resource_type}_id') or None
# Extract org_id and user_id from kwargs if present
org_id = kwargs.get('org_id') or None
user_id = kwargs.get('user_id') or None
# Get IP address and user agent
ip_address = request.remote_addr or None
user_agent = request.headers.get('User-Agent') or None
request_id = request.headers.get('X-Request-ID') or None
# Get extra data from request body (for POST/PATCH requests)
extra_data = None
if request.is_json:
try:
# Exclude sensitive fields
body = request.get_json(silent=True) or {}
sensitive_fields = {'password', 'password_hash', 'token', 'secret', 'key'}
extra_data = {k: v for k, v in body.items() if k not in sensitive_fields}
except Exception:
pass
# Get success status (default to True unless an error is raised)
success = True
error_message = None
try:
result = f(*args, **kwargs)
# Check if the response indicates failure
if hasattr(result, 'get_json'):
result_data = result.get_json()
if result_data and result_data.get('success') is False:
success = False
error_message = result_data.get('message')
return result
except Exception as e:
success = False
error_message = str(e)
raise
finally:
# Log the action
try:
from gatehouse_app.models.superadmin_audit_log import SuperadminAuditLog
from gatehouse_app import db
audit_entry = SuperadminAuditLog(
superadmin_id=superadmin.id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
org_id=org_id,
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
extra_data=extra_data,
success=success,
error_message=error_message
)
db.session.add(audit_entry)
db.session.commit()
logger.info(
f"Superadmin audit: superadmin={superadmin.email} action={action} "
f"resource_type={resource_type} resource_id={resource_id} success={success}"
)
except Exception as e:
# Never let audit logging failures break the main operation
logger.error(f"Failed to write superadmin audit log: {e}")
return decorated_function
return decorator
+59 -36
View File
@@ -1,6 +1,44 @@
"""CORS middleware configuration."""
from flask import request, make_response
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
ALLOWED_HEADERS = (
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
"Cache-Control, Pragma, X-WebAuthn-Session-Token"
)
def _is_origin_allowed(origin, cors_origins):
"""Return True if the origin is permitted by the CORS config.
Handles both wildcard ("*") and explicit origin lists.
"""
if not origin:
return False
if cors_origins == "*":
return True
if isinstance(cors_origins, list):
if "*" in cors_origins:
return True
return origin in cors_origins
return False
def _cors_origin_header(cors_origins, request_origin):
"""Return the value for Access-Control-Allow-Origin.
Per the CORS spec, browsers reject ``*`` when credentials are involved,
so we echo the request origin when wildcard + credentials is configured.
"""
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all and request_origin:
return request_origin
if allow_all:
return "*"
if request_origin and request_origin in cors_origins:
return request_origin
return None
def setup_cors(app):
"""
@@ -9,6 +47,7 @@ def setup_cors(app):
Args:
app: Flask application instance
"""
supports_credentials = app.config.get("CORS_SUPPORTS_CREDENTIALS", True)
@app.before_request
def handle_preflight():
@@ -16,49 +55,33 @@ def setup_cors(app):
if request.method == "OPTIONS":
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all:
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
elif origin and origin in cors_origins:
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
if not _is_origin_allowed(origin, cors_origins):
return None
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = _cors_origin_header(cors_origins, origin)
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
if supports_credentials:
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
@app.after_request
def after_request_cors(response):
"""Add additional CORS headers if needed."""
"""Add CORS headers to non-preflight responses."""
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all:
# When allowing all origins, set header to "*"
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Max-Age"] = "3600"
elif origin and origin in cors_origins:
# When allowing specific origins, echo the request origin
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
response.headers["Access-Control-Allow-Credentials"] = "true"
allow_origin = _cors_origin_header(cors_origins, origin)
if allow_origin:
response.headers["Access-Control-Allow-Origin"] = allow_origin
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
if supports_credentials:
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
return response
+13
View File
@@ -113,6 +113,14 @@ from gatehouse_app.models.zerotier import ( # noqa: F401
ZeroTierMembership,
KillSwitchEvent,
)
# ── Superadmin ─────────────────────────────────────────────────────────────────
from gatehouse_app.models.superadmin import ( # noqa: F401
Superadmin,
SuperadminSession,
SuperadminSessionStatus,
)
from gatehouse_app.models.superadmin_audit_log import SuperadminAuditLog # noqa: F401
from gatehouse_app.models.security.user_security_policy import ( # noqa: F401
UserSecurityPolicy,
)
@@ -175,4 +183,9 @@ __all__ = [
"ActivationSession",
"ZeroTierMembership",
"KillSwitchEvent",
# Superadmin
"Superadmin",
"SuperadminSession",
"SuperadminSessionStatus",
"SuperadminAuditLog",
]
@@ -374,7 +374,7 @@ class OAuthState(BaseModel):
def cleanup_expired(cls) -> None:
"""Remove expired OAuth states."""
now = datetime.now(timezone.utc)
cls.query.filter(cls.expires_at < now).delete()
cls.query.filter(cls.expires_at < now).filter(cls.deleted_at == None).update({"deleted_at": now}, synchronize_session=False)
db.session.commit()
def to_dict(self, exclude=None):
@@ -32,7 +32,8 @@ class EmailVerificationToken(BaseModel):
Any existing unused tokens for this user are invalidated first.
"""
cls.query.filter_by(user_id=user_id, used_at=None).delete()
now = datetime.now(timezone.utc)
cls.query.filter_by(user_id=user_id, used_at=None).filter(cls.deleted_at == None).update({"deleted_at": now}, synchronize_session=False)
db.session.flush()
token_value = secrets.token_urlsafe(48)
@@ -33,7 +33,8 @@ class PasswordResetToken(BaseModel):
Any existing unused tokens for this user are invalidated first.
"""
# Invalidate any existing unused tokens for this user
cls.query.filter_by(user_id=user_id, used_at=None).delete()
now = datetime.now(timezone.utc)
cls.query.filter_by(user_id=user_id, used_at=None).filter(cls.deleted_at == None).update({"deleted_at": now}, synchronize_session=False)
db.session.flush()
token_value = secrets.token_urlsafe(48)
-1
View File
@@ -13,7 +13,6 @@ class BaseModel(db.Model):
db.String(36),
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
nullable=False,
)
created_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
+5
View File
@@ -0,0 +1,5 @@
"""Billing models package."""
from gatehouse_app.models.billing.plan import Plan
from gatehouse_app.models.billing.subscription import Subscription, SubscriptionStatus, BillingCycle
__all__ = ["Plan", "Subscription", "SubscriptionStatus", "BillingCycle"]
+61
View File
@@ -0,0 +1,61 @@
"""Plan model for subscription tiers."""
import logging
from datetime import datetime, timezone
from sqlalchemy import Column, String, Integer, Boolean, DateTime, Text
from gatehouse_app.models.base import BaseModel
logger = logging.getLogger(__name__)
class Plan(BaseModel):
"""Subscription plan definition.
Represents different pricing tiers that organizations can subscribe to.
"""
__tablename__ = "plans"
name = Column(String(100), nullable=False)
slug = Column(String(50), unique=True, nullable=False)
description = Column(Text, nullable=True)
# Pricing in cents
price_monthly = Column(Integer, nullable=False, default=0) # Price in cents
price_yearly = Column(Integer, nullable=False, default=0) # Price in cents
# User limits
included_users = Column(Integer, nullable=False, default=0) # 0 = unlimited
# Overage pricing (cents per user over limit)
overage_rate_per_user = Column(Integer, nullable=False, default=0)
# Feature flags (JSON)
features = Column(Text, nullable=True) # JSON string
# Stripe integration
stripe_price_id_monthly = Column(String(100), nullable=True)
stripe_price_id_yearly = Column(String(100), nullable=True)
# Active/inactive
is_active = Column(Boolean, nullable=False, default=True)
def __repr__(self):
return f"<Plan {self.slug}: ${self.price_monthly / 100}/mo>"
def to_dict(self):
"""Convert plan to dictionary."""
return {
"id": self.id,
"name": self.name,
"slug": self.slug,
"description": self.description,
"price_monthly": self.price_monthly,
"price_yearly": self.price_yearly,
"included_users": self.included_users,
"overage_rate_per_user": self.overage_rate_per_user,
"features": self.features,
"stripe_price_id_monthly": self.stripe_price_id_monthly,
"stripe_price_id_yearly": self.stripe_price_id_yearly,
"is_active": self.is_active,
"created_at": self.created_at.isoformat() + "Z" if self.created_at else None,
"updated_at": self.updated_at.isoformat() + "Z" if self.updated_at else None,
}
@@ -0,0 +1,99 @@
"""Subscription model for organization billing."""
import logging
from datetime import datetime, timezone
from sqlalchemy import Column, String, Integer, Boolean, DateTime, ForeignKey, Enum
from gatehouse_app.models.base import BaseModel
import enum
logger = logging.getLogger(__name__)
class SubscriptionStatus(enum.Enum):
"""Subscription status values."""
TRIAL = "trial"
ACTIVE = "active"
PAST_DUE = "past_due"
CANCELLED = "cancelled"
SUSPENDED = "suspended"
class BillingCycle(enum.Enum):
"""Billing cycle values."""
MONTHLY = "monthly"
YEARLY = "yearly"
class Subscription(BaseModel):
"""Organization subscription record.
Links an organization to a plan and tracks billing state.
"""
__tablename__ = "subscriptions"
# Organization relation
organization_id = Column(
String(36),
ForeignKey("organizations.id", ondelete="CASCADE"),
unique=True,
nullable=False
)
# Plan relation
plan_id = Column(
String(36),
ForeignKey("plans.id", ondelete="SET NULL"),
nullable=True
)
# Status
status = Column(
Enum(SubscriptionStatus, name="subscription_status"),
nullable=False,
default=SubscriptionStatus.TRIAL
)
# Billing
billing_cycle = Column(
Enum(BillingCycle, name="billing_cycle"),
nullable=False,
default=BillingCycle.MONTHLY
)
# Period dates
current_period_start = Column(DateTime, nullable=True)
current_period_end = Column(DateTime, nullable=True)
# Trial
trial_ends_at = Column(DateTime, nullable=True)
# Stripe
stripe_subscription_id = Column(String(100), nullable=True)
# Overage
overage_enabled = Column(Boolean, nullable=False, default=True)
# Cancellation
cancelled_at = Column(DateTime, nullable=True)
cancel_at_period_end = Column(Boolean, nullable=False, default=False)
def __repr__(self):
return f"<Subscription org={self.organization_id} status={self.status.value}>"
def to_dict(self):
"""Convert subscription to dictionary."""
return {
"id": self.id,
"organization_id": self.organization_id,
"plan_id": self.plan_id,
"status": self.status.value if self.status else None,
"billing_cycle": self.billing_cycle.value if self.billing_cycle else None,
"current_period_start": self.current_period_start.isoformat() + "Z" if self.current_period_start else None,
"current_period_end": self.current_period_end.isoformat() + "Z" if self.current_period_end else None,
"trial_ends_at": self.trial_ends_at.isoformat() + "Z" if self.trial_ends_at else None,
"stripe_subscription_id": self.stripe_subscription_id,
"overage_enabled": self.overage_enabled,
"cancelled_at": self.cancelled_at.isoformat() + "Z" if self.cancelled_at else None,
"cancel_at_period_end": self.cancel_at_period_end,
"created_at": self.created_at.isoformat() + "Z" if self.created_at else None,
"updated_at": self.updated_at.isoformat() + "Z" if self.updated_at else None,
}
@@ -78,3 +78,43 @@ class Organization(BaseModel):
).first()
is not None
)
def get_active_members(self):
"""Get active (non-deleted) organization members.
Returns:
List of OrganizationMember instances where deleted_at is None.
"""
return [m for m in self.members if m.deleted_at is None]
def get_active_departments(self):
"""Get active (non-deleted) departments.
Returns:
List of Department instances where deleted_at is None.
"""
return [d for d in self.departments if d.deleted_at is None]
def get_active_principals(self):
"""Get active (non-deleted) principals.
Returns:
List of Principal instances where deleted_at is None.
"""
return [p for p in self.principals if p.deleted_at is None]
def get_active_cas(self):
"""Get active (non-deleted) certificate authorities.
Returns:
List of CA instances where deleted_at is None.
"""
return [ca for ca in self.cas if ca.deleted_at is None]
def get_active_api_keys(self):
"""Get active (non-deleted) API keys.
Returns:
List of OrganizationApiKey instances where deleted_at is None.
"""
return [k for k in self.api_keys if k.deleted_at is None]
@@ -29,6 +29,14 @@ class CertificateAuditLog(BaseModel):
index=True,
)
# The organization that owns the CA (null for system CAs)
organization_id = db.Column(
db.String(36),
db.ForeignKey("organizations.id"),
nullable=True,
index=True,
)
# Action type (e.g., "signed", "revoked", "validated", "requested")
action = db.Column(db.String(50), nullable=False, index=True)
@@ -50,6 +58,7 @@ class CertificateAuditLog(BaseModel):
# Relationships
certificate = db.relationship("SSHCertificate", back_populates="audit_logs")
user = db.relationship("User")
organization = db.relationship("Organization")
__table_args__ = (
db.Index("idx_cert_audit_cert_action", "certificate_id", "action"),
@@ -68,6 +77,7 @@ class CertificateAuditLog(BaseModel):
certificate_id: str,
action: str,
user_id: str = None,
organization_id: str = None,
**kwargs,
) -> "CertificateAuditLog":
"""Create a certificate audit log entry.
@@ -76,6 +86,7 @@ class CertificateAuditLog(BaseModel):
certificate_id: ID of the certificate
action: Action type (e.g., "signed", "revoked")
user_id: ID of the user performing the action (optional)
organization_id: ID of the organization that owns the CA (optional)
**kwargs: Additional fields (ip_address, user_agent, message, etc.)
Returns:
@@ -85,6 +96,7 @@ class CertificateAuditLog(BaseModel):
certificate_id=certificate_id,
action=action,
user_id=user_id,
organization_id=organization_id,
**kwargs,
)
log_entry.save()
+3 -2
View File
@@ -21,10 +21,10 @@ class SSHKey(BaseModel):
)
# SSH key payload in OpenSSH format (e.g., "ssh-ed25519 AAAAB3Nz...")
payload = db.Column(db.Text, nullable=False, unique=True)
payload = db.Column(db.Text, nullable=False)
# SHA256 fingerprint for quick comparison and deduplication
fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True)
fingerprint = db.Column(db.String(255), nullable=False, index=True)
# Optional human-readable description (e.g., "My laptop key")
description = db.Column(db.String(255), nullable=True)
@@ -53,6 +53,7 @@ class SSHKey(BaseModel):
__table_args__ = (
db.Index("idx_ssh_key_user_verified", "user_id", "verified"),
db.UniqueConstraint('user_id', 'fingerprint', name='uix_user_fingerprint'),
)
def __repr__(self):
@@ -0,0 +1,5 @@
"""Superadmin models."""
from gatehouse_app.models.superadmin.superadmin import Superadmin
from gatehouse_app.models.superadmin.superadmin_session import SuperadminSession, SuperadminSessionStatus
__all__ = ["Superadmin", "SuperadminSession", "SuperadminSessionStatus"]
@@ -0,0 +1,56 @@
"""Superadmin model."""
import logging
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
logger = logging.getLogger(__name__)
class Superadmin(BaseModel):
"""Superadmin model for SaaS platform operators.
Completely separate from User model - has its own email/password auth.
"""
__tablename__ = "superadmins"
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
password_hash = db.Column(db.String(255), nullable=False)
full_name = db.Column(db.String(255), nullable=True)
is_active = db.Column(db.Boolean, default=True, nullable=False)
last_login_at = db.Column(db.DateTime, nullable=True)
# Relationship to sessions
sessions = db.relationship(
"SuperadminSession",
back_populates="superadmin",
cascade="all, delete-orphan"
)
# Relationship to audit logs
audit_logs = db.relationship(
"SuperadminAuditLog",
back_populates="superadmin",
cascade="all, delete-orphan"
)
def __repr__(self):
return f"<Superadmin {self.email}>"
def has_password_auth(self):
"""Check if superadmin has password authentication."""
return bool(self.password_hash)
def has_totp_enabled(self):
"""Check if superadmin has TOTP enabled."""
# TODO: Implement TOTP for superadmin if needed
return False
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
exclude.append("password_hash")
return super().to_dict(exclude=exclude)
@@ -0,0 +1,80 @@
"""Superadmin session model."""
import logging
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
logger = logging.getLogger(__name__)
class SuperadminSessionStatus:
"""Session status constants."""
ACTIVE = "active"
REVOKED = "revoked"
EXPIRED = "expired"
class SuperadminSession(BaseModel):
"""Session model for superadmin authentication."""
__tablename__ = "superadmin_sessions"
superadmin_id = db.Column(
db.String(36),
db.ForeignKey("superadmins.id"),
nullable=False,
index=True
)
token = db.Column(db.String(255), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
last_activity_at = db.Column(
db.DateTime,
nullable=False,
default=lambda: datetime.now(timezone.utc)
)
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
# Relationship
superadmin = db.relationship("Superadmin", back_populates="sessions")
def __repr__(self):
return f"<SuperadminSession superadmin_id={self.superadmin_id}>"
def is_active(self):
"""Check if session is currently active."""
now = datetime.now(timezone.utc)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return (
self.deleted_at is None
and self.revoked_at is None
and expires_at > now
)
def is_expired(self):
"""Check if session has expired."""
now = datetime.now(timezone.utc)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
def revoke(self, reason: str = None):
"""Revoke the session."""
self.revoked_at = datetime.now(timezone.utc)
if reason:
self.revoked_reason = reason
from gatehouse_app import db
db.session.commit()
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
exclude.append("token")
return super().to_dict(exclude=exclude)
@@ -0,0 +1,49 @@
"""Superadmin audit log model."""
import logging
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
logger = logging.getLogger(__name__)
class SuperadminAuditLog(BaseModel):
"""Audit log for superadmin actions.
Records every action performed by superadmins for security and compliance.
"""
__tablename__ = "superadmin_audit_logs"
superadmin_id = db.Column(
db.String(36),
db.ForeignKey("superadmins.id"),
nullable=False,
index=True
)
action = db.Column(db.String(100), nullable=False, index=True)
resource_type = db.Column(db.String(50), nullable=False, index=True)
resource_id = db.Column(db.String(36), nullable=True, index=True)
org_id = db.Column(db.String(36), nullable=True, index=True)
user_id = db.Column(db.String(36), nullable=True, index=True)
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
request_id = db.Column(db.String(100), nullable=True)
extra_data = db.Column(db.JSON, nullable=True)
success = db.Column(db.Boolean, default=True, nullable=False)
error_message = db.Column(db.String(500), nullable=True)
# Relationship
superadmin = db.relationship("Superadmin", back_populates="audit_logs")
def __repr__(self):
return (
f"<SuperadminAuditLog superadmin={self.superadmin_id} "
f"action={self.action} resource={self.resource_type}/{self.resource_id}>"
)
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
+55 -16
View File
@@ -1,5 +1,6 @@
"""Session model."""
from datetime import datetime, timedelta, timezone
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import SessionStatus
@@ -38,33 +39,71 @@ class Session(BaseModel):
return f"<Session user_id={self.user_id} status={self.status}>"
def is_active(self):
"""Check if session is currently active."""
"""Check if session is currently active.
Sessions are evaluated against two independent timeouts:
- Idle timeout: expires if no request has been made within
SESSION_IDLE_TIMEOUT seconds (default 15 min).
- Absolute timeout: expires if SESSION_ABSOLUTE_TIMEOUT seconds
have elapsed since the session was created (default 8 h),
regardless of activity.
A session must satisfy *both* constraints to remain active.
"""
now = datetime.now(timezone.utc)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
created_at = self.created_at
last_activity_at = self.last_activity_at
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
if last_activity_at.tzinfo is None:
last_activity_at = last_activity_at.replace(tzinfo=timezone.utc)
idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout)
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
return (
self.status == SessionStatus.ACTIVE
and expires_at > now
and now < idle_expires_at
and now < absolute_expires_at
and self.deleted_at is None
)
def is_expired(self):
"""Check if session has expired."""
now = datetime.now(timezone.utc)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
"""Check if session has expired (either idle or absolute)."""
return not self.is_active() and self.status != SessionStatus.REVOKED
def refresh(self, duration_seconds: int = 86400):
"""Refresh session expiration.
def refresh(self, duration_seconds: int = None):
"""Extend the session expiration using a sliding window.
The new ``expires_at`` is set to *now + idle timeout*, but is
capped so that the session never exceeds the absolute lifetime
(``created_at + absolute timeout``).
Args:
duration_seconds: New session duration in seconds
duration_seconds: Override for the idle timeout. When *None*
(the common case), the value is read from
``SESSION_IDLE_TIMEOUT`` in the Flask config.
"""
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
self.last_activity_at = datetime.now(timezone.utc)
now = datetime.now(timezone.utc)
if duration_seconds is None:
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
idle_expires_at = now + timedelta(seconds=duration_seconds)
created_at = self.created_at
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
self.expires_at = min(idle_expires_at, absolute_expires_at)
self.last_activity_at = now
db.session.commit()
def revoke(self, reason: str = None):
+57 -2
View File
@@ -116,9 +116,64 @@ class User(BaseModel):
is not None
)
def get_active_memberships(self):
"""Get active (non-deleted) organization memberships with active organizations.
Returns:
List of OrganizationMember instances where:
- membership.deleted_at is None
- organization exists and organization.deleted_at is None
"""
return [
m for m in self.organization_memberships
if m.deleted_at is None
and m.organization
and m.organization.deleted_at is None
]
def get_organizations(self):
"""Get all organizations the user is a member of."""
return [membership.organization for membership in self.organization_memberships]
"""Get all active organizations the user is a member of."""
return [membership.organization for membership in self.get_active_memberships()]
def get_active_ssh_keys(self):
"""Get active (non-deleted) SSH keys.
Returns:
List of SSHKey instances where deleted_at is None.
"""
return [k for k in self.ssh_keys if k.deleted_at is None]
def get_active_auth_methods(self):
"""Get active (non-deleted) authentication methods.
Returns:
List of AuthenticationMethod instances where deleted_at is None.
"""
return [m for m in self.authentication_methods if m.deleted_at is None]
def get_active_department_memberships(self):
"""Get active (non-deleted) department memberships.
Returns:
List of DepartmentMembership instances where deleted_at is None.
"""
return [m for m in self.department_memberships if m.deleted_at is None]
def get_active_principal_memberships(self):
"""Get active (non-deleted) principal memberships.
Returns:
List of PrincipalMembership instances where deleted_at is None.
"""
return [m for m in self.principal_memberships if m.deleted_at is None]
def get_active_ca_permissions(self):
"""Get active (non-deleted) CA permissions.
Returns:
List of CAPermission instances where deleted_at is None.
"""
return [p for p in self.ca_permissions if p.deleted_at is None]
def has_totp_enabled(self) -> bool:
"""Check if user has TOTP enabled and verified.
+51
View File
@@ -0,0 +1,51 @@
"""Contact form validation schemas."""
import logging
import re
from marshmallow import Schema, fields, validate, validates_schema, ValidationError
logger = logging.getLogger(__name__)
class ContactSchema(Schema):
"""Schema for contact form submissions."""
email = fields.Email(required=True)
name = fields.Str(
allow_none=True,
load_default=None,
validate=validate.Length(max=255),
)
company = fields.Str(
allow_none=True,
load_default=None,
validate=validate.Length(max=255),
)
enquiry_type = fields.Str(
required=True,
validate=validate.OneOf(["demo_request", "sales_enquiry", "general", "support"]),
)
message = fields.Str(
allow_none=True,
load_default=None,
validate=validate.Length(max=2000),
)
interest_area = fields.Str(
allow_none=True,
load_default=None,
validate=validate.Length(max=100),
)
_hp = fields.Str(
allow_none=True,
load_default=None,
load_from="_hp",
)
@validates_schema
def sanitize_html(self, data, **kwargs):
"""Strip HTML tags from all text fields to prevent XSS."""
text_fields = ["name", "company", "message", "interest_area"]
for field in text_fields:
value = data.get(field)
if value and isinstance(value, str):
data[field] = re.sub(r"<[^>]*>", "", value)
+4
View File
@@ -9,6 +9,8 @@ from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
from gatehouse_app.services.oidc_token_service import OIDCTokenService
from gatehouse_app.services.oidc_session_service import OIDCSessionService
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService
from gatehouse_app.services.superadmin_organization_service import SuperadminOrganizationService
__all__ = [
"AuthService",
@@ -22,4 +24,6 @@ __all__ = [
"OIDCTokenService",
"OIDCSessionService",
"OIDCAuditService",
"SuperadminAuthService",
"SuperadminOrganizationService",
]
+15 -8
View File
@@ -36,9 +36,9 @@ class AuthService:
Raises:
EmailAlreadyExistsError: If email is already registered
"""
# Check if email already exists
existing_user = User.query.filter_by(email=email.lower()).first()
if existing_user and existing_user.deleted_at is None:
# Check if email already exists
existing_user = User.query.filter_by(email=email.lower(), deleted_at=None).first()
if existing_user:
raise EmailAlreadyExistsError()
# Create user
@@ -140,18 +140,26 @@ class AuthService:
return user
@staticmethod
def create_session(user, duration_seconds=86400, is_compliance_only=False):
def create_session(user, duration_seconds=None, is_compliance_only=False):
"""
Create a new session for the user.
Args:
user: User instance
duration_seconds: Session duration in seconds
duration_seconds: Session idle timeout in seconds.
When None, defaults to SESSION_IDLE_TIMEOUT from config.
The absolute lifetime is always enforced by Session.is_active()
regardless of this value.
is_compliance_only: Whether this is a compliance-only session (limited access)
Returns:
Session instance
"""
from flask import current_app
if duration_seconds is None:
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
# Generate session token
token = secrets.token_urlsafe(32)
@@ -280,12 +288,11 @@ class AuthService:
raise ConflictError("TOTP is already enabled for this account")
# Clean up any existing unverified TOTP enrollment attempts
# Use hard delete for unverified methods since they're incomplete enrollment attempts
# Soft delete for unverified methods since they're incomplete enrollment attempts
existing_totp_method = user.get_totp_method()
if existing_totp_method and not existing_totp_method.verified:
logger.debug(f"Removing existing unverified TOTP method for user {user.id}")
db.session.delete(existing_totp_method) # Hard delete - unverified methods are temporary
db.session.commit() # Commit to ensure deletion before creating new record
existing_totp_method.delete(soft=True) # Soft delete - unverified methods are temporary
# Generate TOTP secret
secret = TOTPService.generate_secret()
+192
View File
@@ -0,0 +1,192 @@
"""Billing service for superadmin operations."""
import logging
from datetime import datetime, timedelta, timezone
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.billing.plan import Plan
from gatehouse_app.models.billing.subscription import Subscription, SubscriptionStatus, BillingCycle
from gatehouse_app.extensions import db
logger = logging.getLogger(__name__)
class BillingService:
"""Service for billing operations."""
@staticmethod
def get_plan(plan_id: str) -> Plan:
"""Get a plan by ID."""
plan = Plan.query.get(plan_id)
if not plan:
raise ValueError("Plan not found")
return plan
@staticmethod
def list_plans() -> list:
"""List all active plans."""
return Plan.query.filter(Plan.is_active == True).order_by(Plan.price_monthly.asc()).all()
@staticmethod
def create_subscription(
organization_id: str,
plan_id: str,
billing_cycle: str = "monthly"
) -> Subscription:
"""Create a new subscription for an organization.
Args:
organization_id: Organization UUID
plan_id: Plan UUID
billing_cycle: 'monthly' or 'yearly'
Returns:
New subscription
"""
org = Organization.query.get(organization_id)
if not org:
raise ValueError("Organization not found")
plan = Plan.query.get(plan_id)
if not plan:
raise ValueError("Plan not found")
# Check if subscription already exists
existing = Subscription.query.filter_by(organization_id=organization_id, deleted_at=None).first()
if existing:
raise ValueError("Organization already has a subscription")
now = datetime.now(timezone.utc)
# Calculate period
if billing_cycle == "yearly":
period_end = now + timedelta(days=365)
else:
period_end = now + timedelta(days=30)
subscription = Subscription(
organization_id=organization_id,
plan_id=plan_id,
status=SubscriptionStatus.ACTIVE,
billing_cycle=BillingCycle.MONTHLY if billing_cycle == "monthly" else BillingCycle.YEARLY,
current_period_start=now,
current_period_end=period_end,
)
db.session.add(subscription)
db.session.commit()
return subscription
@staticmethod
def change_plan(organization_id: str, new_plan_id: str) -> Subscription:
"""Change subscription plan.
Args:
organization_id: Organization UUID
new_plan_id: New plan UUID
Returns:
Updated subscription
"""
subscription = Subscription.query.filter_by(organization_id=organization_id, deleted_at=None).first()
if not subscription:
raise ValueError("No subscription found for organization")
new_plan = Plan.query.get(new_plan_id)
if not new_plan:
raise ValueError("Plan not found")
subscription.plan_id = new_plan_id
db.session.commit()
return subscription
@staticmethod
def cancel_subscription(organization_id: str) -> Subscription:
"""Cancel subscription at period end.
Args:
organization_id: Organization UUID
Returns:
Updated subscription
"""
subscription = Subscription.query.filter_by(organization_id=organization_id, deleted_at=None).first()
if not subscription:
raise ValueError("No subscription found for organization")
subscription.cancel_at_period_end = True
subscription.status = SubscriptionStatus.CANCELLED
db.session.commit()
return subscription
@staticmethod
def extend_trial(organization_id: str, days: int) -> Subscription:
"""Extend trial period.
Args:
organization_id: Organization UUID
days: Number of days to extend
Returns:
Updated subscription
"""
subscription = Subscription.query.filter_by(organization_id=organization_id, deleted_at=None).first()
if not subscription:
raise ValueError("No subscription found for organization")
now = datetime.now(timezone.utc)
if subscription.trial_ends_at:
subscription.trial_ends_at = subscription.trial_ends_at + timedelta(days=days)
else:
subscription.trial_ends_at = now + timedelta(days=days)
subscription.status = SubscriptionStatus.TRIAL
db.session.commit()
return subscription
@staticmethod
def calculate_overage(organization_id: str) -> dict:
"""Calculate overage charges for an organization.
Args:
organization_id: Organization UUID
Returns:
Overage calculation with details
"""
subscription = Subscription.query.filter_by(organization_id=organization_id, deleted_at=None).first()
if not subscription:
return {"has_overage": False, "overage_cost": 0, "user_count": 0, "included_users": 0}
plan = Plan.query.get(subscription.plan_id) if subscription.plan_id else None
if not plan:
return {"has_overage": False, "overage_cost": 0, "user_count": 0, "included_users": 0}
# Count current users
user_count = OrganizationMember.query.filter(
OrganizationMember.organization_id == organization_id,
OrganizationMember.deleted_at.is_(None),
).count()
included_users = plan.included_users
overage_users = max(0, user_count - included_users)
if overage_users > 0 and plan.overage_rate_per_user > 0:
overage_cost = overage_users * plan.overage_rate_per_user
has_overage = True
else:
overage_cost = 0
has_overage = False
return {
"has_overage": has_overage,
"user_count": user_count,
"included_users": included_users,
"overage_users": overage_users,
"overage_rate_per_user": plan.overage_rate_per_user,
"overage_cost": overage_cost,
}
+66
View File
@@ -496,3 +496,69 @@ def build_email_verification_resend_html(
{get_alert_box("If you didn't request this, you can safely ignore this email.", "info", "🔒")}
'''
return get_base_html(content, "Verify your Secuird email address", "Please verify your email address")
def build_contact_enquiry_html(
enquiry_type: str,
submitter_email: str,
name: Optional[str],
company: Optional[str],
interest_area: Optional[str],
message: Optional[str],
) -> str:
"""Build a contact enquiry notification email.
Args:
enquiry_type: One of demo_request, sales_enquiry, general, support
submitter_email: Email address of the person submitting the enquiry
name: Full name of the submitter (optional)
company: Company name (optional)
interest_area: Area of interest (optional)
message: Free-text message (optional)
Returns:
HTML email string
"""
# Map enquiry types to display labels and colors
type_labels = {
"demo_request": ("Demo Request", "info"),
"sales_enquiry": ("Sales Enquiry", "success"),
"general": ("General Enquiry", "info"),
"support": ("Support Request", "warning"),
}
type_label, alert_type = type_labels.get(enquiry_type, ("Enquiry", "info"))
name_display = name if name else "Not provided"
company_display = company if company else "Not provided"
interest_display = interest_area if interest_area else "Not provided"
message_display = message if message else "No message provided"
# Build details table
details_rows = f"""
{get_detail_row("Enquiry Type", type_label)}
{get_detail_row("Submitter Email", submitter_email)}
{get_detail_row("Name", name_display)}
{get_detail_row("Company", company_display)}
{get_detail_row("Interest Area", interest_display)}
"""
content = f'''
<h2 style="margin: 0 0 20px 0; color: {TEXT_COLOR}; font-size: 20px; font-weight: 600;">New {type_label}</h2>
<p style="margin: 0 0 20px 0; color: {TEXT_COLOR}; font-size: 15px; line-height: 1.6;">
A new {type_label.lower()} has been submitted through the Secuird website.
</p>
{get_alert_box(f"Enquiry type: <strong>{type_label}</strong>", alert_type, "📬")}
<table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="margin: 20px 0; background-color: {BACKGROUND_COLOR}; border-radius: 8px;">
<tr>
<td style="padding: 20px;">
<h3 style="margin: 0 0 16px 0; color: {TEXT_COLOR}; font-size: 14px; font-weight: 600;">Enquiry Details</h3>
<table role="presentation" width="100%" cellspacing="0" cellpadding="0">
{details_rows}
</table>
</td>
</tr>
</table>
<h3 style="margin: 20px 0 12px 0; color: {TEXT_COLOR}; font-size: 14px; font-weight: 600;">Message</h3>
<p style="margin: 0; color: {TEXT_COLOR}; font-size: 14px; line-height: 1.6; white-space: pre-wrap;">{message_display}</p>
'''
return get_base_html(content, f"Secuird Website: {type_label}", f"New {type_label} from {submitter_email}")
@@ -42,7 +42,7 @@ class ExternalAuthService:
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
app_config = ApplicationProviderConfig.query.filter_by(
provider_type=provider_type_str
provider_type=provider_type_str, deleted_at=None
).first()
if not app_config:
@@ -64,6 +64,7 @@ class ExternalAuthService:
org_override_obj = OrganizationProviderOverride.query.filter_by(
organization_id=organization_id,
provider_type=provider_type_str,
deleted_at=None,
).first()
if org_override_obj and not org_override_obj.is_enabled:
@@ -14,7 +14,7 @@ def create_app_provider_config(
**kwargs,
) -> ApplicationProviderConfig:
existing = ApplicationProviderConfig.query.filter_by(
provider_type=provider_type
provider_type=provider_type, deleted_at=None
).first()
if existing:
@@ -51,7 +51,7 @@ def update_app_provider_config(
**updates,
) -> ApplicationProviderConfig:
config = ApplicationProviderConfig.query.filter_by(
provider_type=provider_type
provider_type=provider_type, deleted_at=None
).first()
if not config:
@@ -90,7 +90,7 @@ def update_app_provider_config(
def get_app_provider_config(provider_type: str) -> ApplicationProviderConfig:
config = ApplicationProviderConfig.query.filter_by(
provider_type=provider_type
provider_type=provider_type, deleted_at=None
).first()
if not config:
@@ -104,13 +104,13 @@ def get_app_provider_config(provider_type: str) -> ApplicationProviderConfig:
def list_app_provider_configs() -> list:
configs = ApplicationProviderConfig.query.all()
configs = ApplicationProviderConfig.query.filter_by(deleted_at=None).all()
return [config.to_dict() for config in configs]
def delete_app_provider_config(provider_type: str) -> bool:
config = ApplicationProviderConfig.query.filter_by(
provider_type=provider_type
provider_type=provider_type, deleted_at=None
).first()
if not config:
@@ -219,10 +219,11 @@ def authenticate_with_provider(
auth_method = AuthenticationMethod.query.filter_by(
method_type=provider_type,
provider_user_id=user_info["provider_user_id"],
deleted_at=None,
).first()
if not auth_method:
existing_user = User.query.filter_by(email=user_info["email"]).first()
existing_user = User.query.filter_by(email=user_info["email"], deleted_at=None).first()
if existing_user:
AuditService.log_external_auth_login_failed(
@@ -262,7 +263,7 @@ def authenticate_with_provider(
state_record.mark_used()
from gatehouse_app.services.auth_service import AuthService
session = AuthService.create_session(user=user, organization_id=organization_id)
session = AuthService.create_session(user=user)
AuditService.log_external_auth_login(
user_id=user.id,
@@ -286,12 +287,13 @@ def unlink_provider(
auth_method = AuthenticationMethod.query.filter_by(
user_id=user_id,
method_type=provider_type,
deleted_at=None,
).first()
if not auth_method:
raise ExternalAuthError("Provider not linked", "PROVIDER_NOT_LINKED", 400)
other_methods = AuthenticationMethod.query.filter_by(user_id=user_id).count()
other_methods = AuthenticationMethod.query.filter_by(user_id=user_id, deleted_at=None).count()
if other_methods <= 1:
raise ExternalAuthError(
"Cannot unlink the last authentication method",
@@ -16,7 +16,7 @@ def create_org_provider_override(
**kwargs,
) -> OrganizationProviderOverride:
app_config = ApplicationProviderConfig.query.filter_by(
provider_type=provider_type
provider_type=provider_type, deleted_at=None
).first()
if not app_config:
@@ -29,6 +29,7 @@ def create_org_provider_override(
existing = OrganizationProviderOverride.query.filter_by(
organization_id=organization_id,
provider_type=provider_type,
deleted_at=None,
).first()
if existing:
@@ -69,6 +70,7 @@ def update_org_provider_override(
override = OrganizationProviderOverride.query.filter_by(
organization_id=organization_id,
provider_type=provider_type,
deleted_at=None,
).first()
if not override:
@@ -110,6 +112,7 @@ def get_org_provider_override(
override = OrganizationProviderOverride.query.filter_by(
organization_id=organization_id,
provider_type=provider_type,
deleted_at=None,
).first()
if not override:
@@ -124,7 +127,7 @@ def get_org_provider_override(
def list_org_provider_overrides(organization_id: str) -> list:
overrides = OrganizationProviderOverride.query.filter_by(
organization_id=organization_id
organization_id=organization_id, deleted_at=None
).all()
return [override.to_dict() for override in overrides]
@@ -133,6 +136,7 @@ def delete_org_provider_override(organization_id: str, provider_type: str) -> bo
override = OrganizationProviderOverride.query.filter_by(
organization_id=organization_id,
provider_type=provider_type,
deleted_at=None,
).first()
if not override:
@@ -1024,17 +1024,17 @@ def revoke_membership_soft(
def hard_delete_membership(membership_id: str) -> None:
"""Hard delete a membership after ZeroTier has been cleaned up.
"""Soft delete a membership after ZeroTier has been cleaned up.
Called by the reconciliation job after successfully removing the member
from the ZeroTier controller. This is the final DB cleanup step.
from the ZeroTier controller. This marks the membership as deleted.
"""
membership = DeviceNetworkMembership.query.filter(
DeviceNetworkMembership.id == membership_id,
).first()
if not membership:
logger.warning(f"[hard_delete_membership] Membership {membership_id} not found, skipping.")
logger.warning(f"[hard_delete_membership] Membership {membership_id} not found or already deleted, skipping.")
return
device = Device.query.get(membership.device_id)
@@ -1048,7 +1048,7 @@ def hard_delete_membership(membership_id: str) -> None:
except Exception as exc:
logger.warning(f"[hard_delete_membership] ZT delete failed for {device.node_id}: {exc}")
db.session.delete(membership)
membership.delete(soft=True)
db.session.commit()
AuditService.log_action(
@@ -1061,6 +1061,6 @@ def hard_delete_membership(membership_id: str) -> None:
"device_node_id": device.node_id if device else None,
"network_id": network.zerotier_network_id if network else None,
},
description=f"Membership hard-deleted: device {device.node_id if device else 'unknown'} from network",
description=f"Membership deleted: device {device.node_id if device else 'unknown'} from network",
success=True,
)
@@ -111,7 +111,7 @@ def handle_register_callback(
access_token=tokens["access_token"],
)
existing_user = User.query.filter_by(email=user_info["email"]).first()
existing_user = User.query.filter_by(email=user_info["email"], deleted_at=None).first()
if existing_user:
raise OAuthFlowError(
f"An account with email {user_info['email']} already exists. "
+2 -2
View File
@@ -55,9 +55,9 @@ def get_userinfo(access_token: str, validate_access_token_fn) -> Dict:
def _get_user_roles(user: User) -> list:
roles = []
if not user or not user.organization_memberships:
if not user:
return roles
for member in user.organization_memberships:
for member in user.get_active_memberships():
roles.append({
"organization_id": str(member.organization_id),
"role": member.role.value,
+2 -2
View File
@@ -324,8 +324,8 @@ class OIDCTokenService:
List of role objects with organization_id and role
"""
roles = []
if user and user.organization_memberships:
for member in user.organization_memberships:
if user:
for member in user.get_active_memberships():
roles.append({
"organization_id": str(member.organization_id),
"role": member.role.value
+27 -2
View File
@@ -70,6 +70,22 @@ class OrganizationService:
return org
@staticmethod
def get_user_org_count(user_id):
"""
Get the count of organizations a user belongs to.
Args:
user_id: User ID
Returns:
Count of active memberships (deleted_at is NULL)
"""
return OrganizationMember.query.filter_by(
user_id=user_id,
deleted_at=None,
).count()
@staticmethod
def get_organization_by_id(org_id):
"""
@@ -286,7 +302,7 @@ class OrganizationService:
Raises:
ConflictError: If user is already a member
"""
# Check if already a member (active or soft-deleted — both blocked by DB unique constraint)
# Check for any membership (active or soft-deleted) to enable reactivation
existing = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org.id,
@@ -294,7 +310,7 @@ class OrganizationService:
# Development-only debug logging for membership validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}, soft_deleted={existing.deleted_at is not None if existing else False}")
if existing:
if existing.deleted_at is not None:
@@ -363,6 +379,15 @@ class OrganizationService:
logger.debug(f"[Org] Member removal: org_id={org.id}, user_id={user_id}, found={member is not None}")
if member:
if member.role == OrganizationRole.OWNER:
owner_count = OrganizationMember.query.filter(
OrganizationMember.organization_id == org.id,
OrganizationMember.role == OrganizationRole.OWNER,
OrganizationMember.deleted_at.is_(None),
OrganizationMember.user_id != user_id,
).count()
if owner_count < 1:
raise ValueError("Cannot remove the only owner from an organization. Transfer ownership first.")
member.delete(soft=True)
# Log member removal
+4 -2
View File
@@ -10,10 +10,10 @@ class SessionService:
@staticmethod
def get_active_session_by_token(token):
"""Get active session by token.
Args:
token: The session token string
Returns:
Session object if found and active, None otherwise
"""
@@ -23,6 +23,8 @@ class SessionService:
token=token,
status=SessionStatus.ACTIVE,
deleted_at=None
).filter(
Session.expires_at > datetime.now(timezone.utc)
).first()
@staticmethod
+26 -22
View File
@@ -41,46 +41,47 @@ class SSHKeyService:
user_id: str,
public_key: str,
description: Optional[str] = None,
) -> SSHKey:
) -> tuple[SSHKey, bool]:
"""Add an SSH public key for a user.
Args:
user_id: ID of the user
public_key: SSH public key in OpenSSH format
description: Optional description of the key
Returns:
Created SSHKey instance
Tuple of (SSHKey instance, is_new) where is_new is True for
newly created keys, False for existing keys (idempotent).
Raises:
UserNotFoundError: If user doesn't exist
SSHKeyError: If key format is invalid
SSHKeyAlreadyExistsError: If key already exists
"""
# Verify user exists
user = User.query.get(user_id)
if not user:
raise UserNotFoundError(f"User {user_id} not found")
# Validate key format
if not verify_ssh_key_format(public_key):
raise SSHKeyError("Invalid SSH public key format")
# Compute fingerprint
try:
fingerprint = compute_ssh_fingerprint(public_key)
except Exception as e:
logger.error(f"Failed to compute fingerprint: {str(e)}")
raise SSHKeyError(f"Failed to compute key fingerprint: {str(e)}")
# Check for duplicate (including soft-deleted records — fingerprint is unique in DB)
existing = SSHKey.query.filter_by(fingerprint=fingerprint).first()
# Check for duplicate per user (including soft-deleted records)
existing = SSHKey.query.filter_by(
user_id=user_id, fingerprint=fingerprint
).first()
if existing:
if existing.deleted_at is not None:
# Restore the soft-deleted key: clear deleted_at and update fields
existing.deleted_at = None
existing.user_id = user_id
existing.description = description or existing.description
existing.description = description if description is not None else existing.description
existing.verified = False
existing.verified_at = None
existing.verify_text = None
@@ -90,15 +91,18 @@ class SSHKeyService:
f"Restored soft-deleted SSH key for user {user_id}: "
f"fingerprint={fingerprint}"
)
return existing
raise SSHKeyAlreadyExistsError(
f"SSH key with fingerprint {fingerprint} already exists"
return existing, False
# Idempotent: return existing key without error
logger.info(
f"SSH key already exists for user {user_id}: "
f"fingerprint={fingerprint}"
)
return existing, False
# Extract metadata
key_type = extract_ssh_key_type(public_key)
key_comment = extract_ssh_key_comment(public_key)
# Create SSH key record
ssh_key = SSHKey(
user_id=user_id,
@@ -109,15 +113,15 @@ class SSHKeyService:
key_comment=key_comment,
verified=False,
)
ssh_key.save()
logger.info(
f"SSH key added for user {user_id}: "
f"fingerprint={fingerprint}, type={key_type}"
)
return ssh_key
return ssh_key, True
def get_ssh_key(self, key_id: str) -> SSHKey:
"""Get an SSH key by ID.
@@ -0,0 +1,177 @@
"""Analytics service for platform-wide statistics."""
import logging
from datetime import datetime, timedelta, timezone
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.user.user import User
from gatehouse_app.models.user.session import Session
from gatehouse_app.models.superadmin_audit_log import SuperadminAuditLog
logger = logging.getLogger(__name__)
class SuperadminAnalyticsService:
"""Service for platform-wide analytics and statistics."""
@staticmethod
def get_dashboard_stats() -> dict:
"""Get dashboard statistics for the overview page.
Returns:
Dashboard stats including org count, user count, etc.
"""
now = datetime.now(timezone.utc)
thirty_days_ago = now - timedelta(days=30)
# Total organizations
total_orgs = Organization.query.filter(Organization.deleted_at.is_(None)).count()
active_orgs = Organization.query.filter(
Organization.deleted_at.is_(None),
Organization.is_active == True, # noqa: E712
).count()
# Total users
total_users = User.query.filter(User.deleted_at.is_(None)).count()
# Active sessions
active_sessions = Session.query.filter(
Session.deleted_at.is_(None),
Session.status == "active",
).count()
# New signups in last 30 days
new_users_30d = User.query.filter(
User.deleted_at.is_(None),
User.created_at >= thirty_days_ago,
).count()
# New organizations in last 30 days
new_orgs_30d = Organization.query.filter(
Organization.deleted_at.is_(None),
Organization.created_at >= thirty_days_ago,
).count()
# Suspended organizations
suspended_orgs = Organization.query.filter(
Organization.deleted_at.is_(None),
Organization.is_active == False, # noqa: E712
).count()
return {
"total_organizations": total_orgs,
"active_organizations": active_orgs,
"suspended_organizations": suspended_orgs,
"total_users": total_users,
"active_sessions": active_sessions,
"new_users_30d": new_users_30d,
"new_orgs_30d": new_orgs_30d,
"generated_at": now.isoformat() + "Z",
}
@staticmethod
def get_signup_trends(days: int = 30) -> dict:
"""Get signup trends over time.
Args:
days: Number of days to analyze
Returns:
Daily signup data
"""
now = datetime.now(timezone.utc)
start_date = now - timedelta(days=days)
# Get all users created in period
users = User.query.filter(
User.deleted_at.is_(None),
User.created_at >= start_date,
).all()
# Group by day
daily_signups = {}
for i in range(days):
date = (start_date + timedelta(days=i)).strftime("%Y-%m-%d")
daily_signups[date] = 0
for user in users:
date = user.created_at.strftime("%Y-%m-%d")
if date in daily_signups:
daily_signups[date] += 1
# Convert to list
history = [
{"date": date, "value": count}
for date, count in sorted(daily_signups.items())
]
return {
"period_start": start_date.isoformat() + "Z",
"period_end": now.isoformat() + "Z",
"total": len(users),
"history": history,
}
@staticmethod
def get_org_distribution() -> dict:
"""Get distribution of organizations by size.
Returns:
Organization size distribution
"""
orgs = Organization.query.filter(Organization.deleted_at.is_(None)).all()
distribution = {
"solo": 0, # 1 user
"small": 0, # 2-10 users
"medium": 0, # 11-50 users
"large": 0, # 51-200 users
"enterprise": 0, # 200+ users
}
for org in orgs:
count = org.get_member_count()
if count == 1:
distribution["solo"] += 1
elif count <= 10:
distribution["small"] += 1
elif count <= 50:
distribution["medium"] += 1
elif count <= 200:
distribution["large"] += 1
else:
distribution["enterprise"] += 1
return {
"distribution": distribution,
"total_orgs": len(orgs),
}
@staticmethod
def get_recent_activity(limit: int = 20) -> list:
"""Get recent superadmin actions.
Args:
limit: Maximum number of actions to return
Returns:
List of recent audit log entries
"""
logs = SuperadminAuditLog.query.filter(
SuperadminAuditLog.deleted_at.is_(None),
).order_by(
SuperadminAuditLog.created_at.desc()
).limit(limit).all()
return [
{
"id": log.id,
"superadmin_id": log.superadmin_id,
"action": log.action,
"resource_type": log.resource_type,
"resource_id": log.resource_id,
"extra_data": log.extra_data,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"created_at": log.created_at.isoformat() + "Z" if log.created_at else None,
}
for log in logs
]
@@ -0,0 +1,239 @@
"""Superadmin authentication service."""
import logging
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional
from flask import request, current_app
from gatehouse_app.extensions import db, bcrypt
from gatehouse_app.models.superadmin import Superadmin, SuperadminSession
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
logger = logging.getLogger(__name__)
class SuperadminAuthService:
"""Service for superadmin authentication operations."""
@staticmethod
def authenticate(email, credentials):
"""Authenticate superadmin with email/password credentials.
Args:
email: Superadmin email
credentials: Plain text credential
Returns:
Superadmin instance if authentication succeeds
Raises:
InvalidCredentialsError: If credentials are invalid or account is disabled
"""
# Find superadmin by email
superadmin = Superadmin.query.filter_by(email=email.lower()).first()
if not superadmin:
logger.warning(f"[SuperadminAuth] Login attempt for non-existent email: {email}")
raise InvalidCredentialsError()
# Check if account is active
if not superadmin.is_active:
logger.warning(f"[SuperadminAuth] Login attempt for disabled account: {email}")
raise InvalidCredentialsError("Account is disabled")
# Check credential
if not superadmin.password_hash:
logger.warning(f"[SuperadminAuth] Login attempt for account with no credential set: {email}")
raise InvalidCredentialsError()
# Verify credential
password_valid = bcrypt.check_password_hash(superadmin.password_hash, credentials)
if not password_valid:
logger.warning(f"[SuperadminAuth] Invalid password for: {email}")
raise InvalidCredentialsError()
# Update last login
superadmin.last_login_at = datetime.now(timezone.utc)
db.session.commit()
logger.info(f"[SuperadminAuth] Successful login for: {email}")
return superadmin
@staticmethod
def create_session(superadmin_id, duration_seconds=28800):
"""Create a new session for superadmin.
Args:
superadmin_id: Superadmin ID
duration_seconds: Session duration in seconds (default 8 hours)
Returns:
SuperadminSession instance
"""
# Generate secure token
token = secrets.token_urlsafe(32)
# Create session
session = SuperadminSession(
superadmin_id=superadmin_id,
token=token,
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
last_activity_at=datetime.now(timezone.utc),
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent"),
)
session.save()
logger.info(f"[SuperadminAuth] Session created for superadmin_id={superadmin_id}")
return session
@staticmethod
def revoke_session(session_id, reason=None):
"""Revoke a superadmin session.
Args:
session_id: Session ID to revoke
reason: Optional revocation reason
"""
session = SuperadminSession.query.get(session_id)
if session:
session.revoke(reason=reason)
logger.info(f"[SuperadminAuth] Session {session_id} revoked: {reason or 'No reason'}")
@staticmethod
def revoke_all_sessions(superadmin_id, except_token=None, reason=None):
"""Revoke all sessions for a superadmin.
Args:
superadmin_id: Superadmin ID
except_token: Optional token to keep (current session)
reason: Optional revocation reason
"""
query = SuperadminSession.query.filter_by(superadmin_id=superadmin_id)
if except_token:
query = query.filter(SuperadminSession.token != except_token)
sessions = query.all()
for session in sessions:
session.revoke(reason=reason)
logger.info(f"[SuperadminAuth] Revoked {len(sessions)} sessions for superadmin_id={superadmin_id}")
return len(sessions)
@staticmethod
def create_emergency_access(superadmin_id, target_user_id, reason, duration_minutes=15):
"""Create emergency access to a user's account.
This creates a special emergency session that grants temporary elevated access.
Args:
superadmin_id: Superadmin ID initiating emergency access
target_user_id: User ID to access
reason: Reason for emergency access
duration_minutes: Duration of emergency access in minutes
Returns:
Dictionary with emergency session info
"""
from gatehouse_app.models.user.user import User
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.audit_service import AuditService
# Verify target user exists
target_user = User.query.get(target_user_id)
if not target_user:
raise ValueError(f"Target user not found: {target_user_id}")
# Create emergency session for the target user
emergency_session = AuthService.create_session(
user=target_user,
duration_seconds=duration_minutes * 60,
is_compliance_only=False
)
# Log the emergency access
logger.warning(
f"[SuperadminAuth] EMERGENCY ACCESS: superadmin_id={superadmin_id} "
f"accessed user_id={target_user_id} reason={reason}"
)
return {
"session": emergency_session,
"expires_at": emergency_session.expires_at,
"reason": reason,
"target_user_id": target_user_id,
}
@staticmethod
def hash_password(plain_credential):
"""Hash a credential for storage.
Args:
plain_credential: Plain text credential
Returns:
Hashed credential string
"""
return bcrypt.generate_password_hash(plain_credential).decode("utf-8")
@staticmethod
def create_superadmin(email, credential, full_name=None):
"""Create a new superadmin.
Args:
email: Superadmin email
credential: Plain text credential
full_name: Optional full name
Returns:
Superadmin instance
"""
# Check if email already exists
existing = Superadmin.query.filter_by(email=email.lower()).first()
if existing:
raise ValueError(f"Superadmin with email {email} already exists")
# Hash credential
password_hash = bcrypt.generate_password_hash(credential).decode("utf-8")
# Create superadmin
superadmin = Superadmin(
email=email.lower(),
password_hash=password_hash,
full_name=full_name,
is_active=True,
)
superadmin.save()
logger.info(f"[SuperadminAuth] Created new superadmin: {email}")
return superadmin
@staticmethod
def update_superadmin(superadmin_id, **kwargs):
"""Update superadmin details.
Args:
superadmin_id: Superadmin ID
**kwargs: Fields to update (email, full_name, is_active, credential)
Returns:
Updated Superadmin instance
"""
superadmin = Superadmin.query.get(superadmin_id)
if not superadmin:
raise ValueError(f"Superadmin not found: {superadmin_id}")
# Handle credential update
if 'password' in kwargs:
kwargs['password_hash'] = bcrypt.generate_password_hash(kwargs.pop('password')).decode("utf-8")
# Update fields
for key, value in kwargs.items():
if hasattr(superadmin, key):
setattr(superadmin, key, value)
superadmin.save()
logger.info(f"[SuperadminAuth] Updated superadmin_id={superadmin_id}")
return superadmin
@@ -0,0 +1,244 @@
"""Superadmin organization management service."""
import logging
from typing import Optional
from gatehouse_app.extensions import db
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.user.session import Session
logger = logging.getLogger(__name__)
class SuperadminOrganizationService:
"""Service for superadmin organization management operations."""
@staticmethod
def list_organizations(
page: int = 1,
per_page: int = 20,
search: Optional[str] = None,
status: Optional[str] = None,
plan_slug: Optional[str] = None,
) -> dict:
"""List organizations with pagination and filtering.
Args:
page: Page number (1-indexed)
per_page: Items per page
search: Search by name or slug
status: Filter by status (active, suspended)
plan_slug: Filter by plan slug (not implemented yet - requires Subscription model)
Returns:
Paginated response with organization summaries
"""
query = Organization.query
# Search filter
if search:
search_term = f"%{search}%"
query = query.filter(
db.or_(
Organization.name.ilike(search_term),
Organization.slug.ilike(search_term),
)
)
# Status filter
if status == "active":
query = query.filter(Organization.is_active.is_(True))
elif status == "suspended":
query = query.filter(Organization.is_active.is_(False))
# Note: plan_slug filtering requires Plan/Subscription models (Phase 4)
# Currently ignored but parameter is accepted for API compatibility
# Order by created_at desc
query = query.order_by(Organization.created_at.desc())
# Paginate
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
# Build response
items = []
for org in pagination.items:
item = {
"id": org.id,
"name": org.name,
"slug": org.slug,
"description": org.description,
"is_active": org.is_active,
"member_count": org.get_member_count(),
"created_at": org.created_at.isoformat() + "Z" if org.created_at else None,
"updated_at": org.updated_at.isoformat() + "Z" if org.updated_at else None,
}
items.append(item)
return {
"items": items,
"total": pagination.total,
"page": page,
"per_page": per_page,
"pages": pagination.pages,
}
@staticmethod
def get_organization_detail(org_id: str) -> dict:
"""Get detailed organization information.
Args:
org_id: Organization UUID
Returns:
Organization detail with member_count, owner, stats
Raises:
ValueError: If organization not found
"""
org = Organization.query.get(org_id)
if not org:
raise ValueError("Organization not found")
owner = org.get_owner()
# Count active sessions for org members
member_user_ids = [m.user_id for m in org.members if m.deleted_at is None]
active_sessions = Session.query.filter(
Session.user_id.in_(member_user_ids),
Session.deleted_at.is_(None),
).count()
return {
"id": org.id,
"name": org.name,
"slug": org.slug,
"description": org.description,
"is_active": org.is_active,
"settings": org.settings or {},
"member_count": org.get_member_count(),
"owner": {
"id": owner.id,
"email": owner.email,
"full_name": owner.full_name,
} if owner else None,
"active_sessions": active_sessions,
"created_at": org.created_at.isoformat() + "Z" if org.created_at else None,
"updated_at": org.updated_at.isoformat() + "Z" if org.updated_at else None,
}
@staticmethod
def update_organization(
org_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
is_active: Optional[bool] = None,
) -> Organization:
"""Update organization details.
Args:
org_id: Organization UUID
name: New name (optional)
description: New description (optional)
is_active: New active status (optional)
Returns:
Updated organization
Raises:
ValueError: If organization not found
"""
org = Organization.query.get(org_id)
if not org:
raise ValueError("Organization not found")
if name is not None:
org.name = name
if description is not None:
org.description = description
if is_active is not None:
org.is_active = is_active
db.session.commit()
logger.info(f"[SuperadminOrg] Updated organization {org_id}")
return org
@staticmethod
def suspend_organization(org_id: str) -> Organization:
"""Suspend an organization.
Sets is_active=False and invalidates all member sessions.
Args:
org_id: Organization UUID
Returns:
Suspended organization
Raises:
ValueError: If organization not found
"""
org = Organization.query.get(org_id)
if not org:
raise ValueError("Organization not found")
org.is_active = False
# Invalidate all member sessions
member_user_ids = [m.user_id for m in org.members if m.deleted_at is None]
Session.query.filter(
Session.user_id.in_(member_user_ids),
Session.deleted_at.is_(None),
).update({"deleted_at": db.func.now()})
db.session.commit()
logger.warning(f"[SuperadminOrg] Suspended organization {org_id}")
return org
@staticmethod
def restore_organization(org_id: str) -> Organization:
"""Restore a suspended organization.
Args:
org_id: Organization UUID
Returns:
Restored organization
Raises:
ValueError: If organization not found
"""
org = Organization.query.get(org_id)
if not org:
raise ValueError("Organization not found")
org.is_active = True
db.session.commit()
logger.info(f"[SuperadminOrg] Restored organization {org_id}")
return org
@staticmethod
def soft_delete_organization(org_id: str) -> Organization:
"""Soft-delete an organization.
Args:
org_id: Organization UUID
Returns:
Soft-deleted organization
Raises:
ValueError: If organization not found
"""
from datetime import datetime, timezone
org = Organization.query.get(org_id)
if not org:
raise ValueError("Organization not found")
org.deleted_at = datetime.now(timezone.utc)
db.session.commit()
logger.warning(f"[SuperadminOrg] Soft-deleted organization {org_id}")
return org
@@ -0,0 +1,199 @@
"""Usage tracking service for superadmin operations."""
import logging
from datetime import datetime, timedelta, timezone
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.user.session import Session
from gatehouse_app.models.organization.organization_member import OrganizationMember
logger = logging.getLogger(__name__)
class UsageMetric:
"""Usage metric types."""
USERS = "users"
SESSIONS = "sessions"
ACTIVE_SESSIONS = "active_sessions"
API_CALLS = "api_calls"
class SuperadminUsageService:
"""Service for tracking and retrieving usage metrics."""
@staticmethod
def get_current_usage(org_id: str) -> dict:
"""Get current period usage for an organization.
Args:
org_id: Organization UUID
Returns:
Current usage metrics including user count and active sessions
"""
org = Organization.query.get(org_id)
if not org:
raise ValueError("Organization not found")
# Get active member count
member_count = org.get_member_count()
# Get active sessions count
member_user_ids = [m.user_id for m in org.members if m.deleted_at is None]
active_sessions = Session.query.filter(
Session.user_id.in_(member_user_ids),
Session.deleted_at.is_(None),
Session.status == "active",
).count()
# Get max concurrent sessions this month
now = datetime.now(timezone.utc)
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
# For simplicity, we'll track peak concurrent sessions
# In production, you'd want a separate tracking table
max_sessions_this_month = active_sessions # Placeholder
return {
"organization_id": org_id,
"period_start": period_start.isoformat() + "Z",
"period_end": now.isoformat() + "Z",
"metrics": {
"users": {
"current": member_count,
"limit": None, # Will come from plan
"description": "Total organization members",
},
"active_sessions": {
"current": active_sessions,
"max_this_month": max_sessions_this_month,
"description": "Currently active user sessions",
},
},
}
@staticmethod
def get_usage_history(
org_id: str,
metric: str,
days: int = 30,
) -> dict:
"""Get usage history for a specific metric.
Args:
org_id: Organization UUID
metric: Metric type (users, sessions)
days: Number of days of history
Returns:
List of daily usage data points
"""
org = Organization.query.get(org_id)
if not org:
raise ValueError("Organization not found")
now = datetime.now(timezone.utc)
start_date = now - timedelta(days=days)
# Get member history (simplified - would need a history table in production)
history = []
current_count = org.get_member_count()
# Generate daily data points (placeholder - real implementation needs history table)
for i in range(days):
date = start_date + timedelta(days=i)
history.append({
"date": date.strftime("%Y-%m-%d"),
"value": current_count, # Simplified
})
return {
"organization_id": org_id,
"metric": metric,
"period_start": start_date.isoformat() + "Z",
"period_end": now.isoformat() + "Z",
"history": history,
}
@staticmethod
def get_seat_count_for_period(org_id: str, year: int, month: int) -> dict:
"""Calculate maximum seat count used in a given month.
For billing purposes - tracks the peak number of users.
Args:
org_id: Organization UUID
year: Year
month: Month
Returns:
Seat count data for the period
"""
org = Organization.query.get(org_id)
if not org:
raise ValueError("Organization not found")
# Calculate first and last day of month
first_day = datetime(year, month, 1, tzinfo=timezone.utc)
if month == 12:
last_day = datetime(year + 1, 1, 1, tzinfo=timezone.utc) - timedelta(seconds=1)
else:
last_day = datetime(year, month + 1, 1, tzinfo=timezone.utc) - timedelta(seconds=1)
# Get all members that existed during this period
members = OrganizationMember.query.filter(
OrganizationMember.organization_id == org_id,
OrganizationMember.deleted_at.is_(None),
).all()
# Count unique users who were members at any point during the month
max_seats = len(members)
# Current count at end of month
current_seats = len([m for m in members if m.deleted_at is None or m.deleted_at > last_day])
return {
"organization_id": org_id,
"period": f"{year}-{month:02d}",
"max_seats": max_seats,
"current_seats": current_seats,
"period_start": first_day.isoformat() + "Z",
"period_end": last_day.isoformat() + "Z",
}
@staticmethod
def adjust_usage(
org_id: str,
metric: str,
adjustment: int,
reason: str,
superadmin_id: str,
) -> dict:
"""Apply a manual usage adjustment (credit or charge).
Args:
org_id: Organization UUID
metric: Metric to adjust
adjustment: Positive (credit) or negative (charge)
reason: Reason for adjustment
superadmin_id: Superadmin making the adjustment
Returns:
Adjustment confirmation
"""
org = Organization.query.get(org_id)
if not org:
raise ValueError("Organization not found")
# In production, you'd create a UsageAdjustment record
# For now, just log and return
logger.warning(
f"[SuperadminUsage] Adjustment: org={org_id}, metric={metric}, "
f"adjustment={adjustment}, reason={reason}, by={superadmin_id}"
)
return {
"organization_id": org_id,
"metric": metric,
"adjustment": adjustment,
"reason": reason,
"applied_at": datetime.now(timezone.utc).isoformat() + "Z",
}
@@ -0,0 +1,371 @@
"""User management service for superadmin operations."""
import logging
from datetime import datetime, timezone
from gatehouse_app.models.user.user import User
from gatehouse_app.models.user.session import Session
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.extensions import db
logger = logging.getLogger(__name__)
class SuperadminUserService:
"""Service for managing users across the platform."""
@staticmethod
def list_users(
page: int = 1,
per_page: int = 20,
org_id: str = None,
status: str = None,
search: str = None,
) -> dict:
"""List users with filters and pagination.
Args:
page: Page number
per_page: Items per page
org_id: Filter by organization
status: Filter by status (active/suspended)
search: Search by email or name
Returns:
Paginated user list with metadata
"""
query = User.query.filter(User.deleted_at.is_(None))
# Filter by organization
if org_id:
member_user_ids = db.session.query(OrganizationMember.user_id).filter(
OrganizationMember.organization_id == org_id,
OrganizationMember.deleted_at.is_(None),
).all()
user_ids = [m.user_id for m in member_user_ids]
query = query.filter(User.id.in_(user_ids))
# Filter by status
if status == "suspended":
query = query.filter(User.status == "GLOBAL_SUSPENDED")
elif status == "active":
query = query.filter(User.status != "GLOBAL_SUSPENDED")
# Search
if search:
search_filter = f"%{search}%"
query = query.filter(
db.or_(
User.email.ilike(search_filter),
User.full_name.ilike(search_filter),
)
)
query = query.order_by(User.created_at.desc())
total = query.count()
users = query.offset((page - 1) * per_page).limit(per_page).all()
items = []
for user in users:
# Get org memberships
memberships = OrganizationMember.query.filter(
OrganizationMember.user_id == user.id,
OrganizationMember.deleted_at.is_(None),
).all()
orgs = []
for m in memberships:
org = Organization.query.get(m.organization_id)
if org:
orgs.append({
"org_id": org.id,
"org_name": org.name,
"role": m.role,
})
# Get active sessions count
active_sessions = Session.query.filter(
Session.user_id == user.id,
Session.deleted_at.is_(None),
Session.status == "active",
).count()
items.append({
"id": user.id,
"email": user.email,
"full_name": user.full_name,
"status": user.status,
"org_count": len(orgs),
"orgs": orgs,
"active_sessions": active_sessions,
"last_login_at": user.last_login_at.isoformat() + "Z" if user.last_login_at else None,
"created_at": user.created_at.isoformat() + "Z" if user.created_at else None,
})
return {
"items": items,
"total": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page if per_page > 0 else 0,
}
@staticmethod
def get_user_detail(user_id: str) -> dict:
"""Get detailed user information.
Args:
user_id: User UUID
Returns:
User detail with orgs, sessions, security methods
"""
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
raise ValueError("User not found")
# Get org memberships
memberships = OrganizationMember.query.filter(
OrganizationMember.user_id == user_id,
OrganizationMember.deleted_at.is_(None),
).all()
orgs = []
for m in memberships:
org = Organization.query.get(m.organization_id)
if org:
orgs.append({
"org_id": org.id,
"org_name": org.name,
"org_slug": org.slug,
"role": m.role,
"joined_at": m.created_at.isoformat() + "Z" if m.created_at else None,
})
# Get active sessions
sessions = Session.query.filter(
Session.user_id == user_id,
Session.deleted_at.is_(None),
Session.status == "active",
).all()
active_sessions = [{
"id": s.id,
"ip_address": s.ip_address,
"user_agent": s.user_agent,
"created_at": s.created_at.isoformat() + "Z" if s.created_at else None,
} for s in sessions]
# Security methods
security_methods = []
if hasattr(user, 'totp_enabled') and user.totp_enabled:
security_methods.append({"type": "totp", "enabled": True})
if hasattr(user, 'webauthn_enabled') and user.webauthn_enabled:
security_methods.append({"type": "webauthn", "enabled": True})
return {
"user": {
"id": user.id,
"email": user.email,
"full_name": user.full_name,
"status": user.status,
"mfa_enabled": user.mfa_enabled if hasattr(user, 'mfa_enabled') else False,
"last_login_at": user.last_login_at.isoformat() + "Z" if user.last_login_at else None,
"created_at": user.created_at.isoformat() + "Z" if user.created_at else None,
},
"organizations": orgs,
"active_sessions": active_sessions,
"security_methods": security_methods,
}
@staticmethod
def suspend_user(user_id: str) -> dict:
"""Globally suspend a user.
Args:
user_id: User UUID
Returns:
Updated user info and count of revoked sessions
"""
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
raise ValueError("User not found")
if user.status == "GLOBAL_SUSPENDED":
raise ValueError("User is already suspended")
user.status = "GLOBAL_SUSPENDED"
db.session.commit()
# Revoke all sessions
revoked_count = Session.query.filter(
Session.user_id == user_id,
Session.deleted_at.is_(None),
).update({"status": "revoked", "deleted_at": db.func.now()})
db.session.commit()
return {
"user": {
"id": user.id,
"email": user.email,
"status": user.status,
},
"sessions_revoked": revoked_count,
}
@staticmethod
def unsuspend_user(user_id: str) -> dict:
"""Remove global suspension from a user.
Args:
user_id: User UUID
Returns:
Updated user info
"""
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
raise ValueError("User not found")
if user.status != "GLOBAL_SUSPENDED":
raise ValueError("User is not suspended")
user.status = "active"
db.session.commit()
return {
"user": {
"id": user.id,
"email": user.email,
"status": user.status,
},
}
@staticmethod
def reset_password(user_id: str) -> dict:
"""Trigger password reset for user.
Args:
user_id: User UUID
Returns:
Email of user
"""
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
raise ValueError("User not found")
# In production, this would call AuthService.send_password_reset_email
logger.info(f"[SuperadminUserService] Password reset requested for {user.email}")
return {"email": user.email}
@staticmethod
def revoke_all_sessions(user_id: str) -> dict:
"""Revoke all sessions for a user.
Args:
user_id: User UUID
Returns:
Count of revoked sessions
"""
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
raise ValueError("User not found")
result = Session.query.filter(
Session.user_id == user_id,
Session.deleted_at.is_(None),
).update({"status": "revoked", "deleted_at": db.func.now()})
db.session.commit()
return {
"user_id": user_id,
"count": result,
}
@staticmethod
def add_to_org(user_id: str, org_id: str, role: str = "member") -> dict:
"""Add a user to an organization.
Args:
user_id: User UUID
org_id: Organization UUID
role: Membership role
Returns:
Membership details
"""
user = User.query.get(user_id)
if not user or user.deleted_at is not None:
raise ValueError("User not found")
org = Organization.query.get(org_id)
if not org or org.deleted_at is not None:
raise ValueError("Organization not found")
# Check if already a member
existing = OrganizationMember.query.filter(
OrganizationMember.user_id == user_id,
OrganizationMember.organization_id == org_id,
OrganizationMember.deleted_at.is_(None),
).first()
if existing:
raise ValueError("User is already a member of this organization")
membership = OrganizationMember(
user_id=user_id,
organization_id=org_id,
role=role,
)
db.session.add(membership)
db.session.commit()
return {
"user_id": user_id,
"organization_id": org_id,
"role": role,
"joined_at": membership.created_at.isoformat() + "Z" if membership.created_at else None,
}
@staticmethod
def remove_from_org(user_id: str, org_id: str) -> dict:
"""Remove a user from an organization.
Args:
user_id: User UUID
org_id: Organization UUID
Returns:
Confirmation
"""
membership = OrganizationMember.query.filter(
OrganizationMember.user_id == user_id,
OrganizationMember.organization_id == org_id,
OrganizationMember.deleted_at.is_(None),
).first()
if not membership:
raise ValueError("User is not a member of this organization")
# Check if user is the only owner
if membership.role == "owner":
owner_count = OrganizationMember.query.filter(
OrganizationMember.organization_id == org_id,
OrganizationMember.role == "owner",
OrganizationMember.deleted_at.is_(None),
).count()
if owner_count <= 1:
raise ValueError("Cannot remove the only owner from an organization. Transfer ownership first.")
membership.deleted_at = db.func.now()
db.session.commit()
return {
"user_id": user_id,
"organization_id": org_id,
}
@@ -283,9 +283,9 @@ def reconcile_deleted_memberships() -> dict:
if not device or not network:
logger.warning(
f"[Reconciliation] Membership {membership.id}: missing "
f"{'device' if not device else 'network'}hard-deleting record only."
f"{'device' if not device else 'network'}soft-deleting record only."
)
db.session.delete(membership)
membership.delete(soft=True)
db.session.commit()
results["deleted"] += 1
continue
@@ -304,20 +304,20 @@ def reconcile_deleted_memberships() -> dict:
except Exception as zt_exc:
logger.warning(
f"[Reconciliation] ZT delete failed for node {node_id} "
f"on {network_label}: {zt_exc} — proceeding with DB hard-delete."
f"on {network_label}: {zt_exc} — proceeding with DB soft-delete."
)
db.session.delete(membership)
membership.delete(soft=True)
db.session.commit()
results["deleted"] += 1
logger.debug(
f"[Reconciliation] Hard-deleted membership {membership.id} "
f"[Reconciliation] Soft-deleted membership {membership.id} "
f"(node={node_id}, network={network_label})."
)
except Exception as exc:
logger.error(
f"[Reconciliation] Failed to hard-delete membership {membership.id}: {exc}",
f"[Reconciliation] Failed to soft-delete membership {membership.id}: {exc}",
exc_info=True,
)
results["errors"] += 1
+1
View File
@@ -75,6 +75,7 @@ class AuditAction(str, Enum):
ORG_MEMBER_REMOVE = "org.member.remove"
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
ORG_OWNERSHIP_TRANSFERRED = "org.ownership.transferred"
ORG_INVITE_SENT = "org.invite.sent"
# Session actions
SESSION_CREATE = "session.create"
+3 -5
View File
@@ -59,11 +59,9 @@ def login_required(f):
error_type="SESSION_INACTIVE"
)
# Update last_activity_at timestamp
from datetime import datetime, timezone
session.last_activity_at = datetime.now(timezone.utc)
from gatehouse_app import db
db.session.commit()
# Extend session via sliding window (updates last_activity_at
# and recalculates expires_at within the idle / absolute caps).
session.refresh()
# Set context variables
g.current_user = session.user
+25
View File
@@ -153,6 +153,31 @@ def mfa_compliance_status():
print("=" * 60)
@cli.command("cleanup_sessions")
def cleanup_sessions():
"""Clean up expired user sessions.
Marks sessions as EXPIRED when they have passed their expires_at
timestamp. Safe to run frequently (e.g. every 5 minutes via job_runner).
Usage:
python manage.py cleanup_sessions
"""
from gatehouse_app.services.session_service import SessionService
print("=" * 60)
print("Session Cleanup Job")
print("=" * 60)
from datetime import datetime, timezone
print(f"Start time: {datetime.now(timezone.utc).isoformat()}")
count = SessionService.cleanup_expired_sessions()
print(f"Expired sessions marked: {count}")
print("=" * 60)
@cli.command("configure_oauth")
@click.argument("provider", required=False)
@click.option("--client-id", default=None, help="OAuth client ID")
@@ -29,8 +29,7 @@ def upgrade():
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_application_provider_configs_provider_type'), 'application_provider_configs', ['provider_type'], unique=True)
op.create_table('oidc_jwks_keys',
@@ -63,8 +62,7 @@ def upgrade():
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_organizations_slug'), 'organizations', ['slug'], unique=True)
op.create_table('users',
@@ -81,8 +79,7 @@ def upgrade():
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_activation_key'), 'users', ['activation_key'], unique=True)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
@@ -105,8 +102,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_audit_org', 'audit_logs', ['organization_id', 'created_at'], unique=False)
op.create_index('idx_audit_resource', 'audit_logs', ['resource_type', 'resource_id'], unique=False)
@@ -135,7 +131,6 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('user_id', 'method_type', 'provider_user_id', name='uix_user_method_provider')
)
op.create_index('idx_user_method', 'authentication_methods', ['user_id', 'method_type'], unique=False)
@@ -165,7 +160,6 @@ def upgrade():
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('fingerprint'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('organization_id', 'name', name='uix_org_ca_name')
)
op.create_index('idx_ca_org_active', 'cas', ['organization_id', 'is_active'], unique=False)
@@ -182,7 +176,6 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('organization_id', 'name', name='uix_org_dept_name')
)
op.create_index(op.f('ix_departments_name'), 'departments', ['name'], unique=False)
@@ -202,8 +195,7 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_devices_node_id'), 'devices', ['node_id'], unique=False)
op.create_index(op.f('ix_devices_organization_id'), 'devices', ['organization_id'], unique=False)
@@ -218,8 +210,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_email_verification_tokens_token'), 'email_verification_tokens', ['token'], unique=True)
op.create_index(op.f('ix_email_verification_tokens_user_id'), 'email_verification_tokens', ['user_id'], unique=False)
@@ -242,7 +233,6 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('organization_id', 'provider_type', name='uix_org_provider_type')
)
op.create_index('idx_provider_config_org', 'external_provider_configs', ['organization_id', 'provider_type'], unique=False)
@@ -262,8 +252,7 @@ def upgrade():
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.ForeignKeyConstraint(['target_user_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['triggered_by_user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_kill_switch_events_organization_id'), 'kill_switch_events', ['organization_id'], unique=False)
op.create_index(op.f('ix_kill_switch_events_target_user_id'), 'kill_switch_events', ['target_user_id'], unique=False)
@@ -285,7 +274,6 @@ def upgrade():
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('user_id', 'organization_id', name='uix_user_org_compliance')
)
op.create_index(op.f('ix_mfa_policy_compliance_organization_id'), 'mfa_policy_compliance', ['organization_id'], unique=False)
@@ -310,8 +298,7 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_oauth_states_expires_at'), 'oauth_states', ['expires_at'], unique=False)
op.create_index(op.f('ix_oauth_states_organization_id'), 'oauth_states', ['organization_id'], unique=False)
@@ -340,8 +327,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_oidc_clients_client_id'), 'oidc_clients', ['client_id'], unique=True)
op.create_index(op.f('ix_oidc_clients_organization_id'), 'oidc_clients', ['organization_id'], unique=False)
@@ -359,8 +345,7 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['invited_by_id'], ['users.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_org_invite_tokens_email'), 'org_invite_tokens', ['email'], unique=False)
op.create_index(op.f('ix_org_invite_tokens_organization_id'), 'org_invite_tokens', ['organization_id'], unique=False)
@@ -379,8 +364,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_api_key_last_used', 'organization_api_keys', ['last_used_at'], unique=False)
op.create_index('idx_org_api_key_org_active', 'organization_api_keys', ['organization_id', 'is_revoked'], unique=False)
@@ -402,7 +386,6 @@ def upgrade():
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('user_id', 'organization_id', name='uix_user_org')
)
op.create_index(op.f('ix_organization_members_organization_id'), 'organization_members', ['organization_id'], unique=False)
@@ -421,7 +404,6 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('organization_id', 'provider_type', name='uix_org_provider_override_type')
)
op.create_index(op.f('ix_organization_provider_overrides_organization_id'), 'organization_provider_overrides', ['organization_id'], unique=False)
@@ -439,8 +421,7 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.ForeignKeyConstraint(['updated_by_user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_organization_security_policies_organization_id'), 'organization_security_policies', ['organization_id'], unique=True)
op.create_table('password_reset_tokens',
@@ -453,8 +434,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_password_reset_tokens_token'), 'password_reset_tokens', ['token'], unique=True)
op.create_index(op.f('ix_password_reset_tokens_user_id'), 'password_reset_tokens', ['user_id'], unique=False)
@@ -476,7 +456,6 @@ def upgrade():
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.ForeignKeyConstraint(['owner_user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('organization_id', 'zerotier_network_id', name='uix_org_zt_network_id')
)
op.create_index(op.f('ix_portal_networks_organization_id'), 'portal_networks', ['organization_id'], unique=False)
@@ -491,7 +470,6 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('organization_id', 'name', name='uix_org_principal_name')
)
op.create_index(op.f('ix_principals_name'), 'principals', ['name'], unique=False)
@@ -513,8 +491,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_sessions_token'), 'sessions', ['token'], unique=True)
op.create_index(op.f('ix_sessions_user_id'), 'sessions', ['user_id'], unique=False)
@@ -536,7 +513,6 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('payload')
)
op.create_index('idx_ssh_key_user_verified', 'ssh_keys', ['user_id', 'verified'], unique=False)
@@ -556,7 +532,6 @@ def upgrade():
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('user_id', 'organization_id', name='uix_user_org_policy')
)
op.create_index(op.f('ix_user_security_policies_organization_id'), 'user_security_policies', ['organization_id'], unique=False)
@@ -572,8 +547,7 @@ def upgrade():
sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('ca_id', 'user_id', name='uix_ca_permission'),
sa.UniqueConstraint('id')
sa.UniqueConstraint('ca_id', 'user_id', name='uix_ca_permission')
)
op.create_index(op.f('ix_ca_permissions_ca_id'), 'ca_permissions', ['ca_id'], unique=False)
op.create_index(op.f('ix_ca_permissions_user_id'), 'ca_permissions', ['user_id'], unique=False)
@@ -589,8 +563,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_department_cert_policies_department_id'), 'department_cert_policies', ['department_id'], unique=True)
op.create_table('department_memberships',
@@ -603,7 +576,6 @@ def upgrade():
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('user_id', 'department_id', name='uix_user_dept')
)
op.create_index(op.f('ix_department_memberships_department_id'), 'department_memberships', ['department_id'], unique=False)
@@ -618,8 +590,7 @@ def upgrade():
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('department_id', 'principal_id', name='uix_dept_principal'),
sa.UniqueConstraint('id')
sa.UniqueConstraint('department_id', 'principal_id', name='uix_dept_principal')
)
op.create_index(op.f('ix_department_principals_department_id'), 'department_principals', ['department_id'], unique=False)
op.create_index(op.f('ix_department_principals_principal_id'), 'department_principals', ['principal_id'], unique=False)
@@ -640,8 +611,7 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['client_id'], ['oidc_clients.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_oidc_audit_logs_client_id'), 'oidc_audit_logs', ['client_id'], unique=False)
op.create_index(op.f('ix_oidc_audit_logs_event_type'), 'oidc_audit_logs', ['event_type'], unique=False)
@@ -668,8 +638,7 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['client_id'], ['oidc_clients.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_oidc_authorization_codes_client_id'), 'oidc_authorization_codes', ['client_id'], unique=False)
op.create_index(op.f('ix_oidc_authorization_codes_expires_at'), 'oidc_authorization_codes', ['expires_at'], unique=False)
@@ -693,8 +662,7 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['client_id'], ['oidc_clients.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_oidc_refresh_tokens_access_token_id'), 'oidc_refresh_tokens', ['access_token_id'], unique=False)
op.create_index(op.f('ix_oidc_refresh_tokens_client_id'), 'oidc_refresh_tokens', ['client_id'], unique=False)
@@ -718,8 +686,7 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['client_id'], ['oidc_clients.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_oidc_sessions_client_id'), 'oidc_sessions', ['client_id'], unique=False)
op.create_index(op.f('ix_oidc_sessions_expires_at'), 'oidc_sessions', ['expires_at'], unique=False)
@@ -755,7 +722,6 @@ def upgrade():
sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('user_id', 'principal_id', name='uix_user_principal')
)
op.create_index(op.f('ix_principal_memberships_principal_id'), 'principal_memberships', ['principal_id'], unique=False)
@@ -787,8 +753,7 @@ def upgrade():
sa.ForeignKeyConstraint(['ssh_key_id'], ['ssh_keys.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('ca_id', 'serial', name='uq_ssh_certificates_ca_serial'),
sa.UniqueConstraint('id')
sa.UniqueConstraint('ca_id', 'serial', name='uq_ssh_certificates_ca_serial')
)
op.create_index('idx_cert_revoked', 'ssh_certificates', ['revoked', 'revoked_at'], unique=False)
op.create_index('idx_cert_user_status', 'ssh_certificates', ['user_id', 'status'], unique=False)
@@ -816,7 +781,6 @@ def upgrade():
sa.ForeignKeyConstraint(['portal_network_id'], ['portal_networks.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('user_id', 'portal_network_id', 'deleted_at', name='uix_user_network_approval')
)
op.create_index(op.f('ix_user_network_approvals_organization_id'), 'user_network_approvals', ['organization_id'], unique=False)
@@ -840,8 +804,7 @@ def upgrade():
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['certificate_id'], ['ssh_certificates.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_cert_audit_cert_action', 'certificate_audit_logs', ['certificate_id', 'action'], unique=False)
op.create_index('idx_cert_audit_user', 'certificate_audit_logs', ['user_id', 'created_at'], unique=False)
@@ -868,8 +831,7 @@ def upgrade():
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['user_network_approval_id'], ['user_network_approvals.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('device_id', 'portal_network_id', 'deleted_at', name='uix_device_network'),
sa.UniqueConstraint('id')
sa.UniqueConstraint('device_id', 'portal_network_id', 'deleted_at', name='uix_device_network')
)
op.create_index(op.f('ix_device_network_memberships_device_id'), 'device_network_memberships', ['device_id'], unique=False)
op.create_index(op.f('ix_device_network_memberships_organization_id'), 'device_network_memberships', ['organization_id'], unique=False)
@@ -894,8 +856,7 @@ def upgrade():
sa.ForeignKeyConstraint(['device_network_membership_id'], ['device_network_memberships.id'], ),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_activation_sessions_device_network_membership_id'), 'activation_sessions', ['device_network_membership_id'], unique=False)
op.create_index(op.f('ix_activation_sessions_organization_id'), 'activation_sessions', ['organization_id'], unique=False)
@@ -917,7 +878,6 @@ def upgrade():
sa.ForeignKeyConstraint(['device_network_membership_id'], ['device_network_memberships.id'], ),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('zerotier_network_id', 'node_id', name='uix_zt_network_node')
)
op.create_index(op.f('ix_zerotier_memberships_device_network_membership_id'), 'zerotier_memberships', ['device_network_membership_id'], unique=False)
@@ -0,0 +1,50 @@
"""Add organization_id to certificate_audit_logs.
Revision ID: 8f2d9e4a7c1b
Revises: b4cd6c6b3b1c
Create Date: 2026-04-23 07:30:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '8f2d9e4a7c1b'
down_revision = 'b4cd6c6b3b1c'
branch_labels = None
depends_on = None
def upgrade():
# Add organization_id column to certificate_audit_logs
op.add_column(
'certificate_audit_logs',
sa.Column('organization_id', sa.String(length=36), nullable=True)
)
# Create index on organization_id
op.create_index(
op.f('ix_certificate_audit_logs_organization_id'),
'certificate_audit_logs',
['organization_id']
)
# Create foreign key constraint
op.create_foreign_key(
'fk_cert_audit_log_organization',
'certificate_audit_logs',
'organizations',
['organization_id'],
['id']
)
def downgrade():
# Drop foreign key constraint
op.drop_constraint('fk_cert_audit_log_organization', 'certificate_audit_logs', type_='foreignkey')
# Drop index
op.drop_index(op.f('ix_certificate_audit_logs_organization_id'), 'certificate_audit_logs')
# Drop organization_id column
op.drop_column('certificate_audit_logs', 'organization_id')
@@ -0,0 +1,44 @@
"""Per-user SSH key fingerprint uniqueness.
Revision ID: a1b2c3d4e5f6
Revises: 8f2d9e4a7c1b
Create Date: 2026-04-24 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a1b2c3d4e5f6'
down_revision = '8f2d9e4a7c1b'
branch_labels = None
depends_on = None
def upgrade():
# Drop the global unique constraint on payload
op.drop_constraint('ssh_keys_payload_key', 'ssh_keys', type_='unique')
# Drop the global unique index on fingerprint
op.drop_index('ix_ssh_keys_fingerprint', table_name='ssh_keys')
# Create a non-unique index on fingerprint for query performance
op.create_index(op.f('ix_ssh_keys_fingerprint'), 'ssh_keys', ['fingerprint'], unique=False)
# Add composite unique constraint for per-user fingerprint uniqueness
op.create_unique_constraint('uix_user_fingerprint', 'ssh_keys', ['user_id', 'fingerprint'])
def downgrade():
# Drop the composite unique constraint
op.drop_constraint('uix_user_fingerprint', 'ssh_keys', type_='unique')
# Drop the non-unique index
op.drop_index(op.f('ix_ssh_keys_fingerprint'), table_name='ssh_keys')
# Recreate the global unique index
op.create_index('ix_ssh_keys_fingerprint', 'ssh_keys', ['fingerprint'], unique=True)
# Recreate the global unique constraint on payload
op.create_unique_constraint('ssh_keys_payload_key', 'ssh_keys', ['payload'])
@@ -0,0 +1,100 @@
"""Superadmin
Revision ID: b4cd6c6b3b1c
Revises: 6a4c4ed4a5c6
Create Date: 2026-04-08 16:55:52.646980
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'b4cd6c6b3b1c'
down_revision = '6a4c4ed4a5c6'
branch_labels = None
depends_on = None
def upgrade():
# --- Create superadmin tables (not captured by auto-generation) ---
op.create_table(
'superadmins',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('email', sa.String(length=255), nullable=False),
sa.Column('password_hash', sa.String(length=255), nullable=False),
sa.Column('full_name', sa.String(length=255), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.text('true')),
sa.Column('last_login_at', sa.DateTime(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id'),
)
op.create_index(op.f('ix_superadmins_email'), 'superadmins', ['email'], unique=True)
op.create_table(
'superadmin_sessions',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('superadmin_id', sa.String(length=36), nullable=False),
sa.Column('token', sa.String(length=255), nullable=False),
sa.Column('expires_at', sa.DateTime(), nullable=False),
sa.Column('last_activity_at', sa.DateTime(), nullable=False),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.Text(), nullable=True),
sa.Column('revoked_at', sa.DateTime(), nullable=True),
sa.Column('revoked_reason', sa.String(length=255), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['superadmin_id'], ['superadmins.id']),
sa.PrimaryKeyConstraint('id'),
)
op.create_index(op.f('ix_superadmin_sessions_superadmin_id'), 'superadmin_sessions', ['superadmin_id'])
op.create_index(op.f('ix_superadmin_sessions_token'), 'superadmin_sessions', ['token'], unique=True)
op.create_table(
'superadmin_audit_logs',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('superadmin_id', sa.String(length=36), nullable=False),
sa.Column('action', sa.String(length=100), nullable=False),
sa.Column('resource_type', sa.String(length=50), nullable=False),
sa.Column('resource_id', sa.String(length=36), nullable=True),
sa.Column('org_id', sa.String(length=36), nullable=True),
sa.Column('user_id', sa.String(length=36), nullable=True),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.Text(), nullable=True),
sa.Column('request_id', sa.String(length=100), nullable=True),
sa.Column('extra_data', sa.JSON(), nullable=True),
sa.Column('success', sa.Boolean(), nullable=False, server_default=sa.text('true')),
sa.Column('error_message', sa.String(length=500), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['superadmin_id'], ['superadmins.id']),
sa.PrimaryKeyConstraint('id'),
)
op.create_index(op.f('ix_superadmin_audit_logs_superadmin_id'), 'superadmin_audit_logs', ['superadmin_id'])
op.create_index(op.f('ix_superadmin_audit_logs_action'), 'superadmin_audit_logs', ['action'])
op.create_index(op.f('ix_superadmin_audit_logs_resource_type'), 'superadmin_audit_logs', ['resource_type'])
op.create_index(op.f('ix_superadmin_audit_logs_resource_id'), 'superadmin_audit_logs', ['resource_id'])
op.create_index(op.f('ix_superadmin_audit_logs_org_id'), 'superadmin_audit_logs', ['org_id'])
op.create_index(op.f('ix_superadmin_audit_logs_user_id'), 'superadmin_audit_logs', ['user_id'])
def downgrade():
# --- Drop superadmin tables (reverse order due to FK dependencies) ---
op.drop_index(op.f('ix_superadmin_audit_logs_user_id'), table_name='superadmin_audit_logs')
op.drop_index(op.f('ix_superadmin_audit_logs_org_id'), table_name='superadmin_audit_logs')
op.drop_index(op.f('ix_superadmin_audit_logs_resource_id'), table_name='superadmin_audit_logs')
op.drop_index(op.f('ix_superadmin_audit_logs_resource_type'), table_name='superadmin_audit_logs')
op.drop_index(op.f('ix_superadmin_audit_logs_action'), table_name='superadmin_audit_logs')
op.drop_index(op.f('ix_superadmin_audit_logs_superadmin_id'), table_name='superadmin_audit_logs')
op.drop_table('superadmin_audit_logs')
op.drop_index(op.f('ix_superadmin_sessions_token'), table_name='superadmin_sessions')
op.drop_index(op.f('ix_superadmin_sessions_superadmin_id'), table_name='superadmin_sessions')
op.drop_table('superadmin_sessions')
op.drop_index(op.f('ix_superadmins_email'), table_name='superadmins')
op.drop_table('superadmins')
+3
View File
@@ -6,6 +6,9 @@ python_functions = test_*
addopts =
-v
--strict-markers
-p no:maas-django
-p no:maas-perftest
-p no:maas-seeds
--cov=gatehouse_app
--cov-report=term-missing
--cov-report=html
+106
View File
@@ -0,0 +1,106 @@
#!/usr/bin/env python3
"""Generic job runner for scheduled tasks in Docker containers.
Runs a Flask CLI command on a configurable interval with graceful shutdown support.
Environment Variables:
JOB_NAME: Name of the job to run (zerotier_reconciliation, mfa_compliance)
JOB_INTERVAL_SECONDS: Seconds between job runs (default: 300)
Usage:
docker run -e JOB_NAME=zerotier_reconciliation -e JOB_INTERVAL_SECONDS=120 app
"""
import os
import signal
import subprocess
import sys
import time
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
JOB_COMMANDS = {
"zerotier_reconciliation": "python manage.py run_zerotier_reconciliation",
"mfa_compliance": "python manage.py run_mfa_compliance_job",
"cleanup_sessions": "python manage.py cleanup_sessions",
}
shutdown_requested = False
def signal_handler(signum, frame):
global shutdown_requested
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
shutdown_requested = True
def run_job(job_name: str) -> bool:
command = JOB_COMMANDS.get(job_name)
if not command:
logger.error(f"Unknown job: {job_name}. Valid jobs: {list(JOB_COMMANDS.keys())}")
return False
logger.info(f"Running job: {job_name}")
start_time = time.monotonic()
try:
result = subprocess.run(
command,
shell=True,
cwd="/app",
capture_output=False,
)
elapsed = time.monotonic() - start_time
logger.info(f"Job {job_name} completed in {elapsed:.2f}s with exit code {result.returncode}")
return result.returncode == 0
except Exception as e:
elapsed = time.monotonic() - start_time
logger.error(f"Job {job_name} failed after {elapsed:.2f}s: {e}")
return False
def main():
job_name = os.getenv("JOB_NAME")
interval = int(os.getenv("JOB_INTERVAL_SECONDS", "300"))
if not job_name:
logger.error("JOB_NAME environment variable is required")
sys.exit(1)
if job_name not in JOB_COMMANDS:
logger.error(f"Unknown JOB_NAME: {job_name}. Valid: {list(JOB_COMMANDS.keys())}")
sys.exit(1)
if interval < 10:
logger.error(f"JOB_INTERVAL_SECONDS must be at least 10 seconds, got {interval}")
sys.exit(1)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
logger.info(f"Job runner started: {job_name}, interval={interval}s")
logger.info(f"Valid jobs: {list(JOB_COMMANDS.keys())}")
while not shutdown_requested:
run_job(job_name)
if shutdown_requested:
break
logger.info(f"Sleeping for {interval}s until next run...")
sleep_start = time.monotonic()
while time.monotonic() - sleep_start < interval and not shutdown_requested:
time.sleep(1)
logger.info("Job runner stopped")
if __name__ == "__main__":
main()
+89
View File
@@ -0,0 +1,89 @@
"""Seed script for creating superadmin user.
This script reads SUPERADMIN_EMAIL and SUPERADMIN_SECRET environment variables.
If both are present:
- Creates a superadmin with the given email and hashed credential
- Idempotent: updates existing superadmin if email already exists
If either is absent:
- No-op (logs that env vars are not set)
Usage:
export SUPERADMIN_EMAIL="admin@example.com"
export SUPERADMIN_SECRET="[SetYourSecureSecretHere]"
python scripts/seed_superadmin.py
"""
import sys
import os
import logging
from dotenv import load_dotenv
# Load environment variables FIRST before any app imports
load_dotenv()
from gatehouse_app import create_app
from gatehouse_app.extensions import db, bcrypt
from gatehouse_app.models.superadmin import Superadmin
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def main():
"""Main entry point."""
app = create_app()
with app.app_context():
# Read environment variables
email = os.environ.get("SUPERADMIN_EMAIL")
credential = os.environ.get("SUPERADMIN_SECRET")
# Check if env vars are set
if not email or not credential:
logger.info("SUPERADMIN_EMAIL and/or SUPERADMIN_SECRET not set. No-op.")
logger.info("To create a superadmin, set both environment variables:")
logger.info(" export SUPERADMIN_EMAIL='admin@example.com'")
logger.info(" export SUPERADMIN_SECRET='[SetYourSecureSecretHere]'")
return
# Normalize email
email = email.lower().strip()
logger.info(f"Processing superadmin: {email}")
# Check if superadmin already exists
existing = Superadmin.query.filter_by(email=email).first()
if existing:
# Update existing superadmin
logger.info(" → Existing superadmin found, updating credential")
existing.password_hash = bcrypt.generate_password_hash(credential).decode("utf-8")
existing.is_active = True
existing.full_name = existing.full_name or "Super Admin"
db.session.commit()
logger.info(f" → Updated superadmin: {email} (id={existing.id})")
print(f"Updated existing superadmin: {email}")
else:
# Create new superadmin
logger.info(" → Creating new superadmin")
password_hash = bcrypt.generate_password_hash(credential).decode("utf-8")
superadmin = Superadmin(
email=email,
password_hash=password_hash,
full_name="Super Admin",
is_active=True,
)
db.session.add(superadmin)
db.session.commit()
logger.info(f" → Created superadmin: {email} (id={superadmin.id})")
print(f"Created new superadmin: {email}")
logger.info("Seed complete.")
if __name__ == "__main__":
main()
+228
View File
@@ -0,0 +1,228 @@
# Secuird Integration Test Suite
This directory contains the integration test suite for the Secuird IAM platform.
## Quick Start
Run all integration tests:
```bash
cd backend
pytest tests/integration/
```
Run a specific test file:
```bash
pytest tests/integration/test_ssh_workflows.py -v
```
Run without coverage (faster):
```bash
pytest tests/integration/ --no-cov
```
Fail fast (stop on first failure):
```bash
pytest tests/integration/ -x
```
Run previously failed tests first:
```bash
pytest tests/integration/ --ff
```
## Test Structure
```
tests/
├── conftest.py # Base pytest fixtures (app, client, test_user)
├── integration/ # Integration tests
│ ├── conftest.py # Integration-specific fixtures and factories
│ ├── client/ # Reusable API client library
│ │ ├── base.py # SecuirdClient with session management
│ │ ├── auth.py # Authentication operations
│ │ ├── users.py # User self-service operations
│ │ ├── orgs.py # Organization operations
│ │ ├── ssh.py # SSH key/cert operations
│ │ ├── mfa.py # TOTP/WebAuthn operations
│ │ ├── zerotier.py # ZeroTier network operations
│ │ └── admin.py # Admin operations
│ ├── fixtures/ # Test data and helpers
│ │ ├── ssh_keys.py # Test SSH key pairs and helpers
│ │ └── test_data.py # Common test data generators
│ ├── test_auth_flows.py # Authentication flows (24 tests)
│ ├── test_totp_workflows.py # TOTP MFA flows (15 tests)
│ ├── test_ssh_workflows.py # SSH key/cert flows (34 tests)
│ ├── test_org_workflows.py # Organization & invite flows (27 tests)
│ ├── test_multi_org.py # Multi-organization access (4 tests)
│ ├── test_self_service.py # User self-service features (9 tests)
│ ├── test_admin_ops.py # Admin user management (9 tests)
│ ├── test_authorization.py # RBAC & access control (8 tests)
│ ├── test_security.py # Security & edge cases (5 tests)
│ ├── test_dept_principal.py # Department & principal management (5 tests)
│ ├── test_ca_management.py # Certificate authority management (4 tests)
│ ├── test_policy_compliance.py # Security policy & compliance (4 tests)
│ ├── test_webauthn_workflows.py# WebAuthn passkey flows (5 tests)
│ └── test_zerotier.py # ZeroTier network access (8 tests)
└── unit/ # Unit tests (existing)
```
## Environment
- **Python**: 3.10+
- **Database**: SQLite in-memory (`sqlite:///:memory:`)
- **Rate Limiting**: Disabled in tests (`RATELIMIT_ENABLED = False`)
- **CSRF**: Disabled (`WTF_CSRF_ENABLED = False`)
- **Email**: Suppressed (`MAIL_SUPPRESS_SEND = True`)
## Configuration
The `pytest.ini` file configures:
- Verbose output (`-v`)
- Coverage reporting (`--cov=gatehouse_app`)
- Disabled maas plugins that cause import errors (see Known Issues below)
- Custom markers for `unit`, `integration`, `slow`, etc.
## Coverage
Coverage reports are generated automatically:
- **Terminal**: printed after each run
- **HTML**: `backend/htmlcov/index.html`
Target coverage: **85% minimum**.
```bash
pytest tests/integration/ --cov=gatehouse_app --cov-fail-under=85
```
## Known Issues
### maastesting Plugin Import Error
The `maas` system package installs pytest entry points that fail to load in this environment. The `pytest.ini` file disables them automatically with:
```ini
-p no:maas-django
-p no:maas-perftest
-p no:maas-seeds
```
If you see `ModuleNotFoundError: No module named 'maastesting'`, these flags are not being applied. Ensure you run pytest from the `backend/` directory.
### ssh-keygen Not Available
One test (`test_verify_key_positive` in `test_ssh_workflows.py`) requires `ssh-keygen` to generate real Ed25519 key pairs for signature verification. It is automatically skipped when `ssh-keygen` is not available:
```bash
sudo apt-get install openssh-client # Debian/Ubuntu
```
Other certificate signing tests use a DB helper (`_mark_key_verified`) to bypass the signature requirement in CI environments.
## Writing New Tests
### Pattern
Every test must include a verbose docstring with `WHAT`, `WHY`, and `EXPECTED`:
```python
def test_add_key_positive(self, integration_client, create_test_user):
"""TEST: SSH-KEY-01 — Add a new SSH public key.
WHAT: Authenticated user POSTs a valid public key with a description.
WHY: Users must be able to register their SSH keys for later
certificate signing and server access.
EXPECTED: 201 Created, response contains key id and metadata.
"""
```
### Fixtures
| Fixture | Purpose |
|---------|---------|
| `integration_client` | Fresh `SecuirdClient` instance per test |
| `create_test_user` | Factory returning `{"id", "email", "password", "full_name"}` |
| `create_test_org` | Factory returning `{"id", "name", "slug"}` |
| `create_test_membership` | Links user to org with a role |
| `create_test_ca` | Creates a Certificate Authority for an org |
### Client Usage
```python
# Authentication
integration_client.auth.register(email, password, full_name)
integration_client.auth.login(email, password)
integration_client.auth.logout()
# SSH
integration_client.ssh.add_key(public_key, description)
integration_client.ssh.sign_certificate(key_id=key_id, principals=["deploy"])
integration_client.ssh.revoke_certificate(cert_id)
# Organizations
integration_client.orgs.create(name, slug)
integration_client.orgs.create_principal(org_id, name)
integration_client.orgs.create_ca(org_id, name, ca_type="user")
```
### Assertions
Use the standard helpers:
```python
def assert_success(response: dict, message_contains: str = "") -> dict:
data = response.get("data", {})
assert response.get("success") is not False
if message_contains:
assert message_contains.lower() in response.get("message", "").lower()
return data
# Negative tests
with pytest.raises(ApiError) as exc_info:
integration_client.ssh.get_key(str(uuid.uuid4()))
assert exc_info.value.status_code == 404
assert exc_info.value.error_type == "NOT_FOUND"
```
## Test Counts
| Module | Tests | Focus |
|--------|-------|-------|
| test_auth_flows.py | 24 | Registration, login, logout, sessions, password reset, email verification |
| test_totp_workflows.py | 15 | TOTP enrollment, verification, backup codes, disable, regenerate |
| test_ssh_workflows.py | 34 | Key CRUD, verification, certificate signing & management |
| test_org_workflows.py | 27 | Org CRUD, members, roles, invites, ownership transfer |
| test_multi_org.py | 4 | Cross-org isolation, role-based access |
| test_self_service.py | 9 | Profile, password change, account deletion |
| test_admin_ops.py | 9 | Suspend, unsuspend, verify email, set password, remove MFA, hard delete |
| test_authorization.py | 8 | RBAC, cross-user isolation, soft-delete behavior |
| test_security.py | 5 | SQL injection, XSS, oversized payload, malformed JSON, empty body |
| test_dept_principal.py | 5 | Department/principal CRUD, membership, linking |
| test_ca_management.py | 4 | CA creation, listing, rotation |
| test_policy_compliance.py | 4 | Security policy, MFA compliance |
| test_webauthn_workflows.py | 5 | WebAuthn registration/login (mocked) |
| test_zerotier.py | 8 | Network CRUD, devices, approvals, memberships (mocked) |
| **Total** | **162** | |
## Pre-Commit Checklist
Before committing backend changes:
1. Run the integration suite: `pytest tests/integration/ -x`
2. Verify coverage hasn't decreased: `pytest tests/integration/ --cov=gatehouse_app --cov-fail-under=85`
3. If tests fail, fix before committing
## CI/CD
Integration tests run automatically on:
- Every pull request
- Every push to main
- Nightly builds
**Failure policy**: Integration test failures block merging.
+1
View File
@@ -0,0 +1 @@
# Tests package
+1
View File
@@ -0,0 +1 @@
# API tests package
+1
View File
@@ -0,0 +1 @@
# API v1 tests package
View File
@@ -0,0 +1,52 @@
import pytest
from gatehouse_app.utils.constants import AuthMethodType
from gatehouse_app.services.external_auth.models import ExternalAuthError
from gatehouse_app.api.v1.external_auth._helpers import (
get_provider_type,
_get_provider_endpoints,
)
class TestProviderType:
def test_google(self):
assert get_provider_type("google") == AuthMethodType.GOOGLE
def test_github(self):
assert get_provider_type("github") == AuthMethodType.GITHUB
def test_microsoft(self):
assert get_provider_type("microsoft") == AuthMethodType.MICROSOFT
def test_case_insensitive(self):
assert get_provider_type("GitHub") == AuthMethodType.GITHUB
def test_unknown_provider_raises(self):
with pytest.raises(ExternalAuthError) as exc_info:
get_provider_type("facebook")
assert exc_info.value.status_code == 400
assert "facebook" in exc_info.value.message.lower()
class TestProviderEndpoints:
def test_google_endpoints(self):
auth, token, userinfo = _get_provider_endpoints(AuthMethodType.GOOGLE)
assert "accounts.google.com" in auth
assert "oauth2.googleapis.com" in token
assert "googleapis.com" in userinfo
def test_github_endpoints(self):
auth, token, userinfo = _get_provider_endpoints(AuthMethodType.GITHUB)
assert "github.com/login" in auth
assert "github.com/login/oauth/access_token" in token
assert "api.github.com/user" in userinfo
def test_microsoft_endpoints(self):
auth, token, userinfo = _get_provider_endpoints(AuthMethodType.MICROSOFT)
assert "login.microsoftonline.com" in auth
assert "login.microsoftonline.com" in token
assert "graph.microsoft.com" in userinfo
def test_unknown_type_raises(self):
with pytest.raises(ExternalAuthError) as exc_info:
_get_provider_endpoints("nonexistent")
assert exc_info.value.status_code == 400
@@ -0,0 +1,59 @@
import pytest
from gatehouse_app.api.v1.organizations._helpers import _get_system_ca_dict
from gatehouse_app.config.ssh_ca_config import SSHCAConfig, reset_config_instance
# Ed25519 key fixture data
VALID_PRIVATE_KEY = (
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\n"
"QyNTUxOQAAACCi+2CgIPgoFL5P6DZlNXztuHy3+TuS2shh/xIDkW89OgAAAJhDQd+ZQ0Hf\n"
"mQAAAAtzc2gtZWQyNTUxOQAAACCi+2CgIPgoFL5P6DZlNXztuHy3+TuS2shh/xIDkW89Og\n"
"AAAECMbnF+1E22w9Z1AOTUbUGspL8Pb0UyP+p8lSLpAwZSpaL7YKAg+CgUvk/oNmU1fO24\n"
"fLf5O5LayGH/EgORbz06AAAAD2NvcnlAbGFwdG9wLXZtMQECAwQFBg==\n"
"-----END OPENSSH PRIVATE KEY-----"
)
class FakeEmptyConfig(SSHCAConfig):
def get_str(self, key, default=""):
if key == "ca_key_path":
return ""
return default
class BadConfig(SSHCAConfig):
def get_str(self, key, default=""):
raise RuntimeError("config error")
class TestSystemCADict:
def test_no_key_available_returns_none(self, monkeypatch):
monkeypatch.delenv("SSH_CA_PRIVATE_KEY", raising=False)
reset_config_instance()
monkeypatch.setattr(
"gatehouse_app.config.ssh_ca_config.get_ssh_ca_config",
lambda: FakeEmptyConfig(),
)
result = _get_system_ca_dict()
assert result is None
def test_env_var_returns_dict(self, monkeypatch):
monkeypatch.setenv("SSH_CA_PRIVATE_KEY", VALID_PRIVATE_KEY)
result = _get_system_ca_dict()
assert result is not None
assert result["ca_type"] == "user"
assert result["is_system"] is True
assert "fingerprint" in result
assert result["public_key"]
assert result["public_key"].startswith("ssh-")
def test_exception_gracefully_returns_none(self, monkeypatch):
monkeypatch.delenv("SSH_CA_PRIVATE_KEY", raising=False)
reset_config_instance()
monkeypatch.setattr(
"gatehouse_app.config.ssh_ca_config.get_ssh_ca_config",
lambda: BadConfig(),
)
result = _get_system_ca_dict()
assert result is None
+1
View File
@@ -0,0 +1 @@
# SSH tests package
+79
View File
@@ -0,0 +1,79 @@
"""Pytest fixtures for API tests."""
import pytest
import uuid
from datetime import datetime, timezone
from gatehouse_app import create_app, db
from gatehouse_app.models.user.user import User
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.ssh_ca.ca import CA, CaType, KeyType
from gatehouse_app.utils.constants import OrganizationRole
@pytest.fixture
def app():
"""Create test Flask app with in-memory SQLite."""
app = create_app(config_name="testing")
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
app.config["TESTING"] = True
app.config["WTF_CSRF_ENABLED"] = False
with app.app_context():
db.create_all()
yield app
db.session.remove()
db.drop_all()
@pytest.fixture
def test_user(app):
"""Create a test user."""
with app.app_context():
user = User(email="test_user@test.com", full_name="Test User")
db.session.add(user)
db.session.commit()
return user.id
@pytest.fixture
def test_org(app):
"""Create a test organization."""
with app.app_context():
org = Organization(name="Test Org", slug="test-org")
db.session.add(org)
db.session.commit()
return org.id
@pytest.fixture
def test_membership(app, test_user, test_org):
"""Create a test membership."""
with app.app_context():
membership = OrganizationMember(
user_id=test_user,
organization_id=test_org,
role=OrganizationRole.MEMBER,
)
db.session.add(membership)
db.session.commit()
return membership.id
@pytest.fixture
def test_ca(app, test_org, test_membership):
"""Create a test CA."""
with app.app_context():
ca = CA(
organization_id=test_org,
name="Test CA",
ca_type=CaType.USER,
key_type=KeyType.ED25519,
private_key="encrypted_private_key_placeholder",
public_key="ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI...",
fingerprint="sha256:TEST123...",
is_active=True,
)
db.session.add(ca)
db.session.commit()
return ca.id
+143
View File
@@ -0,0 +1,143 @@
import pytest
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.user.user import User
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.ssh_ca.ca import CA, CaType, KeyType
from gatehouse_app.api.v1.ssh._helpers import _get_org_ca_for_user
from gatehouse_app.utils.constants import OrganizationRole
class TestCASoftDelete:
"""Test CA soft delete handling."""
def test_active_ca_is_returned(self, app, test_user, test_org, test_ca, test_membership):
"""Active CA should be returned."""
with app.app_context():
user = db.session.get(User, test_user)
ca = _get_org_ca_for_user(user, ca_type='user')
assert ca is not None
assert ca.id == test_ca
def test_deleted_ca_is_not_returned(self, app, test_user, test_org, test_membership):
"""Deleted CA should not be returned."""
with app.app_context():
ca = CA(
organization_id=test_org,
name='Deleted CA',
ca_type=CaType.USER,
key_type=KeyType.ED25519,
private_key='key',
public_key='pubkey',
fingerprint='sha256:deleted123',
is_active=True,
deleted_at=datetime.now(timezone.utc)
)
db.session.add(ca)
db.session.commit()
user = db.session.get(User, test_user)
result = _get_org_ca_for_user(user, ca_type='user')
assert result is None
def test_deleted_membership_no_access(self, app, test_org, test_ca):
"""User with deleted membership should not access CA."""
with app.app_context():
user = User(email='deleted_member@test.com', full_name='Deleted Member')
db.session.add(user)
db.session.commit()
membership = OrganizationMember(
user_id=user.id,
organization_id=test_org,
role=OrganizationRole.MEMBER,
deleted_at=datetime.now(timezone.utc)
)
db.session.add(membership)
db.session.commit()
result = _get_org_ca_for_user(user, ca_type='user')
assert result is None
def test_deleted_org_no_access(self, app):
"""User in deleted org should not access CA."""
with app.app_context():
org = Organization(
name='Deleted Org',
slug='deleted-org',
deleted_at=datetime.now(timezone.utc)
)
db.session.add(org)
db.session.commit()
user = User(email='user@deleted.org', full_name='User')
db.session.add(user)
db.session.commit()
membership = OrganizationMember(
user_id=user.id,
organization_id=org.id,
role=OrganizationRole.MEMBER
)
db.session.add(membership)
ca = CA(
organization_id=org.id,
name='CA in Deleted Org',
ca_type=CaType.USER,
key_type=KeyType.ED25519,
private_key='key',
public_key='pubkey',
fingerprint='sha256:deletedorg123',
is_active=True
)
db.session.add(ca)
db.session.commit()
result = _get_org_ca_for_user(user, ca_type='user')
assert result is None
def test_get_active_memberships_excludes_deleted(self, app, test_user, test_org, test_membership):
"""User.get_active_memberships() should exclude deleted memberships."""
with app.app_context():
user = db.session.get(User, test_user)
org2 = Organization(name='Org 2', slug='org-2')
db.session.add(org2)
db.session.commit()
membership2 = OrganizationMember(
user_id=test_user,
organization_id=org2.id,
role=OrganizationRole.MEMBER,
deleted_at=datetime.now(timezone.utc)
)
db.session.add(membership2)
db.session.commit()
active = user.get_active_memberships()
assert len(active) == 1
assert active[0].organization_id == test_org
def test_get_organizations_excludes_deleted(self, app, test_user, test_org, test_membership):
"""User.get_organizations() should exclude deleted memberships/orgs."""
with app.app_context():
user = db.session.get(User, test_user)
org2 = Organization(name='Deleted Org', slug='deleted-org-2')
db.session.add(org2)
db.session.commit()
membership2 = OrganizationMember(
user_id=test_user,
organization_id=org2.id,
role=OrganizationRole.MEMBER,
deleted_at=datetime.now(timezone.utc)
)
db.session.add(membership2)
db.session.commit()
orgs = user.get_organizations()
assert len(orgs) == 1
assert orgs[0].id == test_org
@@ -0,0 +1,92 @@
import pytest
from gatehouse_app.api.v1.ssh._helpers import _classify_ssh_key_material
class TestClassifySSHKeyMaterial:
def test_classifies_certificate(self):
result = _classify_ssh_key_material("ssh-ed25519-cert-v01@openssh.com AAAA comment")
assert result == "certificate"
def test_classifies_ed25519_public_key(self):
result = _classify_ssh_key_material("ssh-ed25519 AAAAB3NzaC1lZDI1NTE5AAAAI... comment")
assert result == "public_key"
def test_classifies_rsa_public_key(self):
result = _classify_ssh_key_material("ssh-rsa AAAAB3NzaC1yc2E... comment")
assert result == "public_key"
def test_classifies_dss_public_key(self):
result = _classify_ssh_key_material("ssh-dss AAAAB3NzaC1kc3M... comment")
assert result == "public_key"
def test_classifies_ecdsa_nistp256_public_key(self):
result = _classify_ssh_key_material("ecdsa-sha2-nistp256 AAAAE2Vj... comment")
assert result == "public_key"
def test_classifies_ecdsa_nistp384_public_key(self):
result = _classify_ssh_key_material("ecdsa-sha2-nistp384 AAAAE2Vj... comment")
assert result == "public_key"
def test_classifies_ecdsa_nistp521_public_key(self):
result = _classify_ssh_key_material("ecdsa-sha2-nistp521 AAAAE2Vj... comment")
assert result == "public_key"
def test_classifies_sk_ed25519_public_key(self):
result = _classify_ssh_key_material(
"sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5... comment"
)
assert result == "public_key"
def test_classifies_openssh_private_key(self):
result = _classify_ssh_key_material(
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
"base64data==\n"
"-----END OPENSSH PRIVATE KEY-----"
)
assert result == "private_key"
def test_classifies_rsa_private_key(self):
result = _classify_ssh_key_material(
"-----BEGIN RSA PRIVATE KEY-----\n"
"base64data==\n"
"-----END RSA PRIVATE KEY-----"
)
assert result == "private_key"
def test_unknown_for_empty_string(self):
result = _classify_ssh_key_material("")
assert result == "unknown"
def test_unknown_for_whitespace_string(self):
result = _classify_ssh_key_material(" \n ")
assert result == "unknown"
def test_unknown_for_gibberish(self):
result = _classify_ssh_key_material("not a valid ssh key")
assert result == "unknown"
def test_unknown_for_unsupported_key_type(self):
result = _classify_ssh_key_material("ssh-nonsense AAAABogus...")
assert result == "unknown"
@pytest.mark.parametrize("raw,expected", [
("ssh-rsa AAAAB3Nza... user@host", "public_key"),
("ssh-ed25519 AAAAC3... john@laptop", "public_key"),
("ecdsa-sha2-nistp256 AAAAE2Vj... me@box", "public_key"),
("sk-ssh-ed25519@openssh.com AAAAGn...", "public_key"),
("ssh-ed25519-cert-v01@openssh.com AAAAB3Nza cert for user", "certificate"),
(
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
"abcdefghijklmnopqrstuvwxyz\n"
"-----END OPENSSH PRIVATE KEY-----",
"private_key",
),
("", "unknown"),
("totally random garbage here", "unknown"),
])
def test_parametrized_variants(self, raw, expected):
assert _classify_ssh_key_material(raw) == expected
def test_certificate_with_leading_whitespace(self):
raw = " ssh-ed25519-cert-v01@openssh.com AAAAB3Nza extra words"
assert _classify_ssh_key_material(raw) == "certificate"
+275
View File
@@ -0,0 +1,275 @@
import pytest
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.organization.department import (
Department,
DepartmentMembership,
)
from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy
from gatehouse_app.api.v1.ssh._helpers import _get_merged_dept_cert_policy
class TestDeptCertPolicy:
def test_no_departments_returns_none(self, app, test_user):
with app.app_context():
result = _get_merged_dept_cert_policy(test_user)
assert result is None
def test_department_without_policy_returns_none(self, app, test_user, test_org):
with app.app_context():
dept = Department(
organization_id=test_org,
name="No Policy Dept",
)
db.session.add(dept)
db.session.commit()
membership = DepartmentMembership(
user_id=test_user,
department_id=dept.id,
)
db.session.add(membership)
db.session.commit()
result = _get_merged_dept_cert_policy(test_user)
assert result is None
def test_single_department_policy(self, app, test_user, test_org):
with app.app_context():
dept = Department(
organization_id=test_org,
name="Engineering",
)
db.session.add(dept)
db.session.commit()
membership = DepartmentMembership(
user_id=test_user,
department_id=dept.id,
)
db.session.add(membership)
db.session.commit()
policy = DepartmentCertPolicy(
department_id=dept.id,
allow_user_expiry=True,
default_expiry_hours=4,
max_expiry_hours=48,
allowed_extensions=["permit-pty", "permit-agent-forwarding"],
)
db.session.add(policy)
db.session.commit()
result = _get_merged_dept_cert_policy(test_user)
assert result is not None
assert result["allow_user_expiry"] is True
assert result["default_expiry_hours"] == 4
assert result["max_expiry_hours"] == 48
assert set(result["extensions"]) == {"permit-pty", "permit-agent-forwarding"}
def test_both_departments_same_policies(self, app, test_user, test_org):
with app.app_context():
dept1 = Department(
organization_id=test_org,
name="Engineering",
)
dept2 = Department(
organization_id=test_org,
name="SRE",
)
db.session.add_all([dept1, dept2])
db.session.commit()
member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id)
member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id)
db.session.add_all([member1, member2])
db.session.commit()
policy1 = DepartmentCertPolicy(
department_id=dept1.id,
allow_user_expiry=True,
default_expiry_hours=4,
max_expiry_hours=48,
allowed_extensions=["permit-pty", "permit-agent-forwarding"],
)
policy2 = DepartmentCertPolicy(
department_id=dept2.id,
allow_user_expiry=True,
default_expiry_hours=4,
max_expiry_hours=48,
allowed_extensions=["permit-pty", "permit-agent-forwarding"],
)
db.session.add_all([policy1, policy2])
db.session.commit()
result = _get_merged_dept_cert_policy(test_user)
assert result["allow_user_expiry"] is True
assert result["default_expiry_hours"] == 4
assert result["max_expiry_hours"] == 48
def test_merges_min_expiry_across_departments(self, app, test_user, test_org):
with app.app_context():
dept1 = Department(
organization_id=test_org,
name="Engineering",
)
dept2 = Department(
organization_id=test_org,
name="SRE",
)
db.session.add_all([dept1, dept2])
db.session.commit()
member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id)
member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id)
db.session.add_all([member1, member2])
db.session.commit()
policy1 = DepartmentCertPolicy(
department_id=dept1.id,
allow_user_expiry=True,
default_expiry_hours=24,
max_expiry_hours=720,
)
policy2 = DepartmentCertPolicy(
department_id=dept2.id,
allow_user_expiry=True,
default_expiry_hours=1,
max_expiry_hours=72,
)
db.session.add_all([policy1, policy2])
db.session.commit()
result = _get_merged_dept_cert_policy(test_user)
assert result["default_expiry_hours"] == 1
assert result["max_expiry_hours"] == 72
def test_extends_intersection_across_departments(self, app, test_user, test_org):
with app.app_context():
dept1 = Department(
organization_id=test_org,
name="Engineering",
)
dept2 = Department(
organization_id=test_org,
name="SRE",
)
db.session.add_all([dept1, dept2])
db.session.commit()
member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id)
member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id)
db.session.add_all([member1, member2])
db.session.commit()
policy1 = DepartmentCertPolicy(
department_id=dept1.id,
allowed_extensions=["permit-pty", "permit-agent-forwarding"],
)
policy2 = DepartmentCertPolicy(
department_id=dept2.id,
allowed_extensions=["permit-pty", "permit-port-forwarding"],
)
db.session.add_all([policy1, policy2])
db.session.commit()
result = _get_merged_dept_cert_policy(test_user)
assert set(result["extensions"]) == {"permit-pty"}
def test_any_false_user_expiry_means_overall_false(
self, app, test_user, test_org
):
with app.app_context():
dept1 = Department(
organization_id=test_org,
name="Engineering",
)
dept2 = Department(
organization_id=test_org,
name="SRE",
)
db.session.add_all([dept1, dept2])
db.session.commit()
member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id)
member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id)
db.session.add_all([member1, member2])
db.session.commit()
policy1 = DepartmentCertPolicy(
department_id=dept1.id,
allow_user_expiry=True,
)
policy2 = DepartmentCertPolicy(
department_id=dept2.id,
allow_user_expiry=False,
)
db.session.add_all([policy1, policy2])
db.session.commit()
result = _get_merged_dept_cert_policy(test_user)
assert result["allow_user_expiry"] is False
def test_deleted_department_filtered(self, app, test_user, test_org):
with app.app_context():
active_dept = Department(
organization_id=test_org,
name="Active Dept",
)
deleted_dept = Department(
organization_id=test_org,
name="Deleted Dept",
deleted_at=datetime.now(timezone.utc),
)
db.session.add_all([active_dept, deleted_dept])
db.session.commit()
active_member = DepartmentMembership(
user_id=test_user, department_id=active_dept.id
)
deleted_member = DepartmentMembership(
user_id=test_user,
department_id=deleted_dept.id,
deleted_at=datetime.now(timezone.utc),
)
db.session.add_all([active_member, deleted_member])
db.session.commit()
policy = DepartmentCertPolicy(
department_id=active_dept.id,
allow_user_expiry=True,
default_expiry_hours=12,
max_expiry_hours=96,
)
db.session.add(policy)
db.session.commit()
result = _get_merged_dept_cert_policy(test_user)
assert result is not None
assert result["default_expiry_hours"] == 12
def test_single_department_no_extensions(self, app, test_user, test_org):
with app.app_context():
dept = Department(
organization_id=test_org,
name="Minimal Dept",
)
db.session.add(dept)
db.session.commit()
membership = DepartmentMembership(
user_id=test_user, department_id=dept.id
)
db.session.add(membership)
db.session.commit()
policy = DepartmentCertPolicy(
department_id=dept.id,
allowed_extensions=[],
)
db.session.add(policy)
db.session.commit()
result = _get_merged_dept_cert_policy(test_user)
assert result is not None
assert result["extensions"] == []
+118
View File
@@ -0,0 +1,118 @@
import pytest
from uuid import uuid4
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.user.user import User
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.ssh_ca.ca import CA, CaType, KeyType
from gatehouse_app.api.v1.ssh._helpers import _get_org_ca_for_user
from gatehouse_app.utils.constants import OrganizationRole
class TestOrgCAForUser:
def test_organization_id_param_overrides_membership(self, app, test_user, test_org, test_ca, test_membership):
with app.app_context():
org2 = Organization(name="Org 2", slug="org-2")
db.session.add(org2)
db.session.commit()
ca2 = CA(
organization_id=org2.id,
name="Org 2 CA",
ca_type=CaType.USER,
key_type=KeyType.ED25519,
private_key="key2",
public_key="pubkey2",
fingerprint="sha256:org2...",
is_active=True,
)
db.session.add(ca2)
db.session.commit()
user = db.session.get(User, test_user)
result = _get_org_ca_for_user(user, ca_type="user", organization_id=test_org)
assert result is not None
assert result.organization_id == test_org
def test_multiple_orgs_returns_ca(self, app, test_user, test_org, test_ca, test_membership):
with app.app_context():
org2 = Organization(name="Org 2", slug="org-2")
db.session.add(org2)
db.session.commit()
user = db.session.get(User, test_user)
member2 = OrganizationMember(
user_id=test_user, organization_id=org2.id, role=OrganizationRole.MEMBER
)
db.session.add(member2)
db.session.commit()
result = _get_org_ca_for_user(user, ca_type="user")
assert result is not None
def test_user_with_no_memberships_returns_none(self, app):
with app.app_context():
user = User(email="lonely@test.com", full_name="Lonely User")
db.session.add(user)
db.session.commit()
result = _get_org_ca_for_user(user, ca_type="user")
assert result is None
def test_inactive_ca_not_returned(self, app, test_user, test_org, test_membership):
with app.app_context():
ca = CA(
organization_id=test_org,
name="Inactive CA",
ca_type=CaType.USER,
key_type=KeyType.ED25519,
private_key="key",
public_key="pubkey",
fingerprint="sha256:inactive123...",
is_active=False,
)
db.session.add(ca)
db.session.commit()
user = db.session.get(User, test_user)
result = _get_org_ca_for_user(user, ca_type="user")
assert result is None
def test_host_ca_not_returned_when_user_requested(self, app, test_user, test_org, test_membership):
with app.app_context():
ca = CA(
organization_id=test_org,
name="Host CA",
ca_type=CaType.HOST,
key_type=KeyType.ED25519,
private_key="key",
public_key="pubkey",
fingerprint="sha256:host123...",
is_active=True,
)
db.session.add(ca)
db.session.commit()
user = db.session.get(User, test_user)
result = _get_org_ca_for_user(user, ca_type="user")
assert result is None
def test_user_ca_not_returned_when_host_requested(self, app, test_user, test_org, test_membership):
with app.app_context():
ca = CA(
organization_id=test_org,
name="User CA",
ca_type=CaType.USER,
key_type=KeyType.ED25519,
private_key="key",
public_key="pubkey",
fingerprint="sha256:useronly...",
is_active=True,
)
db.session.add(ca)
db.session.commit()
user = db.session.get(User, test_user)
result = _get_org_ca_for_user(user, ca_type="host")
assert result is None
@@ -0,0 +1,180 @@
import pytest
from uuid import uuid4
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.user.user import User
from gatehouse_app.models.ssh_ca.ca import CA, CaType, KeyType, CertType
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
from gatehouse_app.services.ssh_ca_signing_service import SSHCertificateSigningResponse
from gatehouse_app.api.v1.ssh._helpers import _persist_certificate
class TestPersistCertificate:
def test_persists_valid_certificate(self, app, test_user, test_org):
with app.app_context():
ca = CA(
organization_id=test_org,
name="Signing CA",
ca_type=CaType.USER,
key_type=KeyType.ED25519,
private_key="enc_priv",
public_key="ssh-ed25519 AAAAB3Nza...",
fingerprint="sha256:abc123...",
is_active=True,
)
db.session.add(ca)
db.session.commit()
ssh_key = SSHKey(
user_id=test_user,
payload="ssh-ed25519 AAAAB3NzaC1lZDI1NTE5AAAAIKeyData comment",
fingerprint="sha256:keyfp123...",
)
db.session.add(ssh_key)
db.session.commit()
now = datetime.now(timezone.utc)
later = now + timedelta(hours=24)
response = SSHCertificateSigningResponse(
certificate="ssh-ed25519-cert-v01@openssh.com AAAACertData...",
serial="123456",
valid_after=now,
valid_before=later,
principals=["eng-prod"],
)
result = _persist_certificate(
user_id=test_user,
ssh_key_id=ssh_key.id,
ca=ca,
signing_response=response,
request_ip="10.0.0.1",
cert_type_str="user",
cert_identity="user@example.com",
)
assert result is not None
assert result.ca_id == ca.id
assert result.user_id == test_user
assert result.ssh_key_id == ssh_key.id
assert result.cert_type == CertType.USER
assert result.certificate == response.certificate
assert result.serial == response.serial
assert result.valid_after.replace(tzinfo=None) == now.replace(tzinfo=None)
assert result.valid_before.replace(tzinfo=None) == later.replace(tzinfo=None)
assert result.request_ip == "10.0.0.1"
assert result.key_id == "user@example.com"
assert sorted(result.principals) == ["eng-prod"]
assert result.revoked is False
assert result.status == CertificateStatus.ISSUED
def test_none_ca_returns_none(self, app, test_user):
with app.app_context():
now = datetime.now(timezone.utc)
response = SSHCertificateSigningResponse(
certificate="cert-data",
serial="1",
valid_after=now,
valid_before=now + timedelta(hours=1),
)
result = _persist_certificate(test_user, "keyid", None, response)
assert result is None
def test_invalid_cert_type_str_falls_back_to_user(self, app, test_user, test_org):
with app.app_context():
ca = CA(
organization_id=test_org,
name="Signing CA",
ca_type=CaType.USER,
key_type=KeyType.ED25519,
private_key="enc_priv",
public_key="ssh-ed25519 AAAAB3Nza...",
fingerprint="sha256:fallback123...",
is_active=True,
)
db.session.add(ca)
db.session.commit()
now = datetime.now(timezone.utc)
response = SSHCertificateSigningResponse(
certificate="cert-data",
serial="1",
valid_after=now,
valid_before=now + timedelta(hours=1),
)
result = _persist_certificate(
user_id=test_user,
ssh_key_id=None,
ca=ca,
signing_response=response,
cert_type_str="invalid_type",
)
assert result is not None
assert result.cert_type == CertType.USER
def test_none_ssh_key_id_defaults_to_host_cert_key_id(self, app, test_user, test_org):
with app.app_context():
ca = CA(
organization_id=test_org,
name="Host CA",
ca_type=CaType.HOST,
key_type=KeyType.ED25519,
private_key="enc_priv",
public_key="ssh-ed25519 AAAAB3Nza...",
fingerprint="sha256:hostca123...",
is_active=True,
)
db.session.add(ca)
db.session.commit()
now = datetime.now(timezone.utc)
response = SSHCertificateSigningResponse(
certificate="cert-data",
serial="1",
valid_after=now,
valid_before=now + timedelta(hours=1),
)
result = _persist_certificate(
user_id=test_user,
ssh_key_id=None,
ca=ca,
signing_response=response,
)
assert result is not None
assert result.key_id == "host-cert"
def test_request_ip_stored(self, app, test_user, test_org):
with app.app_context():
ca = CA(
organization_id=test_org,
name="Signing CA",
ca_type=CaType.USER,
key_type=KeyType.ED25519,
private_key="enc_priv",
public_key="ssh-ed25519 AAAAB3Nza...",
fingerprint="sha256:ip...",
is_active=True,
)
db.session.add(ca)
db.session.commit()
now = datetime.now(timezone.utc)
response = SSHCertificateSigningResponse(
certificate="cert-data",
serial="1",
valid_after=now,
valid_before=now + timedelta(hours=1),
)
result = _persist_certificate(
user_id=test_user,
ssh_key_id=None,
ca=ca,
signing_response=response,
request_ip="192.168.1.100",
)
assert result is not None
assert result.request_ip == "192.168.1.100"
+78
View File
@@ -0,0 +1,78 @@
import pytest
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.user.user import User
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
from gatehouse_app.services.ssh_key_service import SSHKeyService
from gatehouse_app.exceptions import UserNotFoundError, SSHKeyError
VALID_PUBLIC_KEY = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKL7YKAg+CgUvk/oNmU1fO24fLf5O5LayGH/EgORbz06"
class TestSSHKeyServiceAdd:
def test_add_new_key_returns_true(self, app, test_user):
with app.app_context():
service = SSHKeyService()
key, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "My laptop")
assert is_new is True
assert key.user_id == test_user
assert key.payload == VALID_PUBLIC_KEY
assert key.description == "My laptop"
assert key.verified is False
assert key.fingerprint is not None
assert key.key_type is not None
def test_add_duplicate_returns_existing(self, app, test_user):
with app.app_context():
service = SSHKeyService()
key1, _ = service.add_ssh_key(test_user, VALID_PUBLIC_KEY)
key2, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY)
assert is_new is False
assert key2.id == key1.id
def test_add_restores_soft_deleted_key(self, app, test_user):
with app.app_context():
service = SSHKeyService()
key1, _ = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "Original")
# Soft-delete the key
key1.deleted_at = datetime.now(timezone.utc)
db.session.commit()
# Re-add same key
key2, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "Restored")
assert is_new is False
assert key2.id == key1.id
assert key2.deleted_at is None
assert key2.description == "Restored"
assert key2.verified is False
assert key2.verified_at is None
def test_add_with_description(self, app, test_user):
with app.app_context():
service = SSHKeyService()
key, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "Work laptop")
assert is_new is True
assert key.description == "Work laptop"
def test_user_not_found_raises(self, app):
with app.app_context():
service = SSHKeyService()
with pytest.raises(UserNotFoundError):
service.add_ssh_key("nonexistent-user-id", VALID_PUBLIC_KEY)
def test_invalid_key_format_raises(self, app, test_user):
with app.app_context():
service = SSHKeyService()
with pytest.raises(SSHKeyError):
service.add_ssh_key(test_user, "not-a-valid-key")
def test_idempotent_second_call_no_error(self, app, test_user):
with app.app_context():
service = SSHKeyService()
service.add_ssh_key(test_user, VALID_PUBLIC_KEY)
key2, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY)
assert is_new is False
assert key2 is not None
+148
View File
@@ -0,0 +1,148 @@
import pytest
from gatehouse_app.services.ssh_ca_signing_service import SSHCertificateSigningRequest
VALID_PUBLIC_KEY = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKL7YKAg+CgUvk/oNmU1fO24fLf5O5LayGH/EgORbz06"
class TestCertSigningRequestValidate:
@pytest.fixture(autouse=True)
def patch_config(self, monkeypatch):
from gatehouse_app.config.ssh_ca_config import SSHCAConfig
class TestConfig(SSHCAConfig):
def get_int(self, key, default=0):
values = {
"max_cert_validity_hours": 720,
"max_principals_per_cert": 256,
"max_key_id_length": 255,
}
return values.get(key, default)
monkeypatch.setattr(
"gatehouse_app.config.ssh_ca_config.get_ssh_ca_config",
lambda: TestConfig(),
)
def test_valid_request_no_errors(self):
req = SSHCertificateSigningRequest(
ssh_public_key=VALID_PUBLIC_KEY,
principals=["eng-prod"],
key_id="user@example.com",
)
errors = req.validate()
assert errors == []
def test_valid_host_cert_no_errors(self):
req = SSHCertificateSigningRequest(
ssh_public_key=VALID_PUBLIC_KEY,
principals=["host1.example.com"],
key_id="host-identity",
cert_type="host",
)
errors = req.validate()
assert errors == []
def test_invalid_cert_type(self):
req = SSHCertificateSigningRequest(
ssh_public_key=VALID_PUBLIC_KEY,
principals=["eng-prod"],
key_id="user@example.com",
cert_type="invalid",
)
errors = req.validate()
assert any("cert_type" in e.lower() for e in errors)
def test_missing_public_key(self):
req = SSHCertificateSigningRequest(
ssh_public_key="",
principals=["eng-prod"],
key_id="user@example.com",
)
errors = req.validate()
assert any("public key" in e.lower() for e in errors)
def test_malformed_public_key(self):
req = SSHCertificateSigningRequest(
ssh_public_key="not-a-key",
principals=["eng-prod"],
key_id="user@example.com",
)
errors = req.validate()
assert any("public key" in e.lower() for e in errors)
def test_no_principals(self):
req = SSHCertificateSigningRequest(
ssh_public_key=VALID_PUBLIC_KEY,
principals=[],
key_id="user@example.com",
)
errors = req.validate()
assert any("principal" in e.lower() for e in errors)
def test_too_many_principals(self):
req = SSHCertificateSigningRequest(
ssh_public_key=VALID_PUBLIC_KEY,
principals=[f"p{i}" for i in range(300)],
key_id="user@example.com",
)
errors = req.validate()
assert any("too many" in e.lower() for e in errors)
def test_missing_key_id(self):
req = SSHCertificateSigningRequest(
ssh_public_key=VALID_PUBLIC_KEY,
principals=["eng-prod"],
key_id="",
)
errors = req.validate()
assert any("key_id" in e.lower() for e in errors)
def test_key_id_too_short(self):
req = SSHCertificateSigningRequest(
ssh_public_key=VALID_PUBLIC_KEY,
principals=["eng-prod"],
key_id="ab",
)
errors = req.validate()
assert any("key_id" in e.lower() for e in errors)
def test_key_id_exceeds_max_length(self):
req = SSHCertificateSigningRequest(
ssh_public_key=VALID_PUBLIC_KEY,
principals=["eng-prod"],
key_id="x" * 300,
)
errors = req.validate()
assert any("key_id" in e.lower() for e in errors)
def test_non_positive_expiry(self):
req = SSHCertificateSigningRequest(
ssh_public_key=VALID_PUBLIC_KEY,
principals=["eng-prod"],
key_id="user@example.com",
expiry_hours=0,
)
errors = req.validate()
assert any("expiry" in e.lower() for e in errors)
def test_expiry_exceeds_max(self):
req = SSHCertificateSigningRequest(
ssh_public_key=VALID_PUBLIC_KEY,
principals=["eng-prod"],
key_id="user@example.com",
expiry_hours=99999,
)
errors = req.validate()
assert any("expiry" in e.lower() for e in errors)
def test_none_expiry_is_ok(self):
req = SSHCertificateSigningRequest(
ssh_public_key=VALID_PUBLIC_KEY,
principals=["eng-prod"],
key_id="user@example.com",
expiry_hours=None,
)
errors = req.validate()
assert errors == []
+77
View File
@@ -0,0 +1,77 @@
from gatehouse_app.api.v1.superadmin.organizations import (
ListOrganizationsSchema,
UpdateOrganizationSchema,
)
class TestListOrganizationsSchema:
def test_defaults_when_empty(self):
result = ListOrganizationsSchema.load({})
assert result["page"] == 1
assert result["per_page"] == 20
assert result["search"] is None
assert result["status"] is None
assert result["plan_slug"] is None
def test_normal_pagination(self):
result = ListOrganizationsSchema.load({"page": 3, "per_page": 10})
assert result["page"] == 3
assert result["per_page"] == 10
def test_page_zero_clamped_to_one(self):
result = ListOrganizationsSchema.load({"page": 0})
assert result["page"] == 1
def test_negative_per_page_clamped_to_one(self):
result = ListOrganizationsSchema.load({"per_page": -5})
assert result["per_page"] == 1
def test_per_page_exceeds_max_clamped_to_100(self):
result = ListOrganizationsSchema.load({"per_page": 200})
assert result["per_page"] == 100
def test_non_integer_values_fallback(self):
result = ListOrganizationsSchema.load({"page": "abc", "per_page": "xyz"})
assert result["page"] == 1
assert result["per_page"] == 20
def test_search_passthrough(self):
result = ListOrganizationsSchema.load({"search": "acme"})
assert result["search"] == "acme"
def test_status_passthrough(self):
result = ListOrganizationsSchema.load({"status": "active"})
assert result["status"] == "active"
def test_plan_slug_passthrough(self):
result = ListOrganizationsSchema.load({"plan_slug": "pro"})
assert result["plan_slug"] == "pro"
class TestUpdateOrganizationSchema:
def test_all_fields(self):
result = UpdateOrganizationSchema.load({
"name": "New Name",
"description": "New Description",
"is_active": True,
})
assert result == {
"name": "New Name",
"description": "New Description",
"is_active": True,
}
def test_empty_dict(self):
result = UpdateOrganizationSchema.load({})
assert result == {}
def test_partial_data(self):
result = UpdateOrganizationSchema.load({"name": "Renamed Only"})
assert result == {"name": "Renamed Only"}
def test_is_active_coerced_to_bool(self):
result = UpdateOrganizationSchema.load({"is_active": "truthy"})
assert result["is_active"] is True
result = UpdateOrganizationSchema.load({"is_active": ""})
assert result["is_active"] is False

Some files were not shown because too many files have changed in this diff Show More