diff --git a/go/vt/vtgateproxy/discovery.go b/go/vt/vtgateproxy/discovery.go
index 0235c4f1ac2..735f03572b7 100644
--- a/go/vt/vtgateproxy/discovery.go
+++ b/go/vt/vtgateproxy/discovery.go
@@ -59,6 +59,7 @@ import (
//
const PoolTypeAttr = "PoolType"
+const ZoneLocalAttr = "ZoneLocal"
// Resolver(https://godoc.org/google.golang.org/grpc/resolver#Resolver).
type JSONGateResolver struct {
@@ -83,6 +84,7 @@ type JSONGateResolverBuilder struct {
affinityField string
affinityValue string
numConnections int
+ numBackupConns int
mu sync.RWMutex
targets map[string][]targetHost
@@ -98,6 +100,7 @@ type targetHost struct {
Addr string
PoolType string
Affinity string
+ IsLocal bool
}
var (
@@ -113,6 +116,7 @@ func RegisterJSONGateResolver(
affinityField string,
affinityValue string,
numConnections int,
+ numBackupConns int,
) (*JSONGateResolverBuilder, error) {
jsonDiscovery := &JSONGateResolverBuilder{
targets: map[string][]targetHost{},
@@ -123,6 +127,7 @@ func RegisterJSONGateResolver(
affinityField: affinityField,
affinityValue: affinityValue,
numConnections: numConnections,
+ numBackupConns: numBackupConns,
sorter: newShuffleSorter(),
}
@@ -263,7 +268,7 @@ func (b *JSONGateResolverBuilder) parse() (bool, error) {
return false, fmt.Errorf("error parsing JSON discovery file %s: %v", b.jsonPath, err)
}
- var targets = map[string][]targetHost{}
+ var allTargets = map[string][]targetHost{}
for _, host := range hosts {
hostname, hasHostname := host["host"]
address, hasAddress := host[b.addressField]
@@ -309,8 +314,8 @@ func (b *JSONGateResolverBuilder) parse() (bool, error) {
return false, fmt.Errorf("error parsing JSON discovery file %s: port field %s has invalid value %v", b.jsonPath, b.portField, port)
}
- target := targetHost{hostname.(string), fmt.Sprintf("%s:%s", address, port), poolType.(string), affinity.(string)}
- targets[target.PoolType] = append(targets[target.PoolType], target)
+ target := targetHost{hostname.(string), fmt.Sprintf("%s:%s", address, port), poolType.(string), affinity.(string), affinity == b.affinityValue}
+ allTargets[target.PoolType] = append(allTargets[target.PoolType], target)
}
// If a pool disappears, the metric will not record this unless all counts
@@ -320,16 +325,25 @@ func (b *JSONGateResolverBuilder) parse() (bool, error) {
// targets and only resetting pools which disappear.
targetCount.ResetAll()
- for poolType := range targets {
- b.sorter.shuffleSort(targets[poolType], b.affinityField, b.affinityValue)
- if len(targets[poolType]) > *numConnections {
- targets[poolType] = targets[poolType][:b.numConnections]
+ var selected = map[string][]targetHost{}
+
+ for poolType := range allTargets {
+ b.sorter.shuffleSort(allTargets[poolType])
+
+ // try to pick numConnections from the front of the list (local zone) and numBackupConnections
+ // from the tail (remote zone). if that's not possible, just take the whole set
+ if len(allTargets[poolType]) >= b.numConnections+b.numBackupConns {
+ remoteOffset := len(allTargets[poolType]) - b.numBackupConns
+ selected[poolType] = append(allTargets[poolType][:b.numConnections], allTargets[poolType][remoteOffset:]...)
+ } else {
+ selected[poolType] = allTargets[poolType]
}
- targetCount.Set(poolType, int64(len(targets[poolType])))
+
+ targetCount.Set(poolType, int64(len(selected[poolType])))
}
b.mu.Lock()
- b.targets = targets
+ b.targets = selected
b.mu.Unlock()
return true, nil
@@ -353,7 +367,7 @@ func (b *JSONGateResolverBuilder) getTargets(poolType string) []targetHost {
targets = append(targets, b.targets[poolType]...)
b.mu.RUnlock()
- b.sorter.shuffleSort(targets, b.affinityField, b.affinityValue)
+ b.sorter.shuffleSort(targets)
return targets
}
@@ -373,7 +387,7 @@ func newShuffleSorter() *shuffleSorter {
// shuffleSort shuffles a slice of targetHost to ensure every host has a
// different order to iterate through, putting the affinity matching (e.g. same
// az) hosts at the front and the non-matching ones at the end.
-func (s *shuffleSorter) shuffleSort(targets []targetHost, affinityField, affinityValue string) {
+func (s *shuffleSorter) shuffleSort(targets []targetHost) {
n := len(targets)
head := 0
// Only need to do n-1 swaps since the last host is always in the right place.
@@ -383,7 +397,7 @@ func (s *shuffleSorter) shuffleSort(targets []targetHost, affinityField, affinit
j := head + s.rand.Intn(tail-head+1)
s.mu.Unlock()
- if affinityField != "" && affinityValue == targets[j].Affinity {
+ if targets[j].IsLocal {
targets[head], targets[j] = targets[j], targets[head]
head++
} else {
@@ -406,7 +420,8 @@ func (b *JSONGateResolverBuilder) update(r *JSONGateResolver) error {
var addrs []resolver.Address
for _, target := range targets {
- addrs = append(addrs, resolver.Address{Addr: target.Addr, Attributes: attributes.New(PoolTypeAttr, r.poolType)})
+ attrs := attributes.New(PoolTypeAttr, r.poolType).WithValue(ZoneLocalAttr, target.IsLocal)
+ addrs = append(addrs, resolver.Address{Addr: target.Addr, Attributes: attrs})
}
// If we've already selected some targets, give the new addresses some time to warm up before removing
@@ -488,12 +503,13 @@ const (
{{range $i, $p := .Pools}}
- {{$p}} |
+ {{$p}} |
{{range index $.Targets $p}}
{{.Hostname}} |
{{.Addr}} |
{{.Affinity}} |
+ {{.IsLocal}} |
{{end}}
{{end}}
diff --git a/go/vt/vtgateproxy/firstready_balancer.go b/go/vt/vtgateproxy/firstready_balancer.go
index 2885ae93d5f..8009f3dee86 100644
--- a/go/vt/vtgateproxy/firstready_balancer.go
+++ b/go/vt/vtgateproxy/firstready_balancer.go
@@ -41,12 +41,13 @@ import (
)
// newBuilder creates a new first_ready balancer builder.
-func newBuilder() balancer.Builder {
+func newFirstReadyBuilder() balancer.Builder {
return base.NewBalancerBuilder("first_ready", &frPickerBuilder{currentConns: map[string]balancer.SubConn{}}, base.Config{HealthCheck: true})
}
func init() {
- balancer.Register(newBuilder())
+ log.V(1).Infof("registering first_ready balancer")
+ balancer.Register(newFirstReadyBuilder())
}
// frPickerBuilder implements both the Builder and the Picker interfaces.
diff --git a/go/vt/vtgateproxy/mysql_server.go b/go/vt/vtgateproxy/mysql_server.go
index e6169649582..65771400036 100644
--- a/go/vt/vtgateproxy/mysql_server.go
+++ b/go/vt/vtgateproxy/mysql_server.go
@@ -219,6 +219,8 @@ func (ph *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sql
}
}()
+ ctx = context.WithValue(ctx, CONN_ID_KEY, int(c.ConnectionID))
+
if session.SessionPb().Options.Workload == querypb.ExecuteOptions_OLAP {
err := ph.proxy.StreamExecute(ctx, session, query, make(map[string]*querypb.BindVariable), callback)
return sqlerror.NewSQLErrorFromError(err)
@@ -285,6 +287,8 @@ func (ph *proxyHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[str
}
}(session)
+ ctx = context.WithValue(ctx, CONN_ID_KEY, int(c.ConnectionID))
+
_, fld, err := ph.proxy.Prepare(ctx, session, query, bindVars)
err = sqlerror.NewSQLErrorFromError(err)
if err != nil {
@@ -332,6 +336,8 @@ func (ph *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData
}
}()
+ ctx = context.WithValue(ctx, CONN_ID_KEY, int(c.ConnectionID))
+
if session.SessionPb().Options.Workload == querypb.ExecuteOptions_OLAP {
err := ph.proxy.StreamExecute(ctx, session, prepare.PrepareStmt, prepare.BindVars, callback)
return sqlerror.NewSQLErrorFromError(err)
@@ -396,6 +402,8 @@ func (ph *proxyHandler) getSession(ctx context.Context, c *mysql.Conn) (*vtgatec
options.ClientFoundRows = true
}
+ ctx = context.WithValue(ctx, CONN_ID_KEY, int(c.ConnectionID))
+
var err error
session, err = ph.proxy.NewSession(ctx, options, c.Attributes)
if err != nil {
@@ -420,6 +428,9 @@ func (ph *proxyHandler) closeSession(ctx context.Context, c *mysql.Conn) {
if session.SessionPb().InTransaction {
defer atomic.AddInt32(&busyConnections, -1)
}
+
+ ctx = context.WithValue(ctx, CONN_ID_KEY, int(c.ConnectionID))
+
err := ph.proxy.CloseSession(ctx, session)
if err != nil {
log.Errorf("Error happened in transaction rollback: %v", err)
diff --git a/go/vt/vtgateproxy/sim/vtgateproxysim.go b/go/vt/vtgateproxy/sim/vtgateproxysim.go
new file mode 100644
index 00000000000..e8c6faa934d
--- /dev/null
+++ b/go/vt/vtgateproxy/sim/vtgateproxysim.go
@@ -0,0 +1,101 @@
+package main
+
+import (
+ "flag"
+ "fmt"
+ "math/rand"
+ "sort"
+ "time"
+
+ "github.com/guptarohit/asciigraph"
+)
+
+var (
+ numClients = flag.Int("c", 9761, "Number of clients")
+ numVtgates = flag.Int("v", 1068, "Number of vtgates")
+ numConnections = flag.Int("n", 4, "number of connections per client host")
+ numZones = flag.Int("z", 4, "number of zones")
+)
+
+func main() {
+ rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
+
+ flag.Parse()
+
+ fmt.Printf("Simulating %d clients => %d vtgates with %d zones %d conns per client\n\n",
+ *numClients, *numVtgates, *numZones, *numConnections)
+
+ var clients []string
+ for i := 0; i < *numClients; i++ {
+ clients = append(clients, fmt.Sprintf("client-%03d", i))
+ }
+
+ var vtgates []string
+ for i := 0; i < *numVtgates; i++ {
+ vtgates = append(vtgates, fmt.Sprintf("vtgate-%03d", i))
+ }
+
+ // for now just consider 1/N of the s "local"
+ localClients := clients[:*numClients / *numZones]
+ localVtgates := vtgates[:*numVtgates / *numZones]
+
+ conns := map[string][]string{}
+
+ // Simulate "discovery"
+ for _, client := range localClients {
+ var clientConns []string
+
+ for i := 0; i < *numConnections; i++ {
+ vtgate := localVtgates[rnd.Intn(len(localVtgates))]
+ clientConns = append(clientConns, vtgate)
+ }
+
+ conns[client] = clientConns
+ }
+
+ counts := map[string]int{}
+ for _, conns := range conns {
+ for _, vtgate := range conns {
+ counts[vtgate]++
+ }
+ }
+
+ histogram := map[int]int{}
+ max := 0
+ min := -1
+ for _, count := range counts {
+ histogram[count]++
+ if count > max {
+ max = count
+ }
+ if min == -1 || count < min {
+ min = count
+ }
+ }
+
+ fmt.Printf("Conns per vtgate\n%v\n\n", counts)
+ fmt.Printf("Histogram of conn counts\n%v\n\n", histogram)
+
+ plot := []float64{}
+ for i := 0; i < len(localVtgates); i++ {
+ plot = append(plot, float64(counts[localVtgates[i]]))
+ }
+ sort.Float64s(plot)
+ graph := asciigraph.Plot(plot)
+ fmt.Println("Number of conns per vtgate host")
+ fmt.Println(graph)
+ fmt.Println("")
+ fmt.Println("")
+
+ fmt.Printf("Conn count per vtgate distribution [%d - %d] (%d clients => %d vtgates with %d zones %d conns\n\n",
+ min, max, *numClients, *numVtgates, *numZones, *numConnections)
+ plot = []float64{}
+ for i := min; i < max; i++ {
+ plot = append(plot, float64(histogram[i]))
+ }
+ graph = asciigraph.Plot(plot)
+ fmt.Println(graph)
+
+ fmt.Printf("\nConn stats: min %d max %d spread %d spread/min %f spread/avg %f\n",
+ min, max, max-min, float64(max-min)/float64(min), float64(max-min)/float64((max+min)/2))
+}
diff --git a/go/vt/vtgateproxy/sticky_random_balancer.go b/go/vt/vtgateproxy/sticky_random_balancer.go
new file mode 100644
index 00000000000..ce6232df568
--- /dev/null
+++ b/go/vt/vtgateproxy/sticky_random_balancer.go
@@ -0,0 +1,110 @@
+/*
+ *
+ * Copyright 2017 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+// Sticky random is a derivative based on the round_robin balancer which uses a Context
+// variable to maintain client-side affinity to a given connection.
+
+package vtgateproxy
+
+import (
+ "math/rand/v2"
+
+ "google.golang.org/grpc/balancer"
+ "google.golang.org/grpc/balancer/base"
+ "vitess.io/vitess/go/vt/log"
+)
+
+type ConnIdKey string
+
+const CONN_ID_KEY = ConnIdKey("ConnId")
+
+// newBuilder creates a new roundrobin balancer builder.
+func newStickyRandomBuilder() balancer.Builder {
+ return base.NewBalancerBuilder("sticky_random", &stickyPickerBuilder{}, base.Config{HealthCheck: true})
+}
+
+func init() {
+ log.V(1).Infof("registering sticky_random balancer")
+ balancer.Register(newStickyRandomBuilder())
+}
+
+type stickyPickerBuilder struct{}
+
+// Would be nice if this were easier in golang
+func boolValue(val interface{}) bool {
+ switch val := val.(type) {
+ case bool:
+ return val
+ }
+ return false
+}
+
+func (*stickyPickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker {
+ // log.V(100).Infof("stickyRandomPicker: Build called with info: %v", info)
+ if len(info.ReadySCs) == 0 {
+ return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
+ }
+
+ scs := make([]balancer.SubConn, 0, len(info.ReadySCs))
+
+ // Where possible filter to only ready conns in the local zone, using the remote
+ // zone only if there are no local conns available.
+ for sc, scInfo := range info.ReadySCs {
+ local := boolValue(scInfo.Address.Attributes.Value(ZoneLocalAttr))
+ if local {
+ scs = append(scs, sc)
+ }
+ }
+
+ // Otherwise use all the ready conns regardless of locality
+ if len(scs) == 0 {
+ for sc := range info.ReadySCs {
+ scs = append(scs, sc)
+ }
+ }
+
+ return &stickyPicker{
+ subConns: scs,
+ }
+}
+
+type stickyPicker struct {
+ // subConns is the snapshot of the balancer when this picker was
+ // created. The slice is immutable.
+ subConns []balancer.SubConn
+}
+
+func (p *stickyPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
+
+ subConnsLen := len(p.subConns)
+
+ var connId int
+ connIdVal := info.Ctx.Value(CONN_ID_KEY)
+ if connIdVal != nil {
+ connId = connIdVal.(int)
+ log.V(100).Infof("stickyRandomPicker: using connId %d", connId)
+ } else {
+ log.V(100).Infof("stickyRandomPicker: nonexistent connId -- using random")
+ connId = rand.IntN(subConnsLen) // shouldn't happen
+ }
+
+ // XXX/demmer might want to hash the connId rather than just mod
+ sc := p.subConns[connId%subConnsLen]
+
+ return balancer.PickResult{SubConn: sc}, nil
+}
diff --git a/go/vt/vtgateproxy/vtgateproxy.go b/go/vt/vtgateproxy/vtgateproxy.go
index c0afdf422a7..883e744058f 100644
--- a/go/vt/vtgateproxy/vtgateproxy.go
+++ b/go/vt/vtgateproxy/vtgateproxy.go
@@ -53,6 +53,7 @@ const (
var (
vtgateHostsFile = flag.String("vtgate_hosts_file", "", "json file describing the host list to use for vtgate:// resolution")
numConnections = flag.Int("num_connections", 4, "number of outbound GPRC connections to maintain")
+ numBackupConns = flag.Int("num_backup_conns", 1, "number of backup remote-zone GPRC connections to maintain")
poolTypeField = flag.String("pool_type_field", "", "Field name used to specify the target vtgate type and filter the hosts")
affinityField = flag.String("affinity_field", "", "Attribute (JSON file) used to specify the routing affinity , e.g. 'az_id'")
affinityValue = flag.String("affinity_value", "", "Value to match for routing affinity , e.g. 'use-az1'")
@@ -218,6 +219,7 @@ func Init() {
case "round_robin":
case "first_ready":
case "pick_first":
+ case "sticky_random":
break
default:
log.Fatalf("invalid balancer type %s", *balancerType)
@@ -235,6 +237,7 @@ func Init() {
*affinityField,
*affinityValue,
*numConnections,
+ *numBackupConns,
)
if err != nil {