fix(cors): handle wildcard origin with credentials and add unit tests

- Refactor CORS middleware to echo request origin when wildcard + credentials
  is configured (browsers reject Access-Control-Allow-Origin: * with
  Access-Control-Allow-Credentials: true)
- Add _is_origin_allowed() and _cors_origin_header() helpers
- Use CORS_SUPPORTS_CREDENTIALS config consistently
- Ensure consistent Access-Control-Allow-Headers in all CORS paths
- Fix redirect validation in get_token() to allow wildcard CORS origins
- Add 46 unit tests covering encryption round-trips, idempotency, key
  derivation, thread safety, CORS origin matching, and preflight responses
This commit is contained in:
2026-04-26 01:12:39 +09:30
parent 9738765258
commit 60799bbc52
5 changed files with 555 additions and 37 deletions
+2 -1
View File
@@ -246,7 +246,8 @@ def get_token():
parsed_redirect = urlparse(redirect_url)
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
if redirect_origin not in allowed_origins:
wildcard = "*" in allowed_origins
if not wildcard and redirect_origin not in allowed_origins:
return api_response(success=False, message="Redirect URL is not allowed.", status=400, error_type="INVALID_REDIRECT")
sep = "&" if "?" in redirect_url else "?"
+59 -36
View File
@@ -1,6 +1,44 @@
"""CORS middleware configuration."""
from flask import request, make_response
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
ALLOWED_HEADERS = (
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
"Cache-Control, Pragma, X-WebAuthn-Session-Token"
)
def _is_origin_allowed(origin, cors_origins):
"""Return True if the origin is permitted by the CORS config.
Handles both wildcard ("*") and explicit origin lists.
"""
if not origin:
return False
if cors_origins == "*":
return True
if isinstance(cors_origins, list):
if "*" in cors_origins:
return True
return origin in cors_origins
return False
def _cors_origin_header(cors_origins, request_origin):
"""Return the value for Access-Control-Allow-Origin.
Per the CORS spec, browsers reject ``*`` when credentials are involved,
so we echo the request origin when wildcard + credentials is configured.
"""
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all and request_origin:
return request_origin
if allow_all:
return "*"
if request_origin and request_origin in cors_origins:
return request_origin
return None
def setup_cors(app):
"""
@@ -9,6 +47,7 @@ def setup_cors(app):
Args:
app: Flask application instance
"""
supports_credentials = app.config.get("CORS_SUPPORTS_CREDENTIALS", True)
@app.before_request
def handle_preflight():
@@ -16,49 +55,33 @@ def setup_cors(app):
if request.method == "OPTIONS":
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all:
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
elif origin and origin in cors_origins:
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
if not _is_origin_allowed(origin, cors_origins):
return None
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = _cors_origin_header(cors_origins, origin)
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
if supports_credentials:
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
@app.after_request
def after_request_cors(response):
"""Add additional CORS headers if needed."""
"""Add CORS headers to non-preflight responses."""
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all:
# When allowing all origins, set header to "*"
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Max-Age"] = "3600"
elif origin and origin in cors_origins:
# When allowing specific origins, echo the request origin
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
response.headers["Access-Control-Allow-Credentials"] = "true"
allow_origin = _cors_origin_header(cors_origins, origin)
if allow_origin:
response.headers["Access-Control-Allow-Origin"] = allow_origin
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
if supports_credentials:
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
return response
+205
View File
@@ -0,0 +1,205 @@
"""Unit tests for ca_key_encryption module.
WHAT: Tests for the Fernet-based CA private key encryption/decryption
utility functions.
WHY: CA private keys are the most sensitive data in the system; we need
to verify round-trip correctness, idempotency, and error handling.
EXPECTED: All encrypt/decrypt operations produce correct plaintext.
"""
import os
import threading
from unittest.mock import patch
import pytest
from gatehouse_app.utils.ca_key_encryption import (
CAKeyEncryptionError,
_FERNET_PREFIX,
_get_fernet,
decrypt_ca_key,
encrypt_ca_key,
is_encrypted,
reencrypt_ca_key,
)
# ---------------------------------------------------------------------------
# Shared fixture
# ---------------------------------------------------------------------------
SAMPLE_PEM = (
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtz\n"
"c2gtZWQyNTUxOQAAACBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBAAA\n"
"-----END OPENSSH PRIVATE KEY-----"
)
@pytest.fixture(autouse=True)
def _set_ca_encryption_key():
"""Ensure CA_ENCRYPTION_KEY is set for every test."""
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "test-secret-key-for-unit-tests"}):
yield
# ---------------------------------------------------------------------------
# encrypt / decrypt round-trip
# ---------------------------------------------------------------------------
class TestEncryptDecryptRoundTrip:
"""Verify that encrypt -> decrypt returns the original plaintext."""
def test_basic_round_trip(self):
"""TEST: ENC-RT-01 -- Encrypt then decrypt returns original PEM."""
encrypted = encrypt_ca_key(SAMPLE_PEM)
decrypted = decrypt_ca_key(encrypted)
assert decrypted == SAMPLE_PEM
def test_encrypted_value_has_prefix(self):
"""TEST: ENC-RT-02 -- Encrypted output carries the $fernet$ envelope."""
encrypted = encrypt_ca_key(SAMPLE_PEM)
assert encrypted.startswith(_FERNET_PREFIX)
def test_different_ciphertext_each_time(self):
"""TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ."""
enc1 = encrypt_ca_key(SAMPLE_PEM)
enc2 = encrypt_ca_key(SAMPLE_PEM)
assert enc1 != enc2
assert decrypt_ca_key(enc1) == SAMPLE_PEM
assert decrypt_ca_key(enc2) == SAMPLE_PEM
# ---------------------------------------------------------------------------
# Idempotency
# ---------------------------------------------------------------------------
class TestIdempotency:
"""The module must not double-encrypt or double-decrypt."""
def test_encrypt_idempotent(self):
"""TEST: ENC-IDEM-01 -- Encrypting an already-encrypted value is a no-op."""
encrypted = encrypt_ca_key(SAMPLE_PEM)
double = encrypt_ca_key(encrypted)
assert double == encrypted
def test_decrypt_plaintext_passthrough(self):
"""TEST: ENC-IDEM-02 -- Decrypting a plaintext (legacy) value returns it as-is."""
result = decrypt_ca_key(SAMPLE_PEM)
assert result == SAMPLE_PEM
# ---------------------------------------------------------------------------
# is_encrypted helper
# ---------------------------------------------------------------------------
class TestIsEncrypted:
def test_encrypted_value(self):
"""TEST: ENC-IE-01 -- is_encrypted returns True for $fernet$ values."""
encrypted = encrypt_ca_key(SAMPLE_PEM)
assert is_encrypted(encrypted) is True
def test_plaintext_value(self):
"""TEST: ENC-IE-02 -- is_encrypted returns False for plain PEM."""
assert is_encrypted(SAMPLE_PEM) is False
def test_empty_string(self):
"""TEST: ENC-IE-03 -- is_encrypted returns False for empty string."""
assert is_encrypted("") is False
def test_none_value(self):
"""TEST: ENC-IE-04 -- is_encrypted returns False for None."""
assert is_encrypted(None) is False
# ---------------------------------------------------------------------------
# Error handling
# ---------------------------------------------------------------------------
class TestErrorHandling:
def test_encrypt_empty_raises(self):
"""TEST: ENC-ERR-01 -- Encrypting empty string raises CAKeyEncryptionError."""
with pytest.raises(CAKeyEncryptionError, match="empty"):
encrypt_ca_key("")
def test_decrypt_empty_raises(self):
"""TEST: ENC-ERR-02 -- Decrypting empty string raises CAKeyEncryptionError."""
with pytest.raises(CAKeyEncryptionError, match="empty"):
decrypt_ca_key("")
def test_missing_key_raises(self):
"""TEST: ENC-ERR-03 -- Operations fail when CA_ENCRYPTION_KEY is unset."""
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("CA_ENCRYPTION_KEY", None)
with pytest.raises(CAKeyEncryptionError, match="not set"):
encrypt_ca_key(SAMPLE_PEM)
def test_wrong_key_raises_on_decrypt(self):
"""TEST: ENC-ERR-04 -- Decrypting with the wrong key raises."""
encrypted = encrypt_ca_key(SAMPLE_PEM)
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "wrong-key"}):
with pytest.raises(CAKeyEncryptionError, match="decryption failed"):
decrypt_ca_key(encrypted)
def test_corrupted_data_raises(self):
"""TEST: ENC-ERR-05 -- Decrypting corrupted ciphertext raises."""
with pytest.raises(CAKeyEncryptionError):
decrypt_ca_key("$fernet$not-a-real-token")
# ---------------------------------------------------------------------------
# reencrypt_ca_key -- key rotation
# ---------------------------------------------------------------------------
class TestReencrypt:
def test_reencrypt_round_trip(self):
"""TEST: ENC-RE-01 -- Re-encrypted value decrypts with the new key."""
old_key = "old-secret-key"
new_key = "new-secret-key"
encrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", old_key)
reencrypted = reencrypt_ca_key(encrypted, old_key, new_key)
# Verify it decrypts with the new key
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}):
decrypted = decrypt_ca_key(reencrypted)
assert decrypted == SAMPLE_PEM
def test_reencrypt_plaintext_key(self):
"""TEST: ENC-RE-02 -- Re-encrypting a legacy plaintext key works."""
new_key = "brand-new-key"
reencrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", new_key)
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}):
decrypted = decrypt_ca_key(reencrypted)
assert decrypted == SAMPLE_PEM
# ---------------------------------------------------------------------------
# Thread safety
# ---------------------------------------------------------------------------
class TestThreadSafety:
"""Concurrent encrypt/decrypt calls must not corrupt state."""
def test_concurrent_encrypt_decrypt(self):
"""TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently."""
errors = []
results = []
def worker(i):
try:
data = f"key-data-{i}"
enc = encrypt_ca_key(data)
dec = decrypt_ca_key(enc)
results.append((i, dec))
except Exception as exc:
errors.append((i, exc))
threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=10)
assert not errors, f"Thread errors: {errors}"
assert len(results) == 50
for i, dec in results:
assert dec == f"key-data-{i}", f"Thread {i}: expected 'key-data-{i}', got {dec!r}"
+125
View File
@@ -0,0 +1,125 @@
"""Unit tests for CORS middleware.
WHAT: Tests for the CORS middleware configuration, including wildcard
origin handling, credentials support, and preflight responses.
WHY: CORS misconfiguration can break browser clients or leak credentials.
EXPECTED: Correct Access-Control-* headers for all origin configurations.
"""
import pytest
from flask import Flask
from gatehouse_app.middleware.cors import (
_is_origin_allowed,
_cors_origin_header,
setup_cors,
)
# ---------------------------------------------------------------------------
# _is_origin_allowed
# ---------------------------------------------------------------------------
class TestIsOriginAllowed:
def test_empty_origin_rejected(self):
"""TEST: CORS-01 -- Empty origin is never allowed."""
assert _is_origin_allowed("", ["https://example.com"]) is False
assert _is_origin_allowed(None, "*") is False
def test_wildcard_string(self):
"""TEST: CORS-02 -- Wildcard string allows any origin."""
assert _is_origin_allowed("https://evil.com", "*") is True
def test_wildcard_in_list(self):
"""TEST: CORS-03 -- Wildcard in list allows any origin."""
assert _is_origin_allowed("https://evil.com", ["*", "https://example.com"]) is True
def test_explicit_origin_match(self):
"""TEST: CORS-04 -- Explicit list matches exact origin."""
origins = ["https://example.com", "http://localhost:3000"]
assert _is_origin_allowed("https://example.com", origins) is True
assert _is_origin_allowed("https://evil.com", origins) is False
def test_empty_origins_list(self):
"""TEST: CORS-05 -- Empty list rejects everything."""
assert _is_origin_allowed("https://example.com", []) is False
# ---------------------------------------------------------------------------
# _cors_origin_header
# ---------------------------------------------------------------------------
class TestCorsOriginHeader:
def test_wildcard_with_origin_echoes(self):
"""TEST: CORS-HDR-01 -- Wildcard echoes request origin (for credentials)."""
assert _cors_origin_header("*", "https://example.com") == "https://example.com"
def test_wildcard_without_origin(self):
"""TEST: CORS-HDR-02 -- Wildcard with no origin returns *."""
assert _cors_origin_header("*", None) == "*"
def test_wildcard_in_list_with_origin(self):
"""TEST: CORS-HDR-03 -- Wildcard in list echoes request origin."""
result = _cors_origin_header(["*", "https://example.com"], "https://any.com")
assert result == "https://any.com"
def test_specific_origin_match(self):
"""TEST: CORS-HDR-04 -- Matching origin is echoed."""
origins = ["https://example.com"]
assert _cors_origin_header(origins, "https://example.com") == "https://example.com"
def test_specific_origin_no_match(self):
"""TEST: CORS-HDR-05 -- Non-matching origin returns None."""
origins = ["https://example.com"]
assert _cors_origin_header(origins, "https://evil.com") is None
def test_no_origin_no_match(self):
"""TEST: CORS-HDR-06 -- No origin with specific list returns None."""
origins = ["https://example.com"]
assert _cors_origin_header(origins, None) is None
# ---------------------------------------------------------------------------
# Integration: preflight response
# ---------------------------------------------------------------------------
class TestPreflightIntegration:
@pytest.fixture
def app_wildcard(self):
app = Flask(__name__)
app.config["CORS_ORIGINS"] = "*"
app.config["CORS_SUPPORTS_CREDENTIALS"] = True
setup_cors(app)
app.config["TESTING"] = True
return app
@pytest.fixture
def app_specific(self):
app = Flask(__name__)
app.config["CORS_ORIGINS"] = ["https://example.com"]
app.config["CORS_SUPPORTS_CREDENTIALS"] = True
setup_cors(app)
app.config["TESTING"] = True
return app
def test_wildcard_preflight_echoes_origin(self, app_wildcard):
"""TEST: CORS-PF-01 -- Wildcard preflight echoes request origin."""
with app_wildcard.test_client() as client:
resp = client.options("/", headers={"Origin": "https://example.com"})
assert resp.status_code == 204
assert resp.headers.get("Access-Control-Allow-Origin") == "https://example.com"
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
def test_specific_origin_preflight(self, app_specific):
"""TEST: CORS-PF-02 -- Specific origin preflight allows matching origin."""
with app_specific.test_client() as client:
resp = client.options("/", headers={"Origin": "https://example.com"})
assert resp.status_code == 204
assert resp.headers.get("Access-Control-Allow-Origin") == "https://example.com"
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
def test_specific_origin_rejects_unknown(self, app_specific):
"""TEST: CORS-PF-03 -- Non-matching origin gets no CORS headers."""
with app_specific.test_client() as client:
resp = client.options("/", headers={"Origin": "https://evil.com"})
# No preflight handler runs, Flask returns default
assert resp.headers.get("Access-Control-Allow-Origin") is None
+164
View File
@@ -0,0 +1,164 @@
"""Unit tests for encryption module (general-purpose Fernet encryption).
WHAT: Tests for the PBKDF2-based Fernet encryption/decryption used for
OAuth tokens and client secrets.
WHY: These utilities protect access tokens and client secrets; we need
to verify round-trip correctness and error handling.
EXPECTED: All encrypt/decrypt operations produce correct plaintext.
"""
import threading
import pytest
from gatehouse_app.utils.encryption import (
SALT_LENGTH,
_get_fernet_key,
decrypt,
encrypt,
)
# ---------------------------------------------------------------------------
# Shared fixture
# ---------------------------------------------------------------------------
SECRET_KEY = "test-encryption-secret-key"
SAMPLE_DATA = "access_token=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload"
# ---------------------------------------------------------------------------
# encrypt / decrypt round-trip
# ---------------------------------------------------------------------------
class TestEncryptDecryptRoundTrip:
"""Verify that encrypt -> decrypt returns the original plaintext."""
def test_basic_round_trip(self):
"""TEST: ENC-RT-01 -- Encrypt then decrypt returns original data."""
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
decrypted = decrypt(encrypted, secret_key=SECRET_KEY)
assert decrypted == SAMPLE_DATA
def test_encrypted_is_base64(self):
"""TEST: ENC-RT-02 -- Encrypted output is valid base64."""
import base64
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
# Should not raise
base64.urlsafe_b64decode(encrypted.encode())
def test_different_ciphertext_each_time(self):
"""TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ."""
enc1 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
enc2 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
assert enc1 != enc2
assert decrypt(enc1, secret_key=SECRET_KEY) == SAMPLE_DATA
assert decrypt(enc2, secret_key=SECRET_KEY) == SAMPLE_DATA
def test_round_trip_unicode(self):
"""TEST: ENC-RT-04 -- Unicode data round-trips correctly."""
data = "token=cafe\u00e9\u00f1\u00fc"
encrypted = encrypt(data, secret_key=SECRET_KEY)
assert decrypt(encrypted, secret_key=SECRET_KEY) == data
def test_round_trip_long_data(self):
"""TEST: ENC-RT-05 -- Large data round-trips correctly."""
data = "x" * 10000
encrypted = encrypt(data, secret_key=SECRET_KEY)
assert decrypt(encrypted, secret_key=SECRET_KEY) == data
# ---------------------------------------------------------------------------
# Empty / edge inputs
# ---------------------------------------------------------------------------
class TestEdgeCases:
def test_encrypt_empty_returns_empty(self):
"""TEST: ENC-EDGE-01 -- Encrypting empty string returns empty."""
assert encrypt("", secret_key=SECRET_KEY) == ""
def test_decrypt_empty_returns_empty(self):
"""TEST: ENC-EDGE-02 -- Decrypting empty string returns empty."""
assert decrypt("", secret_key=SECRET_KEY) == ""
def test_missing_key_raises_on_encrypt(self):
"""TEST: ENC-EDGE-03 -- Missing key raises ValueError on encrypt."""
with pytest.raises(ValueError, match="Encryption key not configured"):
encrypt("data", secret_key="")
def test_missing_key_raises_on_decrypt(self):
"""TEST: ENC-EDGE-04 -- Missing key raises ValueError on decrypt."""
with pytest.raises(ValueError, match="Encryption key not configured"):
decrypt("something", secret_key="")
def test_wrong_key_raises_on_decrypt(self):
"""TEST: ENC-EDGE-05 -- Wrong key raises ValueError on decrypt."""
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
with pytest.raises(ValueError, match="Failed to decrypt"):
decrypt(encrypted, secret_key="wrong-key")
def test_corrupted_data_raises(self):
"""TEST: ENC-EDGE-06 -- Corrupted ciphertext raises ValueError."""
import base64
bad = base64.urlsafe_b64encode(b"not-valid-fernet-data").decode()
with pytest.raises(ValueError):
decrypt(bad, secret_key=SECRET_KEY)
# ---------------------------------------------------------------------------
# _get_fernet_key — PBKDF2 derivation
# ---------------------------------------------------------------------------
class TestKeyDerivation:
def test_same_salt_same_key(self):
"""TEST: ENC-KD-01 -- Same salt produces the same derived key."""
salt = b"\x00" * SALT_LENGTH
key1 = _get_fernet_key(SECRET_KEY, salt=salt)
key2 = _get_fernet_key(SECRET_KEY, salt=salt)
assert key1 == key2
def test_different_salt_different_key(self):
"""TEST: ENC-KD-02 -- Different salts produce different keys."""
salt1 = b"\x00" * SALT_LENGTH
salt2 = b"\xff" * SALT_LENGTH
key1 = _get_fernet_key(SECRET_KEY, salt=salt1)
key2 = _get_fernet_key(SECRET_KEY, salt=salt2)
assert key1 != key2
def test_auto_salt_length(self):
"""TEST: ENC-KD-03 -- Auto-generated salt is 16 bytes."""
key = _get_fernet_key(SECRET_KEY)
# If it didn't raise, the salt was valid
assert len(key) > 0
# ---------------------------------------------------------------------------
# Thread safety
# ---------------------------------------------------------------------------
class TestThreadSafety:
"""Concurrent encrypt/decrypt calls must not corrupt state."""
def test_concurrent_encrypt_decrypt(self):
"""TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently."""
errors = []
results = []
def worker(i):
try:
data = f"token-{i}-secret"
enc = encrypt(data, secret_key=SECRET_KEY)
dec = decrypt(enc, secret_key=SECRET_KEY)
results.append((i, dec))
except Exception as exc:
errors.append((i, exc))
threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=10)
assert not errors, f"Thread errors: {errors}"
assert len(results) == 50
for i, dec in results:
assert dec == f"token-{i}-secret", f"Thread {i}: mismatch"