diff --git a/accepter_test.go b/accepter_test.go deleted file mode 100644 index 54bfff845..000000000 --- a/accepter_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) quickfixengine.org All rights reserved. -// -// This file may be distributed under the terms of the quickfixengine.org -// license as defined by quickfixengine.org and appearing in the file -// LICENSE included in the packaging of this file. -// -// This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING -// THE WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A -// PARTICULAR PURPOSE. -// -// See http://www.quickfixengine.org/LICENSE for licensing information. -// -// Contact ask@quickfixengine.org if any conditions of this licensing -// are not clear to you. - -package quickfix - -import ( - "net" - "testing" - - "github.com/quickfixgo/quickfix/config" - - proxyproto "github.com/pires/go-proxyproto" - "github.com/stretchr/testify/assert" -) - -func TestAcceptor_Start(t *testing.T) { - sessionSettings := NewSessionSettings() - sessionSettings.Set(config.BeginString, BeginStringFIX42) - sessionSettings.Set(config.SenderCompID, "sender") - sessionSettings.Set(config.TargetCompID, "target") - - settingsWithTCPProxy := NewSettings() - settingsWithTCPProxy.GlobalSettings().Set("UseTCPProxy", "Y") - - settingsWithNoTCPProxy := NewSettings() - settingsWithNoTCPProxy.GlobalSettings().Set("UseTCPProxy", "N") - - genericSettings := NewSettings() - - const ( - GenericListener = iota - ProxyListener - ) - - acceptorStartTests := []struct { - name string - settings *Settings - listenerType int - }{ - {"with TCP proxy set", settingsWithTCPProxy, ProxyListener}, - {"with no TCP proxy set", settingsWithNoTCPProxy, GenericListener}, - {"no TCP proxy configuration set", genericSettings, GenericListener}, - } - - for _, tt := range acceptorStartTests { - t.Run(tt.name, func(t *testing.T) { - tt.settings.GlobalSettings().Set("SocketAcceptPort", "5001") - if _, err := tt.settings.AddSession(sessionSettings); err != nil { - assert.Nil(t, err) - } - - acceptor := &Acceptor{settings: tt.settings} - if err := acceptor.Start(); err != nil { - assert.NotNil(t, err) - } - assert.Len(t, acceptor.listeners, 1) - - for _, listener := range acceptor.listeners { - if tt.listenerType == ProxyListener { - _, ok := listener.(*proxyproto.Listener) - assert.True(t, ok) - } - - if tt.listenerType == GenericListener { - _, ok := listener.(*net.TCPListener) - assert.True(t, ok) - } - } - - acceptor.Stop() - }) - } -} diff --git a/acceptor.go b/acceptor.go index f5b9b281c..c48ff5f46 100644 --- a/acceptor.go +++ b/acceptor.go @@ -38,6 +38,7 @@ type Acceptor struct { storeFactory MessageStoreFactory globalLog Log sessions map[SessionID]*session + sessionsLock sync.RWMutex sessionGroup sync.WaitGroup listenerShutdown sync.WaitGroup dynamicSessions bool @@ -48,6 +49,7 @@ type Acceptor struct { sessionHostPort map[SessionID]int listeners map[string]net.Listener connectionValidator ConnectionValidator + sessionProvider AcceptorSessionProvider sessionFactory } @@ -104,14 +106,8 @@ func (a *Acceptor) Start() (err error) { a.listeners[address] = &proxyproto.Listener{Listener: a.listeners[address]} } } + a.startSessions() - for _, s := range a.sessions { - a.sessionGroup.Add(1) - go func(s *session) { - s.run() - a.sessionGroup.Done() - }(s) - } if a.dynamicSessions { a.dynamicSessionChan = make(chan *session) a.sessionGroup.Add(1) @@ -140,17 +136,7 @@ func (a *Acceptor) Stop() { if a.dynamicSessions { close(a.dynamicSessionChan) } - for _, session := range a.sessions { - session.stop() - } - a.sessionGroup.Wait() - - for sessionID := range a.sessions { - err := UnregisterSession(sessionID) - if err != nil { - return - } - } + a.stopSessions() } // RemoteAddr gets remote IP address for a given session. @@ -191,6 +177,15 @@ func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Se } for sessionID, sessionSettings := range settings.SessionSettings() { + if sessionSettings.HasSetting(config.AcceptorTemplate) { + var acceptorTemplate bool + if acceptorTemplate, err = sessionSettings.BoolSetting(config.AcceptorTemplate); err != nil { + return + } + if acceptorTemplate { + continue + } + } sessID := sessionID sessID.Qualifier = "" @@ -331,18 +326,35 @@ func (a *Acceptor) handleConnection(netConn net.Conn) { } session, ok := a.sessions[sessID] if !ok { - if !a.dynamicSessions { - a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes) - return + var dynamicSessionCreated bool + if a.sessionProvider != nil { + session, err = a.sessionProvider.GetSession(sessID) + if err != nil { + if err == errUnknownSession && a.dynamicSessions { + goto CREATE_SHORT_LIVED_DYNAMIC_SESSION + } + a.globalLog.OnEventf("Failed to get session %v from provider: %v", sessID, err) + return + } + a.addMngdDynamicSession(sessID, session) + dynamicSessionCreated = true } - dynamicSession, err := a.sessionFactory.createSession(sessID, a.storeFactory, a.settings.globalSettings.clone(), a.logFactory, a.app) - if err != nil { - a.globalLog.OnEventf("Dynamic session %v failed to create: %v", sessID, err) - return + CREATE_SHORT_LIVED_DYNAMIC_SESSION: + if !dynamicSessionCreated { + if !a.dynamicSessions { + a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes) + return + } + dynamicSession, err := a.sessionFactory.createSession(sessID, a.storeFactory, a.settings.globalSettings.clone(), a.logFactory, a.app) + if err != nil { + a.globalLog.OnEventf("Dynamic session %v failed to create: %v", sessID, err) + return + } + a.dynamicSessionChan <- dynamicSession + session = dynamicSession + defer session.stop() } - a.dynamicSessionChan <- dynamicSession - session = dynamicSession - defer session.stop() + } a.sessionAddr.Store(sessID, netConn.RemoteAddr()) @@ -412,6 +424,46 @@ LOOP: } } +func (a *Acceptor) startSessions() { + a.sessionsLock.RLock() + defer a.sessionsLock.RUnlock() + for _, s := range a.sessions { + a.sessionGroup.Add(1) + go func(s *session) { + s.run() + a.sessionGroup.Done() + }(s) + } +} + +func (a *Acceptor) stopSessions() { + a.sessionsLock.RLock() + defer a.sessionsLock.RUnlock() + for _, session := range a.sessions { + session.stop() + } + a.sessionGroup.Wait() + + for sessionID := range a.sessions { + err := UnregisterSession(sessionID) + if err != nil { + return + } + } +} + +func (a *Acceptor) addMngdDynamicSession(sessID SessionID, session *session) { + a.sessionsLock.Lock() + defer a.sessionsLock.Unlock() + + a.sessions[sessID] = session + a.sessionGroup.Add(1) + go func() { + session.run() + a.sessionGroup.Done() + }() +} + // SetConnectionValidator sets an optional connection validator. // Use it when you need a custom authentication logic that includes lower level interactions, // like mTLS auth or IP whitelistening. @@ -421,3 +473,8 @@ LOOP: func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) { a.connectionValidator = validator } + +// SetSessionProvider sets an optional session provider. +func (a *Acceptor) SetSessionProvider(sessionProvider AcceptorSessionProvider) { + a.sessionProvider = sessionProvider +} diff --git a/acceptor_session_provider.go b/acceptor_session_provider.go new file mode 100644 index 000000000..2547a0535 --- /dev/null +++ b/acceptor_session_provider.go @@ -0,0 +1,115 @@ +package quickfix + +import "github.com/quickfixgo/quickfix/config" + +const ( + WildcardPattern string = "*" +) + +type AcceptorSessionProvider interface { + GetSession(SessionID) (*session, error) +} + +type StaticAcceptorSessionProvider struct { + sessions map[SessionID]*session +} + +func (p *StaticAcceptorSessionProvider) GetSession(sessionID SessionID) (*session, error) { + s, ok := p.sessions[sessionID] + if !ok { + return nil, errUnknownSession + } + return s, nil +} + +// DynamicAcceptorSessionProvider dynamically defines sessions for an acceptor. This can be useful for +// applications like simulators that want to accept any connection and +// dynamically create an associated session. +// +// For more complex situations, you can use this class as a starting +// point for implementing your own AcceptorSessionProvider. +type DynamicAcceptorSessionProvider struct { + settings *Settings + messageStoreFactory MessageStoreFactory + logFactory LogFactory + sessionFactory *sessionFactory + application Application + templateMappings []*TemplateMapping +} + +func NewDynamicAcceptorSessionProvider(settings *Settings, messageStoreFactory MessageStoreFactory, logFactory LogFactory, + application Application, templateMappings []*TemplateMapping, +) *DynamicAcceptorSessionProvider { + return &DynamicAcceptorSessionProvider{ + settings: settings, + messageStoreFactory: messageStoreFactory, + logFactory: logFactory, + sessionFactory: &sessionFactory{}, + application: application, + templateMappings: templateMappings, + } +} + +func (p *DynamicAcceptorSessionProvider) FindTemplateID(sessionID SessionID) *SessionID { + return p.lookupTemplateID(sessionID) +} + +func (p *DynamicAcceptorSessionProvider) GetSession(sessionID SessionID) (*session, error) { + s, ok := lookupSession(sessionID) + if !ok { + templateID := p.lookupTemplateID(sessionID) + if templateID == nil { + return nil, errUnknownSession + } + dynamicSessionSettings := p.settings.globalSettings.clone() + templateSettings := p.settings.sessionSettings[*templateID] + dynamicSessionSettings.overlay(templateSettings) + dynamicSessionSettings.Set(config.BeginString, sessionID.BeginString) + dynamicSessionSettings.Set(config.SenderCompID, sessionID.SenderCompID) + dynamicSessionSettings.Set(config.SenderSubID, sessionID.SenderSubID) + dynamicSessionSettings.Set(config.SenderLocationID, sessionID.SenderLocationID) + dynamicSessionSettings.Set(config.TargetCompID, sessionID.TargetCompID) + dynamicSessionSettings.Set(config.TargetSubID, sessionID.TargetSubID) + dynamicSessionSettings.Set(config.TargetLocationID, sessionID.TargetLocationID) + var err error + s, err = p.sessionFactory.createSession(sessionID, + p.messageStoreFactory, + dynamicSessionSettings, + p.logFactory, + p.application, + ) + if err != nil { + return nil, err + } + } + return s, nil +} + +func (provider *DynamicAcceptorSessionProvider) lookupTemplateID(sessionID SessionID) *SessionID { + for _, mapping := range provider.templateMappings { + if isTemplateMatching(mapping.Pattern, sessionID) { + return &mapping.TemplateID + } + } + return nil +} + +func isTemplateMatching(pattern SessionID, sessionID SessionID) bool { + return matches(pattern.BeginString, sessionID.BeginString) && + matches(pattern.SenderCompID, sessionID.SenderCompID) && + matches(pattern.SenderSubID, sessionID.SenderSubID) && + matches(pattern.SenderLocationID, sessionID.SenderLocationID) && + matches(pattern.TargetCompID, sessionID.TargetCompID) && + matches(pattern.TargetSubID, sessionID.TargetSubID) && + matches(pattern.TargetLocationID, sessionID.TargetLocationID) +} + +func matches(pattern string, value string) bool { + return WildcardPattern == pattern || pattern == value +} + +// TemplateMapping mapping from a sessionID pattern to a session template ID. +type TemplateMapping struct { + Pattern SessionID + TemplateID SessionID +} diff --git a/acceptor_session_provider_test.go b/acceptor_session_provider_test.go new file mode 100644 index 000000000..6cc221000 --- /dev/null +++ b/acceptor_session_provider_test.go @@ -0,0 +1,247 @@ +package quickfix + +import ( + "reflect" + "testing" + + "github.com/quickfixgo/quickfix/config" + "github.com/stretchr/testify/suite" +) + +type DynamicAcceptorSessionProviderTestSuite struct { + suite.Suite + + provider *DynamicAcceptorSessionProvider + + settings *Settings + messageStoreFactory MessageStoreFactory + logFactory LogFactory + app Application + sessionFactory *sessionFactory + TemplateMapping []*TemplateMapping +} + +func (suite *DynamicAcceptorSessionProviderTestSuite) SetupTest() { + suite.settings = NewSettings() + suite.messageStoreFactory = NewMemoryStoreFactory() + suite.logFactory = nullLogFactory{} + suite.app = &noopApp{} + suite.sessionFactory = &sessionFactory{} + suite.TemplateMapping = make([]*TemplateMapping, 0) + + templateId1 := SessionID{BeginString: "FIX.4.2", SenderCompID: "ANY", TargetCompID: "ANY"} + suite.TemplateMapping = append( + suite.TemplateMapping, + &TemplateMapping{Pattern: SessionID{BeginString: WildcardPattern, SenderCompID: "S1", TargetCompID: WildcardPattern}, TemplateID: templateId1}, + ) + suite.setUpSettings(templateId1, "ResetOnLogout", "Y") + + templateId2 := SessionID{BeginString: "FIX.4.4", SenderCompID: "S1", TargetCompID: "ANY"} + suite.TemplateMapping = append( + suite.TemplateMapping, + &TemplateMapping{Pattern: SessionID{BeginString: "FIX.4.4", SenderCompID: WildcardPattern, TargetCompID: WildcardPattern}, TemplateID: templateId2}, + ) + suite.setUpSettings(templateId2, "RefreshOnLogon", "Y") + + templateId3 := SessionID{BeginString: "FIX.4.4", SenderCompID: "ANY", TargetCompID: "ANY"} + suite.TemplateMapping = append( + suite.TemplateMapping, + &TemplateMapping{Pattern: SessionID{BeginString: "FIX.4.2", SenderCompID: WildcardPattern, SenderSubID: WildcardPattern, SenderLocationID: WildcardPattern, + TargetCompID: WildcardPattern, TargetSubID: WildcardPattern, TargetLocationID: WildcardPattern, Qualifier: WildcardPattern, + }, TemplateID: templateId3}, + ) + suite.setUpSettings(templateId3, "ResetOnDisconnect", "Y") + + suite.provider = NewDynamicAcceptorSessionProvider(suite.settings, suite.messageStoreFactory, + suite.logFactory, suite.app, suite.TemplateMapping) +} + +func (suite *DynamicAcceptorSessionProviderTestSuite) setUpSettings(TemplateID SessionID, key, value string) { + sessionSettings := NewSessionSettings() + sessionSettings.Set(config.BeginString, TemplateID.BeginString) + sessionSettings.Set(config.SenderCompID, TemplateID.SenderCompID) + sessionSettings.Set(config.SenderSubID, TemplateID.SenderSubID) + sessionSettings.Set(config.SenderLocationID, TemplateID.SenderLocationID) + sessionSettings.Set(config.TargetCompID, TemplateID.TargetCompID) + sessionSettings.Set(config.TargetSubID, TemplateID.TargetSubID) + sessionSettings.Set(config.TargetLocationID, TemplateID.TargetLocationID) + sessionSettings.Set(config.SessionQualifier, TemplateID.Qualifier) + + sessionSettings.Set("StartTime", "00:00:00") + sessionSettings.Set("EndTime", "00:00:00") + sessionSettings.Set(key, value) + suite.settings.AddSession(sessionSettings) +} + +func (suite *DynamicAcceptorSessionProviderTestSuite) TestSessionCreation() { + type expected struct { + sessionID SessionID + resetOnLogout bool + refreshOnLogon bool + resetOnDisconnect bool + } + var tests = []struct { + name string + input SessionID + expected expected + }{ + { + name: "session created - matched", + input: SessionID{ + BeginString: "FIX.4.2", SenderCompID: "SENDER", SenderSubID: "SENDERSUB", SenderLocationID: "SENDERLOC", + TargetCompID: "TARGET", TargetSubID: "TARGETSUB", TargetLocationID: "TARGETLOC", Qualifier: "", + }, + expected: expected{ + sessionID: SessionID{ + BeginString: "FIX.4.2", SenderCompID: "SENDER", SenderSubID: "SENDERSUB", SenderLocationID: "SENDERLOC", + TargetCompID: "TARGET", TargetSubID: "TARGETSUB", TargetLocationID: "TARGETLOC", Qualifier: "", + }, + resetOnLogout: false, + refreshOnLogon: false, + resetOnDisconnect: true, + }, + }, + { + name: "create session - matching the first", + input: SessionID{ + BeginString: "FIX.4.4", SenderCompID: "S1", TargetCompID: "T", + }, + expected: expected{ + sessionID: SessionID{ + BeginString: "FIX.4.4", SenderCompID: "S1", TargetCompID: "T", + }, + resetOnLogout: true, + refreshOnLogon: false, + resetOnDisconnect: false, + }, + }, + { + name: "create session - matching the second", + input: SessionID{ + BeginString: "FIX.4.4", SenderCompID: "X", TargetCompID: "Y", + }, + expected: expected{ + sessionID: SessionID{ + BeginString: "FIX.4.4", SenderCompID: "X", TargetCompID: "Y", + }, + resetOnLogout: false, + refreshOnLogon: true, + resetOnDisconnect: false, + }, + }, + } + + for _, test := range tests { + session, err := suite.provider.GetSession(test.input) + suite.NoError(err) + suite.NotNil(session) + sessionID := session.sessionID + suite.Equal(test.expected.sessionID, sessionID, test.name+": created sessionID not expected") + suite.Equal(test.expected.resetOnLogout, session.ResetOnLogout, test.name+":ResetOnLogout not expected") + suite.Equal(test.expected.refreshOnLogon, session.RefreshOnLogon, test.name+":RefreshOnLogon not expected") + suite.Equal(test.expected.resetOnDisconnect, session.ResetOnDisconnect, test.name+":ResetOnDisconnect not expected") + } +} + +func (suite *DynamicAcceptorSessionProviderTestSuite) TestTemplateNotFound() { + var tests = []struct { + name string + input SessionID + }{ + { + name: "template not found", + input: SessionID{ + BeginString: "FIX.4.3", SenderCompID: "S", TargetCompID: "T", + }, + }, + } + + for _, test := range tests { + _, err := suite.provider.GetSession(test.input) + suite.Error(err, test.name+": expected error for template not found") + } +} + +func TestDynamicAcceptorSessionProviderTestSuite(t *testing.T) { + suite.Run(t, new(DynamicAcceptorSessionProviderTestSuite)) +} + +func TestStaticSessionProvider_GetSession(t *testing.T) { + sessions := make(map[SessionID]*session) + sessionID1 := SessionID{BeginString: "FIX.4.2", SenderCompID: "SENDER", TargetCompID: "TARGET"} + session1 := &session{sessionID: sessionID1} + sessions[sessionID1] = session1 + + type args struct { + sessionID SessionID + } + tests := []struct { + name string + args args + want *session + wantErr bool + }{ + { + name: "session found", + args: args{ + sessionID: sessionID1, + }, + want: session1, + wantErr: false, + }, + { + name: "session not found", + args: args{ + sessionID: SessionID{ + BeginString: "FIX.4.2", SenderCompID: "X", TargetCompID: "Y", + }, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &StaticAcceptorSessionProvider{ + sessions: sessions, + } + got, err := p.GetSession(tt.args.sessionID) + if (err != nil) != tt.wantErr { + t.Errorf("StaticSessionProvider.GetSession() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("StaticSessionProvider.GetSession() = %v, want %v", got, tt.want) + } + }) + } +} + +var _ Application = &noopApp{} + +type noopApp struct { +} + +func (n *noopApp) FromAdmin(message *Message, sessionID SessionID) MessageRejectError { + return nil +} + +func (n *noopApp) FromApp(message *Message, sessionID SessionID) MessageRejectError { + return nil +} + +func (n *noopApp) OnCreate(sessionID SessionID) { +} + +func (n *noopApp) OnLogon(sessionID SessionID) { +} + +func (n *noopApp) OnLogout(sessionID SessionID) { +} + +func (n *noopApp) ToAdmin(message *Message, sessionID SessionID) { +} + +func (n *noopApp) ToApp(message *Message, sessionID SessionID) error { + return nil +} diff --git a/acceptor_test.go b/acceptor_test.go new file mode 100644 index 000000000..ffb015c8c --- /dev/null +++ b/acceptor_test.go @@ -0,0 +1,373 @@ +// Copyright (c) quickfixengine.org All rights reserved. +// +// This file may be distributed under the terms of the quickfixengine.org +// license as defined by quickfixengine.org and appearing in the file +// LICENSE included in the packaging of this file. +// +// This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING +// THE WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A +// PARTICULAR PURPOSE. +// +// See http://www.quickfixengine.org/LICENSE for licensing information. +// +// Contact ask@quickfixengine.org if any conditions of this licensing +// are not clear to you. + +package quickfix + +import ( + "bytes" + "io" + "net" + "testing" + "time" + + "github.com/quickfixgo/quickfix/config" + + proxyproto "github.com/pires/go-proxyproto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +func TestAcceptor_Start(t *testing.T) { + sessionSettings := NewSessionSettings() + sessionSettings.Set(config.BeginString, BeginStringFIX42) + sessionSettings.Set(config.SenderCompID, "sender") + sessionSettings.Set(config.TargetCompID, "target") + + settingsWithTCPProxy := NewSettings() + settingsWithTCPProxy.GlobalSettings().Set("UseTCPProxy", "Y") + + settingsWithNoTCPProxy := NewSettings() + settingsWithNoTCPProxy.GlobalSettings().Set("UseTCPProxy", "N") + + genericSettings := NewSettings() + + const ( + GenericListener = iota + ProxyListener + ) + + acceptorStartTests := []struct { + name string + settings *Settings + listenerType int + }{ + {"with TCP proxy set", settingsWithTCPProxy, ProxyListener}, + {"with no TCP proxy set", settingsWithNoTCPProxy, GenericListener}, + {"no TCP proxy configuration set", genericSettings, GenericListener}, + } + + for _, tt := range acceptorStartTests { + t.Run(tt.name, func(t *testing.T) { + tt.settings.GlobalSettings().Set("SocketAcceptPort", "5001") + if _, err := tt.settings.AddSession(sessionSettings); err != nil { + assert.Nil(t, err) + } + + acceptor := &Acceptor{settings: tt.settings} + if err := acceptor.Start(); err != nil { + assert.NotNil(t, err) + } + assert.Len(t, acceptor.listeners, 1) + + for _, listener := range acceptor.listeners { + if tt.listenerType == ProxyListener { + _, ok := listener.(*proxyproto.Listener) + assert.True(t, ok) + } + + if tt.listenerType == GenericListener { + _, ok := listener.(*net.TCPListener) + assert.True(t, ok) + } + } + + acceptor.Stop() + }) + } +} + +var _ net.Conn = &mockConn{} + +type mockConn struct { + closeChan chan struct{} + localAddr net.Addr + remoteAddr net.Addr + + onWriteback func([]byte) + inboundMessages []*Message +} + +func (c *mockConn) Read(b []byte) (n int, err error) { + if len(c.inboundMessages) > 0 { + messageBytes := c.inboundMessages[0].build() + copy(b, messageBytes) + c.inboundMessages = c.inboundMessages[1:] + return len(messageBytes), err + } + <-c.closeChan + return 0, io.EOF +} + +func (c *mockConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *mockConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *mockConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *mockConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *mockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +func (c *mockConn) Write(b []byte) (n int, err error) { + if c.onWriteback != nil { + c.onWriteback(b) + } + return len(b), nil +} + +func (c *mockConn) Close() error { + return nil +} + +func mockLogonMessage(sessionID SessionID, msgSeqNum int) *Message { + msg := NewMessage() + msg.Header.SetField(tagMsgType, FIXString("A")) + msg.Header.SetInt(tagMsgSeqNum, msgSeqNum) + msg.Header.SetString(tagBeginString, sessionID.BeginString) + msg.Header.SetString(tagSenderCompID, sessionID.SenderCompID) + msg.Header.SetString(tagSenderSubID, sessionID.SenderSubID) + msg.Header.SetString(tagSenderLocationID, sessionID.SenderLocationID) + msg.Header.SetString(tagTargetCompID, sessionID.TargetCompID) + msg.Header.SetString(tagTargetSubID, sessionID.TargetSubID) + msg.Header.SetString(tagTargetLocationID, sessionID.TargetLocationID) + msg.Header.SetField(tagSendingTime, FIXUTCTimestamp{Time: time.Now()}) + msg.Body.SetInt(tagHeartBtInt, 30) + return msg +} + +type AcceptorTemplateTestSuite struct { + suite.Suite + acceptor *Acceptor + + sessionId1 SessionID + sessionId2 SessionID + sessionId3 SessionID + + testDynamicSessionID SessionID + logonSessionID SessionID + seqNum int + + dynamicSessionProvider AcceptorSessionProvider +} + +func (suite *AcceptorTemplateTestSuite) BeforeTest(suiteName, testName string) { + settings := NewSettings() + settings.globalSettings.Set(config.SocketAcceptPort, "5001") + sessionId1 := SessionID{BeginString: BeginStringFIX42, SenderCompID: "sender1", TargetCompID: "target1"} + sessionSettings1 := NewSessionSettings() + sessionSettings1.Set(config.BeginString, sessionId1.BeginString) + sessionSettings1.Set(config.SenderCompID, sessionId1.SenderCompID) + sessionSettings1.Set(config.TargetCompID, sessionId1.TargetCompID) + suite.sessionId1 = sessionId1 + settings.AddSession(sessionSettings1) + + sessionId2 := SessionID{BeginString: BeginStringFIX43, SenderCompID: "sender2", TargetCompID: "target2"} + sessionSettings2 := NewSessionSettings() + sessionSettings2.Set(config.BeginString, sessionId2.BeginString) + sessionSettings2.Set(config.SenderCompID, sessionId2.SenderCompID) + sessionSettings2.Set(config.TargetCompID, sessionId2.TargetCompID) + suite.sessionId2 = sessionId2 + settings.AddSession(sessionSettings2) + + // acceptor template + sessionId3 := SessionID{BeginString: BeginStringFIX43, SenderCompID: "*", SenderSubID: "*", SenderLocationID: "*", + TargetCompID: "target3", TargetSubID: "*", TargetLocationID: "*"} + sessionSettings3 := NewSessionSettings() + sessionSettings3.Set(config.BeginString, sessionId3.BeginString) + sessionSettings3.Set(config.SenderCompID, sessionId3.SenderCompID) + sessionSettings3.Set(config.SenderSubID, sessionId3.SenderSubID) + sessionSettings3.Set(config.SenderLocationID, sessionId3.SenderLocationID) + sessionSettings3.Set(config.TargetCompID, sessionId3.TargetCompID) + sessionSettings3.Set(config.TargetSubID, sessionId3.TargetSubID) + sessionSettings3.Set(config.TargetLocationID, sessionId3.TargetLocationID) + sessionSettings3.Set(config.ResetOnLogout, "Y") + sessionSettings3.Set(config.AcceptorTemplate, "Y") + suite.sessionId3 = sessionId3 + settings.AddSession(sessionSettings3) + + app := &noopApp{} + a, err := NewAcceptor(app, memoryStoreFactory{}, settings, NewScreenLogFactory()) + if err != nil { + suite.Fail("Failed to create acceptor: %v", err) + } + suite.acceptor = a + + templateMappings := make([]*TemplateMapping, 0) + templateMappings = append(templateMappings, &TemplateMapping{ + Pattern: suite.sessionId3, + TemplateID: suite.sessionId3, + }) + suite.dynamicSessionProvider = NewDynamicAcceptorSessionProvider(suite.acceptor.settings, suite.acceptor.storeFactory, suite.acceptor.logFactory, suite.acceptor.app, templateMappings) + suite.acceptor.SetSessionProvider(suite.dynamicSessionProvider) + + suite.testDynamicSessionID = SessionID{BeginString: BeginStringFIX43, SenderCompID: "target3", TargetCompID: "dynamicSender"} + suite.logonSessionID = SessionID{BeginString: BeginStringFIX43, SenderCompID: "dynamicSender", TargetCompID: "target3"} + if err := suite.acceptor.Start(); err != nil { + suite.FailNow("acceptor start failed: %v", err) + } + + suite.verifySessionCount(2) + suite.seqNum = 1 +} + +func (suite *AcceptorTemplateTestSuite) logonAndDisconnectAfterCheck(sessionID SessionID, + checkFuncAfterLogon func(), + mustHaveResponse bool) { + inboundMessages := []*Message{mockLogonMessage(sessionID, suite.seqNum)} + suite.seqNum++ + var respondedLogonMessageReceived bool + mockConn1 := &mockConn{ + closeChan: make(chan struct{}), + inboundMessages: inboundMessages, + localAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5001}, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5002}, + } + mockConn1.onWriteback = func(b []byte) { + reponseMsg := NewMessage() + err := ParseMessage(reponseMsg, bytes.NewBuffer(b)) + suite.Require().NoError(err, "parse responding message failed") + msgType, err := reponseMsg.Header.GetString(tagMsgType) + suite.Require().NoError(err, "unexpected mssage") + suite.Require().Equalf("A", msgType, "expected logon message in reponse %s", reponseMsg.String()) + respondedLogonMessageReceived = true + if checkFuncAfterLogon != nil { + checkFuncAfterLogon() + } + close(mockConn1.closeChan) + } + suite.acceptor.handleConnection(mockConn1) + if mustHaveResponse { + suite.Require().Equal(true, respondedLogonMessageReceived, "expected responding logon message") + } +} + +func (suite *AcceptorTemplateTestSuite) verifySessionCount(expectedSessionCount int) { + suite.Require().Equalf(expectedSessionCount, len(suite.acceptor.sessions), "expected %v sessions but found %v", expectedSessionCount, len(suite.acceptor.sessions)) + suite.Require().Equalf(expectedSessionCount, len(sessions), "expected %v sessions but found %v in registry", expectedSessionCount, len(suite.acceptor.sessions)) +} + +func (suite *AcceptorTemplateTestSuite) TestCreateDynamicSessionBySessionProvider() { + logonSessionID := suite.logonSessionID + suite.logonAndDisconnectAfterCheck(suite.testDynamicSessionID, func() { + suite.verifySessionCount(3) + + createdSession, ok := suite.acceptor.sessions[logonSessionID] + suite.Require().Equal(true, ok, "expected dynamic session to be created") + suite.Require().Equal(logonSessionID, createdSession.sessionID, "expected session ID to match inbound session ID") + suite.Require().Equal(createdSession.ResetOnLogout, true, "expected ResetOnLogout=Y for createdSession") + + remoteAddr, ok := suite.acceptor.RemoteAddr(logonSessionID) + if !ok { + suite.Fail("Failed to get remote address for dynamic session") + } + suite.Require().Equal("127.0.0.1:5002", remoteAddr.String(), "expect remoteAddr for dynamic session to be 127.0.0.1:5002 but got %v", remoteAddr.String()) + }, true) + suite.acceptor.Stop() +} + +func (suite *AcceptorTemplateTestSuite) TestSessionCreatedBySessionProviderShouldBeKept() { + logonSessionID := suite.logonSessionID + suite.logonAndDisconnectAfterCheck(suite.testDynamicSessionID, func() { + suite.verifySessionCount(3) + }, true) + err := SendToTarget(createFIX43NewOrderSingle(), logonSessionID) + suite.NoError(err, "expected message can still be sent after session disconnected") + suite.acceptor.Stop() +} + +func (suite *AcceptorTemplateTestSuite) TestNoNewSessionCreatedWhenSameSessionIDLogons() { + suite.logonAndDisconnectAfterCheck(suite.testDynamicSessionID, func() { + suite.verifySessionCount(3) + }, true) + suite.logonAndDisconnectAfterCheck(suite.testDynamicSessionID, func() { + suite.verifySessionCount(3) + }, true) + suite.logonAndDisconnectAfterCheck(suite.testDynamicSessionID, func() { + suite.verifySessionCount(3) + }, true) + suite.acceptor.Stop() +} + +func (suite *AcceptorTemplateTestSuite) TestSessionNotFoundBySessionProvider() { + sessionID := SessionID{BeginString: BeginStringFIX43, SenderCompID: "unknownSender", TargetCompID: "unknownTarget"} + suite.logonAndDisconnectAfterCheck(sessionID, func() {}, false) + suite.verifySessionCount(2) + suite.acceptor.Stop() +} + +func TestAcceptorTemplateTestSuite(t *testing.T) { + suite.Run(t, new(AcceptorTemplateTestSuite)) +} + +type DynamicSessionTestSuite struct { + suite.Suite +} + +func (suite *DynamicSessionTestSuite) TestDynamicSession() { + settings := NewSettings() + settings.globalSettings.Set(config.SocketAcceptPort, "5001") + settings.globalSettings.Set(config.DynamicSessions, "Y") + sessionId1 := SessionID{BeginString: BeginStringFIX42, SenderCompID: "sender1", TargetCompID: "target1"} + sessionSettings1 := NewSessionSettings() + sessionSettings1.Set(config.BeginString, sessionId1.BeginString) + sessionSettings1.Set(config.SenderCompID, sessionId1.SenderCompID) + sessionSettings1.Set(config.TargetCompID, sessionId1.TargetCompID) + settings.AddSession(sessionSettings1) + + a, err := NewAcceptor(&noopApp{}, memoryStoreFactory{}, settings, NewNullLogFactory()) + suite.Require().NoError(err, "create acceptor with DynamicSession=Y failed") + + if err := a.Start(); err != nil { + suite.FailNow("acceptor start failed: %v", err) + } + + inboundSessionID := SessionID{BeginString: BeginStringFIX43, SenderCompID: "X", TargetCompID: "Y"} + inboundMessages := []*Message{mockLogonMessage(inboundSessionID, 1)} + reversedInboundSessionID := SessionID{BeginString: BeginStringFIX43, SenderCompID: "Y", TargetCompID: "X"} + + mockConn1 := &mockConn{ + closeChan: make(chan struct{}), + inboundMessages: inboundMessages, + localAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5001}, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5002}, + } + + var respondedLogonMessageReceived bool + mockConn1.onWriteback = func(_ []byte) { + respondedLogonMessageReceived = true + // close conn + close(mockConn1.closeChan) + } + + a.handleConnection(mockConn1) + suite.Require().Equal(true, respondedLogonMessageReceived, "expected responding logon message") + err = SendToTarget(createFIX43NewOrderSingle(), reversedInboundSessionID) + suite.Error(err, "session created by DynamicSession is unregistered after session connected") + a.Stop() +} + +func TestDynamicSessionTestSuite(t *testing.T) { + suite.Run(t, new(DynamicSessionTestSuite)) +} diff --git a/config/configuration.go b/config/configuration.go index f524f5005..506641cf9 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -697,6 +697,9 @@ const ( // - Y // - N DynamicQualifier string = "DynamicQualifier" + + // AcceptorTemplate designates a template Acceptor session. + AcceptorTemplate string = "AcceptorTemplate" ) const (