diff --git a/server/pkg/controller/billing.go b/server/pkg/controller/billing.go index 0d01aeeb9a..8b9b3f75f5 100644 --- a/server/pkg/controller/billing.go +++ b/server/pkg/controller/billing.go @@ -178,7 +178,7 @@ func (c *BillingController) IsActivePayingSubscriber(userID int64) error { } // HasActiveSelfOrFamilySubscription validates if the user or user's family admin has active subscription -func (c *BillingController) HasActiveSelfOrFamilySubscription(userID int64) error { +func (c *BillingController) HasActiveSelfOrFamilySubscription(userID int64, mustBeOnPaidPlan bool) error { var subscriptionUserID int64 familyAdminID, err := c.UserRepo.GetFamilyAdminID(userID) if err != nil { @@ -202,6 +202,15 @@ func (c *BillingController) HasActiveSelfOrFamilySubscription(userID int64) erro } return stacktrace.Propagate(err, "") } + if mustBeOnPaidPlan { + isPayingUser, err := c.BillingRepo.IsUserOnPaidPlan(subscriptionUserID) + if err != nil { + return stacktrace.Propagate(err, "failed to check if user is on paid plan") + } + if !isPayingUser { + return ente.ErrSharingDisabledForFreeAccounts + } + } return nil } diff --git a/server/pkg/controller/collection.go b/server/pkg/controller/collection.go index 911afc6d77..dca53f80c3 100644 --- a/server/pkg/controller/collection.go +++ b/server/pkg/controller/collection.go @@ -166,7 +166,7 @@ func (c *CollectionController) Share(ctx *gin.Context, req ente.AlterShareReques if fromUserID != collection.Owner.ID { return nil, stacktrace.Propagate(ente.ErrPermissionDenied, "") } - err = c.BillingCtrl.HasActiveSelfOrFamilySubscription(fromUserID) + err = c.BillingCtrl.HasActiveSelfOrFamilySubscription(fromUserID, true) if err != nil { return nil, stacktrace.Propagate(err, "") } @@ -270,7 +270,7 @@ func (c *CollectionController) ShareURL(ctx context.Context, userID int64, req e if userID != collection.Owner.ID { return ente.PublicURL{}, stacktrace.Propagate(ente.ErrPermissionDenied, "") } - err = c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID) + err = c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true) if err != nil { return ente.PublicURL{}, stacktrace.Propagate(err, "") } @@ -287,7 +287,7 @@ func (c *CollectionController) UpdateShareURL(ctx context.Context, userID int64, if err := c.verifyOwnership(req.CollectionID, userID); err != nil { return ente.PublicURL{}, stacktrace.Propagate(err, "") } - err := c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID) + err := c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true) if err != nil { return ente.PublicURL{}, stacktrace.Propagate(err, "") } diff --git a/server/pkg/middleware/access_token.go b/server/pkg/middleware/access_token.go index c1ca120167..702af77db8 100644 --- a/server/pkg/middleware/access_token.go +++ b/server/pkg/middleware/access_token.go @@ -117,7 +117,7 @@ func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error { if err != nil { return stacktrace.Propagate(err, "") } - return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID) + return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false) } func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context, diff --git a/server/pkg/repo/billing.go b/server/pkg/repo/billing.go index 12ca041e0b..d66dce90d7 100644 --- a/server/pkg/repo/billing.go +++ b/server/pkg/repo/billing.go @@ -3,6 +3,7 @@ package repo import ( "database/sql" "encoding/json" + "fmt" "github.com/ente-io/stacktrace" @@ -108,6 +109,43 @@ func (repo *BillingRepository) LogAppStorePush(userID int64, notification appsto return stacktrace.Propagate(err, "") } +func (repo *BillingRepository) IsUserOnPaidPlan(userID int64) (bool, error) { + query := ` + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 + FROM users u + WHERE u.user_id = 1 + ) THEN true + ELSE EXISTS ( + SELECT 1 + FROM users u + WHERE u.user_id = $1 + AND ( + EXISTS ( + SELECT 1 + FROM subscriptions s + WHERE s.user_id = COALESCE(u.family_admin_id, u.user_id) + AND s.product_id <> 'free' + ) + OR EXISTS ( + SELECT 1 + FROM storage_bonus sb + WHERE sb.user_id = COALESCE(u.family_admin_id, u.user_id) + AND sb.type NOT IN ('SIGN_UP', 'REFERRAL') + ) + ) + ) + END + ` + var isPaidPlan bool + err := repo.DB.QueryRow(query, userID).Scan(&isPaidPlan) + if err != nil { + return false, fmt.Errorf("error checking paid plan status: %v", err) + } + return isPaidPlan, nil +} + // LogStripePush logs a notification from Stripe func (repo *BillingRepository) LogStripePush(eventLog ente.StripeEventLog) error { notificationJSON, _ := json.Marshal(eventLog.Event)