diff --git a/go/vt/vtgateproxy/discovery.go b/go/vt/vtgateproxy/discovery.go index 532823e8330..c5ceae7f3ed 100644 --- a/go/vt/vtgateproxy/discovery.go +++ b/go/vt/vtgateproxy/discovery.go @@ -64,7 +64,7 @@ func (b *JSONGateConfigDiscovery) Build(target resolver.Target, cc resolver.Clie filters := hostFilters{} filters["type"] = gateType - for k, _ := range queryOpts { + for k := range queryOpts { if strings.HasPrefix(k, queryParamFilterPrefix) { filteredPrefix := strings.TrimPrefix(k, queryParamFilterPrefix) filters[filteredPrefix] = queryOpts.Get(k) @@ -91,11 +91,6 @@ func RegisterJsonDiscovery() { fmt.Printf("Registered %v scheme\n", jsonDiscovery.Scheme()) } -type resolveFilters struct { - gate_type string - az_id string -} - type hostFilters = map[string]string // exampleResolver is a @@ -109,12 +104,10 @@ type resolveJSONGateConfig struct { filters hostFilters } -type discoverySlackAZ struct{} -type discoverySlackType struct{} type matchesFilter struct{} func (r *resolveJSONGateConfig) loadConfig() (*[]resolver.Address, []byte, error) { - pairs := []map[string]string{} + pairs := []map[string]interface{}{} fmt.Printf("Loading config %v\n", r.jsonPath) data, err := os.ReadFile(r.jsonPath) @@ -130,20 +123,25 @@ func (r *resolveJSONGateConfig) loadConfig() (*[]resolver.Address, []byte, error addrs := []resolver.Address{} for _, pair := range pairs { - attributes := attributes.New(matchesFilter{}, true) + filterMatch := false for k, v := range r.filters { - if pair[k] != v { - fmt.Printf("Filtering out %v", pair) - attributes.WithValue(matchesFilter{}, false) - continue + if pair[k] == v { + filterMatch = true + } else { + filterMatch = false } } + attrs := attributes.New(matchesFilter{}, "nomatch") + if filterMatch { + attrs = attributes.New(matchesFilter{}, "match") + } + // Add matching hosts to registration list addrs = append(addrs, resolver.Address{ Addr: fmt.Sprintf("%s:%s", pair["nebula_address"], pair["grpc"]), - BalancerAttributes: attributes, + BalancerAttributes: attrs, }) } @@ -205,7 +203,7 @@ func (r *resolveJSONGateConfig) start() { } // Make sure this wasn't a spurious change by checking the hash - if bytes.Compare(hash, newHash) == 0 && newHash != nil { + if bytes.Equal(hash, newHash) && newHash != nil { fmt.Printf("No content changed in discovery file... ignoring\n") continue } diff --git a/go/vt/vtgateproxy/gate_balancer.go b/go/vt/vtgateproxy/gate_balancer.go index 5045622f407..0a6fbc2b139 100644 --- a/go/vt/vtgateproxy/gate_balancer.go +++ b/go/vt/vtgateproxy/gate_balancer.go @@ -1,18 +1,13 @@ package vtgateproxy import ( - "context" "errors" "fmt" - "strconv" - "strings" - "sync" "sync/atomic" "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/metadata" ) // Name is the name of az affinity balancer. @@ -22,12 +17,6 @@ const MetadataDiscoveryFilterPrefix = "grpc_discovery_filter_" var logger = grpclog.Component("slack_affinity_balancer") -func WithSlackAZAffinityContext(ctx context.Context, numConnections string, filters metadata.MD) context.Context { - metadata.NewOutgoingContext(ctx, filters) - ctx = metadata.AppendToOutgoingContext(ctx, MetadataHostAffinityCount, numConnections) - return ctx -} - func newBuilder() balancer.Builder { return base.NewBalancerBuilder(Name, &slackAZAffinityBalancer{}, base.Config{HealthCheck: true}) } @@ -40,7 +29,7 @@ type slackAZAffinityBalancer struct{} func (*slackAZAffinityBalancer) Build(info base.PickerBuildInfo) balancer.Picker { logger.Infof("slackAZAffinityBalancer: Build called with info: %v", info) - fmt.Printf("Rebuilding picker\n") + fmt.Printf("Rebuilding picker: %v\n", info) if len(info.ReadySCs) == 0 { return base.NewErrPicker(balancer.ErrNoSubConnAvailable) @@ -49,15 +38,18 @@ func (*slackAZAffinityBalancer) Build(info base.PickerBuildInfo) balancer.Picker subConnsByFiltered := []balancer.SubConn{} for sc := range info.ReadySCs { - subConnInfo, _ := info.ReadySCs[sc] - matchesFilter := subConnInfo.Address.BalancerAttributes.Value(matchesFilter{}).(bool) + subConnInfo := info.ReadySCs[sc] + matchesFilter := subConnInfo.Address.BalancerAttributes.Value(matchesFilter{}).(string) allSubConns = append(allSubConns, sc) - if matchesFilter { + if matchesFilter == "match" { subConnsByFiltered = append(subConnsByFiltered, sc) } - } + + fmt.Printf("Filtered subcons: %v\n", len(subConnsByFiltered)) + fmt.Printf("All subcons: %v\n", len(allSubConns)) + return &slackAZAffinityPicker{ allSubConns: allSubConns, filteredSubConns: subConnsByFiltered, @@ -68,7 +60,6 @@ type slackAZAffinityPicker struct { // allSubConns is all subconns that were in the ready state when the picker was created allSubConns []balancer.SubConn filteredSubConns []balancer.SubConn - nextByAZ sync.Map next uint32 } @@ -77,60 +68,27 @@ func (p *slackAZAffinityPicker) pickFromSubconns(scList []balancer.SubConn, next subConnsLen := uint32(len(scList)) if subConnsLen == 0 { - return balancer.PickResult{}, errors.New("No hosts in list") + return balancer.PickResult{}, errors.New("no hosts in list") } - fmt.Printf("Select offset: %v %v %v\n", nextIndex, nextIndex%subConnsLen, len(scList)) - sc := scList[nextIndex%subConnsLen] + fmt.Printf("Select offset: %v %v %v %v\n", nextIndex, nextIndex%subConnsLen, len(scList), sc) + return balancer.PickResult{SubConn: sc}, nil } func (p *slackAZAffinityPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - hdrs, _ := metadata.FromOutgoingContext(info.Ctx) - numConnections := 0 - keys := hdrs.Get(MetadataAZKey) - if len(keys) < 1 { - return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) - } - az := keys[0] - - filteredSubconns := p.allSubConns - for k, v := range hdrs { - if strings.HasPrefix(k, MetadataDiscoveryFilterPrefix) { - filterName := strings.TrimPrefix(k, MetadataDiscoveryFilterPrefix) - filterValue := v - } - } - - for _, s := range v { - - } - - if az == "" { - return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) - } - - keys = hdrs.Get(MetadataHostAffinityCount) - if len(keys) > 0 { - if i, err := strconv.Atoi(keys[0]); err != nil { - numConnections = i - } - } - - subConns := p.subConnsByAZ[az] - if len(subConns) == 0 { - fmt.Printf("No subconns in az and gate type, pick from anywhere\n") + filteredSubConns := p.filteredSubConns + numConnections := *numConnectionsInt + if len(filteredSubConns) == 0 { + fmt.Printf("No subconns in the filtered list, pick from anywhere in pool\n") return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) } - val, _ := p.nextByAZ.LoadOrStore(az, new(uint32)) - ptr := val.(*uint32) - atomic.AddUint32(ptr, 1) - if len(subConns) >= numConnections && numConnections > 0 { + if len(filteredSubConns) >= numConnections && numConnections > 0 { fmt.Printf("Limiting to first %v\n", numConnections) - return p.pickFromSubconns(subConns[0:numConnections], *ptr) + return p.pickFromSubconns(filteredSubConns[0:numConnections], atomic.AddUint32(&p.next, 1)) } else { - return p.pickFromSubconns(subConns, *ptr) + return p.pickFromSubconns(filteredSubConns, atomic.AddUint32(&p.next, 1)) } } diff --git a/go/vt/vtgateproxy/vtgateproxy.go b/go/vt/vtgateproxy/vtgateproxy.go index 0ac2885adaa..f6932fe888c 100644 --- a/go/vt/vtgateproxy/vtgateproxy.go +++ b/go/vt/vtgateproxy/vtgateproxy.go @@ -23,7 +23,7 @@ import ( "flag" "fmt" "io" - "strconv" + "net/url" "strings" "sync" "time" @@ -57,13 +57,12 @@ type VTGateProxy struct { mu sync.Mutex } -func (proxy *VTGateProxy) getConnection(ctx context.Context, target string, filters metadata.MD) (*vtgateconn.VTGateConn, error) { - numConnectionsString := strconv.Itoa(*numConnectionsInt) - fmt.Printf("Getting connection for %v in %v with %v filters\n", target, filters) +func (proxy *VTGateProxy) getConnection(ctx context.Context, target string) (*vtgateconn.VTGateConn, error) { + fmt.Printf("Getting connection for %v\n", target) // If the connection exists, return it proxy.mu.Lock() - existingConn, _ := proxy.targetConns[target] + existingConn := proxy.targetConns[target] if existingConn != nil { proxy.mu.Unlock() return existingConn, nil @@ -79,7 +78,7 @@ func (proxy *VTGateProxy) getConnection(ctx context.Context, target string, filt return append(opts, grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"slack_affinity_balancer":{}}]}`)), nil }) - conn, err := vtgateconn.DialProtocol(WithSlackAZAffinityContext(ctx, numConnectionsString, filters), "grpc", target) + conn, err := vtgateconn.DialProtocol(ctx, "grpc", target) if err != nil { return nil, err } @@ -97,14 +96,22 @@ func (proxy *VTGateProxy) NewSession(ctx context.Context, options *querypb.Execu return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "no target string supplied by client") } + targetUrl := url.URL{ + Scheme: "vtgate", + Host: target, + } + filters := metadata.Pairs() + values := url.Values{} for k, v := range connectionAttributes { - if strings.HasPrefix(k, MetadataDiscoveryFilterPrefix) { + if strings.HasPrefix(k, queryParamFilterPrefix) { filters.Append(k, v) + values.Set(k, v) } } + targetUrl.RawQuery = values.Encode() - conn, err := proxy.getConnection(ctx, target, filters) + conn, err := proxy.getConnection(ctx, targetUrl.String()) if err != nil { return nil, err } @@ -116,7 +123,7 @@ func (proxy *VTGateProxy) NewSession(ctx context.Context, options *querypb.Execu // same effect as if a "rollback" statement was executed, but does not affect the query // statistics. func (proxy *VTGateProxy) CloseSession(ctx context.Context, session *vtgateconn.VTGateSession) error { - return session.CloseSession(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.gateType)) + return session.CloseSession(ctx) } // ResolveTransaction resolves the specified 2PC transaction. @@ -138,11 +145,11 @@ func (proxy *VTGateProxy) Execute(ctx context.Context, session *vtgateconn.VTGat return &sqltypes.Result{}, nil } - return session.Execute(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.gateType), sql, bindVariables) + return session.Execute(ctx, sql, bindVariables) } func (proxy *VTGateProxy) StreamExecute(ctx context.Context, session *vtgateconn.VTGateSession, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { - stream, err := session.StreamExecute(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.gateType), sql, bindVariables) + stream, err := session.StreamExecute(ctx, sql, bindVariables) if err != nil { return err }