diff --git a/go/vt/vtgateproxy/discovery.go b/go/vt/vtgateproxy/discovery.go index 0235c4f1ac2..735f03572b7 100644 --- a/go/vt/vtgateproxy/discovery.go +++ b/go/vt/vtgateproxy/discovery.go @@ -59,6 +59,7 @@ import ( // const PoolTypeAttr = "PoolType" +const ZoneLocalAttr = "ZoneLocal" // Resolver(https://godoc.org/google.golang.org/grpc/resolver#Resolver). type JSONGateResolver struct { @@ -83,6 +84,7 @@ type JSONGateResolverBuilder struct { affinityField string affinityValue string numConnections int + numBackupConns int mu sync.RWMutex targets map[string][]targetHost @@ -98,6 +100,7 @@ type targetHost struct { Addr string PoolType string Affinity string + IsLocal bool } var ( @@ -113,6 +116,7 @@ func RegisterJSONGateResolver( affinityField string, affinityValue string, numConnections int, + numBackupConns int, ) (*JSONGateResolverBuilder, error) { jsonDiscovery := &JSONGateResolverBuilder{ targets: map[string][]targetHost{}, @@ -123,6 +127,7 @@ func RegisterJSONGateResolver( affinityField: affinityField, affinityValue: affinityValue, numConnections: numConnections, + numBackupConns: numBackupConns, sorter: newShuffleSorter(), } @@ -263,7 +268,7 @@ func (b *JSONGateResolverBuilder) parse() (bool, error) { return false, fmt.Errorf("error parsing JSON discovery file %s: %v", b.jsonPath, err) } - var targets = map[string][]targetHost{} + var allTargets = map[string][]targetHost{} for _, host := range hosts { hostname, hasHostname := host["host"] address, hasAddress := host[b.addressField] @@ -309,8 +314,8 @@ func (b *JSONGateResolverBuilder) parse() (bool, error) { return false, fmt.Errorf("error parsing JSON discovery file %s: port field %s has invalid value %v", b.jsonPath, b.portField, port) } - target := targetHost{hostname.(string), fmt.Sprintf("%s:%s", address, port), poolType.(string), affinity.(string)} - targets[target.PoolType] = append(targets[target.PoolType], target) + target := targetHost{hostname.(string), fmt.Sprintf("%s:%s", address, port), poolType.(string), affinity.(string), affinity == b.affinityValue} + allTargets[target.PoolType] = append(allTargets[target.PoolType], target) } // If a pool disappears, the metric will not record this unless all counts @@ -320,16 +325,25 @@ func (b *JSONGateResolverBuilder) parse() (bool, error) { // targets and only resetting pools which disappear. targetCount.ResetAll() - for poolType := range targets { - b.sorter.shuffleSort(targets[poolType], b.affinityField, b.affinityValue) - if len(targets[poolType]) > *numConnections { - targets[poolType] = targets[poolType][:b.numConnections] + var selected = map[string][]targetHost{} + + for poolType := range allTargets { + b.sorter.shuffleSort(allTargets[poolType]) + + // try to pick numConnections from the front of the list (local zone) and numBackupConnections + // from the tail (remote zone). if that's not possible, just take the whole set + if len(allTargets[poolType]) >= b.numConnections+b.numBackupConns { + remoteOffset := len(allTargets[poolType]) - b.numBackupConns + selected[poolType] = append(allTargets[poolType][:b.numConnections], allTargets[poolType][remoteOffset:]...) + } else { + selected[poolType] = allTargets[poolType] } - targetCount.Set(poolType, int64(len(targets[poolType]))) + + targetCount.Set(poolType, int64(len(selected[poolType]))) } b.mu.Lock() - b.targets = targets + b.targets = selected b.mu.Unlock() return true, nil @@ -353,7 +367,7 @@ func (b *JSONGateResolverBuilder) getTargets(poolType string) []targetHost { targets = append(targets, b.targets[poolType]...) b.mu.RUnlock() - b.sorter.shuffleSort(targets, b.affinityField, b.affinityValue) + b.sorter.shuffleSort(targets) return targets } @@ -373,7 +387,7 @@ func newShuffleSorter() *shuffleSorter { // shuffleSort shuffles a slice of targetHost to ensure every host has a // different order to iterate through, putting the affinity matching (e.g. same // az) hosts at the front and the non-matching ones at the end. -func (s *shuffleSorter) shuffleSort(targets []targetHost, affinityField, affinityValue string) { +func (s *shuffleSorter) shuffleSort(targets []targetHost) { n := len(targets) head := 0 // Only need to do n-1 swaps since the last host is always in the right place. @@ -383,7 +397,7 @@ func (s *shuffleSorter) shuffleSort(targets []targetHost, affinityField, affinit j := head + s.rand.Intn(tail-head+1) s.mu.Unlock() - if affinityField != "" && affinityValue == targets[j].Affinity { + if targets[j].IsLocal { targets[head], targets[j] = targets[j], targets[head] head++ } else { @@ -406,7 +420,8 @@ func (b *JSONGateResolverBuilder) update(r *JSONGateResolver) error { var addrs []resolver.Address for _, target := range targets { - addrs = append(addrs, resolver.Address{Addr: target.Addr, Attributes: attributes.New(PoolTypeAttr, r.poolType)}) + attrs := attributes.New(PoolTypeAttr, r.poolType).WithValue(ZoneLocalAttr, target.IsLocal) + addrs = append(addrs, resolver.Address{Addr: target.Addr, Attributes: attrs}) } // If we've already selected some targets, give the new addresses some time to warm up before removing @@ -488,12 +503,13 @@ const ( {{range $i, $p := .Pools}} - + {{range index $.Targets $p}} + {{end}} {{end}}
{{$p}}{{$p}}
{{.Hostname}} {{.Addr}} {{.Affinity}}{{.IsLocal}}
diff --git a/go/vt/vtgateproxy/firstready_balancer.go b/go/vt/vtgateproxy/firstready_balancer.go index 2885ae93d5f..8009f3dee86 100644 --- a/go/vt/vtgateproxy/firstready_balancer.go +++ b/go/vt/vtgateproxy/firstready_balancer.go @@ -41,12 +41,13 @@ import ( ) // newBuilder creates a new first_ready balancer builder. -func newBuilder() balancer.Builder { +func newFirstReadyBuilder() balancer.Builder { return base.NewBalancerBuilder("first_ready", &frPickerBuilder{currentConns: map[string]balancer.SubConn{}}, base.Config{HealthCheck: true}) } func init() { - balancer.Register(newBuilder()) + log.V(1).Infof("registering first_ready balancer") + balancer.Register(newFirstReadyBuilder()) } // frPickerBuilder implements both the Builder and the Picker interfaces. diff --git a/go/vt/vtgateproxy/mysql_server.go b/go/vt/vtgateproxy/mysql_server.go index e6169649582..65771400036 100644 --- a/go/vt/vtgateproxy/mysql_server.go +++ b/go/vt/vtgateproxy/mysql_server.go @@ -219,6 +219,8 @@ func (ph *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sql } }() + ctx = context.WithValue(ctx, CONN_ID_KEY, int(c.ConnectionID)) + if session.SessionPb().Options.Workload == querypb.ExecuteOptions_OLAP { err := ph.proxy.StreamExecute(ctx, session, query, make(map[string]*querypb.BindVariable), callback) return sqlerror.NewSQLErrorFromError(err) @@ -285,6 +287,8 @@ func (ph *proxyHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[str } }(session) + ctx = context.WithValue(ctx, CONN_ID_KEY, int(c.ConnectionID)) + _, fld, err := ph.proxy.Prepare(ctx, session, query, bindVars) err = sqlerror.NewSQLErrorFromError(err) if err != nil { @@ -332,6 +336,8 @@ func (ph *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData } }() + ctx = context.WithValue(ctx, CONN_ID_KEY, int(c.ConnectionID)) + if session.SessionPb().Options.Workload == querypb.ExecuteOptions_OLAP { err := ph.proxy.StreamExecute(ctx, session, prepare.PrepareStmt, prepare.BindVars, callback) return sqlerror.NewSQLErrorFromError(err) @@ -396,6 +402,8 @@ func (ph *proxyHandler) getSession(ctx context.Context, c *mysql.Conn) (*vtgatec options.ClientFoundRows = true } + ctx = context.WithValue(ctx, CONN_ID_KEY, int(c.ConnectionID)) + var err error session, err = ph.proxy.NewSession(ctx, options, c.Attributes) if err != nil { @@ -420,6 +428,9 @@ func (ph *proxyHandler) closeSession(ctx context.Context, c *mysql.Conn) { if session.SessionPb().InTransaction { defer atomic.AddInt32(&busyConnections, -1) } + + ctx = context.WithValue(ctx, CONN_ID_KEY, int(c.ConnectionID)) + err := ph.proxy.CloseSession(ctx, session) if err != nil { log.Errorf("Error happened in transaction rollback: %v", err) diff --git a/go/vt/vtgateproxy/sim/vtgateproxysim.go b/go/vt/vtgateproxy/sim/vtgateproxysim.go new file mode 100644 index 00000000000..e8c6faa934d --- /dev/null +++ b/go/vt/vtgateproxy/sim/vtgateproxysim.go @@ -0,0 +1,101 @@ +package main + +import ( + "flag" + "fmt" + "math/rand" + "sort" + "time" + + "github.com/guptarohit/asciigraph" +) + +var ( + numClients = flag.Int("c", 9761, "Number of clients") + numVtgates = flag.Int("v", 1068, "Number of vtgates") + numConnections = flag.Int("n", 4, "number of connections per client host") + numZones = flag.Int("z", 4, "number of zones") +) + +func main() { + rnd := rand.New(rand.NewSource(time.Now().UnixNano())) + + flag.Parse() + + fmt.Printf("Simulating %d clients => %d vtgates with %d zones %d conns per client\n\n", + *numClients, *numVtgates, *numZones, *numConnections) + + var clients []string + for i := 0; i < *numClients; i++ { + clients = append(clients, fmt.Sprintf("client-%03d", i)) + } + + var vtgates []string + for i := 0; i < *numVtgates; i++ { + vtgates = append(vtgates, fmt.Sprintf("vtgate-%03d", i)) + } + + // for now just consider 1/N of the s "local" + localClients := clients[:*numClients / *numZones] + localVtgates := vtgates[:*numVtgates / *numZones] + + conns := map[string][]string{} + + // Simulate "discovery" + for _, client := range localClients { + var clientConns []string + + for i := 0; i < *numConnections; i++ { + vtgate := localVtgates[rnd.Intn(len(localVtgates))] + clientConns = append(clientConns, vtgate) + } + + conns[client] = clientConns + } + + counts := map[string]int{} + for _, conns := range conns { + for _, vtgate := range conns { + counts[vtgate]++ + } + } + + histogram := map[int]int{} + max := 0 + min := -1 + for _, count := range counts { + histogram[count]++ + if count > max { + max = count + } + if min == -1 || count < min { + min = count + } + } + + fmt.Printf("Conns per vtgate\n%v\n\n", counts) + fmt.Printf("Histogram of conn counts\n%v\n\n", histogram) + + plot := []float64{} + for i := 0; i < len(localVtgates); i++ { + plot = append(plot, float64(counts[localVtgates[i]])) + } + sort.Float64s(plot) + graph := asciigraph.Plot(plot) + fmt.Println("Number of conns per vtgate host") + fmt.Println(graph) + fmt.Println("") + fmt.Println("") + + fmt.Printf("Conn count per vtgate distribution [%d - %d] (%d clients => %d vtgates with %d zones %d conns\n\n", + min, max, *numClients, *numVtgates, *numZones, *numConnections) + plot = []float64{} + for i := min; i < max; i++ { + plot = append(plot, float64(histogram[i])) + } + graph = asciigraph.Plot(plot) + fmt.Println(graph) + + fmt.Printf("\nConn stats: min %d max %d spread %d spread/min %f spread/avg %f\n", + min, max, max-min, float64(max-min)/float64(min), float64(max-min)/float64((max+min)/2)) +} diff --git a/go/vt/vtgateproxy/sticky_random_balancer.go b/go/vt/vtgateproxy/sticky_random_balancer.go new file mode 100644 index 00000000000..ce6232df568 --- /dev/null +++ b/go/vt/vtgateproxy/sticky_random_balancer.go @@ -0,0 +1,110 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Sticky random is a derivative based on the round_robin balancer which uses a Context +// variable to maintain client-side affinity to a given connection. + +package vtgateproxy + +import ( + "math/rand/v2" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "vitess.io/vitess/go/vt/log" +) + +type ConnIdKey string + +const CONN_ID_KEY = ConnIdKey("ConnId") + +// newBuilder creates a new roundrobin balancer builder. +func newStickyRandomBuilder() balancer.Builder { + return base.NewBalancerBuilder("sticky_random", &stickyPickerBuilder{}, base.Config{HealthCheck: true}) +} + +func init() { + log.V(1).Infof("registering sticky_random balancer") + balancer.Register(newStickyRandomBuilder()) +} + +type stickyPickerBuilder struct{} + +// Would be nice if this were easier in golang +func boolValue(val interface{}) bool { + switch val := val.(type) { + case bool: + return val + } + return false +} + +func (*stickyPickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker { + // log.V(100).Infof("stickyRandomPicker: Build called with info: %v", info) + if len(info.ReadySCs) == 0 { + return base.NewErrPicker(balancer.ErrNoSubConnAvailable) + } + + scs := make([]balancer.SubConn, 0, len(info.ReadySCs)) + + // Where possible filter to only ready conns in the local zone, using the remote + // zone only if there are no local conns available. + for sc, scInfo := range info.ReadySCs { + local := boolValue(scInfo.Address.Attributes.Value(ZoneLocalAttr)) + if local { + scs = append(scs, sc) + } + } + + // Otherwise use all the ready conns regardless of locality + if len(scs) == 0 { + for sc := range info.ReadySCs { + scs = append(scs, sc) + } + } + + return &stickyPicker{ + subConns: scs, + } +} + +type stickyPicker struct { + // subConns is the snapshot of the balancer when this picker was + // created. The slice is immutable. + subConns []balancer.SubConn +} + +func (p *stickyPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + + subConnsLen := len(p.subConns) + + var connId int + connIdVal := info.Ctx.Value(CONN_ID_KEY) + if connIdVal != nil { + connId = connIdVal.(int) + log.V(100).Infof("stickyRandomPicker: using connId %d", connId) + } else { + log.V(100).Infof("stickyRandomPicker: nonexistent connId -- using random") + connId = rand.IntN(subConnsLen) // shouldn't happen + } + + // XXX/demmer might want to hash the connId rather than just mod + sc := p.subConns[connId%subConnsLen] + + return balancer.PickResult{SubConn: sc}, nil +} diff --git a/go/vt/vtgateproxy/vtgateproxy.go b/go/vt/vtgateproxy/vtgateproxy.go index c0afdf422a7..883e744058f 100644 --- a/go/vt/vtgateproxy/vtgateproxy.go +++ b/go/vt/vtgateproxy/vtgateproxy.go @@ -53,6 +53,7 @@ const ( var ( vtgateHostsFile = flag.String("vtgate_hosts_file", "", "json file describing the host list to use for vtgate:// resolution") numConnections = flag.Int("num_connections", 4, "number of outbound GPRC connections to maintain") + numBackupConns = flag.Int("num_backup_conns", 1, "number of backup remote-zone GPRC connections to maintain") poolTypeField = flag.String("pool_type_field", "", "Field name used to specify the target vtgate type and filter the hosts") affinityField = flag.String("affinity_field", "", "Attribute (JSON file) used to specify the routing affinity , e.g. 'az_id'") affinityValue = flag.String("affinity_value", "", "Value to match for routing affinity , e.g. 'use-az1'") @@ -218,6 +219,7 @@ func Init() { case "round_robin": case "first_ready": case "pick_first": + case "sticky_random": break default: log.Fatalf("invalid balancer type %s", *balancerType) @@ -235,6 +237,7 @@ func Init() { *affinityField, *affinityValue, *numConnections, + *numBackupConns, ) if err != nil {