Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WAMP session acceptor #6

Merged
merged 3 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions acceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package xconn

import (
"bytes"
"fmt"
"net"
"sync"

"github.com/gobwas/ws"
"golang.org/x/exp/maps"

"github.com/xconnio/wampproto-go"
"github.com/xconnio/wampproto-go/auth"
"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
)

var compiledWSProtocols = [][]byte{ //nolint:gochecknoglobals
[]byte(JsonWebsocketProtocol),
[]byte(MsgpackWebsocketProtocol),
[]byte(CborWebsocketProtocol),
}

var serializersByWSSubProtocol = map[string]serializers.Serializer{ //nolint:gochecknoglobals
JsonWebsocketProtocol: &serializers.JSONSerializer{},
MsgpackWebsocketProtocol: &serializers.MsgPackSerializer{},
CborWebsocketProtocol: &serializers.CBORSerializer{},
}

type WebSocketAcceptor struct {
specs map[string]serializers.Serializer
once sync.Once

Authenticator auth.ServerAuthenticator
}

func (w *WebSocketAcceptor) init() {
if w.specs == nil {
w.specs = serializersByWSSubProtocol
}
}

func (w *WebSocketAcceptor) protocols() []string {
w.once.Do(w.init)

return maps.Keys(w.specs)
}

func (w *WebSocketAcceptor) RegisterSpec(subProtocol string, serializer serializers.Serializer) error {
w.once.Do(w.init)

_, exists := w.specs[subProtocol]
if exists {
return fmt.Errorf("spec for %s is alraedy registered", subProtocol)
}

w.specs[subProtocol] = serializer
return nil
}

func (w *WebSocketAcceptor) Spec(subProtocol string) (serializers.Serializer, error) {
w.once.Do(w.init)

serializer, exists := w.specs[subProtocol]
if !exists {
return nil, fmt.Errorf("spec for %s is not registered", subProtocol)
}

return serializer, nil
}

func (w *WebSocketAcceptor) Accept(conn net.Conn) (BaseSession, error) {
config := DefaultWebSocketServerConfig()
config.SubProtocols = w.protocols()
peer, err := UpgradeWebSocket(conn, config)
if err != nil {
return nil, fmt.Errorf("failed to init reader/writer: %w", err)
}

wsPeer := peer.(*WebSocketPeer)
serializer, err := w.Spec(wsPeer.Protocol())
if err != nil {
return nil, fmt.Errorf("unknown subprotocol: %w", err)
}

hello, err := ReadHello(peer, serializer)
if err != nil {
return nil, fmt.Errorf("")
}

return Accept(peer, hello, serializer, w.Authenticator)
}

func Accept(peer Peer, hello *messages.Hello, serializer serializers.Serializer,
authenticator auth.ServerAuthenticator) (BaseSession, error) {

a := wampproto.NewAcceptor(serializer, authenticator)
toSend, err := a.ReceiveMessage(hello)
if err != nil {
return nil, err
}

if err = WriteMessage(peer, toSend, serializer); err != nil {
return nil, err
}

if toSend.Type() == messages.MessageTypeWelcome {
goto Welcomed
}

for {
payload, err := peer.Read()
if err != nil {
return nil, err
}

toSend, welcomed, err := a.Receive(payload)
if err != nil {
return nil, err
}

if err = peer.Write(toSend); err != nil {
return nil, err
}

if welcomed {
goto Welcomed
}
}

Welcomed:
d, _ := a.SessionDetails()
details := NewBaseSession(d.ID(), d.Realm(), d.AuthID(), d.AuthRole(), peer)
return details, nil
}

func UpgradeWebSocket(conn net.Conn, config *WebSocketServerConfig) (Peer, error) {
wsUpgrader := ws.Upgrader{
Protocol: func(protoBytes []byte) bool {
if config == nil {
for _, protocol := range compiledWSProtocols {
if bytes.Equal(protoBytes, protocol) {
return true
}
}
} else {
for _, protocol := range config.SubProtocols {
if bytes.Equal(protoBytes, []byte(protocol)) {
return true
}
}
}

return false
},
}

hs, err := wsUpgrader.Upgrade(conn)
if err != nil {
return nil, fmt.Errorf("failed to upgrade to websocket: %w", err)
}

isBinary := hs.Protocol != JsonWebsocketProtocol
peer, err := NewWebSocketPeer(conn, hs.Protocol, isBinary, true)
if err != nil {
return nil, fmt.Errorf("failed to init reader/writer: %w", err)
}

return peer, nil
}
40 changes: 40 additions & 0 deletions acceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package xconn_test

import (
"context"
"fmt"
"net"
"testing"

"github.com/gammazero/nexus/v3/client"
"github.com/stretchr/testify/require"

"github.com/xconnio/xconn-go"
)

func TestAccept(t *testing.T) {
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
require.NotNil(t, listener)

accepted := make(chan xconn.BaseSession, 1)

go func() {
conn, err := listener.Accept()
require.NoError(t, err)
require.NotNil(t, conn)

acceptor := xconn.WebSocketAcceptor{}
session, err := acceptor.Accept(conn)
require.NoError(t, err)
require.NotNil(t, session)

accepted <- session
}()

wsURL := fmt.Sprintf("ws://%s/ws", listener.Addr().String())
config := client.Config{Realm: "realm1"}
cl, err := client.ConnectNet(context.Background(), wsURL, config)
require.NoError(t, err)
require.NotNil(t, cl)
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ require (
github.com/gammazero/workerpool v1.1.3
github.com/gobwas/ws v1.4.0
github.com/stretchr/testify v1.8.4
github.com/xconnio/wampproto-go v0.0.0-20240530202948-a758eb534226
github.com/xconnio/wampproto-go v0.0.0-20240531231532-d8fa7f588c4e
golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d
)

require (
Expand All @@ -23,7 +24,6 @@ require (
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d // indirect
golang.org/x/sys v0.20.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xconnio/wampproto-go v0.0.0-20240530202948-a758eb534226 h1:1UFs+1ev6G1qDVgf5tGqnhsm5e9btuon41ogyrR1QD4=
github.com/xconnio/wampproto-go v0.0.0-20240530202948-a758eb534226/go.mod h1:BH0AFRLJ9POvVfxsFd9GyvA15U9o0XYQfq8TdkqO2vQ=
github.com/xconnio/wampproto-go v0.0.0-20240531231532-d8fa7f588c4e h1:15wgqkrASYTouf37nDshH9TjTSDNjB6EOfPSytHq9kg=
github.com/xconnio/wampproto-go v0.0.0-20240531231532-d8fa7f588c4e/go.mod h1:BH0AFRLJ9POvVfxsFd9GyvA15U9o0XYQfq8TdkqO2vQ=
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
Expand Down
50 changes: 49 additions & 1 deletion helpers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
package xconn

import "github.com/gobwas/ws/wsutil"
import (
"fmt"

"github.com/gobwas/ws/wsutil"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
)

func ClientSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {
if !binary {
Expand All @@ -17,3 +24,44 @@ func ServerSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {

return wsutil.ReadClientBinary, wsutil.WriteServerBinary, nil
}

func ReadMessage(peer Peer, serializer serializers.Serializer) (messages.Message, error) {
payload, err := peer.Read()
if err != nil {
return nil, err
}

msg, err := serializer.Deserialize(payload)
if err != nil {
return nil, err
}

return msg, nil
}

func ReadHello(peer Peer, serializer serializers.Serializer) (*messages.Hello, error) {
msg, err := ReadMessage(peer, serializer)
if err != nil {
return nil, err
}

if msg.Type() != messages.MessageTypeHello {
return nil, fmt.Errorf("first message must be HELLO, but was %d", msg.Type())
}

hello := msg.(*messages.Hello)
return hello, nil
}

func WriteMessage(peer Peer, message messages.Message, serializer serializers.Serializer) error {
payload, err := serializer.Serialize(message)
if err != nil {
return fmt.Errorf("failed to serialize message: %w", err)
}

if err = peer.Write(payload); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}

return nil
}
10 changes: 10 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ type WebSocketServerConfig struct {
SubProtocols []string
}

func DefaultWebSocketServerConfig() *WebSocketServerConfig {
return &WebSocketServerConfig{
SubProtocols: []string{
JsonWebsocketProtocol,
MsgpackWebsocketProtocol,
CborWebsocketProtocol,
},
}
}

type WSSerializerSpec interface {
SubProtocol() string
Serializer() serializers.Serializer
Expand Down
Loading