Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix product_price_type filter and deprecate it in favor of product_billing_type #5130

Merged
merged 3 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ export default async function Page({
params: {
query: {
organization_id: organization.id,
product_billing_type: 'one_time',
limit: 100,
},
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { getServerSideAPI } from '@/utils/client/serverside'
import { DataTableSearchParams, parseSearchParams } from '@/utils/datatable'
import { getOrganizationBySlugOrNotFound } from '@/utils/organization'
import { schemas } from '@polar-sh/client'
import { Metadata } from 'next'
import ClientPage from './ClientPage'

Expand All @@ -18,7 +17,6 @@ export default async function Page({
params: { organization: string }
searchParams: DataTableSearchParams & {
product_id?: string[] | string
product_price_type?: schemas['ProductPriceType']
metadata?: string[]
}
}) {
Expand Down
8 changes: 6 additions & 2 deletions clients/packages/client/src/v1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20047,7 +20047,9 @@ export interface operations {
organization_id?: string | string[] | null;
/** @description Filter by product ID. */
product_id?: string | string[] | null;
/** @description Filter by product price type. `recurring` will return orders corresponding to subscriptions creations or renewals. `one_time` will return orders corresponding to one-time purchases. */
/** @description Filter by product billing type. `recurring` will filter data corresponding to subscriptions creations or renewals. `one_time` will filter data corresponding to one-time purchases. */
product_billing_type?: components["schemas"]["ProductBillingType"] | components["schemas"]["ProductBillingType"][] | null;
/** @deprecated */
product_price_type?: components["schemas"]["ProductPriceType"] | components["schemas"]["ProductPriceType"][] | null;
/** @description Filter by discount ID. */
discount_id?: string | string[] | null;
Expand Down Expand Up @@ -22960,7 +22962,9 @@ export interface operations {
organization_id?: string | string[] | null;
/** @description Filter by product ID. */
product_id?: string | string[] | null;
/** @description Filter by product price type. `recurring` will return orders corresponding to subscriptions creations or renewals. `one_time` will return orders corresponding to one-time purchases. */
/** @description Filter by product billing type. `recurring` will filter data corresponding to subscriptions creations or renewals. `one_time` will filter data corresponding to one-time purchases. */
product_billing_type?: components["schemas"]["ProductBillingType"] | components["schemas"]["ProductBillingType"][] | null;
/** @deprecated */
product_price_type?: components["schemas"]["ProductPriceType"] | components["schemas"]["ProductPriceType"][] | null;
/** @description Filter by subscription ID. */
subscription_id?: string | string[] | null;
Expand Down
19 changes: 12 additions & 7 deletions server/polar/customer_portal/endpoints/order.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated
from typing import Annotated, cast

from fastapi import Depends, Query

Expand All @@ -8,6 +8,7 @@
from polar.kit.schemas import MultipleQueryFilter
from polar.kit.sorting import Sorting, SortingGetter
from polar.models import Order
from polar.models.product import ProductBillingType
from polar.models.product_price import ProductPriceType
from polar.openapi import APITag
from polar.order.schemas import OrderID
Expand Down Expand Up @@ -43,16 +44,19 @@ async def list(
product_id: MultipleQueryFilter[ProductID] | None = Query(
None, title="ProductID Filter", description="Filter by product ID."
),
product_price_type: MultipleQueryFilter[ProductPriceType] | None = Query(
product_billing_type: MultipleQueryFilter[ProductBillingType] | None = Query(
None,
title="ProductPriceType Filter",
title="ProductBillingType Filter",
description=(
"Filter by product price type. "
"`recurring` will return orders corresponding "
"Filter by product billing type. "
"`recurring` will filter data corresponding "
"to subscriptions creations or renewals. "
"`one_time` will return orders corresponding to one-time purchases."
"`one_time` will filter data corresponding to one-time purchases."
),
),
product_price_type: MultipleQueryFilter[ProductPriceType] | None = Query(
None, title="ProductPriceType Filter", deprecated="Use `product_billing_type"
),
subscription_id: MultipleQueryFilter[SubscriptionID] | None = Query(
None, title="SubscriptionID Filter", description="Filter by subscription ID."
),
Expand All @@ -67,7 +71,8 @@ async def list(
auth_subject,
organization_id=organization_id,
product_id=product_id,
product_price_type=product_price_type,
product_billing_type=product_billing_type
or cast(MultipleQueryFilter[ProductBillingType] | None, product_price_type),
subscription_id=subscription_id,
query=query,
pagination=pagination,
Expand Down
8 changes: 4 additions & 4 deletions server/polar/customer_portal/service/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from polar.kit.services import ResourceServiceReader
from polar.kit.sorting import Sorting
from polar.models import Customer, Order, Organization, Product, ProductPrice
from polar.models.product_price import ProductPriceType
from polar.models.product import ProductBillingType


class CustomerOrderError(PolarError): ...
Expand Down Expand Up @@ -43,7 +43,7 @@ async def list(
*,
organization_id: Sequence[uuid.UUID] | None = None,
product_id: Sequence[uuid.UUID] | None = None,
product_price_type: Sequence[ProductPriceType] | None = None,
product_billing_type: Sequence[ProductBillingType] | None = None,
subscription_id: Sequence[uuid.UUID] | None = None,
query: str | None = None,
pagination: PaginationParams,
Expand Down Expand Up @@ -75,8 +75,8 @@ async def list(
if product_id is not None:
statement = statement.where(Order.product_id.in_(product_id))

if product_price_type is not None:
statement = statement.where(OrderProductPrice.type.in_(product_price_type))
if product_billing_type is not None:
statement = statement.where(Product.billing_type.in_(product_billing_type))

if subscription_id is not None:
statement = statement.where(Order.subscription_id.in_(subscription_id))
Expand Down
1 change: 1 addition & 0 deletions server/polar/kit/extensions/sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class StringEnum(EnumType):

class StrEnumType(TypeDecorator):
impl = sa.String
cache_ok = True

def __init__(self, enum_klass: type[StrEnum], **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand Down
19 changes: 13 additions & 6 deletions server/polar/order/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import cast

from fastapi import Depends, Query
from pydantic import UUID4

Expand All @@ -6,6 +8,7 @@
from polar.kit.pagination import ListResource, PaginationParamsQuery
from polar.kit.schemas import MultipleQueryFilter
from polar.models import Order
from polar.models.product import ProductBillingType
from polar.models.product_price import ProductPriceType
from polar.openapi import APITag
from polar.organization.schemas import OrganizationID
Expand All @@ -32,16 +35,19 @@ async def list(
product_id: MultipleQueryFilter[ProductID] | None = Query(
None, title="ProductID Filter", description="Filter by product ID."
),
product_price_type: MultipleQueryFilter[ProductPriceType] | None = Query(
product_billing_type: MultipleQueryFilter[ProductBillingType] | None = Query(
None,
title="ProductPriceType Filter",
title="ProductBillingType Filter",
description=(
"Filter by product price type. "
"`recurring` will return orders corresponding "
"Filter by product billing type. "
"`recurring` will filter data corresponding "
"to subscriptions creations or renewals. "
"`one_time` will return orders corresponding to one-time purchases."
"`one_time` will filter data corresponding to one-time purchases."
),
),
product_price_type: MultipleQueryFilter[ProductPriceType] | None = Query(
None, title="ProductPriceType Filter", deprecated="Use `product_billing_type"
),
discount_id: MultipleQueryFilter[UUID4] | None = Query(
None, title="DiscountID Filter", description="Filter by discount ID."
),
Expand All @@ -59,7 +65,8 @@ async def list(
auth_subject,
organization_id=organization_id,
product_id=product_id,
product_price_type=product_price_type,
product_billing_type=product_billing_type
or cast(MultipleQueryFilter[ProductBillingType] | None, product_price_type),
discount_id=discount_id,
customer_id=customer_id,
checkout_id=checkout_id,
Expand Down
8 changes: 4 additions & 4 deletions server/polar/order/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
UserOrganization,
)
from polar.models.order import OrderBillingReason
from polar.models.product_price import ProductPriceType
from polar.models.product import ProductBillingType
from polar.models.transaction import TransactionType
from polar.models.webhook_endpoint import WebhookEventType
from polar.notifications.notification import (
Expand Down Expand Up @@ -177,7 +177,7 @@ async def list(
*,
organization_id: Sequence[uuid.UUID] | None = None,
product_id: Sequence[uuid.UUID] | None = None,
product_price_type: Sequence[ProductPriceType] | None = None,
product_billing_type: Sequence[ProductBillingType] | None = None,
discount_id: Sequence[uuid.UUID] | None = None,
customer_id: Sequence[uuid.UUID] | None = None,
checkout_id: Sequence[uuid.UUID] | None = None,
Expand Down Expand Up @@ -209,8 +209,8 @@ async def list(
if product_id is not None:
statement = statement.where(Order.product_id.in_(product_id))

if product_price_type is not None:
statement = statement.where(OrderProductPrice.type.in_(product_price_type))
if product_billing_type is not None:
statement = statement.where(Product.billing_type.in_(product_billing_type))

if discount_id is not None:
statement = statement.where(Order.discount_id.in_(discount_id))
Expand Down
21 changes: 21 additions & 0 deletions server/tests/order/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,27 @@ async def test_organization(self, client: AsyncClient, orders: list[Order]) -> N
json = response.json()
assert json["pagination"]["total_count"] == len(orders)

@pytest.mark.auth
@pytest.mark.parametrize(
"product_price_type,expected", [("one_time", 0), ("recurring", 1)]
)
async def test_deprecated_product_price_type_filter(
self,
product_price_type: str,
expected: int,
client: AsyncClient,
user_organization: UserOrganization,
orders: list[Order],
) -> None:
response = await client.get(
"/v1/orders/", params={"product_price_type": product_price_type}
)

assert response.status_code == 200

json = response.json()
assert json["pagination"]["total_count"] == expected


@pytest.mark.asyncio
class TestGetOrder:
Expand Down
48 changes: 48 additions & 0 deletions server/tests/order/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from polar.models.checkout import CheckoutStatus
from polar.models.order import OrderBillingReason
from polar.models.organization import Organization
from polar.models.product import ProductBillingType
from polar.models.transaction import TransactionType
from polar.order.service import (
CantDetermineInvoicePrice,
Expand Down Expand Up @@ -211,6 +212,53 @@ async def test_organization(
assert len(orders) == 1
assert orders[0].id == order.id

@pytest.mark.auth
async def test_product_billing_type_filter(
self,
auth_subject: AuthSubject[User],
save_fixture: SaveFixture,
session: AsyncSession,
user_organization: UserOrganization,
product: Product,
product_one_time_custom_price: Product,
product_one_time_free_price: Product,
customer: Customer,
) -> None:
order1 = await create_order(
save_fixture,
product=product,
customer=customer,
stripe_invoice_id="INVOICE_1",
)
order2 = await create_order(
save_fixture,
product=product_one_time_custom_price,
customer=customer,
stripe_invoice_id="INVOICE_2",
)

orders, count = await order_service.list(
session,
auth_subject,
product_billing_type=(ProductBillingType.recurring,),
pagination=PaginationParams(1, 10),
)

assert count == 1
assert len(orders) == 1
assert orders[0].id == order1.id

orders, count = await order_service.list(
session,
auth_subject,
product_billing_type=(ProductBillingType.one_time,),
pagination=PaginationParams(1, 10),
)

assert count == 1
assert len(orders) == 1
assert orders[0].id == order2.id


@pytest.mark.asyncio
class TestCreateOrderFromStripe:
Expand Down