diff --git a/internal/tasks/handler.go b/internal/tasks/handler.go index 0143760..e49ddc5 100644 --- a/internal/tasks/handler.go +++ b/internal/tasks/handler.go @@ -38,7 +38,6 @@ func NewHandler(federationService service.FederationService, repositories reposi func (t *Handler) Register(tq *taskqueue.Client) { tq.RegisterHandler(FederateNote, parse(t.FederateNote)) - tq.RegisterHandler(FederateFollow, parse(t.FederateFollow)) } func parse[T any](handler func(context.Context, T) error) func(context.Context, taskqueue.Task) error { diff --git a/internal/tasks/note_tasks/federate_note.go b/internal/tasks/note_tasks/federate_note.go new file mode 100644 index 0000000..6ce3225 --- /dev/null +++ b/internal/tasks/note_tasks/federate_note.go @@ -0,0 +1,52 @@ +package note_tasks + +import ( + "context" + "github.com/versia-pub/versia-go/internal/repository" + + "github.com/google/uuid" + "github.com/versia-pub/versia-go/internal/entity" +) + +type FederateNoteData struct { + NoteID uuid.UUID `json:"noteID"` +} + +func (t *Handler) FederateNote(ctx context.Context, data FederateNoteData) error { + s := t.telemetry.StartSpan(ctx, "function", "tasks/Handler.FederateNote") + defer s.End() + ctx = s.Context() + + var n *entity.Note + if err := t.repositories.Atomic(ctx, func(ctx context.Context, tx repository.Manager) error { + var err error + n, err = tx.Notes().GetByID(ctx, data.NoteID) + if err != nil { + return err + } + if n == nil { + t.log.V(-1).Info("Could not find note", "id", data.NoteID) + return nil + } + + for _, uu := range n.Mentions { + if !uu.IsRemote { + t.log.V(2).Info("User is not remote", "user", uu.ID) + continue + } + + res, err := t.federationService.SendToInbox(ctx, n.Author, &uu, n.ToVersia()) + if err != nil { + t.log.Error(err, "Failed to send note to remote user", "res", res, "user", uu.ID) + } else { + t.log.V(2).Info("Sent note to remote user", "res", res, "user", uu.ID) + } + } + + return nil + }); err != nil { + return err + } + + return nil +} diff --git a/internal/tasks/note_tasks/handler.go b/internal/tasks/note_tasks/handler.go new file mode 100644 index 0000000..d103678 --- /dev/null +++ b/internal/tasks/note_tasks/handler.go @@ -0,0 +1,51 @@ +package note_tasks + +import ( + "context" + "encoding/json" + "github.com/versia-pub/versia-go/internal/repository" + "github.com/versia-pub/versia-go/internal/service" + + "git.devminer.xyz/devminer/unitel" + "github.com/go-logr/logr" + "github.com/versia-pub/versia-go/pkg/taskqueue" +) + +const ( + FederateNote = "federate_note" +) + +type Handler struct { + federationService service.FederationService + + repositories repository.Manager + + telemetry *unitel.Telemetry + log logr.Logger +} + +func NewHandler(federationService service.FederationService, repositories repository.Manager, telemetry *unitel.Telemetry, log logr.Logger) *Handler { + return &Handler{ + federationService: federationService, + + repositories: repositories, + + telemetry: telemetry, + log: log, + } +} + +func (t *Handler) Register(tq *taskqueue.Client) { + tq.RegisterHandler(FederateNote, parse(t.FederateNote)) +} + +func parse[T any](handler func(context.Context, T) error) func(context.Context, taskqueue.Task) error { + return func(ctx context.Context, task taskqueue.Task) error { + var data T + if err := json.Unmarshal(task.Payload, &data); err != nil { + return err + } + + return handler(ctx, data) + } +} diff --git a/pkg/taskqueue/client.go b/pkg/taskqueue/client.go index f3ccd56..951c595 100644 --- a/pkg/taskqueue/client.go +++ b/pkg/taskqueue/client.go @@ -4,8 +4,7 @@ import ( "context" "encoding/json" "errors" - "strings" - "sync" + "fmt" "time" "git.devminer.xyz/devminer/unitel" @@ -53,12 +52,9 @@ func NewTask(type_ string, payload any) (Task, error) { }, nil } -type Handler func(ctx context.Context, task Task) error - type Client struct { - name string - subject string - handlers map[string][]Handler + name string + subject string nc *nats.Conn js jetstream.JetStream @@ -71,15 +67,27 @@ type Client struct { log logr.Logger } -func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, telemetry *unitel.Telemetry, log logr.Logger) (*Client, error) { +func NewClient(streamName string, natsClient *nats.Conn, telemetry *unitel.Telemetry, log logr.Logger) (*Client, error) { js, err := jetstream.New(natsClient) if err != nil { return nil, err } - s, err := js.CreateStream(ctx, jetstream.StreamConfig{ + return &Client{ + name: streamName, + + js: js, + + telemetry: telemetry, + log: log, + }, nil +} + +func (c *Client) TaskSet(ctx context.Context, name string) (*TaskSet, error) { + streamName := fmt.Sprintf("%s_%s", c.name, name) + + s, err := c.js.CreateStream(ctx, jetstream.StreamConfig{ Name: streamName, - Subjects: []string{streamName + ".*"}, MaxConsumers: -1, MaxMsgs: -1, Discard: jetstream.DiscardOld, @@ -89,7 +97,7 @@ func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, te AllowDirect: true, }) if errors.Is(err, nats.ErrStreamNameAlreadyInUse) { - s, err = js.Stream(ctx, streamName) + s, err = c.js.Stream(ctx, streamName) if err != nil { return nil, err } @@ -97,201 +105,5 @@ func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, te return nil, err } - stopCh := make(chan struct{}) - - c := &Client{ - name: streamName, - subject: streamName + ".tasks", - - handlers: map[string][]Handler{}, - - stopCh: stopCh, - closeOnce: sync.OnceFunc(func() { - close(stopCh) - }), - - nc: natsClient, - js: js, - s: s, - - telemetry: telemetry, - log: log, - } - - go func() { - <-ctx.Done() - - c.Close() - }() - - return c, nil -} - -func (c *Client) Close() { - c.closeOnce() - c.nc.Close() -} - -func (c *Client) Submit(ctx context.Context, task Task) error { - s := c.telemetry.StartSpan(ctx, "queue.publish", "taskqueue/Client.Submit"). - AddAttribute("messaging.destination.name", c.subject) - defer s.End() - ctx = s.Context() - - s.AddAttribute("jobID", task.ID) - - data, err := json.Marshal(c.newTaskWrapper(ctx, task)) - if err != nil { - return err - } - - s.AddAttribute("messaging.message.body.size", len(data)) - - msg, err := c.js.PublishMsg(ctx, &nats.Msg{Subject: c.subject, Data: data}) - if err != nil { - return err - } - c.log.V(2).Info("Submitted task", "id", task.ID, "type", task.Type, "sequence", msg.Sequence) - - s.AddAttribute("messaging.message.id", msg.Sequence) - - return nil -} - -func (c *Client) RegisterHandler(type_ string, handler Handler) { - c.log.V(2).Info("Registering handler", "type", type_) - - if _, ok := c.handlers[type_]; !ok { - c.handlers[type_] = []Handler{} - } - c.handlers[type_] = append(c.handlers[type_], handler) -} - -func (c *Client) StartConsumer(ctx context.Context, consumerGroup string) error { - c.log.Info("Starting consumer") - - sub, err := c.js.CreateConsumer(ctx, c.name, jetstream.ConsumerConfig{ - Durable: consumerGroup, - DeliverPolicy: jetstream.DeliverAllPolicy, - ReplayPolicy: jetstream.ReplayInstantPolicy, - AckPolicy: jetstream.AckExplicitPolicy, - FilterSubject: c.subject, - MaxWaiting: 1, - MaxAckPending: 1, - HeadersOnly: false, - MemoryStorage: false, - }) - if err != nil { - return err - } - - m, err := sub.Messages(jetstream.PullMaxMessages(1)) - if err != nil { - return err - } - - go func() { - for { - msg, err := m.Next() - if err != nil { - if errors.Is(err, jetstream.ErrMsgIteratorClosed) { - c.log.Info("Stopping") - return - } - - c.log.Error(err, "Failed to get next message") - break - } - - if err := c.handleTask(ctx, msg); err != nil { - c.log.Error(err, "Failed to handle task") - break - } - } - }() - go func() { - <-c.stopCh - m.Drain() - }() - - return nil -} - -type Consumer struct { - telemetry *unitel.Telemetry - log logr.Logger -} - -func (c *Consumer) handleTask(ctx context.Context, msg jetstream.Msg) error { - msgMeta, err := msg.Metadata() - if err != nil { - return err - } - - data := msg.Data() - - var w taskWrapper - if err := json.Unmarshal(data, &w); err != nil { - if err := msg.Nak(); err != nil { - c.log.Error(err, "Failed to nak message") - } - - return err - } - - s := c.telemetry.StartSpan( - context.Background(), - "queue.process", - "taskqueue/Client.handleTask", - c.telemetry.ContinueFromMap(w.TraceInfo), - ). - AddAttribute("messaging.destination.name", c.subject). - AddAttribute("messaging.message.id", msgMeta.Sequence.Stream). - AddAttribute("messaging.message.retry.count", msgMeta.NumDelivered). - AddAttribute("messaging.message.body.size", len(data)). - AddAttribute("messaging.message.receive.latency", time.Since(w.EnqueuedAt).Milliseconds()) - defer s.End() - ctx = s.Context() - - handlers, ok := c.handlers[w.Task.Type] - if !ok { - c.log.V(2).Info("No handler for task", "type", w.Task.Type) - return msg.Nak() - } - - var errs CombinedError - for _, handler := range handlers { - if err := handler(ctx, w.Task); err != nil { - c.log.Error(err, "Handler failed", "type", w.Task.Type) - errs.Errors = append(errs.Errors, err) - } - } - - if len(errs.Errors) > 0 { - if err := msg.Nak(); err != nil { - c.log.Error(err, "Failed to nak message") - errs.Errors = append(errs.Errors, err) - } - - return errs - } - - return msg.Ack() -} - -type CombinedError struct { - Errors []error -} - -func (e CombinedError) Error() string { - sb := strings.Builder{} - sb.WriteRune('[') - for i, err := range e.Errors { - if i > 0 { - sb.WriteRune(',') - } - sb.WriteString(err.Error()) - } - sb.WriteRune(']') - return sb.String() + return newTaskSet(streamName, c, s, c.log.WithName(fmt.Sprintf("taskset(%s)", streamName)), c.telemetry), nil } diff --git a/pkg/taskqueue/errors.go b/pkg/taskqueue/errors.go new file mode 100644 index 0000000..427bc7d --- /dev/null +++ b/pkg/taskqueue/errors.go @@ -0,0 +1,20 @@ +package taskqueue + +import "strings" + +type CombinedError struct { + Errors []error +} + +func (e CombinedError) Error() string { + sb := strings.Builder{} + sb.WriteRune('[') + for i, err := range e.Errors { + if i > 0 { + sb.WriteRune(',') + } + sb.WriteString(err.Error()) + } + sb.WriteRune(']') + return sb.String() +} diff --git a/pkg/taskqueue/taskset.go b/pkg/taskqueue/taskset.go new file mode 100644 index 0000000..f515787 --- /dev/null +++ b/pkg/taskqueue/taskset.go @@ -0,0 +1,225 @@ +package taskqueue + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "git.devminer.xyz/devminer/unitel" + "github.com/go-logr/logr" + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + "sync" + "time" +) + +type TaskHandler = func(ctx context.Context, task Task) error + +type TaskSet struct { + handlers map[string][]TaskHandler + + streamName string + c *Client + s jetstream.Stream + log logr.Logger + telemetry *unitel.Telemetry +} + +func newTaskSet(streamName string, client *Client, s jetstream.Stream, log logr.Logger, telemetry *unitel.Telemetry) *TaskSet { + return &TaskSet{ + handlers: make(map[string][]TaskHandler), + + streamName: streamName, + c: client, + s: s, + log: log, + telemetry: telemetry, + } +} + +func (t *TaskSet) RegisterHandler(type_ string, handler TaskHandler) { + t.log.V(2).Info("Registering handler", "type", type_) + + if _, ok := t.handlers[type_]; !ok { + t.handlers[type_] = []TaskHandler{} + } + t.handlers[type_] = append(t.handlers[type_], handler) +} + +func (t *TaskSet) Submit(ctx context.Context, task Task) error { + s := t.telemetry.StartSpan(ctx, "queue.publish", "taskqueue/TaskSet.Submit"). + AddAttribute("messaging.destination.name", t.streamName) + defer s.End() + ctx = s.Context() + + s.AddAttribute("jobID", task.ID) + + data, err := json.Marshal(t.c.newTaskWrapper(ctx, task)) + if err != nil { + return err + } + + s.AddAttribute("messaging.message.body.size", len(data)) + + // TODO: Refactor + msg, err := t.c.js.PublishMsg(ctx, &nats.Msg{Subject: t.streamName, Data: data}) + if err != nil { + return err + } + t.log.V(2).Info("Submitted task", "id", task.ID, "type", task.Type, "sequence", msg.Sequence) + + s.AddAttribute("messaging.message.id", msg.Sequence) + + return nil +} + +func (t *TaskSet) StartConsumer(name string) *Consumer { + return newConsumer(name, t.streamName, t, t.log.WithName(fmt.Sprintf("consumer(%s)", name)), t.telemetry) +} + +type Consumer struct { + stopCh chan struct{} + stopOnce func() + + name string + streamName string + telemetry *unitel.Telemetry + log logr.Logger + t *TaskSet +} + +func newConsumer(name, streamName string, t *TaskSet, log logr.Logger, telemetry *unitel.Telemetry) *Consumer { + stopCh := make(chan struct{}) + stopOnce := sync.OnceFunc(func() { + close(stopCh) + }) + + return &Consumer{ + stopCh: stopCh, + stopOnce: stopOnce, + + name: name, + streamName: streamName, + telemetry: telemetry, + log: log, + t: t, + } +} + +func (c *Consumer) Close() { + c.stopOnce() +} + +func (c *Consumer) Start(ctx context.Context) error { + c.log.Info("Starting consumer") + + sub, err := c.t.c.js.CreateConsumer(ctx, c.streamName, jetstream.ConsumerConfig{ + Durable: c.name, + DeliverPolicy: jetstream.DeliverAllPolicy, + ReplayPolicy: jetstream.ReplayInstantPolicy, + AckPolicy: jetstream.AckExplicitPolicy, + MaxWaiting: 1, + MaxAckPending: 1, + HeadersOnly: false, + MemoryStorage: false, + }) + if err != nil { + return err + } + + m, err := sub.Messages(jetstream.PullMaxMessages(1)) + if err != nil { + return err + } + + go c.handleMessages(m) + + go func() { + <-ctx.Done() + c.Close() + }() + + go func() { + <-c.stopCh + m.Drain() + }() + + return nil +} + +func (c *Consumer) handleMessages(m jetstream.MessagesContext) { + for { + msg, err := m.Next() + if err != nil { + if errors.Is(err, jetstream.ErrMsgIteratorClosed) { + c.log.Info("Stopping") + return + } + + c.log.Error(err, "Failed to get next message") + break + } + + if err := c.handleTask(msg); err != nil { + c.log.Error(err, "Failed to handle task") + break + } + } +} + +func (c *Consumer) handleTask(msg jetstream.Msg) error { + msgMeta, err := msg.Metadata() + if err != nil { + return err + } + + data := msg.Data() + + var w taskWrapper + if err := json.Unmarshal(data, &w); err != nil { + if err := msg.Nak(); err != nil { + c.log.Error(err, "Failed to nak message") + } + + return err + } + + s := c.telemetry.StartSpan( + context.Background(), + "queue.process", + "taskqueue/Consumer.handleTask", + c.telemetry.ContinueFromMap(w.TraceInfo), + ). + AddAttribute("messaging.destination.name", msg.Subject()). + AddAttribute("messaging.message.id", msgMeta.Sequence.Stream). + AddAttribute("messaging.message.retry.count", msgMeta.NumDelivered). + AddAttribute("messaging.message.body.size", len(data)). + AddAttribute("messaging.message.receive.latency", time.Since(w.EnqueuedAt).Milliseconds()) + defer s.End() + ctx := s.Context() + + handlers, ok := c.t.handlers[w.Task.Type] + if !ok { + c.log.V(2).Info("No handler for task", "type", w.Task.Type) + return msg.Nak() + } + + var errs CombinedError + for _, handler := range handlers { + if err := handler(ctx, w.Task); err != nil { + c.log.Error(err, "Handler failed", "type", w.Task.Type) + errs.Errors = append(errs.Errors, err) + } + } + + if len(errs.Errors) > 0 { + if err := msg.Nak(); err != nil { + c.log.Error(err, "Failed to nak message") + errs.Errors = append(errs.Errors, err) + } + + return errs + } + + return msg.Ack() +}