Skip to content

Commit

Permalink
refactor: align naming with mqtt terminology
Browse files Browse the repository at this point in the history
  • Loading branch information
neurosnap committed Oct 3, 2024
1 parent 242caa6 commit c319ff1
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 70 deletions.
24 changes: 12 additions & 12 deletions connector.go → broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ import (
"github.com/antoniomika/syncmap"
)

type Connector interface {
type Broker interface {
GetChannels() iter.Seq2[string, *Channel]
GetClients() iter.Seq2[string, *Client]
Connect(*Client, []*Channel) (error, error)
}

type BaseConnector struct {
type BaseBroker struct {
Channels *syncmap.Map[string, *Channel]
}

func (b *BaseConnector) Cleanup() {
func (b *BaseBroker) Cleanup() {
toRemove := []string{}
for _, channel := range b.GetChannels() {
count := 0
Expand All @@ -31,7 +31,7 @@ func (b *BaseConnector) Cleanup() {

if count == 0 {
channel.Cleanup()
toRemove = append(toRemove, channel.ID)
toRemove = append(toRemove, channel.Topic)
}
}

Expand All @@ -40,25 +40,25 @@ func (b *BaseConnector) Cleanup() {
}
}

func (b *BaseConnector) GetChannels() iter.Seq2[string, *Channel] {
func (b *BaseBroker) GetChannels() iter.Seq2[string, *Channel] {
return b.Channels.Range
}

func (b *BaseConnector) GetClients() iter.Seq2[string, *Client] {
func (b *BaseBroker) GetClients() iter.Seq2[string, *Client] {
return func(yield func(string, *Client) bool) {
for _, channel := range b.GetChannels() {
channel.Clients.Range(yield)
}
}
}

func (b *BaseConnector) Connect(client *Client, channels []*Channel) (error, error) {
func (b *BaseBroker) Connect(client *Client, channels []*Channel) (error, error) {
for _, channel := range channels {
dataChannel := b.ensureChannel(channel)
dataChannel.Clients.Store(client.ID, client)
client.Channels.Store(dataChannel.ID, dataChannel)
client.Channels.Store(dataChannel.Topic, dataChannel)
defer func() {
client.Channels.Delete(channel.ID)
client.Channels.Delete(channel.Topic)
dataChannel.Clients.Delete(client.ID)

client.Cleanup()
Expand Down Expand Up @@ -186,10 +186,10 @@ func (b *BaseConnector) Connect(client *Client, channels []*Channel) (error, err
return inputErr, outputErr
}

func (b *BaseConnector) ensureChannel(channel *Channel) *Channel {
dataChannel, _ := b.Channels.LoadOrStore(channel.ID, channel)
func (b *BaseBroker) ensureChannel(channel *Channel) *Channel {
dataChannel, _ := b.Channels.LoadOrStore(channel.Topic, channel)
dataChannel.Handle()
return dataChannel
}

var _ Connector = (*BaseConnector)(nil)
var _ Broker = (*BaseBroker)(nil)
6 changes: 3 additions & 3 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ type ChannelMessage struct {
Action ChannelAction
}

func NewChannel(name string) *Channel {
func NewChannel(topic string) *Channel {
return &Channel{
ID: name,
Topic: topic,
Done: make(chan struct{}),
Data: make(chan ChannelMessage),
Clients: syncmap.New[string, *Client](),
}
}

type Channel struct {
ID string
Topic string
Done chan struct{}
Data chan ChannelMessage
Clients *syncmap.Map[string, *Client]
Expand Down
53 changes: 22 additions & 31 deletions cmd/authorized_keys/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"syscall"
"time"

"github.com/antoniomika/syncmap"
"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
"github.com/google/uuid"
Expand All @@ -26,66 +25,66 @@ func GetEnv(key string, defaultVal string) string {
return defaultVal
}

func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware {
func PubSubMiddleware(broker pubsub.PubSub, logger *slog.Logger) wish.Middleware {
return func(next ssh.Handler) ssh.Handler {
return func(sesh ssh.Session) {
args := sesh.Command()
if len(args) < 2 {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {topic}")
next(sesh)
return
}

cmd := strings.TrimSpace(args[0])
channel := args[1]
topicsRaw := args[1]

channels := strings.Split(channel, ",")
topics := strings.Split(topicsRaw, ",")

logger := cfg.Logger.With(
logger := logger.With(
"cmd", cmd,
"channel", channels,
"topics", topics,
)

logger.Info("running cli")

if cmd == "help" {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {topic}")
} else if cmd == "sub" {
var chans []*pubsub.Channel

for _, c := range channels {
chans = append(chans, pubsub.NewChannel(c))
for _, topic := range topics {
chans = append(chans, pubsub.NewChannel(topic))
}

clientID := uuid.NewString()

err := errors.Join(cfg.PubSub.Sub(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "keepalive"))
err := errors.Join(broker.Sub(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "keepalive"))
if err != nil {
logger.Error("error during pub", slog.Any("error", err), slog.String("client", clientID))
}
} else if cmd == "pub" {
var chans []*pubsub.Channel

for _, c := range channels {
chans = append(chans, pubsub.NewChannel(c))
for _, topic := range topics {
chans = append(chans, pubsub.NewChannel(topic))
}

clientID := uuid.NewString()

err := errors.Join(cfg.PubSub.Pub(sesh.Context(), clientID, sesh, chans))
err := errors.Join(broker.Pub(sesh.Context(), clientID, sesh, chans))
if err != nil {
logger.Error("error during pub", slog.Any("error", err), slog.String("client", clientID))
}
} else if cmd == "pipe" {
var chans []*pubsub.Channel

for _, c := range channels {
chans = append(chans, pubsub.NewChannel(c))
for _, topics := range topics {
chans = append(chans, pubsub.NewChannel(topics))
}

clientID := uuid.NewString()

err := errors.Join(cfg.PubSub.Pipe(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "replay"))
err := errors.Join(broker.Pipe(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "replay"))
if err != nil {
logger.Error(
"pipe error",
Expand All @@ -94,7 +93,7 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware {
)
}
} else {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {topic}")
}

next(sesh)
Expand All @@ -107,23 +106,15 @@ func main() {
host := GetEnv("SSH_HOST", "0.0.0.0")
port := GetEnv("SSH_PORT", "2222")
keyPath := GetEnv("SSH_AUTHORIZED_KEYS", "./ssh_data/authorized_keys")
cfg := &pubsub.Cfg{
Logger: logger,
PubSub: &pubsub.PubSubMulticast{
Logger: logger,
Connector: &pubsub.BaseConnector{
Channels: syncmap.New[string, *pubsub.Channel](),
},
},
}
broker := pubsub.NewMulticast(logger)

s, err := wish.NewServer(
ssh.NoPty(),
wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
wish.WithAuthorizedKeys(keyPath),
wish.WithMiddleware(
PubSubMiddleware(cfg),
PubSubMiddleware(broker, logger),
),
)
if err != nil {
Expand All @@ -149,10 +140,10 @@ func main() {
slog.Info("Debug Info", slog.Int("goroutines", runtime.NumGoroutine()))
select {
case <-time.After(5 * time.Second):
for _, channel := range cfg.PubSub.GetChannels() {
slog.Info("channel online", slog.Any("channel", channel.ID))
for _, channel := range broker.GetChannels() {
slog.Info("channel online", slog.Any("channel topic", channel.Topic))
for _, client := range channel.GetClients() {
slog.Info("client online", slog.Any("channel", channel.ID), slog.Any("client", client.ID), slog.String("direction", client.Direction.String()))
slog.Info("client online", slog.Any("channel topic", channel.Topic), slog.Any("client", client.ID), slog.String("direction", client.Direction.String()))
}
}
case <-done:
Expand Down
33 changes: 22 additions & 11 deletions multicast.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,25 @@ import (
"io"
"iter"
"log/slog"

"github.com/antoniomika/syncmap"
)

type PubSubMulticast struct {
Connector
type Multicast struct {
Broker
Logger *slog.Logger
}

func (p *PubSubMulticast) getClients(direction ChannelDirection) iter.Seq2[string, *Client] {
func NewMulticast(logger *slog.Logger) *Multicast {
return &Multicast{
Logger: logger,
Broker: &BaseBroker{
Channels: syncmap.New[string, *Channel](),
},
}
}

func (p *Multicast) getClients(direction ChannelDirection) iter.Seq2[string, *Client] {
return func(yield func(string, *Client) bool) {
for clientID, client := range p.GetClients() {
if client.Direction == direction {
Expand All @@ -23,19 +34,19 @@ func (p *PubSubMulticast) getClients(direction ChannelDirection) iter.Seq2[strin
}
}

func (p *PubSubMulticast) GetPipes() iter.Seq2[string, *Client] {
func (p *Multicast) GetPipes() iter.Seq2[string, *Client] {
return p.getClients(ChannelDirectionInputOutput)
}

func (p *PubSubMulticast) GetPubs() iter.Seq2[string, *Client] {
func (p *Multicast) GetPubs() iter.Seq2[string, *Client] {
return p.getClients(ChannelDirectionInput)
}

func (p *PubSubMulticast) GetSubs() iter.Seq2[string, *Client] {
func (p *Multicast) GetSubs() iter.Seq2[string, *Client] {
return p.getClients(ChannelDirectionOutput)
}

func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool) (error, error) {
func (p *Multicast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool) (error, error) {
client := NewClient(ID, rw, direction, blockWrite, replay, keepAlive)

go func() {
Expand All @@ -46,16 +57,16 @@ func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWrit
return p.Connect(client, channels)
}

func (p *PubSubMulticast) Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) {
func (p *Multicast) Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) {
return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay, false)
}

func (p *PubSubMulticast) Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error {
func (p *Multicast) Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error {
return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, true, false, false))
}

func (p *PubSubMulticast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error {
func (p *Multicast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error {
return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false, keepAlive))
}

var _ PubSub = (*PubSubMulticast)(nil)
var _ = (*Multicast)(nil)
12 changes: 6 additions & 6 deletions multicast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ func TestMulticastSubBlock(t *testing.T) {
name := "test-channel"
syncer := make(chan int)

cast := &PubSubMulticast{
cast := &Multicast{
Logger: slog.Default(),
Connector: &BaseConnector{
Broker: &BaseBroker{
Channels: syncmap.New[string, *Channel](),
},
}
Expand Down Expand Up @@ -85,9 +85,9 @@ func TestMulticastPubBlock(t *testing.T) {
name := "test-channel"
syncer := make(chan int)

cast := &PubSubMulticast{
cast := &Multicast{
Logger: slog.Default(),
Connector: &BaseConnector{
Broker: &BaseBroker{
Channels: syncmap.New[string, *Channel](),
},
}
Expand Down Expand Up @@ -131,9 +131,9 @@ func TestMulticastMultSubs(t *testing.T) {
name := "test-channel"
syncer := make(chan int)

cast := &PubSubMulticast{
cast := &Multicast{
Logger: slog.Default(),
Connector: &BaseConnector{
Broker: &BaseBroker{
Channels: syncmap.New[string, *Channel](),
},
}
Expand Down
8 changes: 1 addition & 7 deletions pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,14 @@ import (
"context"
"io"
"iter"
"log/slog"
)

type PubSub interface {
Connector
Broker
GetPubs() iter.Seq2[string, *Client]
GetSubs() iter.Seq2[string, *Client]
GetPipes() iter.Seq2[string, *Client]
Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error)
Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error
Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error
}

type Cfg struct {
Logger *slog.Logger
PubSub PubSub
}

0 comments on commit c319ff1

Please sign in to comment.