Added OIDC client CORS attributes
This commit is contained in:
+6
-1
@@ -141,4 +141,9 @@ flask_session/
|
|||||||
# Opencode files and folders
|
# Opencode files and folders
|
||||||
.opencode/
|
.opencode/
|
||||||
.swarm/
|
.swarm/
|
||||||
SWARM_PLAN.*
|
SWARM_PLAN.*
|
||||||
|
# local backups / dumps / sessions
|
||||||
|
*.sql
|
||||||
|
*.dump
|
||||||
|
session-*.md
|
||||||
|
backups/
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from gatehouse_app.exceptions.auth_exceptions import (
|
|||||||
from gatehouse_app.utils.constants import AuditAction
|
from gatehouse_app.utils.constants import AuditAction
|
||||||
from gatehouse_app.services.audit_service import AuditService
|
from gatehouse_app.services.audit_service import AuditService
|
||||||
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
|
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
|
||||||
|
from gatehouse_app.utils.validators import validate_cors_origins
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -816,6 +817,11 @@ def oidc_register():
|
|||||||
except Exception:
|
except Exception:
|
||||||
return jsonify({"error": "invalid_request", "error_description": f"Invalid redirect_uri: {uri}"}), 400
|
return jsonify({"error": "invalid_request", "error_description": f"Invalid redirect_uri: {uri}"}), 400
|
||||||
|
|
||||||
|
cors_origins_raw = data.get("allowed_cors_origins")
|
||||||
|
cors_origins, cors_error = validate_cors_origins(cors_origins_raw)
|
||||||
|
if cors_error:
|
||||||
|
return jsonify({"error": "invalid_request", "error_description": cors_error}), 400
|
||||||
|
|
||||||
client_id = f"oidc_{secrets.token_urlsafe(16)}"
|
client_id = f"oidc_{secrets.token_urlsafe(16)}"
|
||||||
client_secret = f"secret_{secrets.token_urlsafe(24)}"
|
client_secret = f"secret_{secrets.token_urlsafe(24)}"
|
||||||
client_secret_hash = flask_bcrypt.generate_password_hash(client_secret).decode("utf-8")
|
client_secret_hash = flask_bcrypt.generate_password_hash(client_secret).decode("utf-8")
|
||||||
@@ -842,6 +848,7 @@ def oidc_register():
|
|||||||
grant_types=data.get("grant_types", ["authorization_code", "refresh_token"]),
|
grant_types=data.get("grant_types", ["authorization_code", "refresh_token"]),
|
||||||
response_types=data.get("response_types", ["code"]),
|
response_types=data.get("response_types", ["code"]),
|
||||||
scopes=data.get("scope", "openid profile email roles").split(),
|
scopes=data.get("scope", "openid profile email roles").split(),
|
||||||
|
allowed_cors_origins=cors_origins,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
is_confidential=True,
|
is_confidential=True,
|
||||||
require_pkce=True,
|
require_pkce=True,
|
||||||
@@ -871,6 +878,7 @@ def oidc_register():
|
|||||||
"client_secret_expires_at": 0,
|
"client_secret_expires_at": 0,
|
||||||
"client_name": client_name,
|
"client_name": client_name,
|
||||||
"redirect_uris": redirect_uris,
|
"redirect_uris": redirect_uris,
|
||||||
|
"allowed_cors_origins": client.allowed_cors_origins,
|
||||||
"token_endpoint_auth_method": data.get("token_endpoint_auth_method", "client_secret_basic"),
|
"token_endpoint_auth_method": data.get("token_endpoint_auth_method", "client_secret_basic"),
|
||||||
"grant_types": client.grant_types,
|
"grant_types": client.grant_types,
|
||||||
"response_types": client.response_types,
|
"response_types": client.response_types,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from gatehouse_app.utils.decorators import login_required, require_admin, full_a
|
|||||||
from gatehouse_app.extensions import db, bcrypt
|
from gatehouse_app.extensions import db, bcrypt
|
||||||
from gatehouse_app.utils.constants import AuditAction
|
from gatehouse_app.utils.constants import AuditAction
|
||||||
from gatehouse_app.services.audit_service import AuditService
|
from gatehouse_app.services.audit_service import AuditService
|
||||||
|
from gatehouse_app.utils.validators import validate_cors_origins
|
||||||
|
|
||||||
|
|
||||||
@api_v1_bp.route("/organizations/<org_id>/clients", methods=["GET"])
|
@api_v1_bp.route("/organizations/<org_id>/clients", methods=["GET"])
|
||||||
@@ -63,6 +64,11 @@ def create_org_client(org_id):
|
|||||||
if not redirect_uris:
|
if not redirect_uris:
|
||||||
return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR")
|
return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR")
|
||||||
|
|
||||||
|
cors_origins_raw = data.get("allowed_cors_origins")
|
||||||
|
cors_origins, cors_error = validate_cors_origins(cors_origins_raw)
|
||||||
|
if cors_error:
|
||||||
|
return api_response(success=False, message=cors_error, status=400, error_type="VALIDATION_ERROR")
|
||||||
|
|
||||||
client_id = _secrets.token_hex(16)
|
client_id = _secrets.token_hex(16)
|
||||||
client_secret = _secrets.token_urlsafe(32)
|
client_secret = _secrets.token_urlsafe(32)
|
||||||
|
|
||||||
@@ -75,6 +81,7 @@ def create_org_client(org_id):
|
|||||||
grant_types=["authorization_code", "refresh_token"],
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
response_types=["code"],
|
response_types=["code"],
|
||||||
scopes=["openid", "profile", "email"],
|
scopes=["openid", "profile", "email"],
|
||||||
|
allowed_cors_origins=cors_origins,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
is_confidential=True,
|
is_confidential=True,
|
||||||
)
|
)
|
||||||
@@ -99,6 +106,7 @@ def create_org_client(org_id):
|
|||||||
"client_secret": client_secret,
|
"client_secret": client_secret,
|
||||||
"redirect_uris": client.redirect_uris,
|
"redirect_uris": client.redirect_uris,
|
||||||
"scopes": client.scopes,
|
"scopes": client.scopes,
|
||||||
|
"allowed_cors_origins": client.allowed_cors_origins,
|
||||||
"created_at": client.created_at.isoformat() + "Z",
|
"created_at": client.created_at.isoformat() + "Z",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -135,6 +143,12 @@ def update_org_client(org_id, client_id):
|
|||||||
return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR")
|
return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR")
|
||||||
client.redirect_uris = uris
|
client.redirect_uris = uris
|
||||||
|
|
||||||
|
if "allowed_cors_origins" in data:
|
||||||
|
cors_origins, cors_error = validate_cors_origins(data["allowed_cors_origins"])
|
||||||
|
if cors_error:
|
||||||
|
return api_response(success=False, message=cors_error, status=400, error_type="VALIDATION_ERROR")
|
||||||
|
client.allowed_cors_origins = cors_origins
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
AuditService.log_action(
|
AuditService.log_action(
|
||||||
@@ -155,6 +169,7 @@ def update_org_client(org_id, client_id):
|
|||||||
"redirect_uris": client.redirect_uris,
|
"redirect_uris": client.redirect_uris,
|
||||||
"scopes": client.scopes,
|
"scopes": client.scopes,
|
||||||
"grant_types": client.grant_types,
|
"grant_types": client.grant_types,
|
||||||
|
"allowed_cors_origins": client.allowed_cors_origins,
|
||||||
"is_active": client.is_active,
|
"is_active": client.is_active,
|
||||||
"created_at": client.created_at.isoformat() + "Z",
|
"created_at": client.created_at.isoformat() + "Z",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,94 @@
|
|||||||
|
"""Validation helpers for request data."""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
# Special sentinel values allowed in allowed_cors_origins
|
||||||
|
_CORS_SENTINELS = {"+", "*"}
|
||||||
|
|
||||||
|
|
||||||
|
def validate_cors_origins(origins):
|
||||||
|
"""Validate a list of CORS origin values.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
None - means "use global CORS config" (pass-through)
|
||||||
|
["+"] - derive origins from the client's redirect_uris
|
||||||
|
["*"] - allow any origin
|
||||||
|
["https://host"] - explicit allow-list of well-formed origins
|
||||||
|
|
||||||
|
Each non-sentinel entry must be a well-formed origin:
|
||||||
|
scheme (http or https) + host + optional port, with NO path,
|
||||||
|
query string, or fragment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(validated_value, None) on success, or
|
||||||
|
(None, error_message) on failure.
|
||||||
|
"""
|
||||||
|
if origins is None:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not isinstance(origins, list):
|
||||||
|
return None, "allowed_cors_origins must be a list or null"
|
||||||
|
|
||||||
|
validated = []
|
||||||
|
for i, entry in enumerate(origins):
|
||||||
|
if not isinstance(entry, str):
|
||||||
|
return None, f"allowed_cors_origins[{i}]: expected a string, got {type(entry).__name__}"
|
||||||
|
|
||||||
|
entry = entry.strip()
|
||||||
|
if not entry:
|
||||||
|
return None, f"allowed_cors_origins[{i}]: empty string is not allowed"
|
||||||
|
|
||||||
|
# Sentinel values are accepted as-is
|
||||||
|
if entry in _CORS_SENTINELS:
|
||||||
|
validated.append(entry)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse and validate as origin
|
||||||
|
error = _validate_single_origin(entry, i)
|
||||||
|
if error:
|
||||||
|
return None, error
|
||||||
|
|
||||||
|
validated.append(entry)
|
||||||
|
|
||||||
|
return validated, None
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_single_origin(origin, index):
|
||||||
|
"""Validate that a string is a well-formed browser origin.
|
||||||
|
|
||||||
|
A valid origin is: scheme://host[:port] with no path, query, or fragment.
|
||||||
|
Only http and https schemes are accepted.
|
||||||
|
|
||||||
|
Returns an error message string on failure, or None on success.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(origin)
|
||||||
|
except Exception:
|
||||||
|
return f"allowed_cors_origins[{index}]: '{origin}' is not a valid URL"
|
||||||
|
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
return (
|
||||||
|
f"allowed_cors_origins[{index}]: '{origin}' has an invalid scheme "
|
||||||
|
f"'{parsed.scheme}'; only 'http' and 'https' are allowed"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not parsed.hostname:
|
||||||
|
return f"allowed_cors_origins[{index}]: '{origin}' is missing a hostname"
|
||||||
|
|
||||||
|
# Origins must not have a path (other than empty or "/"), query, or fragment
|
||||||
|
if parsed.path and parsed.path != "/":
|
||||||
|
return (
|
||||||
|
f"allowed_cors_origins[{index}]: '{origin}' must not contain a path "
|
||||||
|
f"(got '{parsed.path}'). Specify only scheme://host[:port]"
|
||||||
|
)
|
||||||
|
|
||||||
|
if parsed.query:
|
||||||
|
return (
|
||||||
|
f"allowed_cors_origins[{index}]: '{origin}' must not contain a query string"
|
||||||
|
)
|
||||||
|
|
||||||
|
if parsed.fragment:
|
||||||
|
return (
|
||||||
|
f"allowed_cors_origins[{index}]: '{origin}' must not contain a fragment"
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -38,7 +38,7 @@ def upgrade():
|
|||||||
is_compliance_only, created_at, updated_at, deleted_at
|
is_compliance_only, created_at, updated_at, deleted_at
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
id, 'superadmin', superadmin_id, token, 'active',
|
id, 'superadmin', superadmin_id, token, 'ACTIVE',
|
||||||
ip_address, user_agent, NULL,
|
ip_address, user_agent, NULL,
|
||||||
expires_at, last_activity_at, revoked_at, revoked_reason,
|
expires_at, last_activity_at, revoked_at, revoked_reason,
|
||||||
FALSE, created_at, updated_at, deleted_at
|
FALSE, created_at, updated_at, deleted_at
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
"""Merge branches: consolidate_sessions + remove_sudo_api_keys.
|
||||||
|
|
||||||
|
Revision ID: e1f2a3b4c5d6
|
||||||
|
Revises: c8d2e4f6a1b3, d1e2f3g4h5i6
|
||||||
|
Create Date: 2026-05-19 12:45:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'e1f2a3b4c5d6'
|
||||||
|
down_revision = ('c8d2e4f6a1b3', 'd1e2f3g4h5i6')
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
pass
|
||||||
@@ -7,6 +7,7 @@ Create Date: 2026-05-02 00:00:00.000000
|
|||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
@@ -21,6 +22,21 @@ depends_on = None
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 0: Ensure enum types exist (they may already exist from old tables)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute("""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'approval_grant_type') THEN
|
||||||
|
CREATE TYPE approval_grant_type AS ENUM ('requested', 'assigned');
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'approval_state') THEN
|
||||||
|
CREATE TYPE approval_state AS ENUM ('pending', 'approved', 'rejected', 'revoked', 'suspended');
|
||||||
|
END IF;
|
||||||
|
END$$;
|
||||||
|
""")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Step 1: Create the new network_access_requests table
|
# Step 1: Create the new network_access_requests table
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -34,12 +50,12 @@ def upgrade():
|
|||||||
sa.Column('granted_by_user_id', sa.String(length=36), nullable=True),
|
sa.Column('granted_by_user_id', sa.String(length=36), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
'grant_type',
|
'grant_type',
|
||||||
sa.Enum('requested', 'assigned', name='approval_grant_type', create_type=False),
|
postgresql.ENUM('requested', 'assigned', name='approval_grant_type', create_type=False),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
'status',
|
'status',
|
||||||
sa.Enum(
|
postgresql.ENUM(
|
||||||
'pending', 'approved', 'rejected', 'revoked', 'suspended',
|
'pending', 'approved', 'rejected', 'revoked', 'suspended',
|
||||||
name='approval_state', create_type=False,
|
name='approval_state', create_type=False,
|
||||||
),
|
),
|
||||||
@@ -334,12 +350,12 @@ def downgrade():
|
|||||||
sa.Column('granted_by_user_id', sa.String(length=36), nullable=True),
|
sa.Column('granted_by_user_id', sa.String(length=36), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
'grant_type',
|
'grant_type',
|
||||||
sa.Enum('requested', 'assigned', name='approval_grant_type', create_type=False),
|
postgresql.ENUM('requested', 'assigned', name='approval_grant_type', create_type=False),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
'state',
|
'state',
|
||||||
sa.Enum(
|
postgresql.ENUM(
|
||||||
'pending', 'approved', 'rejected', 'revoked', 'suspended',
|
'pending', 'approved', 'rejected', 'revoked', 'suspended',
|
||||||
name='approval_state', create_type=False,
|
name='approval_state', create_type=False,
|
||||||
),
|
),
|
||||||
@@ -437,7 +453,7 @@ def downgrade():
|
|||||||
sa.Column('user_network_approval_id', sa.String(length=36), nullable=True),
|
sa.Column('user_network_approval_id', sa.String(length=36), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
'state',
|
'state',
|
||||||
sa.Enum(
|
postgresql.ENUM(
|
||||||
'pending_device_registration',
|
'pending_device_registration',
|
||||||
'pending_request',
|
'pending_request',
|
||||||
'pending_manager_approval',
|
'pending_manager_approval',
|
||||||
|
|||||||
@@ -0,0 +1,137 @@
|
|||||||
|
"""Unit tests for validate_cors_origins() helper.
|
||||||
|
|
||||||
|
WHAT: Tests for the CORS origin validation function used by OIDC client
|
||||||
|
create/update endpoints.
|
||||||
|
WHY: Malformed origins would silently break per-client CORS; strict
|
||||||
|
validation catches mistakes at write-time rather than at runtime.
|
||||||
|
EXPECTED: Valid origins accepted, invalid origins rejected with clear errors.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gatehouse_app.utils.validators import validate_cors_origins
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateCorsOriginsAccepts:
|
||||||
|
"""Cases that should pass validation."""
|
||||||
|
|
||||||
|
def test_none_passes_through(self):
|
||||||
|
value, error = validate_cors_origins(None)
|
||||||
|
assert value is None
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_empty_list_is_valid(self):
|
||||||
|
value, error = validate_cors_origins([])
|
||||||
|
assert value == []
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_sentinel_plus(self):
|
||||||
|
value, error = validate_cors_origins(["+"])
|
||||||
|
assert value == ["+"]
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_sentinel_wildcard(self):
|
||||||
|
value, error = validate_cors_origins(["*"])
|
||||||
|
assert value == ["*"]
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_https_origin(self):
|
||||||
|
value, error = validate_cors_origins(["https://example.com"])
|
||||||
|
assert value == ["https://example.com"]
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_https_origin_with_port(self):
|
||||||
|
value, error = validate_cors_origins(["https://example.com:8443"])
|
||||||
|
assert value == ["https://example.com:8443"]
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_http_localhost(self):
|
||||||
|
value, error = validate_cors_origins(["http://localhost:3000"])
|
||||||
|
assert value == ["http://localhost:3000"]
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_multiple_valid_origins(self):
|
||||||
|
origins = ["https://app.example.com", "https://staging.example.com:8443"]
|
||||||
|
value, error = validate_cors_origins(origins)
|
||||||
|
assert value == origins
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_origin_with_trailing_slash_accepted(self):
|
||||||
|
# urlparse treats trailing "/" as path="/", which we allow
|
||||||
|
value, error = validate_cors_origins(["https://example.com/"])
|
||||||
|
assert value == ["https://example.com/"]
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_sentinel_mixed_with_origins(self):
|
||||||
|
value, error = validate_cors_origins(["+", "https://extra.example.com"])
|
||||||
|
assert value == ["+", "https://extra.example.com"]
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateCorsOriginsRejects:
|
||||||
|
"""Cases that should fail validation."""
|
||||||
|
|
||||||
|
def test_not_a_list(self):
|
||||||
|
value, error = validate_cors_origins("https://example.com")
|
||||||
|
assert value is None
|
||||||
|
assert "must be a list" in error
|
||||||
|
|
||||||
|
def test_non_string_entry(self):
|
||||||
|
value, error = validate_cors_origins([123])
|
||||||
|
assert value is None
|
||||||
|
assert "expected a string" in error
|
||||||
|
|
||||||
|
def test_empty_string_entry(self):
|
||||||
|
value, error = validate_cors_origins([""])
|
||||||
|
assert value is None
|
||||||
|
assert "empty string" in error
|
||||||
|
|
||||||
|
def test_whitespace_only_entry(self):
|
||||||
|
value, error = validate_cors_origins([" "])
|
||||||
|
assert value is None
|
||||||
|
assert "empty string" in error
|
||||||
|
|
||||||
|
def test_origin_with_path(self):
|
||||||
|
value, error = validate_cors_origins(["https://example.com/api/v1"])
|
||||||
|
assert value is None
|
||||||
|
assert "must not contain a path" in error
|
||||||
|
|
||||||
|
def test_origin_with_query_string(self):
|
||||||
|
value, error = validate_cors_origins(["https://example.com?q=1"])
|
||||||
|
assert value is None
|
||||||
|
assert "query string" in error
|
||||||
|
|
||||||
|
def test_origin_with_fragment(self):
|
||||||
|
value, error = validate_cors_origins(["https://example.com#section"])
|
||||||
|
assert value is None
|
||||||
|
assert "fragment" in error
|
||||||
|
|
||||||
|
def test_ftp_scheme_rejected(self):
|
||||||
|
value, error = validate_cors_origins(["ftp://example.com"])
|
||||||
|
assert value is None
|
||||||
|
assert "invalid scheme" in error
|
||||||
|
|
||||||
|
def test_no_scheme(self):
|
||||||
|
value, error = validate_cors_origins(["example.com"])
|
||||||
|
assert value is None
|
||||||
|
# urlparse puts this in path, not hostname, so we get either
|
||||||
|
# scheme or hostname error
|
||||||
|
assert error is not None
|
||||||
|
|
||||||
|
def test_bare_string_not_url(self):
|
||||||
|
value, error = validate_cors_origins(["not-a-url"])
|
||||||
|
assert value is None
|
||||||
|
assert error is not None
|
||||||
|
|
||||||
|
def test_mixed_valid_and_invalid(self):
|
||||||
|
"""First invalid entry stops validation."""
|
||||||
|
value, error = validate_cors_origins([
|
||||||
|
"https://good.example.com",
|
||||||
|
"ftp://bad.example.com",
|
||||||
|
])
|
||||||
|
assert value is None
|
||||||
|
assert "allowed_cors_origins[1]" in error
|
||||||
|
|
||||||
|
def test_dict_entry_rejected(self):
|
||||||
|
value, error = validate_cors_origins([{"url": "https://example.com"}])
|
||||||
|
assert value is None
|
||||||
|
assert "expected a string" in error
|
||||||
@@ -0,0 +1,340 @@
|
|||||||
|
"""Unit tests for allowed_cors_origins in OIDC client endpoints.
|
||||||
|
|
||||||
|
WHAT: Tests that the create, update, and list endpoints correctly accept,
|
||||||
|
validate, persist, and return the allowed_cors_origins field.
|
||||||
|
WHY: The field was already on the model but was not wired into any API
|
||||||
|
endpoint; these tests verify the new wiring works end-to-end.
|
||||||
|
EXPECTED: Valid origins are stored and returned; invalid origins are rejected
|
||||||
|
with 400; omitting the field defaults to None (global fallback).
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gatehouse_app import create_app, db
|
||||||
|
from gatehouse_app.extensions import limiter
|
||||||
|
from gatehouse_app.models.oidc.oidc_client import OIDCClient
|
||||||
|
from gatehouse_app.models.organization.organization import Organization
|
||||||
|
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||||
|
from gatehouse_app.models.user.user import User
|
||||||
|
from gatehouse_app.utils.constants import OrganizationRole
|
||||||
|
|
||||||
|
|
||||||
|
# Disable rate limiter for tests
|
||||||
|
limiter.enabled = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def app():
|
||||||
|
"""Create a test Flask app with in-memory SQLite."""
|
||||||
|
_app = create_app(config_name="testing")
|
||||||
|
_app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||||
|
_app.config["TESTING"] = True
|
||||||
|
_app.config["WTF_CSRF_ENABLED"] = False
|
||||||
|
_app.config["RATELIMIT_ENABLED"] = False
|
||||||
|
|
||||||
|
with _app.app_context():
|
||||||
|
db.create_all()
|
||||||
|
yield _app
|
||||||
|
db.session.remove()
|
||||||
|
db.drop_all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_context(app):
|
||||||
|
"""Create a user, org, and admin membership.
|
||||||
|
|
||||||
|
The module-scoped app fixture already holds an active app_context,
|
||||||
|
so we don't push another one here.
|
||||||
|
"""
|
||||||
|
user = User(
|
||||||
|
email=f"admin_{secrets.token_hex(4)}@test.com",
|
||||||
|
full_name="Admin User",
|
||||||
|
email_verified=True,
|
||||||
|
)
|
||||||
|
db.session.add(user)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
org = Organization(
|
||||||
|
name=f"Test Org {secrets.token_hex(4)}",
|
||||||
|
slug=f"test-org-{secrets.token_hex(4)}",
|
||||||
|
)
|
||||||
|
db.session.add(org)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
membership = OrganizationMember(
|
||||||
|
user_id=user.id,
|
||||||
|
organization_id=org.id,
|
||||||
|
role=OrganizationRole.OWNER,
|
||||||
|
)
|
||||||
|
db.session.add(membership)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return user, org
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_headers():
|
||||||
|
return {"Authorization": "Bearer test-token", "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_session_for(user):
|
||||||
|
"""Return a context manager that patches SessionService to authenticate *user*."""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.user = user
|
||||||
|
mock_session.is_active.return_value = True
|
||||||
|
mock_session.is_compliance_only = False
|
||||||
|
mock_session.device_info = {}
|
||||||
|
|
||||||
|
return patch(
|
||||||
|
"gatehouse_app.services.session_service.SessionService.get_active_session_by_token",
|
||||||
|
return_value=mock_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_oidc_client(org_id, **overrides):
|
||||||
|
"""Insert a minimal OIDCClient directly into the DB."""
|
||||||
|
from gatehouse_app.extensions import bcrypt
|
||||||
|
|
||||||
|
defaults = dict(
|
||||||
|
organization_id=org_id,
|
||||||
|
name="Test Client",
|
||||||
|
client_id=secrets.token_hex(16),
|
||||||
|
client_secret_hash=bcrypt.generate_password_hash("secret").decode("utf-8"),
|
||||||
|
redirect_uris=["https://app.example.com/callback"],
|
||||||
|
grant_types=["authorization_code"],
|
||||||
|
response_types=["code"],
|
||||||
|
scopes=["openid"],
|
||||||
|
is_active=True,
|
||||||
|
is_confidential=True,
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
c = OIDCClient(**defaults)
|
||||||
|
db.session.add(c)
|
||||||
|
db.session.commit()
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST create
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCreateClientCorsOrigins:
|
||||||
|
"""POST /api/v1/organizations/<org_id>/clients with allowed_cors_origins."""
|
||||||
|
|
||||||
|
def test_create_with_cors_origins(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.post(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"name": "CORS Test Client",
|
||||||
|
"redirect_uris": ["https://app.example.com/callback"],
|
||||||
|
"allowed_cors_origins": ["https://app.example.com"],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.get_json()
|
||||||
|
assert data["data"]["client"]["allowed_cors_origins"] == ["https://app.example.com"]
|
||||||
|
|
||||||
|
def test_create_with_sentinel_plus(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.post(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"name": "Plus Client",
|
||||||
|
"redirect_uris": ["https://app.example.com/callback"],
|
||||||
|
"allowed_cors_origins": ["+"],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] == ["+"]
|
||||||
|
|
||||||
|
def test_create_without_cors_origins_defaults_to_none(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.post(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"name": "No CORS Client",
|
||||||
|
"redirect_uris": ["https://app.example.com/callback"],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] is None
|
||||||
|
|
||||||
|
def test_create_with_null_cors_origins(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.post(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"name": "Null CORS Client",
|
||||||
|
"redirect_uris": ["https://app.example.com/callback"],
|
||||||
|
"allowed_cors_origins": None,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] is None
|
||||||
|
|
||||||
|
def test_create_with_invalid_cors_origin_returns_400(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.post(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"name": "Bad CORS Client",
|
||||||
|
"redirect_uris": ["https://app.example.com/callback"],
|
||||||
|
"allowed_cors_origins": ["ftp://bad.example.com"],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
assert "invalid scheme" in resp.get_json()["message"]
|
||||||
|
|
||||||
|
def test_create_with_origin_containing_path_returns_400(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.post(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"name": "Path CORS Client",
|
||||||
|
"redirect_uris": ["https://app.example.com/callback"],
|
||||||
|
"allowed_cors_origins": ["https://app.example.com/api"],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
assert "must not contain a path" in resp.get_json()["message"]
|
||||||
|
|
||||||
|
def test_create_with_non_list_cors_origins_returns_400(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.post(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"name": "String CORS Client",
|
||||||
|
"redirect_uris": ["https://app.example.com/callback"],
|
||||||
|
"allowed_cors_origins": "https://app.example.com",
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
assert "must be a list" in resp.get_json()["message"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PATCH update
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestUpdateClientCorsOrigins:
|
||||||
|
"""PATCH /api/v1/organizations/<org_id>/clients/<client_id> with allowed_cors_origins."""
|
||||||
|
|
||||||
|
def test_update_set_cors_origins(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
oidc_client = _create_oidc_client(org.id)
|
||||||
|
assert oidc_client.allowed_cors_origins is None
|
||||||
|
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.patch(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"allowed_cors_origins": ["https://new-app.example.com"],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.get_json()["data"]["client"]
|
||||||
|
assert data["allowed_cors_origins"] == ["https://new-app.example.com"]
|
||||||
|
|
||||||
|
def test_update_clear_cors_origins(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
oidc_client = _create_oidc_client(org.id, allowed_cors_origins=["https://old.example.com"])
|
||||||
|
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.patch(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"allowed_cors_origins": None,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] is None
|
||||||
|
|
||||||
|
def test_update_cors_origins_with_invalid_value_returns_400(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
oidc_client = _create_oidc_client(org.id)
|
||||||
|
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.patch(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"allowed_cors_origins": ["https://good.com", "not-a-url"],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_update_without_cors_field_does_not_change_it(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
oidc_client = _create_oidc_client(org.id, allowed_cors_origins=["https://keep-me.example.com"])
|
||||||
|
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.patch(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"name": "Renamed",
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] == ["https://keep-me.example.com"]
|
||||||
|
|
||||||
|
def test_update_set_wildcard(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
oidc_client = _create_oidc_client(org.id)
|
||||||
|
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.patch(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients/{oidc_client.id}",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
data=json.dumps({
|
||||||
|
"allowed_cors_origins": ["*"],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.get_json()["data"]["client"]["allowed_cors_origins"] == ["*"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET list
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestListClientsCorsOrigins:
|
||||||
|
"""GET /api/v1/organizations/<org_id>/clients returns allowed_cors_origins."""
|
||||||
|
|
||||||
|
def test_list_includes_cors_origins(self, app, auth_context):
|
||||||
|
user, org = auth_context
|
||||||
|
oidc_client = _create_oidc_client(
|
||||||
|
org.id,
|
||||||
|
name="List Test",
|
||||||
|
allowed_cors_origins=["https://list.example.com"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with app.test_client() as tc, _mock_session_for(user):
|
||||||
|
resp = tc.get(
|
||||||
|
f"/api/v1/organizations/{org.id}/clients",
|
||||||
|
headers=_auth_headers(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
clients_list = resp.get_json()["data"]["clients"]
|
||||||
|
matching = [c for c in clients_list if c["client_id"] == oidc_client.client_id]
|
||||||
|
assert len(matching) == 1
|
||||||
|
assert matching[0]["allowed_cors_origins"] == ["https://list.example.com"]
|
||||||
Reference in New Issue
Block a user