From 60799bbc52f1ce9e01d22eb5f58ad332e303c892 Mon Sep 17 00:00:00 2001 From: Cory Hawklvelt Date: Sun, 26 Apr 2026 01:12:39 +0930 Subject: [PATCH] fix(cors): handle wildcard origin with credentials and add unit tests - Refactor CORS middleware to echo request origin when wildcard + credentials is configured (browsers reject Access-Control-Allow-Origin: * with Access-Control-Allow-Credentials: true) - Add _is_origin_allowed() and _cors_origin_header() helpers - Use CORS_SUPPORTS_CREDENTIALS config consistently - Ensure consistent Access-Control-Allow-Headers in all CORS paths - Fix redirect validation in get_token() to allow wildcard CORS origins - Add 46 unit tests covering encryption round-trips, idempotency, key derivation, thread safety, CORS origin matching, and preflight responses --- gatehouse_app/api/v1/auth/core.py | 3 +- gatehouse_app/middleware/cors.py | 95 ++++++++----- tests/unit/test_ca_key_encryption.py | 205 +++++++++++++++++++++++++++ tests/unit/test_cors.py | 125 ++++++++++++++++ tests/unit/test_encryption.py | 164 +++++++++++++++++++++ 5 files changed, 555 insertions(+), 37 deletions(-) create mode 100644 tests/unit/test_ca_key_encryption.py create mode 100644 tests/unit/test_cors.py create mode 100644 tests/unit/test_encryption.py diff --git a/gatehouse_app/api/v1/auth/core.py b/gatehouse_app/api/v1/auth/core.py index 42f11c4..417a4ab 100644 --- a/gatehouse_app/api/v1/auth/core.py +++ b/gatehouse_app/api/v1/auth/core.py @@ -246,7 +246,8 @@ def get_token(): parsed_redirect = urlparse(redirect_url) redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}" - if redirect_origin not in allowed_origins: + wildcard = "*" in allowed_origins + if not wildcard and redirect_origin not in allowed_origins: return api_response(success=False, message="Redirect URL is not allowed.", status=400, error_type="INVALID_REDIRECT") sep = "&" if "?" in redirect_url else "?" diff --git a/gatehouse_app/middleware/cors.py b/gatehouse_app/middleware/cors.py index defe68c..797d026 100644 --- a/gatehouse_app/middleware/cors.py +++ b/gatehouse_app/middleware/cors.py @@ -1,6 +1,44 @@ """CORS middleware configuration.""" from flask import request, make_response +ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS" +ALLOWED_HEADERS = ( + "Content-Type, Authorization, X-Requested-With, X-Request-ID, " + "Cache-Control, Pragma, X-WebAuthn-Session-Token" +) + + +def _is_origin_allowed(origin, cors_origins): + """Return True if the origin is permitted by the CORS config. + + Handles both wildcard ("*") and explicit origin lists. + """ + if not origin: + return False + if cors_origins == "*": + return True + if isinstance(cors_origins, list): + if "*" in cors_origins: + return True + return origin in cors_origins + return False + + +def _cors_origin_header(cors_origins, request_origin): + """Return the value for Access-Control-Allow-Origin. + + Per the CORS spec, browsers reject ``*`` when credentials are involved, + so we echo the request origin when wildcard + credentials is configured. + """ + allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins) + if allow_all and request_origin: + return request_origin + if allow_all: + return "*" + if request_origin and request_origin in cors_origins: + return request_origin + return None + def setup_cors(app): """ @@ -9,6 +47,7 @@ def setup_cors(app): Args: app: Flask application instance """ + supports_credentials = app.config.get("CORS_SUPPORTS_CREDENTIALS", True) @app.before_request def handle_preflight(): @@ -16,49 +55,33 @@ def setup_cors(app): if request.method == "OPTIONS": origin = request.headers.get("Origin") cors_origins = app.config.get("CORS_ORIGINS", []) - - # Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard) - allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins) - - if allow_all: - response = make_response("", 204) - response.headers["Access-Control-Allow-Origin"] = "*" - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma" - response.headers["Access-Control-Max-Age"] = "3600" - response.headers["Cache-Control"] = "no-cache, no-store" - return response - elif origin and origin in cors_origins: - response = make_response("", 204) - response.headers["Access-Control-Allow-Origin"] = origin - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token" + + if not _is_origin_allowed(origin, cors_origins): + return None + + response = make_response("", 204) + response.headers["Access-Control-Allow-Origin"] = _cors_origin_header(cors_origins, origin) + response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS + response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS + if supports_credentials: response.headers["Access-Control-Allow-Credentials"] = "true" - response.headers["Access-Control-Max-Age"] = "3600" - response.headers["Cache-Control"] = "no-cache, no-store" - return response + response.headers["Access-Control-Max-Age"] = "3600" + response.headers["Cache-Control"] = "no-cache, no-store" + return response @app.after_request def after_request_cors(response): - """Add additional CORS headers if needed.""" + """Add CORS headers to non-preflight responses.""" origin = request.headers.get("Origin") cors_origins = app.config.get("CORS_ORIGINS", []) - # Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard) - allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins) - - if allow_all: - # When allowing all origins, set header to "*" - response.headers["Access-Control-Allow-Origin"] = "*" - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma" - response.headers["Access-Control-Max-Age"] = "3600" - elif origin and origin in cors_origins: - # When allowing specific origins, echo the request origin - response.headers["Access-Control-Allow-Origin"] = origin - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token" - response.headers["Access-Control-Allow-Credentials"] = "true" + allow_origin = _cors_origin_header(cors_origins, origin) + if allow_origin: + response.headers["Access-Control-Allow-Origin"] = allow_origin + response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS + response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS + if supports_credentials: + response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Max-Age"] = "3600" return response diff --git a/tests/unit/test_ca_key_encryption.py b/tests/unit/test_ca_key_encryption.py new file mode 100644 index 0000000..3f19353 --- /dev/null +++ b/tests/unit/test_ca_key_encryption.py @@ -0,0 +1,205 @@ +"""Unit tests for ca_key_encryption module. + +WHAT: Tests for the Fernet-based CA private key encryption/decryption + utility functions. +WHY: CA private keys are the most sensitive data in the system; we need + to verify round-trip correctness, idempotency, and error handling. +EXPECTED: All encrypt/decrypt operations produce correct plaintext. +""" +import os +import threading +from unittest.mock import patch + +import pytest + +from gatehouse_app.utils.ca_key_encryption import ( + CAKeyEncryptionError, + _FERNET_PREFIX, + _get_fernet, + decrypt_ca_key, + encrypt_ca_key, + is_encrypted, + reencrypt_ca_key, +) + + +# --------------------------------------------------------------------------- +# Shared fixture +# --------------------------------------------------------------------------- + +SAMPLE_PEM = ( + "-----BEGIN OPENSSH PRIVATE KEY-----\n" + "b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtz\n" + "c2gtZWQyNTUxOQAAACBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBAAA\n" + "-----END OPENSSH PRIVATE KEY-----" +) + + +@pytest.fixture(autouse=True) +def _set_ca_encryption_key(): + """Ensure CA_ENCRYPTION_KEY is set for every test.""" + with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "test-secret-key-for-unit-tests"}): + yield + + +# --------------------------------------------------------------------------- +# encrypt / decrypt round-trip +# --------------------------------------------------------------------------- + +class TestEncryptDecryptRoundTrip: + """Verify that encrypt -> decrypt returns the original plaintext.""" + + def test_basic_round_trip(self): + """TEST: ENC-RT-01 -- Encrypt then decrypt returns original PEM.""" + encrypted = encrypt_ca_key(SAMPLE_PEM) + decrypted = decrypt_ca_key(encrypted) + assert decrypted == SAMPLE_PEM + + def test_encrypted_value_has_prefix(self): + """TEST: ENC-RT-02 -- Encrypted output carries the $fernet$ envelope.""" + encrypted = encrypt_ca_key(SAMPLE_PEM) + assert encrypted.startswith(_FERNET_PREFIX) + + def test_different_ciphertext_each_time(self): + """TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ.""" + enc1 = encrypt_ca_key(SAMPLE_PEM) + enc2 = encrypt_ca_key(SAMPLE_PEM) + assert enc1 != enc2 + assert decrypt_ca_key(enc1) == SAMPLE_PEM + assert decrypt_ca_key(enc2) == SAMPLE_PEM + + +# --------------------------------------------------------------------------- +# Idempotency +# --------------------------------------------------------------------------- + +class TestIdempotency: + """The module must not double-encrypt or double-decrypt.""" + + def test_encrypt_idempotent(self): + """TEST: ENC-IDEM-01 -- Encrypting an already-encrypted value is a no-op.""" + encrypted = encrypt_ca_key(SAMPLE_PEM) + double = encrypt_ca_key(encrypted) + assert double == encrypted + + def test_decrypt_plaintext_passthrough(self): + """TEST: ENC-IDEM-02 -- Decrypting a plaintext (legacy) value returns it as-is.""" + result = decrypt_ca_key(SAMPLE_PEM) + assert result == SAMPLE_PEM + + +# --------------------------------------------------------------------------- +# is_encrypted helper +# --------------------------------------------------------------------------- + +class TestIsEncrypted: + def test_encrypted_value(self): + """TEST: ENC-IE-01 -- is_encrypted returns True for $fernet$ values.""" + encrypted = encrypt_ca_key(SAMPLE_PEM) + assert is_encrypted(encrypted) is True + + def test_plaintext_value(self): + """TEST: ENC-IE-02 -- is_encrypted returns False for plain PEM.""" + assert is_encrypted(SAMPLE_PEM) is False + + def test_empty_string(self): + """TEST: ENC-IE-03 -- is_encrypted returns False for empty string.""" + assert is_encrypted("") is False + + def test_none_value(self): + """TEST: ENC-IE-04 -- is_encrypted returns False for None.""" + assert is_encrypted(None) is False + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + +class TestErrorHandling: + def test_encrypt_empty_raises(self): + """TEST: ENC-ERR-01 -- Encrypting empty string raises CAKeyEncryptionError.""" + with pytest.raises(CAKeyEncryptionError, match="empty"): + encrypt_ca_key("") + + def test_decrypt_empty_raises(self): + """TEST: ENC-ERR-02 -- Decrypting empty string raises CAKeyEncryptionError.""" + with pytest.raises(CAKeyEncryptionError, match="empty"): + decrypt_ca_key("") + + def test_missing_key_raises(self): + """TEST: ENC-ERR-03 -- Operations fail when CA_ENCRYPTION_KEY is unset.""" + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("CA_ENCRYPTION_KEY", None) + with pytest.raises(CAKeyEncryptionError, match="not set"): + encrypt_ca_key(SAMPLE_PEM) + + def test_wrong_key_raises_on_decrypt(self): + """TEST: ENC-ERR-04 -- Decrypting with the wrong key raises.""" + encrypted = encrypt_ca_key(SAMPLE_PEM) + with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "wrong-key"}): + with pytest.raises(CAKeyEncryptionError, match="decryption failed"): + decrypt_ca_key(encrypted) + + def test_corrupted_data_raises(self): + """TEST: ENC-ERR-05 -- Decrypting corrupted ciphertext raises.""" + with pytest.raises(CAKeyEncryptionError): + decrypt_ca_key("$fernet$not-a-real-token") + + +# --------------------------------------------------------------------------- +# reencrypt_ca_key -- key rotation +# --------------------------------------------------------------------------- + +class TestReencrypt: + def test_reencrypt_round_trip(self): + """TEST: ENC-RE-01 -- Re-encrypted value decrypts with the new key.""" + old_key = "old-secret-key" + new_key = "new-secret-key" + encrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", old_key) + reencrypted = reencrypt_ca_key(encrypted, old_key, new_key) + + # Verify it decrypts with the new key + with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}): + decrypted = decrypt_ca_key(reencrypted) + assert decrypted == SAMPLE_PEM + + def test_reencrypt_plaintext_key(self): + """TEST: ENC-RE-02 -- Re-encrypting a legacy plaintext key works.""" + new_key = "brand-new-key" + reencrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", new_key) + with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}): + decrypted = decrypt_ca_key(reencrypted) + assert decrypted == SAMPLE_PEM + + +# --------------------------------------------------------------------------- +# Thread safety +# --------------------------------------------------------------------------- + +class TestThreadSafety: + """Concurrent encrypt/decrypt calls must not corrupt state.""" + + def test_concurrent_encrypt_decrypt(self): + """TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently.""" + errors = [] + results = [] + + def worker(i): + try: + data = f"key-data-{i}" + enc = encrypt_ca_key(data) + dec = decrypt_ca_key(enc) + results.append((i, dec)) + except Exception as exc: + errors.append((i, exc)) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + + assert not errors, f"Thread errors: {errors}" + assert len(results) == 50 + for i, dec in results: + assert dec == f"key-data-{i}", f"Thread {i}: expected 'key-data-{i}', got {dec!r}" diff --git a/tests/unit/test_cors.py b/tests/unit/test_cors.py new file mode 100644 index 0000000..46a55fe --- /dev/null +++ b/tests/unit/test_cors.py @@ -0,0 +1,125 @@ +"""Unit tests for CORS middleware. + +WHAT: Tests for the CORS middleware configuration, including wildcard + origin handling, credentials support, and preflight responses. +WHY: CORS misconfiguration can break browser clients or leak credentials. +EXPECTED: Correct Access-Control-* headers for all origin configurations. +""" +import pytest +from flask import Flask + +from gatehouse_app.middleware.cors import ( + _is_origin_allowed, + _cors_origin_header, + setup_cors, +) + + +# --------------------------------------------------------------------------- +# _is_origin_allowed +# --------------------------------------------------------------------------- + +class TestIsOriginAllowed: + def test_empty_origin_rejected(self): + """TEST: CORS-01 -- Empty origin is never allowed.""" + assert _is_origin_allowed("", ["https://example.com"]) is False + assert _is_origin_allowed(None, "*") is False + + def test_wildcard_string(self): + """TEST: CORS-02 -- Wildcard string allows any origin.""" + assert _is_origin_allowed("https://evil.com", "*") is True + + def test_wildcard_in_list(self): + """TEST: CORS-03 -- Wildcard in list allows any origin.""" + assert _is_origin_allowed("https://evil.com", ["*", "https://example.com"]) is True + + def test_explicit_origin_match(self): + """TEST: CORS-04 -- Explicit list matches exact origin.""" + origins = ["https://example.com", "http://localhost:3000"] + assert _is_origin_allowed("https://example.com", origins) is True + assert _is_origin_allowed("https://evil.com", origins) is False + + def test_empty_origins_list(self): + """TEST: CORS-05 -- Empty list rejects everything.""" + assert _is_origin_allowed("https://example.com", []) is False + + +# --------------------------------------------------------------------------- +# _cors_origin_header +# --------------------------------------------------------------------------- + +class TestCorsOriginHeader: + def test_wildcard_with_origin_echoes(self): + """TEST: CORS-HDR-01 -- Wildcard echoes request origin (for credentials).""" + assert _cors_origin_header("*", "https://example.com") == "https://example.com" + + def test_wildcard_without_origin(self): + """TEST: CORS-HDR-02 -- Wildcard with no origin returns *.""" + assert _cors_origin_header("*", None) == "*" + + def test_wildcard_in_list_with_origin(self): + """TEST: CORS-HDR-03 -- Wildcard in list echoes request origin.""" + result = _cors_origin_header(["*", "https://example.com"], "https://any.com") + assert result == "https://any.com" + + def test_specific_origin_match(self): + """TEST: CORS-HDR-04 -- Matching origin is echoed.""" + origins = ["https://example.com"] + assert _cors_origin_header(origins, "https://example.com") == "https://example.com" + + def test_specific_origin_no_match(self): + """TEST: CORS-HDR-05 -- Non-matching origin returns None.""" + origins = ["https://example.com"] + assert _cors_origin_header(origins, "https://evil.com") is None + + def test_no_origin_no_match(self): + """TEST: CORS-HDR-06 -- No origin with specific list returns None.""" + origins = ["https://example.com"] + assert _cors_origin_header(origins, None) is None + + +# --------------------------------------------------------------------------- +# Integration: preflight response +# --------------------------------------------------------------------------- + +class TestPreflightIntegration: + @pytest.fixture + def app_wildcard(self): + app = Flask(__name__) + app.config["CORS_ORIGINS"] = "*" + app.config["CORS_SUPPORTS_CREDENTIALS"] = True + setup_cors(app) + app.config["TESTING"] = True + return app + + @pytest.fixture + def app_specific(self): + app = Flask(__name__) + app.config["CORS_ORIGINS"] = ["https://example.com"] + app.config["CORS_SUPPORTS_CREDENTIALS"] = True + setup_cors(app) + app.config["TESTING"] = True + return app + + def test_wildcard_preflight_echoes_origin(self, app_wildcard): + """TEST: CORS-PF-01 -- Wildcard preflight echoes request origin.""" + with app_wildcard.test_client() as client: + resp = client.options("/", headers={"Origin": "https://example.com"}) + assert resp.status_code == 204 + assert resp.headers.get("Access-Control-Allow-Origin") == "https://example.com" + assert resp.headers.get("Access-Control-Allow-Credentials") == "true" + + def test_specific_origin_preflight(self, app_specific): + """TEST: CORS-PF-02 -- Specific origin preflight allows matching origin.""" + with app_specific.test_client() as client: + resp = client.options("/", headers={"Origin": "https://example.com"}) + assert resp.status_code == 204 + assert resp.headers.get("Access-Control-Allow-Origin") == "https://example.com" + assert resp.headers.get("Access-Control-Allow-Credentials") == "true" + + def test_specific_origin_rejects_unknown(self, app_specific): + """TEST: CORS-PF-03 -- Non-matching origin gets no CORS headers.""" + with app_specific.test_client() as client: + resp = client.options("/", headers={"Origin": "https://evil.com"}) + # No preflight handler runs, Flask returns default + assert resp.headers.get("Access-Control-Allow-Origin") is None diff --git a/tests/unit/test_encryption.py b/tests/unit/test_encryption.py new file mode 100644 index 0000000..d54ee19 --- /dev/null +++ b/tests/unit/test_encryption.py @@ -0,0 +1,164 @@ +"""Unit tests for encryption module (general-purpose Fernet encryption). + +WHAT: Tests for the PBKDF2-based Fernet encryption/decryption used for + OAuth tokens and client secrets. +WHY: These utilities protect access tokens and client secrets; we need + to verify round-trip correctness and error handling. +EXPECTED: All encrypt/decrypt operations produce correct plaintext. +""" +import threading + +import pytest + +from gatehouse_app.utils.encryption import ( + SALT_LENGTH, + _get_fernet_key, + decrypt, + encrypt, +) + + +# --------------------------------------------------------------------------- +# Shared fixture +# --------------------------------------------------------------------------- + +SECRET_KEY = "test-encryption-secret-key" +SAMPLE_DATA = "access_token=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload" + + +# --------------------------------------------------------------------------- +# encrypt / decrypt round-trip +# --------------------------------------------------------------------------- + +class TestEncryptDecryptRoundTrip: + """Verify that encrypt -> decrypt returns the original plaintext.""" + + def test_basic_round_trip(self): + """TEST: ENC-RT-01 -- Encrypt then decrypt returns original data.""" + encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) + decrypted = decrypt(encrypted, secret_key=SECRET_KEY) + assert decrypted == SAMPLE_DATA + + def test_encrypted_is_base64(self): + """TEST: ENC-RT-02 -- Encrypted output is valid base64.""" + import base64 + encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) + # Should not raise + base64.urlsafe_b64decode(encrypted.encode()) + + def test_different_ciphertext_each_time(self): + """TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ.""" + enc1 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) + enc2 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) + assert enc1 != enc2 + assert decrypt(enc1, secret_key=SECRET_KEY) == SAMPLE_DATA + assert decrypt(enc2, secret_key=SECRET_KEY) == SAMPLE_DATA + + def test_round_trip_unicode(self): + """TEST: ENC-RT-04 -- Unicode data round-trips correctly.""" + data = "token=cafe\u00e9\u00f1\u00fc" + encrypted = encrypt(data, secret_key=SECRET_KEY) + assert decrypt(encrypted, secret_key=SECRET_KEY) == data + + def test_round_trip_long_data(self): + """TEST: ENC-RT-05 -- Large data round-trips correctly.""" + data = "x" * 10000 + encrypted = encrypt(data, secret_key=SECRET_KEY) + assert decrypt(encrypted, secret_key=SECRET_KEY) == data + + +# --------------------------------------------------------------------------- +# Empty / edge inputs +# --------------------------------------------------------------------------- + +class TestEdgeCases: + def test_encrypt_empty_returns_empty(self): + """TEST: ENC-EDGE-01 -- Encrypting empty string returns empty.""" + assert encrypt("", secret_key=SECRET_KEY) == "" + + def test_decrypt_empty_returns_empty(self): + """TEST: ENC-EDGE-02 -- Decrypting empty string returns empty.""" + assert decrypt("", secret_key=SECRET_KEY) == "" + + def test_missing_key_raises_on_encrypt(self): + """TEST: ENC-EDGE-03 -- Missing key raises ValueError on encrypt.""" + with pytest.raises(ValueError, match="Encryption key not configured"): + encrypt("data", secret_key="") + + def test_missing_key_raises_on_decrypt(self): + """TEST: ENC-EDGE-04 -- Missing key raises ValueError on decrypt.""" + with pytest.raises(ValueError, match="Encryption key not configured"): + decrypt("something", secret_key="") + + def test_wrong_key_raises_on_decrypt(self): + """TEST: ENC-EDGE-05 -- Wrong key raises ValueError on decrypt.""" + encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) + with pytest.raises(ValueError, match="Failed to decrypt"): + decrypt(encrypted, secret_key="wrong-key") + + def test_corrupted_data_raises(self): + """TEST: ENC-EDGE-06 -- Corrupted ciphertext raises ValueError.""" + import base64 + bad = base64.urlsafe_b64encode(b"not-valid-fernet-data").decode() + with pytest.raises(ValueError): + decrypt(bad, secret_key=SECRET_KEY) + + +# --------------------------------------------------------------------------- +# _get_fernet_key — PBKDF2 derivation +# --------------------------------------------------------------------------- + +class TestKeyDerivation: + def test_same_salt_same_key(self): + """TEST: ENC-KD-01 -- Same salt produces the same derived key.""" + salt = b"\x00" * SALT_LENGTH + key1 = _get_fernet_key(SECRET_KEY, salt=salt) + key2 = _get_fernet_key(SECRET_KEY, salt=salt) + assert key1 == key2 + + def test_different_salt_different_key(self): + """TEST: ENC-KD-02 -- Different salts produce different keys.""" + salt1 = b"\x00" * SALT_LENGTH + salt2 = b"\xff" * SALT_LENGTH + key1 = _get_fernet_key(SECRET_KEY, salt=salt1) + key2 = _get_fernet_key(SECRET_KEY, salt=salt2) + assert key1 != key2 + + def test_auto_salt_length(self): + """TEST: ENC-KD-03 -- Auto-generated salt is 16 bytes.""" + key = _get_fernet_key(SECRET_KEY) + # If it didn't raise, the salt was valid + assert len(key) > 0 + + +# --------------------------------------------------------------------------- +# Thread safety +# --------------------------------------------------------------------------- + +class TestThreadSafety: + """Concurrent encrypt/decrypt calls must not corrupt state.""" + + def test_concurrent_encrypt_decrypt(self): + """TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently.""" + errors = [] + results = [] + + def worker(i): + try: + data = f"token-{i}-secret" + enc = encrypt(data, secret_key=SECRET_KEY) + dec = decrypt(enc, secret_key=SECRET_KEY) + results.append((i, dec)) + except Exception as exc: + errors.append((i, exc)) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + + assert not errors, f"Thread errors: {errors}" + assert len(results) == 50 + for i, dec in results: + assert dec == f"token-{i}-secret", f"Thread {i}: mismatch"