diff --git a/.gitignore b/.gitignore index 6fcc667..2652764 100644 --- a/.gitignore +++ b/.gitignore @@ -141,4 +141,9 @@ flask_session/ # Opencode files and folders .opencode/ .swarm/ -SWARM_PLAN.* \ No newline at end of file +SWARM_PLAN.* +# local backups / dumps / sessions +*.sql +*.dump +session-*.md +backups/ diff --git a/gatehouse_app/api/v1/oidc.py b/gatehouse_app/api/v1/oidc.py index e4f01f4..4002b07 100644 --- a/gatehouse_app/api/v1/oidc.py +++ b/gatehouse_app/api/v1/oidc.py @@ -29,6 +29,7 @@ from gatehouse_app.exceptions.auth_exceptions import ( from gatehouse_app.utils.constants import AuditAction from gatehouse_app.services.audit_service import AuditService from gatehouse_app.services.oidc_audit_service import OIDCAuditService +from gatehouse_app.utils.validators import validate_cors_origins logger = logging.getLogger(__name__) @@ -816,6 +817,11 @@ def oidc_register(): except Exception: return jsonify({"error": "invalid_request", "error_description": f"Invalid redirect_uri: {uri}"}), 400 + cors_origins_raw = data.get("allowed_cors_origins") + cors_origins, cors_error = validate_cors_origins(cors_origins_raw) + if cors_error: + return jsonify({"error": "invalid_request", "error_description": cors_error}), 400 + client_id = f"oidc_{secrets.token_urlsafe(16)}" client_secret = f"secret_{secrets.token_urlsafe(24)}" client_secret_hash = flask_bcrypt.generate_password_hash(client_secret).decode("utf-8") @@ -842,6 +848,7 @@ def oidc_register(): grant_types=data.get("grant_types", ["authorization_code", "refresh_token"]), response_types=data.get("response_types", ["code"]), scopes=data.get("scope", "openid profile email roles").split(), + allowed_cors_origins=cors_origins, is_active=True, is_confidential=True, require_pkce=True, @@ -871,6 +878,7 @@ def oidc_register(): "client_secret_expires_at": 0, "client_name": client_name, "redirect_uris": redirect_uris, + "allowed_cors_origins": client.allowed_cors_origins, "token_endpoint_auth_method": data.get("token_endpoint_auth_method", "client_secret_basic"), "grant_types": client.grant_types, "response_types": client.response_types, diff --git a/gatehouse_app/api/v1/organizations/clients.py b/gatehouse_app/api/v1/organizations/clients.py index 8fb7334..9e1f16c 100644 --- a/gatehouse_app/api/v1/organizations/clients.py +++ b/gatehouse_app/api/v1/organizations/clients.py @@ -7,6 +7,7 @@ from gatehouse_app.utils.decorators import login_required, require_admin, full_a from gatehouse_app.extensions import db, bcrypt from gatehouse_app.utils.constants import AuditAction from gatehouse_app.services.audit_service import AuditService +from gatehouse_app.utils.validators import validate_cors_origins @api_v1_bp.route("/organizations//clients", methods=["GET"]) @@ -63,6 +64,11 @@ def create_org_client(org_id): if not redirect_uris: return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR") + cors_origins_raw = data.get("allowed_cors_origins") + cors_origins, cors_error = validate_cors_origins(cors_origins_raw) + if cors_error: + return api_response(success=False, message=cors_error, status=400, error_type="VALIDATION_ERROR") + client_id = _secrets.token_hex(16) client_secret = _secrets.token_urlsafe(32) @@ -75,6 +81,7 @@ def create_org_client(org_id): grant_types=["authorization_code", "refresh_token"], response_types=["code"], scopes=["openid", "profile", "email"], + allowed_cors_origins=cors_origins, is_active=True, is_confidential=True, ) @@ -99,6 +106,7 @@ def create_org_client(org_id): "client_secret": client_secret, "redirect_uris": client.redirect_uris, "scopes": client.scopes, + "allowed_cors_origins": client.allowed_cors_origins, "created_at": client.created_at.isoformat() + "Z", } }, @@ -135,6 +143,12 @@ def update_org_client(org_id, client_id): return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR") client.redirect_uris = uris + if "allowed_cors_origins" in data: + cors_origins, cors_error = validate_cors_origins(data["allowed_cors_origins"]) + if cors_error: + return api_response(success=False, message=cors_error, status=400, error_type="VALIDATION_ERROR") + client.allowed_cors_origins = cors_origins + db.session.commit() AuditService.log_action( @@ -155,6 +169,7 @@ def update_org_client(org_id, client_id): "redirect_uris": client.redirect_uris, "scopes": client.scopes, "grant_types": client.grant_types, + "allowed_cors_origins": client.allowed_cors_origins, "is_active": client.is_active, "created_at": client.created_at.isoformat() + "Z", } diff --git a/gatehouse_app/utils/validators.py b/gatehouse_app/utils/validators.py new file mode 100644 index 0000000..260d33c --- /dev/null +++ b/gatehouse_app/utils/validators.py @@ -0,0 +1,94 @@ +"""Validation helpers for request data.""" +from urllib.parse import urlparse + +# Special sentinel values allowed in allowed_cors_origins +_CORS_SENTINELS = {"+", "*"} + + +def validate_cors_origins(origins): + """Validate a list of CORS origin values. + + Accepts: + None - means "use global CORS config" (pass-through) + ["+"] - derive origins from the client's redirect_uris + ["*"] - allow any origin + ["https://host"] - explicit allow-list of well-formed origins + + Each non-sentinel entry must be a well-formed origin: + scheme (http or https) + host + optional port, with NO path, + query string, or fragment. + + Returns: + (validated_value, None) on success, or + (None, error_message) on failure. + """ + if origins is None: + return None, None + + if not isinstance(origins, list): + return None, "allowed_cors_origins must be a list or null" + + validated = [] + for i, entry in enumerate(origins): + if not isinstance(entry, str): + return None, f"allowed_cors_origins[{i}]: expected a string, got {type(entry).__name__}" + + entry = entry.strip() + if not entry: + return None, f"allowed_cors_origins[{i}]: empty string is not allowed" + + # Sentinel values are accepted as-is + if entry in _CORS_SENTINELS: + validated.append(entry) + continue + + # Parse and validate as origin + error = _validate_single_origin(entry, i) + if error: + return None, error + + validated.append(entry) + + return validated, None + + +def _validate_single_origin(origin, index): + """Validate that a string is a well-formed browser origin. + + A valid origin is: scheme://host[:port] with no path, query, or fragment. + Only http and https schemes are accepted. + + Returns an error message string on failure, or None on success. + """ + try: + parsed = urlparse(origin) + except Exception: + return f"allowed_cors_origins[{index}]: '{origin}' is not a valid URL" + + if parsed.scheme not in ("http", "https"): + return ( + f"allowed_cors_origins[{index}]: '{origin}' has an invalid scheme " + f"'{parsed.scheme}'; only 'http' and 'https' are allowed" + ) + + if not parsed.hostname: + return f"allowed_cors_origins[{index}]: '{origin}' is missing a hostname" + + # Origins must not have a path (other than empty or "/"), query, or fragment + if parsed.path and parsed.path != "/": + return ( + f"allowed_cors_origins[{index}]: '{origin}' must not contain a path " + f"(got '{parsed.path}'). Specify only scheme://host[:port]" + ) + + if parsed.query: + return ( + f"allowed_cors_origins[{index}]: '{origin}' must not contain a query string" + ) + + if parsed.fragment: + return ( + f"allowed_cors_origins[{index}]: '{origin}' must not contain a fragment" + ) + + return None diff --git a/migrations/versions/c8d2e4f6a1b3_consolidate_sessions.py b/migrations/versions/c8d2e4f6a1b3_consolidate_sessions.py index 4a0af90..7d10d87 100644 --- a/migrations/versions/c8d2e4f6a1b3_consolidate_sessions.py +++ b/migrations/versions/c8d2e4f6a1b3_consolidate_sessions.py @@ -38,7 +38,7 @@ def upgrade(): is_compliance_only, created_at, updated_at, deleted_at ) SELECT - id, 'superadmin', superadmin_id, token, 'active', + id, 'superadmin', superadmin_id, token, 'ACTIVE', ip_address, user_agent, NULL, expires_at, last_activity_at, revoked_at, revoked_reason, FALSE, created_at, updated_at, deleted_at diff --git a/migrations/versions/e1f2a3b4c5d6_merge_branches.py b/migrations/versions/e1f2a3b4c5d6_merge_branches.py new file mode 100644 index 0000000..b79db54 --- /dev/null +++ b/migrations/versions/e1f2a3b4c5d6_merge_branches.py @@ -0,0 +1,23 @@ +"""Merge branches: consolidate_sessions + remove_sudo_api_keys. + +Revision ID: e1f2a3b4c5d6 +Revises: c8d2e4f6a1b3, d1e2f3g4h5i6 +Create Date: 2026-05-19 12:45:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'e1f2a3b4c5d6' +down_revision = ('c8d2e4f6a1b3', 'd1e2f3g4h5i6') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/migrations/versions/merge_approval_membership_tables.py b/migrations/versions/merge_approval_membership_tables.py index 6297bdd..15fa0f9 100644 --- a/migrations/versions/merge_approval_membership_tables.py +++ b/migrations/versions/merge_approval_membership_tables.py @@ -7,6 +7,7 @@ Create Date: 2026-05-02 00:00:00.000000 from alembic import op import sqlalchemy as sa +from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. @@ -21,6 +22,21 @@ depends_on = None # --------------------------------------------------------------------------- def upgrade(): + # ------------------------------------------------------------------ + # Step 0: Ensure enum types exist (they may already exist from old tables) + # ------------------------------------------------------------------ + op.execute(""" + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'approval_grant_type') THEN + CREATE TYPE approval_grant_type AS ENUM ('requested', 'assigned'); + END IF; + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'approval_state') THEN + CREATE TYPE approval_state AS ENUM ('pending', 'approved', 'rejected', 'revoked', 'suspended'); + END IF; + END$$; + """) + # ------------------------------------------------------------------ # Step 1: Create the new network_access_requests table # ------------------------------------------------------------------ @@ -34,12 +50,12 @@ def upgrade(): sa.Column('granted_by_user_id', sa.String(length=36), nullable=True), sa.Column( 'grant_type', - sa.Enum('requested', 'assigned', name='approval_grant_type', create_type=False), + postgresql.ENUM('requested', 'assigned', name='approval_grant_type', create_type=False), nullable=False, ), sa.Column( 'status', - sa.Enum( + postgresql.ENUM( 'pending', 'approved', 'rejected', 'revoked', 'suspended', name='approval_state', create_type=False, ), @@ -334,12 +350,12 @@ def downgrade(): sa.Column('granted_by_user_id', sa.String(length=36), nullable=True), sa.Column( 'grant_type', - sa.Enum('requested', 'assigned', name='approval_grant_type', create_type=False), + postgresql.ENUM('requested', 'assigned', name='approval_grant_type', create_type=False), nullable=False, ), sa.Column( 'state', - sa.Enum( + postgresql.ENUM( 'pending', 'approved', 'rejected', 'revoked', 'suspended', name='approval_state', create_type=False, ), @@ -437,7 +453,7 @@ def downgrade(): sa.Column('user_network_approval_id', sa.String(length=36), nullable=True), sa.Column( 'state', - sa.Enum( + postgresql.ENUM( 'pending_device_registration', 'pending_request', 'pending_manager_approval', diff --git a/tests/unit/test_cors_origin_validation.py b/tests/unit/test_cors_origin_validation.py new file mode 100644 index 0000000..c5fe3c2 --- /dev/null +++ b/tests/unit/test_cors_origin_validation.py @@ -0,0 +1,137 @@ +"""Unit tests for validate_cors_origins() helper. + +WHAT: Tests for the CORS origin validation function used by OIDC client + create/update endpoints. +WHY: Malformed origins would silently break per-client CORS; strict + validation catches mistakes at write-time rather than at runtime. +EXPECTED: Valid origins accepted, invalid origins rejected with clear errors. +""" +import pytest + +from gatehouse_app.utils.validators import validate_cors_origins + + +class TestValidateCorsOriginsAccepts: + """Cases that should pass validation.""" + + def test_none_passes_through(self): + value, error = validate_cors_origins(None) + assert value is None + assert error is None + + def test_empty_list_is_valid(self): + value, error = validate_cors_origins([]) + assert value == [] + assert error is None + + def test_sentinel_plus(self): + value, error = validate_cors_origins(["+"]) + assert value == ["+"] + assert error is None + + def test_sentinel_wildcard(self): + value, error = validate_cors_origins(["*"]) + assert value == ["*"] + assert error is None + + def test_https_origin(self): + value, error = validate_cors_origins(["https://example.com"]) + assert value == ["https://example.com"] + assert error is None + + def test_https_origin_with_port(self): + value, error = validate_cors_origins(["https://example.com:8443"]) + assert value == ["https://example.com:8443"] + assert error is None + + def test_http_localhost(self): + value, error = validate_cors_origins(["http://localhost:3000"]) + assert value == ["http://localhost:3000"] + assert error is None + + def test_multiple_valid_origins(self): + origins = ["https://app.example.com", "https://staging.example.com:8443"] + value, error = validate_cors_origins(origins) + assert value == origins + assert error is None + + def test_origin_with_trailing_slash_accepted(self): + # urlparse treats trailing "/" as path="/", which we allow + value, error = validate_cors_origins(["https://example.com/"]) + assert value == ["https://example.com/"] + assert error is None + + def test_sentinel_mixed_with_origins(self): + value, error = validate_cors_origins(["+", "https://extra.example.com"]) + assert value == ["+", "https://extra.example.com"] + assert error is None + + +class TestValidateCorsOriginsRejects: + """Cases that should fail validation.""" + + def test_not_a_list(self): + value, error = validate_cors_origins("https://example.com") + assert value is None + assert "must be a list" in error + + def test_non_string_entry(self): + value, error = validate_cors_origins([123]) + assert value is None + assert "expected a string" in error + + def test_empty_string_entry(self): + value, error = validate_cors_origins([""]) + assert value is None + assert "empty string" in error + + def test_whitespace_only_entry(self): + value, error = validate_cors_origins([" "]) + assert value is None + assert "empty string" in error + + def test_origin_with_path(self): + value, error = validate_cors_origins(["https://example.com/api/v1"]) + assert value is None + assert "must not contain a path" in error + + def test_origin_with_query_string(self): + value, error = validate_cors_origins(["https://example.com?q=1"]) + assert value is None + assert "query string" in error + + def test_origin_with_fragment(self): + value, error = validate_cors_origins(["https://example.com#section"]) + assert value is None + assert "fragment" in error + + def test_ftp_scheme_rejected(self): + value, error = validate_cors_origins(["ftp://example.com"]) + assert value is None + assert "invalid scheme" in error + + def test_no_scheme(self): + value, error = validate_cors_origins(["example.com"]) + assert value is None + # urlparse puts this in path, not hostname, so we get either + # scheme or hostname error + assert error is not None + + def test_bare_string_not_url(self): + value, error = validate_cors_origins(["not-a-url"]) + assert value is None + assert error is not None + + def test_mixed_valid_and_invalid(self): + """First invalid entry stops validation.""" + value, error = validate_cors_origins([ + "https://good.example.com", + "ftp://bad.example.com", + ]) + assert value is None + assert "allowed_cors_origins[1]" in error + + def test_dict_entry_rejected(self): + value, error = validate_cors_origins([{"url": "https://example.com"}]) + assert value is None + assert "expected a string" in error diff --git a/tests/unit/test_org_client_cors.py b/tests/unit/test_org_client_cors.py new file mode 100644 index 0000000..f35cf37 --- /dev/null +++ b/tests/unit/test_org_client_cors.py @@ -0,0 +1,340 @@ +"""Unit tests for allowed_cors_origins in OIDC client endpoints. + +WHAT: Tests that the create, update, and list endpoints correctly accept, + validate, persist, and return the allowed_cors_origins field. +WHY: The field was already on the model but was not wired into any API + endpoint; these tests verify the new wiring works end-to-end. +EXPECTED: Valid origins are stored and returned; invalid origins are rejected + with 400; omitting the field defaults to None (global fallback). +""" +import json +import secrets +from unittest.mock import patch, MagicMock + +import pytest + +from gatehouse_app import create_app, db +from gatehouse_app.extensions import limiter +from gatehouse_app.models.oidc.oidc_client import OIDCClient +from gatehouse_app.models.organization.organization import Organization +from gatehouse_app.models.organization.organization_member import OrganizationMember +from gatehouse_app.models.user.user import User +from gatehouse_app.utils.constants import OrganizationRole + + +# Disable rate limiter for tests +limiter.enabled = False + + +@pytest.fixture(scope="module") +def app(): + """Create a test Flask app with in-memory SQLite.""" + _app = create_app(config_name="testing") + _app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" + _app.config["TESTING"] = True + _app.config["WTF_CSRF_ENABLED"] = False + _app.config["RATELIMIT_ENABLED"] = False + + with _app.app_context(): + db.create_all() + yield _app + db.session.remove() + db.drop_all() + + +@pytest.fixture +def auth_context(app): + """Create a user, org, and admin membership. + + The module-scoped app fixture already holds an active app_context, + so we don't push another one here. + """ + user = User( + email=f"admin_{secrets.token_hex(4)}@test.com", + full_name="Admin User", + email_verified=True, + ) + db.session.add(user) + db.session.commit() + + org = Organization( + name=f"Test Org {secrets.token_hex(4)}", + slug=f"test-org-{secrets.token_hex(4)}", + ) + db.session.add(org) + db.session.commit() + + membership = OrganizationMember( + user_id=user.id, + organization_id=org.id, + role=OrganizationRole.OWNER, + ) + db.session.add(membership) + db.session.commit() + + return user, org + + +def _auth_headers(): + return {"Authorization": "Bearer test-token", "Content-Type": "application/json"} + + +def _mock_session_for(user): + """Return a context manager that patches SessionService to authenticate *user*.""" + mock_session = MagicMock() + mock_session.user = user + mock_session.is_active.return_value = True + mock_session.is_compliance_only = False + mock_session.device_info = {} + + return patch( + "gatehouse_app.services.session_service.SessionService.get_active_session_by_token", + return_value=mock_session, + ) + + +def _create_oidc_client(org_id, **overrides): + """Insert a minimal OIDCClient directly into the DB.""" + from gatehouse_app.extensions import bcrypt + + defaults = dict( + organization_id=org_id, + name="Test Client", + client_id=secrets.token_hex(16), + client_secret_hash=bcrypt.generate_password_hash("secret").decode("utf-8"), + redirect_uris=["https://app.example.com/callback"], + grant_types=["authorization_code"], + response_types=["code"], + scopes=["openid"], + is_active=True, + is_confidential=True, + ) + defaults.update(overrides) + c = OIDCClient(**defaults) + db.session.add(c) + db.session.commit() + return c + + +# --------------------------------------------------------------------------- +# POST create +# --------------------------------------------------------------------------- + +class TestCreateClientCorsOrigins: + """POST /api/v1/organizations//clients with allowed_cors_origins.""" + + def test_create_with_cors_origins(self, app, auth_context): + user, org = auth_context + with app.test_client() as tc, _mock_session_for(user): + resp = tc.post( + f"/api/v1/organizations/{org.id}/clients", + headers=_auth_headers(), + data=json.dumps({ + "name": "CORS Test Client", + "redirect_uris": ["https://app.example.com/callback"], + "allowed_cors_origins": ["https://app.example.com"], + }), + ) + assert resp.status_code == 201 + data = resp.get_json() + assert data["data"]["client"]["allowed_cors_origins"] == ["https://app.example.com"] + + def test_create_with_sentinel_plus(self, app, auth_context): + user, org = auth_context + with app.test_client() as tc, _mock_session_for(user): + resp = tc.post( + f"/api/v1/organizations/{org.id}/clients", + headers=_auth_headers(), + data=json.dumps({ + "name": "Plus Client", + "redirect_uris": ["https://app.example.com/callback"], + "allowed_cors_origins": ["+"], + }), + ) + assert resp.status_code == 201 + assert resp.get_json()["data"]["client"]["allowed_cors_origins"] == ["+"] + + def test_create_without_cors_origins_defaults_to_none(self, app, auth_context): + user, org = auth_context + with app.test_client() as tc, _mock_session_for(user): + resp = tc.post( + f"/api/v1/organizations/{org.id}/clients", + headers=_auth_headers(), + data=json.dumps({ + "name": "No CORS Client", + "redirect_uris": ["https://app.example.com/callback"], + }), + ) + assert resp.status_code == 201 + assert resp.get_json()["data"]["client"]["allowed_cors_origins"] is None + + def test_create_with_null_cors_origins(self, app, auth_context): + user, org = auth_context + with app.test_client() as tc, _mock_session_for(user): + resp = tc.post( + f"/api/v1/organizations/{org.id}/clients", + headers=_auth_headers(), + data=json.dumps({ + "name": "Null CORS Client", + "redirect_uris": ["https://app.example.com/callback"], + "allowed_cors_origins": None, + }), + ) + assert resp.status_code == 201 + assert resp.get_json()["data"]["client"]["allowed_cors_origins"] is None + + def test_create_with_invalid_cors_origin_returns_400(self, app, auth_context): + user, org = auth_context + with app.test_client() as tc, _mock_session_for(user): + resp = tc.post( + f"/api/v1/organizations/{org.id}/clients", + headers=_auth_headers(), + data=json.dumps({ + "name": "Bad CORS Client", + "redirect_uris": ["https://app.example.com/callback"], + "allowed_cors_origins": ["ftp://bad.example.com"], + }), + ) + assert resp.status_code == 400 + assert "invalid scheme" in resp.get_json()["message"] + + def test_create_with_origin_containing_path_returns_400(self, app, auth_context): + user, org = auth_context + with app.test_client() as tc, _mock_session_for(user): + resp = tc.post( + f"/api/v1/organizations/{org.id}/clients", + headers=_auth_headers(), + data=json.dumps({ + "name": "Path CORS Client", + "redirect_uris": ["https://app.example.com/callback"], + "allowed_cors_origins": ["https://app.example.com/api"], + }), + ) + assert resp.status_code == 400 + assert "must not contain a path" in resp.get_json()["message"] + + def test_create_with_non_list_cors_origins_returns_400(self, app, auth_context): + user, org = auth_context + with app.test_client() as tc, _mock_session_for(user): + resp = tc.post( + f"/api/v1/organizations/{org.id}/clients", + headers=_auth_headers(), + data=json.dumps({ + "name": "String CORS Client", + "redirect_uris": ["https://app.example.com/callback"], + "allowed_cors_origins": "https://app.example.com", + }), + ) + assert resp.status_code == 400 + assert "must be a list" in resp.get_json()["message"] + + +# --------------------------------------------------------------------------- +# PATCH update +# --------------------------------------------------------------------------- + +class TestUpdateClientCorsOrigins: + """PATCH /api/v1/organizations//clients/ with allowed_cors_origins.""" + + def test_update_set_cors_origins(self, app, auth_context): + user, org = auth_context + oidc_client = _create_oidc_client(org.id) + assert oidc_client.allowed_cors_origins is None + + with app.test_client() as tc, _mock_session_for(user): + resp = tc.patch( + f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}", + headers=_auth_headers(), + data=json.dumps({ + "allowed_cors_origins": ["https://new-app.example.com"], + }), + ) + assert resp.status_code == 200 + data = resp.get_json()["data"]["client"] + assert data["allowed_cors_origins"] == ["https://new-app.example.com"] + + def test_update_clear_cors_origins(self, app, auth_context): + user, org = auth_context + oidc_client = _create_oidc_client(org.id, allowed_cors_origins=["https://old.example.com"]) + + with app.test_client() as tc, _mock_session_for(user): + resp = tc.patch( + f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}", + headers=_auth_headers(), + data=json.dumps({ + "allowed_cors_origins": None, + }), + ) + assert resp.status_code == 200 + assert resp.get_json()["data"]["client"]["allowed_cors_origins"] is None + + def test_update_cors_origins_with_invalid_value_returns_400(self, app, auth_context): + user, org = auth_context + oidc_client = _create_oidc_client(org.id) + + with app.test_client() as tc, _mock_session_for(user): + resp = tc.patch( + f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}", + headers=_auth_headers(), + data=json.dumps({ + "allowed_cors_origins": ["https://good.com", "not-a-url"], + }), + ) + assert resp.status_code == 400 + + def test_update_without_cors_field_does_not_change_it(self, app, auth_context): + user, org = auth_context + oidc_client = _create_oidc_client(org.id, allowed_cors_origins=["https://keep-me.example.com"]) + + with app.test_client() as tc, _mock_session_for(user): + resp = tc.patch( + f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}", + headers=_auth_headers(), + data=json.dumps({ + "name": "Renamed", + }), + ) + assert resp.status_code == 200 + assert resp.get_json()["data"]["client"]["allowed_cors_origins"] == ["https://keep-me.example.com"] + + def test_update_set_wildcard(self, app, auth_context): + user, org = auth_context + oidc_client = _create_oidc_client(org.id) + + with app.test_client() as tc, _mock_session_for(user): + resp = tc.patch( + f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}", + headers=_auth_headers(), + data=json.dumps({ + "allowed_cors_origins": ["*"], + }), + ) + assert resp.status_code == 200 + assert resp.get_json()["data"]["client"]["allowed_cors_origins"] == ["*"] + + +# --------------------------------------------------------------------------- +# GET list +# --------------------------------------------------------------------------- + +class TestListClientsCorsOrigins: + """GET /api/v1/organizations//clients returns allowed_cors_origins.""" + + def test_list_includes_cors_origins(self, app, auth_context): + user, org = auth_context + oidc_client = _create_oidc_client( + org.id, + name="List Test", + allowed_cors_origins=["https://list.example.com"], + ) + + with app.test_client() as tc, _mock_session_for(user): + resp = tc.get( + f"/api/v1/organizations/{org.id}/clients", + headers=_auth_headers(), + ) + assert resp.status_code == 200 + clients_list = resp.get_json()["data"]["clients"] + matching = [c for c in clients_list if c["client_id"] == oidc_client.client_id] + assert len(matching) == 1 + assert matching[0]["allowed_cors_origins"] == ["https://list.example.com"]