Skip to content

Commit

Permalink
server/subscription: fix handling of PWYW subscriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
frankie567 committed Feb 19, 2025
1 parent 98be449 commit 54a50b9
Show file tree
Hide file tree
Showing 11 changed files with 563 additions and 598 deletions.
11 changes: 6 additions & 5 deletions docs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1",
"dev": "mintlify dev",
"generate-webhooks": "tsx .polar/generate-webhooks.mts openapi.yaml snippets/webhooks"
"generate-webhooks": "tsx .polar/generate-webhooks.mts openapi.yaml snippets/webhooks",
"broken-links": "mintlify broken-links"
},
"keywords": [],
"author": "",
"license": "ISC",
"packageManager": "[email protected]+sha512.b2dc20e2fc72b3e18848459b37359a32064663e5627a51e4c74b2c29dd8e8e0491483c3abb40789cfd578bf362fb6ba8261b05f0387d76792ed6e23ea3b1b6a0",
"dependencies": {
"mintlify": "^4.0.388"
"mintlify": "^4.0.395"
},
"devDependencies": {
"@scalar/openapi-parser": "^0.10.6",
"@types/node": "^22.13.2",
"@scalar/openapi-parser": "^0.10.7",
"@types/node": "^22.13.4",
"js-yaml": "^4.1.0",
"openapi-sampler": "^1.6.1",
"tsx": "^4.19.2"
"tsx": "^4.19.3"
}
}
831 changes: 324 additions & 507 deletions docs/pnpm-lock.yaml

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Update Stripe Subscriptions
Revision ID: 21585ed16305
Revises: 69d1834e6285
Create Date: 2025-02-19 15:26:53.346054
"""

import concurrent.futures
import random
import time
import uuid
from typing import Any

import sqlalchemy as sa
import stripe as stripe_lib
from alembic import op

# Polar Custom Imports

# revision identifiers, used by Alembic.
revision = "21585ed16305"
down_revision = "69d1834e6285"
branch_labels: tuple[str] | None = None
depends_on: tuple[str] | None = None


def process_subscription(
subscription: tuple[str, uuid.UUID, uuid.UUID], retry: int = 1
) -> None:
stripe_id, product_id, price_id = subscription
metadata = {
"type": "product",
"product_id": str(product_id),
"product_price_id": str(price_id),
}
try:
stripe_lib.Subscription.modify(
stripe_id,
metadata=metadata,
)
except stripe_lib.RateLimitError:
time.sleep(retry + random.random())
return process_subscription(subscription, retry=retry + 1)


def process_subscriptions(results: sa.CursorResult[Any]) -> None:
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
for result in results:
executor.submit(process_subscription, result._tuple())


def upgrade() -> None:
connection = op.get_bind()
results = connection.execute(
sa.text("""
SELECT stripe_subscription_id, product_id, price_id
FROM subscriptions
WHERE stripe_subscription_id IS NOT NULL
""")
)
process_subscriptions(results)


def downgrade() -> None:
pass
2 changes: 2 additions & 0 deletions server/polar/integrations/stripe/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ async def update_subscription_price(
new_price: str,
proration_behavior: Literal["always_invoice", "create_prorations", "none"],
error_if_incomplete: bool,
metadata: dict[str, str],
) -> stripe_lib.Subscription:
subscription = await stripe_lib.Subscription.retrieve_async(id)

Expand All @@ -420,6 +421,7 @@ async def update_subscription_price(
payment_behavior=(
"error_if_incomplete" if error_if_incomplete else "allow_incomplete"
),
metadata=metadata,
)
except stripe_lib.InvalidRequestError as e:
error = e.error
Expand Down
10 changes: 8 additions & 2 deletions server/polar/kit/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from sqlalchemy import Select, func, over, select
from sqlalchemy.orm import Mapped
from sqlalchemy.sql.base import ExecutableOption

from polar.kit.db.postgres import AsyncSession
from polar.kit.utils import utc_now
Expand Down Expand Up @@ -117,9 +118,14 @@ def from_session(cls, session: AsyncSession) -> Self:

class RepositoryIDMixin(Generic[MODEL_ID, ID_TYPE]):
async def get_by_id(
self: RepositoryProtocol[MODEL_ID], id: ID_TYPE
self: RepositoryProtocol[MODEL_ID],
id: ID_TYPE,
*,
options: Sequence[ExecutableOption] = (),
) -> MODEL_ID | None:
statement = self.get_base_statement().where(self.model.id == id)
statement = (
self.get_base_statement().where(self.model.id == id).options(*options)
)
return await self.get_one_or_none(statement)


Expand Down
68 changes: 58 additions & 10 deletions server/polar/subscription/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PolarRequestValidationError,
ResourceUnavailable,
)
from polar.integrations.stripe.schemas import ProductType
from polar.integrations.stripe.service import stripe as stripe_service
from polar.integrations.stripe.utils import get_expandable_id
from polar.kit.db.postgres import AsyncSession
Expand Down Expand Up @@ -73,7 +74,6 @@
from polar.worker import enqueue_job

from ..product.service.product import product as product_service
from ..product.service.product_price import product_price as product_price_service
from .schemas import (
SubscriptionUpdate,
SubscriptionUpdateProduct,
Expand All @@ -85,6 +85,27 @@
class SubscriptionError(PolarError): ...


class InvalidSubscriptionMetadata(SubscriptionError):
def __init__(self, stripe_subscription_id: str) -> None:
self.stripe_subscription_id = stripe_subscription_id
message = (
f"Received the subscription {stripe_subscription_id} from Stripe, "
"but the metadata doesn't contain the product ID and product price ID."
)
super().__init__(message)


class AssociatedProductDoesNotExist(SubscriptionError):
def __init__(self, stripe_subscription_id: str, product_id: str) -> None:
self.subscription_id = stripe_subscription_id
self.product_id = product_id
message = (
f"Received the subscription {stripe_subscription_id} from Stripe "
f"with product {product_id}, but no associated Product exists."
)
super().__init__(message)


class AssociatedPriceDoesNotExist(SubscriptionError):
def __init__(self, stripe_subscription_id: str, stripe_price_id: str) -> None:
self.subscription_id = stripe_subscription_id
Expand Down Expand Up @@ -342,17 +363,25 @@ async def get_by_stripe_subscription_id(
async def create_subscription_from_stripe(
self, session: AsyncSession, *, stripe_subscription: stripe_lib.Subscription
) -> Subscription:
price_id = stripe_subscription["items"].data[0].price.id
price = await product_price_service.get_by_stripe_price_id(session, price_id)
if price is None:
raise AssociatedPriceDoesNotExist(stripe_subscription.id, price_id)
product_id = stripe_subscription.metadata.get("product_id")
price_id = stripe_subscription.metadata.get("product_price_id")
if product_id is None or price_id is None:
raise InvalidSubscriptionMetadata(stripe_subscription.id)

product = price.product
product_repository = ProductRepository.from_session(session)
product = await product_repository.get_by_id(
uuid.UUID(product_id), options=product_repository.get_eager_options()
)
if product is None:
raise AssociatedProductDoesNotExist(stripe_subscription.id, product_id)
if not product.is_recurring:
raise NotARecurringProduct(stripe_subscription.id, price_id)

organization = await organization_service.get(session, product.organization_id)
assert organization is not None
price = product.get_price(uuid.UUID(price_id))
if price is None:
raise AssociatedPriceDoesNotExist(stripe_subscription.id, price_id)

organization = product.organization

# Get Discount if available
discount: Discount | None = None
Expand Down Expand Up @@ -492,10 +521,24 @@ async def update_subscription_from_stripe(
subscription.set_started_at()
self.update_cancellation_from_stripe(subscription, stripe_subscription)

price_id = stripe_subscription["items"].data[0].price.id
price = await product_price_service.get_by_stripe_price_id(session, price_id)
product_id = stripe_subscription.metadata.get("product_id")
price_id = stripe_subscription.metadata.get("product_price_id")
if product_id is None or price_id is None:
raise InvalidSubscriptionMetadata(stripe_subscription.id)

product_repository = ProductRepository.from_session(session)
product = await product_repository.get_by_id(
uuid.UUID(product_id), options=product_repository.get_eager_options()
)
if product is None:
raise AssociatedProductDoesNotExist(stripe_subscription.id, product_id)
if not product.is_recurring:
raise NotARecurringProduct(stripe_subscription.id, price_id)

price = product.get_price(uuid.UUID(price_id))
if price is None:
raise AssociatedPriceDoesNotExist(stripe_subscription.id, price_id)

subscription.price = price
subscription.product = price.product

Expand Down Expand Up @@ -685,6 +728,11 @@ async def update_product(
new_price=price.stripe_price_id,
proration_behavior=proration_behavior.to_stripe(),
error_if_incomplete=isinstance(subscription.price, ProductPriceFree),
metadata={
"type": ProductType.product,
"product_id": str(product.id),
"product_price_id": str(price.id),
},
)
subscription.product = product
subscription.price = price
Expand Down
5 changes: 5 additions & 0 deletions server/tests/customer_portal/endpoints/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ async def test_valid(
new_price=new_price.stripe_price_id,
proration_behavior=organization.proration_behavior.to_stripe(),
error_if_incomplete=previous_free,
metadata={
"type": "product",
"product_id": str(product_second.id),
"product_price_id": str(new_price_id),
},
)

updated_subscription = response.json()
Expand Down
5 changes: 5 additions & 0 deletions server/tests/customer_portal/service/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ async def test_valid(
new_price=product_second.prices[0].stripe_price_id,
proration_behavior="create_prorations",
error_if_incomplete=False,
metadata={
"type": "product",
"product_id": str(product_second.id),
"product_price_id": str(new_price.id),
},
)


Expand Down
31 changes: 19 additions & 12 deletions server/tests/fixtures/stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Customer,
Discount,
Organization,
Product,
ProductPrice,
Subscription,
)
from polar.models.subscription import SubscriptionStatus
Expand All @@ -30,20 +32,19 @@ def cloned_stripe_subscription(
subscription: Subscription,
*,
customer: Customer | None = None,
price_id: str | None = None,
product: Product | None = None,
price: ProductPrice | None = None,
status: SubscriptionStatus | None = None,
cancel_at_period_end: bool | None = None,
revoke: bool = False,
) -> stripe_lib.Subscription:
if price_id is None:
price_id = subscription.price.stripe_price_id

if cancel_at_period_end is None:
cancel_at_period_end = subscription.cancel_at_period_end

return construct_stripe_subscription(
customer=customer if customer else subscription.customer,
price_id=price_id,
product=product if product else subscription.product,
price=price if price else subscription.price,
status=status if status else subscription.status,
cancel_at_period_end=cancel_at_period_end,
revoke=revoke,
Expand All @@ -52,9 +53,10 @@ def cloned_stripe_subscription(

def construct_stripe_subscription(
*,
product: Product | None,
price: ProductPrice | None = None,
customer: Customer | None = None,
organization: Organization | None = None,
price_id: str = "PRICE_ID",
status: SubscriptionStatus = SubscriptionStatus.incomplete,
latest_invoice: stripe_lib.Invoice | None = None,
cancel_at_period_end: bool = False,
Expand All @@ -63,12 +65,11 @@ def construct_stripe_subscription(
revoke: bool = False,
) -> stripe_lib.Subscription:
now_timestamp = datetime.now(UTC).timestamp()
price = price or product.prices[0] if product else None
stripe_price_id = price.stripe_price_id if price else "PRICE_ID"
base_metadata: dict[str, str] = {
**(
{"organization_subscriber_id": str(organization.id)}
if organization is not None
else {}
),
**({"product_id": str(product.id)} if product is not None else {}),
**({"product_price_id": str(price.id)} if price is not None else {}),
}

canceled_at = None
Expand All @@ -90,7 +91,13 @@ def construct_stripe_subscription(
"status": status,
"items": {
"data": [
{"price": {"id": price_id, "currency": "USD", "unit_amount": 1000}}
{
"price": {
"id": stripe_price_id,
"currency": "USD",
"unit_amount": 1000,
}
}
]
},
"current_period_start": now_timestamp,
Expand Down
5 changes: 5 additions & 0 deletions server/tests/subscription/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ async def test_valid(
new_price=new_price.stripe_price_id,
proration_behavior=organization.proration_behavior.to_stripe(),
error_if_incomplete=previous_free,
metadata={
"type": "product",
"product_id": str(product_second.id),
"product_price_id": str(new_price_id),
},
)

updated_subscription = response.json()
Expand Down
Loading

0 comments on commit 54a50b9

Please sign in to comment.