Skip to content

Commit

Permalink
add feature: acceptorTemplate
Browse files Browse the repository at this point in the history
  • Loading branch information
haoyang1994 committed Aug 14, 2024
1 parent c07597e commit 80a93d4
Show file tree
Hide file tree
Showing 6 changed files with 823 additions and 113 deletions.
85 changes: 0 additions & 85 deletions accepter_test.go

This file was deleted.

113 changes: 85 additions & 28 deletions acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type Acceptor struct {
storeFactory MessageStoreFactory
globalLog Log
sessions map[SessionID]*session
sessionsLock sync.RWMutex
sessionGroup sync.WaitGroup
listenerShutdown sync.WaitGroup
dynamicSessions bool
Expand All @@ -48,6 +49,7 @@ type Acceptor struct {
sessionHostPort map[SessionID]int
listeners map[string]net.Listener
connectionValidator ConnectionValidator
sessionProvider AcceptorSessionProvider
sessionFactory
}

Expand Down Expand Up @@ -104,14 +106,8 @@ func (a *Acceptor) Start() (err error) {
a.listeners[address] = &proxyproto.Listener{Listener: a.listeners[address]}
}
}
a.startSessions()

for _, s := range a.sessions {
a.sessionGroup.Add(1)
go func(s *session) {
s.run()
a.sessionGroup.Done()
}(s)
}
if a.dynamicSessions {
a.dynamicSessionChan = make(chan *session)
a.sessionGroup.Add(1)
Expand Down Expand Up @@ -140,17 +136,7 @@ func (a *Acceptor) Stop() {
if a.dynamicSessions {
close(a.dynamicSessionChan)
}
for _, session := range a.sessions {
session.stop()
}
a.sessionGroup.Wait()

for sessionID := range a.sessions {
err := UnregisterSession(sessionID)
if err != nil {
return
}
}
a.stopSessions()
}

// RemoteAddr gets remote IP address for a given session.
Expand Down Expand Up @@ -191,6 +177,15 @@ func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Se
}

for sessionID, sessionSettings := range settings.SessionSettings() {
if sessionSettings.HasSetting(config.AcceptorTemplate) {
var acceptorTemplate bool
if acceptorTemplate, err = sessionSettings.BoolSetting(config.AcceptorTemplate); err != nil {
return
}
if acceptorTemplate {
continue
}
}
sessID := sessionID
sessID.Qualifier = ""

Expand Down Expand Up @@ -331,18 +326,35 @@ func (a *Acceptor) handleConnection(netConn net.Conn) {
}
session, ok := a.sessions[sessID]
if !ok {
if !a.dynamicSessions {
a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes)
return
var dynamicSessionCreated bool
if a.sessionProvider != nil {
session, err = a.sessionProvider.GetSession(sessID)
if err != nil {
if err == errUnknownSession && a.dynamicSessions {
goto CREATE_SHORT_LIVED_DYNAMIC_SESSION
}
a.globalLog.OnEventf("Failed to get session %v from provider: %v", sessID, err)
return
}
a.addMngdDynamicSession(sessID, session)
dynamicSessionCreated = true
}
dynamicSession, err := a.sessionFactory.createSession(sessID, a.storeFactory, a.settings.globalSettings.clone(), a.logFactory, a.app)
if err != nil {
a.globalLog.OnEventf("Dynamic session %v failed to create: %v", sessID, err)
return
CREATE_SHORT_LIVED_DYNAMIC_SESSION:
if !dynamicSessionCreated {
if !a.dynamicSessions {
a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes)
return
}
dynamicSession, err := a.sessionFactory.createSession(sessID, a.storeFactory, a.settings.globalSettings.clone(), a.logFactory, a.app)
if err != nil {
a.globalLog.OnEventf("Dynamic session %v failed to create: %v", sessID, err)
return
}
a.dynamicSessionChan <- dynamicSession
session = dynamicSession
defer session.stop()
}
a.dynamicSessionChan <- dynamicSession
session = dynamicSession
defer session.stop()

}

a.sessionAddr.Store(sessID, netConn.RemoteAddr())
Expand Down Expand Up @@ -412,6 +424,46 @@ LOOP:
}
}

func (a *Acceptor) startSessions() {
a.sessionsLock.RLock()
defer a.sessionsLock.RUnlock()
for _, s := range a.sessions {
a.sessionGroup.Add(1)
go func(s *session) {
s.run()
a.sessionGroup.Done()
}(s)
}
}

func (a *Acceptor) stopSessions() {
a.sessionsLock.RLock()
defer a.sessionsLock.RUnlock()
for _, session := range a.sessions {
session.stop()
}
a.sessionGroup.Wait()

for sessionID := range a.sessions {
err := UnregisterSession(sessionID)
if err != nil {
return
}
}
}

func (a *Acceptor) addMngdDynamicSession(sessID SessionID, session *session) {
a.sessionsLock.Lock()
defer a.sessionsLock.Unlock()

a.sessions[sessID] = session
a.sessionGroup.Add(1)
go func() {
session.run()
a.sessionGroup.Done()
}()
}

// SetConnectionValidator sets an optional connection validator.
// Use it when you need a custom authentication logic that includes lower level interactions,
// like mTLS auth or IP whitelistening.
Expand All @@ -421,3 +473,8 @@ LOOP:
func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) {
a.connectionValidator = validator
}

// SetSessionProvider sets an optional session provider.
func (a *Acceptor) SetSessionProvider(sessionProvider AcceptorSessionProvider) {
a.sessionProvider = sessionProvider
}
115 changes: 115 additions & 0 deletions acceptor_session_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package quickfix

import "github.com/quickfixgo/quickfix/config"

const (
WildcardPattern string = "*"
)

type AcceptorSessionProvider interface {
GetSession(SessionID) (*session, error)
}

type StaticAcceptorSessionProvider struct {
sessions map[SessionID]*session
}

func (p *StaticAcceptorSessionProvider) GetSession(sessionID SessionID) (*session, error) {
s, ok := p.sessions[sessionID]
if !ok {
return nil, errUnknownSession
}
return s, nil
}

// DynamicAcceptorSessionProvider dynamically defines sessions for an acceptor. This can be useful for
// applications like simulators that want to accept any connection and
// dynamically create an associated session.
//
// For more complex situations, you can use this class as a starting
// point for implementing your own AcceptorSessionProvider.
type DynamicAcceptorSessionProvider struct {
settings *Settings
messageStoreFactory MessageStoreFactory
logFactory LogFactory
sessionFactory *sessionFactory
application Application
templateMappings []*TemplateMapping
}

func NewDynamicAcceptorSessionProvider(settings *Settings, messageStoreFactory MessageStoreFactory, logFactory LogFactory,
application Application, templateMappings []*TemplateMapping,
) *DynamicAcceptorSessionProvider {
return &DynamicAcceptorSessionProvider{
settings: settings,
messageStoreFactory: messageStoreFactory,
logFactory: logFactory,
sessionFactory: &sessionFactory{},
application: application,
templateMappings: templateMappings,
}
}

func (p *DynamicAcceptorSessionProvider) FindTemplateID(sessionID SessionID) *SessionID {
return p.lookupTemplateID(sessionID)
}

func (p *DynamicAcceptorSessionProvider) GetSession(sessionID SessionID) (*session, error) {
s, ok := lookupSession(sessionID)
if !ok {
templateID := p.lookupTemplateID(sessionID)
if templateID == nil {
return nil, errUnknownSession
}
dynamicSessionSettings := p.settings.globalSettings.clone()
templateSettings := p.settings.sessionSettings[*templateID]
dynamicSessionSettings.overlay(templateSettings)
dynamicSessionSettings.Set(config.BeginString, sessionID.BeginString)
dynamicSessionSettings.Set(config.SenderCompID, sessionID.SenderCompID)
dynamicSessionSettings.Set(config.SenderSubID, sessionID.SenderSubID)
dynamicSessionSettings.Set(config.SenderLocationID, sessionID.SenderLocationID)
dynamicSessionSettings.Set(config.TargetCompID, sessionID.TargetCompID)
dynamicSessionSettings.Set(config.TargetSubID, sessionID.TargetSubID)
dynamicSessionSettings.Set(config.TargetLocationID, sessionID.TargetLocationID)
var err error
s, err = p.sessionFactory.createSession(sessionID,
p.messageStoreFactory,
dynamicSessionSettings,
p.logFactory,
p.application,
)
if err != nil {
return nil, err
}
}
return s, nil
}

func (provider *DynamicAcceptorSessionProvider) lookupTemplateID(sessionID SessionID) *SessionID {
for _, mapping := range provider.templateMappings {
if isTemplateMatching(mapping.Pattern, sessionID) {
return &mapping.TemplateID
}
}
return nil
}

func isTemplateMatching(pattern SessionID, sessionID SessionID) bool {
return matches(pattern.BeginString, sessionID.BeginString) &&
matches(pattern.SenderCompID, sessionID.SenderCompID) &&
matches(pattern.SenderSubID, sessionID.SenderSubID) &&
matches(pattern.SenderLocationID, sessionID.SenderLocationID) &&
matches(pattern.TargetCompID, sessionID.TargetCompID) &&
matches(pattern.TargetSubID, sessionID.TargetSubID) &&
matches(pattern.TargetLocationID, sessionID.TargetLocationID)
}

func matches(pattern string, value string) bool {
return WildcardPattern == pattern || pattern == value
}

// TemplateMapping mapping from a sessionID pattern to a session template ID.
type TemplateMapping struct {
Pattern SessionID
TemplateID SessionID
}
Loading

0 comments on commit 80a93d4

Please sign in to comment.