diff --git a/client.go b/client.go index 3bef8a8..aada70d 100644 --- a/client.go +++ b/client.go @@ -8,7 +8,7 @@ import ( "github.com/antoniomika/syncmap" ) -func NewClient(ID string, rw io.ReadWriter, direction ChannelDirection, blockWrite, replay bool) *Client { +func NewClient(ID string, rw io.ReadWriter, direction ChannelDirection, blockWrite, replay, keepAlive bool) *Client { return &Client{ ID: ID, ReadWriter: rw, @@ -18,6 +18,7 @@ func NewClient(ID string, rw io.ReadWriter, direction ChannelDirection, blockWri Data: make(chan ChannelMessage), Replay: replay, BlockWrite: blockWrite, + KeepAlive: keepAlive, } } @@ -30,6 +31,7 @@ type Client struct { Data chan ChannelMessage Replay bool BlockWrite bool + KeepAlive bool once sync.Once onceData sync.Once } diff --git a/cmd/authorized_keys/main.go b/cmd/authorized_keys/main.go index 2cdadcf..8c082d3 100644 --- a/cmd/authorized_keys/main.go +++ b/cmd/authorized_keys/main.go @@ -59,7 +59,7 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware { clientID := uuid.NewString() - err := errors.Join(cfg.PubSub.Sub(sesh.Context(), clientID, sesh, chans)) + err := errors.Join(cfg.PubSub.Sub(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "keepalive")) if err != nil { logger.Error("error during pub", slog.Any("error", err), slog.String("client", clientID)) } diff --git a/connector.go b/connector.go index 68cf81a..450e2e3 100644 --- a/connector.go +++ b/connector.go @@ -72,7 +72,9 @@ func (b *BaseConnector) Connect(client *Client, channels []*Channel) (error, err if count == 0 { for _, cl := range dataChannel.GetClients() { - cl.Cleanup() + if !cl.KeepAlive { + cl.Cleanup() + } } } diff --git a/multicast.go b/multicast.go index 6578d7b..fd32eab 100644 --- a/multicast.go +++ b/multicast.go @@ -35,8 +35,8 @@ func (p *PubSubMulticast) GetSubs() iter.Seq2[string, *Client] { return p.getClients(ChannelDirectionOutput) } -func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay bool) (error, error) { - client := NewClient(ID, rw, direction, blockWrite, replay) +func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool) (error, error) { + client := NewClient(ID, rw, direction, blockWrite, replay, keepAlive) go func() { <-ctx.Done() @@ -47,15 +47,15 @@ func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWrit } func (p *PubSubMulticast) Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) { - return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay) + return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay, false) } func (p *PubSubMulticast) Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error { - return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, true, false)) + return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, true, false, false)) } -func (p *PubSubMulticast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error { - return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false)) +func (p *PubSubMulticast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error { + return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false, keepAlive)) } var _ PubSub = (*PubSubMulticast)(nil) diff --git a/multicast_test.go b/multicast_test.go index 3ddb3c0..dea572f 100644 --- a/multicast_test.go +++ b/multicast_test.go @@ -55,7 +55,7 @@ func TestMulticastSubBlock(t *testing.T) { go func() { orderActual += "sub-" syncer <- 0 - fmt.Println(cast.Sub(context.TODO(), "1", actual, []*Channel{channel})) + fmt.Println(cast.Sub(context.TODO(), "1", actual, []*Channel{channel}, false)) wg.Done() }() @@ -109,7 +109,7 @@ func TestMulticastPubBlock(t *testing.T) { go func() { orderActual += "sub-" wg.Done() - fmt.Println(cast.Sub(context.TODO(), "2", actual, []*Channel{channel})) + fmt.Println(cast.Sub(context.TODO(), "2", actual, []*Channel{channel}, false)) }() wg.Wait() @@ -146,7 +146,7 @@ func TestMulticastMultSubs(t *testing.T) { go func() { orderActual += "sub-" syncer <- 0 - fmt.Println(cast.Sub(context.TODO(), "1", actual, []*Channel{channel})) + fmt.Println(cast.Sub(context.TODO(), "1", actual, []*Channel{channel}, false)) wg.Done() }() @@ -155,7 +155,7 @@ func TestMulticastMultSubs(t *testing.T) { go func() { orderActual += "sub-" syncer <- 0 - fmt.Println(cast.Sub(context.TODO(), "2", actualOther, []*Channel{channel})) + fmt.Println(cast.Sub(context.TODO(), "2", actualOther, []*Channel{channel}, false)) wg.Done() }() diff --git a/pubsub.go b/pubsub.go index 8a587cf..54c8ea5 100644 --- a/pubsub.go +++ b/pubsub.go @@ -13,7 +13,7 @@ type PubSub interface { GetSubs() iter.Seq2[string, *Client] GetPipes() iter.Seq2[string, *Client] Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) - Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error + Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error }