fix(cors): handle wildcard origin with credentials and add unit tests
- Refactor CORS middleware to echo request origin when wildcard + credentials is configured (browsers reject Access-Control-Allow-Origin: * with Access-Control-Allow-Credentials: true) - Add _is_origin_allowed() and _cors_origin_header() helpers - Use CORS_SUPPORTS_CREDENTIALS config consistently - Ensure consistent Access-Control-Allow-Headers in all CORS paths - Fix redirect validation in get_token() to allow wildcard CORS origins - Add 46 unit tests covering encryption round-trips, idempotency, key derivation, thread safety, CORS origin matching, and preflight responses
This commit is contained in:
@@ -0,0 +1,164 @@
|
||||
"""Unit tests for encryption module (general-purpose Fernet encryption).
|
||||
|
||||
WHAT: Tests for the PBKDF2-based Fernet encryption/decryption used for
|
||||
OAuth tokens and client secrets.
|
||||
WHY: These utilities protect access tokens and client secrets; we need
|
||||
to verify round-trip correctness and error handling.
|
||||
EXPECTED: All encrypt/decrypt operations produce correct plaintext.
|
||||
"""
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from gatehouse_app.utils.encryption import (
|
||||
SALT_LENGTH,
|
||||
_get_fernet_key,
|
||||
decrypt,
|
||||
encrypt,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SECRET_KEY = "test-encryption-secret-key"
|
||||
SAMPLE_DATA = "access_token=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# encrypt / decrypt round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEncryptDecryptRoundTrip:
|
||||
"""Verify that encrypt -> decrypt returns the original plaintext."""
|
||||
|
||||
def test_basic_round_trip(self):
|
||||
"""TEST: ENC-RT-01 -- Encrypt then decrypt returns original data."""
|
||||
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
||||
decrypted = decrypt(encrypted, secret_key=SECRET_KEY)
|
||||
assert decrypted == SAMPLE_DATA
|
||||
|
||||
def test_encrypted_is_base64(self):
|
||||
"""TEST: ENC-RT-02 -- Encrypted output is valid base64."""
|
||||
import base64
|
||||
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
||||
# Should not raise
|
||||
base64.urlsafe_b64decode(encrypted.encode())
|
||||
|
||||
def test_different_ciphertext_each_time(self):
|
||||
"""TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ."""
|
||||
enc1 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
||||
enc2 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
||||
assert enc1 != enc2
|
||||
assert decrypt(enc1, secret_key=SECRET_KEY) == SAMPLE_DATA
|
||||
assert decrypt(enc2, secret_key=SECRET_KEY) == SAMPLE_DATA
|
||||
|
||||
def test_round_trip_unicode(self):
|
||||
"""TEST: ENC-RT-04 -- Unicode data round-trips correctly."""
|
||||
data = "token=cafe\u00e9\u00f1\u00fc"
|
||||
encrypted = encrypt(data, secret_key=SECRET_KEY)
|
||||
assert decrypt(encrypted, secret_key=SECRET_KEY) == data
|
||||
|
||||
def test_round_trip_long_data(self):
|
||||
"""TEST: ENC-RT-05 -- Large data round-trips correctly."""
|
||||
data = "x" * 10000
|
||||
encrypted = encrypt(data, secret_key=SECRET_KEY)
|
||||
assert decrypt(encrypted, secret_key=SECRET_KEY) == data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty / edge inputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_encrypt_empty_returns_empty(self):
|
||||
"""TEST: ENC-EDGE-01 -- Encrypting empty string returns empty."""
|
||||
assert encrypt("", secret_key=SECRET_KEY) == ""
|
||||
|
||||
def test_decrypt_empty_returns_empty(self):
|
||||
"""TEST: ENC-EDGE-02 -- Decrypting empty string returns empty."""
|
||||
assert decrypt("", secret_key=SECRET_KEY) == ""
|
||||
|
||||
def test_missing_key_raises_on_encrypt(self):
|
||||
"""TEST: ENC-EDGE-03 -- Missing key raises ValueError on encrypt."""
|
||||
with pytest.raises(ValueError, match="Encryption key not configured"):
|
||||
encrypt("data", secret_key="")
|
||||
|
||||
def test_missing_key_raises_on_decrypt(self):
|
||||
"""TEST: ENC-EDGE-04 -- Missing key raises ValueError on decrypt."""
|
||||
with pytest.raises(ValueError, match="Encryption key not configured"):
|
||||
decrypt("something", secret_key="")
|
||||
|
||||
def test_wrong_key_raises_on_decrypt(self):
|
||||
"""TEST: ENC-EDGE-05 -- Wrong key raises ValueError on decrypt."""
|
||||
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
||||
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||
decrypt(encrypted, secret_key="wrong-key")
|
||||
|
||||
def test_corrupted_data_raises(self):
|
||||
"""TEST: ENC-EDGE-06 -- Corrupted ciphertext raises ValueError."""
|
||||
import base64
|
||||
bad = base64.urlsafe_b64encode(b"not-valid-fernet-data").decode()
|
||||
with pytest.raises(ValueError):
|
||||
decrypt(bad, secret_key=SECRET_KEY)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_fernet_key — PBKDF2 derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestKeyDerivation:
|
||||
def test_same_salt_same_key(self):
|
||||
"""TEST: ENC-KD-01 -- Same salt produces the same derived key."""
|
||||
salt = b"\x00" * SALT_LENGTH
|
||||
key1 = _get_fernet_key(SECRET_KEY, salt=salt)
|
||||
key2 = _get_fernet_key(SECRET_KEY, salt=salt)
|
||||
assert key1 == key2
|
||||
|
||||
def test_different_salt_different_key(self):
|
||||
"""TEST: ENC-KD-02 -- Different salts produce different keys."""
|
||||
salt1 = b"\x00" * SALT_LENGTH
|
||||
salt2 = b"\xff" * SALT_LENGTH
|
||||
key1 = _get_fernet_key(SECRET_KEY, salt=salt1)
|
||||
key2 = _get_fernet_key(SECRET_KEY, salt=salt2)
|
||||
assert key1 != key2
|
||||
|
||||
def test_auto_salt_length(self):
|
||||
"""TEST: ENC-KD-03 -- Auto-generated salt is 16 bytes."""
|
||||
key = _get_fernet_key(SECRET_KEY)
|
||||
# If it didn't raise, the salt was valid
|
||||
assert len(key) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread safety
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Concurrent encrypt/decrypt calls must not corrupt state."""
|
||||
|
||||
def test_concurrent_encrypt_decrypt(self):
|
||||
"""TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently."""
|
||||
errors = []
|
||||
results = []
|
||||
|
||||
def worker(i):
|
||||
try:
|
||||
data = f"token-{i}-secret"
|
||||
enc = encrypt(data, secret_key=SECRET_KEY)
|
||||
dec = decrypt(enc, secret_key=SECRET_KEY)
|
||||
results.append((i, dec))
|
||||
except Exception as exc:
|
||||
errors.append((i, exc))
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=10)
|
||||
|
||||
assert not errors, f"Thread errors: {errors}"
|
||||
assert len(results) == 50
|
||||
for i, dec in results:
|
||||
assert dec == f"token-{i}-secret", f"Thread {i}: mismatch"
|
||||
Reference in New Issue
Block a user