remove junk
This commit is contained in:
@@ -1,135 +0,0 @@
|
|||||||
# OIDC Extension to Seed Data Script
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
Extended [`scripts/seed_data.py`](scripts/seed_data.py) to include OIDC client seeding functionality.
|
|
||||||
|
|
||||||
## Changes Made
|
|
||||||
|
|
||||||
### 1. Added Imports
|
|
||||||
- `import secrets` - For generating secure random values
|
|
||||||
- `import hashlib` - For hashing client secrets
|
|
||||||
- `from app.models.oidc_client import OIDCClient` - OIDC client model
|
|
||||||
|
|
||||||
### 2. New Helper Function: `create_or_get_oidc_client()`
|
|
||||||
Creates OIDC clients with proper configuration or returns existing ones. Features:
|
|
||||||
- Checks for existing clients by `client_id`
|
|
||||||
- Hashes client secrets using SHA256
|
|
||||||
- Supports all OIDC client configuration options
|
|
||||||
- Proper error handling and logging
|
|
||||||
|
|
||||||
### 3. New Seed Step: Step 5 - Create OIDC Clients
|
|
||||||
|
|
||||||
Added 4 OIDC clients across the 3 seeded organizations:
|
|
||||||
|
|
||||||
#### Acme Corporation (2 clients)
|
|
||||||
1. **Acme Internal Portal** (`acme-portal-001`)
|
|
||||||
- Confidential client
|
|
||||||
- Grant types: authorization_code, refresh_token
|
|
||||||
- Scopes: openid, profile, email, offline_access
|
|
||||||
- PKCE required
|
|
||||||
- Redirect URIs for production and localhost
|
|
||||||
|
|
||||||
2. **Acme Mobile App** (`acme-mobile-001`)
|
|
||||||
- Public client (mobile app)
|
|
||||||
- Shorter token lifetimes for security
|
|
||||||
- PKCE required
|
|
||||||
- Custom URL scheme for mobile redirect
|
|
||||||
|
|
||||||
#### Tech Startup Inc (1 client)
|
|
||||||
3. **Tech Startup Dashboard** (`tech-dashboard-001`)
|
|
||||||
- Confidential client
|
|
||||||
- Standard OIDC configuration
|
|
||||||
- PKCE required
|
|
||||||
|
|
||||||
#### Data Systems Inc (1 client)
|
|
||||||
4. **Data Systems API Client** (`data-api-001`)
|
|
||||||
- Confidential server-to-server client
|
|
||||||
- Additional grant type: client_credentials
|
|
||||||
- Custom scopes: api:read, api:write
|
|
||||||
- PKCE not required (server-to-server)
|
|
||||||
|
|
||||||
## OIDC Client Test Credentials
|
|
||||||
|
|
||||||
All clients are configured with test credentials for development:
|
|
||||||
|
|
||||||
| Client | Client ID | Client Secret |
|
|
||||||
|--------|-----------|---------------|
|
|
||||||
| Acme Portal | `acme-portal-001` | `acme_secret_portal_2024` |
|
|
||||||
| Acme Mobile | `acme-mobile-001` | `acme_secret_mobile_2024` |
|
|
||||||
| Tech Dashboard | `tech-dashboard-001` | `tech_secret_dashboard_2024` |
|
|
||||||
| Data API | `data-api-001` | `data_secret_api_2024` |
|
|
||||||
|
|
||||||
## Enhanced Summary Output
|
|
||||||
|
|
||||||
The seed script now displays:
|
|
||||||
- Total count of OIDC clients created
|
|
||||||
- Detailed information for each client including:
|
|
||||||
- Client name and ID
|
|
||||||
- Organization
|
|
||||||
- Configured grant types
|
|
||||||
- Configured scopes
|
|
||||||
- Number of redirect URIs
|
|
||||||
- Complete test credentials table
|
|
||||||
|
|
||||||
## Example Output
|
|
||||||
|
|
||||||
```
|
|
||||||
[Step 5] Creating OIDC Clients...
|
|
||||||
|
|
||||||
Acme Corporation OIDC Clients:
|
|
||||||
→ Created OIDC client: Acme Internal Portal
|
|
||||||
→ Created OIDC client: Acme Mobile App
|
|
||||||
|
|
||||||
Tech Startup OIDC Clients:
|
|
||||||
→ Created OIDC client: Tech Startup Dashboard
|
|
||||||
|
|
||||||
Data Systems OIDC Clients:
|
|
||||||
→ Created OIDC client: Data Systems API Client
|
|
||||||
|
|
||||||
Created 4 OIDC clients
|
|
||||||
|
|
||||||
============================================================
|
|
||||||
Seed Complete!
|
|
||||||
============================================================
|
|
||||||
|
|
||||||
📊 Summary:
|
|
||||||
Organizations: 3
|
|
||||||
Admin Users: 2
|
|
||||||
Regular Users: 9
|
|
||||||
OIDC Clients: 4
|
|
||||||
|
|
||||||
🔐 OIDC Clients:
|
|
||||||
Acme Internal Portal
|
|
||||||
Client ID: acme-portal-001
|
|
||||||
Organization: Acme Corporation
|
|
||||||
Grant Types: authorization_code, refresh_token
|
|
||||||
Scopes: openid, profile, email, offline_access
|
|
||||||
Redirect URIs: 2 configured
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- **Idempotent**: Running the script multiple times won't create duplicate clients
|
|
||||||
- **Comprehensive**: Creates diverse client types (confidential, public, server-to-server)
|
|
||||||
- **Production-ready**: Includes proper secret hashing and security configurations
|
|
||||||
- **Developer-friendly**: Includes localhost URLs and clear test credentials
|
|
||||||
- **Well-documented**: Clear console output showing what was created
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
Run the seed script as usual:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/seed_data.py
|
|
||||||
```
|
|
||||||
|
|
||||||
The OIDC clients will be automatically created along with users and organizations.
|
|
||||||
|
|
||||||
## Security Notes
|
|
||||||
|
|
||||||
- Client secrets are hashed using SHA256 before storage
|
|
||||||
- Test credentials are clearly marked and should **not** be used in production
|
|
||||||
- PKCE is enabled by default for web and mobile clients
|
|
||||||
- Token lifetimes are configured appropriately for each client type
|
|
||||||
@@ -1,171 +0,0 @@
|
|||||||
# TOTP End-to-End Test Proposal
|
|
||||||
|
|
||||||
## Test Objective
|
|
||||||
Test ALL aspects of TOTP functionality regardless of current state (TOTP enabled or disabled).
|
|
||||||
|
|
||||||
## Test Flow
|
|
||||||
|
|
||||||
### Scenario A: TOTP Currently Enabled (Bob already enrolled)
|
|
||||||
|
|
||||||
1. **Login** with email/password
|
|
||||||
- Response: `requires_totp: true`
|
|
||||||
|
|
||||||
2. **Get Secret from DB** (or use environment variable)
|
|
||||||
- Since secret is encrypted/hashed in DB, we need to either:
|
|
||||||
- Store it in environment/file from previous enrollment, OR
|
|
||||||
- User provides it as input, OR
|
|
||||||
- Use backup code from previous enrollment
|
|
||||||
|
|
||||||
3. **Generate TOTP Code** using stored secret/backup code
|
|
||||||
|
|
||||||
4. **Verify TOTP** to complete login
|
|
||||||
- Endpoint: `/auth/totp/verify`
|
|
||||||
- Get auth_token
|
|
||||||
|
|
||||||
5. **Check TOTP Status**
|
|
||||||
- Endpoint: `/auth/totp/status`
|
|
||||||
- Confirm: `totp_enabled: true`
|
|
||||||
|
|
||||||
6. **Disable TOTP**
|
|
||||||
- Endpoint: `/auth/totp/disable`
|
|
||||||
- Provide password
|
|
||||||
|
|
||||||
7. **Logout**
|
|
||||||
|
|
||||||
8. **Continue to Scenario B steps 2-14**
|
|
||||||
|
|
||||||
### Scenario B: TOTP Currently Disabled (or after completing Scenario A)
|
|
||||||
|
|
||||||
1. **Login** with email/password
|
|
||||||
- Response: `token` (no TOTP required)
|
|
||||||
|
|
||||||
2. **Check TOTP Status**
|
|
||||||
- Endpoint: `/auth/totp/status`
|
|
||||||
- Confirm: `totp_enabled: false`
|
|
||||||
|
|
||||||
3. **Enroll in TOTP**
|
|
||||||
- Endpoint: `/auth/totp/enroll`
|
|
||||||
- Store: secret, backup_codes, provisioning_uri, qr_code
|
|
||||||
|
|
||||||
4. **Generate TOTP Code** from new secret
|
|
||||||
- Use timezone-aware UTC
|
|
||||||
|
|
||||||
5. **Verify Enrollment**
|
|
||||||
- Endpoint: `/auth/totp/verify-enrollment`
|
|
||||||
- Provide generated code
|
|
||||||
|
|
||||||
6. **Check TOTP Status Again**
|
|
||||||
- Confirm: `totp_enabled: true`
|
|
||||||
- Confirm: `backup_codes_remaining: 10`
|
|
||||||
- Confirm: `verified_at` is set
|
|
||||||
|
|
||||||
7. **Logout**
|
|
||||||
|
|
||||||
8. **Login** with email/password
|
|
||||||
- Response: `requires_totp: true`
|
|
||||||
|
|
||||||
9. **Generate TOTP Code** from stored secret
|
|
||||||
|
|
||||||
10. **Verify TOTP** to complete login
|
|
||||||
- Endpoint: `/auth/totp/verify`
|
|
||||||
- Get auth_token
|
|
||||||
|
|
||||||
11. **Confirm Logged In**
|
|
||||||
- Endpoint: `/auth/me`
|
|
||||||
- Verify user data returned
|
|
||||||
|
|
||||||
12. **Test Backup Code** (new login)
|
|
||||||
- Logout
|
|
||||||
- Login with email/password
|
|
||||||
- Use backup code instead of TOTP
|
|
||||||
- Endpoint: `/auth/totp/verify` with `is_backup_code: true`
|
|
||||||
|
|
||||||
13. **Check Backup Codes Remaining**
|
|
||||||
- Should be 9 (one consumed)
|
|
||||||
|
|
||||||
14. **Regenerate Backup Codes**
|
|
||||||
- Endpoint: `/auth/totp/regenerate-backup-codes`
|
|
||||||
- Provide password
|
|
||||||
- Get new set of 10 codes
|
|
||||||
|
|
||||||
## Implementation Strategy
|
|
||||||
|
|
||||||
### Secret Persistence Between Test Runs
|
|
||||||
|
|
||||||
**Option 1: Environment Variable** (Recommended)
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Save secret after first successful enrollment
|
|
||||||
SECRET_FILE = ".totp_test_secret"
|
|
||||||
|
|
||||||
if os.path.exists(SECRET_FILE):
|
|
||||||
with open(SECRET_FILE) as f:
|
|
||||||
data = json.load(f)
|
|
||||||
known_secret = data.get("secret")
|
|
||||||
known_backup_codes = data.get("backup_codes", [])
|
|
||||||
else:
|
|
||||||
known_secret = None
|
|
||||||
known_backup_codes = []
|
|
||||||
|
|
||||||
# After enrollment, save for next run
|
|
||||||
with open(SECRET_FILE, 'w') as f:
|
|
||||||
json.dump({
|
|
||||||
"secret": new_secret,
|
|
||||||
"backup_codes": new_backup_codes
|
|
||||||
}, f)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Option 2: Test Database State**
|
|
||||||
- Include SQL query to fetch secret from DB (if stored in plain text for testing)
|
|
||||||
- Or decrypt if encrypted
|
|
||||||
|
|
||||||
**Option 3: Manual Input**
|
|
||||||
- Prompt user for secret/backup code if TOTP already enabled
|
|
||||||
- Less automated but more flexible
|
|
||||||
|
|
||||||
## Expected Assertions
|
|
||||||
|
|
||||||
1. ✅ Login without TOTP works when disabled
|
|
||||||
2. ✅ Enrollment generates secret, QR code, backup codes
|
|
||||||
3. ✅ Enrollment verification accepts valid TOTP code
|
|
||||||
4. ✅ TOTP status shows enabled after verification
|
|
||||||
5. ✅ Login requires TOTP when enabled
|
|
||||||
6. ✅ TOTP verification works during login
|
|
||||||
7. ✅ Backup code works for authentication
|
|
||||||
8. ✅ Backup codes decrement when used
|
|
||||||
9. ✅ Backup code regeneration works
|
|
||||||
10. ✅ TOTP disable works with correct password
|
|
||||||
11. ✅ Login works without TOTP after disabling
|
|
||||||
|
|
||||||
## Test Data Management
|
|
||||||
|
|
||||||
Store in `.totp_test_data.json` (gitignored):
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"user": "bob@acme-corp.com",
|
|
||||||
"secret": "BWAQAP55...",
|
|
||||||
"backup_codes": ["code1", "code2", ...],
|
|
||||||
"enrollment_date": "2026-01-14T03:12:00Z",
|
|
||||||
"last_test_run": "2026-01-14T03:15:00Z"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Error Handling
|
|
||||||
|
|
||||||
- Connection errors → clear message about server not running
|
|
||||||
- 401 errors → check if token/credentials are correct
|
|
||||||
- TOTP code failures → check time synchronization
|
|
||||||
- Backup code failures → check if already used
|
|
||||||
|
|
||||||
## Success Criteria
|
|
||||||
|
|
||||||
Test passes when:
|
|
||||||
1. All 14 steps complete without errors
|
|
||||||
2. All assertions pass
|
|
||||||
3. Test can run multiple times (idempotent)
|
|
||||||
4. Works from both initial states (TOTP enabled/disabled)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Please review this proposal. Once approved, I'll implement it.**
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
# Manual TOTP Reset for Testing
|
|
||||||
|
|
||||||
Since Bob has TOTP enabled, you have two options to run the full test:
|
|
||||||
|
|
||||||
## Option 1: Restart Flask Server (Easiest)
|
|
||||||
The Flask server running on port 8888 uses an in-memory SQLite database.
|
|
||||||
Simply restart it to clear all data:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Stop the server (Ctrl+C in the terminal)
|
|
||||||
# Then restart it
|
|
||||||
cd gatehouse-api
|
|
||||||
.venv/bin/flask run --debug --port 8888
|
|
||||||
```
|
|
||||||
|
|
||||||
Then run the test:
|
|
||||||
```bash
|
|
||||||
.venv/bin/python test_totp_full.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Option 2: Use the TOTP Secret
|
|
||||||
|
|
||||||
If you have the secret from the previous enrollment (check `.totp_test_data.json` if it exists):
|
|
||||||
|
|
||||||
1. Edit `test_totp_full.py`
|
|
||||||
2. Update the `test_data` initialization:
|
|
||||||
```python
|
|
||||||
test_data = {
|
|
||||||
"secret": "YOUR_SECRET_HERE", # From previous enrollment
|
|
||||||
"backup_codes": ["CODE1", "CODE2", ...], # From previous enrollment
|
|
||||||
"last_run": None
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Run the test
|
|
||||||
|
|
||||||
## Option 3: Database Direct Access (if file-based DB)
|
|
||||||
|
|
||||||
If using PostgreSQL or file-based SQLite:
|
|
||||||
|
|
||||||
```sql
|
|
||||||
DELETE FROM authentication_methods
|
|
||||||
WHERE user_id = (SELECT id FROM users WHERE email = 'bob@acme-corp.com')
|
|
||||||
AND method_type = 'totp';
|
|
||||||
```
|
|
||||||
|
|
||||||
The test will then run through the complete flow and save the new secret/codes to `.totp_test_data.json` for subsequent runs.
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
"""Add application-wide external auth provider config tables
|
|
||||||
|
|
||||||
Revision ID: 4edc2fce47c5
|
|
||||||
Revises: a4d4a17a5d15
|
|
||||||
Create Date: 2026-01-20 16:02:34.196815
|
|
||||||
|
|
||||||
"""
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision = '4edc2fce47c5'
|
|
||||||
down_revision = 'a4d4a17a5d15'
|
|
||||||
branch_labels = None
|
|
||||||
depends_on = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table('application_provider_configs',
|
|
||||||
sa.Column('provider_type', sa.String(length=50), nullable=False),
|
|
||||||
sa.Column('client_id', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('client_secret_encrypted', sa.String(length=512), nullable=True),
|
|
||||||
sa.Column('is_enabled', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('default_redirect_url', sa.String(length=2048), nullable=True),
|
|
||||||
sa.Column('additional_config', sa.JSON(), nullable=True),
|
|
||||||
sa.Column('id', sa.String(length=36), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
|
||||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
|
||||||
sa.PrimaryKeyConstraint('id'),
|
|
||||||
sa.UniqueConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_index(op.f('ix_application_provider_configs_provider_type'), 'application_provider_configs', ['provider_type'], unique=True)
|
|
||||||
op.create_table('organization_provider_overrides',
|
|
||||||
sa.Column('organization_id', sa.String(length=36), nullable=False),
|
|
||||||
sa.Column('provider_type', sa.String(length=50), nullable=False),
|
|
||||||
sa.Column('client_id', sa.String(length=255), nullable=True),
|
|
||||||
sa.Column('client_secret_encrypted', sa.String(length=512), nullable=True),
|
|
||||||
sa.Column('is_enabled', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('redirect_url_override', sa.String(length=2048), nullable=True),
|
|
||||||
sa.Column('additional_config', sa.JSON(), nullable=True),
|
|
||||||
sa.Column('id', sa.String(length=36), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
|
||||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
|
||||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
|
||||||
sa.PrimaryKeyConstraint('id'),
|
|
||||||
sa.UniqueConstraint('id'),
|
|
||||||
sa.UniqueConstraint('organization_id', 'provider_type', name='uix_org_provider_type')
|
|
||||||
)
|
|
||||||
op.create_index(op.f('ix_organization_provider_overrides_organization_id'), 'organization_provider_overrides', ['organization_id'], unique=False)
|
|
||||||
op.create_index(op.f('ix_organization_provider_overrides_provider_type'), 'organization_provider_overrides', ['provider_type'], unique=False)
|
|
||||||
op.add_column('oauth_states', sa.Column('return_url', sa.String(length=2048), nullable=True))
|
|
||||||
op.drop_index(op.f('ix_oauth_states_user_id'), table_name='oauth_states')
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_index(op.f('ix_oauth_states_user_id'), 'oauth_states', ['user_id'], unique=False)
|
|
||||||
op.drop_column('oauth_states', 'return_url')
|
|
||||||
op.drop_index(op.f('ix_organization_provider_overrides_provider_type'), table_name='organization_provider_overrides')
|
|
||||||
op.drop_index(op.f('ix_organization_provider_overrides_organization_id'), table_name='organization_provider_overrides')
|
|
||||||
op.drop_table('organization_provider_overrides')
|
|
||||||
op.drop_index(op.f('ix_application_provider_configs_provider_type'), table_name='application_provider_configs')
|
|
||||||
op.drop_table('application_provider_configs')
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
"""empty message
|
|
||||||
|
|
||||||
Revision ID: a4d4a17a5d15
|
|
||||||
Revises: 004
|
|
||||||
Create Date: 2026-01-20 14:30:36.898886
|
|
||||||
|
|
||||||
"""
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision = 'a4d4a17a5d15'
|
|
||||||
down_revision = '004'
|
|
||||||
branch_labels = None
|
|
||||||
depends_on = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table('external_provider_configs',
|
|
||||||
sa.Column('organization_id', sa.String(length=36), nullable=False),
|
|
||||||
sa.Column('provider_type', sa.String(length=50), nullable=False),
|
|
||||||
sa.Column('client_id', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('client_secret_encrypted', sa.String(length=512), nullable=True),
|
|
||||||
sa.Column('auth_url', sa.String(length=2048), nullable=False),
|
|
||||||
sa.Column('token_url', sa.String(length=2048), nullable=False),
|
|
||||||
sa.Column('userinfo_url', sa.String(length=2048), nullable=True),
|
|
||||||
sa.Column('jwks_url', sa.String(length=2048), nullable=True),
|
|
||||||
sa.Column('scopes', sa.JSON(), nullable=False),
|
|
||||||
sa.Column('redirect_uris', sa.JSON(), nullable=False),
|
|
||||||
sa.Column('settings', sa.JSON(), nullable=True),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('id', sa.String(length=36), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
|
||||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
|
||||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
|
||||||
sa.PrimaryKeyConstraint('id'),
|
|
||||||
sa.UniqueConstraint('id'),
|
|
||||||
sa.UniqueConstraint('organization_id', 'provider_type', name='uix_org_provider_type')
|
|
||||||
)
|
|
||||||
op.create_index('idx_provider_config_org', 'external_provider_configs', ['organization_id', 'provider_type'], unique=False)
|
|
||||||
op.create_index(op.f('ix_external_provider_configs_organization_id'), 'external_provider_configs', ['organization_id'], unique=False)
|
|
||||||
op.create_index(op.f('ix_external_provider_configs_provider_type'), 'external_provider_configs', ['provider_type'], unique=False)
|
|
||||||
op.create_table('oauth_states',
|
|
||||||
sa.Column('state', sa.String(length=64), nullable=False),
|
|
||||||
sa.Column('flow_type', sa.String(length=50), nullable=False),
|
|
||||||
sa.Column('user_id', sa.String(length=36), nullable=True),
|
|
||||||
sa.Column('organization_id', sa.String(length=36), nullable=True),
|
|
||||||
sa.Column('provider_type', sa.String(length=50), nullable=False),
|
|
||||||
sa.Column('nonce', sa.String(length=128), nullable=True),
|
|
||||||
sa.Column('code_verifier', sa.String(length=128), nullable=True),
|
|
||||||
sa.Column('code_challenge', sa.String(length=128), nullable=True),
|
|
||||||
sa.Column('redirect_uri', sa.String(length=2048), nullable=True),
|
|
||||||
sa.Column('extra_data', sa.JSON(), nullable=True),
|
|
||||||
sa.Column('expires_at', sa.DateTime(), nullable=False),
|
|
||||||
sa.Column('used', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('id', sa.String(length=36), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
|
||||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
|
||||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
|
||||||
sa.PrimaryKeyConstraint('id'),
|
|
||||||
sa.UniqueConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_index(op.f('ix_oauth_states_expires_at'), 'oauth_states', ['expires_at'], unique=False)
|
|
||||||
op.create_index(op.f('ix_oauth_states_organization_id'), 'oauth_states', ['organization_id'], unique=False)
|
|
||||||
op.create_index(op.f('ix_oauth_states_state'), 'oauth_states', ['state'], unique=True)
|
|
||||||
op.create_index(op.f('ix_oauth_states_user_id'), 'oauth_states', ['user_id'], unique=False)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_index(op.f('ix_oauth_states_user_id'), table_name='oauth_states')
|
|
||||||
op.drop_index(op.f('ix_oauth_states_state'), table_name='oauth_states')
|
|
||||||
op.drop_index(op.f('ix_oauth_states_organization_id'), table_name='oauth_states')
|
|
||||||
op.drop_index(op.f('ix_oauth_states_expires_at'), table_name='oauth_states')
|
|
||||||
op.drop_table('oauth_states')
|
|
||||||
op.drop_index(op.f('ix_external_provider_configs_provider_type'), table_name='external_provider_configs')
|
|
||||||
op.drop_index(op.f('ix_external_provider_configs_organization_id'), table_name='external_provider_configs')
|
|
||||||
op.drop_index('idx_provider_config_org', table_name='external_provider_configs')
|
|
||||||
op.drop_table('external_provider_configs')
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
-102
@@ -1,102 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
ISSUER="https://oidctest.wsweet.org"
|
|
||||||
CLIENT_ID="secret"
|
|
||||||
CLIENT_SECRET="tardis"
|
|
||||||
REDIRECT_URI="http://127.0.0.1:5556/callback"
|
|
||||||
SCOPE="openid profile email offline_access"
|
|
||||||
|
|
||||||
# ---------------------------
|
|
||||||
# Discover OIDC endpoints
|
|
||||||
# ---------------------------
|
|
||||||
DISCOVERY=$(curl -s "$ISSUER/.well-known/openid-configuration")
|
|
||||||
|
|
||||||
AUTH_ENDPOINT=$(echo "$DISCOVERY" | jq -r .authorization_endpoint)
|
|
||||||
TOKEN_ENDPOINT=$(echo "$DISCOVERY" | jq -r .token_endpoint)
|
|
||||||
USERINFO_ENDPOINT=$(echo "$DISCOVERY" | jq -r .userinfo_endpoint)
|
|
||||||
|
|
||||||
echo "Auth endpoint : $AUTH_ENDPOINT"
|
|
||||||
echo "Token endpoint: $TOKEN_ENDPOINT"
|
|
||||||
echo
|
|
||||||
|
|
||||||
# ---------------------------
|
|
||||||
# PKCE
|
|
||||||
# ---------------------------
|
|
||||||
CODE_VERIFIER=$(openssl rand -base64 32 | tr -d '=+/')
|
|
||||||
CODE_CHALLENGE=$(echo -n "$CODE_VERIFIER" | openssl dgst -sha256 -binary | openssl base64 | tr -d '=+/' | tr '/+' '_-')
|
|
||||||
|
|
||||||
STATE=$(openssl rand -hex 16)
|
|
||||||
NONCE=$(openssl rand -hex 16)
|
|
||||||
|
|
||||||
# ---------------------------
|
|
||||||
# Build auth URL
|
|
||||||
# ---------------------------
|
|
||||||
AUTH_URL="$AUTH_ENDPOINT?response_type=code\
|
|
||||||
&client_id=$CLIENT_ID\
|
|
||||||
&redirect_uri=$(printf '%s' "$REDIRECT_URI" | jq -s -R -r @uri)\
|
|
||||||
&scope=$(printf '%s' "$SCOPE" | jq -s -R -r @uri)\
|
|
||||||
&state=$STATE\
|
|
||||||
&nonce=$NONCE\
|
|
||||||
&code_challenge=$CODE_CHALLENGE\
|
|
||||||
&code_challenge_method=S256"
|
|
||||||
|
|
||||||
echo "Open this URL in a browser:"
|
|
||||||
echo
|
|
||||||
echo "$AUTH_URL"
|
|
||||||
echo
|
|
||||||
echo "After login you will be redirected to:"
|
|
||||||
echo "$REDIRECT_URI?code=XXXX&state=YYYY"
|
|
||||||
echo
|
|
||||||
read -p "Paste the full redirect URL: " REDIRECT
|
|
||||||
|
|
||||||
CODE=$(echo "$REDIRECT" | sed -n 's/.*code=\([^&]*\).*/\1/p')
|
|
||||||
RETURNED_STATE=$(echo "$REDIRECT" | sed -n 's/.*state=\([^&]*\).*/\1/p')
|
|
||||||
|
|
||||||
if [ "$RETURNED_STATE" != "$STATE" ]; then
|
|
||||||
echo "STATE MISMATCH"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# ---------------------------
|
|
||||||
# Exchange code for tokens
|
|
||||||
# ---------------------------
|
|
||||||
TOKENS=$(curl -s -X POST "$TOKEN_ENDPOINT" \
|
|
||||||
-u "$CLIENT_ID:$CLIENT_SECRET" \
|
|
||||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
|
||||||
-d "grant_type=authorization_code" \
|
|
||||||
-d "code=$CODE" \
|
|
||||||
-d "redirect_uri=$REDIRECT_URI" \
|
|
||||||
-d "code_verifier=$CODE_VERIFIER")
|
|
||||||
|
|
||||||
echo
|
|
||||||
echo "Token response:"
|
|
||||||
echo "$TOKENS" | jq .
|
|
||||||
|
|
||||||
ACCESS_TOKEN=$(echo "$TOKENS" | jq -r .access_token)
|
|
||||||
ID_TOKEN=$(echo "$TOKENS" | jq -r .id_token)
|
|
||||||
|
|
||||||
# ---------------------------
|
|
||||||
# JWT decode function
|
|
||||||
# ---------------------------
|
|
||||||
decode() {
|
|
||||||
echo "$1" | awk -F. '{print $2}' | tr '_-' '/+' | base64 -d 2>/dev/null | jq .
|
|
||||||
}
|
|
||||||
|
|
||||||
echo
|
|
||||||
echo "================ ID TOKEN ================"
|
|
||||||
decode "$ID_TOKEN"
|
|
||||||
|
|
||||||
echo
|
|
||||||
echo "============== ACCESS TOKEN =============="
|
|
||||||
decode "$ACCESS_TOKEN"
|
|
||||||
|
|
||||||
# ---------------------------
|
|
||||||
# Userinfo (optional)
|
|
||||||
# ---------------------------
|
|
||||||
if [ "$USERINFO_ENDPOINT" != "null" ]; then
|
|
||||||
echo
|
|
||||||
echo "=============== USERINFO ================="
|
|
||||||
curl -s -H "Authorization: Bearer $ACCESS_TOKEN" "$USERINFO_ENDPOINT" | jq .
|
|
||||||
fi
|
|
||||||
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Test script to verify OAuth endpoints work without organization_id
|
|
||||||
# This tests the fix for the "Google OAuth is not configured for this organization" error
|
|
||||||
|
|
||||||
API_BASE="http://localhost:5001/api/v1"
|
|
||||||
|
|
||||||
echo "=== Testing OAuth Authorization Endpoint (without organization_id) ==="
|
|
||||||
echo ""
|
|
||||||
echo "1. Initiating Google OAuth login flow (NO organization_id)..."
|
|
||||||
RESPONSE=$(curl -s -X GET "${API_BASE}/auth/external/google/authorize?flow=login")
|
|
||||||
echo "Response: $RESPONSE"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Check if we get an authorization URL
|
|
||||||
if echo "$RESPONSE" | grep -q "authorization_url"; then
|
|
||||||
echo "✅ SUCCESS: Got authorization URL without requiring organization_id"
|
|
||||||
AUTH_URL=$(echo "$RESPONSE" | jq -r '.data.authorization_url')
|
|
||||||
STATE=$(echo "$RESPONSE" | jq -r '.data.state')
|
|
||||||
echo "Authorization URL: $AUTH_URL"
|
|
||||||
echo "State: $STATE"
|
|
||||||
else
|
|
||||||
echo "❌ FAILED: Did not get authorization URL"
|
|
||||||
echo "Error: $(echo "$RESPONSE" | jq -r '.message')"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "=== Testing with organization_id hint (should still work) ==="
|
|
||||||
echo ""
|
|
||||||
echo "2. Initiating Google OAuth login flow (WITH organization_id hint)..."
|
|
||||||
# You'll need to replace this with an actual organization ID from your database
|
|
||||||
ORG_ID="test-org-id"
|
|
||||||
RESPONSE=$(curl -s -X GET "${API_BASE}/auth/external/google/authorize?flow=login&organization_id=${ORG_ID}")
|
|
||||||
echo "Response: $RESPONSE"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
if echo "$RESPONSE" | grep -q "authorization_url"; then
|
|
||||||
echo "✅ SUCCESS: OAuth works with organization_id hint (backward compatible)"
|
|
||||||
else
|
|
||||||
echo "⚠️ Note: This may fail if the organization ID doesn't exist or if app-level config is not set"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "=== Testing Register Flow ==="
|
|
||||||
echo ""
|
|
||||||
echo "3. Initiating Google OAuth register flow (NO organization_id)..."
|
|
||||||
RESPONSE=$(curl -s -X GET "${API_BASE}/auth/external/google/authorize?flow=register")
|
|
||||||
echo "Response: $RESPONSE"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
if echo "$RESPONSE" | grep -q "authorization_url"; then
|
|
||||||
echo "✅ SUCCESS: Register flow works without organization_id"
|
|
||||||
else
|
|
||||||
echo "❌ FAILED: Register flow did not work"
|
|
||||||
echo "Error: $(echo "$RESPONSE" | jq -r '.message')"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "=== Summary ==="
|
|
||||||
echo ""
|
|
||||||
echo "The key fix addresses the error:"
|
|
||||||
echo " 'Google OAuth is not configured for this organization'"
|
|
||||||
echo ""
|
|
||||||
echo "Now OAuth flows work at the APPLICATION level, not requiring"
|
|
||||||
echo "an organization context during initial authentication."
|
|
||||||
echo ""
|
|
||||||
echo "After OAuth callback:"
|
|
||||||
echo " - Single org user → Automatic login"
|
|
||||||
echo " - Multi org user → Organization selection UI"
|
|
||||||
echo " - New user → Organization creation/selection UI"
|
|
||||||
@@ -1,501 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
COMPREHENSIVE TOTP END-TO-END FUNCTIONAL TEST
|
|
||||||
Tests all aspects of TOTP functionality regardless of current state.
|
|
||||||
|
|
||||||
Based on approved proposal in TOTP_TEST_PROPOSAL.md
|
|
||||||
"""
|
|
||||||
import requests
|
|
||||||
import pyotp
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
BASE_URL = "http://localhost:8888/api/v1"
|
|
||||||
CREDENTIALS = {
|
|
||||||
"email": "bob@acme-corp.com",
|
|
||||||
"password": "UserPass123!"
|
|
||||||
}
|
|
||||||
DATA_FILE = ".totp_test_data.json"
|
|
||||||
|
|
||||||
# Test state
|
|
||||||
test_data = {
|
|
||||||
"secret": None,
|
|
||||||
"backup_codes": [],
|
|
||||||
"last_run": None
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_test_data():
|
|
||||||
"""Load test data from previous run."""
|
|
||||||
global test_data
|
|
||||||
if os.path.exists(DATA_FILE):
|
|
||||||
with open(DATA_FILE, 'r') as f:
|
|
||||||
test_data = json.load(f)
|
|
||||||
print(f"📂 Loaded test data from {DATA_FILE}")
|
|
||||||
print(f" Secret: {test_data['secret'][:20] if test_data['secret'] else 'None'}...")
|
|
||||||
print(f" Backup codes: {len(test_data.get('backup_codes', []))}")
|
|
||||||
else:
|
|
||||||
print(f"📂 No previous test data found")
|
|
||||||
|
|
||||||
def save_test_data():
|
|
||||||
"""Save test data for next run."""
|
|
||||||
test_data['last_run'] = datetime.now(timezone.utc).isoformat()
|
|
||||||
with open(DATA_FILE, 'w') as f:
|
|
||||||
json.dump(test_data, f, indent=2)
|
|
||||||
print(f"\n💾 Saved test data to {DATA_FILE}")
|
|
||||||
|
|
||||||
def print_section(step, title):
|
|
||||||
"""Print test section header."""
|
|
||||||
print(f"\n{'='*70}")
|
|
||||||
print(f"[STEP {step}] {title}")
|
|
||||||
print('='*70)
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run comprehensive TOTP test."""
|
|
||||||
|
|
||||||
print("\n" + "="*70)
|
|
||||||
print("COMPREHENSIVE TOTP END-TO-END TEST")
|
|
||||||
print(f"User: {CREDENTIALS['email']}")
|
|
||||||
print(f"Server: {BASE_URL}")
|
|
||||||
print(f"Time: {datetime.now(timezone.utc).isoformat()}")
|
|
||||||
print("="*70)
|
|
||||||
|
|
||||||
load_test_data()
|
|
||||||
|
|
||||||
session = requests.Session()
|
|
||||||
auth_token = None
|
|
||||||
totp = None
|
|
||||||
step = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
# ==================== PHASE 1: INITIAL LOGIN ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Initial Login")
|
|
||||||
|
|
||||||
login_response = session.post(f"{BASE_URL}/auth/login", json=CREDENTIALS)
|
|
||||||
|
|
||||||
if login_response.status_code != 200:
|
|
||||||
print(f"❌ Login failed: {login_response.status_code}")
|
|
||||||
print(json.dumps(login_response.json(), indent=2))
|
|
||||||
return False
|
|
||||||
|
|
||||||
login_data = login_response.json()
|
|
||||||
|
|
||||||
# Check if TOTP is required
|
|
||||||
totp_required = login_data.get("data", {}).get("requires_totp", False)
|
|
||||||
|
|
||||||
if totp_required:
|
|
||||||
print("⚠️ TOTP is ENABLED - login requires verification")
|
|
||||||
|
|
||||||
# We need either saved secret or backup code
|
|
||||||
if test_data.get('secret'):
|
|
||||||
print("ℹ️ Using saved secret to generate TOTP code")
|
|
||||||
totp = pyotp.TOTP(test_data['secret'])
|
|
||||||
utc_now = datetime.now(timezone.utc)
|
|
||||||
code = totp.at(utc_now)
|
|
||||||
print(f" Generated code: {code}")
|
|
||||||
print(f" At time: {utc_now.isoformat()}")
|
|
||||||
|
|
||||||
verify_response = session.post(
|
|
||||||
f"{BASE_URL}/auth/totp/verify",
|
|
||||||
json={"code": code}
|
|
||||||
)
|
|
||||||
|
|
||||||
if verify_response.status_code != 200:
|
|
||||||
print("❌ TOTP code verification failed")
|
|
||||||
print(" Trying backup code...")
|
|
||||||
|
|
||||||
if test_data.get('backup_codes'):
|
|
||||||
# Try first unused backup code
|
|
||||||
for backup_code in test_data['backup_codes']:
|
|
||||||
verify_response = session.post(
|
|
||||||
f"{BASE_URL}/auth/totp/verify",
|
|
||||||
json={"code": backup_code, "is_backup_code": True}
|
|
||||||
)
|
|
||||||
if verify_response.status_code == 200:
|
|
||||||
print(f"✅ Authenticated with backup code: {backup_code}")
|
|
||||||
# Remove used code
|
|
||||||
test_data['backup_codes'].remove(backup_code)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print("❌ All backup codes failed")
|
|
||||||
print("\nPlease manually delete Bob's TOTP from database:")
|
|
||||||
print("DELETE FROM authentication_methods WHERE user_id = (SELECT id FROM users WHERE email = 'bob@acme-corp.com') AND method_type = 'totp';")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
print("❌ No backup codes available")
|
|
||||||
return False
|
|
||||||
|
|
||||||
auth_token = verify_response.json()["data"]["token"]
|
|
||||||
print("✅ Logged in with TOTP verification")
|
|
||||||
|
|
||||||
elif test_data.get('backup_codes'):
|
|
||||||
print("ℹ️ Using backup code to authenticate")
|
|
||||||
|
|
||||||
for backup_code in test_data['backup_codes']:
|
|
||||||
verify_response = session.post(
|
|
||||||
f"{BASE_URL}/auth/totp/verify",
|
|
||||||
json={"code": backup_code, "is_backup_code": True}
|
|
||||||
)
|
|
||||||
if verify_response.status_code == 200:
|
|
||||||
auth_token = verify_response.json()["data"]["token"]
|
|
||||||
print(f"✅ Authenticated with backup code: {backup_code}")
|
|
||||||
test_data['backup_codes'].remove(backup_code)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print("❌ No valid backup codes")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
print("❌ TOTP enabled but no secret or backup codes available")
|
|
||||||
print("\nPlease manually delete Bob's TOTP from database:")
|
|
||||||
print("DELETE FROM authentication_methods WHERE user_id = (SELECT id FROM users WHERE email = 'bob@acme-corp.com') AND method_type = 'totp';")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
auth_token = login_data["data"]["token"]
|
|
||||||
print("✅ Logged in (TOTP not required)")
|
|
||||||
|
|
||||||
# ==================== PHASE 2: CHECK STATUS AND DISABLE IF ENABLED ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Check TOTP Status")
|
|
||||||
|
|
||||||
status_response = session.get(
|
|
||||||
f"{BASE_URL}/auth/totp/status",
|
|
||||||
headers={"Authorization": f"Bearer {auth_token}"}
|
|
||||||
)
|
|
||||||
|
|
||||||
if status_response.status_code != 200:
|
|
||||||
print("❌ Failed to get TOTP status")
|
|
||||||
return False
|
|
||||||
|
|
||||||
status_data = status_response.json()["data"]
|
|
||||||
print(f"TOTP Enabled: {status_data['totp_enabled']}")
|
|
||||||
print(f"Verified At: {status_data.get('verified_at', 'N/A')}")
|
|
||||||
print(f"Backup Codes Remaining: {status_data['backup_codes_remaining']}")
|
|
||||||
|
|
||||||
# If TOTP is enabled, disable it
|
|
||||||
if status_data['totp_enabled']:
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Disable TOTP")
|
|
||||||
|
|
||||||
disable_response = session.delete(
|
|
||||||
f"{BASE_URL}/auth/totp/disable",
|
|
||||||
headers={"Authorization": f"Bearer {auth_token}"},
|
|
||||||
json={"password": CREDENTIALS["password"]}
|
|
||||||
)
|
|
||||||
|
|
||||||
if disable_response.status_code != 200:
|
|
||||||
print("❌ Failed to disable TOTP")
|
|
||||||
print(json.dumps(disable_response.json(), indent=2))
|
|
||||||
return False
|
|
||||||
|
|
||||||
print("✅ TOTP disabled")
|
|
||||||
|
|
||||||
# Clear saved secret/codes since we're starting fresh
|
|
||||||
test_data['secret'] = None
|
|
||||||
test_data['backup_codes'] = []
|
|
||||||
else:
|
|
||||||
print("ℹ️ TOTP already disabled, skipping disable step")
|
|
||||||
|
|
||||||
# ==================== PHASE 3: LOGOUT AND RE-LOGIN ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Logout")
|
|
||||||
|
|
||||||
logout_response = session.post(
|
|
||||||
f"{BASE_URL}/auth/logout",
|
|
||||||
headers={"Authorization": f"Bearer {auth_token}"}
|
|
||||||
)
|
|
||||||
print(f"✅ Logged out (status: {logout_response.status_code})")
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Re-login (TOTP should NOT be required)")
|
|
||||||
|
|
||||||
session = requests.Session() # Fresh session
|
|
||||||
login2_response = session.post(f"{BASE_URL}/auth/login", json=CREDENTIALS)
|
|
||||||
|
|
||||||
if login2_response.status_code != 200:
|
|
||||||
print("❌ Re-login failed")
|
|
||||||
return False
|
|
||||||
|
|
||||||
login2_data = login2_response.json()
|
|
||||||
if login2_data.get("data", {}).get("requires_totp"):
|
|
||||||
print("❌ Login still requires TOTP (should not after disabling)")
|
|
||||||
return False
|
|
||||||
|
|
||||||
auth_token = login2_data["data"]["token"]
|
|
||||||
print("✅ Logged in successfully (no TOTP required)")
|
|
||||||
|
|
||||||
# ==================== PHASE 4: ENROLL IN TOTP ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Enroll in TOTP")
|
|
||||||
|
|
||||||
enroll_response = session.post(
|
|
||||||
f"{BASE_URL}/auth/totp/enroll",
|
|
||||||
headers={"Authorization": f"Bearer {auth_token}"}
|
|
||||||
)
|
|
||||||
|
|
||||||
if enroll_response.status_code != 201:
|
|
||||||
print(f"❌ Enrollment failed: {enroll_response.status_code}")
|
|
||||||
print(json.dumps(enroll_response.json(), indent=2))
|
|
||||||
return False
|
|
||||||
|
|
||||||
enroll_data = enroll_response.json()["data"]
|
|
||||||
new_secret = enroll_data["secret"]
|
|
||||||
new_backup_codes = enroll_data["backup_codes"]
|
|
||||||
provisioning_uri = enroll_data["provisioning_uri"]
|
|
||||||
qr_code = enroll_data.get("qr_code", "")
|
|
||||||
|
|
||||||
print(f"✅ Enrollment initiated")
|
|
||||||
print(f" Secret: {new_secret}")
|
|
||||||
print(f" Provisioning URI: {provisioning_uri}")
|
|
||||||
print(f" QR Code: {'Present (%d bytes)' % len(qr_code) if qr_code else 'Missing'}")
|
|
||||||
print(f" Backup Codes: {len(new_backup_codes)}")
|
|
||||||
|
|
||||||
# Save for later use
|
|
||||||
test_data['secret'] = new_secret
|
|
||||||
test_data['backup_codes'] = new_backup_codes.copy()
|
|
||||||
|
|
||||||
# ==================== PHASE 5: VERIFY ENROLLMENT ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Verify TOTP Enrollment")
|
|
||||||
|
|
||||||
totp = pyotp.TOTP(new_secret)
|
|
||||||
utc_now = datetime.now(timezone.utc)
|
|
||||||
code = totp.at(utc_now)
|
|
||||||
|
|
||||||
print(f"Generated TOTP code: {code}")
|
|
||||||
print(f"At UTC time: {utc_now.isoformat()}")
|
|
||||||
print(f"Timestamp: {utc_now.timestamp()}")
|
|
||||||
|
|
||||||
verify_enrollment_response = session.post(
|
|
||||||
f"{BASE_URL}/auth/totp/verify-enrollment",
|
|
||||||
headers={"Authorization": f"Bearer {auth_token}"},
|
|
||||||
json={"code": code}
|
|
||||||
)
|
|
||||||
|
|
||||||
if verify_enrollment_response.status_code != 200:
|
|
||||||
print(f"❌ Verification failed: {verify_enrollment_response.status_code}")
|
|
||||||
print(json.dumps(verify_enrollment_response.json(), indent=2))
|
|
||||||
return False
|
|
||||||
|
|
||||||
print("✅ TOTP enrollment verified successfully!")
|
|
||||||
|
|
||||||
# ==================== PHASE 6: CONFIRM ENROLLMENT ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Confirm TOTP is Enabled")
|
|
||||||
|
|
||||||
final_status_response = session.get(
|
|
||||||
f"{BASE_URL}/auth/totp/status",
|
|
||||||
headers={"Authorization": f"Bearer {auth_token}"}
|
|
||||||
)
|
|
||||||
|
|
||||||
final_status = final_status_response.json()["data"]
|
|
||||||
if not final_status["totp_enabled"]:
|
|
||||||
print("❌ TOTP not enabled after verification!")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"✅ TOTP is enabled")
|
|
||||||
print(f" Verified at: {final_status['verified_at']}")
|
|
||||||
print(f" Backup codes remaining: {final_status['backup_codes_remaining']}")
|
|
||||||
|
|
||||||
# ==================== PHASE 7: TEST LOGIN WITH TOTP ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Logout")
|
|
||||||
|
|
||||||
session.post(f"{BASE_URL}/auth/logout", headers={"Authorization": f"Bearer {auth_token}"})
|
|
||||||
print("✅ Logged out")
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Login (should REQUIRE TOTP)")
|
|
||||||
|
|
||||||
session2 = requests.Session()
|
|
||||||
login3_response = session2.post(f"{BASE_URL}/auth/login", json=CREDENTIALS)
|
|
||||||
|
|
||||||
if login3_response.status_code != 200:
|
|
||||||
print("❌ Login failed")
|
|
||||||
return False
|
|
||||||
|
|
||||||
login3_data = login3_response.json()
|
|
||||||
if not login3_data.get("data", {}).get("requires_totp"):
|
|
||||||
print("❌ Login did NOT require TOTP (it should!)")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print("✅ Login correctly requires TOTP")
|
|
||||||
|
|
||||||
# ==================== PHASE 8: VERIFY TOTP DURING LOGIN ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Verify TOTP Code During Login")
|
|
||||||
|
|
||||||
utc_now = datetime.now(timezone.utc)
|
|
||||||
login_code = totp.at(utc_now)
|
|
||||||
|
|
||||||
print(f"Generated TOTP code: {login_code}")
|
|
||||||
print(f"At UTC time: {utc_now.isoformat()}")
|
|
||||||
|
|
||||||
verify_login_response = session2.post(
|
|
||||||
f"{BASE_URL}/auth/totp/verify",
|
|
||||||
json={"code": login_code}
|
|
||||||
)
|
|
||||||
|
|
||||||
if verify_login_response.status_code != 200:
|
|
||||||
print(f"❌ TOTP login verification failed: {verify_login_response.status_code}")
|
|
||||||
print(json.dumps(verify_login_response.json(), indent=2))
|
|
||||||
return False
|
|
||||||
|
|
||||||
final_token = verify_login_response.json()["data"]["token"]
|
|
||||||
print("✅ Successfully logged in with TOTP!")
|
|
||||||
print(f" Token: {final_token[:30]}...")
|
|
||||||
|
|
||||||
# ==================== PHASE 9: TEST /auth/me ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Confirm Logged In (/auth/me)")
|
|
||||||
|
|
||||||
me_response = session2.get(
|
|
||||||
f"{BASE_URL}/auth/me",
|
|
||||||
headers={"Authorization": f"Bearer {final_token}"}
|
|
||||||
)
|
|
||||||
|
|
||||||
if me_response.status_code != 200:
|
|
||||||
print("❌ /auth/me failed")
|
|
||||||
return False
|
|
||||||
|
|
||||||
me_data = me_response.json()["data"]
|
|
||||||
print(f"✅ Confirmed logged in as: {me_data['user']['email']}")
|
|
||||||
print(f" User ID: {me_data['user']['id']}")
|
|
||||||
|
|
||||||
# ==================== PHASE 10: TEST BACKUP CODE ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Test Backup Code Login")
|
|
||||||
|
|
||||||
# Logout
|
|
||||||
session2.post(f"{BASE_URL}/auth/logout", headers={"Authorization": f"Bearer {final_token}"})
|
|
||||||
|
|
||||||
# Fresh login
|
|
||||||
session3 = requests.Session()
|
|
||||||
login4_response = session3.post(f"{BASE_URL}/auth/login", json=CREDENTIALS)
|
|
||||||
|
|
||||||
if not login4_response.json().get("data", {}).get("requires_totp"):
|
|
||||||
print("❌ Login should require TOTP")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"ℹ️ Using backup code: {test_data['backup_codes'][0]}")
|
|
||||||
|
|
||||||
backup_verify_response = session3.post(
|
|
||||||
f"{BASE_URL}/auth/totp/verify",
|
|
||||||
json={"code": test_data['backup_codes'][0], "is_backup_code": True}
|
|
||||||
)
|
|
||||||
|
|
||||||
if backup_verify_response.status_code != 200:
|
|
||||||
print("❌ Backup code login failed")
|
|
||||||
print(json.dumps(backup_verify_response.json(), indent=2))
|
|
||||||
return False
|
|
||||||
|
|
||||||
backup_token = backup_verify_response.json()["data"]["token"]
|
|
||||||
print(f"✅ Logged in with backup code!")
|
|
||||||
|
|
||||||
# Remove used code
|
|
||||||
used_code = test_data['backup_codes'].pop(0)
|
|
||||||
|
|
||||||
# ==================== PHASE 11: CHECK BACKUP CODES REMAINING ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Check Backup Codes Remaining")
|
|
||||||
|
|
||||||
status3_response = session3.get(
|
|
||||||
f"{BASE_URL}/auth/totp/status",
|
|
||||||
headers={"Authorization": f"Bearer {backup_token}"}
|
|
||||||
)
|
|
||||||
|
|
||||||
status3_data = status3_response.json()["data"]
|
|
||||||
if status3_data['backup_codes_remaining'] != 9:
|
|
||||||
print(f"❌ Expected 9 backup codes, got {status3_data['backup_codes_remaining']}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"✅ Backup codes remaining: {status3_data['backup_codes_remaining']} (was 10, now 9)")
|
|
||||||
|
|
||||||
# ==================== PHASE 12: REGENERATE BACKUP CODES ====================
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
print_section(step, "Regenerate Backup Codes")
|
|
||||||
|
|
||||||
regen_response = session3.post(
|
|
||||||
f"{BASE_URL}/auth/totp/regenerate-backup-codes",
|
|
||||||
headers={"Authorization": f"Bearer {backup_token}"},
|
|
||||||
json={"password": CREDENTIALS["password"]}
|
|
||||||
)
|
|
||||||
|
|
||||||
if regen_response.status_code != 200:
|
|
||||||
print("❌ Failed to regenerate backup codes")
|
|
||||||
print(json.dumps(regen_response.json(), indent=2))
|
|
||||||
return False
|
|
||||||
|
|
||||||
regenerated_codes = regen_response.json()["data"]["backup_codes"]
|
|
||||||
print(f"✅ Regenerated {len(regenerated_codes)} backup codes")
|
|
||||||
|
|
||||||
# Update saved codes
|
|
||||||
test_data['backup_codes'] = regenerated_codes.copy()
|
|
||||||
|
|
||||||
# ==================== SUCCESS ====================
|
|
||||||
|
|
||||||
save_test_data()
|
|
||||||
|
|
||||||
print("\n" + "="*70)
|
|
||||||
print("🎉 ALL TESTS PASSED!")
|
|
||||||
print("="*70)
|
|
||||||
|
|
||||||
print("\n✅ TEST SUMMARY:")
|
|
||||||
print(f" 1. ✅ Initial login (with/without TOTP)")
|
|
||||||
print(f" 2. ✅ Check TOTP status")
|
|
||||||
print(f" 3. ✅ Disable TOTP")
|
|
||||||
print(f" 4. ✅ Logout")
|
|
||||||
print(f" 5. ✅ Re-login without TOTP")
|
|
||||||
print(f" 6. ✅ Enroll in TOTP")
|
|
||||||
print(f" 7. ✅ Verify enrollment")
|
|
||||||
print(f" 8. ✅ Confirm TOTP enabled")
|
|
||||||
print(f" 9. ✅ Logout")
|
|
||||||
print(f" 10. ✅ Login with TOTP required")
|
|
||||||
print(f" 11. ✅ Verify TOTP during login")
|
|
||||||
print(f" 12. ✅ Confirm logged in (/auth/me)")
|
|
||||||
print(f" 13. ✅ Login with backup code")
|
|
||||||
print(f" 14. ✅ Check backup codes decremented")
|
|
||||||
print(f" 15. ✅ Regenerate backup codes")
|
|
||||||
|
|
||||||
print(f"\n📱 Current TOTP Secret:")
|
|
||||||
print(f" {test_data['secret']}")
|
|
||||||
|
|
||||||
print(f"\n🔑 Current Backup Codes ({len(test_data['backup_codes'])}):")
|
|
||||||
for i, code in enumerate(test_data['backup_codes'], 1):
|
|
||||||
print(f" {i:2d}. {code}")
|
|
||||||
|
|
||||||
print("\n" + "="*70)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except requests.exceptions.ConnectionError:
|
|
||||||
print(f"\n❌ CONNECTION ERROR - Server not running at {BASE_URL}")
|
|
||||||
return False
|
|
||||||
except KeyError as e:
|
|
||||||
print(f"\n❌ UNEXPECTED RESPONSE STRUCTURE: Missing key {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ UNEXPECTED ERROR: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = main()
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Tests package."""
|
|
||||||
@@ -1,375 +0,0 @@
|
|||||||
"""Pytest configuration and fixtures."""
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
|
|
||||||
from gatehouse_app import create_app
|
|
||||||
from gatehouse_app.extensions import db as _db
|
|
||||||
from gatehouse_app.models import User, Organization, OrganizationMember, AuthenticationMethod
|
|
||||||
from gatehouse_app.services.auth_service import AuthService
|
|
||||||
from gatehouse_app.utils.constants import OrganizationRole, AuthMethodType
|
|
||||||
from gatehouse_app.services.external_auth_service import ExternalProviderConfig, OAuthState
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def app():
|
|
||||||
"""Create application for testing."""
|
|
||||||
app = create_app("testing")
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def db(app):
|
|
||||||
"""Create database for testing."""
|
|
||||||
with app.app_context():
|
|
||||||
_db.create_all()
|
|
||||||
yield _db
|
|
||||||
_db.session.remove()
|
|
||||||
_db.drop_all()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def client(app, db):
|
|
||||||
"""Create test client."""
|
|
||||||
return app.test_client()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def test_user(db):
|
|
||||||
"""Create a test user."""
|
|
||||||
email = "test@example.com"
|
|
||||||
password = "TestPassword123!"
|
|
||||||
full_name = "Test User"
|
|
||||||
|
|
||||||
user = AuthService.register_user(
|
|
||||||
email=email,
|
|
||||||
password=password,
|
|
||||||
full_name=full_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store password for testing
|
|
||||||
user._test_password = password
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def test_organization(db, test_user):
|
|
||||||
"""Create a test organization."""
|
|
||||||
from gatehouse_app.services.organization_service import OrganizationService
|
|
||||||
|
|
||||||
org = OrganizationService.create_organization(
|
|
||||||
name="Test Organization",
|
|
||||||
slug="test-org",
|
|
||||||
owner_user_id=test_user.id,
|
|
||||||
description="A test organization",
|
|
||||||
)
|
|
||||||
|
|
||||||
return org
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def authenticated_client(client, test_user):
|
|
||||||
"""Create authenticated test client."""
|
|
||||||
# Login
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": test_user._test_password,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def second_test_user(db):
|
|
||||||
"""Create a second test user."""
|
|
||||||
email = "second@example.com"
|
|
||||||
password = "TestPassword123!"
|
|
||||||
full_name = "Second User"
|
|
||||||
|
|
||||||
user = AuthService.register_user(
|
|
||||||
email=email,
|
|
||||||
password=password,
|
|
||||||
full_name=full_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
user._test_password = password
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# External Auth Testing Fixtures
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def google_provider_config(db, test_organization):
|
|
||||||
"""Create a Google OAuth provider configuration."""
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-google-client-id",
|
|
||||||
client_secret_encrypted="encrypted-google-secret",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=[
|
|
||||||
"http://localhost:3000/callback",
|
|
||||||
"http://localhost:5173/callback",
|
|
||||||
"https://myapp.example.com/callback",
|
|
||||||
],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def github_provider_config(db, test_organization):
|
|
||||||
"""Create a GitHub OAuth provider configuration."""
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GITHUB.value,
|
|
||||||
client_id="test-github-client-id",
|
|
||||||
client_secret_encrypted="encrypted-github-secret",
|
|
||||||
auth_url="https://github.com/login/oauth/authorize",
|
|
||||||
token_url="https://github.com/login/oauth/access_token",
|
|
||||||
userinfo_url="https://api.github.com/user",
|
|
||||||
scopes=["read:user", "user:email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def microsoft_provider_config(db, test_organization):
|
|
||||||
"""Create a Microsoft OAuth provider configuration."""
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.MICROSOFT.value,
|
|
||||||
client_id="test-microsoft-client-id",
|
|
||||||
client_secret_encrypted="encrypted-microsoft-secret",
|
|
||||||
auth_url="https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
|
||||||
token_url="https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
|
||||||
userinfo_url="https://graph.microsoft.com/oidc/userinfo",
|
|
||||||
scopes=["openid", "profile", "email", "User.Read"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def user_with_google_link(db, test_user):
|
|
||||||
"""Create a test user with a linked Google account."""
|
|
||||||
auth_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123456789",
|
|
||||||
provider_data={
|
|
||||||
"email": test_user.email,
|
|
||||||
"name": "Test User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
},
|
|
||||||
verified=True,
|
|
||||||
is_primary=False,
|
|
||||||
)
|
|
||||||
auth_method.save()
|
|
||||||
return test_user
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def user_with_multiple_providers(db, test_user):
|
|
||||||
"""Create a test user with multiple linked external accounts."""
|
|
||||||
# Google account
|
|
||||||
google_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123",
|
|
||||||
provider_data={
|
|
||||||
"email": test_user.email,
|
|
||||||
"name": "Test User",
|
|
||||||
},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
google_method.save()
|
|
||||||
|
|
||||||
# GitHub account
|
|
||||||
github_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GITHUB,
|
|
||||||
provider_user_id="github-456",
|
|
||||||
provider_data={
|
|
||||||
"email": "user@github.com",
|
|
||||||
"name": "Test User",
|
|
||||||
},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
github_method.save()
|
|
||||||
|
|
||||||
return test_user
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_google_oauth_token_response():
|
|
||||||
"""Mock Google OAuth token response."""
|
|
||||||
return {
|
|
||||||
"access_token": "ya29.mock-access-token",
|
|
||||||
"refresh_token": "1//mock-refresh-token",
|
|
||||||
"id_token": "eyJ.mock-id-token",
|
|
||||||
"token_type": "Bearer",
|
|
||||||
"expires_in": 3600,
|
|
||||||
"scope": "openid profile email",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_google_oauth_user_info():
|
|
||||||
"""Mock Google OAuth user info response."""
|
|
||||||
return {
|
|
||||||
"sub": "google-123456789",
|
|
||||||
"name": "Test User",
|
|
||||||
"given_name": "Test",
|
|
||||||
"family_name": "User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
"email": "testuser@gmail.com",
|
|
||||||
"email_verified": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_github_oauth_token_response():
|
|
||||||
"""Mock GitHub OAuth token response."""
|
|
||||||
return {
|
|
||||||
"access_token": "gho_mock-access-token",
|
|
||||||
"token_type": "bearer",
|
|
||||||
"scope": "read:user,user:email",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_github_oauth_user_info():
|
|
||||||
"""Mock GitHub OAuth user info response."""
|
|
||||||
return {
|
|
||||||
"id": 123456789,
|
|
||||||
"login": "testuser",
|
|
||||||
"name": "Test User",
|
|
||||||
"email": "testuser@github.com",
|
|
||||||
"avatar_url": "https://example.com/avatar.jpg",
|
|
||||||
"type": "User",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def oauth_login_state(db, test_organization):
|
|
||||||
"""Create an OAuth state for login flow."""
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
nonce="mock-nonce",
|
|
||||||
code_verifier="mock-code-verifier",
|
|
||||||
code_challenge="mock-code-challenge",
|
|
||||||
lifetime_seconds=600,
|
|
||||||
)
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def oauth_register_state(db, test_organization):
|
|
||||||
"""Create an OAuth state for register flow."""
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="register",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
lifetime_seconds=600,
|
|
||||||
)
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def oauth_link_state(db, test_user, test_organization):
|
|
||||||
"""Create an OAuth state for link flow."""
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="link",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
lifetime_seconds=600,
|
|
||||||
)
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def expired_oauth_state(db, test_organization):
|
|
||||||
"""Create an expired OAuth state."""
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
lifetime_seconds=-1, # Already expired
|
|
||||||
)
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def used_oauth_state(db, test_organization):
|
|
||||||
"""Create a used OAuth state."""
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
lifetime_seconds=600,
|
|
||||||
)
|
|
||||||
state.mark_used()
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_oauth_flow_mocks():
|
|
||||||
"""Common mocks for OAuth flow tests."""
|
|
||||||
with patch.object(
|
|
||||||
ExternalProviderConfig, 'get_client_secret', return_value='mock-secret'
|
|
||||||
) as mock_get_secret, patch(
|
|
||||||
'requests.post'
|
|
||||||
) as mock_post, patch(
|
|
||||||
'requests.get'
|
|
||||||
) as mock_get:
|
|
||||||
# Mock token exchange response
|
|
||||||
mock_post.return_value.json.return_value = {
|
|
||||||
"access_token": "mock-access-token",
|
|
||||||
"refresh_token": "mock-refresh-token",
|
|
||||||
"id_token": "mock-id-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
mock_post.return_value.raise_for_status = Mock()
|
|
||||||
|
|
||||||
# Mock user info response
|
|
||||||
mock_get.return_value.json.return_value = {
|
|
||||||
"sub": "google-123",
|
|
||||||
"email": "testuser@gmail.com",
|
|
||||||
"email_verified": True,
|
|
||||||
"name": "Test User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
}
|
|
||||||
mock_get.return_value.raise_for_status = Mock()
|
|
||||||
|
|
||||||
yield {
|
|
||||||
'get_secret': mock_get_secret,
|
|
||||||
'post': mock_post,
|
|
||||||
'get': mock_get,
|
|
||||||
}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Integration tests package."""
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
"""Integration tests for authentication flow."""
|
|
||||||
import pytest
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestAuthFlow:
|
|
||||||
"""Integration tests for authentication endpoints."""
|
|
||||||
|
|
||||||
def test_register_login_logout_flow(self, client, db):
|
|
||||||
"""Test complete registration, login, and logout flow."""
|
|
||||||
# Register
|
|
||||||
register_data = {
|
|
||||||
"email": "integration@example.com",
|
|
||||||
"password": "TestPassword123!",
|
|
||||||
"password_confirm": "TestPassword123!",
|
|
||||||
"full_name": "Integration Test",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/register",
|
|
||||||
data=json.dumps(register_data),
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 201
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "user" in data["data"]
|
|
||||||
assert data["data"]["user"]["email"] == "integration@example.com"
|
|
||||||
|
|
||||||
# Logout
|
|
||||||
response = client.post("/api/v1/auth/logout")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
# Login
|
|
||||||
login_data = {
|
|
||||||
"email": "integration@example.com",
|
|
||||||
"password": "TestPassword123!",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
data=json.dumps(login_data),
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "user" in data["data"]
|
|
||||||
|
|
||||||
# Logout again
|
|
||||||
response = client.post("/api/v1/auth/logout")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
def test_get_current_user_authenticated(self, authenticated_client):
|
|
||||||
"""Test getting current user when authenticated."""
|
|
||||||
response = authenticated_client.get("/api/v1/auth/me")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "user" in data["data"]
|
|
||||||
|
|
||||||
def test_get_current_user_unauthenticated(self, client):
|
|
||||||
"""Test getting current user when not authenticated."""
|
|
||||||
response = client.get("/api/v1/auth/me")
|
|
||||||
|
|
||||||
assert response.status_code == 401
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is False
|
|
||||||
|
|
||||||
def test_invalid_credentials(self, client, test_user):
|
|
||||||
"""Test login with invalid credentials."""
|
|
||||||
login_data = {
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": "WrongPassword123!",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
data=json.dumps(login_data),
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 401
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is False
|
|
||||||
|
|
||||||
def test_duplicate_registration(self, client, test_user):
|
|
||||||
"""Test registering with existing email."""
|
|
||||||
register_data = {
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": "TestPassword123!",
|
|
||||||
"password_confirm": "TestPassword123!",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/register",
|
|
||||||
data=json.dumps(register_data),
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 409
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is False
|
|
||||||
@@ -1,696 +0,0 @@
|
|||||||
"""Integration tests for external authentication API flows."""
|
|
||||||
import pytest
|
|
||||||
import json
|
|
||||||
from unittest.mock import patch, Mock
|
|
||||||
|
|
||||||
from gatehouse_app.services.external_auth_service import (
|
|
||||||
ExternalAuthService,
|
|
||||||
ExternalProviderConfig,
|
|
||||||
OAuthState,
|
|
||||||
)
|
|
||||||
from gatehouse_app.services.audit_service import AuditService
|
|
||||||
from gatehouse_app.utils.constants import AuthMethodType, OrganizationRole
|
|
||||||
from gatehouse_app.models import User, AuthenticationMethod, OrganizationMember
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestExternalAuthApiFlows:
|
|
||||||
"""Integration tests for external auth API flows."""
|
|
||||||
|
|
||||||
def test_complete_account_linking_flow(
|
|
||||||
self, app, db, client, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test complete account linking flow: initiate → callback → complete."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
client_secret_encrypted="encrypted-secret",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create organization membership
|
|
||||||
member = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
member.save()
|
|
||||||
|
|
||||||
# Login to get token
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": test_user._test_password,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert login_response.status_code == 200
|
|
||||||
token = login_response.get_json()["data"]["token"]
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
ExternalAuthService, '_exchange_code'
|
|
||||||
) as mock_exchange, patch.object(
|
|
||||||
ExternalAuthService, '_get_user_info'
|
|
||||||
) as mock_get_user_info:
|
|
||||||
# Mock external provider responses
|
|
||||||
mock_exchange.return_value = {
|
|
||||||
"access_token": "mock-access-token",
|
|
||||||
"refresh_token": "mock-refresh-token",
|
|
||||||
"id_token": "mock-id-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_get_user_info.return_value = {
|
|
||||||
"provider_user_id": "google-123",
|
|
||||||
"email": "user@gmail.com",
|
|
||||||
"email_verified": True,
|
|
||||||
"name": "Test User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
"raw_data": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Step 1: Initiate link flow
|
|
||||||
initiate_response = client.post(
|
|
||||||
"/api/v1/auth/external/google/link",
|
|
||||||
json={},
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert initiate_response.status_code == 200
|
|
||||||
initiate_data = initiate_response.get_json()
|
|
||||||
assert "authorization_url" in initiate_data["data"]
|
|
||||||
assert "state" in initiate_data["data"]
|
|
||||||
state = initiate_data["data"]["state"]
|
|
||||||
|
|
||||||
# Step 2: Simulate callback (complete link flow)
|
|
||||||
with patch.object(AuditService, 'log_external_auth_link_completed'):
|
|
||||||
complete_response = client.get(
|
|
||||||
f"/api/v1/auth/external/google/callback",
|
|
||||||
query_string={
|
|
||||||
"code": "mock-auth-code",
|
|
||||||
"state": state,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# The callback returns 200 on success
|
|
||||||
assert complete_response.status_code == 200
|
|
||||||
|
|
||||||
# Verify account is linked
|
|
||||||
auth_method = AuthenticationMethod.query.filter_by(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123",
|
|
||||||
).first()
|
|
||||||
assert auth_method is not None
|
|
||||||
|
|
||||||
def test_complete_login_flow(
|
|
||||||
self, app, db, client, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test complete login flow: initiate → callback → authenticate."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
client_secret_encrypted="encrypted-secret",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create authentication method for user
|
|
||||||
auth_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123",
|
|
||||||
provider_data={"email": test_user.email},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
auth_method.save()
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
ExternalAuthService, '_exchange_code'
|
|
||||||
) as mock_exchange, patch.object(
|
|
||||||
ExternalAuthService, '_get_user_info'
|
|
||||||
) as mock_get_user_info:
|
|
||||||
# Mock external provider responses
|
|
||||||
mock_exchange.return_value = {
|
|
||||||
"access_token": "mock-access-token",
|
|
||||||
"refresh_token": "mock-refresh-token",
|
|
||||||
"id_token": "mock-id-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_get_user_info.return_value = {
|
|
||||||
"provider_user_id": "google-123",
|
|
||||||
"email": test_user.email,
|
|
||||||
"email_verified": True,
|
|
||||||
"name": "Test User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
"raw_data": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initiate login flow
|
|
||||||
login_init_response = client.get(
|
|
||||||
"/api/v1/auth/external/google/authorize",
|
|
||||||
query_string={"flow": "login"},
|
|
||||||
)
|
|
||||||
assert login_init_response.status_code == 200
|
|
||||||
login_init_data = login_init_response.get_json()
|
|
||||||
assert "authorization_url" in login_init_data["data"]
|
|
||||||
state = login_init_data["data"]["state"]
|
|
||||||
|
|
||||||
# Simulate callback
|
|
||||||
callback_response = client.get(
|
|
||||||
f"/api/v1/auth/external/google/callback",
|
|
||||||
query_string={
|
|
||||||
"code": "mock-auth-code",
|
|
||||||
"state": state,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert callback_response.status_code == 200
|
|
||||||
callback_data = callback_response.get_json()
|
|
||||||
|
|
||||||
assert callback_data["success"] is True
|
|
||||||
assert callback_data["flow_type"] == "login"
|
|
||||||
assert "token" in callback_data["data"]
|
|
||||||
assert callback_data["data"]["user"]["id"] == test_user.id
|
|
||||||
|
|
||||||
def test_account_unlinking_flow(
|
|
||||||
self, app, db, client, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test account unlinking flow."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create organization membership
|
|
||||||
member = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
member.save()
|
|
||||||
|
|
||||||
# Create password auth method
|
|
||||||
password_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.PASSWORD,
|
|
||||||
provider_user_id=test_user.id,
|
|
||||||
)
|
|
||||||
password_method.save()
|
|
||||||
|
|
||||||
# Create Google auth method
|
|
||||||
google_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123",
|
|
||||||
provider_data={"email": test_user.email},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
google_method.save()
|
|
||||||
|
|
||||||
# Login to get token
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": test_user._test_password,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
token = login_response.get_json()["data"]["token"]
|
|
||||||
|
|
||||||
# Unlink Google account
|
|
||||||
with patch.object(AuditService, 'log_external_auth_unlink'):
|
|
||||||
unlink_response = client.delete(
|
|
||||||
"/api/v1/auth/external/google/unlink",
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert unlink_response.status_code == 200
|
|
||||||
unlink_data = unlink_response.get_json()
|
|
||||||
assert "success" in unlink_data or "message" in unlink_data
|
|
||||||
|
|
||||||
# Verify account is unlinked
|
|
||||||
auth_method = AuthenticationMethod.query.filter_by(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
).first()
|
|
||||||
assert auth_method is None
|
|
||||||
|
|
||||||
def test_provider_configuration_crud(
|
|
||||||
self, app, db, client, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test provider configuration CRUD operations."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create organization membership as admin
|
|
||||||
member = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.ADMIN,
|
|
||||||
)
|
|
||||||
member.save()
|
|
||||||
|
|
||||||
# Login to get token
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": test_user._test_password,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
token = login_response.get_json()["data"]["token"]
|
|
||||||
|
|
||||||
# Step 1: Create provider config
|
|
||||||
with patch.object(AuditService, 'log_external_auth_config_create'):
|
|
||||||
create_response = client.post(
|
|
||||||
"/api/v1/auth/external/google/config",
|
|
||||||
json={
|
|
||||||
"client_id": "new-client-id",
|
|
||||||
"client_secret": "new-client-secret",
|
|
||||||
"scopes": ["openid", "profile", "email"],
|
|
||||||
"redirect_uris": ["http://localhost:3000/callback"],
|
|
||||||
},
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert create_response.status_code == 201
|
|
||||||
create_data = create_response.get_json()
|
|
||||||
assert create_data["data"]["provider_type"] == "google"
|
|
||||||
assert create_data["data"]["client_id"] == "new-client-id"
|
|
||||||
|
|
||||||
config_id = create_data["data"]["id"]
|
|
||||||
|
|
||||||
# Step 2: List providers
|
|
||||||
list_response = client.get(
|
|
||||||
"/api/v1/auth/external/providers",
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert list_response.status_code == 200
|
|
||||||
list_data = list_response.get_json()
|
|
||||||
google_provider = next(
|
|
||||||
p for p in list_data["data"]["providers"] if p["id"] == "google"
|
|
||||||
)
|
|
||||||
assert google_provider["is_configured"] is True
|
|
||||||
|
|
||||||
# Step 3: Get provider config
|
|
||||||
get_response = client.get(
|
|
||||||
"/api/v1/auth/external/google/config",
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert get_response.status_code == 200
|
|
||||||
get_data = get_response.get_json()
|
|
||||||
assert get_data["data"]["client_id"] == "new-client-id"
|
|
||||||
|
|
||||||
# Step 4: Update provider config
|
|
||||||
with patch.object(AuditService, 'log_external_auth_config_update'):
|
|
||||||
update_response = client.post(
|
|
||||||
"/api/v1/auth/external/google/config",
|
|
||||||
json={
|
|
||||||
"client_id": "updated-client-id",
|
|
||||||
"client_secret": "updated-client-secret",
|
|
||||||
},
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert update_response.status_code == 200
|
|
||||||
update_data = update_response.get_json()
|
|
||||||
assert update_data["data"]["client_id"] == "updated-client-id"
|
|
||||||
|
|
||||||
# Step 5: Delete provider config
|
|
||||||
with patch.object(AuditService, 'log_external_auth_config_delete'):
|
|
||||||
delete_response = client.delete(
|
|
||||||
"/api/v1/auth/external/google/config",
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert delete_response.status_code == 200
|
|
||||||
|
|
||||||
# Verify deletion
|
|
||||||
get_deleted_response = client.get(
|
|
||||||
"/api/v1/auth/external/google/config",
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert get_deleted_response.status_code == 404
|
|
||||||
|
|
||||||
def test_invalid_state_error(self, app, db, client, test_user, test_organization):
|
|
||||||
"""Test error handling for invalid OAuth state."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Try callback with invalid state
|
|
||||||
callback_response = client.get(
|
|
||||||
"/api/v1/auth/external/google/callback",
|
|
||||||
query_string={
|
|
||||||
"code": "mock-auth-code",
|
|
||||||
"state": "invalid-state",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert callback_response.status_code == 400
|
|
||||||
callback_data = callback_response.get_json()
|
|
||||||
assert callback_data["error_type"] == "INVALID_STATE"
|
|
||||||
|
|
||||||
def test_expired_state_error(self, app, db, client, test_user, test_organization):
|
|
||||||
"""Test error handling for expired OAuth state."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create expired state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
lifetime_seconds=-1, # Already expired
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try callback with expired state
|
|
||||||
callback_response = client.get(
|
|
||||||
"/api/v1/auth/external/google/callback",
|
|
||||||
query_string={
|
|
||||||
"code": "mock-auth-code",
|
|
||||||
"state": state.state,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert callback_response.status_code == 400
|
|
||||||
callback_data = callback_response.get_json()
|
|
||||||
assert callback_data["error_type"] == "INVALID_STATE"
|
|
||||||
|
|
||||||
def test_provider_not_configured_error(
|
|
||||||
self, app, db, client, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test error handling when provider is not configured."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create organization membership
|
|
||||||
member = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
member.save()
|
|
||||||
|
|
||||||
# Login to get token
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": test_user._test_password,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
token = login_response.get_json()["data"]["token"]
|
|
||||||
|
|
||||||
# Try to link with unconfigured provider
|
|
||||||
link_response = client.post(
|
|
||||||
"/api/v1/auth/external/google/link",
|
|
||||||
json={},
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert link_response.status_code == 400
|
|
||||||
link_data = link_response.get_json()
|
|
||||||
assert link_data["error_type"] == "PROVIDER_NOT_CONFIGURED"
|
|
||||||
|
|
||||||
def test_linked_accounts_list(self, app, db, client, test_user, test_organization):
|
|
||||||
"""Test listing linked accounts."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create organization membership
|
|
||||||
member = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
member.save()
|
|
||||||
|
|
||||||
# Create authentication methods
|
|
||||||
google_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123",
|
|
||||||
provider_data={
|
|
||||||
"email": test_user.email,
|
|
||||||
"name": "Test User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
google_method.save()
|
|
||||||
|
|
||||||
github_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GITHUB,
|
|
||||||
provider_user_id="github-456",
|
|
||||||
provider_data={
|
|
||||||
"email": "user@github.com",
|
|
||||||
"name": "Test User",
|
|
||||||
},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
github_method.save()
|
|
||||||
|
|
||||||
# Login to get token
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": test_user._test_password,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
token = login_response.get_json()["data"]["token"]
|
|
||||||
|
|
||||||
# List linked accounts
|
|
||||||
list_response = client.get(
|
|
||||||
"/api/v1/auth/external/linked-accounts",
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert list_response.status_code == 200
|
|
||||||
list_data = list_response.get_json()
|
|
||||||
|
|
||||||
assert len(list_data["data"]["linked_accounts"]) == 2
|
|
||||||
assert list_data["data"]["unlink_available"] is True
|
|
||||||
|
|
||||||
def test_non_admin_cannot_manage_providers(
|
|
||||||
self, app, db, client, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test that non-admin users cannot manage provider configurations."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create organization membership as regular member
|
|
||||||
member = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
member.save()
|
|
||||||
|
|
||||||
# Login to get token
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": test_user._test_password,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
token = login_response.get_json()["data"]["token"]
|
|
||||||
|
|
||||||
# Try to create provider config (should fail)
|
|
||||||
create_response = client.post(
|
|
||||||
"/api/v1/auth/external/google/config",
|
|
||||||
json={
|
|
||||||
"client_id": "client-id",
|
|
||||||
"client_secret": "client-secret",
|
|
||||||
},
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert create_response.status_code == 403
|
|
||||||
assert create_response.get_json()["error_type"] == "FORBIDDEN"
|
|
||||||
|
|
||||||
def test_unsupported_provider_error(
|
|
||||||
self, app, db, client, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test error handling for unsupported provider."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create organization membership
|
|
||||||
member = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
member.save()
|
|
||||||
|
|
||||||
# Login to get token
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": test_user._test_password,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
token = login_response.get_json()["data"]["token"]
|
|
||||||
|
|
||||||
# Try to link with unsupported provider
|
|
||||||
link_response = client.post(
|
|
||||||
"/api/v1/auth/external/unsupported/link",
|
|
||||||
json={},
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
assert link_response.status_code == 400
|
|
||||||
link_data = link_response.get_json()
|
|
||||||
assert link_data["error_type"] == "UNSUPPORTED_PROVIDER"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestExternalAuthAuditLogging:
|
|
||||||
"""Integration tests for audit logging in external auth flows."""
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.audit_service.AuditService')
|
|
||||||
def test_audit_log_on_link_initiated(
|
|
||||||
self, mock_audit, app, db, client, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test audit log is created when link flow is initiated."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create organization membership
|
|
||||||
member = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
member.save()
|
|
||||||
|
|
||||||
# Login to get token
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": test_user._test_password,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
token = login_response.get_json()["data"]["token"]
|
|
||||||
|
|
||||||
# Initiate link flow
|
|
||||||
link_response = client.post(
|
|
||||||
"/api/v1/auth/external/google/link",
|
|
||||||
json={},
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify audit log was called
|
|
||||||
mock_audit.log_external_auth_link_initiated.assert_called_once()
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.audit_service.AuditService')
|
|
||||||
def test_audit_log_on_unlink(
|
|
||||||
self, mock_audit, app, db, client, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test audit log is created when account is unlinked."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create organization membership
|
|
||||||
member = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
member.save()
|
|
||||||
|
|
||||||
# Create password auth method
|
|
||||||
password_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.PASSWORD,
|
|
||||||
provider_user_id=test_user.id,
|
|
||||||
)
|
|
||||||
password_method.save()
|
|
||||||
|
|
||||||
# Create Google auth method
|
|
||||||
google_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123",
|
|
||||||
provider_data={"email": test_user.email},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
google_method.save()
|
|
||||||
|
|
||||||
# Login to get token
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": test_user._test_password,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
token = login_response.get_json()["data"]["token"]
|
|
||||||
|
|
||||||
# Unlink Google account
|
|
||||||
unlink_response = client.delete(
|
|
||||||
"/api/v1/auth/external/google/unlink",
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify audit log was called
|
|
||||||
mock_audit.log_external_auth_unlink.assert_called_once()
|
|
||||||
@@ -1,933 +0,0 @@
|
|||||||
"""Integration tests for MFA compliance enforcement."""
|
|
||||||
import pytest
|
|
||||||
import json
|
|
||||||
from datetime import datetime, timezone, timedelta
|
|
||||||
from gatehouse_app.models.user import User
|
|
||||||
from gatehouse_app.models.organization import Organization
|
|
||||||
from gatehouse_app.models.organization_member import OrganizationMember
|
|
||||||
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
|
|
||||||
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
|
||||||
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
|
|
||||||
from gatehouse_app.models.session import Session
|
|
||||||
from gatehouse_app.utils.constants import MfaPolicyMode, MfaComplianceStatus, UserStatus, MfaRequirementOverride
|
|
||||||
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestMfaComplianceLogin:
|
|
||||||
"""Integration tests for MFA compliance during login."""
|
|
||||||
|
|
||||||
def test_login_with_no_policy(self, client, db, test_user):
|
|
||||||
"""Test login with no MFA policy (should work normally)."""
|
|
||||||
login_data = {
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": "TestPassword123!",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
data=json.dumps(login_data),
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "user" in data["data"]
|
|
||||||
assert "token" in data["data"]
|
|
||||||
# No MFA compliance info should be present when no policy exists
|
|
||||||
assert "mfa_compliance" not in data["data"]
|
|
||||||
assert "requires_mfa_enrollment" not in data["data"]
|
|
||||||
|
|
||||||
def test_login_with_optional_policy(self, client, db, test_user, test_organization):
|
|
||||||
"""Test login with optional MFA policy (should work normally)."""
|
|
||||||
# Create an optional MFA policy
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
login_data = {
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": "TestPassword123!",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
data=json.dumps(login_data),
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "user" in data["data"]
|
|
||||||
assert "token" in data["data"]
|
|
||||||
# MFA compliance should be present but status should be not_applicable
|
|
||||||
assert "mfa_compliance" in data["data"]
|
|
||||||
assert data["data"]["mfa_compliance"]["overall_status"] == "not_applicable"
|
|
||||||
assert "requires_mfa_enrollment" not in data["data"]
|
|
||||||
|
|
||||||
def test_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization):
|
|
||||||
"""Test login with required policy within grace period (should work with warning)."""
|
|
||||||
# Create a required MFA policy
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
login_data = {
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": "TestPassword123!",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
data=json.dumps(login_data),
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "user" in data["data"]
|
|
||||||
assert "token" in data["data"]
|
|
||||||
# MFA compliance should be present with in_grace status
|
|
||||||
assert "mfa_compliance" in data["data"]
|
|
||||||
assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace"
|
|
||||||
assert "requires_mfa_enrollment" not in data["data"]
|
|
||||||
assert "totp" in data["data"]["mfa_compliance"]["missing_methods"]
|
|
||||||
|
|
||||||
def test_login_with_required_policy_after_deadline(self, client, db, test_user, test_organization):
|
|
||||||
"""Test login with required policy after deadline (should get compliance-only session)."""
|
|
||||||
# Create a required MFA policy with past deadline
|
|
||||||
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
|
|
||||||
# Create compliance record past due
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.PAST_DUE,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
|
||||||
deadline_at=past_deadline,
|
|
||||||
)
|
|
||||||
db.session.add(compliance)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
login_data = {
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": "TestPassword123!",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
data=json.dumps(login_data),
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "user" in data["data"]
|
|
||||||
assert "token" in data["data"]
|
|
||||||
# Should have compliance-only session
|
|
||||||
assert data["data"]["requires_mfa_enrollment"] is True
|
|
||||||
assert "mfa_compliance" in data["data"]
|
|
||||||
assert data["data"]["mfa_compliance"]["overall_status"] in ["past_due", "suspended"]
|
|
||||||
|
|
||||||
def test_login_with_suspended_user(self, client, db, test_user, test_organization):
|
|
||||||
"""Test login with compliance suspended user (should get compliance-only session)."""
|
|
||||||
# Set user status to compliance suspended
|
|
||||||
test_user.status = UserStatus.COMPLIANCE_SUSPENDED
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
login_data = {
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": "TestPassword123!",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
data=json.dumps(login_data),
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "user" in data["data"]
|
|
||||||
assert "token" in data["data"]
|
|
||||||
# Should have compliance-only session
|
|
||||||
assert data["data"]["requires_mfa_enrollment"] is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestMfaComplianceAccess:
|
|
||||||
"""Integration tests for MFA compliance access control."""
|
|
||||||
|
|
||||||
def test_compliance_only_session_denied_full_access(self, client, db, test_user, test_organization):
|
|
||||||
"""Test that compliance-only session cannot access full access endpoints."""
|
|
||||||
# Create a required MFA policy with past deadline
|
|
||||||
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
|
|
||||||
# Create compliance record past due
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.PAST_DUE,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
|
||||||
deadline_at=past_deadline,
|
|
||||||
)
|
|
||||||
db.session.add(compliance)
|
|
||||||
|
|
||||||
# Create a compliance-only session
|
|
||||||
session = Session(
|
|
||||||
user_id=test_user.id,
|
|
||||||
token="compliance_only_token",
|
|
||||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
|
||||||
is_compliance_only=True,
|
|
||||||
)
|
|
||||||
db.session.add(session)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Try to access a full-access endpoint (get_my_organizations)
|
|
||||||
response = client.get(
|
|
||||||
"/api/v1/users/me/organizations",
|
|
||||||
headers={"Authorization": "Bearer compliance_only_token"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 403
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is False
|
|
||||||
assert data["error_type"] == "MFA_COMPLIANCE_REQUIRED"
|
|
||||||
|
|
||||||
def test_compliance_only_session_can_access_mfa_enrollment(self, client, db, test_user, test_organization):
|
|
||||||
"""Test that compliance-only session can access MFA enrollment endpoints."""
|
|
||||||
# Create a required MFA policy with past deadline
|
|
||||||
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
|
|
||||||
# Create compliance record past due
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.PAST_DUE,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
|
||||||
deadline_at=past_deadline,
|
|
||||||
)
|
|
||||||
db.session.add(compliance)
|
|
||||||
|
|
||||||
# Create a compliance-only session
|
|
||||||
session = Session(
|
|
||||||
user_id=test_user.id,
|
|
||||||
token="compliance_only_token",
|
|
||||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
|
||||||
is_compliance_only=True,
|
|
||||||
)
|
|
||||||
db.session.add(session)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Try to access MFA enrollment endpoint (should work)
|
|
||||||
response = client.get(
|
|
||||||
"/api/v1/auth/totp/status",
|
|
||||||
headers={"Authorization": "Bearer compliance_only_token"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is True
|
|
||||||
|
|
||||||
def test_compliance_only_session_can_access_logout(self, client, db, test_user, test_organization):
|
|
||||||
"""Test that compliance-only session can access logout endpoint."""
|
|
||||||
# Create a required MFA policy with past deadline
|
|
||||||
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
|
|
||||||
# Create compliance record past due
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.PAST_DUE,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
|
||||||
deadline_at=past_deadline,
|
|
||||||
)
|
|
||||||
db.session.add(compliance)
|
|
||||||
|
|
||||||
# Create a compliance-only session
|
|
||||||
session = Session(
|
|
||||||
user_id=test_user.id,
|
|
||||||
token="compliance_only_token",
|
|
||||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
|
||||||
is_compliance_only=True,
|
|
||||||
)
|
|
||||||
db.session.add(session)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Try to access logout endpoint (should work)
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/logout",
|
|
||||||
headers={"Authorization": "Bearer compliance_only_token"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestMfaComplianceWebAuthn:
|
|
||||||
"""Integration tests for MFA compliance with WebAuthn login."""
|
|
||||||
|
|
||||||
def test_webauthn_login_with_required_policy_in_grace_period(self, client, db, test_user, test_organization):
|
|
||||||
"""Test WebAuthn login with required policy within grace period."""
|
|
||||||
# Create a required MFA policy
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Note: Full WebAuthn login test would require WebAuthn setup
|
|
||||||
# This test verifies the compliance response structure
|
|
||||||
login_data = {
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": "TestPassword123!",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
data=json.dumps(login_data),
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "mfa_compliance" in data["data"]
|
|
||||||
assert data["data"]["mfa_compliance"]["overall_status"] == "in_grace"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestMfaComplianceOIDC:
|
|
||||||
"""Integration tests for MFA compliance with OIDC authorization."""
|
|
||||||
|
|
||||||
def test_oidc_authorize_with_compliance_required(self, client, db, test_user, test_organization, app):
|
|
||||||
"""Test OIDC authorize with compliance required (should show error)."""
|
|
||||||
# Create a required MFA policy with past deadline
|
|
||||||
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
|
|
||||||
# Create compliance record past due
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.PAST_DUE,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
|
||||||
deadline_at=past_deadline,
|
|
||||||
)
|
|
||||||
db.session.add(compliance)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Try OIDC authorize with credentials
|
|
||||||
response = client.post(
|
|
||||||
"/oidc/authorize",
|
|
||||||
data={
|
|
||||||
"client_id": "test_client",
|
|
||||||
"redirect_uri": "http://localhost:8080/callback",
|
|
||||||
"response_type": "code",
|
|
||||||
"scope": "openid profile email",
|
|
||||||
"state": "test_state",
|
|
||||||
"email": test_user.email,
|
|
||||||
"password": "TestPassword123!",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should return login page with error
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert b"Your account requires multi factor enrollment before using single sign on" in response.data
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Phase 4: Edge Case Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestMfaComplianceMultiOrg:
|
|
||||||
"""Integration tests for multi-organization MFA compliance edge cases."""
|
|
||||||
|
|
||||||
def test_user_with_multiple_orgs_different_policies(self, client, db, test_user):
|
|
||||||
"""Test user belonging to multiple orgs with different MFA policies."""
|
|
||||||
# Create two organizations
|
|
||||||
org1 = Organization(
|
|
||||||
name="Org1",
|
|
||||||
slug="org1-test-multi",
|
|
||||||
)
|
|
||||||
org2 = Organization(
|
|
||||||
name="Org2",
|
|
||||||
slug="org2-test-multi",
|
|
||||||
)
|
|
||||||
db.session.add_all([org1, org2])
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Add user to both orgs
|
|
||||||
membership1 = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=org1.id,
|
|
||||||
role="member",
|
|
||||||
)
|
|
||||||
membership2 = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=org2.id,
|
|
||||||
role="member",
|
|
||||||
)
|
|
||||||
db.session.add_all([membership1, membership2])
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create different policies for each org
|
|
||||||
# Org1: OPTIONAL (no requirement)
|
|
||||||
policy1 = OrganizationSecurityPolicy(
|
|
||||||
organization_id=org1.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
# Org2: REQUIRE_TOTP (strictest)
|
|
||||||
policy2 = OrganizationSecurityPolicy(
|
|
||||||
organization_id=org2.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add_all([policy1, policy2])
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Evaluate user MFA state
|
|
||||||
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user)
|
|
||||||
|
|
||||||
# Overall status should reflect the strictest policy (REQUIRE_TOTP from org2)
|
|
||||||
assert compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value
|
|
||||||
assert "totp" in compliance_summary.missing_methods
|
|
||||||
|
|
||||||
# Verify per-org breakdown
|
|
||||||
assert len(compliance_summary.orgs) == 2
|
|
||||||
org1_status = next((o for o in compliance_summary.orgs if o.organization_id == org1.id), None)
|
|
||||||
org2_status = next((o for o in compliance_summary.orgs if o.organization_id == org2.id), None)
|
|
||||||
|
|
||||||
assert org1_status is not None
|
|
||||||
assert org2_status is not None
|
|
||||||
assert org1_status.status == MfaComplianceStatus.NOT_APPLICABLE.value
|
|
||||||
assert org2_status.status == MfaComplianceStatus.IN_GRACE.value
|
|
||||||
|
|
||||||
def test_user_with_multiple_orgs_all_suspended(self, client, db, test_user):
|
|
||||||
"""Test user with multiple orgs where all require MFA and are past due."""
|
|
||||||
# Create two organizations
|
|
||||||
org1 = Organization(
|
|
||||||
name="Org1",
|
|
||||||
slug="org1-test-suspended",
|
|
||||||
)
|
|
||||||
org2 = Organization(
|
|
||||||
name="Org2",
|
|
||||||
slug="org2-test-suspended",
|
|
||||||
)
|
|
||||||
db.session.add_all([org1, org2])
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Add user to both orgs
|
|
||||||
membership1 = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=org1.id,
|
|
||||||
role="member",
|
|
||||||
)
|
|
||||||
membership2 = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=org2.id,
|
|
||||||
role="member",
|
|
||||||
)
|
|
||||||
db.session.add_all([membership1, membership2])
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create required policies
|
|
||||||
policy1 = OrganizationSecurityPolicy(
|
|
||||||
organization_id=org1.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
policy2 = OrganizationSecurityPolicy(
|
|
||||||
organization_id=org2.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add_all([policy1, policy2])
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create past-due compliance records for both
|
|
||||||
past_deadline = datetime.now(timezone.utc) - timedelta(days=1)
|
|
||||||
compliance1 = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=org1.id,
|
|
||||||
status=MfaComplianceStatus.SUSPENDED,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
|
|
||||||
deadline_at=past_deadline,
|
|
||||||
suspended_at=past_deadline,
|
|
||||||
)
|
|
||||||
compliance2 = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=org2.id,
|
|
||||||
status=MfaComplianceStatus.SUSPENDED,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
|
|
||||||
deadline_at=past_deadline,
|
|
||||||
suspended_at=past_deadline,
|
|
||||||
)
|
|
||||||
db.session.add_all([compliance1, compliance2])
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Evaluate user MFA state
|
|
||||||
compliance_summary = MfaPolicyService.evaluate_user_mfa_state(test_user)
|
|
||||||
|
|
||||||
# Overall status should be SUSPENDED
|
|
||||||
assert compliance_summary.overall_status == MfaComplianceStatus.SUSPENDED.value
|
|
||||||
|
|
||||||
def test_strictest_mode_selection(self):
|
|
||||||
"""Test that get_strictest_mode returns the most restrictive policy."""
|
|
||||||
modes = [
|
|
||||||
MfaPolicyMode.DISABLED.value,
|
|
||||||
MfaPolicyMode.OPTIONAL.value,
|
|
||||||
MfaPolicyMode.REQUIRE_TOTP.value,
|
|
||||||
]
|
|
||||||
result = MfaPolicyService.get_strictest_mode(modes)
|
|
||||||
assert result == MfaPolicyMode.REQUIRE_TOTP.value
|
|
||||||
|
|
||||||
# Test with REQUIRE_TOTP_OR_WEBAUTHN (strictest)
|
|
||||||
modes_strictest = [
|
|
||||||
MfaPolicyMode.REQUIRE_TOTP.value,
|
|
||||||
MfaPolicyMode.REQUIRE_WEBAUTHN.value,
|
|
||||||
MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
|
|
||||||
]
|
|
||||||
result = MfaPolicyService.get_strictest_mode(modes_strictest)
|
|
||||||
assert result == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestMfaComplianceUserOverrides:
|
|
||||||
"""Integration tests for user override edge cases."""
|
|
||||||
|
|
||||||
def test_user_override_inherit_mode(self, client, db, test_user, test_organization):
|
|
||||||
"""Test INHERIT mode - org policy applies as is."""
|
|
||||||
# Create a required policy
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create INHERIT override (default behavior)
|
|
||||||
override = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.INHERIT,
|
|
||||||
)
|
|
||||||
db.session.add(override)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Get effective policy
|
|
||||||
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
|
||||||
|
|
||||||
# Should inherit org policy
|
|
||||||
assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
|
|
||||||
assert effective.requires_totp is True
|
|
||||||
assert effective.is_exempt is False
|
|
||||||
|
|
||||||
def test_user_override_required_mode(self, client, db, test_user, test_organization):
|
|
||||||
"""Test REQUIRED mode - user always required to have MFA."""
|
|
||||||
# Create an optional policy
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create REQUIRED override
|
|
||||||
override = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.REQUIRED,
|
|
||||||
)
|
|
||||||
db.session.add(override)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Get effective policy
|
|
||||||
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
|
||||||
|
|
||||||
# Should be upgraded to REQUIRE_TOTP_OR_WEBAUTHN
|
|
||||||
assert effective.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
|
|
||||||
assert effective.requires_totp is True
|
|
||||||
assert effective.requires_webauthn is True
|
|
||||||
assert effective.is_exempt is False
|
|
||||||
|
|
||||||
def test_user_override_exempt_mode(self, client, db, test_user, test_organization):
|
|
||||||
"""Test EXEMPT mode - org policy does not apply."""
|
|
||||||
# Create a required policy
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create EXEMPT override
|
|
||||||
override = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
|
||||||
)
|
|
||||||
db.session.add(override)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Get effective policy
|
|
||||||
effective = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
|
||||||
|
|
||||||
# Should be exempt from policy
|
|
||||||
assert effective.is_exempt is True
|
|
||||||
assert effective.effective_mode == MfaPolicyMode.DISABLED.value
|
|
||||||
assert effective.requires_totp is False
|
|
||||||
assert effective.requires_webauthn is False
|
|
||||||
|
|
||||||
def test_get_override_summary(self, client, db, test_user, test_organization):
|
|
||||||
"""Test getting override summary for a user."""
|
|
||||||
# No override exists
|
|
||||||
summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id)
|
|
||||||
|
|
||||||
assert summary["has_override"] is False
|
|
||||||
assert summary["mode"] == "inherit"
|
|
||||||
|
|
||||||
# Create an override
|
|
||||||
override = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
|
||||||
)
|
|
||||||
db.session.add(override)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Get summary again
|
|
||||||
summary = MfaPolicyService.get_override_summary(test_user.id, test_organization.id)
|
|
||||||
|
|
||||||
assert summary["has_override"] is True
|
|
||||||
assert summary["mode"] == "exempt"
|
|
||||||
assert summary["is_exempt"] is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestMfaCompliancePolicyChanges:
|
|
||||||
"""Integration tests for policy changes affecting existing users."""
|
|
||||||
|
|
||||||
def test_policy_change_triggers_compliance_reevaluation(self, client, db, test_user, test_organization):
|
|
||||||
"""Test that policy change triggers compliance reevaluation."""
|
|
||||||
# Create initial optional policy
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create compliance record (should be NOT_APPLICABLE)
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.NOT_APPLICABLE,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(compliance)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Update policy to REQUIRE_TOTP
|
|
||||||
MfaPolicyService.create_org_policy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
updated_by_user_id=test_user.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reevaluate all compliance
|
|
||||||
updated_count = MfaPolicyService.reevaluate_all_org_compliance(test_organization.id)
|
|
||||||
|
|
||||||
# Should have updated at least one record
|
|
||||||
assert updated_count >= 1
|
|
||||||
|
|
||||||
# Check compliance status was updated
|
|
||||||
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
|
|
||||||
assert updated_compliance.status == MfaComplianceStatus.IN_GRACE.value
|
|
||||||
assert updated_compliance.deadline_at is not None
|
|
||||||
|
|
||||||
def test_policy_relaxation_clears_requirements(self, client, db, test_user, test_organization):
|
|
||||||
"""Test that relaxing policy clears compliance requirements."""
|
|
||||||
# Create required policy
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create IN_GRACE compliance record
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.IN_GRACE,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=datetime.now(timezone.utc),
|
|
||||||
deadline_at=datetime.now(timezone.utc) + timedelta(days=14),
|
|
||||||
)
|
|
||||||
db.session.add(compliance)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Update policy to OPTIONAL
|
|
||||||
MfaPolicyService.create_org_policy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
updated_by_user_id=test_user.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reevaluate compliance
|
|
||||||
MfaPolicyService.reevaluate_all_org_compliance(test_organization.id)
|
|
||||||
|
|
||||||
# Check compliance status was updated to NOT_APPLICABLE
|
|
||||||
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
|
|
||||||
assert updated_compliance.status == MfaComplianceStatus.NOT_APPLICABLE.value
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestMfaComplianceScheduledJob:
|
|
||||||
"""Integration tests for the MFA compliance scheduled job."""
|
|
||||||
|
|
||||||
def test_transition_to_suspended(self, client, db, test_user, test_organization):
|
|
||||||
"""Test that past-due users are transitioned to suspended."""
|
|
||||||
# Create required policy
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create past-due compliance record
|
|
||||||
past_deadline = datetime.now(timezone.utc) - timedelta(hours=1)
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.PAST_DUE,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=datetime.now(timezone.utc) - timedelta(days=15),
|
|
||||||
deadline_at=past_deadline,
|
|
||||||
)
|
|
||||||
db.session.add(compliance)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Run the job
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
suspended_count = MfaPolicyService.transition_to_suspended_if_past_due(now)
|
|
||||||
|
|
||||||
# Should have suspended the user
|
|
||||||
assert suspended_count >= 1
|
|
||||||
|
|
||||||
# Check compliance status
|
|
||||||
updated_compliance = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
|
|
||||||
assert updated_compliance.status == MfaComplianceStatus.SUSPENDED.value
|
|
||||||
assert updated_compliance.suspended_at is not None
|
|
||||||
|
|
||||||
# Check user status
|
|
||||||
db.refresh(test_user)
|
|
||||||
assert test_user.status == UserStatus.COMPLIANCE_SUSPENDED
|
|
||||||
|
|
||||||
def test_check_and_restore_user_status(self, client, db, test_user, test_organization):
|
|
||||||
"""Test that suspended users are restored when they become compliant."""
|
|
||||||
# Create required policy
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add(policy)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# User is suspended
|
|
||||||
test_user.status = UserStatus.COMPLIANCE_SUSPENDED
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create EXEMPT override to clear requirement
|
|
||||||
override = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
|
||||||
)
|
|
||||||
db.session.add(override)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Check and restore status
|
|
||||||
restored = MfaPolicyService.check_and_restore_user_status(test_user.id)
|
|
||||||
|
|
||||||
# Should have restored user
|
|
||||||
assert restored is True
|
|
||||||
db.refresh(test_user)
|
|
||||||
assert test_user.status == UserStatus.ACTIVE
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestMfaComplianceMultiOrgAggregate:
|
|
||||||
"""Integration tests for multi-org aggregate state calculation."""
|
|
||||||
|
|
||||||
def test_get_multi_org_aggregate_state(self, client, db, test_user):
|
|
||||||
"""Test aggregate state calculation for multi-org user."""
|
|
||||||
# Create two organizations
|
|
||||||
org1 = Organization(
|
|
||||||
name="AggOrg1",
|
|
||||||
slug="agg-org1-test",
|
|
||||||
)
|
|
||||||
org2 = Organization(
|
|
||||||
name="AggOrg2",
|
|
||||||
slug="agg-org2-test",
|
|
||||||
)
|
|
||||||
db.session.add_all([org1, org2])
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Add user to both
|
|
||||||
membership1 = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=org1.id,
|
|
||||||
role="member",
|
|
||||||
)
|
|
||||||
membership2 = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=org2.id,
|
|
||||||
role="member",
|
|
||||||
)
|
|
||||||
db.session.add_all([membership1, membership2])
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create policies
|
|
||||||
policy1 = OrganizationSecurityPolicy(
|
|
||||||
organization_id=org1.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
policy2 = OrganizationSecurityPolicy(
|
|
||||||
organization_id=org2.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
db.session.add_all([policy1, policy2])
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Get aggregate state
|
|
||||||
aggregate = MfaPolicyService.get_multi_org_aggregate_state(test_user)
|
|
||||||
|
|
||||||
# Verify structure
|
|
||||||
assert "overall_status" in aggregate
|
|
||||||
assert "strictest_mode" in aggregate
|
|
||||||
assert "missing_methods" in aggregate
|
|
||||||
assert "requiring_org_count" in aggregate
|
|
||||||
assert "requiring_orgs" in aggregate
|
|
||||||
assert "per_org_details" in aggregate
|
|
||||||
|
|
||||||
# Strictest mode should be REQUIRE_TOTP_OR_WEBAUTHN
|
|
||||||
assert aggregate["strictest_mode"] == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
|
|
||||||
|
|
||||||
# Both orgs should require MFA
|
|
||||||
assert aggregate["requiring_org_count"] == 2
|
|
||||||
assert len(aggregate["requiring_orgs"]) == 2
|
|
||||||
assert len(aggregate["per_org_details"]) == 2
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
|||||||
"""Unit tests package."""
|
|
||||||
@@ -1,295 +0,0 @@
|
|||||||
"""Unit tests for MFA policy models."""
|
|
||||||
import pytest
|
|
||||||
from datetime import datetime, timezone, timedelta
|
|
||||||
from gatehouse_app.models import (
|
|
||||||
User,
|
|
||||||
Organization,
|
|
||||||
OrganizationMember,
|
|
||||||
OrganizationSecurityPolicy,
|
|
||||||
UserSecurityPolicy,
|
|
||||||
MfaPolicyCompliance,
|
|
||||||
Session,
|
|
||||||
)
|
|
||||||
from gatehouse_app.utils.constants import (
|
|
||||||
UserStatus,
|
|
||||||
MfaPolicyMode,
|
|
||||||
MfaComplianceStatus,
|
|
||||||
MfaRequirementOverride,
|
|
||||||
SessionStatus,
|
|
||||||
OrganizationRole,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestOrganizationSecurityPolicyModel:
|
|
||||||
"""Tests for OrganizationSecurityPolicy model."""
|
|
||||||
|
|
||||||
def test_create_org_security_policy(self, db, test_organization):
|
|
||||||
"""Test creating an organization security policy."""
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
)
|
|
||||||
policy.save()
|
|
||||||
|
|
||||||
assert policy.id is not None
|
|
||||||
assert policy.organization_id == test_organization.id
|
|
||||||
assert policy.mfa_policy_mode == MfaPolicyMode.OPTIONAL
|
|
||||||
assert policy.mfa_grace_period_days == 14
|
|
||||||
assert policy.notify_days_before == 7
|
|
||||||
assert policy.policy_version == 1
|
|
||||||
assert policy.created_at is not None
|
|
||||||
|
|
||||||
def test_org_security_policy_to_dict(self, db, test_organization):
|
|
||||||
"""Test organization security policy to_dict method."""
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=7,
|
|
||||||
notify_days_before=3,
|
|
||||||
)
|
|
||||||
policy.save()
|
|
||||||
|
|
||||||
policy_dict = policy.to_dict()
|
|
||||||
|
|
||||||
assert "id" in policy_dict
|
|
||||||
assert "organization_id" in policy_dict
|
|
||||||
assert policy_dict["organization_id"] == test_organization.id
|
|
||||||
assert "mfa_policy_mode" in policy_dict
|
|
||||||
assert "mfa_grace_period_days" in policy_dict
|
|
||||||
|
|
||||||
def test_org_security_policy_relationships(self, db, test_organization):
|
|
||||||
"""Test organization security policy relationships."""
|
|
||||||
policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
)
|
|
||||||
policy.save()
|
|
||||||
|
|
||||||
# Test relationship
|
|
||||||
assert policy.organization is not None
|
|
||||||
assert policy.organization.id == test_organization.id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestUserSecurityPolicyModel:
|
|
||||||
"""Tests for UserSecurityPolicy model."""
|
|
||||||
|
|
||||||
def test_create_user_security_policy(self, db, test_user, test_organization):
|
|
||||||
"""Test creating a user security policy."""
|
|
||||||
policy = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.INHERIT,
|
|
||||||
)
|
|
||||||
policy.save()
|
|
||||||
|
|
||||||
assert policy.id is not None
|
|
||||||
assert policy.user_id == test_user.id
|
|
||||||
assert policy.organization_id == test_organization.id
|
|
||||||
assert policy.mfa_override_mode == MfaRequirementOverride.INHERIT
|
|
||||||
assert policy.force_totp is False
|
|
||||||
assert policy.force_webauthn is False
|
|
||||||
|
|
||||||
def test_user_security_policy_with_overrides(self, db, test_user, test_organization):
|
|
||||||
"""Test user security policy with override settings."""
|
|
||||||
policy = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.REQUIRED,
|
|
||||||
force_totp=True,
|
|
||||||
force_webauthn=False,
|
|
||||||
)
|
|
||||||
policy.save()
|
|
||||||
|
|
||||||
assert policy.mfa_override_mode == MfaRequirementOverride.REQUIRED
|
|
||||||
assert policy.force_totp is True
|
|
||||||
assert policy.force_webauthn is False
|
|
||||||
|
|
||||||
def test_user_security_policy_exempt(self, db, test_user, test_organization):
|
|
||||||
"""Test user security policy with exempt override."""
|
|
||||||
policy = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
|
||||||
)
|
|
||||||
policy.save()
|
|
||||||
|
|
||||||
assert policy.mfa_override_mode == MfaRequirementOverride.EXEMPT
|
|
||||||
|
|
||||||
def test_user_security_policy_relationships(self, db, test_user, test_organization):
|
|
||||||
"""Test user security policy relationships."""
|
|
||||||
policy = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.INHERIT,
|
|
||||||
)
|
|
||||||
policy.save()
|
|
||||||
|
|
||||||
# Test relationships
|
|
||||||
assert policy.user is not None
|
|
||||||
assert policy.user.id == test_user.id
|
|
||||||
assert policy.organization is not None
|
|
||||||
assert policy.organization.id == test_organization.id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestMfaPolicyComplianceModel:
|
|
||||||
"""Tests for MfaPolicyCompliance model."""
|
|
||||||
|
|
||||||
def test_create_mfa_policy_compliance(self, db, test_user, test_organization):
|
|
||||||
"""Test creating an MFA policy compliance record."""
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.NOT_APPLICABLE,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
compliance.save()
|
|
||||||
|
|
||||||
assert compliance.id is not None
|
|
||||||
assert compliance.user_id == test_user.id
|
|
||||||
assert compliance.organization_id == test_organization.id
|
|
||||||
assert compliance.status == MfaComplianceStatus.NOT_APPLICABLE
|
|
||||||
assert compliance.policy_version == 1
|
|
||||||
assert compliance.notification_count == 0
|
|
||||||
|
|
||||||
def test_mfa_policy_compliance_in_grace(self, db, test_user, test_organization):
|
|
||||||
"""Test MFA compliance record in grace period."""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.IN_GRACE,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=now,
|
|
||||||
deadline_at=now + timedelta(days=14),
|
|
||||||
)
|
|
||||||
compliance.save()
|
|
||||||
|
|
||||||
assert compliance.status == MfaComplianceStatus.IN_GRACE
|
|
||||||
assert compliance.applied_at is not None
|
|
||||||
assert compliance.deadline_at is not None
|
|
||||||
assert compliance.deadline_at > now
|
|
||||||
|
|
||||||
def test_mfa_policy_compliance_compliant(self, db, test_user, test_organization):
|
|
||||||
"""Test MFA compliance record when compliant."""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.COMPLIANT,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=now - timedelta(days=30),
|
|
||||||
deadline_at=now - timedelta(days=16),
|
|
||||||
compliant_at=now - timedelta(days=16),
|
|
||||||
)
|
|
||||||
compliance.save()
|
|
||||||
|
|
||||||
assert compliance.status == MfaComplianceStatus.COMPLIANT
|
|
||||||
assert compliance.compliant_at is not None
|
|
||||||
|
|
||||||
def test_mfa_policy_compliance_suspended(self, db, test_user, test_organization):
|
|
||||||
"""Test MFA compliance record when suspended."""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.SUSPENDED,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=now - timedelta(days=30),
|
|
||||||
deadline_at=now - timedelta(days=16),
|
|
||||||
suspended_at=now - timedelta(days=16),
|
|
||||||
)
|
|
||||||
compliance.save()
|
|
||||||
|
|
||||||
assert compliance.status == MfaComplianceStatus.SUSPENDED
|
|
||||||
assert compliance.suspended_at is not None
|
|
||||||
|
|
||||||
def test_mfa_policy_compliance_relationships(self, db, test_user, test_organization):
|
|
||||||
"""Test MFA compliance relationships."""
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.NOT_APPLICABLE,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
compliance.save()
|
|
||||||
|
|
||||||
# Test relationships
|
|
||||||
assert compliance.user is not None
|
|
||||||
assert compliance.user.id == test_user.id
|
|
||||||
assert compliance.organization is not None
|
|
||||||
assert compliance.organization.id == test_organization.id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestSessionModelComplianceFlag:
|
|
||||||
"""Tests for Session model compliance flag."""
|
|
||||||
|
|
||||||
def test_session_default_not_compliance_only(self, db, test_user):
|
|
||||||
"""Test that sessions are not compliance only by default."""
|
|
||||||
session = Session(
|
|
||||||
user_id=test_user.id,
|
|
||||||
token="test-token-123",
|
|
||||||
status=SessionStatus.ACTIVE,
|
|
||||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
|
|
||||||
last_activity_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
session.save()
|
|
||||||
|
|
||||||
assert session.is_compliance_only is False
|
|
||||||
|
|
||||||
def test_session_compliance_only(self, db, test_user):
|
|
||||||
"""Test creating a compliance-only session."""
|
|
||||||
session = Session(
|
|
||||||
user_id=test_user.id,
|
|
||||||
token="compliance-token-123",
|
|
||||||
status=SessionStatus.ACTIVE,
|
|
||||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
|
|
||||||
last_activity_at=datetime.now(timezone.utc),
|
|
||||||
is_compliance_only=True,
|
|
||||||
)
|
|
||||||
session.save()
|
|
||||||
|
|
||||||
assert session.is_compliance_only is True
|
|
||||||
|
|
||||||
def test_session_to_dict_excludes_token(self, db, test_user):
|
|
||||||
"""Test that session to_dict excludes the token."""
|
|
||||||
session = Session(
|
|
||||||
user_id=test_user.id,
|
|
||||||
token="test-token-456",
|
|
||||||
status=SessionStatus.ACTIVE,
|
|
||||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=8),
|
|
||||||
last_activity_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
session.save()
|
|
||||||
|
|
||||||
session_dict = session.to_dict()
|
|
||||||
|
|
||||||
assert "id" in session_dict
|
|
||||||
assert "user_id" in session_dict
|
|
||||||
assert "is_compliance_only" in session_dict
|
|
||||||
assert session_dict["is_compliance_only"] is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestUserStatusComplianceSuspended:
|
|
||||||
"""Tests for UserStatus.COMPLIANCE_SUSPENDED."""
|
|
||||||
|
|
||||||
def test_compliance_suspended_status_exists(self):
|
|
||||||
"""Test that COMPLIANCE_SUSPENDED status exists."""
|
|
||||||
assert UserStatus.COMPLIANCE_SUSPENDED.value == "compliance_suspended"
|
|
||||||
|
|
||||||
def test_create_compliance_suspended_user(self, db):
|
|
||||||
"""Test creating a compliance suspended user."""
|
|
||||||
user = User(
|
|
||||||
email="suspended@example.com",
|
|
||||||
full_name="Suspended User",
|
|
||||||
status=UserStatus.COMPLIANCE_SUSPENDED,
|
|
||||||
)
|
|
||||||
user.save()
|
|
||||||
|
|
||||||
assert user.status == UserStatus.COMPLIANCE_SUSPENDED
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
"""Unit tests for models."""
|
|
||||||
import pytest
|
|
||||||
from datetime import datetime
|
|
||||||
from gatehouse_app.models import User, Organization
|
|
||||||
from gatehouse_app.utils.constants import UserStatus
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestUserModel:
|
|
||||||
"""Tests for User model."""
|
|
||||||
|
|
||||||
def test_create_user(self, db):
|
|
||||||
"""Test creating a user."""
|
|
||||||
user = User(
|
|
||||||
email="test@example.com",
|
|
||||||
full_name="Test User",
|
|
||||||
status=UserStatus.ACTIVE,
|
|
||||||
)
|
|
||||||
user.save()
|
|
||||||
|
|
||||||
assert user.id is not None
|
|
||||||
assert user.email == "test@example.com"
|
|
||||||
assert user.full_name == "Test User"
|
|
||||||
assert user.status == UserStatus.ACTIVE
|
|
||||||
assert user.created_at is not None
|
|
||||||
assert user.deleted_at is None
|
|
||||||
|
|
||||||
def test_user_to_dict(self, test_user):
|
|
||||||
"""Test user to_dict method."""
|
|
||||||
user_dict = test_user.to_dict()
|
|
||||||
|
|
||||||
assert "id" in user_dict
|
|
||||||
assert "email" in user_dict
|
|
||||||
assert user_dict["email"] == test_user.email
|
|
||||||
assert "created_at" in user_dict
|
|
||||||
|
|
||||||
def test_user_soft_delete(self, test_user):
|
|
||||||
"""Test soft deleting a user."""
|
|
||||||
test_user.delete(soft=True)
|
|
||||||
|
|
||||||
assert test_user.deleted_at is not None
|
|
||||||
assert isinstance(test_user.deleted_at, datetime)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestOrganizationModel:
|
|
||||||
"""Tests for Organization model."""
|
|
||||||
|
|
||||||
def test_create_organization(self, db):
|
|
||||||
"""Test creating an organization."""
|
|
||||||
org = Organization(
|
|
||||||
name="Test Org",
|
|
||||||
slug="test-org",
|
|
||||||
description="Test organization",
|
|
||||||
)
|
|
||||||
org.save()
|
|
||||||
|
|
||||||
assert org.id is not None
|
|
||||||
assert org.name == "Test Org"
|
|
||||||
assert org.slug == "test-org"
|
|
||||||
assert org.is_active is True
|
|
||||||
assert org.created_at is not None
|
|
||||||
|
|
||||||
def test_organization_to_dict(self, test_organization):
|
|
||||||
"""Test organization to_dict method."""
|
|
||||||
org_dict = test_organization.to_dict()
|
|
||||||
|
|
||||||
assert "id" in org_dict
|
|
||||||
assert "name" in org_dict
|
|
||||||
assert org_dict["name"] == test_organization.name
|
|
||||||
assert "slug" in org_dict
|
|
||||||
|
|
||||||
def test_get_member_count(self, test_organization):
|
|
||||||
"""Test getting member count."""
|
|
||||||
count = test_organization.get_member_count()
|
|
||||||
assert count == 1 # Only the owner
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Services unit tests package."""
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
"""Unit tests for AuthService."""
|
|
||||||
import pytest
|
|
||||||
from gatehouse_app.services.auth_service import AuthService
|
|
||||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
|
||||||
from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError
|
|
||||||
from gatehouse_app.utils.constants import UserStatus, AuthMethodType
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestAuthService:
|
|
||||||
"""Tests for AuthService."""
|
|
||||||
|
|
||||||
def test_register_user(self, db):
|
|
||||||
"""Test user registration."""
|
|
||||||
email = "newuser@example.com"
|
|
||||||
password = "SecurePassword123!"
|
|
||||||
full_name = "New User"
|
|
||||||
|
|
||||||
user = AuthService.register_user(
|
|
||||||
email=email,
|
|
||||||
password=password,
|
|
||||||
full_name=full_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert user.id is not None
|
|
||||||
assert user.email == email.lower()
|
|
||||||
assert user.full_name == full_name
|
|
||||||
assert user.status == UserStatus.ACTIVE
|
|
||||||
assert user.has_password_auth()
|
|
||||||
|
|
||||||
def test_register_duplicate_email(self, db, test_user):
|
|
||||||
"""Test registering with duplicate email."""
|
|
||||||
with pytest.raises(EmailAlreadyExistsError):
|
|
||||||
AuthService.register_user(
|
|
||||||
email=test_user.email,
|
|
||||||
password="SomePassword123!",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_authenticate_success(self, db, test_user):
|
|
||||||
"""Test successful authentication."""
|
|
||||||
user = AuthService.authenticate(
|
|
||||||
email=test_user.email,
|
|
||||||
password=test_user._test_password,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert user.id == test_user.id
|
|
||||||
assert user.last_login_at is not None
|
|
||||||
|
|
||||||
def test_authenticate_wrong_password(self, db, test_user):
|
|
||||||
"""Test authentication with wrong password."""
|
|
||||||
with pytest.raises(InvalidCredentialsError):
|
|
||||||
AuthService.authenticate(
|
|
||||||
email=test_user.email,
|
|
||||||
password="WrongPassword123!",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_authenticate_nonexistent_user(self, db):
|
|
||||||
"""Test authentication with non-existent email."""
|
|
||||||
with pytest.raises(InvalidCredentialsError):
|
|
||||||
AuthService.authenticate(
|
|
||||||
email="nonexistent@example.com",
|
|
||||||
password="SomePassword123!",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_create_session(self, app, db, test_user):
|
|
||||||
"""Test creating a session."""
|
|
||||||
with app.test_request_context():
|
|
||||||
session = AuthService.create_session(test_user)
|
|
||||||
|
|
||||||
assert session.id is not None
|
|
||||||
assert session.user_id == test_user.id
|
|
||||||
assert session.token is not None
|
|
||||||
assert session.is_active()
|
|
||||||
|
|
||||||
def test_change_password(self, app, db, test_user):
|
|
||||||
"""Test changing password."""
|
|
||||||
with app.test_request_context():
|
|
||||||
new_password = "NewPassword456!"
|
|
||||||
|
|
||||||
AuthService.change_password(
|
|
||||||
user=test_user,
|
|
||||||
current_password=test_user._test_password,
|
|
||||||
new_password=new_password,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify can login with new password
|
|
||||||
user = AuthService.authenticate(
|
|
||||||
email=test_user.email,
|
|
||||||
password=new_password,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert user.id == test_user.id
|
|
||||||
|
|
||||||
def test_change_password_wrong_current(self, app, db, test_user):
|
|
||||||
"""Test changing password with wrong current password."""
|
|
||||||
with app.test_request_context():
|
|
||||||
with pytest.raises(InvalidCredentialsError):
|
|
||||||
AuthService.change_password(
|
|
||||||
user=test_user,
|
|
||||||
current_password="WrongPassword123!",
|
|
||||||
new_password="NewPassword456!",
|
|
||||||
)
|
|
||||||
@@ -1,698 +0,0 @@
|
|||||||
"""Unit tests for ExternalAuthService."""
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch, MagicMock
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
|
|
||||||
from gatehouse_app.services.external_auth_service import (
|
|
||||||
ExternalAuthService,
|
|
||||||
ExternalAuthError,
|
|
||||||
OAuthState,
|
|
||||||
ExternalProviderConfig,
|
|
||||||
)
|
|
||||||
from gatehouse_app.utils.constants import AuthMethodType
|
|
||||||
from gatehouse_app.models import User, AuthenticationMethod
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestExternalAuthService:
|
|
||||||
"""Tests for ExternalAuthService."""
|
|
||||||
|
|
||||||
def test_get_provider_config_success(self, app, db, test_organization):
|
|
||||||
"""Test getting provider configuration successfully."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
client_secret_encrypted="encrypted-secret",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Get config
|
|
||||||
result = ExternalAuthService.get_provider_config(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.id == config.id
|
|
||||||
assert result.client_id == "test-client-id"
|
|
||||||
assert result.is_active is True
|
|
||||||
|
|
||||||
def test_get_provider_config_not_configured(self, app, db, test_organization):
|
|
||||||
"""Test getting provider configuration when not configured."""
|
|
||||||
with app.app_context():
|
|
||||||
with pytest.raises(ExternalAuthError) as exc_info:
|
|
||||||
ExternalAuthService.get_provider_config(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "PROVIDER_NOT_CONFIGURED"
|
|
||||||
assert exc_info.value.status_code == 400
|
|
||||||
|
|
||||||
def test_get_provider_config_inactive(self, app, db, test_organization):
|
|
||||||
"""Test getting provider configuration when inactive."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create inactive provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=False,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
with pytest.raises(ExternalAuthError) as exc_info:
|
|
||||||
ExternalAuthService.get_provider_config(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "PROVIDER_NOT_CONFIGURED"
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.external_auth_service.AuditService')
|
|
||||||
def test_initiate_link_flow_success(self, mock_audit, app, db, test_user, test_organization):
|
|
||||||
"""Test initiating account linking flow successfully."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Initiate link flow
|
|
||||||
auth_url, state = ExternalAuthService.initiate_link_flow(
|
|
||||||
user_id=test_user.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert auth_url is not None
|
|
||||||
assert state is not None
|
|
||||||
assert len(state) == 43 # Base64 URL-safe token length
|
|
||||||
|
|
||||||
# Verify state was created
|
|
||||||
state_record = OAuthState.query.filter_by(state=state).first()
|
|
||||||
assert state_record is not None
|
|
||||||
assert state_record.flow_type == "link"
|
|
||||||
assert state_record.user_id == test_user.id
|
|
||||||
assert state_record.provider_type == AuthMethodType.GOOGLE.value
|
|
||||||
|
|
||||||
# Verify audit log
|
|
||||||
mock_audit.log_external_auth_link_initiated.assert_called_once()
|
|
||||||
|
|
||||||
def test_initiate_link_flow_invalid_redirect_uri(self, app, db, test_user, test_organization):
|
|
||||||
"""Test initiating link flow with invalid redirect URI."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
with pytest.raises(ExternalAuthError) as exc_info:
|
|
||||||
ExternalAuthService.initiate_link_flow(
|
|
||||||
user_id=test_user.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://malicious-site.com/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "INVALID_REDIRECT_URI"
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._exchange_code')
|
|
||||||
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._get_user_info')
|
|
||||||
@patch('gatehouse_app.services.external_auth_service.AuditService')
|
|
||||||
def test_complete_link_flow_success(
|
|
||||||
self, mock_audit, mock_get_user_info, mock_exchange_code,
|
|
||||||
app, db, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test completing account linking flow successfully."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create OAuth state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="link",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock external provider responses
|
|
||||||
mock_exchange_code.return_value = {
|
|
||||||
"access_token": "mock-access-token",
|
|
||||||
"refresh_token": "mock-refresh-token",
|
|
||||||
"id_token": "mock-id-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_get_user_info.return_value = {
|
|
||||||
"provider_user_id": "google-123",
|
|
||||||
"email": "user@gmail.com",
|
|
||||||
"email_verified": True,
|
|
||||||
"name": "Test User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
"raw_data": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Complete link flow
|
|
||||||
auth_method = ExternalAuthService.complete_link_flow(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state=state.state,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert auth_method is not None
|
|
||||||
assert auth_method.user_id == test_user.id
|
|
||||||
assert auth_method.method_type == AuthMethodType.GOOGLE
|
|
||||||
assert auth_method.provider_user_id == "google-123"
|
|
||||||
|
|
||||||
# Verify state is marked as used
|
|
||||||
state_record = OAuthState.query.get(state.id)
|
|
||||||
assert state_record.used is True
|
|
||||||
|
|
||||||
# Verify audit log
|
|
||||||
mock_audit.log_external_auth_link_completed.assert_called_once()
|
|
||||||
|
|
||||||
def test_complete_link_flow_invalid_state(self, app, db):
|
|
||||||
"""Test completing link flow with invalid state."""
|
|
||||||
with app.app_context():
|
|
||||||
with pytest.raises(ExternalAuthError) as exc_info:
|
|
||||||
ExternalAuthService.complete_link_flow(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state="invalid-state",
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "INVALID_STATE"
|
|
||||||
|
|
||||||
def test_complete_link_flow_wrong_flow_type(self, app, db, test_organization):
|
|
||||||
"""Test completing link flow with wrong flow type state."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create login flow state instead of link
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ExternalAuthError) as exc_info:
|
|
||||||
ExternalAuthService.complete_link_flow(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state=state.state,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "INVALID_FLOW_TYPE"
|
|
||||||
|
|
||||||
def test_complete_link_flow_provider_mismatch(self, app, db, test_organization):
|
|
||||||
"""Test completing link flow with provider mismatch."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create state with different provider
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="link",
|
|
||||||
provider_type=AuthMethodType.GITHUB,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ExternalAuthError) as exc_info:
|
|
||||||
ExternalAuthService.complete_link_flow(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state=state.state,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "PROVIDER_MISMATCH"
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._exchange_code')
|
|
||||||
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._get_user_info')
|
|
||||||
@patch('gatehouse_app.services.external_auth_service.AuditService')
|
|
||||||
def test_authenticate_with_provider_success(
|
|
||||||
self, mock_audit, mock_get_user_info, mock_exchange_code,
|
|
||||||
app, db, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test authenticating with provider successfully."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create authentication method for user
|
|
||||||
auth_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123",
|
|
||||||
provider_data={"email": test_user.email},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
auth_method.save()
|
|
||||||
|
|
||||||
# Create OAuth state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock external provider responses
|
|
||||||
mock_exchange_code.return_value = {
|
|
||||||
"access_token": "mock-access-token",
|
|
||||||
"refresh_token": "mock-refresh-token",
|
|
||||||
"id_token": "mock-id-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_get_user_info.return_value = {
|
|
||||||
"provider_user_id": "google-123",
|
|
||||||
"email": test_user.email,
|
|
||||||
"email_verified": True,
|
|
||||||
"name": "Test User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
"raw_data": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Authenticate
|
|
||||||
user, session_data = ExternalAuthService.authenticate_with_provider(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state=state.state,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert user.id == test_user.id
|
|
||||||
assert session_data is not None
|
|
||||||
assert "token" in session_data
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._exchange_code')
|
|
||||||
@patch('gatehouse_app.services.external_auth_service.ExternalAuthService._get_user_info')
|
|
||||||
@patch('gatehouse_app.services.external_auth_service.AuditService')
|
|
||||||
def test_authenticate_with_provider_account_not_found(
|
|
||||||
self, mock_audit, mock_get_user_info, mock_exchange_code,
|
|
||||||
app, db, test_organization
|
|
||||||
):
|
|
||||||
"""Test authenticating with provider when account not found."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create OAuth state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock external provider responses
|
|
||||||
mock_exchange_code.return_value = {
|
|
||||||
"access_token": "mock-access-token",
|
|
||||||
"refresh_token": "mock-refresh-token",
|
|
||||||
"id_token": "mock-id-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_get_user_info.return_value = {
|
|
||||||
"provider_user_id": "google-456",
|
|
||||||
"email": "newuser@gmail.com",
|
|
||||||
"email_verified": True,
|
|
||||||
"name": "New User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
"raw_data": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
with pytest.raises(ExternalAuthError) as exc_info:
|
|
||||||
ExternalAuthService.authenticate_with_provider(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state=state.state,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "ACCOUNT_NOT_FOUND"
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.external_auth_service.AuditService')
|
|
||||||
def test_unlink_provider_success(self, mock_audit, app, db, test_user):
|
|
||||||
"""Test unlinking provider successfully."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create password auth method first (so user has other methods)
|
|
||||||
password_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.PASSWORD,
|
|
||||||
provider_user_id=test_user.id,
|
|
||||||
)
|
|
||||||
password_method.save()
|
|
||||||
|
|
||||||
# Create Google auth method
|
|
||||||
google_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123",
|
|
||||||
provider_data={"email": test_user.email},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
google_method.save()
|
|
||||||
|
|
||||||
# Unlink Google
|
|
||||||
result = ExternalAuthService.unlink_provider(
|
|
||||||
user_id=test_user.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
# Verify auth method is deleted
|
|
||||||
method = AuthenticationMethod.query.filter_by(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
).first()
|
|
||||||
assert method is None
|
|
||||||
|
|
||||||
# Verify audit log
|
|
||||||
mock_audit.log_external_auth_unlink.assert_called_once()
|
|
||||||
|
|
||||||
def test_unlink_provider_not_linked(self, app, db, test_user):
|
|
||||||
"""Test unlinking provider that is not linked."""
|
|
||||||
with app.app_context():
|
|
||||||
with pytest.raises(ExternalAuthError) as exc_info:
|
|
||||||
ExternalAuthService.unlink_provider(
|
|
||||||
user_id=test_user.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "PROVIDER_NOT_LINKED"
|
|
||||||
|
|
||||||
def test_unlink_provider_last_method(self, app, db, test_user):
|
|
||||||
"""Test unlinking last authentication method."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create only Google auth method
|
|
||||||
google_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123",
|
|
||||||
provider_data={"email": test_user.email},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
google_method.save()
|
|
||||||
|
|
||||||
with pytest.raises(ExternalAuthError) as exc_info:
|
|
||||||
ExternalAuthService.unlink_provider(
|
|
||||||
user_id=test_user.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "CANNOT_UNLINK_LAST"
|
|
||||||
|
|
||||||
def test_get_linked_accounts(self, app, db, test_user):
|
|
||||||
"""Test getting linked accounts for user."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create Google auth method
|
|
||||||
google_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123",
|
|
||||||
provider_data={
|
|
||||||
"email": test_user.email,
|
|
||||||
"name": "Test User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
google_method.save()
|
|
||||||
|
|
||||||
# Create GitHub auth method
|
|
||||||
github_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GITHUB,
|
|
||||||
provider_user_id="github-456",
|
|
||||||
provider_data={
|
|
||||||
"email": "user@github.com",
|
|
||||||
"name": "Test User",
|
|
||||||
},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
github_method.save()
|
|
||||||
|
|
||||||
# Get linked accounts
|
|
||||||
accounts = ExternalAuthService.get_linked_accounts(test_user.id)
|
|
||||||
|
|
||||||
assert len(accounts) == 2
|
|
||||||
|
|
||||||
google_account = next(a for a in accounts if a["provider_type"] == "google")
|
|
||||||
assert google_account["provider_user_id"] == "google-123"
|
|
||||||
assert google_account["email"] == test_user.email
|
|
||||||
|
|
||||||
github_account = next(a for a in accounts if a["provider_type"] == "github")
|
|
||||||
assert github_account["provider_user_id"] == "github-456"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestOAuthState:
|
|
||||||
"""Tests for OAuthState model."""
|
|
||||||
|
|
||||||
def test_create_state(self, app, db):
|
|
||||||
"""Test creating OAuth state."""
|
|
||||||
with app.app_context():
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
user_id="user-123",
|
|
||||||
organization_id="org-456",
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert state.state is not None
|
|
||||||
assert len(state.state) == 43
|
|
||||||
assert state.flow_type == "login"
|
|
||||||
assert state.provider_type == AuthMethodType.GOOGLE.value
|
|
||||||
assert state.user_id == "user-123"
|
|
||||||
assert state.organization_id == "org-456"
|
|
||||||
assert state.redirect_uri == "http://localhost:3000/callback"
|
|
||||||
assert state.used is False
|
|
||||||
assert state.expires_at > datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
def test_is_valid(self, app, db):
|
|
||||||
"""Test OAuth state validity check."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create valid state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert state.is_valid() is True
|
|
||||||
|
|
||||||
# Mark as used
|
|
||||||
state.mark_used()
|
|
||||||
|
|
||||||
assert state.is_valid() is False
|
|
||||||
|
|
||||||
def test_is_valid_expired(self, app, db):
|
|
||||||
"""Test OAuth state validity with expiration."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create expired state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
lifetime_seconds=-1, # Already expired
|
|
||||||
)
|
|
||||||
|
|
||||||
assert state.is_valid() is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestExternalProviderConfig:
|
|
||||||
"""Tests for ExternalProviderConfig model."""
|
|
||||||
|
|
||||||
def test_is_redirect_uri_allowed(self, app, db, test_organization):
|
|
||||||
"""Test redirect URI validation."""
|
|
||||||
with app.app_context():
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=[
|
|
||||||
"http://localhost:3000/callback",
|
|
||||||
"https://myapp.com/callback",
|
|
||||||
],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
assert config.is_redirect_uri_allowed("http://localhost:3000/callback") is True
|
|
||||||
assert config.is_redirect_uri_allowed("https://myapp.com/callback") is True
|
|
||||||
assert config.is_redirect_uri_allowed("http://malicious.com/callback") is False
|
|
||||||
|
|
||||||
def test_to_dict(self, app, db, test_organization):
|
|
||||||
"""Test converting config to dictionary."""
|
|
||||||
with app.app_context():
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
result = config.to_dict()
|
|
||||||
|
|
||||||
assert result["organization_id"] == test_organization.id
|
|
||||||
assert result["provider_type"] == AuthMethodType.GOOGLE.value
|
|
||||||
assert result["client_id"] == "test-client-id"
|
|
||||||
assert "client_secret" not in result
|
|
||||||
assert result["is_active"] is True
|
|
||||||
|
|
||||||
def test_to_dict_include_secrets(self, app, db, test_organization):
|
|
||||||
"""Test converting config to dictionary with secrets."""
|
|
||||||
with app.app_context():
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
client_secret_encrypted="encrypted-secret",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
result = config.to_dict(include_secrets=True)
|
|
||||||
|
|
||||||
assert "client_secret" in result
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestExternalAuthError:
|
|
||||||
"""Tests for ExternalAuthError exception."""
|
|
||||||
|
|
||||||
def test_error_creation(self):
|
|
||||||
"""Test creating ExternalAuthError."""
|
|
||||||
error = ExternalAuthError(
|
|
||||||
message="Test error message",
|
|
||||||
error_type="TEST_ERROR",
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert error.message == "Test error message"
|
|
||||||
assert error.error_type == "TEST_ERROR"
|
|
||||||
assert error.status_code == 400
|
|
||||||
|
|
||||||
def test_error_default_status_code(self):
|
|
||||||
"""Test ExternalAuthError with default status code."""
|
|
||||||
error = ExternalAuthError(
|
|
||||||
message="Test error message",
|
|
||||||
error_type="TEST_ERROR",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert error.status_code == 400
|
|
||||||
@@ -1,476 +0,0 @@
|
|||||||
"""Unit tests for MfaPolicyService."""
|
|
||||||
import pytest
|
|
||||||
from datetime import datetime, timezone, timedelta
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
from gatehouse_app.models import (
|
|
||||||
User,
|
|
||||||
Organization,
|
|
||||||
OrganizationMember,
|
|
||||||
OrganizationSecurityPolicy,
|
|
||||||
UserSecurityPolicy,
|
|
||||||
MfaPolicyCompliance,
|
|
||||||
Session,
|
|
||||||
)
|
|
||||||
from gatehouse_app.services.mfa_policy_service import (
|
|
||||||
MfaPolicyService,
|
|
||||||
OrgPolicyDto,
|
|
||||||
EffectiveUserPolicyDto,
|
|
||||||
AggregateMfaStateDto,
|
|
||||||
LoginPolicyResult,
|
|
||||||
)
|
|
||||||
from gatehouse_app.utils.constants import (
|
|
||||||
UserStatus,
|
|
||||||
MfaPolicyMode,
|
|
||||||
MfaComplianceStatus,
|
|
||||||
MfaRequirementOverride,
|
|
||||||
SessionStatus,
|
|
||||||
OrganizationRole,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestMfaPolicyService:
|
|
||||||
"""Tests for MfaPolicyService."""
|
|
||||||
|
|
||||||
def test_get_org_policy_not_found(self, db, test_organization):
|
|
||||||
"""Test getting organization policy when none exists."""
|
|
||||||
policy = MfaPolicyService.get_org_policy(test_organization.id)
|
|
||||||
assert policy is None
|
|
||||||
|
|
||||||
def test_get_org_policy_found(self, db, test_organization):
|
|
||||||
"""Test getting organization policy when it exists."""
|
|
||||||
# Create policy
|
|
||||||
org_policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
)
|
|
||||||
org_policy.save()
|
|
||||||
|
|
||||||
policy = MfaPolicyService.get_org_policy(test_organization.id)
|
|
||||||
|
|
||||||
assert policy is not None
|
|
||||||
assert policy.organization_id == test_organization.id
|
|
||||||
assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
|
|
||||||
assert policy.mfa_grace_period_days == 14
|
|
||||||
assert policy.notify_days_before == 7
|
|
||||||
assert policy.policy_version == 1
|
|
||||||
|
|
||||||
def test_get_effective_user_policy_no_org_policy(self, db, test_user, test_organization):
|
|
||||||
"""Test effective user policy when no org policy exists."""
|
|
||||||
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
|
||||||
|
|
||||||
assert policy is not None
|
|
||||||
assert policy.organization_id == test_organization.id
|
|
||||||
assert policy.effective_mode == MfaPolicyMode.DISABLED.value
|
|
||||||
assert policy.requires_totp is False
|
|
||||||
assert policy.requires_webauthn is False
|
|
||||||
assert policy.is_exempt is True
|
|
||||||
|
|
||||||
def test_get_effective_user_policy_with_org_policy(self, db, test_user, test_organization):
|
|
||||||
"""Test effective user policy with org policy and no override."""
|
|
||||||
# Create org policy
|
|
||||||
org_policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
)
|
|
||||||
org_policy.save()
|
|
||||||
|
|
||||||
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
|
||||||
|
|
||||||
assert policy is not None
|
|
||||||
assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
|
|
||||||
assert policy.requires_totp is True
|
|
||||||
assert policy.requires_webauthn is False
|
|
||||||
assert policy.is_exempt is False
|
|
||||||
|
|
||||||
def test_get_effective_user_policy_with_override_inherit(self, db, test_user, test_organization):
|
|
||||||
"""Test effective user policy with INHERIT override."""
|
|
||||||
# Create org policy
|
|
||||||
org_policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=7,
|
|
||||||
)
|
|
||||||
org_policy.save()
|
|
||||||
|
|
||||||
# Create user override
|
|
||||||
user_override = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.INHERIT,
|
|
||||||
)
|
|
||||||
user_override.save()
|
|
||||||
|
|
||||||
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
|
||||||
|
|
||||||
assert policy.effective_mode == MfaPolicyMode.REQUIRE_WEBAUTHN.value
|
|
||||||
assert policy.requires_webauthn is True
|
|
||||||
|
|
||||||
def test_get_effective_user_policy_with_override_exempt(self, db, test_user, test_organization):
|
|
||||||
"""Test effective user policy with EXEMPT override."""
|
|
||||||
# Create org policy
|
|
||||||
org_policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
)
|
|
||||||
org_policy.save()
|
|
||||||
|
|
||||||
# Create user override
|
|
||||||
user_override = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
|
||||||
)
|
|
||||||
user_override.save()
|
|
||||||
|
|
||||||
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
|
||||||
|
|
||||||
assert policy.effective_mode == MfaPolicyMode.DISABLED.value
|
|
||||||
assert policy.is_exempt is True
|
|
||||||
|
|
||||||
def test_get_effective_user_policy_with_override_required(self, db, test_user, test_organization):
|
|
||||||
"""Test effective user policy with REQUIRED override."""
|
|
||||||
# Create org policy
|
|
||||||
org_policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
)
|
|
||||||
org_policy.save()
|
|
||||||
|
|
||||||
# Create user override
|
|
||||||
user_override = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.REQUIRED,
|
|
||||||
)
|
|
||||||
user_override.save()
|
|
||||||
|
|
||||||
policy = MfaPolicyService.get_effective_user_policy(test_user.id, test_organization.id)
|
|
||||||
|
|
||||||
assert policy.effective_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value
|
|
||||||
assert policy.requires_totp is True
|
|
||||||
assert policy.requires_webauthn is True
|
|
||||||
assert policy.is_exempt is False
|
|
||||||
|
|
||||||
def test_evaluate_user_mfa_state_no_policy(self, db, test_user, test_organization):
|
|
||||||
"""Test evaluating user MFA state with no policy."""
|
|
||||||
# Create membership
|
|
||||||
membership = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
membership.save()
|
|
||||||
|
|
||||||
state = MfaPolicyService.evaluate_user_mfa_state(test_user)
|
|
||||||
|
|
||||||
assert state is not None
|
|
||||||
assert state.overall_status == MfaComplianceStatus.COMPLIANT.value
|
|
||||||
assert len(state.missing_methods) == 0
|
|
||||||
assert len(state.orgs) == 1
|
|
||||||
|
|
||||||
def test_evaluate_user_mfa_state_with_policy(self, db, test_user, test_organization):
|
|
||||||
"""Test evaluating user MFA state with policy."""
|
|
||||||
# Create membership
|
|
||||||
membership = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
membership.save()
|
|
||||||
|
|
||||||
# Create org policy
|
|
||||||
org_policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
)
|
|
||||||
org_policy.save()
|
|
||||||
|
|
||||||
state = MfaPolicyService.evaluate_user_mfa_state(test_user)
|
|
||||||
|
|
||||||
assert state is not None
|
|
||||||
assert state.overall_status == MfaComplianceStatus.IN_GRACE.value
|
|
||||||
assert "totp" in state.missing_methods
|
|
||||||
assert len(state.orgs) == 1
|
|
||||||
assert state.orgs[0].effective_mode == MfaPolicyMode.REQUIRE_TOTP.value
|
|
||||||
|
|
||||||
def test_after_primary_auth_success_no_required_policy(self, db, test_user, test_organization):
|
|
||||||
"""Test after_primary_auth_success with no required policy."""
|
|
||||||
# Create membership
|
|
||||||
membership = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
membership.save()
|
|
||||||
|
|
||||||
result = MfaPolicyService.after_primary_auth_success(test_user)
|
|
||||||
|
|
||||||
assert result.can_create_full_session is True
|
|
||||||
assert result.create_compliance_only_session is False
|
|
||||||
assert result.compliance_summary.overall_status == MfaComplianceStatus.COMPLIANT.value
|
|
||||||
|
|
||||||
def test_after_primary_auth_success_in_grace(self, db, test_user, test_organization):
|
|
||||||
"""Test after_primary_auth_success when user is in grace period."""
|
|
||||||
# Create membership
|
|
||||||
membership = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
membership.save()
|
|
||||||
|
|
||||||
# Create org policy
|
|
||||||
org_policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
)
|
|
||||||
org_policy.save()
|
|
||||||
|
|
||||||
result = MfaPolicyService.after_primary_auth_success(test_user)
|
|
||||||
|
|
||||||
assert result.can_create_full_session is True
|
|
||||||
assert result.create_compliance_only_session is False
|
|
||||||
assert result.compliance_summary.overall_status == MfaComplianceStatus.IN_GRACE.value
|
|
||||||
|
|
||||||
def test_after_primary_auth_success_past_due(self, db, test_user, test_organization):
|
|
||||||
"""Test after_primary_auth_success when user is past due."""
|
|
||||||
# Create membership
|
|
||||||
membership = OrganizationMember(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
role=OrganizationRole.MEMBER,
|
|
||||||
)
|
|
||||||
membership.save()
|
|
||||||
|
|
||||||
# Create org policy
|
|
||||||
org_policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
)
|
|
||||||
org_policy.save()
|
|
||||||
|
|
||||||
# Create compliance record past due
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.PAST_DUE,
|
|
||||||
policy_version=1,
|
|
||||||
applied_at=datetime.now(timezone.utc) - timedelta(days=30),
|
|
||||||
deadline_at=datetime.now(timezone.utc) - timedelta(days=1),
|
|
||||||
)
|
|
||||||
compliance.save()
|
|
||||||
|
|
||||||
result = MfaPolicyService.after_primary_auth_success(test_user)
|
|
||||||
|
|
||||||
assert result.can_create_full_session is False
|
|
||||||
assert result.create_compliance_only_session is True
|
|
||||||
|
|
||||||
def test_create_org_policy_new(self, db, test_organization):
|
|
||||||
"""Test creating a new organization policy."""
|
|
||||||
policy = MfaPolicyService.create_org_policy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
updated_by_user_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert policy is not None
|
|
||||||
assert policy.organization_id == test_organization.id
|
|
||||||
assert policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN
|
|
||||||
assert policy.policy_version == 1
|
|
||||||
|
|
||||||
def test_create_org_policy_update(self, db, test_organization):
|
|
||||||
"""Test updating an existing organization policy."""
|
|
||||||
# Create initial policy
|
|
||||||
initial_policy = OrganizationSecurityPolicy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.OPTIONAL,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
)
|
|
||||||
initial_policy.save()
|
|
||||||
|
|
||||||
# Update policy
|
|
||||||
updated_policy = MfaPolicyService.create_org_policy(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP,
|
|
||||||
mfa_grace_period_days=7,
|
|
||||||
updated_by_user_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert updated_policy.mfa_policy_mode == MfaPolicyMode.REQUIRE_TOTP
|
|
||||||
assert updated_policy.mfa_grace_period_days == 7
|
|
||||||
assert updated_policy.policy_version == 2
|
|
||||||
|
|
||||||
def test_set_user_override_new(self, db, test_user, test_organization):
|
|
||||||
"""Test setting a new user override."""
|
|
||||||
override = MfaPolicyService.set_user_override(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.REQUIRED,
|
|
||||||
force_totp=True,
|
|
||||||
force_webauthn=False,
|
|
||||||
updated_by_user_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert override is not None
|
|
||||||
assert override.user_id == test_user.id
|
|
||||||
assert override.organization_id == test_organization.id
|
|
||||||
assert override.mfa_override_mode == MfaRequirementOverride.REQUIRED
|
|
||||||
assert override.force_totp is True
|
|
||||||
|
|
||||||
def test_set_user_override_update(self, db, test_user, test_organization):
|
|
||||||
"""Test updating an existing user override."""
|
|
||||||
# Create initial override
|
|
||||||
initial_override = UserSecurityPolicy(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.INHERIT,
|
|
||||||
)
|
|
||||||
initial_override.save()
|
|
||||||
|
|
||||||
# Update override
|
|
||||||
updated_override = MfaPolicyService.set_user_override(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
mfa_override_mode=MfaRequirementOverride.EXEMPT,
|
|
||||||
updated_by_user_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert updated_override.mfa_override_mode == MfaRequirementOverride.EXEMPT
|
|
||||||
|
|
||||||
def test_get_user_compliance(self, db, test_user, test_organization):
|
|
||||||
"""Test getting user compliance record."""
|
|
||||||
# Create compliance record
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.COMPLIANT,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
compliance.save()
|
|
||||||
|
|
||||||
result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result.status == MfaComplianceStatus.COMPLIANT
|
|
||||||
|
|
||||||
def test_get_user_compliance_not_found(self, db, test_user, test_organization):
|
|
||||||
"""Test getting user compliance record when none exists."""
|
|
||||||
result = MfaPolicyService.get_user_compliance(test_user.id, test_organization.id)
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
def test_get_org_compliance_list(self, db, test_user, test_organization):
|
|
||||||
"""Test getting organization compliance list."""
|
|
||||||
# Create compliance record
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.IN_GRACE,
|
|
||||||
policy_version=1,
|
|
||||||
deadline_at=datetime.now(timezone.utc) + timedelta(days=14),
|
|
||||||
)
|
|
||||||
compliance.save()
|
|
||||||
|
|
||||||
results = MfaPolicyService.get_org_compliance_list(test_organization.id)
|
|
||||||
|
|
||||||
assert len(results) == 1
|
|
||||||
assert results[0]["user_id"] == test_user.id
|
|
||||||
assert results[0]["status"] == MfaComplianceStatus.IN_GRACE.value
|
|
||||||
|
|
||||||
def test_get_org_compliance_list_with_status_filter(self, db, test_user, test_organization):
|
|
||||||
"""Test getting organization compliance list with status filter."""
|
|
||||||
# Create compliance record
|
|
||||||
compliance = MfaPolicyCompliance(
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
status=MfaComplianceStatus.COMPLIANT,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
compliance.save()
|
|
||||||
|
|
||||||
# Filter by different status
|
|
||||||
results = MfaPolicyService.get_org_compliance_list(
|
|
||||||
test_organization.id, status=MfaComplianceStatus.IN_GRACE
|
|
||||||
)
|
|
||||||
assert len(results) == 0
|
|
||||||
|
|
||||||
# Filter by correct status
|
|
||||||
results = MfaPolicyService.get_org_compliance_list(
|
|
||||||
test_organization.id, status=MfaComplianceStatus.COMPLIANT
|
|
||||||
)
|
|
||||||
assert len(results) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestMfaPolicyServiceDto:
|
|
||||||
"""Tests for MfaPolicyService DTOs."""
|
|
||||||
|
|
||||||
def test_org_policy_dto(self):
|
|
||||||
"""Test OrgPolicyDto creation."""
|
|
||||||
dto = OrgPolicyDto(
|
|
||||||
organization_id="org-123",
|
|
||||||
mfa_policy_mode=MfaPolicyMode.REQUIRE_TOTP.value,
|
|
||||||
mfa_grace_period_days=14,
|
|
||||||
notify_days_before=7,
|
|
||||||
policy_version=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert dto.organization_id == "org-123"
|
|
||||||
assert dto.mfa_policy_mode == "require_totp"
|
|
||||||
assert dto.mfa_grace_period_days == 14
|
|
||||||
|
|
||||||
def test_effective_user_policy_dto(self):
|
|
||||||
"""Test EffectiveUserPolicyDto creation."""
|
|
||||||
dto = EffectiveUserPolicyDto(
|
|
||||||
organization_id="org-123",
|
|
||||||
effective_mode=MfaPolicyMode.REQUIRE_TOTP_OR_WEBAUTHN.value,
|
|
||||||
requires_totp=True,
|
|
||||||
requires_webauthn=True,
|
|
||||||
grace_period_days=14,
|
|
||||||
is_exempt=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert dto.requires_totp is True
|
|
||||||
assert dto.requires_webauthn is True
|
|
||||||
assert dto.is_exempt is False
|
|
||||||
|
|
||||||
def test_aggregate_mfa_state_dto(self):
|
|
||||||
"""Test AggregateMfaStateDto creation."""
|
|
||||||
dto = AggregateMfaStateDto(
|
|
||||||
overall_status=MfaComplianceStatus.IN_GRACE.value,
|
|
||||||
missing_methods=["totp"],
|
|
||||||
deadline_at="2025-02-01T00:00:00Z",
|
|
||||||
orgs=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert dto.overall_status == "in_grace"
|
|
||||||
assert "totp" in dto.missing_methods
|
|
||||||
assert dto.deadline_at == "2025-02-01T00:00:00Z"
|
|
||||||
|
|
||||||
def test_login_policy_result(self):
|
|
||||||
"""Test LoginPolicyResult creation."""
|
|
||||||
summary = AggregateMfaStateDto(
|
|
||||||
overall_status=MfaComplianceStatus.IN_GRACE.value,
|
|
||||||
missing_methods=["totp"],
|
|
||||||
orgs=[],
|
|
||||||
)
|
|
||||||
result = LoginPolicyResult(
|
|
||||||
can_create_full_session=True,
|
|
||||||
create_compliance_only_session=False,
|
|
||||||
compliance_summary=summary,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.can_create_full_session is True
|
|
||||||
assert result.create_compliance_only_session is False
|
|
||||||
assert result.compliance_summary.overall_status == "in_grace"
|
|
||||||
@@ -1,533 +0,0 @@
|
|||||||
"""Unit tests for OAuthFlowService."""
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch, MagicMock
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
|
|
||||||
from gatehouse_app.services.oauth_flow_service import (
|
|
||||||
OAuthFlowService,
|
|
||||||
OAuthFlowError,
|
|
||||||
)
|
|
||||||
from gatehouse_app.services.external_auth_service import OAuthState, ExternalProviderConfig
|
|
||||||
from gatehouse_app.utils.constants import AuthMethodType
|
|
||||||
from gatehouse_app.models import User, AuthenticationMethod
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestOAuthFlowService:
|
|
||||||
"""Tests for OAuthFlowService."""
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
|
|
||||||
def test_initiate_login_flow_success(self, mock_audit, app, db, test_organization):
|
|
||||||
"""Test initiating login flow successfully."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
with app.test_request_context():
|
|
||||||
auth_url, state = OAuthFlowService.initiate_login_flow(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert auth_url is not None
|
|
||||||
assert state is not None
|
|
||||||
assert len(state) == 43
|
|
||||||
|
|
||||||
# Verify state was created with correct flow type
|
|
||||||
state_record = OAuthState.query.filter_by(state=state).first()
|
|
||||||
assert state_record is not None
|
|
||||||
assert state_record.flow_type == "login"
|
|
||||||
assert state_record.organization_id == test_organization.id
|
|
||||||
|
|
||||||
def test_initiate_login_flow_invalid_redirect_uri(self, app, db, test_organization):
|
|
||||||
"""Test initiating login flow with invalid redirect URI."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
with app.test_request_context():
|
|
||||||
with pytest.raises(OAuthFlowError) as exc_info:
|
|
||||||
OAuthFlowService.initiate_login_flow(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://malicious.com/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "INVALID_REDIRECT_URI"
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
|
|
||||||
def test_initiate_register_flow_success(self, mock_audit, app, db, test_organization):
|
|
||||||
"""Test initiating register flow successfully."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
with app.test_request_context():
|
|
||||||
auth_url, state = OAuthFlowService.initiate_register_flow(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert auth_url is not None
|
|
||||||
assert state is not None
|
|
||||||
|
|
||||||
# Verify state was created with correct flow type
|
|
||||||
state_record = OAuthState.query.filter_by(state=state).first()
|
|
||||||
assert state_record is not None
|
|
||||||
assert state_record.flow_type == "register"
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService.authenticate_with_provider')
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
|
|
||||||
def test_handle_callback_login_flow(
|
|
||||||
self, mock_audit, mock_authenticate,
|
|
||||||
app, db, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test handling callback for login flow."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create authentication method
|
|
||||||
auth_method = AuthenticationMethod(
|
|
||||||
user_id=test_user.id,
|
|
||||||
method_type=AuthMethodType.GOOGLE,
|
|
||||||
provider_user_id="google-123",
|
|
||||||
provider_data={"email": test_user.email},
|
|
||||||
verified=True,
|
|
||||||
)
|
|
||||||
auth_method.save()
|
|
||||||
|
|
||||||
# Create login state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock authentication
|
|
||||||
mock_authenticate.return_value = (test_user, {"token": "session-token", "expires_in": 86400})
|
|
||||||
|
|
||||||
with app.test_request_context():
|
|
||||||
result = OAuthFlowService.handle_callback(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state=state.state,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert result["flow_type"] == "login"
|
|
||||||
assert result["user"]["id"] == test_user.id
|
|
||||||
assert result["session"]["token"] == "session-token"
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService.complete_link_flow')
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
|
|
||||||
def test_handle_callback_link_flow(
|
|
||||||
self, mock_audit, mock_complete_link,
|
|
||||||
app, db, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test handling callback for link flow."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create link state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="link",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
user_id=test_user.id,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock complete link
|
|
||||||
mock_auth_method = Mock()
|
|
||||||
mock_auth_method.id = "auth-method-123"
|
|
||||||
mock_auth_method.provider_user_id = "google-123"
|
|
||||||
mock_auth_method.verified = True
|
|
||||||
mock_complete_link.return_value = mock_auth_method
|
|
||||||
|
|
||||||
with app.test_request_context():
|
|
||||||
result = OAuthFlowService.handle_callback(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state=state.state,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert result["flow_type"] == "link"
|
|
||||||
assert result["linked_account"]["id"] == "auth-method-123"
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService._exchange_code')
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService._get_user_info')
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService._encrypt_provider_data')
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
|
|
||||||
@patch('gatehouse_app.services.auth_service.AuthService.create_session')
|
|
||||||
def test_handle_callback_register_flow(
|
|
||||||
self, mock_create_session, mock_audit, mock_encrypt,
|
|
||||||
mock_get_user_info, mock_exchange_code,
|
|
||||||
app, db, test_organization
|
|
||||||
):
|
|
||||||
"""Test handling callback for register flow."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create register state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="register",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock external provider responses
|
|
||||||
mock_exchange_code.return_value = {
|
|
||||||
"access_token": "mock-access-token",
|
|
||||||
"refresh_token": "mock-refresh-token",
|
|
||||||
"id_token": "mock-id-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_get_user_info.return_value = {
|
|
||||||
"provider_user_id": "google-new-123",
|
|
||||||
"email": "newuser@gmail.com",
|
|
||||||
"email_verified": True,
|
|
||||||
"name": "New User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
"raw_data": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_encrypt.return_value = {
|
|
||||||
"access_token": "mock-access-token",
|
|
||||||
"email": "newuser@gmail.com",
|
|
||||||
"name": "New User",
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_session = Mock()
|
|
||||||
mock_session.to_dict.return_value = {"token": "session-token", "expires_in": 86400}
|
|
||||||
mock_create_session.return_value = mock_session
|
|
||||||
|
|
||||||
with app.test_request_context():
|
|
||||||
result = OAuthFlowService.handle_callback(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state=state.state,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert result["flow_type"] == "register"
|
|
||||||
assert result["user"]["email"] == "newuser@gmail.com"
|
|
||||||
assert result["session"]["token"] == "session-token"
|
|
||||||
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService._exchange_code')
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.ExternalAuthService._get_user_info')
|
|
||||||
@patch('gatehouse_app.services.oauth_flow_service.AuditService')
|
|
||||||
def test_handle_callback_register_flow_email_exists(
|
|
||||||
self, mock_audit, mock_get_user_info, mock_exchange_code,
|
|
||||||
app, db, test_user, test_organization
|
|
||||||
):
|
|
||||||
"""Test handling callback for register flow when email already exists."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create register state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="register",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock external provider responses
|
|
||||||
mock_exchange_code.return_value = {
|
|
||||||
"access_token": "mock-access-token",
|
|
||||||
"refresh_token": "mock-refresh-token",
|
|
||||||
"id_token": "mock-id-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Return email that matches existing user
|
|
||||||
mock_get_user_info.return_value = {
|
|
||||||
"provider_user_id": "google-new-123",
|
|
||||||
"email": test_user.email, # Existing email
|
|
||||||
"email_verified": True,
|
|
||||||
"name": "Test User",
|
|
||||||
"picture": "https://example.com/avatar.jpg",
|
|
||||||
"raw_data": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
with app.test_request_context():
|
|
||||||
with pytest.raises(OAuthFlowError) as exc_info:
|
|
||||||
OAuthFlowService.handle_callback(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state=state.state,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "EMAIL_EXISTS"
|
|
||||||
|
|
||||||
def test_handle_callback_invalid_state(self, app, db):
|
|
||||||
"""Test handling callback with invalid state."""
|
|
||||||
with app.app_context():
|
|
||||||
with app.test_request_context():
|
|
||||||
with pytest.raises(OAuthFlowError) as exc_info:
|
|
||||||
OAuthFlowService.handle_callback(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state="invalid-state",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "INVALID_STATE"
|
|
||||||
|
|
||||||
def test_handle_callback_provider_error(self, app, db):
|
|
||||||
"""Test handling callback with provider error."""
|
|
||||||
with app.app_context():
|
|
||||||
with app.test_request_context():
|
|
||||||
with pytest.raises(OAuthFlowError) as exc_info:
|
|
||||||
OAuthFlowService.handle_callback(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
authorization_code=None,
|
|
||||||
state=None,
|
|
||||||
error="access_denied",
|
|
||||||
error_description="User denied access",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "ACCESS_DENIED"
|
|
||||||
|
|
||||||
def test_handle_callback_unknown_flow_type(self, app, db, test_organization):
|
|
||||||
"""Test handling callback with unknown flow type."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create state with unknown flow type
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="unknown",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
redirect_uri="http://localhost:3000/callback",
|
|
||||||
)
|
|
||||||
|
|
||||||
with app.test_request_context():
|
|
||||||
with pytest.raises(OAuthFlowError) as exc_info:
|
|
||||||
OAuthFlowService.handle_callback(
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
authorization_code="mock-auth-code",
|
|
||||||
state=state.state,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exc_info.value.error_type == "INVALID_FLOW_TYPE"
|
|
||||||
|
|
||||||
def test_validate_state_valid(self, app, db, test_organization):
|
|
||||||
"""Test validating a valid state."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = OAuthFlowService.validate_state(state.state)
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result.id == state.id
|
|
||||||
|
|
||||||
def test_validate_state_invalid(self, app, db):
|
|
||||||
"""Test validating an invalid state."""
|
|
||||||
with app.app_context():
|
|
||||||
result = OAuthFlowService.validate_state("nonexistent-state")
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
def test_validate_state_expired(self, app, db, test_organization):
|
|
||||||
"""Test validating an expired state."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create expired state
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
lifetime_seconds=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = OAuthFlowService.validate_state(state.state)
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
def test_validate_state_used(self, app, db, test_organization):
|
|
||||||
"""Test validating a used state."""
|
|
||||||
with app.app_context():
|
|
||||||
# Create provider config
|
|
||||||
config = ExternalProviderConfig(
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
provider_type=AuthMethodType.GOOGLE.value,
|
|
||||||
client_id="test-client-id",
|
|
||||||
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
|
|
||||||
token_url="https://oauth2.googleapis.com/token",
|
|
||||||
userinfo_url="https://www.googleapis.com/oauth2/v3/userinfo",
|
|
||||||
scopes=["openid", "profile", "email"],
|
|
||||||
redirect_uris=["http://localhost:3000/callback"],
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Create and mark state as used
|
|
||||||
state = OAuthState.create_state(
|
|
||||||
flow_type="login",
|
|
||||||
provider_type=AuthMethodType.GOOGLE,
|
|
||||||
organization_id=test_organization.id,
|
|
||||||
)
|
|
||||||
state.mark_used()
|
|
||||||
|
|
||||||
result = OAuthFlowService.validate_state(state.state)
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestOAuthFlowError:
|
|
||||||
"""Tests for OAuthFlowError exception."""
|
|
||||||
|
|
||||||
def test_error_creation(self):
|
|
||||||
"""Test creating OAuthFlowError."""
|
|
||||||
error = OAuthFlowError(
|
|
||||||
message="Test error message",
|
|
||||||
error_type="TEST_ERROR",
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert error.message == "Test error message"
|
|
||||||
assert error.error_type == "TEST_ERROR"
|
|
||||||
assert error.status_code == 400
|
|
||||||
|
|
||||||
def test_error_default_status_code(self):
|
|
||||||
"""Test OAuthFlowError with default status code."""
|
|
||||||
error = OAuthFlowError(
|
|
||||||
message="Test error message",
|
|
||||||
error_type="TEST_ERROR",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert error.status_code == 400
|
|
||||||
@@ -1,285 +0,0 @@
|
|||||||
"""Unit tests for TOTPService."""
|
|
||||||
import base64
|
|
||||||
import pytest
|
|
||||||
from gatehouse_app.services.totp_service import TOTPService
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestTOTPService:
|
|
||||||
"""Tests for TOTPService."""
|
|
||||||
|
|
||||||
# Test generate_secret()
|
|
||||||
def test_generate_secret_returns_string(self):
|
|
||||||
"""Test that generate_secret returns a string."""
|
|
||||||
secret = TOTPService.generate_secret()
|
|
||||||
assert isinstance(secret, str)
|
|
||||||
|
|
||||||
def test_generate_secret_length(self):
|
|
||||||
"""Test that generate_secret returns a 32-character string."""
|
|
||||||
secret = TOTPService.generate_secret()
|
|
||||||
assert len(secret) == 32
|
|
||||||
|
|
||||||
def test_generate_secret_base32_encoded(self):
|
|
||||||
"""Test that generate_secret returns a base32 encoded string."""
|
|
||||||
secret = TOTPService.generate_secret()
|
|
||||||
# Base32 characters are A-Z and 2-7
|
|
||||||
valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567")
|
|
||||||
assert all(c in valid_chars for c in secret)
|
|
||||||
|
|
||||||
def test_generate_secret_unique(self):
|
|
||||||
"""Test that generate_secret produces unique secrets."""
|
|
||||||
secret1 = TOTPService.generate_secret()
|
|
||||||
secret2 = TOTPService.generate_secret()
|
|
||||||
assert secret1 != secret2
|
|
||||||
|
|
||||||
# Test generate_provisioning_uri()
|
|
||||||
def test_generate_provisioning_uri_format(self):
|
|
||||||
"""Test that provisioning URI is generated correctly."""
|
|
||||||
email = "user@example.com"
|
|
||||||
secret = "JBSWY3DPEHPK3PXP"
|
|
||||||
issuer = "Gatehouse"
|
|
||||||
|
|
||||||
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
|
|
||||||
|
|
||||||
assert isinstance(uri, str)
|
|
||||||
assert uri.startswith("otpauth://totp/")
|
|
||||||
|
|
||||||
def test_generate_provisioning_uri_contains_email(self):
|
|
||||||
"""Test that provisioning URI contains the user email."""
|
|
||||||
email = "user@example.com"
|
|
||||||
secret = "JBSWY3DPEHPK3PXP"
|
|
||||||
issuer = "Gatehouse"
|
|
||||||
|
|
||||||
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
|
|
||||||
|
|
||||||
assert email in uri
|
|
||||||
|
|
||||||
def test_generate_provisioning_uri_contains_secret(self):
|
|
||||||
"""Test that provisioning URI contains the secret."""
|
|
||||||
email = "user@example.com"
|
|
||||||
secret = "JBSWY3DPEHPK3PXP"
|
|
||||||
issuer = "Gatehouse"
|
|
||||||
|
|
||||||
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
|
|
||||||
|
|
||||||
assert secret in uri
|
|
||||||
|
|
||||||
def test_generate_provisioning_uri_contains_issuer(self):
|
|
||||||
"""Test that provisioning URI contains the issuer."""
|
|
||||||
email = "user@example.com"
|
|
||||||
secret = "JBSWY3DPEHPK3PXP"
|
|
||||||
issuer = "Gatehouse"
|
|
||||||
|
|
||||||
uri = TOTPService.generate_provisioning_uri(email, secret, issuer)
|
|
||||||
|
|
||||||
assert issuer in uri
|
|
||||||
|
|
||||||
def test_generate_provisioning_uri_custom_issuer(self):
|
|
||||||
"""Test that provisioning URI uses custom issuer."""
|
|
||||||
email = "user@example.com"
|
|
||||||
secret = "JBSWY3DPEHPK3PXP"
|
|
||||||
custom_issuer = "MyApp"
|
|
||||||
|
|
||||||
uri = TOTPService.generate_provisioning_uri(email, secret, custom_issuer)
|
|
||||||
|
|
||||||
assert custom_issuer in uri
|
|
||||||
|
|
||||||
# Test verify_code()
|
|
||||||
def test_verify_code_valid(self):
|
|
||||||
"""Test that a valid TOTP code is accepted."""
|
|
||||||
secret = TOTPService.generate_secret()
|
|
||||||
# Generate a valid code using pyotp
|
|
||||||
import pyotp
|
|
||||||
totp = pyotp.TOTP(secret)
|
|
||||||
valid_code = totp.now()
|
|
||||||
|
|
||||||
result = TOTPService.verify_code(secret, valid_code)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
def test_verify_code_invalid(self):
|
|
||||||
"""Test that an invalid TOTP code is rejected."""
|
|
||||||
secret = TOTPService.generate_secret()
|
|
||||||
invalid_code = "000000"
|
|
||||||
|
|
||||||
result = TOTPService.verify_code(secret, invalid_code)
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
def test_verify_code_window_parameter(self):
|
|
||||||
"""Test that the time window parameter works correctly."""
|
|
||||||
secret = TOTPService.generate_secret()
|
|
||||||
import pyotp
|
|
||||||
totp = pyotp.TOTP(secret)
|
|
||||||
|
|
||||||
# Get current code
|
|
||||||
current_code = totp.now()
|
|
||||||
|
|
||||||
# Verify with window=1 (default) - should accept current code
|
|
||||||
result = TOTPService.verify_code(secret, current_code, window=1)
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
# Verify with window=0 - should only accept exact time match
|
|
||||||
result = TOTPService.verify_code(secret, current_code, window=0)
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
def test_verify_code_wrong_length(self):
|
|
||||||
"""Test that codes with wrong length are rejected."""
|
|
||||||
secret = TOTPService.generate_secret()
|
|
||||||
wrong_length_code = "12345" # 5 digits instead of 6
|
|
||||||
|
|
||||||
result = TOTPService.verify_code(secret, wrong_length_code)
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
# Test generate_backup_codes()
|
|
||||||
def test_generate_backup_codes_default_count(self):
|
|
||||||
"""Test that generate_backup_codes generates 10 codes by default."""
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
|
||||||
|
|
||||||
assert len(plain_codes) == 10
|
|
||||||
assert len(hashed_codes) == 10
|
|
||||||
|
|
||||||
def test_generate_backup_codes_custom_count(self):
|
|
||||||
"""Test that generate_backup_codes generates the specified number of codes."""
|
|
||||||
count = 5
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count)
|
|
||||||
|
|
||||||
assert len(plain_codes) == count
|
|
||||||
assert len(hashed_codes) == count
|
|
||||||
|
|
||||||
def test_generate_backup_codes_plain_are_strings(self):
|
|
||||||
"""Test that plain backup codes are strings."""
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
|
||||||
|
|
||||||
assert all(isinstance(code, str) for code in plain_codes)
|
|
||||||
|
|
||||||
def test_generate_backup_codes_plain_length(self):
|
|
||||||
"""Test that plain backup codes are 16 characters long."""
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
|
||||||
|
|
||||||
assert all(len(code) == 16 for code in plain_codes)
|
|
||||||
|
|
||||||
def test_generate_backup_codes_hashed_different_from_plain(self):
|
|
||||||
"""Test that hashed codes are different from plain codes."""
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
|
||||||
|
|
||||||
for plain, hashed in zip(plain_codes, hashed_codes):
|
|
||||||
assert plain != hashed
|
|
||||||
|
|
||||||
def test_generate_backup_codes_are_bcrypt_hashes(self):
|
|
||||||
"""Test that hashed codes are bcrypt hashes."""
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
|
||||||
|
|
||||||
# Bcrypt hashes start with $2a$, $2b$, or $2y$
|
|
||||||
for hashed in hashed_codes:
|
|
||||||
assert hashed.startswith("$2")
|
|
||||||
|
|
||||||
def test_generate_backup_codes_unique(self):
|
|
||||||
"""Test that generated backup codes are unique."""
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes()
|
|
||||||
|
|
||||||
assert len(set(plain_codes)) == len(plain_codes)
|
|
||||||
assert len(set(hashed_codes)) == len(hashed_codes)
|
|
||||||
|
|
||||||
# Test verify_backup_code()
|
|
||||||
def test_verify_backup_code_valid(self):
|
|
||||||
"""Test that a valid backup code is accepted and removed."""
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=3)
|
|
||||||
code_to_verify = plain_codes[0]
|
|
||||||
|
|
||||||
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
|
|
||||||
|
|
||||||
assert is_valid is True
|
|
||||||
assert len(remaining_codes) == 2
|
|
||||||
|
|
||||||
def test_verify_backup_code_invalid(self):
|
|
||||||
"""Test that an invalid backup code is rejected."""
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=3)
|
|
||||||
invalid_code = "INVALIDCODE1234"
|
|
||||||
|
|
||||||
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, invalid_code)
|
|
||||||
|
|
||||||
assert is_valid is False
|
|
||||||
assert len(remaining_codes) == 3
|
|
||||||
|
|
||||||
def test_verify_backup_code_remaining_updated(self):
|
|
||||||
"""Test that the remaining codes list is updated correctly."""
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=5)
|
|
||||||
code_to_verify = plain_codes[2]
|
|
||||||
|
|
||||||
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
|
|
||||||
|
|
||||||
assert is_valid is True
|
|
||||||
# The verified code should be removed
|
|
||||||
assert len(remaining_codes) == 4
|
|
||||||
# The remaining codes should not include the verified code's hash
|
|
||||||
assert hashed_codes[2] not in remaining_codes
|
|
||||||
|
|
||||||
def test_verify_backup_code_case_sensitive(self):
|
|
||||||
"""Test that backup code verification is case sensitive."""
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=1)
|
|
||||||
code_to_verify = plain_codes[0].lower() # Convert to lowercase
|
|
||||||
|
|
||||||
is_valid, remaining_codes = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
|
|
||||||
|
|
||||||
assert is_valid is False
|
|
||||||
assert len(remaining_codes) == 1
|
|
||||||
|
|
||||||
def test_verify_backup_code_single_use(self):
|
|
||||||
"""Test that a backup code can only be used once."""
|
|
||||||
plain_codes, hashed_codes = TOTPService.generate_backup_codes(count=1)
|
|
||||||
code_to_verify = plain_codes[0]
|
|
||||||
|
|
||||||
# First use - should succeed
|
|
||||||
is_valid1, remaining1 = TOTPService.verify_backup_code(hashed_codes, code_to_verify)
|
|
||||||
assert is_valid1 is True
|
|
||||||
assert len(remaining1) == 0
|
|
||||||
|
|
||||||
# Second use - should fail (code already consumed)
|
|
||||||
is_valid2, remaining2 = TOTPService.verify_backup_code(remaining1, code_to_verify)
|
|
||||||
assert is_valid2 is False
|
|
||||||
assert len(remaining2) == 0
|
|
||||||
|
|
||||||
# Test generate_qr_code_data_uri()
|
|
||||||
def test_generate_qr_code_data_uri_format(self):
|
|
||||||
"""Test that a data URI is generated."""
|
|
||||||
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
|
|
||||||
|
|
||||||
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
|
|
||||||
|
|
||||||
assert isinstance(data_uri, str)
|
|
||||||
|
|
||||||
def test_generate_qr_code_data_uri_starts_with_prefix(self):
|
|
||||||
"""Test that the data URI starts with the correct prefix."""
|
|
||||||
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
|
|
||||||
|
|
||||||
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
|
|
||||||
|
|
||||||
assert data_uri.startswith("data:image/png;base64,")
|
|
||||||
|
|
||||||
def test_generate_qr_code_data_uri_contains_base64(self):
|
|
||||||
"""Test that the data URI contains base64 encoded data."""
|
|
||||||
provisioning_uri = "otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
|
|
||||||
|
|
||||||
data_uri = TOTPService.generate_qr_code_data_uri(provisioning_uri)
|
|
||||||
|
|
||||||
# Extract the base64 part (after the prefix)
|
|
||||||
base64_part = data_uri.split("data:image/png;base64,")[1]
|
|
||||||
|
|
||||||
# Verify it's valid base64
|
|
||||||
try:
|
|
||||||
base64.b64decode(base64_part)
|
|
||||||
assert True
|
|
||||||
except Exception:
|
|
||||||
assert False, "Data URI does not contain valid base64 data"
|
|
||||||
|
|
||||||
def test_generate_qr_code_data_uri_different_uris(self):
|
|
||||||
"""Test that different provisioning URIs generate different QR codes."""
|
|
||||||
uri1 = "otpauth://totp/Gatehouse:user1@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
|
|
||||||
uri2 = "otpauth://totp/Gatehouse:user2@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse"
|
|
||||||
|
|
||||||
data_uri1 = TOTPService.generate_qr_code_data_uri(uri1)
|
|
||||||
data_uri2 = TOTPService.generate_qr_code_data_uri(uri2)
|
|
||||||
|
|
||||||
assert data_uri1 != data_uri2
|
|
||||||
Reference in New Issue
Block a user