diff --git a/cmd/authorized_keys/main.go b/cmd/authorized_keys/main.go index 827c795..12ea537 100644 --- a/cmd/authorized_keys/main.go +++ b/cmd/authorized_keys/main.go @@ -30,7 +30,7 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware { return func(sesh ssh.Session) { args := sesh.Command() if len(args) < 2 { - wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}") + wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}") next(sesh) return } @@ -45,7 +45,7 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware { logger.Info("running cli") if cmd == "help" { - wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}") + wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}") } else if cmd == "sub" { sub := &pubsub.Sub{ ID: uuid.NewString(), @@ -79,8 +79,29 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware { if err != nil { logger.Error("error from pub", slog.Any("error", err), slog.String("pub", pub.ID)) } + } else if cmd == "pipe" { + pipeClient := &pubsub.PipeClient{ + ID: uuid.NewString(), + Done: make(chan struct{}), + Data: make(chan pubsub.PipeMessage), + Replay: args[len(args)-1] == "replay", + ReadWriter: sesh, + } + + go func() { + <-sesh.Context().Done() + pipeClient.Cleanup() + }() + + readErr, writeErr := cfg.PubSub.Pipe(channel, pipeClient) + if readErr != nil { + logger.Error("error reading from pipe", slog.Any("error", readErr), slog.String("pipeClient", pipeClient.ID)) + } + if writeErr != nil { + logger.Error("error writing to pipe", slog.Any("error", writeErr), slog.String("pipeClient", pipeClient.ID)) + } } else { - wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}") + wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}") } next(sesh) @@ -98,6 +119,7 @@ func main() { PubSub: &pubsub.PubSubMulticast{ Logger: logger, Channels: syncmap.New[string, *pubsub.Channel](), + Pipes: syncmap.New[string, *pubsub.Pipe](), }, } diff --git a/multicast.go b/multicast.go index fe391ff..25b66c8 100644 --- a/multicast.go +++ b/multicast.go @@ -5,13 +5,22 @@ import ( "io" "log/slog" "strings" + "sync" "github.com/antoniomika/syncmap" ) +type PipeDirection int + +const ( + PipeInput PipeDirection = iota + PipeOutput +) + type PubSubMulticast struct { Logger *slog.Logger Channels *syncmap.Map[string, *Channel] + Pipes *syncmap.Map[string, *Pipe] } func (b *PubSubMulticast) Cleanup() { @@ -39,6 +48,151 @@ func (b *PubSubMulticast) Cleanup() { for _, channel := range toRemove { b.Channels.Delete(channel) } + + pipesToRemove := []string{} + b.Pipes.Range(func(I string, J *Pipe) bool { + count := 0 + J.Clients.Range(func(K string, V *PipeClient) bool { + count++ + return true + }) + + if count == 0 { + J.Cleanup() + pipesToRemove = append(pipesToRemove, I) + } + + return true + }) + + for _, pipe := range pipesToRemove { + b.Pipes.Delete(pipe) + } +} + +func (b *PubSubMulticast) ensurePipe(pipe string) *Pipe { + dataPipe, _ := b.Pipes.LoadOrStore(pipe, &Pipe{ + Name: pipe, + Clients: syncmap.New[string, *PipeClient](), + Done: make(chan struct{}), + Data: make(chan PipeMessage), + }) + dataPipe.Handle() + + return dataPipe +} + +func (b *PubSubMulticast) GetPipes(pipePrefix string) []*Pipe { + var pipes []*Pipe + b.Pipes.Range(func(I string, J *Pipe) bool { + if strings.HasPrefix(I, pipePrefix) { + pipes = append(pipes, J) + } + + return true + }) + return pipes +} + +func (b *PubSubMulticast) GetPipe(pipe string) *Pipe { + pipeData, _ := b.Pipes.Load(pipe) + return pipeData +} + +func (b *PubSubMulticast) Pipe(pipe string, pipeClient *PipeClient) (error, error) { + pipeData := b.ensurePipe(pipe) + pipeData.Clients.Store(pipeClient.ID, pipeClient) + defer func() { + pipeClient.Cleanup() + pipeData.Clients.Delete(pipeClient.ID) + b.Cleanup() + }() + + var ( + readErr error + writeErr error + wg sync.WaitGroup + ) + + wg.Add(2) + + go func() { + defer wg.Done() + mainLoop: + for { + select { + case data, ok := <-pipeClient.Data: + if data.Direction == PipeInput { + select { + case pipeData.Data <- data: + case <-pipeClient.Done: + break mainLoop + case <-pipeData.Done: + break mainLoop + default: + continue + } + } else { + if data.ClientID == pipeClient.ID && !pipeClient.Replay { + continue + } + + _, err := pipeClient.ReadWriter.Write(data.Data) + if err != nil { + slog.Error("error writing to sub", slog.String("pipeClient", pipeClient.ID), slog.String("pipe", pipe), slog.Any("error", err)) + writeErr = err + return + } + } + + if !ok { + break mainLoop + } + case <-pipeClient.Done: + break mainLoop + case <-pipeData.Done: + break mainLoop + } + } + }() + + go func() { + defer wg.Done() + mainLoop: + for { + data := make([]byte, 32*1024) + n, err := pipeClient.ReadWriter.Read(data) + data = data[:n] + + pipeMessage := PipeMessage{ + Data: data, + ClientID: pipeClient.ID, + Direction: PipeInput, + } + + select { + case pipeClient.Data <- pipeMessage: + case <-pipeClient.Done: + break mainLoop + case <-pipeData.Done: + break mainLoop + } + + if err != nil { + if errors.Is(err, io.EOF) { + return + } + + slog.Error("error reading from pipe", slog.String("pipeClient", pipeClient.ID), slog.String("pipe", pipe), slog.Any("error", err)) + readErr = err + return + } + } + }() + + wg.Wait() + + return readErr, writeErr } func (b *PubSubMulticast) GetChannels(channelPrefix string) []*Channel { @@ -90,7 +244,7 @@ func (b *PubSubMulticast) GetSubs(channel string) []*Sub { return subs } -func (b *PubSubMulticast) ensure(channel string) *Channel { +func (b *PubSubMulticast) ensureChannel(channel string) *Channel { dataChannel, _ := b.Channels.LoadOrStore(channel, &Channel{ Name: channel, Done: make(chan struct{}), @@ -104,7 +258,7 @@ func (b *PubSubMulticast) ensure(channel string) *Channel { } func (b *PubSubMulticast) Sub(channel string, sub *Sub) error { - dataChannel := b.ensure(channel) + dataChannel := b.ensureChannel(channel) dataChannel.Subs.Store(sub.ID, sub) defer func() { sub.Cleanup() @@ -122,7 +276,7 @@ mainLoop: case data, ok := <-sub.Data: _, err := sub.Writer.Write(data) if err != nil { - slog.Error("error writing to sub", slog.Any("sub", sub.ID), slog.Any("channel", channel), slog.Any("error", err)) + slog.Error("error writing to sub", slog.String("sub", sub.ID), slog.String("channel", channel), slog.Any("error", err)) return err } @@ -136,7 +290,7 @@ mainLoop: } func (b *PubSubMulticast) Pub(channel string, pub *Pub) error { - dataChannel := b.ensure(channel) + dataChannel := b.ensureChannel(channel) dataChannel.Pubs.Store(pub.ID, pub) defer func() { pub.Cleanup() @@ -182,7 +336,7 @@ mainLoop: return nil } - slog.Error("error reading from pub", slog.Any("pub", pub.ID), slog.Any("channel", channel), slog.Any("error", err)) + slog.Error("error reading from pub", slog.String("pub", pub.ID), slog.String("channel", channel), slog.Any("error", err)) return err } } diff --git a/pubsub.go b/pubsub.go index 6cac450..2bc383e 100644 --- a/pubsub.go +++ b/pubsub.go @@ -15,7 +15,7 @@ type Channel struct { Data chan []byte Subs *syncmap.Map[string, *Sub] Pubs *syncmap.Map[string, *Pub] - once sync.Once + handleOnce sync.Once cleanupOnce sync.Once onceData sync.Once } @@ -30,7 +30,7 @@ func (c *Channel) Cleanup() { } func (c *Channel) Handle() { - c.once.Do(func() { + c.handleOnce.Do(func() { go func() { defer func() { c.Subs.Range(func(I string, J *Sub) bool { @@ -122,11 +122,104 @@ func (pub *Pub) Cleanup() { }) } +type PipeClient struct { + ID string + Done chan struct{} + Data chan PipeMessage + ReadWriter io.ReadWriter + Replay bool + once sync.Once + onceData sync.Once +} + +func (pipeClient *PipeClient) Cleanup() { + pipeClient.once.Do(func() { + close(pipeClient.Done) + }) +} + +type PipeMessage struct { + Data []byte + ClientID string + Direction PipeDirection +} + +type Pipe struct { + Name string + Clients *syncmap.Map[string, *PipeClient] + Done chan struct{} + Data chan PipeMessage + handleOnce sync.Once + cleanupOnce sync.Once +} + +func (pipe *Pipe) Handle() { + pipe.handleOnce.Do(func() { + go func() { + defer func() { + pipe.Clients.Range(func(I string, J *PipeClient) bool { + J.Cleanup() + return true + }) + }() + + for { + select { + case <-pipe.Done: + return + case data, ok := <-pipe.Data: + pipe.Clients.Range(func(I string, J *PipeClient) bool { + if !ok { + J.onceData.Do(func() { + close(J.Data) + }) + return true + } + + data.Direction = PipeOutput + + select { + case J.Data <- data: + return true + case <-J.Done: + return true + case <-pipe.Done: + return true + case <-time.After(1 * time.Second): + slog.Error("timeout writing to pipe", slog.String("pipeClient", I), slog.String("pipe", pipe.Name)) + return true + } + }) + case <-time.After(1 * time.Millisecond): + count := 0 + pipe.Clients.Range(func(I string, J *PipeClient) bool { + count++ + return true + }) + if count == 0 { + return + } + } + } + }() + }) +} + +func (pipe *Pipe) Cleanup() { + pipe.cleanupOnce.Do(func() { + close(pipe.Done) + close(pipe.Data) + }) +} + type PubSub interface { GetSubs(channel string) []*Sub GetPubs(channel string) []*Pub GetChannels(channelPrefix string) []*Channel + GetPipes(pipePrefix string) []*Pipe GetChannel(channel string) *Channel + GetPipe(pipe string) *Pipe + Pipe(pipe string, pipeClient *PipeClient) (error, error) Sub(channel string, sub *Sub) error Pub(channel string, pub *Pub) error }