diff --git a/downstream.go b/downstream.go index 8527e94a..16c8292c 100644 --- a/downstream.go +++ b/downstream.go @@ -248,6 +248,7 @@ var needAllDownstreamCaps = map[string]string{ "chghost": "", "extended-join": "", "extended-monitor": "", + "labeled-response": "", "message-tags": "", "multi-prefix": "", @@ -355,6 +356,10 @@ type downstreamConn struct { casemap xirc.CaseMapping monitored xirc.CaseMappingMap[struct{}] + + label string + labelBatch string + labelPending *irc.Message } func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { @@ -474,6 +479,51 @@ func (dc *downstreamConn) readMessages(ch chan<- event) error { return nil } +// DeferredResponse must be called during handleMessage when no immediate message is sent +// to the downstream in response to its message, but that the response will come later as +// a response from the upstream. +// +// This essentially keeps the label of the eventual current labeled-response "open" so that +// we can later fill it with the upstream response, instead of sending an empty "ACK" when +// handleMessage exits. +func (dc *downstreamConn) DeferredResponse() { + if dc.labelPending != nil { + panic(fmt.Sprintf("called DeferredResponse after buffering a message: %v", dc.labelPending)) + } + if dc.labelBatch != "" { + panic("called DeferredResponse after sending messages") + } + dc.label = "" +} + +func (dc *downstreamConn) FlushBatch() { + if dc.labelPending != nil { + dc.srv.metrics.downstreamOutMessagesTotal.Inc() + m := dc.labelPending.Copy() + m.Tags["label"] = dc.label + dc.conn.SendMessage(context.TODO(), m) + } else if dc.labelBatch != "" && dc.label != "" { + dc.srv.metrics.downstreamOutMessagesTotal.Inc() + dc.conn.SendMessage(context.TODO(), &irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BATCH", + Params: []string{fmt.Sprintf("-%s", dc.labelBatch)}, + }) + } else if dc.label != "" { + dc.srv.metrics.downstreamOutMessagesTotal.Inc() + dc.conn.SendMessage(context.TODO(), &irc.Message{ + Prefix: dc.srv.prefix(), + Command: "ACK", + Tags: irc.Tags{ + "label": dc.label, + }, + }) + } + dc.label = "" + dc.labelPending = nil + dc.labelBatch = "" +} + // SendMessage sends an outgoing message. // // This can only called from the user goroutine. @@ -537,8 +587,48 @@ func (dc *downstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { msg.Prefix = dc.srv.prefix() } - dc.srv.metrics.downstreamOutMessagesTotal.Inc() - dc.conn.SendMessage(ctx, msg) + if dc.labelPending != nil { + // create a batch + dc.lastBatchRef++ + dc.labelBatch = fmt.Sprintf("%v", dc.lastBatchRef) + dc.srv.metrics.downstreamOutMessagesTotal.Inc() + dc.conn.SendMessage(ctx, &irc.Message{ + Tags: irc.Tags{"label": dc.label}, + Prefix: dc.srv.prefix(), + Command: "BATCH", + Params: []string{"+" + dc.labelBatch, "labeled-response"}, + }) + + // send the buffered message + m := dc.labelPending.Copy() + if m.Tags["batch"] == "" { + if m.Tags == nil { + m.Tags = make(irc.Tags) + } + m.Tags["batch"] = dc.labelBatch + } + dc.srv.metrics.downstreamOutMessagesTotal.Inc() + dc.conn.SendMessage(ctx, m) + dc.labelPending = nil + } + if dc.labelBatch != "" { + // send the current message in the batch + m := msg.Copy() + if m.Tags["batch"] == "" { + if m.Tags == nil { + m.Tags = make(irc.Tags) + } + m.Tags["batch"] = dc.labelBatch + } + dc.srv.metrics.downstreamOutMessagesTotal.Inc() + dc.conn.SendMessage(ctx, m) + } else if dc.label != "" { + // first message we're sending: buffer it + dc.labelPending = msg + } else { + dc.srv.metrics.downstreamOutMessagesTotal.Inc() + dc.conn.SendMessage(ctx, msg) + } } func (dc *downstreamConn) SendBatch(ctx context.Context, typ string, params []string, tags irc.Tags, f func(batchRef string)) { @@ -625,16 +715,30 @@ func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) e ctx, cancel = context.WithTimeout(ctx, handleDownstreamMessageTimeout) defer cancel() + if dc.caps.IsEnabled("labeled-response") { + dc.label = msg.Tags["label"] + } + defer func() { + dc.FlushBatch() + }() + switch msg.Command { case "QUIT": dc.conn.Shutdown(ctx) return nil // TODO: stop handling commands default: + var err error if dc.registered { - return dc.handleMessageRegistered(ctx, msg) + err = dc.handleMessageRegistered(ctx, msg) } else { - return dc.handleMessageUnregistered(ctx, msg) + err = dc.handleMessageUnregistered(ctx, msg) + } + if ircErr, ok := err.(ircError); ok { + ircErr.Message.Prefix = dc.srv.prefix() + dc.SendMessage(ctx, ircErr.Message) + err = nil } + return err } } @@ -1699,10 +1803,7 @@ func (dc *downstreamConn) runUntilRegistered() error { } err = dc.handleMessage(ctx, msg) - if ircErr, ok := err.(ircError); ok { - ircErr.Message.Prefix = dc.srv.prefix() - dc.SendMessage(ctx, ircErr.Message) - } else if err != nil { + if err != nil { return fmt.Errorf("failed to handle IRC command %q: %v", msg, err) } @@ -1791,10 +1892,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if dc.network != nil { dc.network.Network.Nick = nick if uc := dc.upstream(); uc != nil { - uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, dc.label, &irc.Message{ Command: "NICK", Params: []string{nick}, }) + dc.DeferredResponse() } else { dc.updateNick(ctx) } @@ -1832,10 +1934,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if uc := dc.upstream(); uc != nil && uc.caps.IsEnabled("setname") { // Upstream will reply with a SETNAME message on success - uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, dc.label, &irc.Message{ Command: "SETNAME", Params: []string{realname}, }) + dc.DeferredResponse() err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &record) } else { @@ -1880,7 +1983,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. keys = strings.Split(msg.Params[1], ",") } - for i, name := range strings.Split(namesStr, ",") { + names := strings.Split(namesStr, ",") + for i, name := range names { var key string if len(keys) > i { key = keys[i] @@ -1910,10 +2014,20 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if key != "" { params = append(params, key) } - uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ - Command: "JOIN", - Params: params, - }) + if len(names) == 1 { + // only one channel: defer the labeled-response to the upstream + uc.SendMessageLabeled(ctx, dc.id, dc.label, &irc.Message{ + Command: "JOIN", + Params: params, + }) + dc.DeferredResponse() + } else { + // general case: respond to labeled-response locally + uc.SendMessageLabeled(ctx, dc.id, "", &irc.Message{ + Command: "JOIN", + Params: params, + }) + } } ch := uc.network.channels.Get(name) @@ -1951,7 +2065,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. reason = msg.Params[1] } - for _, name := range strings.Split(namesStr, ",") { + names := strings.Split(namesStr, ",") + for _, name := range names { if strings.EqualFold(reason, "detach") { ch := uc.network.channels.Get(name) if ch != nil { @@ -1971,10 +2086,20 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if reason != "" { params = append(params, reason) } - uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ - Command: "PART", - Params: params, - }) + if len(names) == 1 { + // only one channel: defer the labeled-response to the upstream + uc.SendMessageLabeled(ctx, dc.id, dc.label, &irc.Message{ + Command: "PART", + Params: params, + }) + dc.DeferredResponse() + } else { + // general case: respond to labeled-response locally + uc.SendMessageLabeled(ctx, dc.id, "", &irc.Message{ + Command: "PART", + Params: params, + }) + } if err := uc.network.deleteChannel(ctx, name); err != nil { dc.logger.Printf("failed to delete channel %q: %v", name, err) @@ -1989,7 +2114,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return err } - uc.SendMessageLabeled(ctx, dc.id, msg) + uc.SendMessageLabeled(ctx, dc.id, dc.label, msg) + dc.DeferredResponse() case "MODE": var name string if err := parseMessageParams(msg, &name); err != nil { @@ -2007,7 +2133,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if err != nil { return err } - uc.SendMessageLabeled(ctx, dc.id, msg) + uc.SendMessageLabeled(ctx, dc.id, dc.label, msg) + dc.DeferredResponse() } else { var userMode string if uc := dc.upstream(); uc != nil { @@ -2037,15 +2164,17 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if modeStr != "" { params := []string{name, modeStr} params = append(params, msg.Params[2:]...) - uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, dc.label, &irc.Message{ Command: "MODE", Params: params, }) + dc.DeferredResponse() } else { ch := uc.channels.Get(name) if ch == nil { // we're not on that channel, pass command to upstream - uc.SendMessageLabeled(ctx, dc.id, msg) + uc.SendMessageLabeled(ctx, dc.id, dc.label, msg) + dc.DeferredResponse() return nil } @@ -2083,15 +2212,17 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if len(msg.Params) > 1 { // setting topic topic := msg.Params[1] - uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, dc.label, &irc.Message{ Command: "TOPIC", Params: []string{name, topic}, }) + dc.DeferredResponse() } else { // getting topic ch := uc.channels.Get(name) if ch == nil { // we're not on that channel, pass command to upstream - uc.SendMessageLabeled(ctx, dc.id, msg) + uc.SendMessageLabeled(ctx, dc.id, dc.label, msg) + dc.DeferredResponse() } else { sendTopic(ctx, dc, ch) } @@ -2103,6 +2234,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } uc.enqueueCommand(dc, msg) + dc.DeferredResponse() case "NAMES": uc, err := dc.upstreamForCommand(msg.Command) if err != nil { @@ -2124,10 +2256,20 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. sendNames(ctx, dc, ch) } else { // NAMES on a channel we have not joined, ask upstream - uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ - Command: "NAMES", - Params: []string{name}, - }) + if len(channels) == 1 { + // only one channel: defer the labeled-response to the upstream + uc.SendMessageLabeled(ctx, dc.id, dc.label, &irc.Message{ + Command: "NAMES", + Params: []string{name}, + }) + dc.DeferredResponse() + } else { + // general case: respond to labeled-response locally + uc.SendMessageLabeled(ctx, dc.id, "", &irc.Message{ + Command: "NAMES", + Params: []string{name}, + }) + } } } case "WHO": @@ -2259,6 +2401,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } uc.enqueueCommand(dc, msg) + dc.DeferredResponse() case "WHOIS": if len(msg.Params) == 0 { return ircError{&irc.Message{ @@ -2337,6 +2480,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } uc.enqueueCommand(dc, msg) + dc.DeferredResponse() case "PRIVMSG", "NOTICE", "TAGMSG": var targetsStr, text string if msg.Command != "TAGMSG" { @@ -2351,7 +2495,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. tags := copyClientTags(msg.Tags) - for _, name := range strings.Split(targetsStr, ",") { + targets := strings.Split(targetsStr, ",") + for _, name := range targets { params := []string{name} if msg.Command != "TAGMSG" { params = append(params, text) @@ -2436,11 +2581,22 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. upstreamParams = append(upstreamParams, text) } - uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ - Tags: tags, - Command: msg.Command, - Params: upstreamParams, - }) + if len(targets) == 1 && uc.caps.IsEnabled("echo-message") { + // only one target and echo-message is supported: defer the labeled-response to the upstream + uc.SendMessageLabeled(ctx, dc.id, dc.label, &irc.Message{ + Tags: tags, + Command: msg.Command, + Params: upstreamParams, + }) + dc.DeferredResponse() + } else { + // general case: respond to labeled-response locally + uc.SendMessageLabeled(ctx, dc.id, "", &irc.Message{ + Tags: tags, + Command: msg.Command, + Params: upstreamParams, + }) + } // If the upstream supports echo message, we'll produce the message // when it is echoed from the upstream. @@ -2477,7 +2633,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return err } - uc.SendMessageLabeled(ctx, dc.id, msg) + uc.SendMessageLabeled(ctx, dc.id, dc.label, msg) + dc.DeferredResponse() case "AUTHENTICATE": // Post-connection-registration AUTHENTICATE is only supported if an // upstream is bound and supports SASL @@ -2512,6 +2669,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Command: "AUTHENTICATE", Params: []string{"PLAIN"}, }) + dc.DeferredResponse() case "ANONYMOUS": if uc.network.SASL.Mechanism != "" { record := uc.network.Network // copy network record because we'll mutate it @@ -2554,6 +2712,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. uc.logger.Printf("starting %v with account name %v", msg.Command, msg.Params[0]) uc.enqueueCommand(dc, msg) + dc.DeferredResponse() case "AWAY": if len(msg.Params) > 0 { dc.away = &msg.Params[0] @@ -2586,7 +2745,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Params: []string{dc.nick, msg.Command, "Disconnected from upstream network"}, }} } else { - uc.SendMessageLabeled(ctx, dc.id, msg) + uc.SendMessageLabeled(ctx, dc.id, dc.label, msg) + dc.DeferredResponse() } case "MONITOR": uc := dc.upstream() @@ -2602,6 +2762,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return err } + // TODO: support MONITOR labeled-response through upstream + switch strings.ToUpper(subcommand) { case "+", "-": var targets string @@ -3301,7 +3463,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }} } - uc.SendMessageLabeled(ctx, dc.id, msg) + uc.SendMessageLabeled(ctx, dc.id, dc.label, msg) + dc.DeferredResponse() } return nil } diff --git a/upstream.go b/upstream.go index b3c92276..2c8f7adf 100644 --- a/upstream.go +++ b/upstream.go @@ -185,9 +185,10 @@ func (uu *upstreamUser) updateFrom(update *upstreamUser) { } type pendingUpstreamCommand struct { - downstreamID uint64 - msg *irc.Message - sentAt time.Time + downstreamID uint64 + downstreamLabel string + msg *irc.Message + sentAt time.Time } type upstreamConn struct { @@ -450,6 +451,8 @@ func (uc *upstreamConn) forwardMsgByID(ctx context.Context, id uint64, msg *irc. } func (uc *upstreamConn) abortPendingCommands() { + // TODO: support labeled-response + ctx := context.TODO() for _, l := range uc.pendingCmds { for _, pendingCmd := range l { @@ -503,7 +506,7 @@ func (uc *upstreamConn) sendNextPendingCommand(cmd string) { return } pendingCmd := &uc.pendingCmds[cmd][0] - uc.SendMessageLabeled(context.TODO(), pendingCmd.downstreamID, pendingCmd.msg) + uc.SendMessageLabeled(context.TODO(), pendingCmd.downstreamID, pendingCmd.downstreamLabel, pendingCmd.msg) pendingCmd.sentAt = time.Now() } @@ -516,8 +519,9 @@ func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) { } uc.pendingCmds[msg.Command] = append(uc.pendingCmds[msg.Command], pendingUpstreamCommand{ - downstreamID: dc.id, - msg: msg, + downstreamID: dc.id, + downstreamLabel: dc.label, + msg: msg, }) // If we didn't get a reply after a while, just give up @@ -580,6 +584,26 @@ func (uc *upstreamConn) parseMembershipPrefix(s string) (ms xirc.MembershipSet, return memberships, s[i:] } +func (uc *upstreamConn) parseLabel(label string) (downstreamID uint64, downstreamLabel string, err error) { + if label == "" { + return + } + parts := strings.SplitN(label, "-", 4) + if len(parts) < 4 { + err = errors.New("not enough arguments") + } else if parts[0] != "sd" { + err = fmt.Errorf("expected %v, got %v", "sd", parts[0]) + } else { + downstreamID, err = strconv.ParseUint(parts[1], 10, 64) + } + if err != nil { + err = fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err) + return + } + downstreamLabel = parts[3] + return +} + func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error { var label string if l, ok := msg.Tags["label"]; ok { @@ -588,6 +612,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } var msgBatch *upstreamBatch + batchLabel := false if batchName, ok := msg.Tags["batch"]; ok { b, ok := uc.batches[batchName] if !ok { @@ -595,22 +620,34 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } msgBatch = &b if label == "" { + batchLabel = true label = msgBatch.Label } delete(msg.Tags, "batch") } - var downstreamID uint64 - if label != "" { - var labelOffset uint64 - n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset) - if err == nil && n < 2 { - err = errors.New("not enough arguments") - } - if err != nil { - return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err) + downstreamID, downstreamLabel, err := uc.parseLabel(label) + if err != nil { + return err + } + if downstreamID != 0 { + if batchLabel { + // label comes from batch, send the message as part of that batch + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.labelBatch = msg.Tags["batch"] + }) + } else { + // label comes from this message, respond to the message with that label + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.label = downstreamLabel + }) } } + defer func() { + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.FlushBatch() + }) + }() if msg.Prefix == nil { msg.Prefix = uc.serverPrefix @@ -1064,12 +1101,39 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err Outer: msgBatch, Label: label, } + if downstreamID != 0 && downstreamLabel != "" { + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.label = "" + dc.SendMessage(ctx, &irc.Message{ + Prefix: uc.srv.prefix(), + Command: "BATCH", + Params: msg.Params, + Tags: irc.Tags{ + "label": downstreamLabel, + }, + }) + }) + } } else if strings.HasPrefix(tag, "-") { tag = tag[1:] if _, ok := uc.batches[tag]; !ok { return fmt.Errorf("unknown BATCH reference tag: %q", tag) } + label := uc.batches[tag].Label delete(uc.batches, tag) + + // BATCH - does not have @label/@batch attached to it, so downstreamID is empty. + // extract it back from the batch struct. + downstreamID, downstreamLabel, err := uc.parseLabel(label) + if downstreamID != 0 && downstreamLabel != "" && err == nil { + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(ctx, &irc.Message{ + Prefix: uc.srv.prefix(), + Command: "BATCH", + Params: msg.Params, + }) + }) + } } else { return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag) } @@ -2100,12 +2164,12 @@ func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { uc.conn.SendMessage(ctx, msg) } -func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) { +func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, label string, msg *irc.Message) { if uc.caps.IsEnabled("labeled-response") { if msg.Tags == nil { msg.Tags = make(irc.Tags) } - msg.Tags["label"] = fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID) + msg.Tags["label"] = fmt.Sprintf("sd-%d-%d-%s", downstreamID, uc.nextLabelID, label) uc.nextLabelID++ } uc.SendMessage(ctx, msg) diff --git a/user.go b/user.go index c2d2da5b..2cf121cd 100644 --- a/user.go +++ b/user.go @@ -798,10 +798,7 @@ func (u *user) run() { break } err := dc.handleMessage(context.TODO(), msg) - if ircErr, ok := err.(ircError); ok { - ircErr.Message.Prefix = dc.srv.prefix() - dc.SendMessage(context.TODO(), ircErr.Message) - } else if err != nil { + if err != nil { dc.logger.Printf("failed to handle message %q: %v", msg, err) dc.Close() }