Added OIDC client CORS attributes

This commit is contained in:
2026-05-19 15:15:47 +00:00
parent 78bae3c2bb
commit 2342a1aab6
9 changed files with 645 additions and 7 deletions
+137
View File
@@ -0,0 +1,137 @@
"""Unit tests for validate_cors_origins() helper.
WHAT: Tests for the CORS origin validation function used by OIDC client
create/update endpoints.
WHY: Malformed origins would silently break per-client CORS; strict
validation catches mistakes at write-time rather than at runtime.
EXPECTED: Valid origins accepted, invalid origins rejected with clear errors.
"""
import pytest
from gatehouse_app.utils.validators import validate_cors_origins
class TestValidateCorsOriginsAccepts:
"""Cases that should pass validation."""
def test_none_passes_through(self):
value, error = validate_cors_origins(None)
assert value is None
assert error is None
def test_empty_list_is_valid(self):
value, error = validate_cors_origins([])
assert value == []
assert error is None
def test_sentinel_plus(self):
value, error = validate_cors_origins(["+"])
assert value == ["+"]
assert error is None
def test_sentinel_wildcard(self):
value, error = validate_cors_origins(["*"])
assert value == ["*"]
assert error is None
def test_https_origin(self):
value, error = validate_cors_origins(["https://example.com"])
assert value == ["https://example.com"]
assert error is None
def test_https_origin_with_port(self):
value, error = validate_cors_origins(["https://example.com:8443"])
assert value == ["https://example.com:8443"]
assert error is None
def test_http_localhost(self):
value, error = validate_cors_origins(["http://localhost:3000"])
assert value == ["http://localhost:3000"]
assert error is None
def test_multiple_valid_origins(self):
origins = ["https://app.example.com", "https://staging.example.com:8443"]
value, error = validate_cors_origins(origins)
assert value == origins
assert error is None
def test_origin_with_trailing_slash_accepted(self):
# urlparse treats trailing "/" as path="/", which we allow
value, error = validate_cors_origins(["https://example.com/"])
assert value == ["https://example.com/"]
assert error is None
def test_sentinel_mixed_with_origins(self):
value, error = validate_cors_origins(["+", "https://extra.example.com"])
assert value == ["+", "https://extra.example.com"]
assert error is None
class TestValidateCorsOriginsRejects:
"""Cases that should fail validation."""
def test_not_a_list(self):
value, error = validate_cors_origins("https://example.com")
assert value is None
assert "must be a list" in error
def test_non_string_entry(self):
value, error = validate_cors_origins([123])
assert value is None
assert "expected a string" in error
def test_empty_string_entry(self):
value, error = validate_cors_origins([""])
assert value is None
assert "empty string" in error
def test_whitespace_only_entry(self):
value, error = validate_cors_origins([" "])
assert value is None
assert "empty string" in error
def test_origin_with_path(self):
value, error = validate_cors_origins(["https://example.com/api/v1"])
assert value is None
assert "must not contain a path" in error
def test_origin_with_query_string(self):
value, error = validate_cors_origins(["https://example.com?q=1"])
assert value is None
assert "query string" in error
def test_origin_with_fragment(self):
value, error = validate_cors_origins(["https://example.com#section"])
assert value is None
assert "fragment" in error
def test_ftp_scheme_rejected(self):
value, error = validate_cors_origins(["ftp://example.com"])
assert value is None
assert "invalid scheme" in error
def test_no_scheme(self):
value, error = validate_cors_origins(["example.com"])
assert value is None
# urlparse puts this in path, not hostname, so we get either
# scheme or hostname error
assert error is not None
def test_bare_string_not_url(self):
value, error = validate_cors_origins(["not-a-url"])
assert value is None
assert error is not None
def test_mixed_valid_and_invalid(self):
"""First invalid entry stops validation."""
value, error = validate_cors_origins([
"https://good.example.com",
"ftp://bad.example.com",
])
assert value is None
assert "allowed_cors_origins[1]" in error
def test_dict_entry_rejected(self):
value, error = validate_cors_origins([{"url": "https://example.com"}])
assert value is None
assert "expected a string" in error
+340
View File
@@ -0,0 +1,340 @@
"""Unit tests for allowed_cors_origins in OIDC client endpoints.
WHAT: Tests that the create, update, and list endpoints correctly accept,
validate, persist, and return the allowed_cors_origins field.
WHY: The field was already on the model but was not wired into any API
endpoint; these tests verify the new wiring works end-to-end.
EXPECTED: Valid origins are stored and returned; invalid origins are rejected
with 400; omitting the field defaults to None (global fallback).
"""
import json
import secrets
from unittest.mock import patch, MagicMock
import pytest
from gatehouse_app import create_app, db
from gatehouse_app.extensions import limiter
from gatehouse_app.models.oidc.oidc_client import OIDCClient
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.user.user import User
from gatehouse_app.utils.constants import OrganizationRole
# Disable rate limiter for tests
limiter.enabled = False
@pytest.fixture(scope="module")
def app():
"""Create a test Flask app with in-memory SQLite."""
_app = create_app(config_name="testing")
_app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
_app.config["TESTING"] = True
_app.config["WTF_CSRF_ENABLED"] = False
_app.config["RATELIMIT_ENABLED"] = False
with _app.app_context():
db.create_all()
yield _app
db.session.remove()
db.drop_all()
@pytest.fixture
def auth_context(app):
"""Create a user, org, and admin membership.
The module-scoped app fixture already holds an active app_context,
so we don't push another one here.
"""
user = User(
email=f"admin_{secrets.token_hex(4)}@test.com",
full_name="Admin User",
email_verified=True,
)
db.session.add(user)
db.session.commit()
org = Organization(
name=f"Test Org {secrets.token_hex(4)}",
slug=f"test-org-{secrets.token_hex(4)}",
)
db.session.add(org)
db.session.commit()
membership = OrganizationMember(
user_id=user.id,
organization_id=org.id,
role=OrganizationRole.OWNER,
)
db.session.add(membership)
db.session.commit()
return user, org
def _auth_headers():
return {"Authorization": "Bearer test-token", "Content-Type": "application/json"}
def _mock_session_for(user):
"""Return a context manager that patches SessionService to authenticate *user*."""
mock_session = MagicMock()
mock_session.user = user
mock_session.is_active.return_value = True
mock_session.is_compliance_only = False
mock_session.device_info = {}
return patch(
"gatehouse_app.services.session_service.SessionService.get_active_session_by_token",
return_value=mock_session,
)
def _create_oidc_client(org_id, **overrides):
"""Insert a minimal OIDCClient directly into the DB."""
from gatehouse_app.extensions import bcrypt
defaults = dict(
organization_id=org_id,
name="Test Client",
client_id=secrets.token_hex(16),
client_secret_hash=bcrypt.generate_password_hash("secret").decode("utf-8"),
redirect_uris=["https://app.example.com/callback"],
grant_types=["authorization_code"],
response_types=["code"],
scopes=["openid"],
is_active=True,
is_confidential=True,
)
defaults.update(overrides)
c = OIDCClient(**defaults)
db.session.add(c)
db.session.commit()
return c
# ---------------------------------------------------------------------------
# POST create
# ---------------------------------------------------------------------------
class TestCreateClientCorsOrigins:
"""POST /api/v1/organizations/<org_id>/clients with allowed_cors_origins."""
def test_create_with_cors_origins(self, app, auth_context):
user, org = auth_context
with app.test_client() as tc, _mock_session_for(user):
resp = tc.post(
f"/api/v1/organizations/{org.id}/clients",
headers=_auth_headers(),
data=json.dumps({
"name": "CORS Test Client",
"redirect_uris": ["https://app.example.com/callback"],
"allowed_cors_origins": ["https://app.example.com"],
}),
)
assert resp.status_code == 201
data = resp.get_json()
assert data["data"]["client"]["allowed_cors_origins"] == ["https://app.example.com"]
def test_create_with_sentinel_plus(self, app, auth_context):
user, org = auth_context
with app.test_client() as tc, _mock_session_for(user):
resp = tc.post(
f"/api/v1/organizations/{org.id}/clients",
headers=_auth_headers(),
data=json.dumps({
"name": "Plus Client",
"redirect_uris": ["https://app.example.com/callback"],
"allowed_cors_origins": ["+"],
}),
)
assert resp.status_code == 201
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] == ["+"]
def test_create_without_cors_origins_defaults_to_none(self, app, auth_context):
user, org = auth_context
with app.test_client() as tc, _mock_session_for(user):
resp = tc.post(
f"/api/v1/organizations/{org.id}/clients",
headers=_auth_headers(),
data=json.dumps({
"name": "No CORS Client",
"redirect_uris": ["https://app.example.com/callback"],
}),
)
assert resp.status_code == 201
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] is None
def test_create_with_null_cors_origins(self, app, auth_context):
user, org = auth_context
with app.test_client() as tc, _mock_session_for(user):
resp = tc.post(
f"/api/v1/organizations/{org.id}/clients",
headers=_auth_headers(),
data=json.dumps({
"name": "Null CORS Client",
"redirect_uris": ["https://app.example.com/callback"],
"allowed_cors_origins": None,
}),
)
assert resp.status_code == 201
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] is None
def test_create_with_invalid_cors_origin_returns_400(self, app, auth_context):
user, org = auth_context
with app.test_client() as tc, _mock_session_for(user):
resp = tc.post(
f"/api/v1/organizations/{org.id}/clients",
headers=_auth_headers(),
data=json.dumps({
"name": "Bad CORS Client",
"redirect_uris": ["https://app.example.com/callback"],
"allowed_cors_origins": ["ftp://bad.example.com"],
}),
)
assert resp.status_code == 400
assert "invalid scheme" in resp.get_json()["message"]
def test_create_with_origin_containing_path_returns_400(self, app, auth_context):
user, org = auth_context
with app.test_client() as tc, _mock_session_for(user):
resp = tc.post(
f"/api/v1/organizations/{org.id}/clients",
headers=_auth_headers(),
data=json.dumps({
"name": "Path CORS Client",
"redirect_uris": ["https://app.example.com/callback"],
"allowed_cors_origins": ["https://app.example.com/api"],
}),
)
assert resp.status_code == 400
assert "must not contain a path" in resp.get_json()["message"]
def test_create_with_non_list_cors_origins_returns_400(self, app, auth_context):
user, org = auth_context
with app.test_client() as tc, _mock_session_for(user):
resp = tc.post(
f"/api/v1/organizations/{org.id}/clients",
headers=_auth_headers(),
data=json.dumps({
"name": "String CORS Client",
"redirect_uris": ["https://app.example.com/callback"],
"allowed_cors_origins": "https://app.example.com",
}),
)
assert resp.status_code == 400
assert "must be a list" in resp.get_json()["message"]
# ---------------------------------------------------------------------------
# PATCH update
# ---------------------------------------------------------------------------
class TestUpdateClientCorsOrigins:
"""PATCH /api/v1/organizations/<org_id>/clients/<client_id> with allowed_cors_origins."""
def test_update_set_cors_origins(self, app, auth_context):
user, org = auth_context
oidc_client = _create_oidc_client(org.id)
assert oidc_client.allowed_cors_origins is None
with app.test_client() as tc, _mock_session_for(user):
resp = tc.patch(
f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}",
headers=_auth_headers(),
data=json.dumps({
"allowed_cors_origins": ["https://new-app.example.com"],
}),
)
assert resp.status_code == 200
data = resp.get_json()["data"]["client"]
assert data["allowed_cors_origins"] == ["https://new-app.example.com"]
def test_update_clear_cors_origins(self, app, auth_context):
user, org = auth_context
oidc_client = _create_oidc_client(org.id, allowed_cors_origins=["https://old.example.com"])
with app.test_client() as tc, _mock_session_for(user):
resp = tc.patch(
f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}",
headers=_auth_headers(),
data=json.dumps({
"allowed_cors_origins": None,
}),
)
assert resp.status_code == 200
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] is None
def test_update_cors_origins_with_invalid_value_returns_400(self, app, auth_context):
user, org = auth_context
oidc_client = _create_oidc_client(org.id)
with app.test_client() as tc, _mock_session_for(user):
resp = tc.patch(
f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}",
headers=_auth_headers(),
data=json.dumps({
"allowed_cors_origins": ["https://good.com", "not-a-url"],
}),
)
assert resp.status_code == 400
def test_update_without_cors_field_does_not_change_it(self, app, auth_context):
user, org = auth_context
oidc_client = _create_oidc_client(org.id, allowed_cors_origins=["https://keep-me.example.com"])
with app.test_client() as tc, _mock_session_for(user):
resp = tc.patch(
f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}",
headers=_auth_headers(),
data=json.dumps({
"name": "Renamed",
}),
)
assert resp.status_code == 200
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] == ["https://keep-me.example.com"]
def test_update_set_wildcard(self, app, auth_context):
user, org = auth_context
oidc_client = _create_oidc_client(org.id)
with app.test_client() as tc, _mock_session_for(user):
resp = tc.patch(
f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}",
headers=_auth_headers(),
data=json.dumps({
"allowed_cors_origins": ["*"],
}),
)
assert resp.status_code == 200
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] == ["*"]
# ---------------------------------------------------------------------------
# GET list
# ---------------------------------------------------------------------------
class TestListClientsCorsOrigins:
"""GET /api/v1/organizations/<org_id>/clients returns allowed_cors_origins."""
def test_list_includes_cors_origins(self, app, auth_context):
user, org = auth_context
oidc_client = _create_oidc_client(
org.id,
name="List Test",
allowed_cors_origins=["https://list.example.com"],
)
with app.test_client() as tc, _mock_session_for(user):
resp = tc.get(
f"/api/v1/organizations/{org.id}/clients",
headers=_auth_headers(),
)
assert resp.status_code == 200
clients_list = resp.get_json()["data"]["clients"]
matching = [c for c in clients_list if c["client_id"] == oidc_client.client_id]
assert len(matching) == 1
assert matching[0]["allowed_cors_origins"] == ["https://list.example.com"]