Skip to content

Commit

Permalink
Added pipe support
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniomika committed Sep 19, 2024
1 parent 977bd6b commit b923038
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 10 deletions.
28 changes: 25 additions & 3 deletions cmd/authorized_keys/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware {
return func(sesh ssh.Session) {
args := sesh.Command()
if len(args) < 2 {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}")
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
next(sesh)
return
}
Expand All @@ -45,7 +45,7 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware {
logger.Info("running cli")

if cmd == "help" {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}")
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
} else if cmd == "sub" {
sub := &pubsub.Sub{
ID: uuid.NewString(),
Expand Down Expand Up @@ -79,8 +79,29 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware {
if err != nil {
logger.Error("error from pub", slog.Any("error", err), slog.String("pub", pub.ID))
}
} else if cmd == "pipe" {
pipeClient := &pubsub.PipeClient{
ID: uuid.NewString(),
Done: make(chan struct{}),
Data: make(chan pubsub.PipeMessage),
Replay: args[len(args)-1] == "replay",
ReadWriter: sesh,
}

go func() {
<-sesh.Context().Done()
pipeClient.Cleanup()
}()

readErr, writeErr := cfg.PubSub.Pipe(channel, pipeClient)
if readErr != nil {
logger.Error("error reading from pipe", slog.Any("error", readErr), slog.String("pipeClient", pipeClient.ID))
}
if writeErr != nil {
logger.Error("error writing to pipe", slog.Any("error", writeErr), slog.String("pipeClient", pipeClient.ID))
}
} else {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}")
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
}

next(sesh)
Expand All @@ -98,6 +119,7 @@ func main() {
PubSub: &pubsub.PubSubMulticast{
Logger: logger,
Channels: syncmap.New[string, *pubsub.Channel](),
Pipes: syncmap.New[string, *pubsub.Pipe](),
},
}

Expand Down
164 changes: 159 additions & 5 deletions multicast.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@ import (
"io"
"log/slog"
"strings"
"sync"

"github.com/antoniomika/syncmap"
)

type PipeDirection int

const (
PipeInput PipeDirection = iota
PipeOutput
)

type PubSubMulticast struct {
Logger *slog.Logger
Channels *syncmap.Map[string, *Channel]
Pipes *syncmap.Map[string, *Pipe]
}

func (b *PubSubMulticast) Cleanup() {
Expand Down Expand Up @@ -39,6 +48,151 @@ func (b *PubSubMulticast) Cleanup() {
for _, channel := range toRemove {
b.Channels.Delete(channel)
}

pipesToRemove := []string{}
b.Pipes.Range(func(I string, J *Pipe) bool {
count := 0
J.Clients.Range(func(K string, V *PipeClient) bool {
count++
return true
})

if count == 0 {
J.Cleanup()
pipesToRemove = append(pipesToRemove, I)
}

return true
})

for _, pipe := range pipesToRemove {
b.Pipes.Delete(pipe)
}
}

func (b *PubSubMulticast) ensurePipe(pipe string) *Pipe {
dataPipe, _ := b.Pipes.LoadOrStore(pipe, &Pipe{
Name: pipe,
Clients: syncmap.New[string, *PipeClient](),
Done: make(chan struct{}),
Data: make(chan PipeMessage),
})
dataPipe.Handle()

return dataPipe
}

func (b *PubSubMulticast) GetPipes(pipePrefix string) []*Pipe {
var pipes []*Pipe
b.Pipes.Range(func(I string, J *Pipe) bool {
if strings.HasPrefix(I, pipePrefix) {
pipes = append(pipes, J)
}

return true
})
return pipes
}

func (b *PubSubMulticast) GetPipe(pipe string) *Pipe {
pipeData, _ := b.Pipes.Load(pipe)
return pipeData
}

func (b *PubSubMulticast) Pipe(pipe string, pipeClient *PipeClient) (error, error) {
pipeData := b.ensurePipe(pipe)
pipeData.Clients.Store(pipeClient.ID, pipeClient)
defer func() {
pipeClient.Cleanup()
pipeData.Clients.Delete(pipeClient.ID)
b.Cleanup()
}()

var (
readErr error
writeErr error
wg sync.WaitGroup
)

wg.Add(2)

go func() {
defer wg.Done()
mainLoop:
for {
select {
case data, ok := <-pipeClient.Data:
if data.Direction == PipeInput {
select {
case pipeData.Data <- data:
case <-pipeClient.Done:
break mainLoop
case <-pipeData.Done:
break mainLoop
default:
continue
}
} else {
if data.ClientID == pipeClient.ID && !pipeClient.Replay {
continue
}

_, err := pipeClient.ReadWriter.Write(data.Data)
if err != nil {
slog.Error("error writing to sub", slog.String("pipeClient", pipeClient.ID), slog.String("pipe", pipe), slog.Any("error", err))
writeErr = err
return
}
}

if !ok {
break mainLoop
}
case <-pipeClient.Done:
break mainLoop
case <-pipeData.Done:
break mainLoop
}
}
}()

go func() {
defer wg.Done()
mainLoop:
for {
data := make([]byte, 32*1024)
n, err := pipeClient.ReadWriter.Read(data)
data = data[:n]

pipeMessage := PipeMessage{
Data: data,
ClientID: pipeClient.ID,
Direction: PipeInput,
}

select {
case pipeClient.Data <- pipeMessage:
case <-pipeClient.Done:
break mainLoop
case <-pipeData.Done:
break mainLoop
}

if err != nil {
if errors.Is(err, io.EOF) {
return
}

slog.Error("error reading from pipe", slog.String("pipeClient", pipeClient.ID), slog.String("pipe", pipe), slog.Any("error", err))
readErr = err
return
}
}
}()

wg.Wait()

return readErr, writeErr
}

func (b *PubSubMulticast) GetChannels(channelPrefix string) []*Channel {
Expand Down Expand Up @@ -90,7 +244,7 @@ func (b *PubSubMulticast) GetSubs(channel string) []*Sub {
return subs
}

func (b *PubSubMulticast) ensure(channel string) *Channel {
func (b *PubSubMulticast) ensureChannel(channel string) *Channel {
dataChannel, _ := b.Channels.LoadOrStore(channel, &Channel{
Name: channel,
Done: make(chan struct{}),
Expand All @@ -104,7 +258,7 @@ func (b *PubSubMulticast) ensure(channel string) *Channel {
}

func (b *PubSubMulticast) Sub(channel string, sub *Sub) error {
dataChannel := b.ensure(channel)
dataChannel := b.ensureChannel(channel)
dataChannel.Subs.Store(sub.ID, sub)
defer func() {
sub.Cleanup()
Expand All @@ -122,7 +276,7 @@ mainLoop:
case data, ok := <-sub.Data:
_, err := sub.Writer.Write(data)
if err != nil {
slog.Error("error writing to sub", slog.Any("sub", sub.ID), slog.Any("channel", channel), slog.Any("error", err))
slog.Error("error writing to sub", slog.String("sub", sub.ID), slog.String("channel", channel), slog.Any("error", err))
return err
}

Expand All @@ -136,7 +290,7 @@ mainLoop:
}

func (b *PubSubMulticast) Pub(channel string, pub *Pub) error {
dataChannel := b.ensure(channel)
dataChannel := b.ensureChannel(channel)
dataChannel.Pubs.Store(pub.ID, pub)
defer func() {
pub.Cleanup()
Expand Down Expand Up @@ -182,7 +336,7 @@ mainLoop:
return nil
}

slog.Error("error reading from pub", slog.Any("pub", pub.ID), slog.Any("channel", channel), slog.Any("error", err))
slog.Error("error reading from pub", slog.String("pub", pub.ID), slog.String("channel", channel), slog.Any("error", err))
return err
}
}
Expand Down
97 changes: 95 additions & 2 deletions pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type Channel struct {
Data chan []byte
Subs *syncmap.Map[string, *Sub]
Pubs *syncmap.Map[string, *Pub]
once sync.Once
handleOnce sync.Once
cleanupOnce sync.Once
onceData sync.Once
}
Expand All @@ -30,7 +30,7 @@ func (c *Channel) Cleanup() {
}

func (c *Channel) Handle() {
c.once.Do(func() {
c.handleOnce.Do(func() {
go func() {
defer func() {
c.Subs.Range(func(I string, J *Sub) bool {
Expand Down Expand Up @@ -122,11 +122,104 @@ func (pub *Pub) Cleanup() {
})
}

type PipeClient struct {
ID string
Done chan struct{}
Data chan PipeMessage
ReadWriter io.ReadWriter
Replay bool
once sync.Once
onceData sync.Once
}

func (pipeClient *PipeClient) Cleanup() {
pipeClient.once.Do(func() {
close(pipeClient.Done)
})
}

type PipeMessage struct {
Data []byte
ClientID string
Direction PipeDirection
}

type Pipe struct {
Name string
Clients *syncmap.Map[string, *PipeClient]
Done chan struct{}
Data chan PipeMessage
handleOnce sync.Once
cleanupOnce sync.Once
}

func (pipe *Pipe) Handle() {
pipe.handleOnce.Do(func() {
go func() {
defer func() {
pipe.Clients.Range(func(I string, J *PipeClient) bool {
J.Cleanup()
return true
})
}()

for {
select {
case <-pipe.Done:
return
case data, ok := <-pipe.Data:
pipe.Clients.Range(func(I string, J *PipeClient) bool {
if !ok {
J.onceData.Do(func() {
close(J.Data)
})
return true
}

data.Direction = PipeOutput

select {
case J.Data <- data:
return true
case <-J.Done:
return true
case <-pipe.Done:
return true
case <-time.After(1 * time.Second):
slog.Error("timeout writing to pipe", slog.String("pipeClient", I), slog.String("pipe", pipe.Name))
return true
}
})
case <-time.After(1 * time.Millisecond):
count := 0
pipe.Clients.Range(func(I string, J *PipeClient) bool {
count++
return true
})
if count == 0 {
return
}
}
}
}()
})
}

func (pipe *Pipe) Cleanup() {
pipe.cleanupOnce.Do(func() {
close(pipe.Done)
close(pipe.Data)
})
}

type PubSub interface {
GetSubs(channel string) []*Sub
GetPubs(channel string) []*Pub
GetChannels(channelPrefix string) []*Channel
GetPipes(pipePrefix string) []*Pipe
GetChannel(channel string) *Channel
GetPipe(pipe string) *Pipe
Pipe(pipe string, pipeClient *PipeClient) (error, error)
Sub(channel string, sub *Sub) error
Pub(channel string, pub *Pub) error
}
Expand Down

0 comments on commit b923038

Please sign in to comment.