Skip to content

Commit

Permalink
server/product: allow to create product with one_time price
Browse files Browse the repository at this point in the history
  • Loading branch information
frankie567 committed May 14, 2024
1 parent 2ad3fbd commit 571440b
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Add ProductPrice.type
Revision ID: 8ca7adb6786e
Revises: e437db540330
Create Date: 2024-05-14 16:32:20.907470
"""

import sqlalchemy as sa
from alembic import op

# Polar Custom Imports
from polar.kit.extensions.sqlalchemy import PostgresUUID

# revision identifiers, used by Alembic.
revision = "8ca7adb6786e"
down_revision = "e437db540330"
branch_labels: tuple[str] | None = None
depends_on: tuple[str] | None = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("product_prices", sa.Column("type", sa.String(), nullable=True))
op.execute("UPDATE product_prices SET type = 'recurring'")
op.alter_column("product_prices", "type", existing_type=sa.String(), nullable=False)

op.alter_column(
"product_prices",
"recurring_interval",
existing_type=sa.VARCHAR(),
nullable=True,
)
op.create_index(
op.f("ix_product_prices_type"), "product_prices", ["type"], unique=False
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_product_prices_type"), table_name="product_prices")
op.alter_column(
"product_prices",
"recurring_interval",
existing_type=sa.VARCHAR(),
nullable=False,
)
op.drop_column("product_prices", "type")
# ### end Alembic commands ###
16 changes: 9 additions & 7 deletions server/polar/integrations/stripe/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,16 +354,18 @@ def create_price_for_product(
product: str,
price_amount: int,
price_currency: str,
interval: Literal["day", "month", "week", "year"],
recurring_interval: Literal["day", "month", "week", "year"] | None,
*,
set_default: bool = False,
) -> stripe_lib.Price:
price = stripe_lib.Price.create(
currency=price_currency,
product=product,
unit_amount=price_amount,
recurring={"interval": interval},
)
params: stripe_lib.Price.CreateParams = {
"currency": price_currency,
"product": product,
"unit_amount": price_amount,
}
if recurring_interval is not None:
params["recurring"] = {"interval": recurring_interval}
price = stripe_lib.Price.create(**params)
if set_default:
stripe_lib.Product.modify(product, default_price=price.id)
return price
Expand Down
11 changes: 10 additions & 1 deletion server/polar/models/product_price.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
from polar.models import Product, Subscription


class ProductPriceType(StrEnum):
one_time = "one_time"
recurring = "recurring"

def as_literal(self) -> Literal["one_time", "recurring"]:
return cast(Literal["one_time", "recurring"], self.value)


class ProductPriceRecurringInterval(StrEnum):
month = "month"
year = "year"
Expand All @@ -23,8 +31,9 @@ def as_literal(self) -> Literal["month", "year"]:
class ProductPrice(RecordModel):
__tablename__ = "product_prices"

type: Mapped[ProductPriceType] = mapped_column(String, nullable=False, index=True)
recurring_interval: Mapped[ProductPriceRecurringInterval] = mapped_column(
String, nullable=False, index=True
String, nullable=True, index=True
)
price_amount: Mapped[int] = mapped_column(Integer, nullable=False)
price_currency: Mapped[str] = mapped_column(String(3), nullable=False)
Expand Down
23 changes: 18 additions & 5 deletions server/polar/product/schemas.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Literal
from typing import Annotated, Literal

from pydantic import UUID4, Field
from pydantic import UUID4, Discriminator, Field

from polar.benefit.schemas import BenefitPublic, BenefitSubscriber
from polar.kit.schemas import EmptyStrToNone, Schema, TimestampedSchema
from polar.models.product import SubscriptionTierType
from polar.models.product_price import ProductPriceRecurringInterval
from polar.models.product_price import ProductPriceRecurringInterval, ProductPriceType

PRODUCT_NAME_MIN_LENGTH = 3
PRODUCT_NAME_MAX_LENGTH = 24
Expand All @@ -17,12 +17,24 @@
MAXIMUM_PRICE_AMOUNT = 99999999


class ProductPriceCreate(Schema):
class ProductPriceRecurringCreate(Schema):
type: Literal[ProductPriceType.recurring]
recurring_interval: ProductPriceRecurringInterval
price_amount: int = Field(..., gt=0, le=MAXIMUM_PRICE_AMOUNT)
price_currency: str = Field("usd", pattern="usd")


class ProductPriceOneTimeCreate(Schema):
type: Literal[ProductPriceType.one_time]
price_amount: int = Field(..., gt=0, le=MAXIMUM_PRICE_AMOUNT)
price_currency: str = Field("usd", pattern="usd")


ProductPriceCreate = Annotated[
ProductPriceRecurringCreate | ProductPriceOneTimeCreate, Discriminator("type")
]


class ProductCreate(Schema):
type: Literal[
SubscriptionTierType.individual,
Expand Down Expand Up @@ -62,7 +74,8 @@ class ProductBenefitsUpdate(Schema):

class ProductPrice(TimestampedSchema):
id: UUID4
recurring_interval: ProductPriceRecurringInterval
type: ProductPriceType
recurring_interval: ProductPriceRecurringInterval | None = None
price_amount: int
price_currency: str
is_archived: bool
Expand Down
15 changes: 12 additions & 3 deletions server/polar/product/service/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
from polar.webhook.webhooks import WebhookTypeObject
from polar.worker import enqueue_job

from ..schemas import ExistingProductPrice, ProductCreate, ProductUpdate
from ..schemas import (
ExistingProductPrice,
ProductCreate,
ProductPriceRecurringCreate,
ProductUpdate,
)


class ProductError(PolarError): ...
Expand Down Expand Up @@ -221,7 +226,9 @@ async def user_create(
stripe_product.id,
price_create.price_amount,
price_create.price_currency,
price_create.recurring_interval.as_literal(),
price_create.recurring_interval.as_literal()
if isinstance(price_create, ProductPriceRecurringCreate)
else None,
)
price = ProductPrice(
**price_create.model_dump(),
Expand Down Expand Up @@ -284,7 +291,9 @@ async def user_update(
product.stripe_product_id,
price_update.price_amount,
price_update.price_currency,
price_update.recurring_interval.as_literal(),
price_update.recurring_interval.as_literal()
if isinstance(price_update, ProductPriceRecurringCreate)
else None,
)
price = ProductPrice(
**price_update.model_dump(),
Expand Down
4 changes: 3 additions & 1 deletion server/tests/fixtures/random_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from polar.models.issue import Issue
from polar.models.pledge import Pledge, PledgeState, PledgeType
from polar.models.product import SubscriptionTierType
from polar.models.product_price import ProductPriceRecurringInterval
from polar.models.product_price import ProductPriceRecurringInterval, ProductPriceType
from polar.models.pull_request import PullRequest
from polar.models.subscription import SubscriptionStatus
from polar.models.user import OAuthAccount, OAuthPlatform
Expand Down Expand Up @@ -458,12 +458,14 @@ async def create_product_price(
save_fixture: SaveFixture,
*,
product: Product,
type: ProductPriceType = ProductPriceType.recurring,
recurring_interval: ProductPriceRecurringInterval = ProductPriceRecurringInterval.month,
amount: int = 1000,
) -> ProductPrice:
price = ProductPrice(
price_amount=amount,
price_currency="usd",
type=type,
recurring_interval=recurring_interval,
stripe_price_id=rstr("PRICE_ID"),
product=product,
Expand Down
31 changes: 20 additions & 11 deletions server/tests/product/service/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from polar.models import Benefit, Organization, Product, User, UserOrganization
from polar.models.benefit import BenefitType
from polar.models.product import SubscriptionTierType
from polar.models.product_price import ProductPriceRecurringInterval
from polar.models.product_price import ProductPriceRecurringInterval, ProductPriceType
from polar.postgres import AsyncSession
from polar.product.schemas import (
ExistingProductPrice,
ProductCreate,
ProductPriceCreate,
ProductPriceRecurringCreate,
ProductUpdate,
)
from polar.product.service.product import (
Expand Down Expand Up @@ -545,7 +545,8 @@ async def test_user_not_existing_organization(
name="Product",
organization_id=uuid.uuid4(),
prices=[
ProductPriceCreate(
ProductPriceRecurringCreate(
type=ProductPriceType.recurring,
recurring_interval=ProductPriceRecurringInterval.month,
price_amount=1000,
price_currency="usd",
Expand Down Expand Up @@ -574,7 +575,8 @@ async def test_user_not_writable_organization(
name="Product",
organization_id=organization.id,
prices=[
ProductPriceCreate(
ProductPriceRecurringCreate(
type=ProductPriceType.recurring,
recurring_interval=ProductPriceRecurringInterval.month,
price_amount=1000,
price_currency="usd",
Expand Down Expand Up @@ -613,7 +615,8 @@ async def test_user_valid_organization(
name="Product",
organization_id=organization.id,
prices=[
ProductPriceCreate(
ProductPriceRecurringCreate(
type=ProductPriceType.recurring,
recurring_interval=ProductPriceRecurringInterval.month,
price_amount=1000,
price_currency="usd",
Expand Down Expand Up @@ -667,7 +670,8 @@ async def test_user_valid_highlighted(
organization_id=organization.id,
is_highlighted=True,
prices=[
ProductPriceCreate(
ProductPriceRecurringCreate(
type=ProductPriceType.recurring,
recurring_interval=ProductPriceRecurringInterval.month,
price_amount=1000,
price_currency="usd",
Expand Down Expand Up @@ -712,7 +716,8 @@ async def test_user_empty_description(
description="",
organization_id=organization.id,
prices=[
ProductPriceCreate(
ProductPriceRecurringCreate(
type=ProductPriceType.recurring,
recurring_interval=ProductPriceRecurringInterval.month,
price_amount=1000,
price_currency="usd",
Expand Down Expand Up @@ -741,7 +746,8 @@ async def test_organization_set_organization_id(
name="Product",
organization_id=organization.id,
prices=[
ProductPriceCreate(
ProductPriceRecurringCreate(
type=ProductPriceType.recurring,
recurring_interval=ProductPriceRecurringInterval.month,
price_amount=1000,
price_currency="usd",
Expand Down Expand Up @@ -778,7 +784,8 @@ async def test_organization_valid(
type=SubscriptionTierType.individual,
name="Product",
prices=[
ProductPriceCreate(
ProductPriceRecurringCreate(
type=ProductPriceType.recurring,
recurring_interval=ProductPriceRecurringInterval.month,
price_amount=1000,
price_currency="usd",
Expand Down Expand Up @@ -959,7 +966,8 @@ async def test_valid_price_added(
update_schema = ProductUpdate(
prices=[
ExistingProductPrice(id=product_organization_loaded.prices[0].id),
ProductPriceCreate(
ProductPriceRecurringCreate(
type=ProductPriceType.recurring,
recurring_interval=ProductPriceRecurringInterval.year,
price_amount=12000,
price_currency="usd",
Expand Down Expand Up @@ -1013,7 +1021,8 @@ async def test_valid_price_deleted(

update_schema = ProductUpdate(
prices=[
ProductPriceCreate(
ProductPriceRecurringCreate(
type=ProductPriceType.recurring,
recurring_interval=ProductPriceRecurringInterval.year,
price_amount=12000,
price_currency="usd",
Expand Down
26 changes: 19 additions & 7 deletions server/tests/product/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ async def test_validation(
"organization_id": str(organization.id),
"prices": [
{
"type": "recurring",
"recurring_interval": "month",
"price_amount": 1000,
"price_currency": "usd",
Expand All @@ -242,8 +243,25 @@ async def test_validation(
assert response.status_code == 422

@pytest.mark.auth
@pytest.mark.parametrize(
"prices",
(
[
{
"type": "recurring",
"recurring_interval": "month",
"price_amount": 1000,
"price_currency": "usd",
}
],
[
{"type": "one_time", "price_amount": 1000, "price_currency": "usd"},
],
),
)
async def test_valid(
self,
prices: list[dict[str, Any]],
client: AsyncClient,
organization: Organization,
user_organization_admin: UserOrganization,
Expand All @@ -268,13 +286,7 @@ async def test_valid(
"name": "Product",
"price_amount": 1000,
"organization_id": str(organization.id),
"prices": [
{
"recurring_interval": "month",
"price_amount": 1000,
"price_currency": "usd",
}
],
"prices": prices,
},
)

Expand Down

0 comments on commit 571440b

Please sign in to comment.