diff --git a/docker-compose.override.yml b/docker-compose.override.yml new file mode 100644 index 0000000..d5ca2fa --- /dev/null +++ b/docker-compose.override.yml @@ -0,0 +1,17 @@ +version: '3.8' + +services: + api: + environment: + - FLASK_ENV=development + - FLASK_DEBUG=1 + volumes: + - .:/app + command: > + flask run --host=0.0.0.0 --port=5000 --reload + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:5000/api/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 60s diff --git a/gatehouse_app/api/__init__.py b/gatehouse_app/api/__init__.py index 527dcb3..b2d6a57 100644 --- a/gatehouse_app/api/__init__.py +++ b/gatehouse_app/api/__init__.py @@ -1,12 +1,14 @@ """API package.""" from flask import Blueprint from gatehouse_app.utils.response import api_response +from gatehouse_app.extensions import limiter # Create main API blueprint api_bp = Blueprint("api", __name__) @api_bp.route("/health", methods=["GET"]) +@limiter.exempt def health_check(): """Health check endpoint.""" return api_response( diff --git a/gatehouse_app/api/v1/__init__.py b/gatehouse_app/api/v1/__init__.py index 7a45c53..3a59bfc 100644 --- a/gatehouse_app/api/v1/__init__.py +++ b/gatehouse_app/api/v1/__init__.py @@ -5,7 +5,7 @@ from flask import Blueprint api_v1_bp = Blueprint("api_v1", __name__) # Import route modules to register them -from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh, zerotier, sudo, oidc, contact +from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh, zerotier, oidc, contact from gatehouse_app.api.v1 import superadmin api_v1_bp.register_blueprint(ssh.ssh_bp) diff --git a/gatehouse_app/api/v1/departments.py b/gatehouse_app/api/v1/departments.py index 4ced781..d305c66 100644 --- a/gatehouse_app/api/v1/departments.py +++ b/gatehouse_app/api/v1/departments.py @@ -16,15 +16,12 @@ class DepartmentCreateSchema(Schema): """Schema for creating a department.""" name = fields.Str(required=True, validate=validate.Length(min=1, max=255)) description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) - can_sudo = fields.Bool(allow_none=True, load_default=False) - class DepartmentUpdateSchema(Schema): """Schema for updating a department.""" name = fields.Str(validate=validate.Length(min=1, max=255)) description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) - can_sudo = fields.Bool(allow_none=True) class AddDepartmentMemberSchema(Schema): @@ -122,7 +119,6 @@ def create_department(org_id): organization_id=org_id, name=data["name"], description=data.get("description"), - can_sudo=data.get("can_sudo", False), ) db.session.add(dept) db.session.commit() diff --git a/gatehouse_app/api/v1/organizations/__init__.py b/gatehouse_app/api/v1/organizations/__init__.py index fba555b..76f6fdd 100644 --- a/gatehouse_app/api/v1/organizations/__init__.py +++ b/gatehouse_app/api/v1/organizations/__init__.py @@ -1,4 +1,4 @@ """Organization routes package.""" -from gatehouse_app.api.v1.organizations import core, members, invites, clients, cas, audit, roles, api_keys +from gatehouse_app.api.v1.organizations import core, members, invites, clients, cas, audit, roles -__all__ = ["core", "members", "invites", "clients", "cas", "audit", "roles", "api_keys"] +__all__ = ["core", "members", "invites", "clients", "cas", "audit", "roles"] diff --git a/gatehouse_app/api/v1/organizations/api_keys.py b/gatehouse_app/api/v1/organizations/api_keys.py deleted file mode 100644 index 90d83ee..0000000 --- a/gatehouse_app/api/v1/organizations/api_keys.py +++ /dev/null @@ -1,299 +0,0 @@ -"""Organization API Key management endpoints.""" -from flask import g, request -from marshmallow import Schema, fields, validate, ValidationError - -from gatehouse_app.api.v1 import api_v1_bp -from gatehouse_app.utils.response import api_response -from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required -from gatehouse_app.models.organization import OrganizationApiKey -from gatehouse_app.services.organization_service import OrganizationService -from gatehouse_app.extensions import db - - -class ApiKeyCreateSchema(Schema): - """Schema for creating an API key.""" - name = fields.Str(required=True, validate=validate.Length(min=1, max=255)) - description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) - - -class ApiKeyUpdateSchema(Schema): - """Schema for updating an API key.""" - name = fields.Str(validate=validate.Length(min=1, max=255)) - description = fields.Str(allow_none=True, validate=validate.Length(max=2000)) - - -@api_v1_bp.route("/organizations//api-keys", methods=["GET"]) -@login_required -@require_admin -@full_access_required -def list_api_keys(org_id): - """ - List all API keys for an organization. - - Only accessible by organization admins. - - Args: - org_id: Organization ID - - Returns: - 200: List of API keys (without key values) - 401: Not authenticated - 403: Not an admin - 404: Organization not found - """ - org = OrganizationService.get_organization_by_id(org_id) - - # Check if user is an admin - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - - membership = OrganizationMember.query.filter_by( - user_id=g.current_user.id, - organization_id=org_id, - deleted_at=None - ).first() - - if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: - return api_response( - success=False, - message="You do not have permission to manage API keys", - status=403, - error_type="AUTHORIZATION_ERROR", - ) - - api_keys = OrganizationApiKey.query.filter_by( - organization_id=org_id, - deleted_at=None - ).all() - - return api_response( - data={ - "api_keys": [k.to_dict() for k in api_keys], - "count": len(api_keys), - }, - message="API keys retrieved successfully", - ) - - -@api_v1_bp.route("/organizations//api-keys", methods=["POST"]) -@login_required -@require_admin -@full_access_required -def create_api_key(org_id): - """ - Create a new API key for an organization. - - Only accessible by organization admins. - The plain text key is returned only on creation and should be stored securely. - - Args: - org_id: Organization ID - - Request body: - name: API key name (required) - description: Optional description - - Returns: - 201: API key created successfully - 400: Validation error - 401: Not authenticated - 403: Not an admin - 404: Organization not found - """ - try: - org = OrganizationService.get_organization_by_id(org_id) - - # Check if user is an admin - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - - membership = OrganizationMember.query.filter_by( - user_id=g.current_user.id, - organization_id=org_id, - deleted_at=None - ).first() - - if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: - return api_response( - success=False, - message="You do not have permission to create API keys", - status=403, - error_type="AUTHORIZATION_ERROR", - ) - - schema = ApiKeyCreateSchema() - data = schema.load(request.json or {}) - - # Create the API key - api_key, plain_key = OrganizationApiKey.create_key( - organization_id=org_id, - name=data["name"], - description=data.get("description"), - ) - - # Return the key data with the plain text key (only on creation) - key_dict = api_key.to_dict() - key_dict["key"] = plain_key # Include plain text only on creation - - return api_response( - data={"api_key": key_dict}, - message="API key created successfully. Store the key value securely - it cannot be retrieved later.", - status=201, - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - -@api_v1_bp.route("/organizations//api-keys/", methods=["PATCH"]) -@login_required -@require_admin -@full_access_required -def update_api_key(org_id, key_id): - """ - Update an API key. - - Only accessible by organization admins. - - Args: - org_id: Organization ID - key_id: API Key ID - - Request body: - name: New name (optional) - description: New description (optional) - - Returns: - 200: API key updated successfully - 400: Validation error - 401: Not authenticated - 403: Not an admin - 404: Organization or API key not found - """ - try: - org = OrganizationService.get_organization_by_id(org_id) - - # Check if user is an admin - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - - membership = OrganizationMember.query.filter_by( - user_id=g.current_user.id, - organization_id=org_id, - deleted_at=None - ).first() - - if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: - return api_response( - success=False, - message="You do not have permission to update API keys", - status=403, - error_type="AUTHORIZATION_ERROR", - ) - - api_key = OrganizationApiKey.query.filter_by( - id=key_id, - organization_id=org_id, - deleted_at=None - ).first() - - if not api_key: - return api_response( - success=False, - message="API key not found", - status=404, - error_type="NOT_FOUND", - ) - - schema = ApiKeyUpdateSchema() - data = schema.load(request.json or {}) - - # Update fields - if "name" in data: - api_key.name = data["name"] - if "description" in data: - api_key.description = data["description"] - - api_key.save() - - return api_response( - data={"api_key": api_key.to_dict()}, - message="API key updated successfully", - ) - - except ValidationError as e: - return api_response( - success=False, - message="Validation failed", - status=400, - error_type="VALIDATION_ERROR", - error_details=e.messages, - ) - - -@api_v1_bp.route("/organizations//api-keys/", methods=["DELETE"]) -@login_required -@require_admin -@full_access_required -def delete_api_key(org_id, key_id): - """ - Delete/revoke an API key. - - Only accessible by organization admins. - - Args: - org_id: Organization ID - key_id: API Key ID - - Returns: - 200: API key deleted successfully - 401: Not authenticated - 403: Not an admin - 404: Organization or API key not found - """ - org = OrganizationService.get_organization_by_id(org_id) - - # Check if user is an admin - from gatehouse_app.models.organization.organization_member import OrganizationMember - from gatehouse_app.utils.constants import OrganizationRole - - membership = OrganizationMember.query.filter_by( - user_id=g.current_user.id, - organization_id=org_id, - deleted_at=None - ).first() - - if not membership or membership.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]: - return api_response( - success=False, - message="You do not have permission to delete API keys", - status=403, - error_type="AUTHORIZATION_ERROR", - ) - - api_key = OrganizationApiKey.query.filter_by( - id=key_id, - organization_id=org_id, - deleted_at=None - ).first() - - if not api_key: - return api_response( - success=False, - message="API key not found", - status=404, - error_type="NOT_FOUND", - ) - - # Soft delete the API key - api_key.delete(soft=True) - - return api_response( - message="API key deleted successfully", - ) diff --git a/gatehouse_app/api/v1/sudo.py b/gatehouse_app/api/v1/sudo.py deleted file mode 100644 index f828587..0000000 --- a/gatehouse_app/api/v1/sudo.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Sudoer check and sudo-related endpoints.""" -from flask import request -from gatehouse_app.api.v1 import api_v1_bp -from gatehouse_app.utils.response import api_response -from gatehouse_app.models.organization import OrganizationApiKey -from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate -from gatehouse_app.models.organization import Department, DepartmentMembership - - -@api_v1_bp.route("/sudo/check", methods=["POST"]) -def check_sudoer(): - """ - Check if a user with a given certificate can sudo. - - This endpoint validates an API key for an organization, retrieves the certificate - by serial ID, finds the user and their departments, and checks if any of their - departments have sudo capability. - - Request body: - api_key: Organization API key (required) - certificate_serial: Certificate serial ID (required) - - Returns: - 200: Sudoer status returned - 400: Invalid request body - 401: Invalid API key - 403: Certificate not found or user not found - 404: Organization or certificate not found - """ - try: - data = request.get_json() - - if not data: - return api_response( - success=False, - message="Request body is required", - status=400, - error_type="INVALID_REQUEST", - ) - - api_key = data.get("api_key") - certificate_serial = data.get("certificate_serial") - - if not api_key or certificate_serial is None: - return api_response( - success=False, - message="api_key and certificate_serial are required", - status=400, - error_type="MISSING_REQUIRED_FIELDS", - ) - - # Find the certificate by serial - certificate = SSHCertificate.query.filter_by( - serial=certificate_serial, - deleted_at=None - ).first() - - if not certificate: - return api_response( - success=False, - message="Certificate not found", - status=404, - error_type="NOT_FOUND", - ) - - # Get the CA and organization - ca = certificate.ca - if not ca: - return api_response( - success=False, - message="Certificate CA not found", - status=404, - error_type="NOT_FOUND", - ) - - org_id = ca.organization_id - - # Verify the API key for this organization - org_api_key = OrganizationApiKey.verify_key(org_id, api_key) - - if not org_api_key: - return api_response( - success=False, - message="Invalid API key for organization", - status=401, - error_type="UNAUTHORIZED", - ) - - # Get the user from the certificate - user = certificate.user - if not user: - return api_response( - success=False, - message="Certificate user not found", - status=404, - error_type="NOT_FOUND", - ) - - # Get all departments the user belongs to - user_departments = DepartmentMembership.query.filter_by( - user_id=user.id, - deleted_at=None - ).all() - - # Check if any of the user's departments have sudo capability - can_sudo = False - sudoer_departments = [] - - for dept_membership in user_departments: - dept = dept_membership.department - if dept and dept.can_sudo and dept.deleted_at is None: - can_sudo = True - sudoer_departments.append({ - "id": dept.id, - "name": dept.name, - }) - - return api_response( - data={ - "can_sudo": can_sudo, - "user_id": user.id, - "user_email": user.email, - "certificate_serial": certificate.serial, - "sudoer_departments": sudoer_departments, - "all_departments_count": len(user_departments), - }, - message="Sudoer status retrieved successfully", - status=200, - ) - - except Exception as e: - return api_response( - success=False, - message=f"An error occurred: {str(e)}", - status=500, - error_type="INTERNAL_ERROR", - ) diff --git a/gatehouse_app/api/v1/users/admin.py b/gatehouse_app/api/v1/users/admin.py index 116519e..8a2b671 100644 --- a/gatehouse_app/api/v1/users/admin.py +++ b/gatehouse_app/api/v1/users/admin.py @@ -710,6 +710,128 @@ def admin_set_user_password(user_id): return api_response(data={"user": {"id": str(target.id), "email": target.email}}, message=f"Password updated for {target.email}") +@api_v1_bp.route("/admin/users//ssh-certificates", methods=["GET"]) +@login_required +@full_access_required +def admin_get_user_ssh_certificates(user_id): + """List all SSH certificates for a user (admin view). + + Returns all certificates — active, expired, revoked — with relevant + metrics for admin visibility. Includes SSH key metadata (fingerprint, + type, description) via the ssh_key relationship. + + Query parameters: + status: Filter by certificate status (issued, revoked, expired, superseded) + active: If "true", return only currently valid certificates + cert_type: Filter by certificate type (user, host) + """ + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus + from gatehouse_app.models.ssh_ca.ca import CertType + + caller = g.current_user + target = _find_user_for_admin(user_id) + if not target: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + if not _get_admin_access(caller, target): + return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR") + + query = SSHCertificate.query.filter_by(user_id=user_id, deleted_at=None) + + # Filter by explicit status (e.g. ?status=revoked) + status_param = request.args.get("status", "").strip().lower() + if status_param: + try: + status_enum = CertificateStatus(status_param) + query = query.filter(SSHCertificate.status == status_enum) + except ValueError: + valid_statuses = [s.value for s in CertificateStatus] + return api_response( + success=False, + message=f"Invalid status '{status_param}'. Must be one of: {', '.join(valid_statuses)}", + status=400, error_type="VALIDATION_ERROR", + ) + + # Filter for only currently valid certs (?active=true) + active_param = request.args.get("active", "").strip().lower() + if active_param == "true": + now = datetime.now(timezone.utc) + query = query.filter( + SSHCertificate.revoked == False, + SSHCertificate.valid_after <= now, + SSHCertificate.valid_before >= now, + ) + elif active_param == "false": + now = datetime.now(timezone.utc) + query = query.filter( + (SSHCertificate.revoked == True) | + (SSHCertificate.valid_before < now) + ) + + # Filter by certificate type (?cert_type=host) + cert_type_param = request.args.get("cert_type", "").strip().lower() + if cert_type_param: + try: + cert_type_enum = CertType(cert_type_param) + query = query.filter(SSHCertificate.cert_type == cert_type_enum) + except ValueError: + return api_response( + success=False, + message=f"Invalid cert_type '{cert_type_param}'. Must be one of: user, host", + status=400, error_type="VALIDATION_ERROR", + ) + + # Pagination + try: + page = max(1, int(request.args.get("page", 1))) + per_page = min(100, max(1, int(request.args.get("per_page", 50)))) + except ValueError: + page, per_page = 1, 50 + + total = query.count() + certs = ( + query.order_by(SSHCertificate.created_at.desc()) + .offset((page - 1) * per_page) + .limit(per_page) + .all() + ) + + now = datetime.now(timezone.utc) + certs_data = [] + for cert in certs: + d = cert.to_dict() + # Enrich with SSH key metadata + if cert.ssh_key: + d["ssh_key"] = { + "id": str(cert.ssh_key.id), + "fingerprint": cert.ssh_key.fingerprint, + "key_type": cert.ssh_key.key_type, + "key_bits": cert.ssh_key.key_bits, + "key_comment": cert.ssh_key.key_comment, + "description": cert.ssh_key.description, + "verified": cert.ssh_key.verified, + } + else: + d["ssh_key"] = None + certs_data.append(d) + + return api_response( + data={ + "user": { + "id": str(target.id), + "email": target.email, + "full_name": target.full_name, + }, + "certificates": certs_data, + "count": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page, + }, + message="SSH certificates retrieved successfully", + ) + + @api_v1_bp.route("/admin/users//linked-accounts", methods=["GET"]) @login_required @full_access_required diff --git a/gatehouse_app/api/v1/zerotier.py b/gatehouse_app/api/v1/zerotier.py index 1ba82c9..244784f 100644 --- a/gatehouse_app/api/v1/zerotier.py +++ b/gatehouse_app/api/v1/zerotier.py @@ -13,12 +13,12 @@ from gatehouse_app.services import device_service from gatehouse_app.services import network_access_service from gatehouse_app.services import zerotier_api_service as zt from gatehouse_app.services import zerotier_reconciliation_service +from gatehouse_app.services.user_service import UserService from gatehouse_app.models import ( PortalNetwork, Device, - DeviceNetworkMembership, - UserNetworkApproval, ActivationSession, + NetworkAccessRequest, ) from gatehouse_app.models.organization import Organization from gatehouse_app.models.organization.organization_member import OrganizationMember @@ -30,7 +30,6 @@ from gatehouse_app.exceptions import ( DeviceNotFoundError, DeviceAlreadyExistsError, ApprovalNotFoundError, - MembershipNotFoundError, ) @@ -347,6 +346,47 @@ def list_devices(org_id): ) +@api_v1_bp.route("/organizations//users//devices", methods=["GET"]) +@login_required +@require_admin +@full_access_required +def list_user_devices(org_id, user_id): + """List all ZeroTier devices for a specific user in the organization (admin only).""" + org, err = _org_check(org_id) + if err: + return err + + # Verify target user exists + from gatehouse_app.exceptions.validation_exceptions import UserNotFoundError + try: + target_user = UserService.get_user_by_id(user_id) + except UserNotFoundError: + return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND") + + # Verify target user is a member of the org + is_member = OrganizationMember.query.filter( + OrganizationMember.organization_id == org_id, + OrganizationMember.user_id == user_id, + OrganizationMember.deleted_at.is_(None), + ).first() is not None + + if not is_member: + return api_response(success=False, message="User is not a member of this organization", status=404, error_type="NOT_FOUND") + + # Get devices for the user in this org + devices = device_service.list_user_devices(user_id, org_id) + + return api_response( + data={ + "devices": [d.to_dict() for d in devices], + "count": len(devices), + "user_id": user_id, + "organization_id": org_id, + }, + message="User devices retrieved successfully", + ) + + @api_v1_bp.route("/organizations//devices", methods=["POST"]) @login_required @full_access_required @@ -373,11 +413,8 @@ def register_device(org_id): serial_number=data.get("serial_number"), ) - from gatehouse_app.services.network_access_service import materialize_device_memberships - memberships = materialize_device_memberships(device.id) - return api_response( - data={"device": device.to_dict(), "memberships_created": len(memberships)}, + data={"device": device.to_dict()}, message="Device registered successfully", status=201, ) @@ -486,7 +523,7 @@ def list_my_approvals(org_id): if err: return err - approvals = network_access_service.list_user_approvals(g.current_user.id, org_id) + approvals = network_access_service.list_user_requests(g.current_user.id, org_id) return api_response( data={"approvals": [a.to_dict() for a in approvals], "count": len(approvals)}, message="Approvals retrieved successfully", @@ -549,18 +586,18 @@ def reject_request(org_id, approval_id): return api_response(success=False, message=str(e.message), status=400, error_type=e.error_type) -@api_v1_bp.route("/organizations//approvals//revoke", methods=["POST"]) +@api_v1_bp.route("/organizations//approvals//revoke", methods=["POST"]) @login_required @require_admin @full_access_required -def revoke_approval(org_id, approval_id): +def revoke_approval(org_id, request_id): """Revoke an approved access record (admin only).""" org, err = _org_check(org_id) if err: return err try: - approval = network_access_service.revoke_approval(approval_id, g.current_user.id) + approval = network_access_service.revoke_access(request_id, g.current_user.id) return api_response(data={"approval": approval.to_dict()}, message="Approval revoked successfully") except ApprovalNotFoundError as e: return api_response(success=False, message=str(e), status=404, error_type=e.error_type) @@ -607,7 +644,7 @@ def admin_list_all_approvals(org_id): network_id = request.args.get("network_id") state = request.args.get("state") - approvals = network_access_service.list_all_org_approvals(org_id, network_id=network_id, state=state) + approvals = network_access_service.list_all_org_requests(org_id, network_id=network_id, state=state) return api_response( data={"approvals": [a.to_dict() for a in approvals], "count": len(approvals)}, message="Approvals retrieved successfully", @@ -626,10 +663,10 @@ def list_memberships(org_id): if err: return err - memberships = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.user_id == g.current_user.id, - DeviceNetworkMembership.organization_id == org_id, - DeviceNetworkMembership.deleted_at.is_(None), + memberships = NetworkAccessRequest.query.filter( + NetworkAccessRequest.user_id == g.current_user.id, + NetworkAccessRequest.organization_id == org_id, + NetworkAccessRequest.deleted_at.is_(None), ).all() return api_response( @@ -656,15 +693,14 @@ def activate_membership(org_id, membership_id): is_admin = _is_org_admin(org_id, g.current_user.id) try: - session = network_access_service.activate_device_membership( - membership_id=membership_id, + session = network_access_service.activate_request( + request_id=membership_id, user_id=g.current_user.id, lifetime_minutes=data.get("lifetime_minutes"), admin_override=is_admin, ) - membership = DeviceNetworkMembership.query.get(membership_id) - return api_response(data={"session": session.to_dict(), "membership": membership.to_dict()}, message="Membership activated successfully") - except MembershipNotFoundError as e: + return api_response(data={"session": session.to_dict()}, message="Request activated successfully") + except ApprovalNotFoundError as e: return api_response(success=False, message=str(e), status=404, error_type=e.error_type) except AppValidationError as e: return api_response(success=False, message=str(e.message), status=400, error_type=e.error_type) @@ -681,22 +717,22 @@ def deactivate_membership(org_id, membership_id): # Verify ownership for non-admins if not _is_org_admin(org_id, g.current_user.id): - membership_check = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.id == membership_id, - DeviceNetworkMembership.user_id == g.current_user.id, - DeviceNetworkMembership.deleted_at.is_(None), + membership_check = NetworkAccessRequest.query.filter( + NetworkAccessRequest.id == membership_id, + NetworkAccessRequest.user_id == g.current_user.id, + NetworkAccessRequest.deleted_at.is_(None), ).first() if not membership_check: return api_response(success=False, message="Membership not found", status=404, error_type="NOT_FOUND") try: - membership = network_access_service.deactivate_membership( - membership_id=membership_id, + req = network_access_service.deactivate_request( + request_id=membership_id, reason="manual_revoke", deactivated_by_user_id=g.current_user.id, ) - return api_response(data={"membership": membership.to_dict()}, message="Membership deactivated successfully") - except MembershipNotFoundError as e: + return api_response(data={"request": req.to_dict()}, message="Request deactivated successfully") + except ApprovalNotFoundError as e: return api_response(success=False, message=str(e), status=404, error_type=e.error_type) @@ -730,17 +766,21 @@ def activate_all_memberships(org_id): @login_required @full_access_required def join_network(org_id, device_id, portal_network_id): - """Join an open network directly with a registered device.""" + """Join an open network directly with a registered device. Admins can override for any network.""" org, err = _org_check(org_id) if err: return err + is_admin = _is_org_admin(org_id, g.current_user.id) + try: membership = network_access_service.join_network_for_device( user_id=g.current_user.id, organization_id=org_id, device_id=device_id, portal_network_id=portal_network_id, + admin_override=is_admin, + granted_by_user_id=g.current_user.id if is_admin else None, ) return api_response(data={"membership": membership.to_dict()}, message="Joined network successfully", status=201) except AppValidationError as e: @@ -759,12 +799,12 @@ def delete_membership(org_id, membership_id): return err try: - network_access_service.revoke_membership_soft( - membership_id=membership_id, - revoked_by_user_id=g.current_user.id, + network_access_service.revoke_request_soft( + request_id=membership_id, + revoker_user_id=g.current_user.id, ) - return api_response(message="Membership removed successfully") - except MembershipNotFoundError as e: + return api_response(message="Request revoked successfully") + except ApprovalNotFoundError as e: return api_response(success=False, message=str(e), status=404, error_type=e.error_type) @@ -820,10 +860,8 @@ def end_session(org_id, session_id): _end_session(session, ActivationEndReason.LOGOUT) - membership = DeviceNetworkMembership.query.get(session.device_network_membership_id) - if membership: - from gatehouse_app.services.network_access_service import deactivate_membership - deactivate_membership(membership.id, reason="logout") + if session.network_access_request_id: + network_access_service.deactivate_request(session.network_access_request_id, reason="logout") return api_response(message="Session ended successfully") @@ -848,15 +886,16 @@ def trigger_kill_switch(org_id): return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) try: - event = network_access_service.kill_switch( - target_user_id=data["target_user_id"], - triggered_by_user_id=g.current_user.id, - organization_id=org_id, - scope=data.get("scope", "organization"), - reason=data.get("reason"), + from gatehouse_app.utils.constants import KillSwitchScope + scope = data.get("scope", "organization") + scope_enum = KillSwitchScope(scope) if scope in KillSwitchScope._value2member_map_ else KillSwitchScope.ORGANIZATION + count = network_access_service.kill_switch( + user_id=data["target_user_id"], + org_id=org_id, + scope=scope_enum, network_ids=data.get("network_ids"), ) - return api_response(data={"event": event.to_dict()}, message="Kill switch triggered successfully") + return api_response(data={"affected_count": count}, message="Kill switch triggered successfully") except AppValidationError as e: return api_response(success=False, message=str(e.message), status=400, error_type=e.error_type) @@ -873,10 +912,10 @@ def admin_list_memberships(org_id): if err: return err - memberships = network_access_service.get_all_memberships_with_details(org_id) + requests = network_access_service.get_all_requests_with_details(org_id) return api_response( - data={"memberships": memberships, "count": len(memberships)}, - message="All memberships retrieved successfully", + data={"requests": requests, "count": len(requests)}, + message="All requests retrieved successfully", ) @@ -891,9 +930,9 @@ def admin_delete_membership(org_id, membership_id): return err try: - network_access_service.hard_delete_membership(membership_id) - return api_response(message="Membership permanently deleted") - except MembershipNotFoundError as e: + network_access_service.hard_delete_request(membership_id) + return api_response(message="Request permanently deleted") + except ApprovalNotFoundError as e: return api_response(success=False, message=str(e), status=404, error_type=e.error_type) diff --git a/gatehouse_app/models/__init__.py b/gatehouse_app/models/__init__.py index fe62307..84a6784 100644 --- a/gatehouse_app/models/__init__.py +++ b/gatehouse_app/models/__init__.py @@ -17,9 +17,8 @@ models.ssh_ca — CA, KeyType, CertType, CaType, CAPermission, CertificateAuditLog models.security — OrganizationSecurityPolicy, UserSecurityPolicy, MfaPolicyCompliance -models.zerotier — PortalNetwork, Device, UserNetworkApproval, - DeviceNetworkMembership, ActivationSession, - ZeroTierMembership, KillSwitchEvent +models.zerotier — PortalNetwork, Device, NetworkAccessRequest, + ActivationSession, ZeroTierMembership, KillSwitchEvent All names are re-exported here so that existing code using the flat import style (``from gatehouse_app.models import X``) or the old per-file style @@ -107,8 +106,7 @@ from gatehouse_app.models.security.mfa_policy_compliance import ( from gatehouse_app.models.zerotier import ( # noqa: F401 PortalNetwork, Device, - UserNetworkApproval, - DeviceNetworkMembership, + NetworkAccessRequest, ActivationSession, ZeroTierMembership, KillSwitchEvent, @@ -178,8 +176,7 @@ __all__ = [ # ZeroTier "PortalNetwork", "Device", - "UserNetworkApproval", - "DeviceNetworkMembership", + "NetworkAccessRequest", "ActivationSession", "ZeroTierMembership", "KillSwitchEvent", diff --git a/gatehouse_app/models/organization/__init__.py b/gatehouse_app/models/organization/__init__.py index 52e29cf..aa33f8e 100644 --- a/gatehouse_app/models/organization/__init__.py +++ b/gatehouse_app/models/organization/__init__.py @@ -12,7 +12,6 @@ from gatehouse_app.models.organization.department_cert_policy import ( ) from gatehouse_app.models.organization.principal import Principal, PrincipalMembership from gatehouse_app.models.organization.org_invite_token import OrgInviteToken -from gatehouse_app.models.organization.organization_api_key import OrganizationApiKey __all__ = [ "Organization", @@ -25,5 +24,4 @@ __all__ = [ "Principal", "PrincipalMembership", "OrgInviteToken", - "OrganizationApiKey", ] diff --git a/gatehouse_app/models/organization/department.py b/gatehouse_app/models/organization/department.py index f46385a..800780b 100644 --- a/gatehouse_app/models/organization/department.py +++ b/gatehouse_app/models/organization/department.py @@ -27,7 +27,6 @@ class Department(BaseModel): ) name = db.Column(db.String(255), nullable=False, index=True) description = db.Column(db.Text, nullable=True) - can_sudo = db.Column(db.Boolean, default=False, nullable=False) # Relationships organization = db.relationship("Organization", back_populates="departments") diff --git a/gatehouse_app/models/organization/organization.py b/gatehouse_app/models/organization/organization.py index f8ae26e..ce96741 100644 --- a/gatehouse_app/models/organization/organization.py +++ b/gatehouse_app/models/organization/organization.py @@ -47,9 +47,6 @@ class Organization(BaseModel): cas = db.relationship( "CA", back_populates="organization", cascade="all, delete-orphan" ) - api_keys = db.relationship( - "OrganizationApiKey", back_populates="organization", cascade="all, delete-orphan" - ) def __repr__(self): """String representation of Organization.""" @@ -110,11 +107,3 @@ class Organization(BaseModel): """ return [ca for ca in self.cas if ca.deleted_at is None] - def get_active_api_keys(self): - """Get active (non-deleted) API keys. - - Returns: - List of OrganizationApiKey instances where deleted_at is None. - """ - return [k for k in self.api_keys if k.deleted_at is None] - diff --git a/gatehouse_app/models/organization/organization_api_key.py b/gatehouse_app/models/organization/organization_api_key.py deleted file mode 100644 index 64feefa..0000000 --- a/gatehouse_app/models/organization/organization_api_key.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Organization API Key model — API keys for organizations for external integrations.""" -import secrets -from datetime import datetime, timezone -from gatehouse_app.extensions import db -from gatehouse_app.models.base import BaseModel - - -class OrganizationApiKey(BaseModel): - """API Key model representing an API key for an organization. - - API keys are used to authenticate external integrations or services - that need programmatic access to the organization's resources. - Each key is tied to an organization and can be revoked/deleted as needed. - """ - - __tablename__ = "organization_api_keys" - - organization_id = db.Column( - db.String(36), - db.ForeignKey("organizations.id"), - nullable=False, - index=True, - ) - - # Human-readable name for the API key - name = db.Column(db.String(255), nullable=False) - - # Hashed key value (never store plain text) - key_hash = db.Column(db.String(255), nullable=False, unique=True, index=True) - - # Last used timestamp for tracking activity - last_used_at = db.Column(db.DateTime, nullable=True) - - # Revocation status - is_revoked = db.Column(db.Boolean, default=False, nullable=False, index=True) - revoked_at = db.Column(db.DateTime, nullable=True) - revoke_reason = db.Column(db.String(255), nullable=True) - - # Description/purpose of the key - description = db.Column(db.Text, nullable=True) - - # Relationships - organization = db.relationship("Organization", back_populates="api_keys") - - __table_args__ = ( - db.Index("idx_org_api_key_org_active", "organization_id", "is_revoked"), - db.Index("idx_api_key_last_used", "last_used_at"), - ) - - def __repr__(self): - """String representation of OrganizationApiKey.""" - return f"" - - @staticmethod - def generate_key() -> str: - """Generate a random API key. - - Returns: - A random 32-byte hex string suitable for use as an API key - """ - return secrets.token_hex(32) - - @classmethod - def create_key( - cls, - organization_id: str, - name: str, - description: str = None, - ) -> tuple: - """Create and store a new API key for an organization. - - Args: - organization_id: ID of the organization - name: Human-readable name for the key - description: Optional description/purpose of the key - - Returns: - Tuple of (OrganizationApiKey instance, plain_text_key_string) - The plain text key is only returned on creation and should be - stored securely by the user. It cannot be retrieved later. - """ - # Generate a plain text key - plain_key = cls.generate_key() - - # Hash it using the key_hash method - key_hash = cls.hash_key(plain_key) - - # Create the database record - api_key = cls( - organization_id=organization_id, - name=name, - key_hash=key_hash, - description=description, - ) - api_key.save() - - return api_key, plain_key - - @staticmethod - def hash_key(plain_key: str) -> str: - """Hash an API key for storage. - - Args: - plain_key: The plain text API key - - Returns: - Hashed version of the key - """ - import hashlib - return hashlib.sha256(plain_key.encode()).hexdigest() - - @classmethod - def verify_key(cls, organization_id: str, plain_key: str) -> "OrganizationApiKey": - """Verify an API key for an organization. - - Args: - organization_id: ID of the organization - plain_key: The plain text API key to verify - - Returns: - OrganizationApiKey instance if valid and active, None otherwise - """ - key_hash = cls.hash_key(plain_key) - - api_key = cls.query.filter_by( - organization_id=organization_id, - key_hash=key_hash, - is_revoked=False, - deleted_at=None, - ).first() - - if api_key: - # Update last used timestamp - api_key.last_used_at = datetime.now(timezone.utc) - api_key.save() - - return api_key - - def revoke(self, reason: str = None) -> None: - """Revoke this API key. - - Args: - reason: Optional reason for revocation - """ - self.is_revoked = True - self.revoked_at = datetime.now(timezone.utc) - self.revoke_reason = reason - self.save() - - def to_dict(self, exclude=None): - """Convert API key to dictionary. - - The key_hash is excluded by default for security. - """ - exclude = exclude or [] - if "key_hash" not in exclude: - exclude.append("key_hash") - return super().to_dict(exclude=exclude) diff --git a/gatehouse_app/models/zerotier/__init__.py b/gatehouse_app/models/zerotier/__init__.py index d6b3f5f..36caa81 100644 --- a/gatehouse_app/models/zerotier/__init__.py +++ b/gatehouse_app/models/zerotier/__init__.py @@ -2,8 +2,7 @@ PortalNetwork — manager-created network bound to a ZT network ID Device — user-registered ZeroTier node endpoint -UserNetworkApproval — durable manager approval for network access -DeviceNetworkMembership — per-device per-network workflow record +NetworkAccessRequest — unified per-device, per-network access record ActivationSession — temporary activation window ZeroTierMembership — observed controller-side member state KillSwitchEvent — explicit rapid deactivation record @@ -11,8 +10,7 @@ KillSwitchEvent — explicit rapid deactivation record from gatehouse_app.models.zerotier.activation_session import ActivationSession # noqa: F401 from gatehouse_app.models.zerotier.device import Device # noqa: F401 -from gatehouse_app.models.zerotier.device_network_membership import DeviceNetworkMembership # noqa: F401 from gatehouse_app.models.zerotier.kill_switch_event import KillSwitchEvent # noqa: F401 +from gatehouse_app.models.zerotier.network_access_request import NetworkAccessRequest # noqa: F401 from gatehouse_app.models.zerotier.portal_network import PortalNetwork # noqa: F401 -from gatehouse_app.models.zerotier.user_network_approval import UserNetworkApproval # noqa: F401 from gatehouse_app.models.zerotier.zerotier_membership import ZeroTierMembership # noqa: F401 diff --git a/gatehouse_app/models/zerotier/activation_session.py b/gatehouse_app/models/zerotier/activation_session.py index 00e4086..d396951 100644 --- a/gatehouse_app/models/zerotier/activation_session.py +++ b/gatehouse_app/models/zerotier/activation_session.py @@ -16,7 +16,7 @@ class ActivationSession(BaseModel): Attributes: organization_id: FK to the organization user_id: FK to the user who owns the session - device_network_membership_id: FK to the related membership + network_access_request_id: FK to the related network access request authenticated_at: When the user re-authenticated to start this session expires_at: When the activation window ends ended_at: When the session was explicitly ended (null if still active) @@ -38,9 +38,9 @@ class ActivationSession(BaseModel): nullable=False, index=True, ) - device_network_membership_id = db.Column( + network_access_request_id = db.Column( db.String(36), - db.ForeignKey("device_network_memberships.id"), + db.ForeignKey("network_access_requests.id"), nullable=False, index=True, ) @@ -75,14 +75,14 @@ class ActivationSession(BaseModel): foreign_keys=[created_by], backref="created_activation_sessions", ) - membership = db.relationship( - "DeviceNetworkMembership", + access_request = db.relationship( + "NetworkAccessRequest", back_populates="activation_sessions", ) def __repr__(self): return ( - f"" ) diff --git a/gatehouse_app/models/zerotier/device.py b/gatehouse_app/models/zerotier/device.py index 5955890..9604021 100644 --- a/gatehouse_app/models/zerotier/device.py +++ b/gatehouse_app/models/zerotier/device.py @@ -2,7 +2,7 @@ from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel -from gatehouse_app.utils.constants import DeviceStatus +from gatehouse_app.utils.constants import ApprovalState, DeviceStatus class Device(BaseModel): @@ -55,8 +55,8 @@ class Device(BaseModel): # Relationships user = db.relationship("User", backref="devices") organization = db.relationship("Organization", backref="devices") - memberships = db.relationship( - "DeviceNetworkMembership", + network_access_requests = db.relationship( + "NetworkAccessRequest", back_populates="device", cascade="all, delete-orphan", ) @@ -73,7 +73,7 @@ class Device(BaseModel): data = super().to_dict(exclude=exclude) data["display_name"] = self.display_name data["active_membership_count"] = sum( - 1 for m in self.memberships - if m.state == "active_authorized" and m.deleted_at is None + 1 for r in self.network_access_requests + if r.active and r.status == ApprovalState.APPROVED and r.deleted_at is None ) return data diff --git a/gatehouse_app/models/zerotier/device_network_membership.py b/gatehouse_app/models/zerotier/device_network_membership.py deleted file mode 100644 index cc6b85d..0000000 --- a/gatehouse_app/models/zerotier/device_network_membership.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Device network membership — per-device, per-network workflow object.""" - -from gatehouse_app.extensions import db -from gatehouse_app.models.base import BaseModel -from gatehouse_app.utils.constants import MembershipState - - -class DeviceNetworkMembership(BaseModel): - """The main per-device, per-network workflow record. - - This binds a specific Device to a specific PortalNetwork through a - UserNetworkApproval. It tracks both the internal portal state and the - observed ZeroTier membership state. - - States: - pending_device_registration — approval exists but no device registered yet - pending_request — user has requested access but not yet approved - pending_manager_approval — approval pending manager sign-off - approved_inactive — approved but not currently active - joined_deauthorized — device has joined ZT network but not authorized - active_authorized — authorized and actively connected - activation_expired — activation window ended (member still in ZT, deauth'd) - suspended — temporarily suspended - revoked — permanently revoked - rejected — request was rejected - """ - - __tablename__ = "device_network_memberships" - - organization_id = db.Column( - db.String(36), - db.ForeignKey("organizations.id"), - nullable=False, - index=True, - ) - user_id = db.Column( - db.String(36), - db.ForeignKey("users.id"), - nullable=False, - index=True, - ) - device_id = db.Column( - db.String(36), - db.ForeignKey("devices.id"), - nullable=False, - index=True, - ) - portal_network_id = db.Column( - db.String(36), - db.ForeignKey("portal_networks.id"), - nullable=False, - index=True, - ) - user_network_approval_id = db.Column( - db.String(36), - db.ForeignKey("user_network_approvals.id"), - nullable=True, - index=True, - ) - state = db.Column( - db.Enum(MembershipState, name="membership_state", values_callable=lambda x: [e.value for e in x]), - default=MembershipState.PENDING_DEVICE_REGISTRATION, - nullable=False, - index=True, - ) - join_seen = db.Column(db.Boolean, default=False, nullable=False) - currently_authorized = db.Column(db.Boolean, default=False, nullable=False) - approved_for_activation = db.Column(db.Boolean, default=True, nullable=False) - - # Relationships - organization = db.relationship("Organization", backref="network_memberships") - user = db.relationship("User", backref="network_memberships") - device = db.relationship("Device", back_populates="memberships") - portal_network = db.relationship( - "PortalNetwork", - back_populates="memberships", - ) - approval = db.relationship( - "UserNetworkApproval", - back_populates="memberships", - ) - activation_sessions = db.relationship( - "ActivationSession", - back_populates="membership", - cascade="all, delete-orphan", - ) - zerotier_membership = db.relationship( - "ZeroTierMembership", - back_populates="device_network_membership", - uselist=False, - cascade="all, delete-orphan", - ) - - __table_args__ = ( - db.UniqueConstraint( - "device_id", - "portal_network_id", - "deleted_at", - name="uix_device_network", - ), - ) - - def __repr__(self): - return ( - f"" - ) - - @property - def active_session(self): - """Return the current active ActivationSession, if any.""" - for s in self.activation_sessions: - if s.ended_at is None and s.expires_at is not None: - from datetime import datetime, timezone - now = datetime.now(timezone.utc) - exp = s.expires_at - if exp.tzinfo is None: - exp = exp.replace(tzinfo=timezone.utc) - if exp > now: - return s - return None - - def to_dict(self, exclude=None): - exclude = exclude or [] - data = super().to_dict(exclude=exclude) - data["active_session"] = ( - self.active_session.to_dict() if self.active_session else None - ) - return data diff --git a/gatehouse_app/models/zerotier/network_access_request.py b/gatehouse_app/models/zerotier/network_access_request.py new file mode 100644 index 0000000..69b80e0 --- /dev/null +++ b/gatehouse_app/models/zerotier/network_access_request.py @@ -0,0 +1,147 @@ +"""Network access request model — unified per-device, per-network access record.""" + +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel +from gatehouse_app.utils.constants import ApprovalGrantType, ApprovalState + + +class NetworkAccessRequest(BaseModel): + """A unified access record binding a user's device to a portal network. + + Replaces the separate UserNetworkApproval and DeviceNetworkMembership + tables with a single per-device, per-network row. Each row tracks both + the business-level approval status and the device-level active/inactive + toggle. + + Attributes: + organization_id: FK to the organization + user_id: FK to the requesting user + device_id: FK to the specific device + portal_network_id: FK to the portal network + granted_by_user_id: FK to the manager who approved (null for user-initiated) + grant_type: requested (user-initiated) or assigned (manager-initiated) + status: pending / approved / rejected / revoked / suspended + active: whether the device connection is currently live + justification: Business reason for the request + join_seen: Whether the device has been seen joining the ZeroTier network + """ + + __tablename__ = "network_access_requests" + + organization_id = db.Column( + db.String(36), + db.ForeignKey("organizations.id"), + nullable=False, + index=True, + ) + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id"), + nullable=False, + index=True, + ) + device_id = db.Column( + db.String(36), + db.ForeignKey("devices.id"), + nullable=False, + index=True, + ) + portal_network_id = db.Column( + db.String(36), + db.ForeignKey("portal_networks.id"), + nullable=False, + index=True, + ) + granted_by_user_id = db.Column( + db.String(36), + db.ForeignKey("users.id"), + nullable=True, + ) + grant_type = db.Column( + db.Enum(ApprovalGrantType, name="approval_grant_type", values_callable=lambda x: [e.value for e in x]), + default=ApprovalGrantType.REQUESTED, + nullable=False, + ) + status = db.Column( + db.Enum(ApprovalState, name="approval_state", values_callable=lambda x: [e.value for e in x]), + default=ApprovalState.PENDING, + nullable=False, + index=True, + ) + active = db.Column( + db.Boolean, + default=False, + nullable=False, + ) + justification = db.Column(db.Text, nullable=True) + join_seen = db.Column(db.Boolean, default=False, nullable=False) + + # Relationships + organization = db.relationship("Organization", backref="network_access_requests") + user = db.relationship( + "User", + foreign_keys=[user_id], + backref="network_access_requests", + ) + granted_by = db.relationship( + "User", + foreign_keys=[granted_by_user_id], + backref="granted_network_requests", + ) + device = db.relationship( + "Device", + back_populates="network_access_requests", + ) + portal_network = db.relationship( + "PortalNetwork", + backref="access_requests", + ) + activation_sessions = db.relationship( + "ActivationSession", + back_populates="access_request", + cascade="all, delete-orphan", + ) + zerotier_membership = db.relationship( + "ZeroTierMembership", + back_populates="access_request", + uselist=False, + cascade="all, delete-orphan", + ) + + __table_args__ = ( + db.UniqueConstraint( + "user_id", + "device_id", + "portal_network_id", + "deleted_at", + name="uix_user_device_network", + ), + ) + + def __repr__(self): + return ( + f"" + ) + + @property + def active_session(self): + """Return the current active ActivationSession, if any.""" + for s in self.activation_sessions: + if s.ended_at is None and s.expires_at is not None: + now = datetime.now(timezone.utc) + exp = s.expires_at + if exp.tzinfo is None: + exp = exp.replace(tzinfo=timezone.utc) + if exp > now: + return s + return None + + def to_dict(self, exclude=None): + exclude = exclude or [] + data = super().to_dict(exclude=exclude) + session = self.active_session + data["active_session"] = session.to_dict() if session else None + return data diff --git a/gatehouse_app/models/zerotier/portal_network.py b/gatehouse_app/models/zerotier/portal_network.py index bea0971..c71fb1a 100644 --- a/gatehouse_app/models/zerotier/portal_network.py +++ b/gatehouse_app/models/zerotier/portal_network.py @@ -2,7 +2,7 @@ from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel -from gatehouse_app.utils.constants import NetworkEnvironment, NetworkRequestMode +from gatehouse_app.utils.constants import ApprovalState, NetworkEnvironment, NetworkRequestMode class PortalNetwork(BaseModel): @@ -65,16 +65,6 @@ class PortalNetwork(BaseModel): # Relationships organization = db.relationship("Organization", backref="portal_networks") owner = db.relationship("User", backref="owned_networks") - approvals = db.relationship( - "UserNetworkApproval", - back_populates="portal_network", - cascade="all, delete-orphan", - ) - memberships = db.relationship( - "DeviceNetworkMembership", - back_populates="portal_network", - cascade="all, delete-orphan", - ) __table_args__ = ( db.UniqueConstraint( @@ -91,10 +81,11 @@ class PortalNetwork(BaseModel): exclude = exclude or [] data = super().to_dict(exclude=exclude) data["approved_user_count"] = sum( - 1 for a in self.approvals if a.state == "approved" and a.deleted_at is None + 1 for a in self.access_requests + if a.status == ApprovalState.APPROVED and a.deleted_at is None ) data["active_membership_count"] = sum( - 1 for m in self.memberships - if m.state == "active_authorized" and m.deleted_at is None + 1 for r in self.access_requests + if r.active and r.status == ApprovalState.APPROVED and r.deleted_at is None ) return data diff --git a/gatehouse_app/models/zerotier/user_network_approval.py b/gatehouse_app/models/zerotier/user_network_approval.py deleted file mode 100644 index dfe559e..0000000 --- a/gatehouse_app/models/zerotier/user_network_approval.py +++ /dev/null @@ -1,106 +0,0 @@ -"""User network approval model — durable manager approval for network access.""" - -from gatehouse_app.extensions import db -from gatehouse_app.models.base import BaseModel -from gatehouse_app.utils.constants import ApprovalGrantType, ApprovalState - - -class UserNetworkApproval(BaseModel): - """A durable approval record binding a user to a portal network. - - This is the business-level approval — separate from any device and separate - from activation sessions. Manager approval survives across days and only - needs to be issued once unless explicitly revoked. - - Attributes: - organization_id: FK to the organization - user_id: FK to the approved user - portal_network_id: FK to the portal network - granted_by_user_id: FK to the manager who approved (null for system-assigned) - grant_type: requested (user-initiated) or assigned (manager-initiated) - state: pending / approved / rejected / revoked / suspended - justification: Business reason for the approval - """ - - __tablename__ = "user_network_approvals" - - organization_id = db.Column( - db.String(36), - db.ForeignKey("organizations.id"), - nullable=False, - index=True, - ) - user_id = db.Column( - db.String(36), - db.ForeignKey("users.id"), - nullable=False, - index=True, - ) - portal_network_id = db.Column( - db.String(36), - db.ForeignKey("portal_networks.id"), - nullable=False, - index=True, - ) - granted_by_user_id = db.Column( - db.String(36), - db.ForeignKey("users.id"), - nullable=True, - ) - grant_type = db.Column( - db.Enum(ApprovalGrantType, name="approval_grant_type", values_callable=lambda x: [e.value for e in x]), - default=ApprovalGrantType.REQUESTED, - nullable=False, - ) - state = db.Column( - db.Enum(ApprovalState, name="approval_state", values_callable=lambda x: [e.value for e in x]), - default=ApprovalState.PENDING, - nullable=False, - index=True, - ) - justification = db.Column(db.Text, nullable=True) - - # Relationships - organization = db.relationship("Organization", backref="network_approvals") - user = db.relationship( - "User", - foreign_keys=[user_id], - backref="network_approvals", - ) - granted_by = db.relationship( - "User", - foreign_keys=[granted_by_user_id], - backref="granted_approvals", - ) - portal_network = db.relationship( - "PortalNetwork", - back_populates="approvals", - ) - memberships = db.relationship( - "DeviceNetworkMembership", - back_populates="approval", - cascade="all, delete-orphan", - ) - - __table_args__ = ( - db.UniqueConstraint( - "user_id", - "portal_network_id", - "deleted_at", - name="uix_user_network_approval", - ), - ) - - def __repr__(self): - return ( - f"" - ) - - def to_dict(self, exclude=None): - exclude = exclude or [] - data = super().to_dict(exclude=exclude) - data["active_membership_count"] = sum( - 1 for m in self.memberships if m.deleted_at is None - ) - return data diff --git a/gatehouse_app/models/zerotier/zerotier_membership.py b/gatehouse_app/models/zerotier/zerotier_membership.py index ad2b7e9..5c39e9d 100644 --- a/gatehouse_app/models/zerotier/zerotier_membership.py +++ b/gatehouse_app/models/zerotier/zerotier_membership.py @@ -15,7 +15,7 @@ class ZeroTierMembership(BaseModel): Attributes: organization_id: FK to the organization - device_network_membership_id: FK to the portal's membership record (nullable) + network_access_request_id: FK to the portal's access request record (nullable) zerotier_network_id: The 16-char hex ZeroTier network ID node_id: The 10-char hex ZeroTier node ID member_seen: Whether the controller has ever seen this member @@ -33,9 +33,9 @@ class ZeroTierMembership(BaseModel): nullable=False, index=True, ) - device_network_membership_id = db.Column( + network_access_request_id = db.Column( db.String(36), - db.ForeignKey("device_network_memberships.id"), + db.ForeignKey("network_access_requests.id"), nullable=True, index=True, ) @@ -57,8 +57,8 @@ class ZeroTierMembership(BaseModel): # Relationships organization = db.relationship("Organization", backref="zerotier_memberships") - device_network_membership = db.relationship( - "DeviceNetworkMembership", + access_request = db.relationship( + "NetworkAccessRequest", back_populates="zerotier_membership", ) diff --git a/gatehouse_app/services/device_service.py b/gatehouse_app/services/device_service.py index 573c8bf..706f134 100644 --- a/gatehouse_app/services/device_service.py +++ b/gatehouse_app/services/device_service.py @@ -167,10 +167,10 @@ def remove_device(device_id: str, user_id: str) -> None: raise DeviceNotFoundError("Device not found.") # Soft-delete all memberships (deactivates active ones first) - for membership in device.memberships: - if membership.deleted_at is None: - from gatehouse_app.services.network_access_service import revoke_membership_soft - revoke_membership_soft(membership.id, revoked_by_user_id=user_id) + for request in device.network_access_requests: + if request.deleted_at is None: + from gatehouse_app.services.network_access_service import revoke_request_soft + revoke_request_soft(request.id, revoker_user_id=user_id) device.delete(soft=True) @@ -180,7 +180,7 @@ def remove_device(device_id: str, user_id: str) -> None: organization_id=device.organization_id, resource_type="device", resource_id=device.id, - metadata={"node_id": device.node_id, "memberships_removed": len([m for m in device.memberships if m.deleted_at is None])}, + metadata={"node_id": device.node_id, "memberships_removed": len([m for m in device.network_access_requests if m.deleted_at is None])}, description=f"Device {device.node_id} removed", success=True, ) diff --git a/gatehouse_app/services/network_access_service.py b/gatehouse_app/services/network_access_service.py index 99b7132..58d7db5 100644 --- a/gatehouse_app/services/network_access_service.py +++ b/gatehouse_app/services/network_access_service.py @@ -10,8 +10,7 @@ from datetime import datetime, timedelta, timezone from gatehouse_app.extensions import db from gatehouse_app.models import ( Device, - DeviceNetworkMembership, - UserNetworkApproval, + NetworkAccessRequest, ActivationSession, ZeroTierMembership, KillSwitchEvent, @@ -23,7 +22,6 @@ from gatehouse_app.utils.constants import ( ApprovalGrantType, ApprovalState, ActivationEndReason, - MembershipState, KillSwitchScope, ) from gatehouse_app.exceptions import ( @@ -47,15 +45,17 @@ def request_access( portal_network_id: str, device_id: str, justification: str | None = None, -) -> UserNetworkApproval: +) -> NetworkAccessRequest: """Create a pending access request for a user's specific device to a network. - Creates a UserNetworkApproval and a DeviceNetworkMembership pinned to the device - in one transaction. For open-mode networks the approval is immediate; + For open-mode networks the approval is immediate; for approval_required networks it starts pending. """ network = _validate_org_network(org_id=organization_id, network_id=portal_network_id) + if network.request_mode.value == "invite_only": + raise ValidationError("This network is invite-only. Access can only be assigned by a manager.") + device = Device.query.filter( Device.id == device_id, Device.user_id == user_id, @@ -65,67 +65,59 @@ def request_access( if not device: raise DeviceNotFoundError(f"Device {device_id} not found or does not belong to this user.") - existing = UserNetworkApproval.query.filter( - UserNetworkApproval.user_id == user_id, - UserNetworkApproval.portal_network_id == portal_network_id, - UserNetworkApproval.deleted_at.is_(None), + is_open = network.request_mode.value == "open" + + existing = NetworkAccessRequest.query.filter( + NetworkAccessRequest.user_id == user_id, + NetworkAccessRequest.device_id == device_id, + NetworkAccessRequest.portal_network_id == portal_network_id, + NetworkAccessRequest.deleted_at.is_(None), ).first() + if existing: - if existing.state in (ApprovalState.APPROVED, ApprovalState.PENDING): + if existing.status in (ApprovalState.APPROVED, ApprovalState.PENDING, ApprovalState.SUSPENDED): raise ApprovalAlreadyExistsError( - "An access request or approval already exists for this user and network." + "An access request or approval already exists for this user, device, and network." ) - is_open = network.request_mode.value == "open" - existing.state = ApprovalState.APPROVED if is_open else ApprovalState.PENDING + existing.status = ApprovalState.APPROVED if is_open else ApprovalState.PENDING + existing.grant_type = ApprovalGrantType.REQUESTED existing.justification = justification + existing.active = False + existing.granted_by_user_id = None existing.save() - existing_membership = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.user_network_approval_id == existing.id, - DeviceNetworkMembership.device_id == device_id, - DeviceNetworkMembership.deleted_at.is_(None), - ).first() - if not existing_membership: - membership_state = MembershipState.APPROVED_INACTIVE if is_open else MembershipState.PENDING_DEVICE_REGISTRATION - membership = DeviceNetworkMembership( - organization_id=organization_id, - user_id=user_id, - device_id=device_id, - portal_network_id=portal_network_id, - user_network_approval_id=existing.id, - state=membership_state, - approved_for_activation=is_open, - ) - membership.save() - _ensure_zerotier_member(device.node_id, portal_network_id, authorized=False) + _ensure_zerotier_member(device.node_id, portal_network_id, authorized=False) + + AuditService.log_action( + action="zt.approval.reopened", + user_id=user_id, + organization_id=organization_id, + resource_type="network_access_request", + resource_id=existing.id, + metadata={ + "portal_network_id": portal_network_id, + "device_id": device_id, + "device_node_id": device.node_id, + "justification": justification, + "is_open_network": is_open, + }, + description=f"Network access request reopened for device {device.node_id}", + success=True, + ) return existing - is_open = network.request_mode.value == "open" - approval_state = ApprovalState.APPROVED if is_open else ApprovalState.PENDING - - approval = UserNetworkApproval( - organization_id=organization_id, - user_id=user_id, - portal_network_id=portal_network_id, - grant_type=ApprovalGrantType.REQUESTED, - state=approval_state, - justification=justification, - ) - approval.save() - - membership_state = MembershipState.APPROVED_INACTIVE if is_open else MembershipState.PENDING_DEVICE_REGISTRATION - - membership = DeviceNetworkMembership( + request = NetworkAccessRequest( organization_id=organization_id, user_id=user_id, device_id=device_id, portal_network_id=portal_network_id, - user_network_approval_id=approval.id, - state=membership_state, - approved_for_activation=is_open, + grant_type=ApprovalGrantType.REQUESTED, + status=ApprovalState.APPROVED if is_open else ApprovalState.PENDING, + active=False, + justification=justification, ) - membership.save() + request.save() _ensure_zerotier_member(device.node_id, portal_network_id, authorized=False) @@ -133,8 +125,8 @@ def request_access( action="zt.approval.requested", user_id=user_id, organization_id=organization_id, - resource_type="user_network_approval", - resource_id=approval.id, + resource_type="network_access_request", + resource_id=request.id, metadata={ "portal_network_id": portal_network_id, "device_id": device_id, @@ -146,24 +138,7 @@ def request_access( success=True, ) - if is_open: - AuditService.log_action( - action="zt.membership.created", - user_id=user_id, - organization_id=organization_id, - resource_type="device_network_membership", - resource_id=membership.id, - metadata={ - "device_id": device_id, - "device_node_id": device.node_id, - "portal_network_id": portal_network_id, - "source": "open_network_join", - }, - description=f"Device membership created (open network) for {device.node_id}", - success=True, - ) - - return approval + return request def assign_access( @@ -172,43 +147,70 @@ def assign_access( granted_by_user_id: str, organization_id: str, justification: str | None = None, -) -> UserNetworkApproval: - """Manager directly assigns access to a user (no approval needed).""" - network = _validate_org_network(org_id=organization_id, network_id=portal_network_id) +) -> NetworkAccessRequest: + """Manager directly assigns access to a user (no approval needed). - existing = UserNetworkApproval.query.filter( - UserNetworkApproval.user_id == target_user_id, - UserNetworkApproval.portal_network_id == portal_network_id, - UserNetworkApproval.deleted_at.is_(None), - ).first() - if existing: - if existing.state == ApprovalState.APPROVED: - return existing - existing.state = ApprovalState.APPROVED - existing.granted_by_user_id = granted_by_user_id - existing.justification = justification - existing.save() - _materialize_memberships_for_approval(existing) - return existing + Creates one NetworkAccessRequest per device for the target user. + Returns the first created or updated request for backward compatibility. + """ + _validate_org_network(org_id=organization_id, network_id=portal_network_id) - approval = UserNetworkApproval( - organization_id=organization_id, - user_id=target_user_id, - portal_network_id=portal_network_id, - granted_by_user_id=granted_by_user_id, - grant_type=ApprovalGrantType.ASSIGNED, - state=ApprovalState.APPROVED, - justification=justification, - ) - approval.save() - _materialize_memberships_for_approval(approval) + devices = Device.query.filter( + Device.user_id == target_user_id, + Device.organization_id == organization_id, + Device.deleted_at.is_(None), + ).all() + + if not devices: + raise DeviceNotFoundError(f"User {target_user_id} has no registered devices.") + + first_returned: NetworkAccessRequest | None = None + + for device in devices: + existing = NetworkAccessRequest.query.filter( + NetworkAccessRequest.user_id == target_user_id, + NetworkAccessRequest.device_id == device.id, + NetworkAccessRequest.portal_network_id == portal_network_id, + NetworkAccessRequest.deleted_at.is_(None), + ).first() + + if existing: + if existing.status == ApprovalState.APPROVED: + if first_returned is None: + first_returned = existing + continue + existing.status = ApprovalState.APPROVED + existing.grant_type = ApprovalGrantType.ASSIGNED + existing.granted_by_user_id = granted_by_user_id + existing.justification = justification + existing.active = False + existing.save() + if first_returned is None: + first_returned = existing + else: + req = NetworkAccessRequest( + organization_id=organization_id, + user_id=target_user_id, + device_id=device.id, + portal_network_id=portal_network_id, + granted_by_user_id=granted_by_user_id, + grant_type=ApprovalGrantType.ASSIGNED, + status=ApprovalState.APPROVED, + active=False, + justification=justification, + ) + req.save() + if first_returned is None: + first_returned = req + + _ensure_zerotier_member(device.node_id, portal_network_id, authorized=False) AuditService.log_action( action="zt.approval.granted", user_id=granted_by_user_id, organization_id=organization_id, - resource_type="user_network_approval", - resource_id=approval.id, + resource_type="network_access_request", + resource_id=first_returned.id if first_returned else None, metadata={ "target_user_id": target_user_id, "portal_network_id": portal_network_id, @@ -218,311 +220,210 @@ def assign_access( success=True, ) - return approval + return first_returned def approve_request( - approval_id: str, + request_id: str, approver_user_id: str, -) -> UserNetworkApproval: - """Approve a pending access request. Updates the pre-created device membership to approved_inactive.""" - approval = _get_approval(approval_id) +) -> NetworkAccessRequest: + """Approve a pending access request.""" + request = _get_request(request_id) - if approval.state != ApprovalState.PENDING: - raise ValidationError(f"Approval is not pending (current state: {approval.state.value}).") + if request.status != ApprovalState.PENDING: + raise ValidationError(f"Request is not pending (current status: {request.status.value}).") - approval.state = ApprovalState.APPROVED - approval.granted_by_user_id = approver_user_id - approval.save() - - membership = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.user_network_approval_id == approval_id, - DeviceNetworkMembership.deleted_at.is_(None), - ).first() - - if membership: - membership.state = MembershipState.APPROVED_INACTIVE - membership.approved_for_activation = True - membership.save() - else: - logger.warning(f"[approve_request] No pre-created membership found for approval {approval_id}") + request.status = ApprovalState.APPROVED + request.granted_by_user_id = approver_user_id + request.save() AuditService.log_action( action="zt.approval.granted", user_id=approver_user_id, - organization_id=approval.organization_id, - resource_type="user_network_approval", - resource_id=approval.id, - metadata={"target_user_id": approval.user_id, "grant_type": "requested"}, - description=f"Network access approved for user {approval.user_id}", + organization_id=request.organization_id, + resource_type="network_access_request", + resource_id=request.id, + metadata={"target_user_id": request.user_id, "grant_type": "requested"}, + description=f"Network access approved for user {request.user_id}", success=True, ) - return approval + return request def reject_request( - approval_id: str, + request_id: str, rejecter_user_id: str, -) -> UserNetworkApproval: - """Reject a pending access request and remove the pre-created device membership.""" - approval = _get_approval(approval_id) +) -> NetworkAccessRequest: + """Reject a pending access request and soft-delete it.""" + request = _get_request(request_id) - if approval.state != ApprovalState.PENDING: - raise ValidationError(f"Approval is not pending (current state: {approval.state.value}).") + if request.status != ApprovalState.PENDING: + raise ValidationError(f"Request is not pending (current status: {request.status.value}).") - membership = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.user_network_approval_id == approval_id, - DeviceNetworkMembership.deleted_at.is_(None), - ).first() - - if membership: - from datetime import datetime, timezone - membership.deleted_at = datetime.now(timezone.utc) - membership.save() - - approval.state = ApprovalState.REJECTED - approval.save() + request.status = ApprovalState.REJECTED + request.deleted_at = datetime.now(timezone.utc) + request.save() AuditService.log_action( action="zt.approval.rejected", user_id=rejecter_user_id, - organization_id=approval.organization_id, - resource_type="user_network_approval", - resource_id=approval.id, - metadata={"target_user_id": approval.user_id}, - description=f"Network access rejected for user {approval.user_id}", + organization_id=request.organization_id, + resource_type="network_access_request", + resource_id=request.id, + metadata={"target_user_id": request.user_id}, + description=f"Network access rejected for user {request.user_id}", success=True, ) - return approval + return request -def revoke_approval( - approval_id: str, +def revoke_access( + request_id: str, revoker_user_id: str, -) -> UserNetworkApproval: - """Revoke an approved access record and deactivate all related memberships.""" - approval = _get_approval(approval_id) +) -> NetworkAccessRequest: + """Revoke an approved access request and deactivate it.""" + request = _get_request(request_id) - approval.state = ApprovalState.REVOKED - approval.save() + if request.status not in (ApprovalState.APPROVED, ApprovalState.SUSPENDED): + raise ValidationError(f"Cannot revoke request in status {request.status.value}.") - # Deactivate all memberships - for membership in approval.memberships: - if membership.deleted_at is None: - deactivate_membership(membership.id, reason="approval_revoked") + request.status = ApprovalState.REVOKED + request.active = False + request.save() + + # End any active activation session + _end_active_session(request) # defaults to APPROVAL_REVOKED + + # Deauthorize in ZeroTier + device = Device.query.get(request.device_id) + network = PortalNetwork.query.get(request.portal_network_id) + if device and network: + try: + zt.deauthorize_member(network.zerotier_network_id, device.node_id, + organization_id=request.organization_id) + except Exception as exc: + logger.warning(f"[revoke_access] Could not deauthorize {device.node_id}: {exc}") AuditService.log_action( action="zt.approval.revoked", user_id=revoker_user_id, - organization_id=approval.organization_id, - resource_type="user_network_approval", - resource_id=approval.id, - metadata={"target_user_id": approval.user_id}, - description=f"Network access revoked for user {approval.user_id}", + organization_id=request.organization_id, + resource_type="network_access_request", + resource_id=request.id, + metadata={"target_user_id": request.user_id}, + description=f"Network access revoked for user {request.user_id}", success=True, ) - return approval + return request def list_pending_approvals( organization_id: str, network_id: str | None = None, -) -> list[UserNetworkApproval]: - """List pending approval requests for managers.""" - q = UserNetworkApproval.query.filter( - UserNetworkApproval.organization_id == organization_id, - UserNetworkApproval.state == ApprovalState.PENDING, - UserNetworkApproval.deleted_at.is_(None), +) -> list[NetworkAccessRequest]: + """List pending access requests for managers.""" + q = NetworkAccessRequest.query.filter( + NetworkAccessRequest.organization_id == organization_id, + NetworkAccessRequest.status == ApprovalState.PENDING, + NetworkAccessRequest.deleted_at.is_(None), ) if network_id: - q = q.filter(UserNetworkApproval.portal_network_id == network_id) + q = q.filter(NetworkAccessRequest.portal_network_id == network_id) return q.all() -def list_user_approvals(user_id: str, organization_id: str) -> list[UserNetworkApproval]: - """List all approval records for a user in an org.""" - return UserNetworkApproval.query.filter( - UserNetworkApproval.user_id == user_id, - UserNetworkApproval.organization_id == organization_id, - UserNetworkApproval.deleted_at.is_(None), +def list_user_requests(user_id: str, organization_id: str) -> list[NetworkAccessRequest]: + """List all access requests for a user.""" + return NetworkAccessRequest.query.filter( + NetworkAccessRequest.user_id == user_id, + NetworkAccessRequest.organization_id == organization_id, + NetworkAccessRequest.deleted_at.is_(None), ).all() -def list_all_org_approvals( +def list_all_org_requests( organization_id: str, network_id: str | None = None, state: str | None = None, -) -> list[UserNetworkApproval]: - """List all approval records across all users in an org (admin use).""" - q = UserNetworkApproval.query.filter( - UserNetworkApproval.organization_id == organization_id, - UserNetworkApproval.deleted_at.is_(None), +) -> list[NetworkAccessRequest]: + """List all access requests for an organization.""" + q = NetworkAccessRequest.query.filter( + NetworkAccessRequest.organization_id == organization_id, + NetworkAccessRequest.deleted_at.is_(None), ) if network_id: - q = q.filter(UserNetworkApproval.portal_network_id == network_id) + q = q.filter(NetworkAccessRequest.portal_network_id == network_id) if state: - q = q.filter(UserNetworkApproval.state == state) - return q.order_by(UserNetworkApproval.created_at.desc()).all() - - -# ── Membership materialisation ─────────────────────────────────────────────── - - -def _materialize_memberships_for_approval(approval: UserNetworkApproval) -> None: - """Create DeviceNetworkMembership records for all of a user's devices on a network.""" - devices = Device.query.filter( - Device.user_id == approval.user_id, - Device.organization_id == approval.organization_id, - Device.deleted_at.is_(None), - ).all() - - for device in devices: - existing = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.device_id == device.id, - DeviceNetworkMembership.portal_network_id == approval.portal_network_id, - DeviceNetworkMembership.deleted_at.is_(None), - ).first() - - if not existing: - membership = DeviceNetworkMembership( - organization_id=approval.organization_id, - user_id=approval.user_id, - device_id=device.id, - portal_network_id=approval.portal_network_id, - user_network_approval_id=approval.id, - state=MembershipState.APPROVED_INACTIVE, - approved_for_activation=True, - ) - membership.save() - - # Pre-provision the member in ZeroTier (de-authorized) - _ensure_zerotier_member(device.node_id, approval.portal_network_id, authorized=False) - - AuditService.log_action( - action="zt.membership.created", - user_id=approval.user_id, - organization_id=approval.organization_id, - resource_type="device_network_membership", - resource_id=membership.id, - metadata={ - "device_id": device.id, - "device_node_id": device.node_id, - "portal_network_id": approval.portal_network_id, - }, - description=f"Device membership created for network", - success=True, - ) - - -def materialize_device_memberships(device_id: str) -> list[DeviceNetworkMembership]: - """When a device is newly registered, create memberships for all approved networks.""" - device = Device.query.filter(Device.id == device_id, Device.deleted_at.is_(None)).first() - if not device: - raise DeviceNotFoundError(f"Device {device_id} not found.") - - created = [] - approvals = UserNetworkApproval.query.filter( - UserNetworkApproval.user_id == device.user_id, - UserNetworkApproval.organization_id == device.organization_id, - UserNetworkApproval.state == ApprovalState.APPROVED, - UserNetworkApproval.deleted_at.is_(None), - ).all() - - for approval in approvals: - existing = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.device_id == device_id, - DeviceNetworkMembership.portal_network_id == approval.portal_network_id, - DeviceNetworkMembership.deleted_at.is_(None), - ).first() - if existing: - continue - - membership = DeviceNetworkMembership( - organization_id=device.organization_id, - user_id=device.user_id, - device_id=device_id, - portal_network_id=approval.portal_network_id, - user_network_approval_id=approval.id, - state=MembershipState.APPROVED_INACTIVE, - approved_for_activation=True, - ) - membership.save() - _ensure_zerotier_member( - device.node_id, - approval.portal_network_id, - authorized=False, - ) - created.append(membership) - - return created + try: + state_enum = ApprovalState(state) + except ValueError: + raise ValidationError(f"Invalid state filter: {state}") + q = q.filter(NetworkAccessRequest.status == state_enum) + return q.order_by(NetworkAccessRequest.created_at.desc()).all() # ── Activation ─────────────────────────────────────────────────────────────── -def activate_device_membership( - membership_id: str, +def activate_request( + request_id: str, user_id: str, lifetime_minutes: int | None = None, admin_override: bool = False, ) -> ActivationSession: - """Activate an approved device on a network. Creates an activation session and authorizes in ZT.""" - membership = _get_membership(membership_id) + """Activate an approved network access request. Creates an ActivationSession.""" + request = _get_request(request_id) - if not admin_override and membership.user_id != user_id: - raise MembershipNotFoundError("Membership not found.") + if not admin_override and request.user_id != user_id: + raise ApprovalNotFoundError("Request not found.") - # Check approval is still active - if membership.user_network_approval_id: - approval = UserNetworkApproval.query.get(membership.user_network_approval_id) - if not approval or approval.state != ApprovalState.APPROVED: - raise ValidationError("Network access approval is not active.") + if request.status != ApprovalState.APPROVED: + raise ValidationError(f"Request is not approved (current status: {request.status.value}).") + + if request.active: + raise ValidationError("Request is already active.") # Determine lifetime - network = PortalNetwork.query.get(membership.portal_network_id) + network = PortalNetwork.query.get(request.portal_network_id) if lifetime_minutes is None: lifetime_minutes = network.default_activation_lifetime_minutes if network.max_activation_lifetime_minutes and lifetime_minutes > network.max_activation_lifetime_minutes: lifetime_minutes = network.max_activation_lifetime_minutes # End any existing active session - for session in membership.activation_sessions: - if session.ended_at is None: - _end_session(session, ActivationEndReason.MANUAL_REVOKE) + _end_active_session(request, reason=ActivationEndReason.LOGOUT) - # Create session now = datetime.now(timezone.utc) expires = now + timedelta(minutes=lifetime_minutes) session = ActivationSession( - organization_id=membership.organization_id, - user_id=membership.user_id, - device_network_membership_id=membership.id, + organization_id=request.organization_id, + user_id=request.user_id, + network_access_request_id=request.id, authenticated_at=now, expires_at=expires, created_by=user_id, ) session.save() - # Update membership state - membership.state = MembershipState.ACTIVE_AUTHORIZED - membership.currently_authorized = True - membership.save() + # Update request + request.active = True + request.save() # Authorize in ZeroTier - device = Device.query.get(membership.device_id) - _authorize_in_zerotier(device.node_id, network.zerotier_network_id, membership) + device = Device.query.get(request.device_id) + _authorize_in_zerotier(device.node_id, network.zerotier_network_id, request) AuditService.log_action( action="zt.membership.activated", user_id=user_id, - organization_id=membership.organization_id, + organization_id=request.organization_id, resource_type="activation_session", resource_id=session.id, metadata={ - "membership_id": membership.id, + "request_id": request.id, "device_node_id": device.node_id, "network_id": network.zerotier_network_id, "expires_at": expires.isoformat(), @@ -539,182 +440,138 @@ def activate_all_approved( organization_id: str, lifetime_minutes: int | None = None, ) -> list[ActivationSession]: - """Bulk-activate all approved inactive memberships for a user.""" - memberships = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.user_id == user_id, - DeviceNetworkMembership.organization_id == organization_id, - DeviceNetworkMembership.state == MembershipState.APPROVED_INACTIVE, - DeviceNetworkMembership.approved_for_activation.is_(True), - DeviceNetworkMembership.deleted_at.is_(None), + """Bulk-activate all approved inactive requests for a user.""" + requests = NetworkAccessRequest.query.filter( + NetworkAccessRequest.user_id == user_id, + NetworkAccessRequest.organization_id == organization_id, + NetworkAccessRequest.status == ApprovalState.APPROVED, + NetworkAccessRequest.active == False, + NetworkAccessRequest.deleted_at.is_(None), ).all() sessions = [] - for m in memberships: + for r in requests: try: - s = activate_device_membership(m.id, user_id, lifetime_minutes=lifetime_minutes) + s = activate_request(r.id, user_id, lifetime_minutes=lifetime_minutes) sessions.append(s) except Exception as exc: - logger.warning(f"[Activation] Failed to activate membership {m.id}: {exc}") + logger.warning(f"[Activation] Failed to activate request {r.id}: {exc}") return sessions -def deactivate_membership( - membership_id: str, +def deactivate_request( + request_id: str, reason: str, deactivated_by_user_id: str | None = None, -) -> DeviceNetworkMembership: - """Deactivate a device membership: end session, deauthorize in ZT, update state.""" - membership = _get_membership(membership_id) +) -> NetworkAccessRequest: + """Deactivate a network access request: end session, deauthorize in ZT.""" + request = _get_request(request_id) + + if not request.active: + raise ValidationError("Request is not active.") # End any active session - for session in membership.activation_sessions: - if session.ended_at is None: - end_reason = ActivationEndReason(reason) if reason in ActivationEndReason._value2member_map_ else ActivationEndReason.MANUAL_REVOKE - _end_session(session, end_reason) + _end_active_session(request) # Deauthorize in ZeroTier - device = Device.query.get(membership.device_id) - network = PortalNetwork.query.get(membership.portal_network_id) - _deauthorize_in_zerotier(device.node_id, network.zerotier_network_id, - organization_id=membership.organization_id) + device = Device.query.get(request.device_id) + network = PortalNetwork.query.get(request.portal_network_id) + if device and network: + _deauthorize_in_zerotier(device.node_id, network.zerotier_network_id, + organization_id=request.organization_id) - membership.state = MembershipState.APPROVED_INACTIVE - membership.currently_authorized = False - membership.save() + request.active = False + request.save() AuditService.log_action( action="zt.membership.deactivated", user_id=deactivated_by_user_id, - organization_id=membership.organization_id, - resource_type="device_network_membership", - resource_id=membership.id, + organization_id=request.organization_id, + resource_type="network_access_request", + resource_id=request.id, metadata={ "reason": reason, - "device_node_id": device.node_id, - "network_id": network.zerotier_network_id, + "device_id": request.device_id, + "portal_network_id": request.portal_network_id, }, - description=f"Device membership deactivated: {reason}", + description="Device membership deactivated", success=True, ) - return membership + return request # ── Kill switch ─────────────────────────────────────────────────────────────── def kill_switch( - target_user_id: str, - triggered_by_user_id: str, - scope: str, - organization_id: str | None = None, - reason: str | None = None, + user_id: str, + org_id: str | None = None, network_ids: list[str] | None = None, -) -> KillSwitchEvent: - """Immediately deauthorize all active memberships for a user.""" - scope_enum = KillSwitchScope(scope) if scope in KillSwitchScope._value2member_map_ else KillSwitchScope.ORGANIZATION - - q = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.user_id == target_user_id, - DeviceNetworkMembership.state == MembershipState.ACTIVE_AUTHORIZED, - DeviceNetworkMembership.deleted_at.is_(None), + scope: KillSwitchScope = KillSwitchScope.ORGANIZATION, +) -> int: + """Emergency kill switch: deactivate all active requests for a user.""" + q = NetworkAccessRequest.query.filter( + NetworkAccessRequest.user_id == user_id, + NetworkAccessRequest.active == True, + NetworkAccessRequest.deleted_at.is_(None), ) + if org_id: + q = q.filter(NetworkAccessRequest.organization_id == org_id) + if network_ids: + q = q.filter(NetworkAccessRequest.portal_network_id.in_(network_ids)) - org_id = organization_id # Use caller-supplied org_id as the primary source - if scope_enum == KillSwitchScope.ORGANIZATION: - if not org_id: - # Fall back to deriving from first active membership - first = q.first() - org_id = first.organization_id if first else None - else: - # Scope query to the specified org - q = q.filter(DeviceNetworkMembership.organization_id == org_id) - elif scope_enum == KillSwitchScope.SELECTED_NETWORKS and network_ids: - q = q.filter(DeviceNetworkMembership.portal_network_id.in_(network_ids)) - if not org_id: - first_network = PortalNetwork.query.filter( - PortalNetwork.id.in_(network_ids), - PortalNetwork.deleted_at.is_(None), - ).first() - org_id = first_network.organization_id if first_network else None + requests = q.all() + count = 0 - if not org_id: - raise ValidationError("Cannot determine organization for kill switch event.") + for r in requests: + # End active session + _end_active_session(r, reason=ActivationEndReason.KILL_SWITCH) - # Create kill switch event - event = KillSwitchEvent( - organization_id=org_id or "", - target_user_id=target_user_id, - scope=scope_enum, - triggered_by_user_id=triggered_by_user_id, - reason=reason, - network_ids=network_ids, - ) - event.save() + # Deauthorize in ZT + device = Device.query.get(r.device_id) + network = PortalNetwork.query.get(r.portal_network_id) + if device and network: + try: + zt.deauthorize_member(network.zerotier_network_id, device.node_id, + organization_id=r.organization_id) + except Exception as exc: + logger.warning(f"[kill_switch] Could not deauthorize {device.node_id}: {exc}") - # Suspend all approvals - approvals = UserNetworkApproval.query.filter( - UserNetworkApproval.user_id == target_user_id, - UserNetworkApproval.state == ApprovalState.APPROVED, - UserNetworkApproval.deleted_at.is_(None), - ).all() - for approval in approvals: - if scope_enum == KillSwitchScope.ORGANIZATION and org_id: - if approval.organization_id != org_id: - continue - elif scope_enum == KillSwitchScope.SELECTED_NETWORKS and network_ids: - if approval.portal_network_id not in network_ids: - continue - approval.state = ApprovalState.SUSPENDED - approval.save() - - # Deactivate memberships - memberships = q.all() - for membership in memberships: - deactivate_membership(membership.id, reason="kill_switch") + # Update request + r.active = False + if r.status == ApprovalState.APPROVED: + r.status = ApprovalState.SUSPENDED + r.save() + count += 1 + # Log audit AuditService.log_action( - action="zt.kill_switch.triggered", - user_id=triggered_by_user_id, + action="zt.kill_switch.activated", + user_id=user_id, organization_id=org_id, - resource_type="kill_switch_event", - resource_id=event.id, - metadata={ - "target_user_id": target_user_id, - "scope": scope, - "reason": reason, - "network_ids": network_ids, - "memberships_deactivated": len(memberships), - }, - description=f"Kill switch triggered for user {target_user_id}: {len(memberships)} memberships deactivated", + resource_type="network_access_request", + metadata={"scope": scope.value, "affected_count": count, "network_ids": network_ids}, + description=f"Kill switch activated: {count} requests deactivated", success=True, ) - return event + return count # ── Helpers ──────────────────────────────────────────────────────────────────── -def _get_approval(approval_id: str) -> UserNetworkApproval: - approval = UserNetworkApproval.query.filter( - UserNetworkApproval.id == approval_id, - UserNetworkApproval.deleted_at.is_(None), +def _get_request(request_id: str) -> NetworkAccessRequest: + """Get a non-deleted network access request by ID.""" + request = NetworkAccessRequest.query.filter( + NetworkAccessRequest.id == request_id, + NetworkAccessRequest.deleted_at.is_(None), ).first() - if not approval: - raise ApprovalNotFoundError(f"Approval {approval_id} not found.") - return approval - - -def _get_membership(membership_id: str) -> DeviceNetworkMembership: - membership = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.id == membership_id, - DeviceNetworkMembership.deleted_at.is_(None), - ).first() - if not membership: - raise MembershipNotFoundError(f"Membership {membership_id} not found.") - return membership + if not request: + raise ApprovalNotFoundError(f"Network access request {request_id} not found.") + return request def _validate_org_network(org_id: str, network_id: str) -> PortalNetwork: @@ -751,11 +608,11 @@ def _ensure_zerotier_member( def _authorize_in_zerotier( node_id: str, zerotier_network_id: str, - membership: DeviceNetworkMembership, + request: NetworkAccessRequest, ) -> None: try: zt.authorize_member(zerotier_network_id, node_id, - organization_id=membership.organization_id) + organization_id=request.organization_id) # Update zerotier_membership cache zt_membership = ZeroTierMembership.query.filter( @@ -769,8 +626,8 @@ def _authorize_in_zerotier( zt_membership.save() else: zt_membership = ZeroTierMembership( - organization_id=membership.organization_id, - device_network_membership_id=membership.id, + organization_id=request.organization_id, + network_access_request_id=request.id, zerotier_network_id=zerotier_network_id, node_id=node_id, authorized=True, @@ -781,8 +638,8 @@ def _authorize_in_zerotier( AuditService.log_action( action="zt.member.authorized", - user_id=membership.user_id, - organization_id=membership.organization_id, + user_id=request.user_id, + organization_id=request.organization_id, resource_type="zerotier_membership", resource_id=zt_membership.id, metadata={"node_id": node_id, "network_id": zerotier_network_id}, @@ -843,6 +700,18 @@ def _end_session(session: ActivationSession, reason: ActivationEndReason) -> Non session.save() +def _end_active_session(request: NetworkAccessRequest, reason: ActivationEndReason = ActivationEndReason.APPROVAL_REVOKED) -> None: + """End any active activation session for a network access request.""" + session = ActivationSession.query.filter( + ActivationSession.network_access_request_id == request.id, + ActivationSession.ended_at.is_(None), + ).first() + if session: + session.ended_at = datetime.now(timezone.utc) + session.end_reason = reason + session.save() + + # ── Open network join ────────────────────────────────────────────────────────── @@ -851,216 +720,143 @@ def join_network_for_device( organization_id: str, device_id: str, portal_network_id: str, -) -> DeviceNetworkMembership: - """Join an open network with a specific registered device. - - Creates an immediately-approved UserNetworkApproval and DeviceNetworkMembership - in approved_inactive state. User can then activate. - """ + admin_override: bool = False, + granted_by_user_id: str | None = None, +) -> NetworkAccessRequest: + """Direct join for open networks or admin override. Creates an immediately-approved request.""" network = _validate_org_network(org_id=organization_id, network_id=portal_network_id) - if network.request_mode.value != "open": - raise ValidationError("Network does not support direct join. Use request_access instead.") + if not admin_override and network.request_mode.value != "open": + raise ValidationError("Network is not open. Use request_access() instead.") + + # For admin override, don't filter by user_id - admin can join any device in the org + if admin_override: + device = Device.query.filter( + Device.id == device_id, + Device.organization_id == organization_id, + Device.deleted_at.is_(None), + ).first() + # Use the device owner's user_id for the request + if device: + user_id = device.user_id + else: + device = Device.query.filter( + Device.id == device_id, + Device.user_id == user_id, + Device.organization_id == organization_id, + Device.deleted_at.is_(None), + ).first() - device = Device.query.filter( - Device.id == device_id, - Device.user_id == user_id, - Device.organization_id == organization_id, - Device.deleted_at.is_(None), - ).first() if not device: - raise DeviceNotFoundError(f"Device {device_id} not found or does not belong to this user.") + raise DeviceNotFoundError(f"Device {device_id} not found.") - existing = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.device_id == device_id, - DeviceNetworkMembership.portal_network_id == portal_network_id, - DeviceNetworkMembership.deleted_at.is_(None), + # Check for existing request + existing = NetworkAccessRequest.query.filter( + NetworkAccessRequest.user_id == user_id, + NetworkAccessRequest.device_id == device_id, + NetworkAccessRequest.portal_network_id == portal_network_id, + NetworkAccessRequest.deleted_at.is_(None), ).first() + if existing: - raise ValidationError("Device already has a membership for this network.") + if existing.status in (ApprovalState.APPROVED, ApprovalState.PENDING): + raise ApprovalAlreadyExistsError("Already have access or pending request.") + # Re-open + existing.status = ApprovalState.APPROVED + existing.active = False + if admin_override: + existing.grant_type = ApprovalGrantType.ASSIGNED + existing.granted_by_user_id = granted_by_user_id + existing.save() + _ensure_zerotier_member(device.node_id, portal_network_id, authorized=False) + return existing - approval = UserNetworkApproval( - organization_id=organization_id, - user_id=user_id, - portal_network_id=portal_network_id, - grant_type=ApprovalGrantType.REQUESTED, - state=ApprovalState.APPROVED, - justification="Direct join (open network)", - ) - approval.save() - - membership = DeviceNetworkMembership( + request = NetworkAccessRequest( organization_id=organization_id, user_id=user_id, device_id=device_id, portal_network_id=portal_network_id, - user_network_approval_id=approval.id, - state=MembershipState.APPROVED_INACTIVE, - approved_for_activation=True, + grant_type=ApprovalGrantType.ASSIGNED if admin_override else ApprovalGrantType.REQUESTED, + status=ApprovalState.APPROVED, + active=False, + granted_by_user_id=granted_by_user_id if admin_override else None, ) - membership.save() - + request.save() _ensure_zerotier_member(device.node_id, portal_network_id, authorized=False) AuditService.log_action( action="zt.membership.created", user_id=user_id, organization_id=organization_id, - resource_type="device_network_membership", - resource_id=membership.id, - metadata={ - "device_id": device_id, - "device_node_id": device.node_id, - "portal_network_id": portal_network_id, - "source": "open_network_join", - }, - description=f"Device membership created (direct join) for {device.node_id}", + resource_type="network_access_request", + resource_id=request.id, + metadata={"device_id": device_id, "device_node_id": device.node_id, "source": "admin_override" if admin_override else "open_network_join"}, + description=f"Direct join for device {device.node_id}" + (" (admin override)" if admin_override else ""), success=True, ) - return membership + return request # ── Admin membership management ──────────────────────────────────────────────── -def get_all_memberships_with_details(organization_id: str) -> list[dict]: - """Return all memberships for an org with enriched user/device/network info. - - Used by managers to see every device membership across all users and networks. - Returns a list of plain dicts (not model objects) for easy serialisation. - """ - memberships = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.organization_id == organization_id, - DeviceNetworkMembership.deleted_at.is_(None), +def get_all_requests_with_details(organization_id: str) -> list[dict]: + """List all access requests with details for an organization.""" + requests = NetworkAccessRequest.query.filter( + NetworkAccessRequest.organization_id == organization_id, + NetworkAccessRequest.deleted_at.is_(None), ).all() result = [] - for m in memberships: - device = Device.query.get(m.device_id) - network = PortalNetwork.query.get(m.portal_network_id) - approval = UserNetworkApproval.query.get(m.user_network_approval_id) if m.user_network_approval_id else None - - active_session = None - for sess in m.activation_sessions: - if sess.ended_at is None and sess.deleted_at is None: - active_session = sess.to_dict() - break - - from gatehouse_app.models.user.user import User - user = User.query.get(m.user_id) - - result.append({ - "id": m.id, - "user_id": m.user_id, - "user_email": user.email if user else m.user_id, - "user_full_name": user.full_name if user else None, - "device_id": m.device_id, - "device_nickname": device.device_nickname if device else None, - "device_hostname": device.hostname if device else None, - "device_node_id": device.node_id if device else None, - "device_status": device.status.value if device and device.status else None, - "portal_network_id": m.portal_network_id, - "network_name": network.name if network else m.portal_network_id, - "network_environment": network.environment.value if network and network.environment else None, - "state": m.state.value if m.state else None, - "join_seen": m.join_seen, - "currently_authorized": m.currently_authorized, - "approved_for_activation": m.approved_for_activation, - "user_network_approval_id": m.user_network_approval_id, - "approval_state": approval.state.value if approval and approval.state else None, - "active_session": active_session, - "created_at": m.created_at.isoformat() if m.created_at else None, - "updated_at": m.updated_at.isoformat() if m.updated_at else None, - }) + for r in requests: + d = r.to_dict() + d["device_node_id"] = r.device.node_id if r.device else None + result.append(d) return result -def revoke_membership_soft( - membership_id: str, - revoked_by_user_id: str | None = None, -) -> DeviceNetworkMembership: - """Soft-delete a membership (user or admin initiated). Sets deleted_at. +def revoke_request_soft( + request_id: str, + revoker_user_id: str, +) -> NetworkAccessRequest: + """Soft-delete a network access request.""" + request = _get_request(request_id) - The membership is marked deleted and the ZeroTier member will be removed - by the reconciliation job. - """ - membership = _get_membership(membership_id) + # End active session and deactivate + if request.active: + deactivate_request(request_id, reason="manual_revoke", deactivated_by_user_id=revoker_user_id) - for session in membership.activation_sessions: - if session.ended_at is None: - _end_session(session, ActivationEndReason.MANUAL_REVOKE) - - device = Device.query.get(membership.device_id) - network = PortalNetwork.query.get(membership.portal_network_id) - - if device and network: - try: - zt.deauthorize_member(network.zerotier_network_id, device.node_id, - organization_id=membership.organization_id) - except Exception as exc: - logger.warning(f"[revoke_membership_soft] ZT deauthorize failed for {device.node_id}: {exc}") - - membership.currently_authorized = False - membership.deleted_at = datetime.now(timezone.utc) - membership.save() + request.deleted_at = datetime.now(timezone.utc) + request.save() AuditService.log_action( - action="zt.membership.revoked", - user_id=revoked_by_user_id, - organization_id=membership.organization_id, - resource_type="device_network_membership", - resource_id=membership.id, - metadata={ - "device_node_id": device.node_id if device else None, - "network_id": network.zerotier_network_id if network else None, - }, - description=f"Membership revoked for device {device.node_id if device else membership.device_id}", + action="zt.request.revoked", + user_id=revoker_user_id, + organization_id=request.organization_id, + resource_type="network_access_request", + resource_id=request.id, + metadata={"target_user_id": request.user_id}, + description=f"Network access request revoked for user {request.user_id}", success=True, ) - return membership + return request -def hard_delete_membership(membership_id: str) -> None: - """Soft delete a membership after ZeroTier has been cleaned up. - - Called by the reconciliation job after successfully removing the member - from the ZeroTier controller. This marks the membership as deleted. - """ - membership = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.id == membership_id, +def hard_delete_request( + request_id: str, +) -> None: + """Hard-delete a soft-deleted network access request.""" + request = NetworkAccessRequest.query.filter( + NetworkAccessRequest.id == request_id, ).first() + if not request: + raise ApprovalNotFoundError(f"Request {request_id} not found.") - if not membership: - logger.warning(f"[hard_delete_membership] Membership {membership_id} not found or already deleted, skipping.") - return + if request.deleted_at is None: + raise ValidationError("Cannot hard-delete a non-soft-deleted request.") - device = Device.query.get(membership.device_id) - network = PortalNetwork.query.get(membership.portal_network_id) - - if device and network: - try: - zt.delete_network_member(network.zerotier_network_id, device.node_id, - organization_id=membership.organization_id) - logger.info(f"[hard_delete_membership] Deleted {device.node_id} from ZT network {network.zerotier_network_id}") - except Exception as exc: - logger.warning(f"[hard_delete_membership] ZT delete failed for {device.node_id}: {exc}") - - membership.delete(soft=True) + db.session.delete(request) db.session.commit() - - AuditService.log_action( - action="zt.membership.deleted", - user_id=None, - organization_id=membership.organization_id, - resource_type="device_network_membership", - resource_id=membership_id, - metadata={ - "device_node_id": device.node_id if device else None, - "network_id": network.zerotier_network_id if network else None, - }, - description=f"Membership deleted: device {device.node_id if device else 'unknown'} from network", - success=True, - ) diff --git a/gatehouse_app/services/portal_network_service.py b/gatehouse_app/services/portal_network_service.py index 0e2a140..4be9e31 100644 --- a/gatehouse_app/services/portal_network_service.py +++ b/gatehouse_app/services/portal_network_service.py @@ -262,47 +262,33 @@ def update_network( def delete_network(network_id: str, user_id: str) -> None: """Soft-delete a portal network and deactivate/clean up all related records.""" from datetime import datetime, timezone - from gatehouse_app.models import UserNetworkApproval from gatehouse_app.extensions import db network = get_network(network_id) # Deauthorize all active memberships in ZeroTier - for membership in network.memberships: - if membership.deleted_at is None and membership.state.value == "active_authorized": - from gatehouse_app.services.network_access_service import deactivate_membership - deactivate_membership(membership.id, reason="network_deleted") + for request in network.access_requests: + if request.deleted_at is None and request.active: + from gatehouse_app.services.network_access_service import deactivate_request + deactivate_request(request.id, reason="network_deleted") network.delete(soft=True) - # Cascade soft-delete all active approvals and memberships for this network. + # Cascade soft-delete all active access requests for this network. now = datetime.now(timezone.utc) db.session.execute( db.text( - "UPDATE user_network_approvals AS a " + "UPDATE network_access_requests AS a " "SET deleted_at = :now + (s.rn * interval '1 microsecond') " "FROM (" " SELECT id, row_number() OVER () AS rn " - " FROM user_network_approvals " + " FROM network_access_requests " " WHERE portal_network_id = :network_id AND deleted_at IS NULL" ") s " "WHERE a.id = s.id" ), {"now": now, "network_id": network_id}, ) - db.session.execute( - db.text( - "UPDATE device_network_memberships AS m " - "SET deleted_at = :now + (s.rn * interval '1 microsecond') " - "FROM (" - " SELECT id, row_number() OVER () AS rn " - " FROM device_network_memberships " - " WHERE portal_network_id = :network_id AND deleted_at IS NULL" - ") s " - "WHERE m.id = s.id" - ), - {"now": now, "network_id": network_id}, - ) db.session.commit() AuditService.log_action( @@ -318,22 +304,25 @@ def delete_network(network_id: str, user_id: str) -> None: def get_network_members(network_id: str) -> list: - """Return all DeviceNetworkMemberships for a network with user and device info.""" - from gatehouse_app.models import DeviceNetworkMembership + """Return all approved and active NetworkAccessRequests for a network.""" + from gatehouse_app.models import NetworkAccessRequest + from gatehouse_app.utils.constants import ApprovalState - return DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.portal_network_id == network_id, - DeviceNetworkMembership.deleted_at.is_(None), + return NetworkAccessRequest.query.filter( + NetworkAccessRequest.portal_network_id == network_id, + NetworkAccessRequest.status == ApprovalState.APPROVED, + NetworkAccessRequest.active == True, + NetworkAccessRequest.deleted_at.is_(None), ).all() def get_network_pending_requests(network_id: str) -> list: - """Return pending UserNetworkApprovals for a network.""" - from gatehouse_app.models import UserNetworkApproval + """Return pending NetworkAccessRequests for a network.""" + from gatehouse_app.models import NetworkAccessRequest from gatehouse_app.utils.constants import ApprovalState - return UserNetworkApproval.query.filter( - UserNetworkApproval.portal_network_id == network_id, - UserNetworkApproval.state == ApprovalState.PENDING, - UserNetworkApproval.deleted_at.is_(None), + return NetworkAccessRequest.query.filter( + NetworkAccessRequest.portal_network_id == network_id, + NetworkAccessRequest.status == ApprovalState.PENDING, + NetworkAccessRequest.deleted_at.is_(None), ).all() diff --git a/gatehouse_app/services/zerotier_reconciliation_service.py b/gatehouse_app/services/zerotier_reconciliation_service.py index 9238018..32fe366 100644 --- a/gatehouse_app/services/zerotier_reconciliation_service.py +++ b/gatehouse_app/services/zerotier_reconciliation_service.py @@ -7,16 +7,14 @@ from datetime import datetime, timezone from gatehouse_app.extensions import db from gatehouse_app.models import ( Device, - DeviceNetworkMembership, + NetworkAccessRequest, ActivationSession, ZeroTierMembership, PortalNetwork, - UserNetworkApproval, ) from gatehouse_app.services import zerotier_api_service as zt from gatehouse_app.utils.constants import ( ActivationEndReason, - MembershipState, ApprovalState, ) @@ -45,7 +43,7 @@ def reconcile_expired_activations() -> int: except Exception as exc: logger.error( f"[Reconciliation] Failed to expire session {session.id} " - f"(user={session.user_id} membership={session.device_network_membership_id}): {exc}", + f"(user={session.user_id} request={session.network_access_request_id}): {exc}", exc_info=True, ) @@ -104,9 +102,9 @@ def reconcile_network(portal_network_id: str) -> dict: # Get our portal memberships for this network our_memberships = { m.device.node_id: m - for m in DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.portal_network_id == portal_network_id, - DeviceNetworkMembership.deleted_at.is_(None), + for m in NetworkAccessRequest.query.filter( + NetworkAccessRequest.portal_network_id == portal_network_id, + NetworkAccessRequest.deleted_at.is_(None), ).all() if m.device and m.device.deleted_at is None } @@ -124,7 +122,7 @@ def reconcile_network(portal_network_id: str) -> dict: # Member not seen in ZT yet — could be freshly joined or never connected logger.debug( f"[Reconciliation] {network_label}: node {node_id} " - f"(device={device.display_name!r}, state={membership.state}) not yet seen in ZT controller." + f"(device={device.display_name!r}, active={membership.active}) not yet seen in ZT controller." ) continue @@ -134,11 +132,11 @@ def reconcile_network(portal_network_id: str) -> dict: _sync_zt_membership(membership, zt_member) # Sync authorization state - if membership.state == MembershipState.ACTIVE_AUTHORIZED: + if membership.active: if not zt_member.is_authorized: # Portal says active but ZT disagrees — drift, re-authorize logger.warning( - f"[Reconciliation] {network_label}: DRIFT detected — portal=ACTIVE_AUTHORIZED " + f"[Reconciliation] {network_label}: DRIFT detected — portal=active " f"but ZT says unauthorized for node {node_id} (device={device.display_name!r}). Re-authorizing." ) try: @@ -154,13 +152,13 @@ def reconcile_network(portal_network_id: str) -> dict: ) else: logger.debug( - f"[Reconciliation] {network_label}: node {node_id} — portal=ACTIVE_AUTHORIZED, ZT=authorized. OK." + f"[Reconciliation] {network_label}: node {node_id} — portal=active, ZT=authorized. OK." ) else: if zt_member.is_authorized: # ZT says authorized but portal doesn't — could be manual override in ZT console logger.warning( - f"[Reconciliation] {network_label}: DRIFT detected — portal state={membership.state} " + f"[Reconciliation] {network_label}: DRIFT detected — portal=inactive " f"but ZT says authorized for node {node_id} (device={device.display_name!r}). Deauthorizing." ) try: @@ -177,7 +175,7 @@ def reconcile_network(portal_network_id: str) -> dict: else: logger.debug( f"[Reconciliation] {network_label}: node {node_id} — " - f"portal={membership.state}, ZT=unauthorized. OK." + f"portal=inactive, ZT=unauthorized. OK." ) # Unknown ZT members not in our portal — log only, do not touch @@ -261,11 +259,11 @@ def reconcile_deleted_memberships() -> dict: """Find soft-deleted memberships and hard-delete them after ZeroTier cleanup. Only processes memberships whose ZeroTier members are already de-authorized - (the de-authorize step happened in revoke_membership_soft). This function + (the de-authorize step happened in revoke_request_soft). This function removes the member from ZeroTier entirely and then hard-deletes the DB record. """ - deleted = DeviceNetworkMembership.query.filter( - DeviceNetworkMembership.deleted_at.isnot(None), + deleted = NetworkAccessRequest.query.filter( + NetworkAccessRequest.deleted_at.isnot(None), ).all() if not deleted: @@ -328,7 +326,7 @@ def reconcile_deleted_memberships() -> dict: return results -def _sync_zt_membership(membership: DeviceNetworkMembership, zt_member) -> None: +def _sync_zt_membership(membership: NetworkAccessRequest, zt_member) -> None: """Update the ZeroTierMembership cache record from a ZT API response.""" device = membership.device network = membership.portal_network @@ -347,7 +345,7 @@ def _sync_zt_membership(membership: DeviceNetworkMembership, zt_member) -> None: ) zt_membership = ZeroTierMembership( organization_id=membership.organization_id, - device_network_membership_id=membership.id, + network_access_request_id=membership.id, zerotier_network_id=network.zerotier_network_id, node_id=device.node_id, ) @@ -377,10 +375,10 @@ def _sync_zt_membership(membership: DeviceNetworkMembership, zt_member) -> None: logger.info( f"[Reconciliation] First join seen for node {device.node_id} " f"(device={device.display_name!r}, membership={membership.id}). " - f"State: {membership.state} → {MembershipState.JOINED_DEAUTHORIZED}" + f"Setting join_seen=True, active=False" ) membership.join_seen = True - membership.state = MembershipState.JOINED_DEAUTHORIZED + membership.active = False membership.save() else: logger.debug( @@ -397,23 +395,22 @@ def _expire_session(session: ActivationSession) -> None: logger.info( f"[Reconciliation] Expiring activation session {session.id} " - f"(user={session.user_id}, membership={session.device_network_membership_id}, " + f"(user={session.user_id}, request={session.network_access_request_id}, " f"expired_at={session.expires_at.isoformat()})." ) - membership = DeviceNetworkMembership.query.get(session.device_network_membership_id) - if not membership: + request = NetworkAccessRequest.query.get(session.network_access_request_id) + if not request: logger.warning( - f"[Reconciliation] Session {session.id}: membership " - f"{session.device_network_membership_id} not found — skipping ZT deauth." + f"[Reconciliation] Session {session.id}: request " + f"{session.network_access_request_id} not found — skipping ZT deauth." ) else: - membership.state = MembershipState.ACTIVATION_EXPIRED - membership.currently_authorized = False - membership.save() + request.active = False + request.save() - device = Device.query.get(membership.device_id) - network = PortalNetwork.query.get(membership.portal_network_id) + device = Device.query.get(request.device_id) + network = PortalNetwork.query.get(request.portal_network_id) if device and network: network_label = f"{network.name} ({network.zerotier_network_id})" try: @@ -449,8 +446,8 @@ def _expire_session(session: ActivationSession) -> None: else: logger.warning( f"[Reconciliation] Session {session.id}: missing " - f"{'device' if not device else 'network'} for membership " - f"{membership.id} — ZT deauth skipped." + f"{'device' if not device else 'network'} for request " + f"{request.id} — ZT deauth skipped." ) from gatehouse_app.services.audit_service import AuditService @@ -460,7 +457,7 @@ def _expire_session(session: ActivationSession) -> None: organization_id=session.organization_id, resource_type="activation_session", resource_id=session.id, - metadata={"membership_id": session.device_network_membership_id}, + metadata={"request_id": session.network_access_request_id}, description="Activation session expired", success=True, ) diff --git a/gatehouse_app/utils/constants.py b/gatehouse_app/utils/constants.py index 1d600b7..adbb669 100644 --- a/gatehouse_app/utils/constants.py +++ b/gatehouse_app/utils/constants.py @@ -253,21 +253,6 @@ class ApprovalState(str, Enum): SUSPENDED = "suspended" -class MembershipState(str, Enum): - """State of a device network membership record.""" - - PENDING_DEVICE_REGISTRATION = "pending_device_registration" - PENDING_REQUEST = "pending_request" - PENDING_MANAGER_APPROVAL = "pending_manager_approval" - APPROVED_INACTIVE = "approved_inactive" - JOINED_DEAUTHORIZED = "joined_deauthorized" - ACTIVE_AUTHORIZED = "active_authorized" - ACTIVATION_EXPIRED = "activation_expired" - SUSPENDED = "suspended" - REVOKED = "revoked" - REJECTED = "rejected" - - class ActivationEndReason(str, Enum): """Why an activation session ended.""" diff --git a/migrations/versions/d1e2f3g4h5i6_remove_sudo_and_api_keys.py b/migrations/versions/d1e2f3g4h5i6_remove_sudo_and_api_keys.py new file mode 100644 index 0000000..aab9efa --- /dev/null +++ b/migrations/versions/d1e2f3g4h5i6_remove_sudo_and_api_keys.py @@ -0,0 +1,71 @@ +"""Remove sudo: drop can_sudo column and organization_api_keys table. + +Revision ID: d1e2f3g4h5i6 +Revises: c0a1b2c3d4e5 +Create Date: 2026-05-03 10:20:00.000000 +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd1e2f3g4h5i6' +down_revision = 'c0a1b2c3d4e5' +branch_labels = None +depends_on = None + + +def upgrade(): + # ------------------------------------------------------------------ + # Step 1: Drop organization_api_keys table and all its indexes + # ------------------------------------------------------------------ + op.drop_index('idx_api_key_last_used', table_name='organization_api_keys') + op.drop_index('idx_org_api_key_org_active', table_name='organization_api_keys') + op.drop_index(op.f('ix_organization_api_keys_is_revoked'), table_name='organization_api_keys') + op.drop_index(op.f('ix_organization_api_keys_key_hash'), table_name='organization_api_keys') + op.drop_index(op.f('ix_organization_api_keys_organization_id'), table_name='organization_api_keys') + op.drop_table('organization_api_keys') + + # ------------------------------------------------------------------ + # Step 2: Drop can_sudo column from departments table + # ------------------------------------------------------------------ + op.drop_column('departments', 'can_sudo') + + +def downgrade(): + # ------------------------------------------------------------------ + # Step 1: Recreate can_sudo column in departments table + # ------------------------------------------------------------------ + op.add_column( + 'departments', + sa.Column('can_sudo', sa.Boolean(), nullable=False, server_default='false') + ) + + # ------------------------------------------------------------------ + # Step 2: Recreate organization_api_keys table + # ------------------------------------------------------------------ + op.create_table( + 'organization_api_keys', + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('organization_id', sa.String(length=36), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('key_hash', sa.String(length=255), nullable=False), + sa.Column('last_used_at', sa.DateTime(), nullable=True), + sa.Column('is_revoked', sa.Boolean(), nullable=False), + sa.Column('revoked_at', sa.DateTime(), nullable=True), + sa.Column('revoke_reason', sa.String(length=255), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], name='fk_organization_api_keys_organization'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('key_hash', name='uq_organization_api_keys_key_hash'), + ) + + # Recreate indexes on organization_api_keys + op.create_index('idx_org_api_key_org_active', 'organization_api_keys', ['organization_id', 'is_revoked']) + op.create_index('idx_api_key_last_used', 'organization_api_keys', ['last_used_at']) + op.create_index(op.f('ix_organization_api_keys_is_revoked'), 'organization_api_keys', ['is_revoked']) + op.create_index(op.f('ix_organization_api_keys_key_hash'), 'organization_api_keys', ['key_hash'], unique=True) + op.create_index(op.f('ix_organization_api_keys_organization_id'), 'organization_api_keys', ['organization_id']) diff --git a/migrations/versions/merge_approval_membership_tables.py b/migrations/versions/merge_approval_membership_tables.py new file mode 100644 index 0000000..6297bdd --- /dev/null +++ b/migrations/versions/merge_approval_membership_tables.py @@ -0,0 +1,691 @@ +"""Merge user_network_approvals and device_network_memberships into network_access_requests. + +Revision ID: c0a1b2c3d4e5 +Revises: a1b2c3d4e5f6 +Create Date: 2026-05-02 00:00:00.000000 +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'c0a1b2c3d4e5' +down_revision = 'a1b2c3d4e5f6' +branch_labels = None +depends_on = None + + +# --------------------------------------------------------------------------- +# UPGRADE +# --------------------------------------------------------------------------- + +def upgrade(): + # ------------------------------------------------------------------ + # Step 1: Create the new network_access_requests table + # ------------------------------------------------------------------ + op.create_table( + 'network_access_requests', + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('organization_id', sa.String(length=36), nullable=False), + sa.Column('user_id', sa.String(length=36), nullable=False), + sa.Column('device_id', sa.String(length=36), nullable=False), + sa.Column('portal_network_id', sa.String(length=36), nullable=False), + sa.Column('granted_by_user_id', sa.String(length=36), nullable=True), + sa.Column( + 'grant_type', + sa.Enum('requested', 'assigned', name='approval_grant_type', create_type=False), + nullable=False, + ), + sa.Column( + 'status', + sa.Enum( + 'pending', 'approved', 'rejected', 'revoked', 'suspended', + name='approval_state', create_type=False, + ), + nullable=False, + ), + sa.Column('active', sa.Boolean(), nullable=False, server_default='false'), + sa.Column('justification', sa.Text(), nullable=True), + sa.Column('join_seen', sa.Boolean(), nullable=False, server_default='false'), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ['device_id'], ['devices.id'], + name='fk_network_access_requests_device', + ), + sa.ForeignKeyConstraint( + ['granted_by_user_id'], ['users.id'], + name='fk_network_access_requests_granted_by_user', + ), + sa.ForeignKeyConstraint( + ['organization_id'], ['organizations.id'], + name='fk_network_access_requests_organization', + ), + sa.ForeignKeyConstraint( + ['portal_network_id'], ['portal_networks.id'], + name='fk_network_access_requests_portal_network', + ), + sa.ForeignKeyConstraint( + ['user_id'], ['users.id'], + name='fk_network_access_requests_user', + ), + sa.PrimaryKeyConstraint('id', name='pk_network_access_requests'), + sa.UniqueConstraint( + 'user_id', 'device_id', 'portal_network_id', 'deleted_at', + name='uix_user_device_network', + ), + ) + + # Indexes on network_access_requests + op.create_index( + 'ix_network_access_requests_device_id', + 'network_access_requests', + ['device_id'], + unique=False, + ) + op.create_index( + 'ix_network_access_requests_organization_id', + 'network_access_requests', + ['organization_id'], + unique=False, + ) + op.create_index( + 'ix_network_access_requests_portal_network_id', + 'network_access_requests', + ['portal_network_id'], + unique=False, + ) + op.create_index( + 'ix_network_access_requests_status', + 'network_access_requests', + ['status'], + unique=False, + ) + op.create_index( + 'ix_network_access_requests_user_id', + 'network_access_requests', + ['user_id'], + unique=False, + ) + + # ------------------------------------------------------------------ + # Step 2: Migrate data from old tables into the new table + # ------------------------------------------------------------------ + op.execute( + """ + INSERT INTO network_access_requests ( + id, organization_id, user_id, device_id, portal_network_id, + granted_by_user_id, grant_type, status, active, justification, + join_seen, created_at, updated_at, deleted_at + ) + SELECT + dnm.id, + dnm.organization_id, + dnm.user_id, + dnm.device_id, + dnm.portal_network_id, + COALESCE(una.granted_by_user_id, NULL), + COALESCE(una.grant_type, 'requested'), + COALESCE(una.state, 'pending'), + CASE + WHEN dnm.currently_authorized = true AND una.state = 'approved' + THEN true + ELSE false + END, + una.justification, + dnm.join_seen, + COALESCE(dnm.created_at, una.created_at), + COALESCE(dnm.updated_at, una.updated_at), + dnm.deleted_at + FROM device_network_memberships dnm + LEFT JOIN user_network_approvals una + ON una.id = dnm.user_network_approval_id; + """ + ) + + # ------------------------------------------------------------------ + # Step 3: Update activation_sessions FK + # ------------------------------------------------------------------ + # 3a. Add the new nullable column + op.add_column( + 'activation_sessions', + sa.Column('network_access_request_id', sa.String(length=36), nullable=True), + ) + + # 3b. Populate the new column from the old column + op.execute( + """ + UPDATE activation_sessions + SET network_access_request_id = device_network_membership_id; + """ + ) + + # 3c. Drop the old foreign-key constraint + op.drop_constraint( + 'activation_sessions_device_network_membership_id_fkey', + 'activation_sessions', + type_='foreignkey', + ) + + # 3d. Drop the old column + op.drop_column('activation_sessions', 'device_network_membership_id') + + # 3d-alt. Enforce NOT NULL on the new column before FK creation + op.alter_column('activation_sessions', 'network_access_request_id', nullable=False) + + # 3e. Create the new foreign-key constraint + op.create_foreign_key( + 'fk_activation_sessions_network_access_request', + 'activation_sessions', + 'network_access_requests', + ['network_access_request_id'], + ['id'], + ) + + # 3f. Create the new index + op.create_index( + 'ix_activation_sessions_network_access_request_id', + 'activation_sessions', + ['network_access_request_id'], + unique=False, + ) + + # ------------------------------------------------------------------ + # Step 4: Update zerotier_memberships FK + # ------------------------------------------------------------------ + # 4a. Add the new nullable column + op.add_column( + 'zerotier_memberships', + sa.Column('network_access_request_id', sa.String(length=36), nullable=True), + ) + + # 4b. Populate the new column from the old column + op.execute( + """ + UPDATE zerotier_memberships + SET network_access_request_id = device_network_membership_id; + """ + ) + + # 4c. Drop the old foreign-key constraint + op.drop_constraint( + 'zerotier_memberships_device_network_membership_id_fkey', + 'zerotier_memberships', + type_='foreignkey', + ) + + # 4d. Drop the old column + op.drop_column('zerotier_memberships', 'device_network_membership_id') + + # 4e. Create the new foreign-key constraint + op.create_foreign_key( + 'fk_zerotier_memberships_network_access_request', + 'zerotier_memberships', + 'network_access_requests', + ['network_access_request_id'], + ['id'], + ) + + # 4f. Create the new index + op.create_index( + 'ix_zerotier_memberships_network_access_request_id', + 'zerotier_memberships', + ['network_access_request_id'], + unique=False, + ) + + # ------------------------------------------------------------------ + # Step 5: Drop old tables and the membership_state enum + # ------------------------------------------------------------------ + # 5a. Drop device_network_memberships and all its indexes + op.drop_index( + 'ix_device_network_memberships_user_network_approval_id', + table_name='device_network_memberships', + ) + op.drop_index( + 'ix_device_network_memberships_user_id', + table_name='device_network_memberships', + ) + op.drop_index( + 'ix_device_network_memberships_state', + table_name='device_network_memberships', + ) + op.drop_index( + 'ix_device_network_memberships_portal_network_id', + table_name='device_network_memberships', + ) + op.drop_index( + 'ix_device_network_memberships_organization_id', + table_name='device_network_memberships', + ) + op.drop_index( + 'ix_device_network_memberships_device_id', + table_name='device_network_memberships', + ) + op.drop_table('device_network_memberships') + + # 5b. Drop user_network_approvals and all its indexes + op.drop_index( + 'ix_user_network_approvals_user_id', + table_name='user_network_approvals', + ) + op.drop_index( + 'ix_user_network_approvals_state', + table_name='user_network_approvals', + ) + op.drop_index( + 'ix_user_network_approvals_portal_network_id', + table_name='user_network_approvals', + ) + op.drop_index( + 'ix_user_network_approvals_organization_id', + table_name='user_network_approvals', + ) + op.drop_table('user_network_approvals') + + # 5c. Drop the membership_state enum type if it exists + op.execute( + """ + DO $$ + BEGIN + IF EXISTS ( + SELECT 1 FROM pg_type WHERE typname = 'membership_state' + ) THEN + DROP TYPE membership_state; + END IF; + END$$; + """ + ) + + +# --------------------------------------------------------------------------- +# DOWNGRADE +# --------------------------------------------------------------------------- + +def downgrade(): + # ------------------------------------------------------------------ + # Step 1: Recreate the membership_state enum (used by old tables) + # ------------------------------------------------------------------ + membership_state = sa.Enum( + 'pending_device_registration', + 'pending_request', + 'pending_manager_approval', + 'approved_inactive', + 'joined_deauthorized', + 'active_authorized', + 'activation_expired', + 'suspended', + 'revoked', + 'rejected', + name='membership_state', + ) + membership_state.create(op.get_bind(), checkfirst=True) + + # ------------------------------------------------------------------ + # Step 2: Recreate user_network_approvals table + # ------------------------------------------------------------------ + op.create_table( + 'user_network_approvals', + sa.Column('organization_id', sa.String(length=36), nullable=False), + sa.Column('user_id', sa.String(length=36), nullable=False), + sa.Column('portal_network_id', sa.String(length=36), nullable=False), + sa.Column('granted_by_user_id', sa.String(length=36), nullable=True), + sa.Column( + 'grant_type', + sa.Enum('requested', 'assigned', name='approval_grant_type', create_type=False), + nullable=False, + ), + sa.Column( + 'state', + sa.Enum( + 'pending', 'approved', 'rejected', 'revoked', 'suspended', + name='approval_state', create_type=False, + ), + nullable=False, + ), + sa.Column('justification', sa.Text(), nullable=True), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ['granted_by_user_id'], ['users.id'], + ), + sa.ForeignKeyConstraint( + ['organization_id'], ['organizations.id'], + ), + sa.ForeignKeyConstraint( + ['portal_network_id'], ['portal_networks.id'], + ), + sa.ForeignKeyConstraint( + ['user_id'], ['users.id'], + ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint( + 'user_id', 'portal_network_id', 'deleted_at', + name='uix_user_network_approval', + ), + ) + + # Recreate indexes on user_network_approvals + op.create_index( + 'ix_user_network_approvals_organization_id', + 'user_network_approvals', + ['organization_id'], + unique=False, + ) + op.create_index( + 'ix_user_network_approvals_portal_network_id', + 'user_network_approvals', + ['portal_network_id'], + unique=False, + ) + op.create_index( + 'ix_user_network_approvals_state', + 'user_network_approvals', + ['state'], + unique=False, + ) + op.create_index( + 'ix_user_network_approvals_user_id', + 'user_network_approvals', + ['user_id'], + unique=False, + ) + + # ------------------------------------------------------------------ + # Step 3: Migrate data back into user_network_approvals + # ------------------------------------------------------------------ + # Derive one approval row per (user_id, portal_network_id, deleted_at). + # We use gen_random_uuid() to generate new approval IDs because the + # original approval IDs were lost during the upgrade. + op.execute( + """ + INSERT INTO user_network_approvals ( + id, organization_id, user_id, portal_network_id, + granted_by_user_id, grant_type, state, justification, + created_at, updated_at, deleted_at + ) + SELECT + gen_random_uuid()::text, + (array_agg(organization_id ORDER BY created_at))[1], + user_id, + portal_network_id, + (array_agg(granted_by_user_id ORDER BY created_at))[1], + (array_agg(grant_type ORDER BY created_at))[1], + (array_agg(status ORDER BY created_at))[1], + (array_agg(justification ORDER BY created_at))[1], + MIN(created_at), + MAX(updated_at), + deleted_at + FROM network_access_requests + GROUP BY user_id, portal_network_id, deleted_at; + """ + ) + + # ------------------------------------------------------------------ + # Step 4: Recreate device_network_memberships table + # ------------------------------------------------------------------ + op.create_table( + 'device_network_memberships', + sa.Column('organization_id', sa.String(length=36), nullable=False), + sa.Column('user_id', sa.String(length=36), nullable=False), + sa.Column('device_id', sa.String(length=36), nullable=False), + sa.Column('portal_network_id', sa.String(length=36), nullable=False), + sa.Column('user_network_approval_id', sa.String(length=36), nullable=True), + sa.Column( + 'state', + sa.Enum( + 'pending_device_registration', + 'pending_request', + 'pending_manager_approval', + 'approved_inactive', + 'joined_deauthorized', + 'active_authorized', + 'activation_expired', + 'suspended', + 'revoked', + 'rejected', + name='membership_state', create_type=False, + ), + nullable=False, + ), + sa.Column('join_seen', sa.Boolean(), nullable=False), + sa.Column('currently_authorized', sa.Boolean(), nullable=False), + sa.Column('approved_for_activation', sa.Boolean(), nullable=False), + sa.Column('id', sa.String(length=36), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ['device_id'], ['devices.id'], + ), + sa.ForeignKeyConstraint( + ['organization_id'], ['organizations.id'], + ), + sa.ForeignKeyConstraint( + ['portal_network_id'], ['portal_networks.id'], + ), + sa.ForeignKeyConstraint( + ['user_id'], ['users.id'], + ), + sa.ForeignKeyConstraint( + ['user_network_approval_id'], ['user_network_approvals.id'], + ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint( + 'device_id', 'portal_network_id', 'deleted_at', + name='uix_device_network', + ), + ) + + # Recreate indexes on device_network_memberships + op.create_index( + 'ix_device_network_memberships_device_id', + 'device_network_memberships', + ['device_id'], + unique=False, + ) + op.create_index( + 'ix_device_network_memberships_organization_id', + 'device_network_memberships', + ['organization_id'], + unique=False, + ) + op.create_index( + 'ix_device_network_memberships_portal_network_id', + 'device_network_memberships', + ['portal_network_id'], + unique=False, + ) + op.create_index( + 'ix_device_network_memberships_state', + 'device_network_memberships', + ['state'], + unique=False, + ) + op.create_index( + 'ix_device_network_memberships_user_id', + 'device_network_memberships', + ['user_id'], + unique=False, + ) + op.create_index( + 'ix_device_network_memberships_user_network_approval_id', + 'device_network_memberships', + ['user_network_approval_id'], + unique=False, + ) + + # ------------------------------------------------------------------ + # Step 5: Migrate data back into device_network_memberships + # ------------------------------------------------------------------ + # Map network_access_requests rows back to device_network_memberships. + # Reverse the status/active mapping using a best-effort approach. + op.execute( + """ + INSERT INTO device_network_memberships ( + id, organization_id, user_id, device_id, portal_network_id, + user_network_approval_id, state, join_seen, currently_authorized, + approved_for_activation, created_at, updated_at, deleted_at + ) + SELECT + nar.id, + nar.organization_id, + nar.user_id, + nar.device_id, + nar.portal_network_id, + una.id AS user_network_approval_id, + CASE nar.status + WHEN 'approved' THEN + CASE WHEN nar.active = true + THEN 'active_authorized' + ELSE 'approved_inactive' + END + WHEN 'pending' THEN 'pending_request' + ELSE nar.status + END AS state, + nar.join_seen, + nar.active AS currently_authorized, + CASE + WHEN nar.status = 'approved' THEN true + ELSE false + END AS approved_for_activation, + nar.created_at, + nar.updated_at, + nar.deleted_at + FROM network_access_requests nar + JOIN user_network_approvals una + ON una.user_id = nar.user_id + AND una.portal_network_id = nar.portal_network_id + AND (una.deleted_at IS NOT DISTINCT FROM nar.deleted_at); + """ + ) + + # ------------------------------------------------------------------ + # Step 6: Restore activation_sessions FK + # ------------------------------------------------------------------ + # 6a. Add the old column (nullable first so we can populate) + op.add_column( + 'activation_sessions', + sa.Column('device_network_membership_id', sa.String(length=36), nullable=True), + ) + + # 6b. Populate the old column from the new column before it disappears + op.execute( + """ + UPDATE activation_sessions + SET device_network_membership_id = network_access_request_id + WHERE network_access_request_id IS NOT NULL; + """ + ) + + # 6c. Drop the new column, FK, and index + op.drop_constraint( + 'fk_activation_sessions_network_access_request', + 'activation_sessions', + type_='foreignkey', + ) + op.drop_index( + 'ix_activation_sessions_network_access_request_id', + table_name='activation_sessions', + ) + op.drop_column('activation_sessions', 'network_access_request_id') + + # 6d. Alter the old column to NOT NULL + op.alter_column( + 'activation_sessions', + 'device_network_membership_id', + nullable=False, + ) + + # 6d. Recreate the old foreign key + op.create_foreign_key( + None, + 'activation_sessions', + 'device_network_memberships', + ['device_network_membership_id'], + ['id'], + ) + + # 6e. Recreate the old index + op.create_index( + 'ix_activation_sessions_device_network_membership_id', + 'activation_sessions', + ['device_network_membership_id'], + unique=False, + ) + + # ------------------------------------------------------------------ + # Step 7: Restore zerotier_memberships FK + # ------------------------------------------------------------------ + # 7a. Add the old column (nullable first so we can populate) + op.add_column( + 'zerotier_memberships', + sa.Column('device_network_membership_id', sa.String(length=36), nullable=True), + ) + + # 7b. Populate the old column from the new column before it disappears + op.execute( + """ + UPDATE zerotier_memberships + SET device_network_membership_id = network_access_request_id + WHERE network_access_request_id IS NOT NULL; + """ + ) + + # 7c. Drop the new column, FK, and index + op.drop_constraint( + 'fk_zerotier_memberships_network_access_request', + 'zerotier_memberships', + type_='foreignkey', + ) + op.drop_index( + 'ix_zerotier_memberships_network_access_request_id', + table_name='zerotier_memberships', + ) + op.drop_column('zerotier_memberships', 'network_access_request_id') + + # 7d. Recreate the old foreign key + op.create_foreign_key( + None, + 'zerotier_memberships', + 'device_network_memberships', + ['device_network_membership_id'], + ['id'], + ) + + # 7e. Recreate the old index + op.create_index( + 'ix_zerotier_memberships_device_network_membership_id', + 'zerotier_memberships', + ['device_network_membership_id'], + unique=False, + ) + + # ------------------------------------------------------------------ + # Step 8: Drop the new network_access_requests table and indexes + # ------------------------------------------------------------------ + op.drop_index( + 'ix_network_access_requests_user_id', + table_name='network_access_requests', + ) + op.drop_index( + 'ix_network_access_requests_status', + table_name='network_access_requests', + ) + op.drop_index( + 'ix_network_access_requests_portal_network_id', + table_name='network_access_requests', + ) + op.drop_index( + 'ix_network_access_requests_organization_id', + table_name='network_access_requests', + ) + op.drop_index( + 'ix_network_access_requests_device_id', + table_name='network_access_requests', + ) + op.drop_table('network_access_requests') diff --git a/tests/integration/client/admin.py b/tests/integration/client/admin.py index 66bc212..fb84097 100644 --- a/tests/integration/client/admin.py +++ b/tests/integration/client/admin.py @@ -48,6 +48,21 @@ class AdminClient: data={"confirm": confirm}, ) + def get_user_ssh_certificates(self, user_id: str, **params) -> dict: + """List all SSH certificates for a user (admin view). + + Args: + user_id: Target user ID + **params: Optional query parameters — status, active, cert_type, page, per_page + """ + path = f"/admin/users/{user_id}/ssh-certificates" + if params: + from urllib.parse import urlencode + query = urlencode({k: v for k, v in params.items() if v is not None}) + if query: + path = f"{path}?{query}" + return self._client.get(path) + def list_audit_logs(self) -> dict: """List system-wide audit logs.""" return self._client.get("/audit-logs") diff --git a/tests/integration/test_admin_ops.py b/tests/integration/test_admin_ops.py index 0e50b66..9f66b9a 100644 --- a/tests/integration/test_admin_ops.py +++ b/tests/integration/test_admin_ops.py @@ -211,3 +211,309 @@ class TestAdminUserManagement: with pytest.raises(ApiError) as exc_info: integration_client.auth.login(email=victim["email"], password="VictimPass123!") assert exc_info.value.status_code in (400, 401) + + +class TestAdminSSHCertificates: + """Test admin SSH certificate listing endpoints.""" + + def _create_test_cert( + self, integration_app, user_id: str, ca_id: str, *, ssh_key_id=None, + status="issued", revoked=False, valid_after=None, valid_before=None, + cert_type="user", principals=None, + ): + """Create a test SSH certificate record.""" + from datetime import datetime, timezone, timedelta + from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus + from gatehouse_app.models.ssh_ca.ca import CertType + + now = datetime.now(timezone.utc) + valid_after = valid_after or (now - timedelta(hours=1)) + valid_before = valid_before or (now + timedelta(hours=23)) + principals = principals or ["prod-servers"] + + with integration_app.app_context(): + cert = SSHCertificate( + ca_id=ca_id, + user_id=user_id, + ssh_key_id=ssh_key_id, + certificate=f"ssh-ed25519-cert-v01@openssh.com AAAA...test_serial_{uuid.uuid4().hex[:8]}", + serial=str(uuid.uuid4().int)[:20], + key_id=f"test@example.com-{uuid.uuid4().hex[:8]}", + cert_type=CertType(cert_type), + principals=principals, + valid_after=valid_after, + valid_before=valid_before, + revoked=revoked, + status=CertificateStatus(status), + request_ip="192.168.1.100", + request_user_agent="OpenSSH_9.0", + ) + if revoked: + cert.revoked_at = now + cert.revoke_reason = "test revocation" + db.session.add(cert) + db.session.commit() + return str(cert.id) + + def _create_test_ssh_key(self, integration_app, user_id: str, fingerprint: str = None): + """Create a test SSH key record.""" + from gatehouse_app.models.ssh_ca.ssh_key import SSHKey + + fingerprint = fingerprint or f"SHA256:{uuid.uuid4().hex[:43]}" + with integration_app.app_context(): + key = SSHKey( + user_id=user_id, + payload=f"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI...test", + fingerprint=fingerprint, + description="Test laptop key", + verified=True, + key_type="ssh-ed25519", + key_bits=256, + key_comment="test@laptop", + ) + db.session.add(key) + db.session.commit() + return str(key.id) + + def test_list_user_ssh_certs_positive(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca): + """TEST: ADMIN-SSH-01 — List all SSH certificates for a user as admin. + + WHAT: Create a user with two certs (one active, one expired), + admin lists all certs via the new endpoint. + WHY: Admin needs full visibility of user SSH certificate history. + EXPECTED: 200 OK with certificates array containing both certs. + """ + admin = create_test_user(password="AdminPass123!") + victim = create_test_user(password="VictimPass123!") + org = create_test_org() + create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER) + create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER) + ca = create_test_ca(org_id=org["id"]) + + from datetime import datetime, timezone, timedelta + now = datetime.now(timezone.utc) + + # Create an active cert + self._create_test_cert( + integration_app, victim["id"], ca["id"], + status="issued", valid_after=now - timedelta(hours=1), + valid_before=now + timedelta(hours=23), + ) + # Create an expired cert + self._create_test_cert( + integration_app, victim["id"], ca["id"], + status="expired", valid_after=now - timedelta(days=7), + valid_before=now - timedelta(days=1), + ) + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + result = integration_client.admin.get_user_ssh_certificates(victim["id"]) + data = assert_success(result) + assert "certificates" in data + assert data["count"] == 2 + assert len(data["certificates"]) == 2 + + def test_list_user_ssh_certs_with_key_metadata(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca): + """TEST: ADMIN-SSH-02 — Certificate includes SSH key metadata. + + WHAT: Create a cert linked to an SSH key, verify key details + appear in the response. + WHY: Admin needs to see which key was used to request the cert. + EXPECTED: ssh_key object with fingerprint, key_type, key_bits. + """ + admin = create_test_user(password="AdminPass123!") + victim = create_test_user(password="VictimPass123!") + org = create_test_org() + create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER) + create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER) + ca = create_test_ca(org_id=org["id"]) + + key_id = self._create_test_ssh_key(integration_app, victim["id"]) + self._create_test_cert(integration_app, victim["id"], ca["id"], ssh_key_id=key_id) + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + result = integration_client.admin.get_user_ssh_certificates(victim["id"]) + data = assert_success(result) + + cert = data["certificates"][0] + assert cert["ssh_key"] is not None + assert cert["ssh_key"]["key_type"] == "ssh-ed25519" + assert cert["ssh_key"]["fingerprint"] is not None + assert cert["ssh_key"]["description"] == "Test laptop key" + + def test_list_user_ssh_certs_non_admin_negative(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca): + """TEST: ADMIN-SSH-03 — Non-admin cannot list another user's certs. + + WHAT: Regular member tries to list admin's certs. + WHY: Certificate data is sensitive and admin-only. + EXPECTED: 403 Forbidden. + """ + member = create_test_user(password="MemberPass123!") + admin_user = create_test_user(password="AdminPass123!") + org = create_test_org() + create_test_membership(member["id"], org["id"], OrganizationRole.MEMBER) + create_test_membership(admin_user["id"], org["id"], OrganizationRole.OWNER) + + integration_client.auth.login(email=member["email"], password="MemberPass123!") + + with pytest.raises(ApiError) as exc_info: + integration_client.admin.get_user_ssh_certificates(admin_user["id"]) + + assert exc_info.value.status_code == 403 + + def test_list_user_ssh_certs_filter_by_status(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca): + """TEST: ADMIN-SSH-04 — Filter certificates by status. + + WHAT: Create certs with different statuses, filter by status=revoked. + WHY: Admin may want to see only revoked certs to audit access. + EXPECTED: Only revoked certs returned. + """ + admin = create_test_user(password="AdminPass123!") + victim = create_test_user(password="VictimPass123!") + org = create_test_org() + create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER) + create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER) + ca = create_test_ca(org_id=org["id"]) + + from datetime import datetime, timezone, timedelta + now = datetime.now(timezone.utc) + + self._create_test_cert(integration_app, victim["id"], ca["id"], status="issued") + self._create_test_cert(integration_app, victim["id"], ca["id"], status="revoked", revoked=True) + self._create_test_cert(integration_app, victim["id"], ca["id"], status="expired") + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + result = integration_client.admin.get_user_ssh_certificates(victim["id"], status="revoked") + data = assert_success(result) + + assert data["count"] == 1 + assert data["certificates"][0]["status"] == "revoked" + assert data["certificates"][0]["revoked"] is True + + def test_list_user_ssh_certs_filter_active_only(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca): + """TEST: ADMIN-SSH-05 — Filter for only currently valid certificates. + + WHAT: Create active and expired certs, filter by active=true. + WHY: Admin needs quick view of currently active certs. + EXPECTED: Only valid (non-revoked, non-expired) certs returned. + """ + admin = create_test_user(password="AdminPass123!") + victim = create_test_user(password="VictimPass123!") + org = create_test_org() + create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER) + create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER) + ca = create_test_ca(org_id=org["id"]) + + from datetime import datetime, timezone, timedelta + now = datetime.now(timezone.utc) + + self._create_test_cert( + integration_app, victim["id"], ca["id"], status="issued", + valid_after=now - timedelta(hours=1), valid_before=now + timedelta(hours=23), + ) + self._create_test_cert( + integration_app, victim["id"], ca["id"], status="expired", + valid_after=now - timedelta(days=7), valid_before=now - timedelta(days=1), + ) + self._create_test_cert( + integration_app, victim["id"], ca["id"], status="revoked", revoked=True, + valid_after=now - timedelta(hours=1), valid_before=now + timedelta(hours=23), + ) + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + result = integration_client.admin.get_user_ssh_certificates(victim["id"], active="true") + data = assert_success(result) + + assert data["count"] == 1 + cert = data["certificates"][0] + assert cert["is_valid"] is True + assert cert["revoked"] is False + + def test_list_user_ssh_certs_user_not_found(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership): + """TEST: ADMIN-SSH-06 — Return 404 for non-existent user. + + WHAT: Admin requests certs for a user ID that doesn't exist. + WHY: Clear error for missing resources. + EXPECTED: 404 NOT_FOUND. + """ + admin = create_test_user(password="AdminPass123!") + org = create_test_org() + create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER) + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + + with pytest.raises(ApiError) as exc_info: + integration_client.admin.get_user_ssh_certificates("non-existent-user-id") + + assert exc_info.value.status_code == 404 + assert exc_info.value.error_type == "NOT_FOUND" + + def test_list_user_ssh_certs_empty_result(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership): + """TEST: ADMIN-SSH-07 — Empty result when user has no certs. + + WHAT: Admin lists certs for a user who has never requested one. + WHY: Endpoint should handle gracefully, not error. + EXPECTED: 200 OK with empty certificates array and count=0. + """ + admin = create_test_user(password="AdminPass123!") + victim = create_test_user(password="VictimPass123!") + org = create_test_org() + create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER) + create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER) + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + result = integration_client.admin.get_user_ssh_certificates(victim["id"]) + data = assert_success(result) + + assert data["certificates"] == [] + assert data["count"] == 0 + + def test_list_user_ssh_certs_revoked_cert_details(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca): + """TEST: ADMIN-SSH-08 — Revoked certificate shows revocation details. + + WHAT: Create a revoked cert, verify revoke metadata is present. + WHY: Admin needs to know when and why a cert was revoked. + EXPECTED: revoked=True, revoked_at populated, revoke_reason present. + """ + admin = create_test_user(password="AdminPass123!") + victim = create_test_user(password="VictimPass123!") + org = create_test_org() + create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER) + create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER) + ca = create_test_ca(org_id=org["id"]) + + self._create_test_cert( + integration_app, victim["id"], ca["id"], + status="revoked", revoked=True, + ) + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + result = integration_client.admin.get_user_ssh_certificates(victim["id"]) + data = assert_success(result) + + cert = data["certificates"][0] + assert cert["revoked"] is True + assert cert["revoked_at"] is not None + assert cert["revoke_reason"] == "test revocation" + assert cert["status"] == "revoked" + + def test_list_user_ssh_certs_invalid_status_filter(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership): + """TEST: ADMIN-SSH-09 — Invalid status filter returns 400. + + WHAT: Admin passes an invalid status value. + WHY: Input validation prevents confusing queries. + EXPECTED: 400 VALIDATION_ERROR. + """ + admin = create_test_user(password="AdminPass123!") + victim = create_test_user(password="VictimPass123!") + org = create_test_org() + create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER) + create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER) + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + + with pytest.raises(ApiError) as exc_info: + integration_client.admin.get_user_ssh_certificates(victim["id"], status="bogus") + + assert exc_info.value.status_code == 400 + assert exc_info.value.error_type == "VALIDATION_ERROR" diff --git a/tests/integration/test_zerotier.py b/tests/integration/test_zerotier.py index e5f0444..269d003 100644 --- a/tests/integration/test_zerotier.py +++ b/tests/integration/test_zerotier.py @@ -201,3 +201,145 @@ class TestZeroTierMembership: except ApiError as exc: # Accept errors when no active memberships to kill assert exc.status_code in (400, 500) + + +class TestAdminUserDevices: + """Test admin endpoint to list devices for a specific user.""" + + def test_list_user_devices_positive( + self, integration_client, create_test_user, create_test_org, create_test_membership, integration_app + ): + """TEST: ZT-10 — Admin lists devices for a user with devices. + + WHAT: Admin GET /organizations//users//devices. + WHY: Admins need to see what devices a user has registered. + EXPECTED: 200 OK with devices array. + """ + from gatehouse_app.models.zerotier.device import Device + + admin = create_test_user(password="AdminPass123!") + member = create_test_user(password="MemberPass123!") + org = create_test_org() + + create_test_membership(admin["id"], org["id"], OrganizationRole.ADMIN) + create_test_membership(member["id"], org["id"], OrganizationRole.MEMBER) + + # Create test devices for the member + from gatehouse_app.extensions import db as _db + with integration_app.app_context(): + device1 = Device( + user_id=member["id"], + organization_id=org["id"], + node_id="1234567890", + device_nickname="Member Laptop", + hostname="member-laptop", + ) + device2 = Device( + user_id=member["id"], + organization_id=org["id"], + node_id="0987654321", + device_nickname="Member Phone", + hostname="member-phone", + ) + _db.session.add_all([device1, device2]) + _db.session.commit() + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + result = integration_client.get(f"/organizations/{org['id']}/users/{member['id']}/devices") + data = assert_success(result, "devices retrieved") + + assert "devices" in data + assert data["count"] == 2 + assert data["user_id"] == member["id"] + assert data["organization_id"] == org["id"] + device_node_ids = [d["node_id"] for d in data["devices"]] + assert "1234567890" in device_node_ids + assert "0987654321" in device_node_ids + + def test_list_user_devices_no_devices( + self, integration_client, create_test_user, create_test_org, create_test_membership + ): + """TEST: ZT-11 — Admin lists devices for a user with no devices. + + WHAT: Admin GET /organizations//users//devices for user with no devices. + WHY: Endpoint should return empty list, not error. + EXPECTED: 200 OK with empty devices array. + """ + admin = create_test_user(password="AdminPass123!") + member = create_test_user(password="MemberPass123!") + org = create_test_org() + + create_test_membership(admin["id"], org["id"], OrganizationRole.ADMIN) + create_test_membership(member["id"], org["id"], OrganizationRole.MEMBER) + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + result = integration_client.get(f"/organizations/{org['id']}/users/{member['id']}/devices") + data = assert_success(result) + + assert data["count"] == 0 + assert data["devices"] == [] + + def test_list_user_devices_non_admin_negative( + self, integration_client, create_test_user, create_test_org, create_test_membership + ): + """TEST: ZT-12 — Non-admin cannot list another user's devices. + + WHAT: Member attempts GET /organizations//users//devices. + WHY: This endpoint is admin-only. + EXPECTED: 403 Forbidden. + """ + member1 = create_test_user(password="Member1Pass123!") + member2 = create_test_user(password="Member2Pass123!") + org = create_test_org() + + create_test_membership(member1["id"], org["id"], OrganizationRole.MEMBER) + create_test_membership(member2["id"], org["id"], OrganizationRole.MEMBER) + + integration_client.auth.login(email=member1["email"], password="Member1Pass123!") + with pytest.raises(ApiError) as exc_info: + integration_client.get(f"/organizations/{org['id']}/users/{member2['id']}/devices") + assert exc_info.value.status_code == 403 + + def test_list_user_devices_user_not_in_org_negative( + self, integration_client, create_test_user, create_test_org, create_test_membership + ): + """TEST: ZT-13 — Cannot list devices for user not in organization. + + WHAT: Admin GET /organizations//users//devices for user not in org. + WHY: User must be a member of the organization. + EXPECTED: 404 Not Found. + """ + admin = create_test_user(password="AdminPass123!") + outside_user = create_test_user(password="OutsidePass123!") + org = create_test_org() + + create_test_membership(admin["id"], org["id"], OrganizationRole.ADMIN) + # outside_user is NOT added to the org + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + with pytest.raises(ApiError) as exc_info: + integration_client.get(f"/organizations/{org['id']}/users/{outside_user['id']}/devices") + assert exc_info.value.status_code == 404 + + def test_list_user_devices_user_not_found_negative( + self, integration_client, create_test_user, create_test_org, create_test_membership + ): + """TEST: ZT-14 — Cannot list devices for non-existent user. + + WHAT: Admin GET /organizations//users//devices. + WHY: User must exist. + EXPECTED: 404 Not Found. + """ + import uuid + + admin = create_test_user(password="AdminPass123!") + org = create_test_org() + + create_test_membership(admin["id"], org["id"], OrganizationRole.ADMIN) + + non_existent_id = str(uuid.uuid4()) + + integration_client.auth.login(email=admin["email"], password="AdminPass123!") + with pytest.raises(ApiError) as exc_info: + integration_client.get(f"/organizations/{org['id']}/users/{non_existent_id}/devices") + assert exc_info.value.status_code == 404 diff --git a/tests/unit/test_migration_merge_approval_membership.py b/tests/unit/test_migration_merge_approval_membership.py new file mode 100644 index 0000000..a9985ad --- /dev/null +++ b/tests/unit/test_migration_merge_approval_membership.py @@ -0,0 +1,88 @@ +"""Verify the structure of the Alembic migration that merges +user_network_approvals and device_network_memberships into network_access_requests. + +These are STRUCTURAL tests only — no database connection is required. +""" + +import importlib +import importlib.util +import os +import sys + + +# ── helpers ──────────────────────────────────────────────────────────────── + +def _load_migration_module(): + """Load the migration module by file path without executing Alembic.""" + migration_path = os.path.join( + os.path.dirname(__file__), + '..', '..', 'migrations', 'versions', + 'merge_approval_membership_tables.py', + ) + migration_path = os.path.abspath(migration_path) + + spec = importlib.util.spec_from_file_location( + 'merge_approval_membership_tables', migration_path, + ) + assert spec is not None, f'Could not create module spec for {migration_path}' + assert spec.loader is not None, f'Module spec has no loader for {migration_path}' + + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ── structural tests ─────────────────────────────────────────────────────── + +def test_migration_file_can_be_imported(): + """The migration module MUST import without raising any exception.""" + mod = _load_migration_module() + assert mod is not None + + +def test_upgrade_function_exists(): + """upgrade() must be a callable in the module.""" + mod = _load_migration_module() + assert hasattr(mod, 'upgrade'), 'module is missing upgrade()' + assert callable(mod.upgrade), 'upgrade is not callable' + + +def test_downgrade_function_exists(): + """downgrade() must be a callable in the module.""" + mod = _load_migration_module() + assert hasattr(mod, 'downgrade'), 'module is missing downgrade()' + assert callable(mod.downgrade), 'downgrade is not callable' + + +def test_revision_is_set_correctly(): + """revision must equal the documented value 'c0a1b2c3d4e5'.""" + mod = _load_migration_module() + assert hasattr(mod, 'revision'), 'module is missing revision' + assert mod.revision == 'c0a1b2c3d4e5', ( + f"Expected revision 'c0a1b2c3d4e5', got '{mod.revision}'" + ) + + +def test_down_revision_is_set_correctly(): + """down_revision must equal the documented value 'a1b2c3d4e5f6'.""" + mod = _load_migration_module() + assert hasattr(mod, 'down_revision'), 'module is missing down_revision' + assert mod.down_revision == 'a1b2c3d4e5f6', ( + f"Expected down_revision 'a1b2c3d4e5f6', got '{mod.down_revision}'" + ) + + +def test_branch_labels_is_none(): + """branch_labels should be None for a standard linear migration.""" + mod = _load_migration_module() + assert mod.branch_labels is None, ( + f"Expected branch_labels None, got {mod.branch_labels!r}" + ) + + +def test_depends_on_is_none(): + """depends_on should be None — this migration has no cross-dependencies.""" + mod = _load_migration_module() + assert mod.depends_on is None, ( + f"Expected depends_on None, got {mod.depends_on!r}" + ) diff --git a/tests/unit/test_network_access_request_model.py b/tests/unit/test_network_access_request_model.py new file mode 100644 index 0000000..481a762 --- /dev/null +++ b/tests/unit/test_network_access_request_model.py @@ -0,0 +1,340 @@ +"""Unit tests for NetworkAccessRequest model structure. + +WHAT: Verifies the model class can be imported, has the expected columns, + constraints, and enum types. +WHY: Structural correctness of the model is a prerequisite for Phase 2+ + work; catching missing columns or constraints early prevents + migration/runtime failures. + +APPROACH: gatehouse_app/__init__.py calls create_app() at module level which + requires psycopg2 (PostgreSQL driver). We prevent this by pre-loading + gatehouse_app as a bare namespace package, then selectively providing + the real submodules (utils.constants) and fakes (extensions, models.base). + + We do NOT call db.create_all() — the table metadata is fully populated + during class definition. FK target tables don't exist in our test + metadata, so we check FK presence without table resolution. +""" + +import sys +import importlib.util +import pytest +from flask import Flask +from flask_sqlalchemy import SQLAlchemy + +# ═══════════════════════════════════════════════════════════════════════════════ +# Step 1: Pre-load gatehouse_app as a bare namespace (prevents __init__.py) +# ═══════════════════════════════════════════════════════════════════════════════ + +_gatehouse = type(sys)("gatehouse_app") +_gatehouse.__path__ = [] +sys.modules["gatehouse_app"] = _gatehouse + +# ═══════════════════════════════════════════════════════════════════════════════ +# Step 2: Load the real gatehouse_app.utils.constants (self-contained, no deps) +# ═══════════════════════════════════════════════════════════════════════════════ + +_constants_spec = importlib.util.spec_from_file_location( + "gatehouse_app.utils.constants", + "/home/ubuntu/securid/gatehouse-api/gatehouse_app/utils/constants.py", + submodule_search_locations=[], +) +_constants_mod = importlib.util.module_from_spec(_constants_spec) +sys.modules["gatehouse_app.utils"] = type(sys)("gatehouse_app.utils") +sys.modules["gatehouse_app.utils.constants"] = _constants_mod +_constants_spec.loader.exec_module(_constants_mod) + +ApprovalGrantType = _constants_mod.ApprovalGrantType +ApprovalState = _constants_mod.ApprovalState + +# ═══════════════════════════════════════════════════════════════════════════════ +# Step 3: Build fake extensions.db and models.base +# ═══════════════════════════════════════════════════════════════════════════════ + +_fake_db = SQLAlchemy() + + +class FakeBaseModel(_fake_db.Model): + """Minimal BaseModel matching the real one's column definitions.""" + __abstract__ = True + id = _fake_db.Column(_fake_db.String(36), primary_key=True, default=lambda: "test-uuid", nullable=False) + created_at = _fake_db.Column(_fake_db.DateTime, nullable=False) + updated_at = _fake_db.Column(_fake_db.DateTime, nullable=False) + deleted_at = _fake_db.Column(_fake_db.DateTime, nullable=True) + + def to_dict(self, exclude=None): + """Mimic the real BaseModel.to_dict — iterates __table__.columns.""" + from datetime import datetime, timezone + exclude = exclude or [] + result = {} + for column in self.__table__.columns: + if column.name not in exclude: + value = getattr(self, column.name) + if isinstance(value, datetime): + result[column.name] = value.isoformat() + else: + result[column.name] = value + return result + + +_fake_extensions = type(sys)("gatehouse_app.extensions") +_fake_extensions.db = _fake_db + +_fake_models_base = type(sys)("gatehouse_app.models.base") +_fake_models_base.BaseModel = FakeBaseModel + +sys.modules["gatehouse_app.extensions"] = _fake_extensions +sys.modules["gatehouse_app.models"] = type(sys)("gatehouse_app.models") +sys.modules["gatehouse_app.models.base"] = _fake_models_base + +# ═══════════════════════════════════════════════════════════════════════════════ +# Step 3b: Create stub models for relationship targets so ORM mapper +# can resolve 'Organization', 'User', 'Device', 'PortalNetwork' +# ═══════════════════════════════════════════════════════════════════════════════ + +class Organization(_fake_db.Model): + __tablename__ = "organizations" + id = _fake_db.Column(_fake_db.String(36), primary_key=True) + + +class User(_fake_db.Model): + __tablename__ = "users" + id = _fake_db.Column(_fake_db.String(36), primary_key=True) + + +class Device(_fake_db.Model): + __tablename__ = "devices" + id = _fake_db.Column(_fake_db.String(36), primary_key=True) + + +class PortalNetwork(_fake_db.Model): + __tablename__ = "portal_networks" + id = _fake_db.Column(_fake_db.String(36), primary_key=True) + +# ═══════════════════════════════════════════════════════════════════════════════ +# Step 4: Load the real network_access_request module from file +# ═══════════════════════════════════════════════════════════════════════════════ + +_model_spec = importlib.util.spec_from_file_location( + "gatehouse_app.models.zerotier.network_access_request", + "/home/ubuntu/securid/gatehouse-api/gatehouse_app/models/zerotier/network_access_request.py", + submodule_search_locations=[], +) +_model_mod = importlib.util.module_from_spec(_model_spec) +sys.modules["gatehouse_app.models.zerotier"] = type(sys)("gatehouse_app.models.zerotier") +sys.modules["gatehouse_app.models.zerotier.network_access_request"] = _model_mod +_model_spec.loader.exec_module(_model_mod) +NetworkAccessRequest = _model_mod.NetworkAccessRequest + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Fixture +# ═══════════════════════════════════════════════════════════════════════════════ + +@pytest.fixture(scope="module") +def model_class(): + """Return the model class — table metadata is already built at definition time.""" + return NetworkAccessRequest + +@pytest.fixture(scope="module") +def app(): + """Minimal Flask app for to_dict (BaseModel.to_dict iterates __table__.columns).""" + app = Flask(__name__) + app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" + app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False + _fake_db.init_app(app) + return app + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Test data +# ═══════════════════════════════════════════════════════════════════════════════ + +EXPECTED_LOCAL_COLUMNS = { + "organization_id", "user_id", "device_id", "portal_network_id", + "granted_by_user_id", "grant_type", "status", "active", + "justification", "join_seen", +} + +EXPECTED_INHERITED_COLUMNS = {"id", "created_at", "updated_at", "deleted_at"} +ALL_EXPECTED = EXPECTED_LOCAL_COLUMNS | EXPECTED_INHERITED_COLUMNS + +# FK columns that should have foreign keys (table name, FK target) +EXPECTED_FKS = { + "organization_id": "organizations.id", + "user_id": "users.id", + "device_id": "devices.id", + "portal_network_id": "portal_networks.id", + "granted_by_user_id": "users.id", +} + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Test: Module importability +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestImport: + def test_model_importable(self, model_class): + assert model_class is not None + assert isinstance(model_class, type) + + def test_model_tablename(self, model_class): + assert model_class.__tablename__ == "network_access_requests" + + def test_model_inherits_base(self, model_class): + assert issubclass(model_class, FakeBaseModel) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Test: Columns +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestColumns: + def test_all_expected_columns_present(self, model_class): + actual = {c.name for c in model_class.__table__.columns} + missing = ALL_EXPECTED - actual + assert missing == set(), f"Missing columns: {missing}" + + def test_no_extra_columns(self, model_class): + actual = {c.name for c in model_class.__table__.columns} + extra = actual - ALL_EXPECTED + assert extra == set(), f"Unexpected columns: {extra}" + + def test_exact_column_count(self, model_class): + assert len(model_class.__table__.columns) == 14, ( + f"Expected 14 columns, got {len(model_class.__table__.columns)}: " + f"{sorted(c.name for c in model_class.__table__.columns)}" + ) + + def test_organization_id_is_fk_string_not_null(self, model_class): + col = model_class.__table__.columns["organization_id"] + assert not col.nullable + assert _has_foreign_key(col) + + def test_user_id_is_fk_string_not_null(self, model_class): + col = model_class.__table__.columns["user_id"] + assert not col.nullable + assert _has_foreign_key(col) + + def test_device_id_is_fk_string_not_null(self, model_class): + col = model_class.__table__.columns["device_id"] + assert not col.nullable + assert _has_foreign_key(col) + + def test_portal_network_id_is_fk_string_not_null(self, model_class): + col = model_class.__table__.columns["portal_network_id"] + assert not col.nullable + assert _has_foreign_key(col) + + def test_granted_by_user_id_nullable_fk(self, model_class): + col = model_class.__table__.columns["granted_by_user_id"] + assert col.nullable + assert _has_foreign_key(col) + + def test_justification_is_text_nullable(self, model_class): + col = model_class.__table__.columns["justification"] + assert col.nullable + assert "TEXT" in str(col.type).upper() + + def test_active_is_boolean_not_null(self, model_class): + col = model_class.__table__.columns["active"] + assert str(col.type) in ("BOOLEAN", "INTEGER") + assert not col.nullable + + def test_join_seen_is_boolean_not_null(self, model_class): + col = model_class.__table__.columns["join_seen"] + assert str(col.type) in ("BOOLEAN", "INTEGER") + assert not col.nullable + + def test_fk_count(self, model_class): + """Verify exactly the expected FK columns have foreign keys.""" + fk_cols = {c.name for c in model_class.__table__.columns if _has_foreign_key(c)} + assert fk_cols == set(EXPECTED_FKS.keys()), ( + f"FK columns {sorted(fk_cols)} != expected {sorted(EXPECTED_FKS.keys())}" + ) + + +def _has_foreign_key(column): + """Check if column has at least one ForeignKey, without resolving target table.""" + return bool(column.foreign_keys) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Test: UniqueConstraint +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestConstraints: + def test_unique_constraint_exists(self, model_class): + from sqlalchemy import UniqueConstraint + ucs = [c for c in model_class.__table__.constraints if isinstance(c, UniqueConstraint)] + assert len(ucs) >= 1, "No UniqueConstraint found" + + def test_unique_constraint_columns(self, model_class): + from sqlalchemy import UniqueConstraint + ucs = [c for c in model_class.__table__.constraints if isinstance(c, UniqueConstraint)] + assert len(ucs) == 1, f"Expected 1, found {len(ucs)}" + cols = {col.name for col in ucs[0].columns} + expected = {"user_id", "device_id", "portal_network_id", "deleted_at"} + assert cols == expected, f"UniqueConstraint columns {cols} != {expected}" + + def test_unique_constraint_name(self, model_class): + from sqlalchemy import UniqueConstraint + ucs = [c for c in model_class.__table__.constraints if isinstance(c, UniqueConstraint)] + assert len(ucs) == 1 + assert ucs[0].name == "uix_user_device_network", ( + f"Expected 'uix_user_device_network', got '{ucs[0].name}'" + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Test: Enum types +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestEnumTypes: + def test_status_column_uses_approval_state_enum(self, model_class): + col = model_class.__table__.columns["status"] + assert hasattr(col.type, "enum_class"), ( + f"status column type {type(col.type)} has no enum_class" + ) + assert col.type.enum_class is ApprovalState, ( + f"status enum is {col.type.enum_class}, expected ApprovalState" + ) + + def test_grant_type_column_uses_approval_grant_type_enum(self, model_class): + col = model_class.__table__.columns["grant_type"] + assert hasattr(col.type, "enum_class"), ( + f"grant_type column type {type(col.type)} has no enum_class" + ) + assert col.type.enum_class is ApprovalGrantType, ( + f"grant_type enum is {col.type.enum_class}, expected ApprovalGrantType" + ) + + def test_status_column_not_nullable(self, model_class): + assert not model_class.__table__.columns["status"].nullable + + def test_grant_type_column_not_nullable(self, model_class): + assert not model_class.__table__.columns["grant_type"].nullable + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Test: Properties and methods +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestMethods: + def test_repr_returns_string(self, model_class): + instance = model_class() + result = repr(instance) + assert isinstance(result, str) + assert "NetworkAccessRequest" in result + + def test_active_session_property_returns_none(self, model_class): + instance = model_class() + assert instance.active_session is None + + def test_to_dict_returns_dict(self, model_class, app): + with app.app_context(): + instance = model_class() + result = instance.to_dict() + assert isinstance(result, dict) + for col_name in EXPECTED_LOCAL_COLUMNS: + assert col_name in result, f"Missing '{col_name}' in to_dict output"