Skip to content

Commit

Permalink
Add keepalive setting for subs
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniomika committed Oct 2, 2024
1 parent 3ff0e57 commit 242caa6
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 14 deletions.
4 changes: 3 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/antoniomika/syncmap"
)

func NewClient(ID string, rw io.ReadWriter, direction ChannelDirection, blockWrite, replay bool) *Client {
func NewClient(ID string, rw io.ReadWriter, direction ChannelDirection, blockWrite, replay, keepAlive bool) *Client {
return &Client{
ID: ID,
ReadWriter: rw,
Expand All @@ -18,6 +18,7 @@ func NewClient(ID string, rw io.ReadWriter, direction ChannelDirection, blockWri
Data: make(chan ChannelMessage),
Replay: replay,
BlockWrite: blockWrite,
KeepAlive: keepAlive,
}
}

Expand All @@ -30,6 +31,7 @@ type Client struct {
Data chan ChannelMessage
Replay bool
BlockWrite bool
KeepAlive bool
once sync.Once
onceData sync.Once
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/authorized_keys/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware {

clientID := uuid.NewString()

err := errors.Join(cfg.PubSub.Sub(sesh.Context(), clientID, sesh, chans))
err := errors.Join(cfg.PubSub.Sub(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "keepalive"))
if err != nil {
logger.Error("error during pub", slog.Any("error", err), slog.String("client", clientID))
}
Expand Down
4 changes: 3 additions & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ func (b *BaseConnector) Connect(client *Client, channels []*Channel) (error, err

if count == 0 {
for _, cl := range dataChannel.GetClients() {
cl.Cleanup()
if !cl.KeepAlive {
cl.Cleanup()
}
}
}

Expand Down
12 changes: 6 additions & 6 deletions multicast.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func (p *PubSubMulticast) GetSubs() iter.Seq2[string, *Client] {
return p.getClients(ChannelDirectionOutput)
}

func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay bool) (error, error) {
client := NewClient(ID, rw, direction, blockWrite, replay)
func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool) (error, error) {
client := NewClient(ID, rw, direction, blockWrite, replay, keepAlive)

go func() {
<-ctx.Done()
Expand All @@ -47,15 +47,15 @@ func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWrit
}

func (p *PubSubMulticast) Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) {
return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay)
return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay, false)
}

func (p *PubSubMulticast) Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error {
return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, true, false))
return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, true, false, false))
}

func (p *PubSubMulticast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error {
return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false))
func (p *PubSubMulticast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error {
return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false, keepAlive))
}

var _ PubSub = (*PubSubMulticast)(nil)
8 changes: 4 additions & 4 deletions multicast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestMulticastSubBlock(t *testing.T) {
go func() {
orderActual += "sub-"
syncer <- 0
fmt.Println(cast.Sub(context.TODO(), "1", actual, []*Channel{channel}))
fmt.Println(cast.Sub(context.TODO(), "1", actual, []*Channel{channel}, false))
wg.Done()
}()

Expand Down Expand Up @@ -109,7 +109,7 @@ func TestMulticastPubBlock(t *testing.T) {
go func() {
orderActual += "sub-"
wg.Done()
fmt.Println(cast.Sub(context.TODO(), "2", actual, []*Channel{channel}))
fmt.Println(cast.Sub(context.TODO(), "2", actual, []*Channel{channel}, false))
}()

wg.Wait()
Expand Down Expand Up @@ -146,7 +146,7 @@ func TestMulticastMultSubs(t *testing.T) {
go func() {
orderActual += "sub-"
syncer <- 0
fmt.Println(cast.Sub(context.TODO(), "1", actual, []*Channel{channel}))
fmt.Println(cast.Sub(context.TODO(), "1", actual, []*Channel{channel}, false))
wg.Done()
}()

Expand All @@ -155,7 +155,7 @@ func TestMulticastMultSubs(t *testing.T) {
go func() {
orderActual += "sub-"
syncer <- 0
fmt.Println(cast.Sub(context.TODO(), "2", actualOther, []*Channel{channel}))
fmt.Println(cast.Sub(context.TODO(), "2", actualOther, []*Channel{channel}, false))
wg.Done()
}()

Expand Down
2 changes: 1 addition & 1 deletion pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type PubSub interface {
GetSubs() iter.Seq2[string, *Client]
GetPipes() iter.Seq2[string, *Client]
Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error)
Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error
Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error
Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error
}

Expand Down

0 comments on commit 242caa6

Please sign in to comment.