Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subscriber Billed Topics #609

Closed
wants to merge 12 commits into from
3 changes: 2 additions & 1 deletion server/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package server
import (
"encoding/json"
"fmt"
"heckel.io/ntfy/log"
"net/http"

"heckel.io/ntfy/log"
)

// errHTTP is a generic HTTP error for any non-200 HTTP error
Expand Down
109 changes: 63 additions & 46 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/emersion/go-smtp"
"github.com/gorilla/websocket"
"golang.org/x/sync/errgroup"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"net"
"net/http"
Expand All @@ -30,6 +24,13 @@ import (
"sync"
"time"
"unicode/utf8"

"github.com/emersion/go-smtp"
"github.com/gorilla/websocket"
"golang.org/x/sync/errgroup"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
)

/*
Expand Down Expand Up @@ -114,13 +115,15 @@ var (
)

const (
firebaseControlTopic = "~control" // See Android if changed
firebasePollTopic = "~poll" // See iOS if changed
emptyMessageBody = "triggered" // Used if message body is empty
newMessageBody = "New message" // Used in poll requests as generic message
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
jsonBodyBytesLimit = 16384
firebaseControlTopic = "~control" // See Android if changed
firebasePollTopic = "~poll" // See iOS if changed
emptyMessageBody = "triggered" // Used if message body is empty
newMessageBody = "New message" // Used in poll requests as generic message
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
jsonBodyBytesLimit = 16384
subscriberBilledTopicPrefix = "up_"
karmanyaahm marked this conversation as resolved.
Show resolved Hide resolved
karmanyaahm marked this conversation as resolved.
Show resolved Hide resolved
subscriberBilledValidity = 12 * time.Hour
)

// WebSocket constants
Expand Down Expand Up @@ -434,13 +437,14 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
} else if r.Method == http.MethodOptions {
return s.limitRequests(s.handleOptions)(w, r, v) // Should work even if the web app is not enabled, see #598
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
// So I don't *really* have to switch this order, since this is unrelated to UP; But, making this and matrix inconsistent is just calling for confusion, no?
return s.transformBodyJSON(s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
return s.transformMatrixJSON(s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
return s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
return s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
Expand Down Expand Up @@ -599,19 +603,24 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
}

func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
t, err := s.topicFromPath(r.URL.Path)
if err != nil {
return nil, err
vRate, ok := r.Context().Value("vRate").(*visitor)
if !ok {
return nil, errHTTPInternalError
}
if !v.MessageAllowed() {
t, ok := r.Context().Value("topic").(*topic)
if !ok {
return nil, errHTTPInternalError
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I said to just replace v in an older comment, but the more I think about it, the more I think that this is literally just about rate limiting, right? So maybe maybe our Billee is just a vrate (= visitor used for rate limiting), and for all the rate limiting functions, we'd use vrate, but for everything else, we'd use v?

I have also noticed that you haven't tackled the limitRequests middleware at all. The rate limiting for number of messages (visitor-message-daily-limit) and number of emails (visitor-email-limit-*) happens in handlePublish*, but the rate limiting for requests (visitor-request-limit-*) happens in limitRequests, so this logic would have to happen there, not here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this using context (as suggested below), but isn't that supposed to be an anti-pattern? In any case, I guess it's just one thing so it shouldn't matter much. I feel like the ideal solution to this would be making visitor an interface and having a compositeVisitor in addition to visitor...but, that's probably overengineering.


if !vRate.MessageAllowed() {
return nil, errHTTPTooManyRequestsLimitMessages
}
body, err := util.Peek(r.Body, s.config.MessageLimit)
if err != nil {
return nil, err
}
m := newDefaultMessage(t.ID, "")
cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m)
cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vRate, m)
if err != nil {
return nil, err
}
Expand All @@ -620,25 +629,26 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
}
m.Sender = v.IP()
m.User = v.MaybeUserID()
m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix()
m.Expires = time.Unix(m.Time, 0).Add(vRate.Limits().MessageExpiryDuration).Unix()
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
return nil, err
}
if m.Message == "" {
m.Message = emptyMessageBody
}
delayed := m.Time > time.Now().Unix()
logvrm(v, r, m).
logvrm(vRate, r, m).
Tag(tagPublish).
Fields(log.Context{
"message_delayed": delayed,
"message_firebase": firebase,
"message_unifiedpush": unifiedpush,
"message_email": email,
"message_delayed": delayed,
"message_firebase": firebase,
"message_unifiedpush": unifiedpush,
"message_email": email,
"message_subscriber_rate_limited": vRate != v,
}).
Debug("Received message")
if log.IsTrace() {
logvrm(v, r, m).
logvrm(vRate, r, m).
Tag(tagPublish).
Field("message_body", util.MaybeMarshalJSON(m)).
Trace("Message body")
Expand All @@ -648,7 +658,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
return nil, err
}
if s.firebaseClient != nil && firebase {
go s.sendToFirebase(v, m)
go s.sendToFirebase(vRate, m)
}
if s.smtpSender != nil && email != "" {
go s.sendEmail(v, m, email)
Expand Down Expand Up @@ -680,6 +690,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
if err != nil {
return err
}

return s.writeJSON(w, m)
}

Expand Down Expand Up @@ -733,7 +744,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
}
}

func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
cache = readBoolParam(r, true, "x-cache", "cache")
firebase = readBoolParam(r, true, "x-firebase", "firebase")
m.Title = readParam(r, "x-title", "title", "t")
Expand Down Expand Up @@ -773,7 +784,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
}
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
if email != "" {
if !v.EmailAllowed() {
if !vRate.EmailAllowed() {
return false, false, "", false, errHTTPTooManyRequestsLimitEmails
}
}
Expand All @@ -788,13 +799,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
if err != nil {
return false, false, "", false, errHTTPBadRequestPriorityInvalid
}
tagsStr := readParam(r, "x-tags", "tags", "tag", "ta")
if tagsStr != "" {
m.Tags = make([]string, 0)
for _, s := range util.SplitNoEmpty(tagsStr, ",") {
m.Tags = append(m.Tags, strings.TrimSpace(s))
}
}
m.Tags = readCommaSeperatedParam(r, "x-tags", "tags", "tag", "ta")
delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
if delayStr != "" {
if !cache {
Expand Down Expand Up @@ -984,7 +989,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
if err != nil {
return err
}
poll, since, scheduled, filters, err := parseSubscribeParams(r)
poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
if err != nil {
return err
}
Expand Down Expand Up @@ -1023,7 +1028,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
defer cancel()
subscriberIDs := make([]int, 0)
for _, t := range topics {
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
}
defer func() {
for i, subscriberID := range subscriberIDs {
Expand Down Expand Up @@ -1066,7 +1072,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
if err != nil {
return err
}
poll, since, scheduled, filters, err := parseSubscribeParams(r)
poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
if err != nil {
return err
}
Expand Down Expand Up @@ -1155,7 +1161,8 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
subscriberIDs := make([]int, 0)
for _, t := range topics {
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
}
defer func() {
for i, subscriberID := range subscriberIDs {
Expand All @@ -1176,7 +1183,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
return err
}

func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, subscriberTopics []string, err error) {
poll = readBoolParam(r, false, "x-poll", "poll", "po")
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
since, err = parseSince(r, poll)
Expand All @@ -1187,6 +1194,8 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
if err != nil {
return
}

subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
return
}

Expand Down Expand Up @@ -1402,9 +1411,17 @@ func (s *Server) execManager() {
defer s.mu.Unlock()
for _, t := range s.topics {
subs := t.SubscribersCount()
log.Tag(tagManager).Trace("- topic %s: %d subscribers", t.ID, subs)
ev := log.Tag(tagManager)
if ev.IsTrace() {
expiryMessage := ""
if subs == 0 {
expiryTime := time.Until(t.vRateExpires)
expiryMessage = ", expires in " + expiryTime.String()
}
ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage)
}
msgs, exists := messageCounts[t.ID]
binwiederhier marked this conversation as resolved.
Show resolved Hide resolved
if subs == 0 && (!exists || msgs == 0) {
if t.Stale() && (!exists || msgs == 0) {
log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID)
emptyTopics++
delete(s.topics, t.ID)
Expand Down
26 changes: 25 additions & 1 deletion server/server_middleware.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package server

import (
"heckel.io/ntfy/util"
"context"
"net/http"

"heckel.io/ntfy/util"
)

func (s *Server) limitRequests(next handleFunc) handleFunc {
Expand All @@ -16,6 +18,28 @@ func (s *Server) limitRequests(next handleFunc) handleFunc {
}
}

// limitRequestsWithTopic limits requests with a topic and stores the rate-limiting-subscriber and topic into request.Context
func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
t, err := s.topicFromPath(r.URL.Path)
if err != nil {
return err
}
vRate := v
if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil {
vRate = topicCountsAgainst
}
r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vRate), "topic", t))

if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
return next(w, r, v)
} else if !vRate.RequestAllowed() {
return errHTTPTooManyRequestsLimitRequests
}
return next(w, r, v)
}
}

func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if !s.config.EnableWeb {
Expand Down
Loading