diff --git a/acceptor.go b/acceptor.go index 98db720..be3ae1a 100644 --- a/acceptor.go +++ b/acceptor.go @@ -69,7 +69,7 @@ func (w *WebSocketAcceptor) Spec(subProtocol string) (serializers.Serializer, er return serializer, nil } -func (w *WebSocketAcceptor) Accept(conn net.Conn, config *WebSocketServerConfig) (BaseSession, error) { +func (w *WebSocketAcceptor) Accept(conn net.Conn, router *Router, config *WebSocketServerConfig) (BaseSession, error) { if config == nil { config = DefaultWebSocketServerConfig() } @@ -90,6 +90,19 @@ func (w *WebSocketAcceptor) Accept(conn net.Conn, config *WebSocketServerConfig) return nil, fmt.Errorf("") } + if !router.HasRealm(hello.Realm()) { + abortMessage := messages.NewAbort(map[string]any{}, wampproto.ErrNoSuchRealm, nil, nil) + serializedAbort, err := serializer.Serialize(abortMessage) + if err != nil { + return nil, fmt.Errorf("failed to serialize abort: %w", err) + } + if err = peer.Write(serializedAbort); err != nil { + return nil, fmt.Errorf("failed to send abort: %w", err) + } + + return nil, fmt.Errorf(wampproto.ErrNoSuchRealm) + } + return Accept(peer, hello, serializer, w.Authenticator) } diff --git a/acceptor_test.go b/acceptor_test.go index 68a7b52..7e8011a 100644 --- a/acceptor_test.go +++ b/acceptor_test.go @@ -24,8 +24,11 @@ func TestAccept(t *testing.T) { require.NoError(t, err) require.NotNil(t, conn) + rout := xconn.NewRouter() + rout.AddRealm("realm1") + acceptor := xconn.WebSocketAcceptor{} - session, err := acceptor.Accept(conn, nil) + session, err := acceptor.Accept(conn, rout, nil) require.NoError(t, err) require.NotNil(t, session) diff --git a/server.go b/server.go index c848cd3..ded9707 100644 --- a/server.go +++ b/server.go @@ -64,7 +64,7 @@ func (s *Server) HandleClient(conn net.Conn) { config := DefaultWebSocketServerConfig() config.KeepAliveInterval = s.keepAliveInterval config.KeepAliveTimeout = s.keepAliveTimeout - base, err := s.acceptor.Accept(conn, config) + base, err := s.acceptor.Accept(conn, s.router, config) if err != nil { return }