193 lines
6.4 KiB
Python
193 lines
6.4 KiB
Python
"""Billing service for superadmin operations."""
|
|
import logging
|
|
from datetime import datetime, timedelta, timezone
|
|
from gatehouse_app.models.organization.organization import Organization
|
|
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
|
from gatehouse_app.models.billing.plan import Plan
|
|
from gatehouse_app.models.billing.subscription import Subscription, SubscriptionStatus, BillingCycle
|
|
from gatehouse_app.extensions import db
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BillingService:
|
|
"""Service for billing operations."""
|
|
|
|
@staticmethod
|
|
def get_plan(plan_id: str) -> Plan:
|
|
"""Get a plan by ID."""
|
|
plan = Plan.query.get(plan_id)
|
|
if not plan:
|
|
raise ValueError("Plan not found")
|
|
return plan
|
|
|
|
@staticmethod
|
|
def list_plans() -> list:
|
|
"""List all active plans."""
|
|
return Plan.query.filter(Plan.is_active == True).order_by(Plan.price_monthly.asc()).all()
|
|
|
|
@staticmethod
|
|
def create_subscription(
|
|
organization_id: str,
|
|
plan_id: str,
|
|
billing_cycle: str = "monthly"
|
|
) -> Subscription:
|
|
"""Create a new subscription for an organization.
|
|
|
|
Args:
|
|
organization_id: Organization UUID
|
|
plan_id: Plan UUID
|
|
billing_cycle: 'monthly' or 'yearly'
|
|
|
|
Returns:
|
|
New subscription
|
|
"""
|
|
org = Organization.query.get(organization_id)
|
|
if not org:
|
|
raise ValueError("Organization not found")
|
|
|
|
plan = Plan.query.get(plan_id)
|
|
if not plan:
|
|
raise ValueError("Plan not found")
|
|
|
|
# Check if subscription already exists
|
|
existing = Subscription.query.filter_by(organization_id=organization_id).first()
|
|
if existing:
|
|
raise ValueError("Organization already has a subscription")
|
|
|
|
now = datetime.now(timezone.utc)
|
|
|
|
# Calculate period
|
|
if billing_cycle == "yearly":
|
|
period_end = now + timedelta(days=365)
|
|
else:
|
|
period_end = now + timedelta(days=30)
|
|
|
|
subscription = Subscription(
|
|
organization_id=organization_id,
|
|
plan_id=plan_id,
|
|
status=SubscriptionStatus.ACTIVE,
|
|
billing_cycle=BillingCycle.MONTHLY if billing_cycle == "monthly" else BillingCycle.YEARLY,
|
|
current_period_start=now,
|
|
current_period_end=period_end,
|
|
)
|
|
|
|
db.session.add(subscription)
|
|
db.session.commit()
|
|
|
|
return subscription
|
|
|
|
@staticmethod
|
|
def change_plan(organization_id: str, new_plan_id: str) -> Subscription:
|
|
"""Change subscription plan.
|
|
|
|
Args:
|
|
organization_id: Organization UUID
|
|
new_plan_id: New plan UUID
|
|
|
|
Returns:
|
|
Updated subscription
|
|
"""
|
|
subscription = Subscription.query.filter_by(organization_id=organization_id).first()
|
|
if not subscription:
|
|
raise ValueError("No subscription found for organization")
|
|
|
|
new_plan = Plan.query.get(new_plan_id)
|
|
if not new_plan:
|
|
raise ValueError("Plan not found")
|
|
|
|
subscription.plan_id = new_plan_id
|
|
db.session.commit()
|
|
|
|
return subscription
|
|
|
|
@staticmethod
|
|
def cancel_subscription(organization_id: str) -> Subscription:
|
|
"""Cancel subscription at period end.
|
|
|
|
Args:
|
|
organization_id: Organization UUID
|
|
|
|
Returns:
|
|
Updated subscription
|
|
"""
|
|
subscription = Subscription.query.filter_by(organization_id=organization_id).first()
|
|
if not subscription:
|
|
raise ValueError("No subscription found for organization")
|
|
|
|
subscription.cancel_at_period_end = True
|
|
subscription.status = SubscriptionStatus.CANCELLED
|
|
db.session.commit()
|
|
|
|
return subscription
|
|
|
|
@staticmethod
|
|
def extend_trial(organization_id: str, days: int) -> Subscription:
|
|
"""Extend trial period.
|
|
|
|
Args:
|
|
organization_id: Organization UUID
|
|
days: Number of days to extend
|
|
|
|
Returns:
|
|
Updated subscription
|
|
"""
|
|
subscription = Subscription.query.filter_by(organization_id=organization_id).first()
|
|
if not subscription:
|
|
raise ValueError("No subscription found for organization")
|
|
|
|
now = datetime.now(timezone.utc)
|
|
|
|
if subscription.trial_ends_at:
|
|
subscription.trial_ends_at = subscription.trial_ends_at + timedelta(days=days)
|
|
else:
|
|
subscription.trial_ends_at = now + timedelta(days=days)
|
|
|
|
subscription.status = SubscriptionStatus.TRIAL
|
|
db.session.commit()
|
|
|
|
return subscription
|
|
|
|
@staticmethod
|
|
def calculate_overage(organization_id: str) -> dict:
|
|
"""Calculate overage charges for an organization.
|
|
|
|
Args:
|
|
organization_id: Organization UUID
|
|
|
|
Returns:
|
|
Overage calculation with details
|
|
"""
|
|
subscription = Subscription.query.filter_by(organization_id=organization_id).first()
|
|
if not subscription:
|
|
return {"has_overage": False, "overage_cost": 0, "user_count": 0, "included_users": 0}
|
|
|
|
plan = Plan.query.get(subscription.plan_id) if subscription.plan_id else None
|
|
if not plan:
|
|
return {"has_overage": False, "overage_cost": 0, "user_count": 0, "included_users": 0}
|
|
|
|
# Count current users
|
|
user_count = OrganizationMember.query.filter(
|
|
OrganizationMember.organization_id == organization_id,
|
|
OrganizationMember.deleted_at.is_(None),
|
|
).count()
|
|
|
|
included_users = plan.included_users
|
|
overage_users = max(0, user_count - included_users)
|
|
|
|
if overage_users > 0 and plan.overage_rate_per_user > 0:
|
|
overage_cost = overage_users * plan.overage_rate_per_user
|
|
has_overage = True
|
|
else:
|
|
overage_cost = 0
|
|
has_overage = False
|
|
|
|
return {
|
|
"has_overage": has_overage,
|
|
"user_count": user_count,
|
|
"included_users": included_users,
|
|
"overage_users": overage_users,
|
|
"overage_rate_per_user": plan.overage_rate_per_user,
|
|
"overage_cost": overage_cost,
|
|
}
|