diff --git a/controllers/jetstream/conn_pool.go b/controllers/jetstream/conn_pool.go new file mode 100644 index 00000000..ddcb006e --- /dev/null +++ b/controllers/jetstream/conn_pool.go @@ -0,0 +1,273 @@ +package jetstream + +import ( + "crypto/sha256" + "crypto/tls" + "encoding/json" + "fmt" + "os" + "sync" + + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + "golang.org/x/sync/singleflight" +) + +type natsContext struct { + Name string `json:"name"` + URL string `json:"url"` + JWT string `json:"jwt"` + Seed string `json:"seed"` + Credentials string `json:"credential"` + Nkey string `json:"nkey"` + Token string `json:"token"` + Username string `json:"username"` + Password string `json:"password"` + TLSCAs []string `json:"tls_ca"` + TLSCert string `json:"tls_cert"` + TLSKey string `json:"tls_key"` +} + +func (c *natsContext) copy() *natsContext { + if c == nil { + return nil + } + cp := *c + return &cp +} + +func (c *natsContext) hash() (string, error) { + b, err := json.Marshal(c) + if err != nil { + return "", fmt.Errorf("error marshaling context to json: %v", err) + } + if c.Nkey != "" { + fb, err := os.ReadFile(c.Nkey) + if err != nil { + return "", fmt.Errorf("error opening nkey file %s: %v", c.Nkey, err) + } + b = append(b, fb...) + } + if c.Credentials != "" { + fb, err := os.ReadFile(c.Credentials) + if err != nil { + return "", fmt.Errorf("error opening creds file %s: %v", c.Credentials, err) + } + b = append(b, fb...) + } + if len(c.TLSCAs) > 0 { + for _, cert := range c.TLSCAs { + fb, err := os.ReadFile(cert) + if err != nil { + return "", fmt.Errorf("error opening ca file %s: %v", cert, err) + } + b = append(b, fb...) + } + } + if c.TLSCert != "" { + fb, err := os.ReadFile(c.TLSCert) + if err != nil { + return "", fmt.Errorf("error opening cert file %s: %v", c.TLSCert, err) + } + b = append(b, fb...) + } + if c.TLSKey != "" { + fb, err := os.ReadFile(c.TLSKey) + if err != nil { + return "", fmt.Errorf("error opening key file %s: %v", c.TLSKey, err) + } + b = append(b, fb...) + } + hash := sha256.New() + hash.Write(b) + return fmt.Sprintf("%x", hash.Sum(nil)), nil +} + +type natsContextDefaults struct { + Name string + URL string + TLSCAs []string + TLSCert string + TLSKey string + TLSConfig *tls.Config +} + +type pooledNatsConn struct { + nc *nats.Conn + cp *natsConnPool + key string + count uint64 + closed bool +} + +func (pc *pooledNatsConn) ReturnToPool() { + pc.cp.Lock() + pc.count-- + if pc.count == 0 { + if pooledConn, ok := pc.cp.cache[pc.key]; ok && pc == pooledConn { + delete(pc.cp.cache, pc.key) + } + pc.closed = true + pc.cp.Unlock() + pc.nc.Close() + return + } + pc.cp.Unlock() +} + +type natsConnPool struct { + sync.Mutex + cache map[string]*pooledNatsConn + logger *logrus.Logger + group *singleflight.Group + natsDefaults *natsContextDefaults + natsOpts []nats.Option +} + +func newNatsConnPool(logger *logrus.Logger, natsDefaults *natsContextDefaults, natsOpts []nats.Option) *natsConnPool { + return &natsConnPool{ + cache: map[string]*pooledNatsConn{}, + group: &singleflight.Group{}, + logger: logger, + natsDefaults: natsDefaults, + natsOpts: natsOpts, + } +} + +const getPooledConnMaxTries = 10 + +// Get returns a *pooledNatsConn +func (cp *natsConnPool) Get(cfg *natsContext) (*pooledNatsConn, error) { + if cfg == nil { + return nil, fmt.Errorf("nats context must not be nil") + } + + // copy cfg + cfg = cfg.copy() + + // set defaults + if cfg.Name == "" { + cfg.Name = cp.natsDefaults.Name + } + if cfg.URL == "" { + cfg.URL = cp.natsDefaults.URL + } + if len(cfg.TLSCAs) == 0 { + cfg.TLSCAs = cp.natsDefaults.TLSCAs + } + if cfg.TLSCert == "" { + cfg.TLSCert = cp.natsDefaults.TLSCert + } + if cfg.TLSKey == "" { + cfg.TLSKey = cp.natsDefaults.TLSKey + } + + // get hash + key, err := cfg.hash() + if err != nil { + return nil, err + } + + for i := 0; i < getPooledConnMaxTries; i++ { + connection, err := cp.getPooledConn(key, cfg) + if err != nil { + return nil, err + } + + cp.Lock() + if connection.closed { + // ReturnToPool closed this while lock not held, try again + cp.Unlock() + continue + } + + // increment count out of the pool + connection.count++ + cp.Unlock() + return connection, nil + } + + return nil, fmt.Errorf("failed to get pooled connection after %d attempts", getPooledConnMaxTries) +} + +// getPooledConn gets or establishes a *pooledNatsConn in a singleflight group, but does not increment its count +func (cp *natsConnPool) getPooledConn(key string, cfg *natsContext) (*pooledNatsConn, error) { + conn, err, _ := cp.group.Do(key, func() (interface{}, error) { + cp.Lock() + pooledConn, ok := cp.cache[key] + if ok && pooledConn.nc.IsConnected() { + cp.Unlock() + return pooledConn, nil + } + cp.Unlock() + + opts := cp.natsOpts + opts = append(opts, func(options *nats.Options) error { + if cfg.Name != "" { + options.Name = cfg.Name + } + if cfg.Token != "" { + options.Token = cfg.Token + } + if cfg.Username != "" { + options.User = cfg.Username + } + if cfg.Password != "" { + options.Password = cfg.Password + } + return nil + }) + + if cfg.JWT != "" && cfg.Seed != "" { + opts = append(opts, nats.UserJWTAndSeed(cfg.JWT, cfg.Seed)) + } + + if cfg.Nkey != "" { + opt, err := nats.NkeyOptionFromSeed(cfg.Nkey) + if err != nil { + return nil, fmt.Errorf("unable to load nkey: %v", err) + } + opts = append(opts, opt) + } + + if cfg.Credentials != "" { + opts = append(opts, nats.UserCredentials(cfg.Credentials)) + } + + if len(cfg.TLSCAs) > 0 { + opts = append(opts, nats.RootCAs(cfg.TLSCAs...)) + } + + if cfg.TLSCert != "" && cfg.TLSKey != "" { + opts = append(opts, nats.ClientCert(cfg.TLSCert, cfg.TLSKey)) + } + + nc, err := nats.Connect(cfg.URL, opts...) + if err != nil { + return nil, err + } + cp.logger.Infof("%s connected to NATS Deployment: %s", cfg.Name, nc.ConnectedAddr()) + + connection := &pooledNatsConn{ + nc: nc, + cp: cp, + key: key, + } + + cp.Lock() + cp.cache[key] = connection + cp.Unlock() + + return connection, err + }) + + if err != nil { + return nil, err + } + + connection, ok := conn.(*pooledNatsConn) + if !ok { + return nil, fmt.Errorf("not a pooledNatsConn") + } + return connection, nil +} diff --git a/controllers/jetstream/conn_pool_test.go b/controllers/jetstream/conn_pool_test.go new file mode 100644 index 00000000..02ec8880 --- /dev/null +++ b/controllers/jetstream/conn_pool_test.go @@ -0,0 +1,92 @@ +package jetstream + +import ( + "sync" + "testing" + "time" + + "github.com/nats-io/nats.go" + + natsservertest "github.com/nats-io/nats-server/v2/test" + "github.com/sirupsen/logrus" + testifyAssert "github.com/stretchr/testify/assert" +) + +func TestConnPool(t *testing.T) { + t.Parallel() + + s := natsservertest.RunRandClientPortServer() + defer s.Shutdown() + o1 := &natsContext{ + Name: "Client 1", + } + o2 := &natsContext{ + Name: "Client 1", + } + o3 := &natsContext{ + Name: "Client 2", + } + + natsDefaults := &natsContextDefaults{ + URL: s.ClientURL(), + } + natsOptions := []nats.Option{ + nats.MaxReconnects(10240), + } + cp := newNatsConnPool(logrus.New(), natsDefaults, natsOptions) + + var c1, c2, c3 *pooledNatsConn + var c1e, c2e, c3e error + wg := &sync.WaitGroup{} + wg.Add(3) + go func() { + c1, c1e = cp.Get(o1) + wg.Done() + }() + go func() { + c2, c2e = cp.Get(o2) + wg.Done() + }() + go func() { + c3, c3e = cp.Get(o3) + wg.Done() + }() + wg.Wait() + + assert := testifyAssert.New(t) + if assert.NoError(c1e) && assert.NoError(c2e) { + assert.Same(c1, c2) + } + if assert.NoError(c3e) { + assert.NotSame(c1, c3) + assert.NotSame(c2, c3) + } + + c1.ReturnToPool() + c3.ReturnToPool() + time.Sleep(1 * time.Second) + assert.False(c1.nc.IsClosed()) + assert.False(c2.nc.IsClosed()) + assert.True(c3.nc.IsClosed()) + + c4, c4e := cp.Get(o1) + if assert.NoError(c4e) { + assert.Same(c2, c4) + } + + c2.ReturnToPool() + c4.ReturnToPool() + time.Sleep(1 * time.Second) + assert.True(c1.nc.IsClosed()) + assert.True(c2.nc.IsClosed()) + assert.True(c4.nc.IsClosed()) + + c5, c5e := cp.Get(o1) + if assert.NoError(c5e) { + assert.NotSame(c1, c5) + } + + c5.ReturnToPool() + time.Sleep(1 * time.Second) + assert.True(c5.nc.IsClosed()) +} diff --git a/controllers/jetstream/consumer.go b/controllers/jetstream/consumer.go index c26393d2..16be2e5b 100644 --- a/controllers/jetstream/consumer.go +++ b/controllers/jetstream/consumer.go @@ -14,8 +14,6 @@ import ( jsmapi "github.com/nats-io/jsm.go/api" apis "github.com/nats-io/nack/pkg/jetstream/apis/jetstream/v1beta2" typed "github.com/nats-io/nack/pkg/jetstream/generated/clientset/versioned/typed/jetstream/v1beta2" - "github.com/nats-io/nats.go" - k8sapi "k8s.io/api/core/v1" k8serrors "k8s.io/apimachinery/pkg/api/errors" k8smeta "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -25,11 +23,11 @@ import ( func (c *Controller) runConsumerQueue() { for { - processQueueNext(c.cnsQueue, &realJsmClient{jm: c.jm}, c.processConsumer) + processQueueNext(c.cnsQueue, c.RealJSMC, c.processConsumer) } } -func (c *Controller) processConsumer(ns, name string, jsmc jsmClient) (err error) { +func (c *Controller) processConsumer(ns, name string, jsmClient jsmClientFunc) (err error) { cns, err := c.cnsLister.Consumers(ns).Get(name) if err != nil && k8serrors.IsNotFound(err) { return nil @@ -37,10 +35,10 @@ func (c *Controller) processConsumer(ns, name string, jsmc jsmClient) (err error return err } - return c.processConsumerObject(cns, jsmc) + return c.processConsumerObject(cns, jsmClient) } -func (c *Controller) processConsumerObject(cns *apis.Consumer, jsmc jsmClient) (err error) { +func (c *Controller) processConsumerObject(cns *apis.Consumer, jsm jsmClientFunc) (err error) { defer func() { if err != nil { err = fmt.Errorf("failed to process consumer: %w", err) @@ -133,56 +131,51 @@ func (c *Controller) processConsumerObject(cns *apis.Consumer, jsmc jsmClient) ( servers := spec.Servers if c.opts.CRDConnect { // Create a new client - opts := make([]nats.Option, 0) - opts = append(opts, nats.Name(fmt.Sprintf("%s-con-%s-%d", c.opts.NATSClientName, spec.DurableName, cns.Generation))) + natsCtx := &natsContext{} + natsCtx.Name = fmt.Sprintf("%s-con-%s-%d", c.opts.NATSClientName, spec.DurableName, cns.Generation) // Use JWT/NKEYS based credentials if present. if spec.Creds != "" { - opts = append(opts, nats.UserCredentials(spec.Creds)) + natsCtx.Credentials = spec.Creds } else if spec.Nkey != "" { - opt, err := nats.NkeyOptionFromSeed(spec.Nkey) - if err != nil { - return err - } - opts = append(opts, opt) + natsCtx.Nkey = spec.Nkey } if spec.TLS.ClientCert != "" && spec.TLS.ClientKey != "" { - opts = append(opts, nats.ClientCert(spec.TLS.ClientCert, spec.TLS.ClientKey)) + natsCtx.TLSCert = spec.TLS.ClientCert + natsCtx.TLSKey = spec.TLS.ClientKey } // Use fetched secrets for the account and server if defined. if remoteClientCert != "" && remoteClientKey != "" { - opts = append(opts, nats.ClientCert(remoteClientCert, remoteClientKey)) + natsCtx.TLSCert = remoteClientCert + natsCtx.TLSKey = remoteClientKey } if remoteRootCA != "" { - opts = append(opts, nats.RootCAs(remoteRootCA)) + natsCtx.TLSCAs = []string{remoteRootCA} } if accUserCreds != "" { - opts = append(opts, nats.UserCredentials(accUserCreds)) + natsCtx.Credentials = accUserCreds } if len(spec.TLS.RootCAs) > 0 { - opts = append(opts, nats.RootCAs(spec.TLS.RootCAs...)) + natsCtx.TLSCAs = spec.TLS.RootCAs } - opts = append(opts, nats.MaxReconnects(-1)) - natsServers := strings.Join(append(servers, accServers...), ",") - newNc, err := nats.Connect(natsServers, opts...) - if err != nil { - return fmt.Errorf("failed to connect to leaf nats(%s): %w", natsServers, err) - } - + natsCtx.URL = natsServers c.normalEvent(cns, "Connecting", "Connecting to new nats-servers") - newJm, err := jsm.New(newNc) + jsmc, err := jsm(natsCtx) if err != nil { return err } - newJsmc := &realJsmClient{nc: newNc, jm: newJm} + defer jsmc.Close() - if err := op(c.ctx, newJsmc, spec); err != nil { + if err := op(c.ctx, jsmc, spec); err != nil { return err } - newJsmc.Close() } else { + jsmc, err := jsm(&natsContext{}) + if err != nil { + return err + } if err := op(c.ctx, jsmc, spec); err != nil { return err } diff --git a/controllers/jetstream/consumer_test.go b/controllers/jetstream/consumer_test.go index b0fcbec1..b6ca9874 100644 --- a/controllers/jetstream/consumer_test.go +++ b/controllers/jetstream/consumer_test.go @@ -81,7 +81,9 @@ func TestProcessConsumer(t *testing.T) { newConsumerErr: nil, newConsumer: &mockConsumer{}, } - if err := ctrl.processConsumer(ns, name, jsmc); err != nil { + if err := ctrl.processConsumer(ns, name, func(n *natsContext) (jsmClient, error) { + return jsmc, nil + }); err != nil { t.Fatal(err) } @@ -138,7 +140,7 @@ func TestProcessConsumer(t *testing.T) { newConsumerErr: nil, newConsumer: &mockConsumer{}, } - if err := ctrl.processConsumer(ns, name, jsmc); err == nil || !strings.Contains(err.Error(), `failed to create consumer "my-consumer" on stream `) { + if err := ctrl.processConsumer(ns, name, testWrapJSMC(jsmc)); err == nil || !strings.Contains(err.Error(), `failed to create consumer "my-consumer" on stream `) { t.Fatal(err) } @@ -193,7 +195,7 @@ func TestProcessConsumer(t *testing.T) { loadConsumerErr: nil, loadConsumer: &mockConsumer{}, } - if err := ctrl.processConsumer(ns, name, jsmc); err != nil { + if err := ctrl.processConsumer(ns, name, testWrapJSMC(jsmc)); err != nil { t.Fatal(err) } @@ -248,7 +250,7 @@ func TestProcessConsumer(t *testing.T) { loadConsumerErr: nil, loadConsumer: &mockConsumer{}, } - if err := ctrl.processConsumer(ns, name, jsmc); err != nil { + if err := ctrl.processConsumer(ns, name, testWrapJSMC(jsmc)); err != nil { t.Fatal(err) } @@ -322,7 +324,7 @@ func TestProcessConsumer(t *testing.T) { jsmc := &mockJsmClient{ loadConsumerErr: errors.New("failed to load consumer"), } - if err := ctrl.processConsumer(ns, name, jsmc); err == nil { + if err := ctrl.processConsumer(ns, name, testWrapJSMC(jsmc)); err == nil { t.Fatal("unexpected success") } }) @@ -432,3 +434,9 @@ func TestConsumerSpecToOpts(t *testing.T) { }) } } + +func testWrapJSMC(jsm jsmClient) jsmClientFunc { + return func(n *natsContext) (jsmClient, error) { + return jsm, nil + } +} diff --git a/controllers/jetstream/controller.go b/controllers/jetstream/controller.go index 8895584e..d2c1aedb 100644 --- a/controllers/jetstream/controller.go +++ b/controllers/jetstream/controller.go @@ -23,6 +23,7 @@ import ( "github.com/nats-io/jsm.go" jsmapi "github.com/nats-io/jsm.go/api" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" apis "github.com/nats-io/nack/pkg/jetstream/apis/jetstream/v1beta2" clientset "github.com/nats-io/nack/pkg/jetstream/generated/clientset/versioned" @@ -83,10 +84,9 @@ type Options struct { } type Controller struct { - ctx context.Context - opts Options - nc *nats.Conn - jm *jsm.Manager + ctx context.Context + opts Options + connPool *natsConnPool ki k8styped.CoreV1Interface ji typed.JetstreamV1beta2Interface @@ -175,12 +175,13 @@ func NewController(opt Options) *Controller { } func (c *Controller) Run() error { - if !c.opts.CRDConnect { - // Connect to NATS. - opts := make([]nats.Option, 0) - - opts = append(opts, nats.Name(c.opts.NATSClientName)) + // Connect to NATS. + opts := make([]nats.Option, 0) + opts = append(opts, nats.Name(c.opts.NATSClientName)) + // Always attempt to have a connection to NATS. + opts = append(opts, nats.MaxReconnects(-1)) + if !c.opts.CRDConnect { // Use JWT/NKEYS based credentials if present. if c.opts.NATSCredentials != "" { opts = append(opts, nats.UserCredentials(c.opts.NATSCredentials)) @@ -199,20 +200,15 @@ func (c *Controller) Run() error { if c.opts.NATSCA != "" { opts = append(opts, nats.RootCAs(c.opts.NATSCA)) } - - // Always attempt to have a connection to NATS. - opts = append(opts, nats.MaxReconnects(-1)) - - nc, err := nats.Connect(c.opts.NATSServerURL, opts...) + ncp := newNatsConnPool(logrus.New(), &natsContextDefaults{URL: c.opts.NATSServerURL}, opts) + pooledNc, err := ncp.Get(&natsContext{}) if err != nil { return fmt.Errorf("failed to connect to nats: %w", err) } - c.nc = nc - jm, err := jsm.New(c.nc) - if err != nil { - return err - } - c.jm = jm + pooledNc.ReturnToPool() + c.connPool = ncp + } else { + c.connPool = newNatsConnPool(logrus.New(), &natsContextDefaults{Name: c.opts.NATSClientName}, opts) } defer utilruntime.HandleCrash() @@ -240,6 +236,25 @@ func (c *Controller) Run() error { return nil } +// RealJSMC creates a new JSM client from pooled nats connections +// Providing a blank string for servers, defaults to c.opts.NATSServerUrls +// call deferred jsmC.Close() on returned instance to return the nats connection to pool +func (c *Controller) RealJSMC(cfg *natsContext) (jsmClient, error) { + if cfg == nil { + cfg = &natsContext{} + } + pooledNc, err := c.connPool.Get(cfg) + if err != nil { + return nil, err + } + jm, err := jsm.New(pooledNc.nc) + if err != nil { + return nil, err + } + jsmc := &realJsmClient{pooledNc: pooledNc, jm: jm} + return jsmc, nil +} + func selectMissingStreamsFromList(prev, cur map[string]*apis.Stream) []*apis.Stream { var deleted []*apis.Stream for name, ps := range prev { @@ -293,7 +308,7 @@ func (c *Controller) cleanupStreams() error { klog.Infof("stream %s/%s was not found anymore, deleting from JetStream", s.Namespace, s.Name) t := k8smeta.NewTime(time.Now()) s.DeletionTimestamp = &t - if err := c.processStreamObject(s, &realJsmClient{jm: c.jm}); err != nil && !k8serrors.IsNotFound(err) { + if err := c.processStreamObject(s, c.RealJSMC); err != nil && !k8serrors.IsNotFound(err) { klog.Infof("failed to delete stream %s/%s: %s", s.Namespace, s.Name, err) continue } @@ -363,7 +378,7 @@ func (c *Controller) cleanupConsumers() error { klog.Infof("consumer %s/%s was not found anymore, deleting from JetStream", cns.Namespace, cns.Name) t := k8smeta.NewTime(time.Now()) cns.DeletionTimestamp = &t - if err := c.processConsumerObject(cns, &realJsmClient{jm: c.jm}); err != nil && !k8serrors.IsNotFound(err) { + if err := c.processConsumerObject(cns, c.RealJSMC); err != nil && !k8serrors.IsNotFound(err) { klog.Infof("failed to delete consumer %s/%s: %s", cns.Namespace, cns.Name, err) continue } @@ -433,9 +448,10 @@ func enqueueWork(q workqueue.RateLimitingInterface, item interface{}) (err error return nil } -type processorFunc func(ns, name string, c jsmClient) error +type jsmClientFunc func(*natsContext) (jsmClient, error) +type processorFunc func(ns, name string, jmsClient jsmClientFunc) error -func processQueueNext(q workqueue.RateLimitingInterface, c jsmClient, process processorFunc) { +func processQueueNext(q workqueue.RateLimitingInterface, jmsClient jsmClientFunc, process processorFunc) { item, shutdown := q.Get() if shutdown { return @@ -450,7 +466,7 @@ func processQueueNext(q workqueue.RateLimitingInterface, c jsmClient, process pr return } - err = process(ns, name, c) + err = process(ns, name, jmsClient) if err == nil { // Item processed successfully, don't requeue. q.Forget(item) diff --git a/controllers/jetstream/controller_test.go b/controllers/jetstream/controller_test.go index 7f130ae9..94f831bc 100644 --- a/controllers/jetstream/controller_test.go +++ b/controllers/jetstream/controller_test.go @@ -103,7 +103,7 @@ func TestProcessQueueNext(t *testing.T) { key := "this/is/a/bad/key" q.Add(key) - processQueueNext(q, &mockJsmClient{}, func(ns, name string, c jsmClient) error { + processQueueNext(q, testWrapJSMC(&mockJsmClient{}), func(ns, name string, c jsmClientFunc) error { return nil }) @@ -136,7 +136,7 @@ func TestProcessQueueNext(t *testing.T) { numRequeues = q.NumRequeues(key) } - processQueueNext(q, &mockJsmClient{}, func(ns, name string, c jsmClient) error { + processQueueNext(q, testWrapJSMC(&mockJsmClient{}), func(ns, name string, c jsmClientFunc) error { return fmt.Errorf("processing error") }) } @@ -164,7 +164,7 @@ func TestProcessQueueNext(t *testing.T) { q.Add(key) numRequeues := q.NumRequeues(key) - processQueueNext(q, &mockJsmClient{}, func(ns, name string, c jsmClient) error { + processQueueNext(q, testWrapJSMC(&mockJsmClient{}), func(ns, name string, c jsmClientFunc) error { return nil }) diff --git a/controllers/jetstream/jsmclient.go b/controllers/jetstream/jsmclient.go index 2f9cde03..4607a8f8 100644 --- a/controllers/jetstream/jsmclient.go +++ b/controllers/jetstream/jsmclient.go @@ -6,6 +6,7 @@ import ( "github.com/nats-io/jsm.go" jsmapi "github.com/nats-io/jsm.go/api" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" ) type jsmClient interface { @@ -29,23 +30,20 @@ type jsmConsumer interface { Delete() error } -type jsmDeleter interface { - Delete() error -} - type realJsmClient struct { - nc *nats.Conn - jm *jsm.Manager + pooledNc *pooledNatsConn + jm *jsm.Manager } func (c *realJsmClient) Connect(servers string, opts ...nats.Option) error { - nc, err := nats.Connect(servers, opts...) + connPool := newNatsConnPool(logrus.New(), &natsContextDefaults{URL: servers}, opts) + pooledNc, err := connPool.Get(&natsContext{}) if err != nil { return err } - c.nc = nc + c.pooledNc = pooledNc - m, err := jsm.New(nc) + m, err := jsm.New(pooledNc.nc) if err != nil { return err } @@ -55,7 +53,7 @@ func (c *realJsmClient) Connect(servers string, opts ...nats.Option) error { } func (c *realJsmClient) Close() { - _ = c.nc.Drain() + c.pooledNc.ReturnToPool() } func (c *realJsmClient) LoadStream(_ context.Context, name string) (jsmStream, error) { diff --git a/controllers/jetstream/stream.go b/controllers/jetstream/stream.go index 4a757060..ad0a530d 100644 --- a/controllers/jetstream/stream.go +++ b/controllers/jetstream/stream.go @@ -27,8 +27,6 @@ import ( jsmapi "github.com/nats-io/jsm.go/api" apis "github.com/nats-io/nack/pkg/jetstream/apis/jetstream/v1beta2" typed "github.com/nats-io/nack/pkg/jetstream/generated/clientset/versioned/typed/jetstream/v1beta2" - "github.com/nats-io/nats.go" - k8sapi "k8s.io/api/core/v1" k8serrors "k8s.io/apimachinery/pkg/api/errors" k8smeta "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -38,11 +36,11 @@ import ( func (c *Controller) runStreamQueue() { for { - processQueueNext(c.strQueue, &realJsmClient{jm: c.jm}, c.processStream) + processQueueNext(c.strQueue, c.RealJSMC, c.processStream) } } -func (c *Controller) processStream(ns, name string, jsmc jsmClient) (err error) { +func (c *Controller) processStream(ns, name string, jsm jsmClientFunc) (err error) { str, err := c.strLister.Streams(ns).Get(name) if err != nil && k8serrors.IsNotFound(err) { return nil @@ -50,10 +48,10 @@ func (c *Controller) processStream(ns, name string, jsmc jsmClient) (err error) return err } - return c.processStreamObject(str, jsmc) + return c.processStreamObject(str, jsm) } -func (c *Controller) processStreamObject(str *apis.Stream, jsmc jsmClient) (err error) { +func (c *Controller) processStreamObject(str *apis.Stream, jsm jsmClientFunc) (err error) { defer func() { if err != nil { err = fmt.Errorf("failed to process stream: %w", err) @@ -148,56 +146,50 @@ func (c *Controller) processStreamObject(str *apis.Stream, jsmc jsmClient) (err servers := spec.Servers if c.opts.CRDConnect { // Create a new client - opts := make([]nats.Option, 0) - opts = append(opts, nats.Name(fmt.Sprintf("%s-str-%s-%d", c.opts.NATSClientName, spec.Name, str.Generation))) + natsCtx := &natsContext{} + natsCtx.Name = fmt.Sprintf("%s-str-%s-%d", c.opts.NATSClientName, spec.Name, str.Generation) // Use JWT/NKEYS based credentials if present. if spec.Creds != "" { - opts = append(opts, nats.UserCredentials(spec.Creds)) + natsCtx.Credentials = spec.Creds } else if spec.Nkey != "" { - opt, err := nats.NkeyOptionFromSeed(spec.Nkey) - if err != nil { - return err - } - opts = append(opts, opt) + natsCtx.Nkey = spec.Nkey } if spec.TLS.ClientCert != "" && spec.TLS.ClientKey != "" { - opts = append(opts, nats.ClientCert(spec.TLS.ClientCert, spec.TLS.ClientKey)) + natsCtx.TLSCert = spec.TLS.ClientCert + natsCtx.TLSKey = spec.TLS.ClientKey } // Use fetched secrets for the account and server if defined. if remoteClientCert != "" && remoteClientKey != "" { - opts = append(opts, nats.ClientCert(remoteClientCert, remoteClientKey)) + natsCtx.TLSCert = remoteClientCert + natsCtx.TLSKey = remoteClientKey } if remoteRootCA != "" { - opts = append(opts, nats.RootCAs(remoteRootCA)) + natsCtx.TLSCAs = []string{remoteRootCA} } if accUserCreds != "" { - opts = append(opts, nats.UserCredentials(accUserCreds)) + natsCtx.Credentials = accUserCreds } if len(spec.TLS.RootCAs) > 0 { - opts = append(opts, nats.RootCAs(spec.TLS.RootCAs...)) + natsCtx.TLSCAs = spec.TLS.RootCAs } - opts = append(opts, nats.MaxReconnects(-1)) - natsServers := strings.Join(append(servers, accServers...), ",") - newNc, err := nats.Connect(natsServers, opts...) + natsCtx.URL = natsServers + c.normalEvent(str, "Connecting", "Connecting to new nats-servers") + jsmc, err := jsm(natsCtx) if err != nil { return fmt.Errorf("failed to connect to nats-servers(%s): %w", natsServers, err) } - - c.normalEvent(str, "Connecting", "Connecting to new nats-servers") - newJm, err := jsm.New(newNc) - if err != nil { + defer jsmc.Close() + if err := op(c.ctx, jsmc, spec); err != nil { return err } - newJsmc := &realJsmClient{nc: newNc, jm: newJm} - - if err := op(c.ctx, newJsmc, spec); err != nil { + } else { + jsmc, err := jsm(&natsContext{}) + if err != nil { return err } - newJsmc.Close() - } else { if err := op(c.ctx, jsmc, spec); err != nil { return err } diff --git a/controllers/jetstream/stream_test.go b/controllers/jetstream/stream_test.go index 794ee620..23d5e4b6 100644 --- a/controllers/jetstream/stream_test.go +++ b/controllers/jetstream/stream_test.go @@ -68,7 +68,7 @@ func TestProcessStream(t *testing.T) { jsmc := &mockJsmClient{ loadStreamErr: notFoundErr, } - if err := ctrl.processStream(ns, name, jsmc); err != nil { + if err := ctrl.processStream(ns, name, testWrapJSMC(jsmc)); err != nil { t.Fatal(err) } @@ -129,7 +129,7 @@ func TestProcessStream(t *testing.T) { loadStreamErr: nil, loadStream: &mockStream{}, } - if err := ctrl.processStream(ns, name, jsmc); err != nil { + if err := ctrl.processStream(ns, name, testWrapJSMC(jsmc)); err != nil { t.Fatal(err) } @@ -190,7 +190,7 @@ func TestProcessStream(t *testing.T) { loadStreamErr: nil, loadStream: &mockStream{}, } - if err := ctrl.processStream(ns, name, jsmc); err != nil { + if err := ctrl.processStream(ns, name, testWrapJSMC(jsmc)); err != nil { t.Fatal(err) } @@ -268,7 +268,7 @@ func TestProcessStream(t *testing.T) { jsmc := &mockJsmClient{ loadStreamErr: errors.New("failed to load stream"), } - if err := ctrl.processStream(ns, name, jsmc); err == nil { + if err := ctrl.processStream(ns, name, testWrapJSMC(jsmc)); err == nil { t.Fatal("unexpected success") } }) diff --git a/go.mod b/go.mod index 88a8ead8..d0921bd3 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,11 @@ go 1.20 require ( github.com/fsnotify/fsnotify v1.6.0 github.com/nats-io/jsm.go v0.1.0 + github.com/nats-io/nats-server/v2 v2.10.0 github.com/nats-io/nats.go v1.30.0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 + golang.org/x/sync v0.1.0 k8s.io/api v0.28.2 k8s.io/apimachinery v0.28.2 k8s.io/client-go v0.28.2 @@ -38,9 +40,11 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/minio/highwayhash v1.0.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/nats-io/jwt/v2 v2.5.2 // indirect github.com/nats-io/nkeys v0.4.5 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/pkg/errors v0.9.1 // indirect diff --git a/go.sum b/go.sum index 81293b09..bc071373 100644 --- a/go.sum +++ b/go.sum @@ -64,6 +64,7 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= +github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -74,7 +75,9 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m github.com/nats-io/jsm.go v0.1.0 h1:H2gYCee/iyBDjUftPOr5fEPWAcG/+fyVl89IWiy6AC4= github.com/nats-io/jsm.go v0.1.0/go.mod h1:snnYORje42cEDCX5QygzeoVA2KiWVbiIJbLfGIvXW08= github.com/nats-io/jwt/v2 v2.5.2 h1:DhGH+nKt+wIkDxM6qnVSKjokq5t59AZV5HRcFW0zJwU= +github.com/nats-io/jwt/v2 v2.5.2/go.mod h1:24BeQtRwxRV8ruvC4CojXlx/WQ/VjuwlYiH+vu/+ibI= github.com/nats-io/nats-server/v2 v2.10.0 h1:rcU++Hzo+wARxtJugrV3J5z5iGdHeVG8tT8Chb3bKDg= +github.com/nats-io/nats-server/v2 v2.10.0/go.mod h1:3PMvMSu2cuK0J9YInRLWdFpFsswKKGUS77zVSAudRto= github.com/nats-io/nats.go v1.30.0 h1:bj/rVsRCrFXxmm9mJiDhb74UKl2HhKpDwKRBtvCjZjc= github.com/nats-io/nats.go v1.30.0/go.mod h1:dcfhUgmQNN4GJEfIb2f9R7Fow+gzBF4emzDHrVBd5qM= github.com/nats-io/nkeys v0.4.5 h1:Zdz2BUlFm4fJlierwvGK+yl20IAKUm7eV6AAZXEhkPk= @@ -109,6 +112,7 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= @@ -125,6 +129,9 @@ golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=