diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 40bedb08..ec7ca57b 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -75,11 +75,11 @@ type Config struct { PacketSize uint16 } -func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string) (*tls.Config, error) { +func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string, tlsMinVer uint16) (*tls.Config, error) { config := tls.Config{ ServerName: hostInCertificate, InsecureSkipVerify: insecureSkipVerify, - + MinVersion: tlsMinVer, // fix for https://github.com/denisenkom/go-mssqldb/issues/166 // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, // while SQL Server seems to expect one TCP segment per encrypted TDS package. @@ -254,10 +254,29 @@ func Parse(dsn string) (Config, map[string]string, error) { hostInCertificate = p.Host p.HostInCertificateProvided = false } + tlsversion, ok := params["tlsminversion"] + tlsMinVer := uint16(0) + if ok { + tlsversion = strings.ToUpper(tlsversion) + switch tlsversion { + case "TLS1.0": + tlsMinVer = tls.VersionTLS10 + case "TLS1.1": + tlsMinVer = tls.VersionTLS11 + case "TLS1.2": + tlsMinVer = tls.VersionTLS12 + /*comment by go1.8 ~ go1.11 has no tls.VersionTLS13 + case "TLS1.3": + tlsMinVer = tls.VersionTLS13 + */ + default: + tlsMinVer = 0 + } + } if p.Encryption != EncryptionDisabled { var err error - p.TLSConfig, err = SetupTLS(certificate, trustServerCert, hostInCertificate) + p.TLSConfig, err = SetupTLS(certificate, trustServerCert, hostInCertificate, tlsMinVer) if err != nil { return p, params, fmt.Errorf("failed to setup TLS: %w", err) } diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 594b5b3d..6b789914 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -1,6 +1,7 @@ package msdsn import ( + "crypto/tls" "reflect" "testing" "time" @@ -196,3 +197,31 @@ func TestConnParseRoundTripFixed(t *testing.T) { t.Fatal("Parameters do not match after roundtrip", params, rtParams) } } + +func TestConnParseWithTlsVersion(t *testing.T) { + tests := []struct { + name string + connStr string + wantCfg *Config + }{ + {name: "1.TLS1.0", connStr: "sqlserver://someuser@somehost?tlsminversion=tls1.0", wantCfg: &Config{TLSConfig: &tls.Config{MinVersion: tls.VersionTLS10}}}, + {name: "2.TLS1.1", connStr: "sqlserver://someuser@somehost?tlsminversion=tls1.1", wantCfg: &Config{TLSConfig: &tls.Config{MinVersion: tls.VersionTLS11}}}, + {name: "3.TLS1.2", connStr: "sqlserver://someuser@somehost?tlsminversion=tls1.2", wantCfg: &Config{TLSConfig: &tls.Config{MinVersion: tls.VersionTLS12}}}, + {name: "4.no tlsminversion parameter", connStr: "sqlserver://someuser@somehost", wantCfg: &Config{TLSConfig: &tls.Config{MinVersion: 0}}}, + {name: "5.wrong tlsminversion parameter", connStr: "sqlserver://someuser@somehost?tlsminversion=wrongtlsversion", wantCfg: &Config{TLSConfig: &tls.Config{MinVersion: 0}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, _, err := Parse(tt.connStr) + if err != nil { + t.Errorf("%s Parse Error:%+v", tt.name, err) + return + } + if got.TLSConfig.MinVersion != tt.wantCfg.TLSConfig.MinVersion { + t.Errorf("%s Parse MinVersion not match. want:%d, got:%d", tt.name, tt.wantCfg.TLSConfig.MinVersion, got.TLSConfig.MinVersion) + return + } + }) + } + +} diff --git a/tds.go b/tds.go index dbe95272..0bbe7cbd 100644 --- a/tds.go +++ b/tds.go @@ -1050,50 +1050,7 @@ func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Co dialCtx, cancel = context.WithTimeout(ctx, dt) defer cancel() } - // if instance is specified use instance resolution service - if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 { - // both instance name and port specified - // when port is specified instance name is not used - // you should not provide instance name when you provide port - logger.Log(ctx, msdsn.LogDebug, "WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored") - } - if len(p.Instance) > 0 { - p.Instance = strings.ToUpper(p.Instance) - d := c.getDialer(&p) - instances, err := getInstances(dialCtx, d, p.Host) - if err != nil { - f := "unable to get instances from Sql Server Browser on host %v: %v" - return nil, fmt.Errorf(f, p.Host, err.Error()) - } - strport, ok := instances[p.Instance]["tcp"] - if !ok { - f := "no instance matching '%v' returned from host '%v'" - return nil, fmt.Errorf(f, p.Instance, p.Host) - } - port, err := strconv.ParseUint(strport, 0, 16) - if err != nil { - f := "invalid tcp port returned from Sql Server Browser '%v': %v" - return nil, fmt.Errorf(f, strport, err.Error()) - } - p.Port = port - } - if p.Port == 0 { - p.Port = defaultServerPort - } - - packetSize := p.PacketSize - if packetSize == 0 { - packetSize = defaultPacketSize - } - // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes - // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request - // a higher packet size, the server will respond with an ENVCHANGE request to - // alter the packet size to 16383 bytes. - if packetSize < 512 { - packetSize = 512 - } else if packetSize > 32767 { - packetSize = 32767 - } + err = prepareMSDSN(ctx, c, logger, &p) initiate_connection: conn, err := dialConnection(dialCtx, c, p) @@ -1103,7 +1060,7 @@ initiate_connection: toconn := newTimeoutConn(conn, p.ConnTimeout) - outbuf := newTdsBuffer(packetSize, toconn) + outbuf := newTdsBuffer(p.PacketSize, toconn) sess := tdsSession{ buf: outbuf, logger: logger, @@ -1136,25 +1093,8 @@ initiate_connection: } if encrypt != encryptNotSup { - var config *tls.Config - if pc := p.TLSConfig; pc != nil { - config = pc - if config.DynamicRecordSizingDisabled == false { - config = config.Clone() - - // fix for https://github.com/denisenkom/go-mssqldb/issues/166 - // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, - // while SQL Server seems to expect one TCP segment per encrypted TDS package. - // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package - config.DynamicRecordSizingDisabled = true - } - } - if config == nil { - config, err = msdsn.SetupTLS("", false, p.Host) - if err != nil { - return nil, err - } - } + //refactor tls config build. + config := prepareTLSConfig(p) // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream handshakeConn := tlsHandshakeConn{buf: outbuf} @@ -1288,3 +1228,75 @@ func resolveServerPort(port uint64) uint64 { return port } + +func prepareTLSConfig(p msdsn.Config) (config *tls.Config) { + if pc := p.TLSConfig; pc != nil { + config = pc + if config.DynamicRecordSizingDisabled == false { + config = config.Clone() + + // fix for https://github.com/denisenkom/go-mssqldb/issues/166 + // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, + // while SQL Server seems to expect one TCP segment per encrypted TDS package. + // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package + config.DynamicRecordSizingDisabled = true + } + } + if config == nil { + //In this scenario, error will not appear + config, _ = msdsn.SetupTLS("", false, p.Host, 0) + } + return +} + +func prepareMSDSN(dialCtx context.Context, c *Connector, logger ContextLogger, p *msdsn.Config) (err error) { + + // if instance is specified use instance resolution service + if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 { + // both instance name and port specified + // when port is specified instance name is not used + // you should not provide instance name when you provide port + logger.Log(dialCtx, msdsn.LogDebug, "WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored") + } + if len(p.Instance) > 0 { + p.Instance = strings.ToUpper(p.Instance) + d := c.getDialer(p) + instances, err := getInstances(dialCtx, d, p.Host) + if err != nil { + const f = "unable to get instances from Sql Server Browser on host %v: %v" + return fmt.Errorf(f, p.Host, err.Error()) + } + strport, ok := instances[p.Instance]["tcp"] + if !ok { + const f = "no instance matching '%v' returned from host '%v'" + return fmt.Errorf(f, p.Instance, p.Host) + } + port, err := strconv.ParseUint(strport, 0, 16) + if err != nil { + const f = "invalid tcp port returned from Sql Server Browser '%v': %v" + return fmt.Errorf(f, strport, err.Error()) + } + p.Port = port + } + if p.Port == 0 { + p.Port = defaultServerPort + } + + packetSize := p.PacketSize + if packetSize == 0 { + packetSize = defaultPacketSize + } + // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes + // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request + // a higher packet size, the server will respond with an ENVCHANGE request to + // alter the packet size to 16383 bytes. + if packetSize < 512 { + packetSize = 512 + } else if packetSize > 32767 { + packetSize = 32767 + } + + p.PacketSize = packetSize + return err + +} diff --git a/tds_test.go b/tds_test.go index 7d9af0de..9603635e 100644 --- a/tds_test.go +++ b/tds_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "crypto/tls" "database/sql" "encoding/hex" "fmt" @@ -12,6 +13,7 @@ import ( "net/url" "os" "path" + "reflect" "runtime" "sync" "testing" @@ -758,3 +760,128 @@ func runBatch(t testing.TB, p msdsn.Config) { } } } + +func Test_prepareTLSConfig(t *testing.T) { + + tests := []struct { + name string + p msdsn.Config + wantConfig *tls.Config + wantErr bool + }{ + {name: "1.TLSConfig is null ", p: msdsn.Config{Host: "testserver"}, wantConfig: &tls.Config{ServerName: "testserver"}, wantErr: false}, + {name: "2.TLSConfig not null ,DynamicRecordSizingDisabled=false", p: msdsn.Config{TLSConfig: &tls.Config{DynamicRecordSizingDisabled: false, ServerName: "testserver", MinVersion: tls.VersionTLS10}}, wantConfig: &tls.Config{ServerName: "testserver", DynamicRecordSizingDisabled: true, MinVersion: tls.VersionTLS10}, wantErr: false}, + {name: "3.TLSConfig not null ,DynamicRecordSizingDisabled=true", p: msdsn.Config{TLSConfig: &tls.Config{DynamicRecordSizingDisabled: true, ServerName: "testserver", MinVersion: tls.VersionTLS10}}, wantConfig: &tls.Config{ServerName: "testserver", DynamicRecordSizingDisabled: true, MinVersion: tls.VersionTLS10}, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotConfig := prepareTLSConfig(tt.p) + + if gotConfig.ServerName != tt.wantConfig.ServerName || + gotConfig.MinVersion != tt.wantConfig.MinVersion { + t.Errorf("prepareTLSConfig() = %v, want %v", gotConfig, tt.wantConfig) + } + }) + } +} + +func Test_prepareMSDSN(t *testing.T) { + + tl := testLogger{t: t} + defer tl.StopLogging() + SetLogger(&tl) + var newdialerCallErr = func() Dialer { + return NewMockTransportDialer( + []string{ + " 03", + }, + []string{ + " 05 00 00 49 6e 73 74 61 6e 63 65 4e 61 6d 65 3b 62 3b 3b", + }, + ) + } + var newdialerCallErr2 = func() Dialer { + return NewMockTransportDialer( + []string{ + " 03", + }, + []string{ + " 05 00 00 49 6e 73 74 61 6e 63 65 4e 61 6d 65 3b 62 3b 74 63 70 3b 61 62 63 3b 3b", + }, + ) + } + var newdialerCallErr3 = func() Dialer { + return NewMockTransportDialer( + []string{ + " 04", + }, + []string{ + " 05 00 00 49 6e 73 74 61 6e 63 65 4e 61 6d 65 3b 62 3b 74 63 70 3b 61 62 63 3b 3b", + }, + ) + } + var newdialerCallSuc = func() Dialer { + return NewMockTransportDialer( + []string{ + " 03", + }, + []string{ + " 05 00 00 49 6e 73 74 61 6e 63 65 4e 61 6d 65 3b 62 3b 74 63 70 3b 31 34 33 33 3b 3b", + }, + ) + } + + type args struct { + ctx context.Context + p *msdsn.Config + } + + tests := []struct { + name string + args args + dialCall func() Dialer + wantDialCtx context.Context + wantErr bool + }{ + {name: "1.", dialCall: newdialerCallErr, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "test", Port: 1433, LogFlags: msdsn.LogDebug}}, wantDialCtx: nil, wantErr: true}, + {name: "2.", dialCall: newdialerCallErr, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "test", Port: 0, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: true}, + {name: "3.", dialCall: newdialerCallErr, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "", PacketSize: 1, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: false}, + {name: "4.", dialCall: newdialerCallErr, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "", PacketSize: 32768, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: false}, + {name: "5.", dialCall: newdialerCallErr, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "", PacketSize: 32768, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: false}, + {name: "6.", dialCall: newdialerCallSuc, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "B", PacketSize: 4096, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: false}, + {name: "7.", dialCall: newdialerCallErr2, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "B", PacketSize: 4096, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: true}, + {name: "8.", dialCall: newdialerCallErr3, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "B", PacketSize: 4096, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + c := &Connector{params: *tt.args.p, Dialer: tt.dialCall()} + err := prepareMSDSN(tt.args.ctx, c, driverInstanceNoProcess.logger, tt.args.p) + if (err != nil) != tt.wantErr { + t.Errorf("prepareMSDSN() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func Test_parseInstances(t *testing.T) { + + tests := []struct { + name string + args []byte + want map[string]map[string]string + }{ + {name: "1.len<=3", args: []byte(`abc`), want: map[string]map[string]string{}}, + {name: "2.len byte[0]!=5", args: []byte{1, 0, 1, 1}, want: map[string]map[string]string{}}, + {name: "3.normal-1", args: append([]byte{5, 0, 0}, []byte(`;b;`)...), want: map[string]map[string]string{}}, + {name: "3.normal-2", args: append([]byte{5, 0, 0}, []byte(`InstanceName;b;;`)...), want: map[string]map[string]string{"B": map[string]string{"InstanceName": "b"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseInstances(tt.args); !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseInstances() = %v, want %v", got, tt.want) + } + }) + } +}