Skip to content

Commit

Permalink
Specify a custom dial function per config
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronjheng committed Dec 19, 2023
1 parent 0004702 commit 31d874d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
24 changes: 17 additions & 7 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,30 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
mc.parseTime = mc.cfg.ParseTime

// Connect to Server
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
if c.cfg.DialFunc != nil {
dctx := ctx
if mc.cfg.Timeout > 0 {
var cancel context.CancelFunc
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
defer cancel()
}
mc.netConn, err = dial(dctx, mc.cfg.Addr)
mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
dctx := ctx
if mc.cfg.Timeout > 0 {
var cancel context.CancelFunc
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
defer cancel()
}
mc.netConn, err = dial(dctx, mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
}
}

if err != nil {
Expand Down
38 changes: 20 additions & 18 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package mysql

import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"errors"
Expand All @@ -34,24 +35,25 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
DBName string // Database name
Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger
User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
DBName string // Database name
Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // Specifies the dial function for creating connections

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
Expand Down

0 comments on commit 31d874d

Please sign in to comment.