diff --git a/pubsub/jetstream/jetstream.go b/pubsub/jetstream/jetstream.go index 6652266bcf..08cd2a54bb 100644 --- a/pubsub/jetstream/jetstream.go +++ b/pubsub/jetstream/jetstream.go @@ -112,6 +112,89 @@ func (js *jetstreamPubSub) Features() []pubsub.Feature { return nil } +// A wrapper for nats.PubAckFuture that allows us to associate the message ID with the specific ack. +type pubAckWrapped struct { + ack nats.PubAckFuture + id string +} + +func (js *jetstreamPubSub) BulkPublish(ctx context.Context, req *pubsub.BulkPublishRequest) (pubsub.BulkPublishResponse, error) { + + if js.closed.Load() { + return pubsub.BulkPublishResponse{}, errors.New("component is closed") + } + + acks := []pubAckWrapped{} + errs := []pubsub.BulkPublishResponseFailedEntry{} + errsMutex := sync.Mutex{} + for _, entry := range req.Entries { + var opts []nats.PubOpt + var msgID string + + event, err := pubsub.FromCloudEvent(entry.Event, "", "", "", "") + if err != nil { + js.l.Debugf("error unmarshalling cloudevent: %v", err) + } else { + // Use the cloudevent id as the Nats-MsgId for deduplication + if id, ok := event["id"].(string); ok { + msgID = id + opts = append(opts, nats.MsgId(msgID)) + } + } + if msgID == "" { + js.l.Warn("empty message ID, Jetstream deduplication will not be possible") + } + + js.l.Debugf("Publishing to topic %v id: %s", req.Topic, msgID) + ack, err := js.jsc.PublishAsync(req.Topic, entry.Event, opts...) + if err != nil { + errs = append(errs, pubsub.BulkPublishResponseFailedEntry{ + EntryId: entry.EntryId, + Error: err, + }) + continue + } + ackWrapped := pubAckWrapped{ + ack: ack, + id: entry.EntryId, + } + acks = append(acks, ackWrapped) + } + + // Wait for all acks to be processed + var wg sync.WaitGroup + wg.Add(len(acks)) + for _, ack := range acks { + // We're spawning goroutines for each ack, as if there is some connectivity problem, + // we could end up timing out acks one by one, resulting in a very long operation. + go func(ack pubAckWrapped) { + select { + case <-ack.ack.Ok(): + case err := <-ack.ack.Err(): + if err != nil { + errsMutex.Lock() + errs = append(errs, pubsub.BulkPublishResponseFailedEntry{ + EntryId: ack.id, + Error: err, + }) + errsMutex.Unlock() + } + case <-ctx.Done(): + errsMutex.Lock() + // Context timed out or canceled + errs = append(errs, pubsub.BulkPublishResponseFailedEntry{ + EntryId: ack.id, + Error: ctx.Err(), + }) + errsMutex.Unlock() + } + wg.Done() + }(ack) + } + wg.Wait() + return pubsub.BulkPublishResponse{FailedEntries: errs}, nil +} + func (js *jetstreamPubSub) Publish(ctx context.Context, req *pubsub.PublishRequest) error { if js.closed.Load() { return errors.New("component is closed") diff --git a/pubsub/jetstream/jetstream_test.go b/pubsub/jetstream/jetstream_test.go index 82f9c10258..67a9d9e95e 100644 --- a/pubsub/jetstream/jetstream_test.go +++ b/pubsub/jetstream/jetstream_test.go @@ -15,6 +15,7 @@ package jetstream import ( "context" + "fmt" "testing" "time" @@ -57,6 +58,59 @@ func setupServerAndStream(t *testing.T) (*server.Server, *nats.Conn) { return ns, nc } +func TestNewJetStream_BulkPublish(t *testing.T) { + ns, nc := setupServerAndStream(t) + defer ns.Shutdown() + defer nc.Drain() + + bus := NewJetStream(logger.NewLogger("test")) + defer bus.Close() + + err := bus.Init(context.Background(), pubsub.Metadata{ + Base: mdata.Base{ + Properties: map[string]string{ + "natsURL": ns.ClientURL(), + }, + }, + }) + require.NoError(t, err) + + msgs := []pubsub.BulkMessageEntry{} + + for i := 0; i < 100; i++ { + msgs = append(msgs, pubsub.BulkMessageEntry{ + Event: []byte("test"), + EntryId: fmt.Sprintf("%d", i), + }) + } + + ctx := context.Background() + + bP, ok := bus.(pubsub.BulkPublisher) + if !ok { + t.Fatal("expected BulkPublisher") + } + + req := pubsub.BulkPublishRequest{ + PubsubName: "test", + Topic: "test", + Entries: msgs, + } + + res, err := bP.BulkPublish(ctx, &req) + require.NoError(t, err) + + assert.Len(t, res.FailedEntries, 0) + + js, err := nc.JetStream() + require.NoError(t, err) + + info, err := js.StreamInfo("test") + require.NoError(t, err) + + assert.Equal(t, uint64(100), info.State.Msgs) +} + func TestNewJetStream_EmphemeralPushConsumer(t *testing.T) { ns, nc := setupServerAndStream(t) defer ns.Shutdown()