Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add global connection pool for mysql #6327

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 105 additions & 7 deletions pkg/scalers/mysql_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"strings"
"time"

"github.com/go-logr/logr"
"github.com/go-sql-driver/mysql"
Expand All @@ -16,6 +17,16 @@ import (
kedautil "github.com/kedacore/keda/v2/pkg/util"
)

var (
// A map that holds MySQL connection pools, keyed by connection string
connectionPools *kedautil.RefMap[string, *sql.DB]
)

func init() {
// Initialize the global connectionPools map
connectionPools = kedautil.NewRefMap[string, *sql.DB]()
}

type mySQLScaler struct {
metricType v2.MetricTargetType
metadata *mySQLMetadata
Expand All @@ -34,6 +45,12 @@ type mySQLMetadata struct {
QueryValue float64 `keda:"name=queryValue, order=triggerMetadata"`
ActivationQueryValue float64 `keda:"name=activationQueryValue, order=triggerMetadata, default=0"`
MetricName string `keda:"name=metricName, order=triggerMetadata, optional"`

// Connection pool settings
UseGlobalConnPools bool `keda:"name=useGlobalConnPools, order=triggerMetadata, optional"`
MaxOpenConns int `keda:"name=maxOpenConns, order=triggerMetadata, optional"`
MaxIdleConns int `keda:"name=maxIdleConns, order=triggerMetadata, optional"`
ConnMaxIdleTime int `keda:"name=connMaxIdleTime, order=triggerMetadata, optional"` // seconds
}

// NewMySQLScaler creates a new MySQL scaler
Expand All @@ -50,10 +67,19 @@ func NewMySQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
return nil, fmt.Errorf("error parsing MySQL metadata: %w", err)
}

conn, err := newMySQLConnection(meta, logger)
// Create MySQL connection, if useGlobalConnPools is set to true, it will use
// the global connection pool for the given connection string, otherwise it
// will create a new local connection pool for the given connection string
var conn *sql.DB
if meta.UseGlobalConnPools {
conn, err = getConnectionPool(meta, logger)
} else {
conn, err = newMySQLConnection(meta, logger)
}
if err != nil {
return nil, fmt.Errorf("error establishing MySQL connection: %w", err)
return nil, fmt.Errorf("error creating MySQL connection: %w", err)
}

return &mySQLScaler{
metricType: metricType,
metadata: meta,
Expand Down Expand Up @@ -96,6 +122,40 @@ func metadataToConnectionStr(meta *mySQLMetadata) string {
return connStr
}

// getConnectionPool will check if the connection pool has already been
// created for the given connection string and return it. If it has not
// been created, it will create a new connection pool and store it in the
// connectionPools map.
func getConnectionPool(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error) {
connStr := metadataToConnectionStr(meta)
// Try to load an existing pool and increment its reference count if found
if pool, ok := connectionPools.Load(connStr); ok {
err := connectionPools.AddRef(connStr)
if err != nil {
logger.Error(err, "Error increasing connection pool reference count")
return nil, err
}

return pool, nil
}

// If pool does not exist, create a new one and store it in RefMap
newPool, err := newMySQLConnection(meta, logger)
if err != nil {
return nil, err
}
err = connectionPools.Store(connStr, newPool, func(db *sql.DB) error {
logger.Info("Closing MySQL connection pool", "connectionString", connStr)
return db.Close()
})
if err != nil {
logger.Error(err, "Error storing connection pool in RefMap")
return nil, err
}

return newPool, nil
}

// newMySQLConnection creates MySQL db connection
func newMySQLConnection(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error) {
connStr := metadataToConnectionStr(meta)
Expand All @@ -104,14 +164,35 @@ func newMySQLConnection(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error
logger.Error(err, fmt.Sprintf("Found error when opening connection: %s", err))
return nil, err
}

err = db.Ping()
if err != nil {
logger.Error(err, fmt.Sprintf("Found error when pinging database: %s", err))
return nil, err
}

setConnectionPoolConfiguration(meta, db)

return db, nil
}

// setConnectionPoolConfiguration configures the MySQL connection pool settings
// based on the parameters provided in mySQLMetadata. If a setting is zero, it
// is left at its default value.
func setConnectionPoolConfiguration(meta *mySQLMetadata, db *sql.DB) {
Copy link
Contributor Author

@robpickerill robpickerill Nov 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Reconcile these values when scalers share pools but present different connection pool settings

  • select max/min/etc
  • OR create new connection and add to pool

if meta.MaxOpenConns > 0 {
db.SetMaxOpenConns(meta.MaxOpenConns)
}

if meta.MaxIdleConns > 0 {
db.SetMaxIdleConns(meta.MaxIdleConns)
}

if meta.ConnMaxIdleTime > 0 {
db.SetConnMaxIdleTime(time.Duration(meta.ConnMaxIdleTime) * time.Second)
}
}

// parseMySQLDbNameFromConnectionStr returns dbname from connection string
// in it is not able to parse it, it returns "dbname" string
func parseMySQLDbNameFromConnectionStr(connectionString string) string {
Expand All @@ -123,13 +204,30 @@ func parseMySQLDbNameFromConnectionStr(connectionString string) string {
return "dbname"
}

// Close disposes of MySQL connections
func (s *mySQLScaler) Close(context.Context) error {
err := s.connection.Close()
if err != nil {
s.logger.Error(err, "Error closing MySQL connection")
// Close disposes of MySQL connections, closing either the global pool if used
// or the local connection pool
func (s *mySQLScaler) Close(ctx context.Context) error {
if s.metadata.UseGlobalConnPools {
if err := s.closeGlobalPool(ctx); err != nil {
return fmt.Errorf("error closing MySQL connection: %w", err)
}
} else {
if err := s.connection.Close(); err != nil {
return fmt.Errorf("error closing MySQL connection: %w", err)
}
}

return nil
}

// closeGlobalPool closes all MySQL connections in the global pool
func (s *mySQLScaler) closeGlobalPool(_ context.Context) error {
connStr := metadataToConnectionStr(s.metadata)
if err := connectionPools.RemoveRef(connStr); err != nil {
s.logger.Error(err, "Error decreasing connection pool reference count")
return err
}

return nil
}

Expand Down
21 changes: 21 additions & 0 deletions pkg/scalers/mysql_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,27 @@ var testMySQLMetadata = []parseMySQLMetadataTestData{
resolvedEnv: map[string]string{},
raisesError: true,
},
// use global pool
{
metadata: map[string]string{"query": "query", "queryValue": "12", "useGlobalConnPools": "true"},
authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"},
resolvedEnv: testMySQLResolvedEnv,
raisesError: false,
},
// use connection pool settings
{
metadata: map[string]string{"query": "query", "queryValue": "12", "maxOpenConns": "10", "maxIdleConns": "5", "connMaxIdleTime": "10"},
authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"},
resolvedEnv: testMySQLResolvedEnv,
raisesError: false,
},
// use connection pool settings and global pool
{
metadata: map[string]string{"query": "query", "queryValue": "12", "maxOpenConns": "10", "maxIdleConns": "5", "connMaxIdleTime": "10", "useGlobalConnPools": "true"},
authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"},
resolvedEnv: testMySQLResolvedEnv,
raisesError: false,
},
}

var mySQLMetricIdentifiers = []mySQLMetricIdentifier{
Expand Down
123 changes: 123 additions & 0 deletions pkg/util/refmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package util

//nolint:depguard // sync/atomic
import (
"fmt"
"sync"
"sync/atomic"
)

// refCountedValue manages a reference-counted value with a cleanup function.
type refCountedValue[V any] struct {
value V
refCount atomic.Int64
closeFunc func(V) error // Cleanup function to call when count reaches zero
}

// Add increments the reference count.
func (r *refCountedValue[V]) Add() {
r.refCount.Add(1)
}

// Remove decrements the reference count and invokes closeFunc if the count
// reaches zero.
func (r *refCountedValue[V]) Remove() error {
if r.refCount.Add(-1) == 0 {
return r.closeFunc(r.value)
}

return nil
}

// Value returns the underlying value.
func (r *refCountedValue[V]) Value() V {
return r.value
}

// RefMap manages reference-counted items in a concurrent-safe map.
type RefMap[K comparable, V any] struct {
data map[K]*refCountedValue[V]
mu sync.RWMutex
}

// NewRefMap initializes a new RefMap. A RefMap is an atomic reference-counted
// concurrent hashmap. The general usage pattern is to Store a value with a
// close function, once the value is contained within the RefMap, it can be
// accessed via the Load method. The AddRef method signals ownership of the
// value and increments the reference count. The RemoveRef method decrements
// the reference count. When the reference count reaches zero, the close
// function is called and the value is removed from the map.
func NewRefMap[K comparable, V any]() *RefMap[K, V] {
return &RefMap[K, V]{
data: make(map[K]*refCountedValue[V]),
}
}

// Store adds a new item with an initial reference count of 1 and a close
// function.
func (r *RefMap[K, V]) Store(key K, value V, closeFunc func(V) error) error {
r.mu.Lock()
defer r.mu.Unlock()

if _, exists := r.data[key]; exists {
return fmt.Errorf("key already exists: %v", key)
}

r.data[key] = &refCountedValue[V]{value: value, refCount: atomic.Int64{}, closeFunc: closeFunc}
r.data[key].Add() // Set initial reference count to 1

return nil
}

// Load retrieves a value by key without modifying the reference count,
// returning the value and a boolean indicating if it was found. The reference
// count not being modified means that a check for the existence of a key
// can be performed without signalling ownership of the value. If the value is
// used after this method, it is recommended to call AddRef to increment the
// reference
func (r *RefMap[K, V]) Load(key K) (V, bool) {
r.mu.RLock()
defer r.mu.RUnlock()

if refValue, found := r.data[key]; found {
return refValue.Value(), true
}
var zero V

return zero, false
}

// AddRef increments the reference count for a key if it exists. Ensure
// to call RemoveRef when done with the value to prevent memory leaks.
func (r *RefMap[K, V]) AddRef(key K) error {
r.mu.RLock()
defer r.mu.RUnlock()

refValue, found := r.data[key]
if !found {
return fmt.Errorf("key not found: %v", key)
}

refValue.Add()
return nil
}

// RemoveRef decrements the reference count and deletes the entry if count
// reaches zero.
func (r *RefMap[K, V]) RemoveRef(key K) error {
r.mu.Lock()
defer r.mu.Unlock()

refValue, found := r.data[key]
if !found {
return fmt.Errorf("key not found: %v", key)
}

err := refValue.Remove()

if refValue.refCount.Load() == 0 {
delete(r.data, key)
}

return err // returns the error from closeFunc
}
Loading
Loading