From d686e1ee77b0a1de2ea108a12259eae6497def5c Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Tue, 14 Feb 2023 13:07:32 -0600 Subject: [PATCH 01/12] Use visitor instead of UserID in topicSubscription --- server/server.go | 4 ++-- server/topic.go | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/server/server.go b/server/server.go index 288b13885..890690fbf 100644 --- a/server/server.go +++ b/server/server.go @@ -1023,7 +1023,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * defer cancel() subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel)) + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) } defer func() { for i, subscriberID := range subscriberIDs { @@ -1155,7 +1155,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel)) + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) } defer func() { for i, subscriberID := range subscriberIDs { diff --git a/server/topic.go b/server/topic.go index 150a185b9..fca5ee0a1 100644 --- a/server/topic.go +++ b/server/topic.go @@ -15,8 +15,8 @@ type topic struct { } type topicSubscriber struct { - userID string // User ID associated with this subscription, may be empty subscriber subscriber + visitor *visitor // User ID associated with this subscription, may be empty cancel func() } @@ -32,12 +32,12 @@ func newTopic(id string) *topic { } // Subscribe subscribes to this topic -func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int { +func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int { t.mu.Lock() defer t.mu.Unlock() subscriberID := rand.Int() t.subscribers[subscriberID] = &topicSubscriber{ - userID: userID, // May be empty + visitor: visitor, // May be empty subscriber: s, cancel: cancel, } @@ -87,8 +87,9 @@ func (t *topic) CancelSubscribers(exceptUserID string) { t.mu.Lock() defer t.mu.Unlock() for _, s := range t.subscribers { - if s.userID != exceptUserID { - log.Tag(tagSubscribe).Field("topic", t.ID).Debug("Canceling subscriber %s", s.userID) + if s.visitor.MaybeUserID() != exceptUserID { + // TODO: Shouldn't this log the IP for anonymous visitors? It was s.userID before my change. + log.Tag(tagSubscribe).Field("topic", t.ID).Debug("Canceling subscriber %s", s.visitor.MaybeUserID()) s.cancel() } } @@ -101,7 +102,7 @@ func (t *topic) subscribersCopy() map[int]*topicSubscriber { subscribers := make(map[int]*topicSubscriber) for k, sub := range t.subscribers { subscribers[k] = &topicSubscriber{ - userID: sub.userID, + visitor: sub.visitor, subscriber: sub.subscriber, cancel: sub.cancel, } From 28b654ae2746b6a48deddab9df69dfc5fa88895e Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Tue, 14 Feb 2023 13:58:13 -0600 Subject: [PATCH 02/12] Keep track of lastVisitor to a topic --- server/server.go | 16 +++++++++------- server/topic.go | 32 ++++++++++++++++++++++++++++---- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/server/server.go b/server/server.go index 890690fbf..a22ed7ba8 100644 --- a/server/server.go +++ b/server/server.go @@ -114,13 +114,15 @@ var ( ) const ( - firebaseControlTopic = "~control" // See Android if changed - firebasePollTopic = "~poll" // See iOS if changed - emptyMessageBody = "triggered" // Used if message body is empty - newMessageBody = "New message" // Used in poll requests as generic message - defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment - encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages - jsonBodyBytesLimit = 16384 + firebaseControlTopic = "~control" // See Android if changed + firebasePollTopic = "~poll" // See iOS if changed + emptyMessageBody = "triggered" // Used if message body is empty + newMessageBody = "New message" // Used in poll requests as generic message + defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment + encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages + jsonBodyBytesLimit = 16384 + subscriberBilledTopicPrefix = "up_" + subscriberBilledValidity = 12 * time.Hour ) // WebSocket constants diff --git a/server/topic.go b/server/topic.go index fca5ee0a1..6911ec0ba 100644 --- a/server/topic.go +++ b/server/topic.go @@ -1,17 +1,21 @@ package server import ( - "heckel.io/ntfy/log" "math/rand" "sync" + "time" + + "heckel.io/ntfy/log" ) // topic represents a channel to which subscribers can subscribe, and publishers // can publish a message type topic struct { - ID string - subscribers map[int]*topicSubscriber - mu sync.Mutex + ID string + subscribers map[int]*topicSubscriber + lastVisitor *visitor + lastVisitorExpires time.Time + mu sync.Mutex } type topicSubscriber struct { @@ -44,10 +48,30 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int { return subscriberID } +func (t *topic) Stale() bool { + return t.getBillee() == nil +} + +func (t *topic) getBillee() *visitor { + for _, this_subscriber := range t.subscribers { + return this_subscriber.visitor + } + if t.lastVisitor != nil && t.lastVisitorExpires.After(time.Now()) { + t.lastVisitor = nil + } + return t.lastVisitor + +} + // Unsubscribe removes the subscription from the list of subscribers func (t *topic) Unsubscribe(id int) { t.mu.Lock() defer t.mu.Unlock() + + if len(t.subscribers) == 1 { + t.lastVisitor = t.subscribers[id].visitor + t.lastVisitorExpires = time.Now().Add(subscriberBilledValidity) + } delete(t.subscribers, id) } From fb2fa4c478778ab0d09154b5d274ae37caa6a4b6 Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Tue, 14 Feb 2023 14:00:43 -0600 Subject: [PATCH 03/12] Fix m.Expires and prune stale topics based on lastVisitorExpires --- server/server.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/server/server.go b/server/server.go index a22ed7ba8..619e30cb4 100644 --- a/server/server.go +++ b/server/server.go @@ -622,7 +622,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes } m.Sender = v.IP() m.User = v.MaybeUserID() - m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix() + m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix() if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { return nil, err } @@ -666,6 +666,8 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if err := s.messageCache.AddMessage(m); err != nil { return nil, err } + } else { + m.Expires = m.Time } u := v.User() if s.userManager != nil && u != nil && u.Tier != nil { @@ -1404,9 +1406,10 @@ func (s *Server) execManager() { defer s.mu.Unlock() for _, t := range s.topics { subs := t.SubscribersCount() - log.Tag(tagManager).Trace("- topic %s: %d subscribers", t.ID, subs) + expiryTime := time.Until(t.lastVisitorExpires) + log.Tag(tagManager).Trace("- topic %s: %d subscribers, expires in %s", t.ID, subs, expiryTime) msgs, exists := messageCounts[t.ID] - if subs == 0 && (!exists || msgs == 0) { + if t.Stale() && (!exists || msgs == 0) { log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID) emptyTopics++ delete(s.topics, t.ID) From 6bfe4a97797bcb3f513f63b5dcc944decb60d926 Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Tue, 14 Feb 2023 14:07:02 -0600 Subject: [PATCH 04/12] Bill to visitor and set TTL in response --- server/errors.go | 1 + server/server.go | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/server/errors.go b/server/errors.go index a00105c37..f50a4df9f 100644 --- a/server/errors.go +++ b/server/errors.go @@ -92,4 +92,5 @@ var ( errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""} errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"} + errHTTPWontStoreMessage = &errHTTP{50701, http.StatusInsufficientStorage, "topic is inactive; no device available to recieve message", ""} ) diff --git a/server/server.go b/server/server.go index 619e30cb4..16771332f 100644 --- a/server/server.go +++ b/server/server.go @@ -372,6 +372,7 @@ func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor, } w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests + w.Header().Set("TTL", "0") // if message is not being stored because of an error, tell them w.WriteHeader(httpErr.HTTPCode) io.WriteString(w, httpErr.JSON()+"\n") } @@ -605,6 +606,14 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if err != nil { return nil, err } + v_old := v + if strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) { + v = t.getBillee() + if v == nil { + return nil, errHTTPWontStoreMessage + } + } + if !v.MessageAllowed() { return nil, errHTTPTooManyRequestsLimitMessages } @@ -639,8 +648,9 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes "message_email": email, }). Debug("Received message") + //Where should I log the original visitor vs the billing visitor if log.IsTrace() { - logvrm(v, r, m). + logvrm(v_old, r, m). Tag(tagPublish). Field("message_body", util.MaybeMarshalJSON(m)). Trace("Message body") @@ -684,6 +694,10 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito if err != nil { return err } + + w.Header().Set("TTL", strconv.FormatInt(m.Expires-m.Time, 10)) // return how long a message will be stored for + + // using m.Time, not time.Now() so the value isn't negative if the request is processed at a second boundary return s.writeJSON(w, m) } From 7c5b9c0e62fc1ec113f12b81e70735e4fb84dc1e Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Tue, 14 Feb 2023 14:21:33 -0600 Subject: [PATCH 05/12] only log expiry if applicable --- server/server.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/server.go b/server/server.go index 16771332f..8a4655bea 100644 --- a/server/server.go +++ b/server/server.go @@ -1420,8 +1420,12 @@ func (s *Server) execManager() { defer s.mu.Unlock() for _, t := range s.topics { subs := t.SubscribersCount() - expiryTime := time.Until(t.lastVisitorExpires) - log.Tag(tagManager).Trace("- topic %s: %d subscribers, expires in %s", t.ID, subs, expiryTime) + expiryMessage := "" + if subs == 0 { + expiryTime := time.Until(t.lastVisitorExpires) + expiryMessage = ", expires in " + expiryTime.String() + } + log.Tag(tagManager).Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage) msgs, exists := messageCounts[t.ID] if t.Stale() && (!exists || msgs == 0) { log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID) From c6b64df662c8ea9a29d44571dcf7acba09cb8b1e Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Wed, 15 Feb 2023 03:31:59 -0600 Subject: [PATCH 06/12] remove ttl --- server/server.go | 1 - 1 file changed, 1 deletion(-) diff --git a/server/server.go b/server/server.go index 8a4655bea..abb08a8b3 100644 --- a/server/server.go +++ b/server/server.go @@ -372,7 +372,6 @@ func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor, } w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests - w.Header().Set("TTL", "0") // if message is not being stored because of an error, tell them w.WriteHeader(httpErr.HTTPCode) io.WriteString(w, httpErr.JSON()+"\n") } From b9badee6dbb4d43d6f1abee2f433e7113bb86ab8 Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Wed, 15 Feb 2023 03:38:24 -0600 Subject: [PATCH 07/12] remove TTL, will make a seperate PR --- server/server.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/server/server.go b/server/server.go index abb08a8b3..4a4be262f 100644 --- a/server/server.go +++ b/server/server.go @@ -694,9 +694,6 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito return err } - w.Header().Set("TTL", strconv.FormatInt(m.Expires-m.Time, 10)) // return how long a message will be stored for - - // using m.Time, not time.Now() so the value isn't negative if the request is processed at a second boundary return s.writeJSON(w, m) } From 36685e9df9b9792d85ca7b448ff29a570f6389df Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Mon, 20 Feb 2023 17:55:21 -0600 Subject: [PATCH 08/12] Suggested changes - https://github.com/binwiederhier/ntfy/pull/609/files/b9badee6dbb4d43d6f1abee2f433e7113bb86ab8#r1111115151 - https://github.com/binwiederhier/ntfy/pull/609/files/b9badee6dbb4d43d6f1abee2f433e7113bb86ab8#r1111114771 --- server/server.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/server/server.go b/server/server.go index 4a4be262f..aa261d5e9 100644 --- a/server/server.go +++ b/server/server.go @@ -675,8 +675,6 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if err := s.messageCache.AddMessage(m); err != nil { return nil, err } - } else { - m.Expires = m.Time } u := v.User() if s.userManager != nil && u != nil && u.Tier != nil { @@ -1416,12 +1414,15 @@ func (s *Server) execManager() { defer s.mu.Unlock() for _, t := range s.topics { subs := t.SubscribersCount() - expiryMessage := "" - if subs == 0 { - expiryTime := time.Until(t.lastVisitorExpires) - expiryMessage = ", expires in " + expiryTime.String() + ev := log.Tag(tagManager) + if ev.IsTrace() { + expiryMessage := "" + if subs == 0 { + expiryTime := time.Until(t.lastVisitorExpires) + expiryMessage = ", expires in " + expiryTime.String() + } + ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage) } - log.Tag(tagManager).Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage) msgs, exists := messageCounts[t.ID] if t.Stale() && (!exists || msgs == 0) { log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID) From 1655f584f9382239ab89f78f3250bc37aa5d62f3 Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Tue, 21 Feb 2023 20:04:56 -0600 Subject: [PATCH 09/12] rate limiting impl 2.0? --- server/errors.go | 4 +-- server/server.go | 75 ++++++++++++++++++++++++------------------------ server/topic.go | 56 ++++++++++++++++++++++++------------ server/util.go | 23 +++++++++++++-- 4 files changed, 97 insertions(+), 61 deletions(-) diff --git a/server/errors.go b/server/errors.go index f50a4df9f..819f972e5 100644 --- a/server/errors.go +++ b/server/errors.go @@ -3,8 +3,9 @@ package server import ( "encoding/json" "fmt" - "heckel.io/ntfy/log" "net/http" + + "heckel.io/ntfy/log" ) // errHTTP is a generic HTTP error for any non-200 HTTP error @@ -92,5 +93,4 @@ var ( errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""} errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"} - errHTTPWontStoreMessage = &errHTTP{50701, http.StatusInsufficientStorage, "topic is inactive; no device available to recieve message", ""} ) diff --git a/server/server.go b/server/server.go index aa261d5e9..a5307964a 100644 --- a/server/server.go +++ b/server/server.go @@ -9,12 +9,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/emersion/go-smtp" - "github.com/gorilla/websocket" - "golang.org/x/sync/errgroup" - "heckel.io/ntfy/log" - "heckel.io/ntfy/user" - "heckel.io/ntfy/util" "io" "net" "net/http" @@ -30,6 +24,13 @@ import ( "sync" "time" "unicode/utf8" + + "github.com/emersion/go-smtp" + "github.com/gorilla/websocket" + "golang.org/x/sync/errgroup" + "heckel.io/ntfy/log" + "heckel.io/ntfy/user" + "heckel.io/ntfy/util" ) /* @@ -605,23 +606,23 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if err != nil { return nil, err } - v_old := v - if strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) { - v = t.getBillee() - if v == nil { - return nil, errHTTPWontStoreMessage - } + vRate := v + if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil { + vRate = topicCountsAgainst } - if !v.MessageAllowed() { - return nil, errHTTPTooManyRequestsLimitMessages + if !vRate.MessageAllowed() { + vRate = v + if !v.MessageAllowed() { + return nil, errHTTPTooManyRequestsLimitMessages + } } body, err := util.Peek(r.Body, s.config.MessageLimit) if err != nil { return nil, err } m := newDefaultMessage(t.ID, "") - cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m) + cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vRate, m) if err != nil { return nil, err } @@ -630,7 +631,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes } m.Sender = v.IP() m.User = v.MaybeUserID() - m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix() + m.Expires = time.Unix(m.Time, 0).Add(vRate.Limits().MessageExpiryDuration).Unix() if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { return nil, err } @@ -638,18 +639,18 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes m.Message = emptyMessageBody } delayed := m.Time > time.Now().Unix() - logvrm(v, r, m). + logvrm(vRate, r, m). Tag(tagPublish). Fields(log.Context{ - "message_delayed": delayed, - "message_firebase": firebase, - "message_unifiedpush": unifiedpush, - "message_email": email, + "message_delayed": delayed, + "message_firebase": firebase, + "message_unifiedpush": unifiedpush, + "message_email": email, + "message_subscriber_rate_limited": vRate != v, }). Debug("Received message") - //Where should I log the original visitor vs the billing visitor if log.IsTrace() { - logvrm(v_old, r, m). + logvrm(vRate, r, m). Tag(tagPublish). Field("message_body", util.MaybeMarshalJSON(m)). Trace("Message body") @@ -659,7 +660,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes return nil, err } if s.firebaseClient != nil && firebase { - go s.sendToFirebase(v, m) + go s.sendToFirebase(vRate, m) } if s.smtpSender != nil && email != "" { go s.sendEmail(v, m, email) @@ -745,7 +746,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) { } } -func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { +func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { cache = readBoolParam(r, true, "x-cache", "cache") firebase = readBoolParam(r, true, "x-firebase", "firebase") m.Title = readParam(r, "x-title", "title", "t") @@ -785,7 +786,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca } email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e") if email != "" { - if !v.EmailAllowed() { + if !vRate.EmailAllowed() { return false, false, "", false, errHTTPTooManyRequestsLimitEmails } } @@ -800,13 +801,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca if err != nil { return false, false, "", false, errHTTPBadRequestPriorityInvalid } - tagsStr := readParam(r, "x-tags", "tags", "tag", "ta") - if tagsStr != "" { - m.Tags = make([]string, 0) - for _, s := range util.SplitNoEmpty(tagsStr, ",") { - m.Tags = append(m.Tags, strings.TrimSpace(s)) - } - } + m.Tags = readCommaSeperatedParam(r, "x-tags", "tags", "tag", "ta") delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in") if delayStr != "" { if !cache { @@ -996,7 +991,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * if err != nil { return err } - poll, since, scheduled, filters, err := parseSubscribeParams(r) + poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r) if err != nil { return err } @@ -1035,7 +1030,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * defer cancel() subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) + subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited)) } defer func() { for i, subscriberID := range subscriberIDs { @@ -1078,7 +1074,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi if err != nil { return err } - poll, since, scheduled, filters, err := parseSubscribeParams(r) + poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r) if err != nil { return err } @@ -1167,7 +1163,8 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) + subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited)) } defer func() { for i, subscriberID := range subscriberIDs { @@ -1188,7 +1185,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi return err } -func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) { +func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, subscriberTopics []string, err error) { poll = readBoolParam(r, false, "x-poll", "poll", "po") scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched") since, err = parseSince(r, poll) @@ -1199,6 +1196,8 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu if err != nil { return } + + subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt") return } diff --git a/server/topic.go b/server/topic.go index 6911ec0ba..19501a87f 100644 --- a/server/topic.go +++ b/server/topic.go @@ -19,9 +19,10 @@ type topic struct { } type topicSubscriber struct { - subscriber subscriber - visitor *visitor // User ID associated with this subscription, may be empty - cancel func() + subscriber subscriber + visitor *visitor // User ID associated with this subscription, may be empty + cancel func() + subscriberRateLimit bool } // subscriber is a function that is called for every new message on a topic @@ -36,31 +37,36 @@ func newTopic(id string) *topic { } // Subscribe subscribes to this topic -func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int { +func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscriberRateLimit bool) int { t.mu.Lock() defer t.mu.Unlock() subscriberID := rand.Int() t.subscribers[subscriberID] = &topicSubscriber{ - visitor: visitor, // May be empty - subscriber: s, - cancel: cancel, + visitor: visitor, // May be empty + subscriber: s, + cancel: cancel, + subscriberRateLimit: subscriberRateLimit, } + + // if no subscriber is already handling the rate limit + if t.lastVisitor == nil && subscriberRateLimit { + t.lastVisitor = visitor + t.lastVisitorExpires = time.Time{} + } + return subscriberID } func (t *topic) Stale() bool { - return t.getBillee() == nil -} - -func (t *topic) getBillee() *visitor { - for _, this_subscriber := range t.subscribers { - return this_subscriber.visitor - } - if t.lastVisitor != nil && t.lastVisitorExpires.After(time.Now()) { + // if Time is initialized (not the zero value) and the expiry time has passed + if !t.lastVisitorExpires.IsZero() && t.lastVisitorExpires.Before(time.Now()) { t.lastVisitor = nil } - return t.lastVisitor + return len(t.subscribers) == 0 && t.lastVisitor == nil +} +func (t *topic) Billee() *visitor { + return t.lastVisitor } // Unsubscribe removes the subscription from the list of subscribers @@ -68,11 +74,23 @@ func (t *topic) Unsubscribe(id int) { t.mu.Lock() defer t.mu.Unlock() - if len(t.subscribers) == 1 { - t.lastVisitor = t.subscribers[id].visitor + deletingSub := t.subscribers[id] + delete(t.subscribers, id) + + // look for an active subscriber (in random order) that wants to handle the rate limit + for _, v := range t.subscribers { + if v.subscriberRateLimit { + t.lastVisitor = v.visitor + t.lastVisitorExpires = time.Time{} + return + } + } + + // if no active subscriber is found, count it towards the leaving subscriber + if deletingSub.subscriberRateLimit { + t.lastVisitor = deletingSub.visitor t.lastVisitorExpires = time.Now().Add(subscriberBilledValidity) } - delete(t.subscribers, id) } // Publish asynchronously publishes to all subscribers diff --git a/server/util.go b/server/util.go index 8fbfaefa4..048e2f931 100644 --- a/server/util.go +++ b/server/util.go @@ -1,12 +1,13 @@ package server import ( - "heckel.io/ntfy/log" - "heckel.io/ntfy/util" "io" "net/http" "net/netip" "strings" + + "heckel.io/ntfy/log" + "heckel.io/ntfy/util" ) func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool { @@ -17,6 +18,17 @@ func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool { return value == "1" || value == "yes" || value == "true" } +func readCommaSeperatedParam(r *http.Request, names ...string) (params []string) { + paramStr := readParam(r, names...) + if paramStr != "" { + params = make([]string, 0) + for _, s := range util.SplitNoEmpty(paramStr, ",") { + params = append(params, strings.TrimSpace(s)) + } + } + return params +} + func readParam(r *http.Request, names ...string) string { value := readHeaderParam(r, names...) if value != "" { @@ -35,6 +47,13 @@ func readHeaderParam(r *http.Request, names ...string) string { return "" } +func readHeaderParamValues(r *http.Request, names ...string) (values []string) { + for _, name := range names { + values = append(values, r.Header.Values(name)...) + } + return +} + func readQueryParam(r *http.Request, names ...string) string { for _, name := range names { value := r.URL.Query().Get(strings.ToLower(name)) From bc3d897d7a315153f9dcdbf742435551abbeda4e Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Tue, 21 Feb 2023 20:16:03 -0600 Subject: [PATCH 10/12] Use mutexes in topic --- server/topic.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/topic.go b/server/topic.go index 19501a87f..b168c70f6 100644 --- a/server/topic.go +++ b/server/topic.go @@ -58,6 +58,8 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscri } func (t *topic) Stale() bool { + t.mu.Lock() + defer t.mu.Unlock() // if Time is initialized (not the zero value) and the expiry time has passed if !t.lastVisitorExpires.IsZero() && t.lastVisitorExpires.Before(time.Now()) { t.lastVisitor = nil @@ -66,6 +68,8 @@ func (t *topic) Stale() bool { } func (t *topic) Billee() *visitor { + t.mu.Lock() + defer t.mu.Unlock() return t.lastVisitor } From 0e4044b7477c7083556d6097ef02ae6436cc110f Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Tue, 21 Feb 2023 20:18:04 -0600 Subject: [PATCH 11/12] rename lastVisitor to vRate --- server/server.go | 2 +- server/topic.go | 32 ++++++++++++++++---------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/server/server.go b/server/server.go index a5307964a..e0be8d414 100644 --- a/server/server.go +++ b/server/server.go @@ -1417,7 +1417,7 @@ func (s *Server) execManager() { if ev.IsTrace() { expiryMessage := "" if subs == 0 { - expiryTime := time.Until(t.lastVisitorExpires) + expiryTime := time.Until(t.vRateExpires) expiryMessage = ", expires in " + expiryTime.String() } ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage) diff --git a/server/topic.go b/server/topic.go index b168c70f6..315453401 100644 --- a/server/topic.go +++ b/server/topic.go @@ -11,11 +11,11 @@ import ( // topic represents a channel to which subscribers can subscribe, and publishers // can publish a message type topic struct { - ID string - subscribers map[int]*topicSubscriber - lastVisitor *visitor - lastVisitorExpires time.Time - mu sync.Mutex + ID string + subscribers map[int]*topicSubscriber + vRate *visitor + vRateExpires time.Time + mu sync.Mutex } type topicSubscriber struct { @@ -49,9 +49,9 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscri } // if no subscriber is already handling the rate limit - if t.lastVisitor == nil && subscriberRateLimit { - t.lastVisitor = visitor - t.lastVisitorExpires = time.Time{} + if t.vRate == nil && subscriberRateLimit { + t.vRate = visitor + t.vRateExpires = time.Time{} } return subscriberID @@ -61,16 +61,16 @@ func (t *topic) Stale() bool { t.mu.Lock() defer t.mu.Unlock() // if Time is initialized (not the zero value) and the expiry time has passed - if !t.lastVisitorExpires.IsZero() && t.lastVisitorExpires.Before(time.Now()) { - t.lastVisitor = nil + if !t.vRateExpires.IsZero() && t.vRateExpires.Before(time.Now()) { + t.vRate = nil } - return len(t.subscribers) == 0 && t.lastVisitor == nil + return len(t.subscribers) == 0 && t.vRate == nil } func (t *topic) Billee() *visitor { t.mu.Lock() defer t.mu.Unlock() - return t.lastVisitor + return t.vRate } // Unsubscribe removes the subscription from the list of subscribers @@ -84,16 +84,16 @@ func (t *topic) Unsubscribe(id int) { // look for an active subscriber (in random order) that wants to handle the rate limit for _, v := range t.subscribers { if v.subscriberRateLimit { - t.lastVisitor = v.visitor - t.lastVisitorExpires = time.Time{} + t.vRate = v.visitor + t.vRateExpires = time.Time{} return } } // if no active subscriber is found, count it towards the leaving subscriber if deletingSub.subscriberRateLimit { - t.lastVisitor = deletingSub.visitor - t.lastVisitorExpires = time.Now().Add(subscriberBilledValidity) + t.vRate = deletingSub.visitor + t.vRateExpires = time.Now().Add(subscriberBilledValidity) } } From ce7d447f16b6cf4c029c919d69df72fbdfb53292 Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Tue, 21 Feb 2023 22:40:15 -0600 Subject: [PATCH 12/12] limitRequestsWithTopic --- server/server.go | 26 ++++++++++++-------------- server/server_middleware.go | 26 +++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/server/server.go b/server/server.go index e0be8d414..642f90256 100644 --- a/server/server.go +++ b/server/server.go @@ -437,13 +437,14 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit } else if r.Method == http.MethodOptions { return s.limitRequests(s.handleOptions)(w, r, v) // Should work even if the web app is not enabled, see #598 } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" { - return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v) + // So I don't *really* have to switch this order, since this is unrelated to UP; But, making this and matrix inconsistent is just calling for confusion, no? + return s.transformBodyJSON(s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish)))(w, r, v) } else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath { - return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v) + return s.transformMatrixJSON(s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v) } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) { - return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v) + return s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish))(w, r, v) } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) { - return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v) + return s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish))(w, r, v) } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) { return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v) } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) { @@ -602,20 +603,17 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error { } func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { - t, err := s.topicFromPath(r.URL.Path) - if err != nil { - return nil, err + vRate, ok := r.Context().Value("vRate").(*visitor) + if !ok { + return nil, errHTTPInternalError } - vRate := v - if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil { - vRate = topicCountsAgainst + t, ok := r.Context().Value("topic").(*topic) + if !ok { + return nil, errHTTPInternalError } if !vRate.MessageAllowed() { - vRate = v - if !v.MessageAllowed() { - return nil, errHTTPTooManyRequestsLimitMessages - } + return nil, errHTTPTooManyRequestsLimitMessages } body, err := util.Peek(r.Body, s.config.MessageLimit) if err != nil { diff --git a/server/server_middleware.go b/server/server_middleware.go index 684253ad9..e4f658b5b 100644 --- a/server/server_middleware.go +++ b/server/server_middleware.go @@ -1,8 +1,10 @@ package server import ( - "heckel.io/ntfy/util" + "context" "net/http" + + "heckel.io/ntfy/util" ) func (s *Server) limitRequests(next handleFunc) handleFunc { @@ -16,6 +18,28 @@ func (s *Server) limitRequests(next handleFunc) handleFunc { } } +// limitRequestsWithTopic limits requests with a topic and stores the rate-limiting-subscriber and topic into request.Context +func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc { + return func(w http.ResponseWriter, r *http.Request, v *visitor) error { + t, err := s.topicFromPath(r.URL.Path) + if err != nil { + return err + } + vRate := v + if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil { + vRate = topicCountsAgainst + } + r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vRate), "topic", t)) + + if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { + return next(w, r, v) + } else if !vRate.RequestAllowed() { + return errHTTPTooManyRequestsLimitRequests + } + return next(w, r, v) + } +} + func (s *Server) ensureWebEnabled(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { if !s.config.EnableWeb {