Skip to content

Commit

Permalink
add connection pool so we don't leak connections
Browse files Browse the repository at this point in the history
  • Loading branch information
haisum committed Nov 13, 2023
1 parent b6bb02b commit 3f17554
Show file tree
Hide file tree
Showing 11 changed files with 491 additions and 108 deletions.
273 changes: 273 additions & 0 deletions controllers/jetstream/conn_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
package jetstream

import (
"crypto/sha256"
"crypto/tls"
"encoding/json"
"fmt"
"os"
"sync"

"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
)

type natsContext struct {
Name string `json:"name"`
URL string `json:"url"`
JWT string `json:"jwt"`
Seed string `json:"seed"`
Credentials string `json:"credential"`
Nkey string `json:"nkey"`
Token string `json:"token"`
Username string `json:"username"`
Password string `json:"password"`
TLSCAs []string `json:"tls_ca"`
TLSCert string `json:"tls_cert"`
TLSKey string `json:"tls_key"`
}

func (c *natsContext) copy() *natsContext {
if c == nil {
return nil
}
cp := *c
return &cp
}

func (c *natsContext) hash() (string, error) {
b, err := json.Marshal(c)
if err != nil {
return "", fmt.Errorf("error marshaling context to json: %v", err)
}
if c.Nkey != "" {
fb, err := os.ReadFile(c.Nkey)
if err != nil {
return "", fmt.Errorf("error opening nkey file %s: %v", c.Nkey, err)
}
b = append(b, fb...)
}
if c.Credentials != "" {
fb, err := os.ReadFile(c.Credentials)
if err != nil {
return "", fmt.Errorf("error opening creds file %s: %v", c.Credentials, err)
}
b = append(b, fb...)
}
if len(c.TLSCAs) > 0 {
for _, cert := range c.TLSCAs {
fb, err := os.ReadFile(cert)
if err != nil {
return "", fmt.Errorf("error opening ca file %s: %v", cert, err)
}
b = append(b, fb...)
}
}
if c.TLSCert != "" {
fb, err := os.ReadFile(c.TLSCert)
if err != nil {
return "", fmt.Errorf("error opening cert file %s: %v", c.TLSCert, err)
}
b = append(b, fb...)
}
if c.TLSKey != "" {
fb, err := os.ReadFile(c.TLSKey)
if err != nil {
return "", fmt.Errorf("error opening key file %s: %v", c.TLSKey, err)
}
b = append(b, fb...)
}
hash := sha256.New()
hash.Write(b)
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}

type natsContextDefaults struct {
Name string
URL string
TLSCAs []string
TLSCert string
TLSKey string
TLSConfig *tls.Config
}

type pooledNatsConn struct {
nc *nats.Conn
cp *natsConnPool
key string
count uint64
closed bool
}

func (pc *pooledNatsConn) ReturnToPool() {
pc.cp.Lock()
pc.count--
if pc.count == 0 {
if pooledConn, ok := pc.cp.cache[pc.key]; ok && pc == pooledConn {
delete(pc.cp.cache, pc.key)
}
pc.closed = true
pc.cp.Unlock()
pc.nc.Close()
return
}
pc.cp.Unlock()
}

type natsConnPool struct {
sync.Mutex
cache map[string]*pooledNatsConn
logger *logrus.Logger
group *singleflight.Group
natsDefaults *natsContextDefaults
natsOpts []nats.Option
}

func newNatsConnPool(logger *logrus.Logger, natsDefaults *natsContextDefaults, natsOpts []nats.Option) *natsConnPool {
return &natsConnPool{
cache: map[string]*pooledNatsConn{},
group: &singleflight.Group{},
logger: logger,
natsDefaults: natsDefaults,
natsOpts: natsOpts,
}
}

const getPooledConnMaxTries = 10

// Get returns a *pooledNatsConn
func (cp *natsConnPool) Get(cfg *natsContext) (*pooledNatsConn, error) {
if cfg == nil {
return nil, fmt.Errorf("nats context must not be nil")
}

// copy cfg
cfg = cfg.copy()

// set defaults
if cfg.Name == "" {
cfg.Name = cp.natsDefaults.Name
}
if cfg.URL == "" {
cfg.URL = cp.natsDefaults.URL
}
if len(cfg.TLSCAs) == 0 {
cfg.TLSCAs = cp.natsDefaults.TLSCAs
}
if cfg.TLSCert == "" {
cfg.TLSCert = cp.natsDefaults.TLSCert
}
if cfg.TLSKey == "" {
cfg.TLSKey = cp.natsDefaults.TLSKey
}

// get hash
key, err := cfg.hash()
if err != nil {
return nil, err
}

for i := 0; i < getPooledConnMaxTries; i++ {
connection, err := cp.getPooledConn(key, cfg)
if err != nil {
return nil, err
}

cp.Lock()
if connection.closed {
// ReturnToPool closed this while lock not held, try again
cp.Unlock()
continue
}

// increment count out of the pool
connection.count++
cp.Unlock()
return connection, nil
}

return nil, fmt.Errorf("failed to get pooled connection after %d attempts", getPooledConnMaxTries)
}

// getPooledConn gets or establishes a *pooledNatsConn in a singleflight group, but does not increment its count
func (cp *natsConnPool) getPooledConn(key string, cfg *natsContext) (*pooledNatsConn, error) {
conn, err, _ := cp.group.Do(key, func() (interface{}, error) {
cp.Lock()
pooledConn, ok := cp.cache[key]
if ok && pooledConn.nc.IsConnected() {
cp.Unlock()
return pooledConn, nil
}
cp.Unlock()

opts := cp.natsOpts
opts = append(opts, func(options *nats.Options) error {
if cfg.Name != "" {
options.Name = cfg.Name
}
if cfg.Token != "" {
options.Token = cfg.Token
}
if cfg.Username != "" {
options.User = cfg.Username
}
if cfg.Password != "" {
options.Password = cfg.Password
}
return nil
})

if cfg.JWT != "" && cfg.Seed != "" {
opts = append(opts, nats.UserJWTAndSeed(cfg.JWT, cfg.Seed))
}

if cfg.Nkey != "" {
opt, err := nats.NkeyOptionFromSeed(cfg.Nkey)
if err != nil {
return nil, fmt.Errorf("unable to load nkey: %v", err)
}
opts = append(opts, opt)
}

if cfg.Credentials != "" {
opts = append(opts, nats.UserCredentials(cfg.Credentials))
}

if len(cfg.TLSCAs) > 0 {
opts = append(opts, nats.RootCAs(cfg.TLSCAs...))
}

if cfg.TLSCert != "" && cfg.TLSKey != "" {
opts = append(opts, nats.ClientCert(cfg.TLSCert, cfg.TLSKey))
}

nc, err := nats.Connect(cfg.URL, opts...)
if err != nil {
return nil, err
}
cp.logger.Infof("%s connected to NATS Deployment: %s", cfg.Name, nc.ConnectedAddr())

connection := &pooledNatsConn{
nc: nc,
cp: cp,
key: key,
}

cp.Lock()
cp.cache[key] = connection
cp.Unlock()

return connection, err
})

if err != nil {
return nil, err
}

connection, ok := conn.(*pooledNatsConn)
if !ok {
return nil, fmt.Errorf("not a pooledNatsConn")
}
return connection, nil
}
92 changes: 92 additions & 0 deletions controllers/jetstream/conn_pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package jetstream

import (
"sync"
"testing"
"time"

"github.com/nats-io/nats.go"

natsservertest "github.com/nats-io/nats-server/v2/test"
"github.com/sirupsen/logrus"
testifyAssert "github.com/stretchr/testify/assert"
)

func TestConnPool(t *testing.T) {
t.Parallel()

s := natsservertest.RunRandClientPortServer()
defer s.Shutdown()
o1 := &natsContext{
Name: "Client 1",
}
o2 := &natsContext{
Name: "Client 1",
}
o3 := &natsContext{
Name: "Client 2",
}

natsDefaults := &natsContextDefaults{
URL: s.ClientURL(),
}
natsOptions := []nats.Option{
nats.MaxReconnects(10240),
}
cp := newNatsConnPool(logrus.New(), natsDefaults, natsOptions)

var c1, c2, c3 *pooledNatsConn
var c1e, c2e, c3e error
wg := &sync.WaitGroup{}
wg.Add(3)
go func() {
c1, c1e = cp.Get(o1)
wg.Done()
}()
go func() {
c2, c2e = cp.Get(o2)
wg.Done()
}()
go func() {
c3, c3e = cp.Get(o3)
wg.Done()
}()
wg.Wait()

assert := testifyAssert.New(t)
if assert.NoError(c1e) && assert.NoError(c2e) {
assert.Same(c1, c2)
}
if assert.NoError(c3e) {
assert.NotSame(c1, c3)
assert.NotSame(c2, c3)
}

c1.ReturnToPool()
c3.ReturnToPool()
time.Sleep(1 * time.Second)
assert.False(c1.nc.IsClosed())
assert.False(c2.nc.IsClosed())
assert.True(c3.nc.IsClosed())

c4, c4e := cp.Get(o1)
if assert.NoError(c4e) {
assert.Same(c2, c4)
}

c2.ReturnToPool()
c4.ReturnToPool()
time.Sleep(1 * time.Second)
assert.True(c1.nc.IsClosed())
assert.True(c2.nc.IsClosed())
assert.True(c4.nc.IsClosed())

c5, c5e := cp.Get(o1)
if assert.NoError(c5e) {
assert.NotSame(c1, c5)
}

c5.ReturnToPool()
time.Sleep(1 * time.Second)
assert.True(c5.nc.IsClosed())
}
Loading

0 comments on commit 3f17554

Please sign in to comment.