Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subscriber Billed Topics #609

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions server/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,5 @@ var (
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}
errHTTPWontStoreMessage = &errHTTP{50701, http.StatusInsufficientStorage, "topic is inactive; no device available to recieve message", ""}
karmanyaahm marked this conversation as resolved.
Show resolved Hide resolved
)
45 changes: 32 additions & 13 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,15 @@ var (
)

const (
firebaseControlTopic = "~control" // See Android if changed
firebasePollTopic = "~poll" // See iOS if changed
emptyMessageBody = "triggered" // Used if message body is empty
newMessageBody = "New message" // Used in poll requests as generic message
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
jsonBodyBytesLimit = 16384
firebaseControlTopic = "~control" // See Android if changed
firebasePollTopic = "~poll" // See iOS if changed
emptyMessageBody = "triggered" // Used if message body is empty
newMessageBody = "New message" // Used in poll requests as generic message
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
jsonBodyBytesLimit = 16384
subscriberBilledTopicPrefix = "up_"
karmanyaahm marked this conversation as resolved.
Show resolved Hide resolved
karmanyaahm marked this conversation as resolved.
Show resolved Hide resolved
subscriberBilledValidity = 12 * time.Hour
)

// WebSocket constants
Expand Down Expand Up @@ -370,6 +372,7 @@ func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.Header().Set("TTL", "0") // if message is not being stored because of an error, tell them
w.WriteHeader(httpErr.HTTPCode)
io.WriteString(w, httpErr.JSON()+"\n")
}
Expand Down Expand Up @@ -603,6 +606,14 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
if err != nil {
return nil, err
}
v_old := v
if strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) {
v = t.getBillee()
if v == nil {
return nil, errHTTPWontStoreMessage
}
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I said to just replace v in an older comment, but the more I think about it, the more I think that this is literally just about rate limiting, right? So maybe maybe our Billee is just a vrate (= visitor used for rate limiting), and for all the rate limiting functions, we'd use vrate, but for everything else, we'd use v?

I have also noticed that you haven't tackled the limitRequests middleware at all. The rate limiting for number of messages (visitor-message-daily-limit) and number of emails (visitor-email-limit-*) happens in handlePublish*, but the rate limiting for requests (visitor-request-limit-*) happens in limitRequests, so this logic would have to happen there, not here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this using context (as suggested below), but isn't that supposed to be an anti-pattern? In any case, I guess it's just one thing so it shouldn't matter much. I feel like the ideal solution to this would be making visitor an interface and having a compositeVisitor in addition to visitor...but, that's probably overengineering.


if !v.MessageAllowed() {
return nil, errHTTPTooManyRequestsLimitMessages
}
Expand All @@ -620,7 +631,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
}
m.Sender = v.IP()
m.User = v.MaybeUserID()
m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix()
m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
return nil, err
}
Expand All @@ -637,8 +648,9 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
"message_email": email,
}).
Debug("Received message")
//Where should I log the original visitor vs the billing visitor
if log.IsTrace() {
logvrm(v, r, m).
logvrm(v_old, r, m).
karmanyaahm marked this conversation as resolved.
Show resolved Hide resolved
Tag(tagPublish).
Field("message_body", util.MaybeMarshalJSON(m)).
Trace("Message body")
Expand All @@ -664,6 +676,8 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
if err := s.messageCache.AddMessage(m); err != nil {
return nil, err
}
} else {
m.Expires = m.Time
karmanyaahm marked this conversation as resolved.
Show resolved Hide resolved
}
u := v.User()
if s.userManager != nil && u != nil && u.Tier != nil {
Expand All @@ -680,6 +694,10 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
if err != nil {
return err
}

w.Header().Set("TTL", strconv.FormatInt(m.Expires-m.Time, 10)) // return how long a message will be stored for

// using m.Time, not time.Now() so the value isn't negative if the request is processed at a second boundary
return s.writeJSON(w, m)
}

Expand Down Expand Up @@ -1023,7 +1041,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
defer cancel()
subscriberIDs := make([]int, 0)
for _, t := range topics {
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
}
defer func() {
for i, subscriberID := range subscriberIDs {
Expand Down Expand Up @@ -1155,7 +1173,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
subscriberIDs := make([]int, 0)
for _, t := range topics {
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
}
defer func() {
for i, subscriberID := range subscriberIDs {
Expand Down Expand Up @@ -1402,9 +1420,10 @@ func (s *Server) execManager() {
defer s.mu.Unlock()
for _, t := range s.topics {
subs := t.SubscribersCount()
log.Tag(tagManager).Trace("- topic %s: %d subscribers", t.ID, subs)
expiryTime := time.Until(t.lastVisitorExpires)
log.Tag(tagManager).Trace("- topic %s: %d subscribers, expires in %s", t.ID, subs, expiryTime)
msgs, exists := messageCounts[t.ID]
binwiederhier marked this conversation as resolved.
Show resolved Hide resolved
if subs == 0 && (!exists || msgs == 0) {
if t.Stale() && (!exists || msgs == 0) {
log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID)
emptyTopics++
delete(s.topics, t.ID)
Expand Down
45 changes: 35 additions & 10 deletions server/topic.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
package server

import (
"heckel.io/ntfy/log"
"math/rand"
"sync"
"time"

"heckel.io/ntfy/log"
)

// topic represents a channel to which subscribers can subscribe, and publishers
// can publish a message
type topic struct {
ID string
subscribers map[int]*topicSubscriber
mu sync.Mutex
ID string
subscribers map[int]*topicSubscriber
lastVisitor *visitor
lastVisitorExpires time.Time
mu sync.Mutex
}

type topicSubscriber struct {
userID string // User ID associated with this subscription, may be empty
subscriber subscriber
visitor *visitor // User ID associated with this subscription, may be empty
cancel func()
}

Expand All @@ -32,22 +36,42 @@ func newTopic(id string) *topic {
}

// Subscribe subscribes to this topic
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
t.mu.Lock()
defer t.mu.Unlock()
subscriberID := rand.Int()
t.subscribers[subscriberID] = &topicSubscriber{
userID: userID, // May be empty
visitor: visitor, // May be empty
subscriber: s,
cancel: cancel,
}
return subscriberID
}

func (t *topic) Stale() bool {
return t.getBillee() == nil
}

func (t *topic) getBillee() *visitor {
for _, this_subscriber := range t.subscribers {
return this_subscriber.visitor
}
if t.lastVisitor != nil && t.lastVisitorExpires.After(time.Now()) {
t.lastVisitor = nil
}
return t.lastVisitor

}
Copy link
Owner

@binwiederhier binwiederhier Feb 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Style: Don't use snake_case, use camelCase. And: don't use getters (no getX), use X instead, here: func billee()
  2. This function should use the mutex!
  3. This is a strange way of retrieving the billee, IMHO. It took me very very very long to understand, which is probably not a good sign. So maybe we can do something where we assign it in Subscribe and pick a new one in Unsubscribe (if we're removing the the topicSubscriber with the same visitor as in topic.lastVisitor

Something like this:

type topic struct {
  vrate *visitor
  vrateExpires time.Time

(like you have, but rename it)

And then just assign it when you first subscribe:

  if t.vrate = nil {
    t.vrate = v
    t.vrateExpires = ...
  }
  t.subscribers[subscriberID] = &topicSubscriber{
    visitor: v,
    ..   
  }

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated my comment from this afternoon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 and 2 are done, and I think the latest rewrite achieves 3?


// Unsubscribe removes the subscription from the list of subscribers
func (t *topic) Unsubscribe(id int) {
t.mu.Lock()
defer t.mu.Unlock()

if len(t.subscribers) == 1 {
t.lastVisitor = t.subscribers[id].visitor
t.lastVisitorExpires = time.Now().Add(subscriberBilledValidity)
}
delete(t.subscribers, id)
}

Expand Down Expand Up @@ -87,8 +111,9 @@ func (t *topic) CancelSubscribers(exceptUserID string) {
t.mu.Lock()
defer t.mu.Unlock()
for _, s := range t.subscribers {
if s.userID != exceptUserID {
log.Tag(tagSubscribe).Field("topic", t.ID).Debug("Canceling subscriber %s", s.userID)
if s.visitor.MaybeUserID() != exceptUserID {
// TODO: Shouldn't this log the IP for anonymous visitors? It was s.userID before my change.
log.Tag(tagSubscribe).Field("topic", t.ID).Debug("Canceling subscriber %s", s.visitor.MaybeUserID())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should that log be visitorID instead of userID, to also log anonymous visitors properly?

s.cancel()
}
}
Expand All @@ -101,7 +126,7 @@ func (t *topic) subscribersCopy() map[int]*topicSubscriber {
subscribers := make(map[int]*topicSubscriber)
for k, sub := range t.subscribers {
subscribers[k] = &topicSubscriber{
userID: sub.userID,
visitor: sub.visitor,
subscriber: sub.subscriber,
cancel: sub.cancel,
}
Expand Down