Files
coryHawkvelt 60799bbc52 fix(cors): handle wildcard origin with credentials and add unit tests
- Refactor CORS middleware to echo request origin when wildcard + credentials
  is configured (browsers reject Access-Control-Allow-Origin: * with
  Access-Control-Allow-Credentials: true)
- Add _is_origin_allowed() and _cors_origin_header() helpers
- Use CORS_SUPPORTS_CREDENTIALS config consistently
- Ensure consistent Access-Control-Allow-Headers in all CORS paths
- Fix redirect validation in get_token() to allow wildcard CORS origins
- Add 46 unit tests covering encryption round-trips, idempotency, key
  derivation, thread safety, CORS origin matching, and preflight responses
2026-04-26 01:12:39 +09:30

165 lines
6.4 KiB
Python

"""Unit tests for encryption module (general-purpose Fernet encryption).
WHAT: Tests for the PBKDF2-based Fernet encryption/decryption used for
OAuth tokens and client secrets.
WHY: These utilities protect access tokens and client secrets; we need
to verify round-trip correctness and error handling.
EXPECTED: All encrypt/decrypt operations produce correct plaintext.
"""
import threading
import pytest
from gatehouse_app.utils.encryption import (
SALT_LENGTH,
_get_fernet_key,
decrypt,
encrypt,
)
# ---------------------------------------------------------------------------
# Shared fixture
# ---------------------------------------------------------------------------
SECRET_KEY = "test-encryption-secret-key"
SAMPLE_DATA = "access_token=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload"
# ---------------------------------------------------------------------------
# encrypt / decrypt round-trip
# ---------------------------------------------------------------------------
class TestEncryptDecryptRoundTrip:
"""Verify that encrypt -> decrypt returns the original plaintext."""
def test_basic_round_trip(self):
"""TEST: ENC-RT-01 -- Encrypt then decrypt returns original data."""
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
decrypted = decrypt(encrypted, secret_key=SECRET_KEY)
assert decrypted == SAMPLE_DATA
def test_encrypted_is_base64(self):
"""TEST: ENC-RT-02 -- Encrypted output is valid base64."""
import base64
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
# Should not raise
base64.urlsafe_b64decode(encrypted.encode())
def test_different_ciphertext_each_time(self):
"""TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ."""
enc1 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
enc2 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
assert enc1 != enc2
assert decrypt(enc1, secret_key=SECRET_KEY) == SAMPLE_DATA
assert decrypt(enc2, secret_key=SECRET_KEY) == SAMPLE_DATA
def test_round_trip_unicode(self):
"""TEST: ENC-RT-04 -- Unicode data round-trips correctly."""
data = "token=cafe\u00e9\u00f1\u00fc"
encrypted = encrypt(data, secret_key=SECRET_KEY)
assert decrypt(encrypted, secret_key=SECRET_KEY) == data
def test_round_trip_long_data(self):
"""TEST: ENC-RT-05 -- Large data round-trips correctly."""
data = "x" * 10000
encrypted = encrypt(data, secret_key=SECRET_KEY)
assert decrypt(encrypted, secret_key=SECRET_KEY) == data
# ---------------------------------------------------------------------------
# Empty / edge inputs
# ---------------------------------------------------------------------------
class TestEdgeCases:
def test_encrypt_empty_returns_empty(self):
"""TEST: ENC-EDGE-01 -- Encrypting empty string returns empty."""
assert encrypt("", secret_key=SECRET_KEY) == ""
def test_decrypt_empty_returns_empty(self):
"""TEST: ENC-EDGE-02 -- Decrypting empty string returns empty."""
assert decrypt("", secret_key=SECRET_KEY) == ""
def test_missing_key_raises_on_encrypt(self):
"""TEST: ENC-EDGE-03 -- Missing key raises ValueError on encrypt."""
with pytest.raises(ValueError, match="Encryption key not configured"):
encrypt("data", secret_key="")
def test_missing_key_raises_on_decrypt(self):
"""TEST: ENC-EDGE-04 -- Missing key raises ValueError on decrypt."""
with pytest.raises(ValueError, match="Encryption key not configured"):
decrypt("something", secret_key="")
def test_wrong_key_raises_on_decrypt(self):
"""TEST: ENC-EDGE-05 -- Wrong key raises ValueError on decrypt."""
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
with pytest.raises(ValueError, match="Failed to decrypt"):
decrypt(encrypted, secret_key="wrong-key")
def test_corrupted_data_raises(self):
"""TEST: ENC-EDGE-06 -- Corrupted ciphertext raises ValueError."""
import base64
bad = base64.urlsafe_b64encode(b"not-valid-fernet-data").decode()
with pytest.raises(ValueError):
decrypt(bad, secret_key=SECRET_KEY)
# ---------------------------------------------------------------------------
# _get_fernet_key — PBKDF2 derivation
# ---------------------------------------------------------------------------
class TestKeyDerivation:
def test_same_salt_same_key(self):
"""TEST: ENC-KD-01 -- Same salt produces the same derived key."""
salt = b"\x00" * SALT_LENGTH
key1 = _get_fernet_key(SECRET_KEY, salt=salt)
key2 = _get_fernet_key(SECRET_KEY, salt=salt)
assert key1 == key2
def test_different_salt_different_key(self):
"""TEST: ENC-KD-02 -- Different salts produce different keys."""
salt1 = b"\x00" * SALT_LENGTH
salt2 = b"\xff" * SALT_LENGTH
key1 = _get_fernet_key(SECRET_KEY, salt=salt1)
key2 = _get_fernet_key(SECRET_KEY, salt=salt2)
assert key1 != key2
def test_auto_salt_length(self):
"""TEST: ENC-KD-03 -- Auto-generated salt is 16 bytes."""
key = _get_fernet_key(SECRET_KEY)
# If it didn't raise, the salt was valid
assert len(key) > 0
# ---------------------------------------------------------------------------
# Thread safety
# ---------------------------------------------------------------------------
class TestThreadSafety:
"""Concurrent encrypt/decrypt calls must not corrupt state."""
def test_concurrent_encrypt_decrypt(self):
"""TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently."""
errors = []
results = []
def worker(i):
try:
data = f"token-{i}-secret"
enc = encrypt(data, secret_key=SECRET_KEY)
dec = decrypt(enc, secret_key=SECRET_KEY)
results.append((i, dec))
except Exception as exc:
errors.append((i, exc))
threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=10)
assert not errors, f"Thread errors: {errors}"
assert len(results) == 50
for i, dec in results:
assert dec == f"token-{i}-secret", f"Thread {i}: mismatch"