-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add connection pool so we don't leak connections
- Loading branch information
haisum
committed
Nov 13, 2023
1 parent
b6bb02b
commit 3f17554
Showing
11 changed files
with
491 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
package jetstream | ||
|
||
import ( | ||
"crypto/sha256" | ||
"crypto/tls" | ||
"encoding/json" | ||
"fmt" | ||
"os" | ||
"sync" | ||
|
||
"github.com/nats-io/nats.go" | ||
"github.com/sirupsen/logrus" | ||
"golang.org/x/sync/singleflight" | ||
) | ||
|
||
type natsContext struct { | ||
Name string `json:"name"` | ||
URL string `json:"url"` | ||
JWT string `json:"jwt"` | ||
Seed string `json:"seed"` | ||
Credentials string `json:"credential"` | ||
Nkey string `json:"nkey"` | ||
Token string `json:"token"` | ||
Username string `json:"username"` | ||
Password string `json:"password"` | ||
TLSCAs []string `json:"tls_ca"` | ||
TLSCert string `json:"tls_cert"` | ||
TLSKey string `json:"tls_key"` | ||
} | ||
|
||
func (c *natsContext) copy() *natsContext { | ||
if c == nil { | ||
return nil | ||
} | ||
cp := *c | ||
return &cp | ||
} | ||
|
||
func (c *natsContext) hash() (string, error) { | ||
b, err := json.Marshal(c) | ||
if err != nil { | ||
return "", fmt.Errorf("error marshaling context to json: %v", err) | ||
} | ||
if c.Nkey != "" { | ||
fb, err := os.ReadFile(c.Nkey) | ||
if err != nil { | ||
return "", fmt.Errorf("error opening nkey file %s: %v", c.Nkey, err) | ||
} | ||
b = append(b, fb...) | ||
} | ||
if c.Credentials != "" { | ||
fb, err := os.ReadFile(c.Credentials) | ||
if err != nil { | ||
return "", fmt.Errorf("error opening creds file %s: %v", c.Credentials, err) | ||
} | ||
b = append(b, fb...) | ||
} | ||
if len(c.TLSCAs) > 0 { | ||
for _, cert := range c.TLSCAs { | ||
fb, err := os.ReadFile(cert) | ||
if err != nil { | ||
return "", fmt.Errorf("error opening ca file %s: %v", cert, err) | ||
} | ||
b = append(b, fb...) | ||
} | ||
} | ||
if c.TLSCert != "" { | ||
fb, err := os.ReadFile(c.TLSCert) | ||
if err != nil { | ||
return "", fmt.Errorf("error opening cert file %s: %v", c.TLSCert, err) | ||
} | ||
b = append(b, fb...) | ||
} | ||
if c.TLSKey != "" { | ||
fb, err := os.ReadFile(c.TLSKey) | ||
if err != nil { | ||
return "", fmt.Errorf("error opening key file %s: %v", c.TLSKey, err) | ||
} | ||
b = append(b, fb...) | ||
} | ||
hash := sha256.New() | ||
hash.Write(b) | ||
return fmt.Sprintf("%x", hash.Sum(nil)), nil | ||
} | ||
|
||
type natsContextDefaults struct { | ||
Name string | ||
URL string | ||
TLSCAs []string | ||
TLSCert string | ||
TLSKey string | ||
TLSConfig *tls.Config | ||
} | ||
|
||
type pooledNatsConn struct { | ||
nc *nats.Conn | ||
cp *natsConnPool | ||
key string | ||
count uint64 | ||
closed bool | ||
} | ||
|
||
func (pc *pooledNatsConn) ReturnToPool() { | ||
pc.cp.Lock() | ||
pc.count-- | ||
if pc.count == 0 { | ||
if pooledConn, ok := pc.cp.cache[pc.key]; ok && pc == pooledConn { | ||
delete(pc.cp.cache, pc.key) | ||
} | ||
pc.closed = true | ||
pc.cp.Unlock() | ||
pc.nc.Close() | ||
return | ||
} | ||
pc.cp.Unlock() | ||
} | ||
|
||
type natsConnPool struct { | ||
sync.Mutex | ||
cache map[string]*pooledNatsConn | ||
logger *logrus.Logger | ||
group *singleflight.Group | ||
natsDefaults *natsContextDefaults | ||
natsOpts []nats.Option | ||
} | ||
|
||
func newNatsConnPool(logger *logrus.Logger, natsDefaults *natsContextDefaults, natsOpts []nats.Option) *natsConnPool { | ||
return &natsConnPool{ | ||
cache: map[string]*pooledNatsConn{}, | ||
group: &singleflight.Group{}, | ||
logger: logger, | ||
natsDefaults: natsDefaults, | ||
natsOpts: natsOpts, | ||
} | ||
} | ||
|
||
const getPooledConnMaxTries = 10 | ||
|
||
// Get returns a *pooledNatsConn | ||
func (cp *natsConnPool) Get(cfg *natsContext) (*pooledNatsConn, error) { | ||
if cfg == nil { | ||
return nil, fmt.Errorf("nats context must not be nil") | ||
} | ||
|
||
// copy cfg | ||
cfg = cfg.copy() | ||
|
||
// set defaults | ||
if cfg.Name == "" { | ||
cfg.Name = cp.natsDefaults.Name | ||
} | ||
if cfg.URL == "" { | ||
cfg.URL = cp.natsDefaults.URL | ||
} | ||
if len(cfg.TLSCAs) == 0 { | ||
cfg.TLSCAs = cp.natsDefaults.TLSCAs | ||
} | ||
if cfg.TLSCert == "" { | ||
cfg.TLSCert = cp.natsDefaults.TLSCert | ||
} | ||
if cfg.TLSKey == "" { | ||
cfg.TLSKey = cp.natsDefaults.TLSKey | ||
} | ||
|
||
// get hash | ||
key, err := cfg.hash() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
for i := 0; i < getPooledConnMaxTries; i++ { | ||
connection, err := cp.getPooledConn(key, cfg) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
cp.Lock() | ||
if connection.closed { | ||
// ReturnToPool closed this while lock not held, try again | ||
cp.Unlock() | ||
continue | ||
} | ||
|
||
// increment count out of the pool | ||
connection.count++ | ||
cp.Unlock() | ||
return connection, nil | ||
} | ||
|
||
return nil, fmt.Errorf("failed to get pooled connection after %d attempts", getPooledConnMaxTries) | ||
} | ||
|
||
// getPooledConn gets or establishes a *pooledNatsConn in a singleflight group, but does not increment its count | ||
func (cp *natsConnPool) getPooledConn(key string, cfg *natsContext) (*pooledNatsConn, error) { | ||
conn, err, _ := cp.group.Do(key, func() (interface{}, error) { | ||
cp.Lock() | ||
pooledConn, ok := cp.cache[key] | ||
if ok && pooledConn.nc.IsConnected() { | ||
cp.Unlock() | ||
return pooledConn, nil | ||
} | ||
cp.Unlock() | ||
|
||
opts := cp.natsOpts | ||
opts = append(opts, func(options *nats.Options) error { | ||
if cfg.Name != "" { | ||
options.Name = cfg.Name | ||
} | ||
if cfg.Token != "" { | ||
options.Token = cfg.Token | ||
} | ||
if cfg.Username != "" { | ||
options.User = cfg.Username | ||
} | ||
if cfg.Password != "" { | ||
options.Password = cfg.Password | ||
} | ||
return nil | ||
}) | ||
|
||
if cfg.JWT != "" && cfg.Seed != "" { | ||
opts = append(opts, nats.UserJWTAndSeed(cfg.JWT, cfg.Seed)) | ||
} | ||
|
||
if cfg.Nkey != "" { | ||
opt, err := nats.NkeyOptionFromSeed(cfg.Nkey) | ||
if err != nil { | ||
return nil, fmt.Errorf("unable to load nkey: %v", err) | ||
} | ||
opts = append(opts, opt) | ||
} | ||
|
||
if cfg.Credentials != "" { | ||
opts = append(opts, nats.UserCredentials(cfg.Credentials)) | ||
} | ||
|
||
if len(cfg.TLSCAs) > 0 { | ||
opts = append(opts, nats.RootCAs(cfg.TLSCAs...)) | ||
} | ||
|
||
if cfg.TLSCert != "" && cfg.TLSKey != "" { | ||
opts = append(opts, nats.ClientCert(cfg.TLSCert, cfg.TLSKey)) | ||
} | ||
|
||
nc, err := nats.Connect(cfg.URL, opts...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
cp.logger.Infof("%s connected to NATS Deployment: %s", cfg.Name, nc.ConnectedAddr()) | ||
|
||
connection := &pooledNatsConn{ | ||
nc: nc, | ||
cp: cp, | ||
key: key, | ||
} | ||
|
||
cp.Lock() | ||
cp.cache[key] = connection | ||
cp.Unlock() | ||
|
||
return connection, err | ||
}) | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
connection, ok := conn.(*pooledNatsConn) | ||
if !ok { | ||
return nil, fmt.Errorf("not a pooledNatsConn") | ||
} | ||
return connection, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package jetstream | ||
|
||
import ( | ||
"sync" | ||
"testing" | ||
"time" | ||
|
||
"github.com/nats-io/nats.go" | ||
|
||
natsservertest "github.com/nats-io/nats-server/v2/test" | ||
"github.com/sirupsen/logrus" | ||
testifyAssert "github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestConnPool(t *testing.T) { | ||
t.Parallel() | ||
|
||
s := natsservertest.RunRandClientPortServer() | ||
defer s.Shutdown() | ||
o1 := &natsContext{ | ||
Name: "Client 1", | ||
} | ||
o2 := &natsContext{ | ||
Name: "Client 1", | ||
} | ||
o3 := &natsContext{ | ||
Name: "Client 2", | ||
} | ||
|
||
natsDefaults := &natsContextDefaults{ | ||
URL: s.ClientURL(), | ||
} | ||
natsOptions := []nats.Option{ | ||
nats.MaxReconnects(10240), | ||
} | ||
cp := newNatsConnPool(logrus.New(), natsDefaults, natsOptions) | ||
|
||
var c1, c2, c3 *pooledNatsConn | ||
var c1e, c2e, c3e error | ||
wg := &sync.WaitGroup{} | ||
wg.Add(3) | ||
go func() { | ||
c1, c1e = cp.Get(o1) | ||
wg.Done() | ||
}() | ||
go func() { | ||
c2, c2e = cp.Get(o2) | ||
wg.Done() | ||
}() | ||
go func() { | ||
c3, c3e = cp.Get(o3) | ||
wg.Done() | ||
}() | ||
wg.Wait() | ||
|
||
assert := testifyAssert.New(t) | ||
if assert.NoError(c1e) && assert.NoError(c2e) { | ||
assert.Same(c1, c2) | ||
} | ||
if assert.NoError(c3e) { | ||
assert.NotSame(c1, c3) | ||
assert.NotSame(c2, c3) | ||
} | ||
|
||
c1.ReturnToPool() | ||
c3.ReturnToPool() | ||
time.Sleep(1 * time.Second) | ||
assert.False(c1.nc.IsClosed()) | ||
assert.False(c2.nc.IsClosed()) | ||
assert.True(c3.nc.IsClosed()) | ||
|
||
c4, c4e := cp.Get(o1) | ||
if assert.NoError(c4e) { | ||
assert.Same(c2, c4) | ||
} | ||
|
||
c2.ReturnToPool() | ||
c4.ReturnToPool() | ||
time.Sleep(1 * time.Second) | ||
assert.True(c1.nc.IsClosed()) | ||
assert.True(c2.nc.IsClosed()) | ||
assert.True(c4.nc.IsClosed()) | ||
|
||
c5, c5e := cp.Get(o1) | ||
if assert.NoError(c5e) { | ||
assert.NotSame(c1, c5) | ||
} | ||
|
||
c5.ReturnToPool() | ||
time.Sleep(1 * time.Second) | ||
assert.True(c5.nc.IsClosed()) | ||
} |
Oops, something went wrong.