diff --git a/server/migrations/versions/2024-05-15-1206_update_stripe_s_product_metadata.py b/server/migrations/versions/2024-05-15-1206_update_stripe_s_product_metadata.py new file mode 100644 index 0000000000..aef2535922 --- /dev/null +++ b/server/migrations/versions/2024-05-15-1206_update_stripe_s_product_metadata.py @@ -0,0 +1,61 @@ +"""Update Stripe's product metadata + +Revision ID: 40f397ad512a +Revises: 8ca7adb6786e +Create Date: 2024-05-15 12:06:39.431048 + +""" + +import sqlalchemy as sa +from alembic import op + +from polar.integrations.stripe.service import stripe as stripe_service + +# Polar Custom Imports +from polar.kit.extensions.sqlalchemy import PostgresUUID + +# revision identifiers, used by Alembic. +revision = "40f397ad512a" +down_revision = "8ca7adb6786e" +branch_labels: tuple[str] | None = None +depends_on: tuple[str] | None = None + + +def upgrade() -> None: + connection = op.get_bind() + result = connection.execute( + sa.text( + """ + SELECT products.id, products.stripe_product_id + FROM products + WHERE stripe_product_id IS NOT NULL; + """ + ) + ) + + for product_id, stripe_product_id in result: + metadata = { + "product_id": str(product_id), + "subscription_tier_id": "", + } + stripe_service.update_product(stripe_product_id, metadata=metadata) + + +def downgrade() -> None: + connection = op.get_bind() + result = connection.execute( + sa.text( + """ + SELECT products.id, products.stripe_product_id + FROM products + WHERE stripe_product_id IS NOT NULL; + """ + ) + ) + + for product_id, stripe_product_id in result: + metadata = { + "subscription_tier_id": str(product_id), + "product_id": "", + } + stripe_service.update_product(stripe_product_id, metadata=metadata) diff --git a/server/polar/integrations/stripe/service.py b/server/polar/integrations/stripe/service.py index 4cdb6f97cb..59dd4385a6 100644 --- a/server/polar/integrations/stripe/service.py +++ b/server/polar/integrations/stripe/service.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Iterator -from typing import Literal, TypedDict, Unpack, cast +from typing import Literal, Unpack, cast import stripe as stripe_lib @@ -24,12 +24,6 @@ stripe_lib.default_http_client = stripe_http_client -class ProductUpdateKwargs(TypedDict, total=False): - name: str - description: str - default_price: str - - class MissingOrganizationBillingEmail(PolarError): def __init__(self, organization_id: uuid.UUID) -> None: self.organization_id = organization_id @@ -371,7 +365,7 @@ def create_price_for_product( return price def update_product( - self, product: str, **kwargs: Unpack[ProductUpdateKwargs] + self, product: str, **kwargs: Unpack[stripe_lib.Product.ModifyParams] ) -> stripe_lib.Product: return stripe_lib.Product.modify(product, **kwargs) diff --git a/server/polar/product/service/product.py b/server/polar/product/service/product.py index 6f2f065628..48dc418c73 100644 --- a/server/polar/product/service/product.py +++ b/server/polar/product/service/product.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from typing import Any, List, Literal, TypeVar # noqa: UP035 +import stripe from sqlalchemy import Select, and_, case, or_, select, update from sqlalchemy.exc import InvalidRequestError from sqlalchemy.orm import contains_eager, joinedload @@ -11,7 +12,6 @@ from polar.authz.service import AccessType, Authz from polar.benefit.service.benefit import benefit as benefit_service from polar.exceptions import NotPermitted, PolarError, PolarRequestValidationError -from polar.integrations.stripe.service import ProductUpdateKwargs from polar.integrations.stripe.service import stripe as stripe_service from polar.kit.db.postgres import AsyncSession from polar.kit.pagination import PaginationParams, paginate @@ -236,7 +236,7 @@ async def user_update( if product.is_archived and update_schema.is_archived is False: product = await self._unarchive(product) - product_update: ProductUpdateKwargs = {} + product_update: stripe.Product.ModifyParams = {} if update_schema.name is not None and update_schema.name != product.name: product.name = update_schema.name product_update["name"] = product.get_stripe_name()