Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: align naming with mqtt terminology #3

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions connector.go → broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ import (
"github.com/antoniomika/syncmap"
)

type Connector interface {
type Broker interface {
GetChannels() iter.Seq2[string, *Channel]
GetClients() iter.Seq2[string, *Client]
Connect(*Client, []*Channel) (error, error)
}

type BaseConnector struct {
type BaseBroker struct {
Channels *syncmap.Map[string, *Channel]
}

func (b *BaseConnector) Cleanup() {
func (b *BaseBroker) Cleanup() {
toRemove := []string{}
for _, channel := range b.GetChannels() {
count := 0
Expand All @@ -31,7 +31,7 @@ func (b *BaseConnector) Cleanup() {

if count == 0 {
channel.Cleanup()
toRemove = append(toRemove, channel.ID)
toRemove = append(toRemove, channel.Topic)
}
}

Expand All @@ -40,25 +40,25 @@ func (b *BaseConnector) Cleanup() {
}
}

func (b *BaseConnector) GetChannels() iter.Seq2[string, *Channel] {
func (b *BaseBroker) GetChannels() iter.Seq2[string, *Channel] {
return b.Channels.Range
}

func (b *BaseConnector) GetClients() iter.Seq2[string, *Client] {
func (b *BaseBroker) GetClients() iter.Seq2[string, *Client] {
return func(yield func(string, *Client) bool) {
for _, channel := range b.GetChannels() {
channel.Clients.Range(yield)
}
}
}

func (b *BaseConnector) Connect(client *Client, channels []*Channel) (error, error) {
func (b *BaseBroker) Connect(client *Client, channels []*Channel) (error, error) {
for _, channel := range channels {
dataChannel := b.ensureChannel(channel)
dataChannel.Clients.Store(client.ID, client)
client.Channels.Store(dataChannel.ID, dataChannel)
client.Channels.Store(dataChannel.Topic, dataChannel)
defer func() {
client.Channels.Delete(channel.ID)
client.Channels.Delete(channel.Topic)
dataChannel.Clients.Delete(client.ID)

client.Cleanup()
Expand Down Expand Up @@ -186,10 +186,10 @@ func (b *BaseConnector) Connect(client *Client, channels []*Channel) (error, err
return inputErr, outputErr
}

func (b *BaseConnector) ensureChannel(channel *Channel) *Channel {
dataChannel, _ := b.Channels.LoadOrStore(channel.ID, channel)
func (b *BaseBroker) ensureChannel(channel *Channel) *Channel {
dataChannel, _ := b.Channels.LoadOrStore(channel.Topic, channel)
dataChannel.Handle()
return dataChannel
}

var _ Connector = (*BaseConnector)(nil)
var _ Broker = (*BaseBroker)(nil)
6 changes: 3 additions & 3 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ type ChannelMessage struct {
Action ChannelAction
}

func NewChannel(name string) *Channel {
func NewChannel(topic string) *Channel {
return &Channel{
ID: name,
Topic: topic,
Done: make(chan struct{}),
Data: make(chan ChannelMessage),
Clients: syncmap.New[string, *Client](),
}
}

type Channel struct {
ID string
Topic string
Done chan struct{}
Data chan ChannelMessage
Clients *syncmap.Map[string, *Client]
Expand Down
53 changes: 22 additions & 31 deletions cmd/authorized_keys/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"syscall"
"time"

"github.com/antoniomika/syncmap"
"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
"github.com/google/uuid"
Expand All @@ -26,66 +25,66 @@ func GetEnv(key string, defaultVal string) string {
return defaultVal
}

func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware {
func PubSubMiddleware(broker pubsub.PubSub, logger *slog.Logger) wish.Middleware {
return func(next ssh.Handler) ssh.Handler {
return func(sesh ssh.Session) {
args := sesh.Command()
if len(args) < 2 {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {topic}")
next(sesh)
return
}

cmd := strings.TrimSpace(args[0])
channel := args[1]
topicsRaw := args[1]

channels := strings.Split(channel, ",")
topics := strings.Split(topicsRaw, ",")

logger := cfg.Logger.With(
logger := logger.With(
"cmd", cmd,
"channel", channels,
"topics", topics,
)

logger.Info("running cli")

if cmd == "help" {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {topic}")
} else if cmd == "sub" {
var chans []*pubsub.Channel

for _, c := range channels {
chans = append(chans, pubsub.NewChannel(c))
for _, topic := range topics {
chans = append(chans, pubsub.NewChannel(topic))
}

clientID := uuid.NewString()

err := errors.Join(cfg.PubSub.Sub(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "keepalive"))
err := errors.Join(broker.Sub(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "keepalive"))
if err != nil {
logger.Error("error during pub", slog.Any("error", err), slog.String("client", clientID))
}
} else if cmd == "pub" {
var chans []*pubsub.Channel

for _, c := range channels {
chans = append(chans, pubsub.NewChannel(c))
for _, topic := range topics {
chans = append(chans, pubsub.NewChannel(topic))
}

clientID := uuid.NewString()

err := errors.Join(cfg.PubSub.Pub(sesh.Context(), clientID, sesh, chans))
err := errors.Join(broker.Pub(sesh.Context(), clientID, sesh, chans))
if err != nil {
logger.Error("error during pub", slog.Any("error", err), slog.String("client", clientID))
}
} else if cmd == "pipe" {
var chans []*pubsub.Channel

for _, c := range channels {
chans = append(chans, pubsub.NewChannel(c))
for _, topics := range topics {
chans = append(chans, pubsub.NewChannel(topics))
}

clientID := uuid.NewString()

err := errors.Join(cfg.PubSub.Pipe(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "replay"))
err := errors.Join(broker.Pipe(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "replay"))
if err != nil {
logger.Error(
"pipe error",
Expand All @@ -94,7 +93,7 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware {
)
}
} else {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {topic}")
}

next(sesh)
Expand All @@ -107,23 +106,15 @@ func main() {
host := GetEnv("SSH_HOST", "0.0.0.0")
port := GetEnv("SSH_PORT", "2222")
keyPath := GetEnv("SSH_AUTHORIZED_KEYS", "./ssh_data/authorized_keys")
cfg := &pubsub.Cfg{
Logger: logger,
PubSub: &pubsub.PubSubMulticast{
Logger: logger,
Connector: &pubsub.BaseConnector{
Channels: syncmap.New[string, *pubsub.Channel](),
},
},
}
broker := pubsub.NewMulticast(logger)

s, err := wish.NewServer(
ssh.NoPty(),
wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
wish.WithAuthorizedKeys(keyPath),
wish.WithMiddleware(
PubSubMiddleware(cfg),
PubSubMiddleware(broker, logger),
),
)
if err != nil {
Expand All @@ -149,10 +140,10 @@ func main() {
slog.Info("Debug Info", slog.Int("goroutines", runtime.NumGoroutine()))
select {
case <-time.After(5 * time.Second):
for _, channel := range cfg.PubSub.GetChannels() {
slog.Info("channel online", slog.Any("channel", channel.ID))
for _, channel := range broker.GetChannels() {
slog.Info("channel online", slog.Any("channel topic", channel.Topic))
for _, client := range channel.GetClients() {
slog.Info("client online", slog.Any("channel", channel.ID), slog.Any("client", client.ID), slog.String("direction", client.Direction.String()))
slog.Info("client online", slog.Any("channel topic", channel.Topic), slog.Any("client", client.ID), slog.String("direction", client.Direction.String()))
}
}
case <-done:
Expand Down
33 changes: 22 additions & 11 deletions multicast.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,25 @@ import (
"io"
"iter"
"log/slog"

"github.com/antoniomika/syncmap"
)

type PubSubMulticast struct {
Connector
type Multicast struct {
Broker
Logger *slog.Logger
}

func (p *PubSubMulticast) getClients(direction ChannelDirection) iter.Seq2[string, *Client] {
func NewMulticast(logger *slog.Logger) *Multicast {
return &Multicast{
Logger: logger,
Broker: &BaseBroker{
Channels: syncmap.New[string, *Channel](),
},
}
}

func (p *Multicast) getClients(direction ChannelDirection) iter.Seq2[string, *Client] {
return func(yield func(string, *Client) bool) {
for clientID, client := range p.GetClients() {
if client.Direction == direction {
Expand All @@ -23,19 +34,19 @@ func (p *PubSubMulticast) getClients(direction ChannelDirection) iter.Seq2[strin
}
}

func (p *PubSubMulticast) GetPipes() iter.Seq2[string, *Client] {
func (p *Multicast) GetPipes() iter.Seq2[string, *Client] {
return p.getClients(ChannelDirectionInputOutput)
}

func (p *PubSubMulticast) GetPubs() iter.Seq2[string, *Client] {
func (p *Multicast) GetPubs() iter.Seq2[string, *Client] {
return p.getClients(ChannelDirectionInput)
}

func (p *PubSubMulticast) GetSubs() iter.Seq2[string, *Client] {
func (p *Multicast) GetSubs() iter.Seq2[string, *Client] {
return p.getClients(ChannelDirectionOutput)
}

func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool) (error, error) {
func (p *Multicast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool) (error, error) {
client := NewClient(ID, rw, direction, blockWrite, replay, keepAlive)

go func() {
Expand All @@ -46,16 +57,16 @@ func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWrit
return p.Connect(client, channels)
}

func (p *PubSubMulticast) Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) {
func (p *Multicast) Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) {
return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay, false)
}

func (p *PubSubMulticast) Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error {
func (p *Multicast) Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error {
return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, true, false, false))
}

func (p *PubSubMulticast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error {
func (p *Multicast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error {
return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false, keepAlive))
}

var _ PubSub = (*PubSubMulticast)(nil)
var _ = (*Multicast)(nil)
12 changes: 6 additions & 6 deletions multicast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ func TestMulticastSubBlock(t *testing.T) {
name := "test-channel"
syncer := make(chan int)

cast := &PubSubMulticast{
cast := &Multicast{
Logger: slog.Default(),
Connector: &BaseConnector{
Broker: &BaseBroker{
Channels: syncmap.New[string, *Channel](),
},
}
Expand Down Expand Up @@ -85,9 +85,9 @@ func TestMulticastPubBlock(t *testing.T) {
name := "test-channel"
syncer := make(chan int)

cast := &PubSubMulticast{
cast := &Multicast{
Logger: slog.Default(),
Connector: &BaseConnector{
Broker: &BaseBroker{
Channels: syncmap.New[string, *Channel](),
},
}
Expand Down Expand Up @@ -131,9 +131,9 @@ func TestMulticastMultSubs(t *testing.T) {
name := "test-channel"
syncer := make(chan int)

cast := &PubSubMulticast{
cast := &Multicast{
Logger: slog.Default(),
Connector: &BaseConnector{
Broker: &BaseBroker{
Channels: syncmap.New[string, *Channel](),
},
}
Expand Down
8 changes: 1 addition & 7 deletions pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,14 @@ import (
"context"
"io"
"iter"
"log/slog"
)

type PubSub interface {
Connector
Broker
GetPubs() iter.Seq2[string, *Client]
GetSubs() iter.Seq2[string, *Client]
GetPipes() iter.Seq2[string, *Client]
Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error)
Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error
Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error
}

type Cfg struct {
Logger *slog.Logger
PubSub PubSub
}
Loading