Skip to content

Commit

Permalink
refactor: taskset WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
TheDevMinerTV committed Aug 29, 2024
1 parent c701950 commit 6d9641d
Show file tree
Hide file tree
Showing 6 changed files with 368 additions and 209 deletions.
1 change: 0 additions & 1 deletion internal/tasks/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ func NewHandler(federationService service.FederationService, repositories reposi

func (t *Handler) Register(tq *taskqueue.Client) {
tq.RegisterHandler(FederateNote, parse(t.FederateNote))
tq.RegisterHandler(FederateFollow, parse(t.FederateFollow))
}

func parse[T any](handler func(context.Context, T) error) func(context.Context, taskqueue.Task) error {
Expand Down
52 changes: 52 additions & 0 deletions internal/tasks/note_tasks/federate_note.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package note_tasks

import (
"context"
"github.com/versia-pub/versia-go/internal/repository"

"github.com/google/uuid"
"github.com/versia-pub/versia-go/internal/entity"
)

type FederateNoteData struct {
NoteID uuid.UUID `json:"noteID"`
}

func (t *Handler) FederateNote(ctx context.Context, data FederateNoteData) error {
s := t.telemetry.StartSpan(ctx, "function", "tasks/Handler.FederateNote")
defer s.End()
ctx = s.Context()

var n *entity.Note
if err := t.repositories.Atomic(ctx, func(ctx context.Context, tx repository.Manager) error {
var err error
n, err = tx.Notes().GetByID(ctx, data.NoteID)
if err != nil {
return err
}
if n == nil {
t.log.V(-1).Info("Could not find note", "id", data.NoteID)
return nil
}

for _, uu := range n.Mentions {
if !uu.IsRemote {
t.log.V(2).Info("User is not remote", "user", uu.ID)
continue
}

res, err := t.federationService.SendToInbox(ctx, n.Author, &uu, n.ToVersia())
if err != nil {
t.log.Error(err, "Failed to send note to remote user", "res", res, "user", uu.ID)
} else {
t.log.V(2).Info("Sent note to remote user", "res", res, "user", uu.ID)
}
}

return nil
}); err != nil {
return err
}

return nil
}
51 changes: 51 additions & 0 deletions internal/tasks/note_tasks/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package note_tasks

import (
"context"
"encoding/json"
"github.com/versia-pub/versia-go/internal/repository"
"github.com/versia-pub/versia-go/internal/service"

"git.devminer.xyz/devminer/unitel"
"github.com/go-logr/logr"
"github.com/versia-pub/versia-go/pkg/taskqueue"
)

const (
FederateNote = "federate_note"
)

type Handler struct {
federationService service.FederationService

repositories repository.Manager

telemetry *unitel.Telemetry
log logr.Logger
}

func NewHandler(federationService service.FederationService, repositories repository.Manager, telemetry *unitel.Telemetry, log logr.Logger) *Handler {
return &Handler{
federationService: federationService,

repositories: repositories,

telemetry: telemetry,
log: log,
}
}

func (t *Handler) Register(tq *taskqueue.Client) {
tq.RegisterHandler(FederateNote, parse(t.FederateNote))
}

func parse[T any](handler func(context.Context, T) error) func(context.Context, taskqueue.Task) error {
return func(ctx context.Context, task taskqueue.Task) error {
var data T
if err := json.Unmarshal(task.Payload, &data); err != nil {
return err
}

return handler(ctx, data)
}
}
228 changes: 20 additions & 208 deletions pkg/taskqueue/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"strings"
"sync"
"fmt"
"time"

"git.devminer.xyz/devminer/unitel"
Expand Down Expand Up @@ -53,12 +52,9 @@ func NewTask(type_ string, payload any) (Task, error) {
}, nil
}

type Handler func(ctx context.Context, task Task) error

type Client struct {
name string
subject string
handlers map[string][]Handler
name string
subject string

nc *nats.Conn
js jetstream.JetStream
Expand All @@ -71,15 +67,27 @@ type Client struct {
log logr.Logger
}

func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, telemetry *unitel.Telemetry, log logr.Logger) (*Client, error) {
func NewClient(streamName string, natsClient *nats.Conn, telemetry *unitel.Telemetry, log logr.Logger) (*Client, error) {
js, err := jetstream.New(natsClient)
if err != nil {
return nil, err
}

s, err := js.CreateStream(ctx, jetstream.StreamConfig{
return &Client{
name: streamName,

js: js,

telemetry: telemetry,
log: log,
}, nil
}

func (c *Client) TaskSet(ctx context.Context, name string) (*TaskSet, error) {
streamName := fmt.Sprintf("%s_%s", c.name, name)

s, err := c.js.CreateStream(ctx, jetstream.StreamConfig{
Name: streamName,
Subjects: []string{streamName + ".*"},
MaxConsumers: -1,
MaxMsgs: -1,
Discard: jetstream.DiscardOld,
Expand All @@ -89,209 +97,13 @@ func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, te
AllowDirect: true,
})
if errors.Is(err, nats.ErrStreamNameAlreadyInUse) {
s, err = js.Stream(ctx, streamName)
s, err = c.js.Stream(ctx, streamName)
if err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}

stopCh := make(chan struct{})

c := &Client{
name: streamName,
subject: streamName + ".tasks",

handlers: map[string][]Handler{},

stopCh: stopCh,
closeOnce: sync.OnceFunc(func() {
close(stopCh)
}),

nc: natsClient,
js: js,
s: s,

telemetry: telemetry,
log: log,
}

go func() {
<-ctx.Done()

c.Close()
}()

return c, nil
}

func (c *Client) Close() {
c.closeOnce()
c.nc.Close()
}

func (c *Client) Submit(ctx context.Context, task Task) error {
s := c.telemetry.StartSpan(ctx, "queue.publish", "taskqueue/Client.Submit").
AddAttribute("messaging.destination.name", c.subject)
defer s.End()
ctx = s.Context()

s.AddAttribute("jobID", task.ID)

data, err := json.Marshal(c.newTaskWrapper(ctx, task))
if err != nil {
return err
}

s.AddAttribute("messaging.message.body.size", len(data))

msg, err := c.js.PublishMsg(ctx, &nats.Msg{Subject: c.subject, Data: data})
if err != nil {
return err
}
c.log.V(2).Info("Submitted task", "id", task.ID, "type", task.Type, "sequence", msg.Sequence)

s.AddAttribute("messaging.message.id", msg.Sequence)

return nil
}

func (c *Client) RegisterHandler(type_ string, handler Handler) {
c.log.V(2).Info("Registering handler", "type", type_)

if _, ok := c.handlers[type_]; !ok {
c.handlers[type_] = []Handler{}
}
c.handlers[type_] = append(c.handlers[type_], handler)
}

func (c *Client) StartConsumer(ctx context.Context, consumerGroup string) error {
c.log.Info("Starting consumer")

sub, err := c.js.CreateConsumer(ctx, c.name, jetstream.ConsumerConfig{
Durable: consumerGroup,
DeliverPolicy: jetstream.DeliverAllPolicy,
ReplayPolicy: jetstream.ReplayInstantPolicy,
AckPolicy: jetstream.AckExplicitPolicy,
FilterSubject: c.subject,
MaxWaiting: 1,
MaxAckPending: 1,
HeadersOnly: false,
MemoryStorage: false,
})
if err != nil {
return err
}

m, err := sub.Messages(jetstream.PullMaxMessages(1))
if err != nil {
return err
}

go func() {
for {
msg, err := m.Next()
if err != nil {
if errors.Is(err, jetstream.ErrMsgIteratorClosed) {
c.log.Info("Stopping")
return
}

c.log.Error(err, "Failed to get next message")
break
}

if err := c.handleTask(ctx, msg); err != nil {
c.log.Error(err, "Failed to handle task")
break
}
}
}()
go func() {
<-c.stopCh
m.Drain()
}()

return nil
}

type Consumer struct {
telemetry *unitel.Telemetry
log logr.Logger
}

func (c *Consumer) handleTask(ctx context.Context, msg jetstream.Msg) error {
msgMeta, err := msg.Metadata()
if err != nil {
return err
}

data := msg.Data()

var w taskWrapper
if err := json.Unmarshal(data, &w); err != nil {
if err := msg.Nak(); err != nil {
c.log.Error(err, "Failed to nak message")
}

return err
}

s := c.telemetry.StartSpan(
context.Background(),
"queue.process",
"taskqueue/Client.handleTask",
c.telemetry.ContinueFromMap(w.TraceInfo),
).
AddAttribute("messaging.destination.name", c.subject).
AddAttribute("messaging.message.id", msgMeta.Sequence.Stream).
AddAttribute("messaging.message.retry.count", msgMeta.NumDelivered).
AddAttribute("messaging.message.body.size", len(data)).
AddAttribute("messaging.message.receive.latency", time.Since(w.EnqueuedAt).Milliseconds())
defer s.End()
ctx = s.Context()

handlers, ok := c.handlers[w.Task.Type]
if !ok {
c.log.V(2).Info("No handler for task", "type", w.Task.Type)
return msg.Nak()
}

var errs CombinedError
for _, handler := range handlers {
if err := handler(ctx, w.Task); err != nil {
c.log.Error(err, "Handler failed", "type", w.Task.Type)
errs.Errors = append(errs.Errors, err)
}
}

if len(errs.Errors) > 0 {
if err := msg.Nak(); err != nil {
c.log.Error(err, "Failed to nak message")
errs.Errors = append(errs.Errors, err)
}

return errs
}

return msg.Ack()
}

type CombinedError struct {
Errors []error
}

func (e CombinedError) Error() string {
sb := strings.Builder{}
sb.WriteRune('[')
for i, err := range e.Errors {
if i > 0 {
sb.WriteRune(',')
}
sb.WriteString(err.Error())
}
sb.WriteRune(']')
return sb.String()
return newTaskSet(streamName, c, s, c.log.WithName(fmt.Sprintf("taskset(%s)", streamName)), c.telemetry), nil
}
Loading

0 comments on commit 6d9641d

Please sign in to comment.