Skip to content

Commit

Permalink
⚠️ always use up-to-date context
Browse files Browse the repository at this point in the history
Signed-off-by: Pranav Gaikwad <[email protected]>
  • Loading branch information
pranavgaikwad committed Sep 14, 2023
1 parent 46fc5b7 commit 5a1c361
Show file tree
Hide file tree
Showing 18 changed files with 100 additions and 94 deletions.
4 changes: 2 additions & 2 deletions cmd/dep/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func main() {
}

if treeOutput {
deps, err := prov.GetDependenciesDAG()
deps, err := prov.GetDependenciesDAG(ctx)
if err != nil {
log.Error(err, "failed to get list of dependencies for provider", "provider", name)
continue
Expand All @@ -116,7 +116,7 @@ func main() {
})
}
} else {
deps, err := prov.GetDependencies()
deps, err := prov.GetDependencies(ctx)
if err != nil {
log.Error(err, "failed to get list of dependencies for provider", "provider", name)
continue
Expand Down
8 changes: 4 additions & 4 deletions engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (r *ruleEngine) RunRules(ctx context.Context, ruleSets []RuleSet, selectors
rs.Errors[response.Rule.RuleID] = response.Err.Error()
}
} else if response.ConditionResponse.Matched && len(response.ConditionResponse.Incidents) > 0 {
violation, err := r.createViolation(response.ConditionResponse, response.Rule)
violation, err := r.createViolation(ctx, response.ConditionResponse, response.Rule)
if err != nil {
r.logger.Error(err, "unable to create violation from response")
}
Expand Down Expand Up @@ -388,7 +388,7 @@ func processRule(ctx context.Context, rule Rule, ruleCtx ConditionContext, log l

}

func (r *ruleEngine) createViolation(conditionResponse ConditionResponse, rule Rule) (konveyor.Violation, error) {
func (r *ruleEngine) createViolation(ctx context.Context, conditionResponse ConditionResponse, rule Rule) (konveyor.Violation, error) {
incidents := []konveyor.Incident{}
fileCodeSnipCount := map[string]int{}
incidentsSet := map[string]struct{}{} // Set of incidents
Expand All @@ -409,7 +409,7 @@ func (r *ruleEngine) createViolation(conditionResponse ConditionResponse, rule R
// Some violations may not have a location in code.
limitSnip := (r.codeSnipLimit != 0 && fileCodeSnipCount[string(m.FileURI)] == r.codeSnipLimit)
if !limitSnip {
codeSnip, err := r.getCodeLocation(m, rule)
codeSnip, err := r.getCodeLocation(ctx, m, rule)
if err != nil || codeSnip == "" {
r.logger.V(6).Error(err, "unable to get code location")
} else {
Expand Down Expand Up @@ -496,7 +496,7 @@ func (r *ruleEngine) createViolation(conditionResponse ConditionResponse, rule R
}, nil
}

func (r *ruleEngine) getCodeLocation(m IncidentContext, rule Rule) (codeSnip string, err error) {
func (r *ruleEngine) getCodeLocation(ctx context.Context, m IncidentContext, rule Rule) (codeSnip string, err error) {
if m.CodeLocation == nil {
r.logger.V(6).Info("unable to get the code snip", "URI", m.FileURI)
return "", nil
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package generic

import (
"context"
"encoding/json"
"fmt"
"os/exec"
Expand All @@ -9,7 +10,7 @@ import (
"go.lsp.dev/uri"
)

func (g *genericServiceClient) GetDependencies() (map[uri.URI][]*provider.Dep, error) {
func (g *genericServiceClient) GetDependencies(ctx context.Context) (map[uri.URI][]*provider.Dep, error) {
cmdStr, isString := g.config.ProviderSpecificConfig["dependencyProviderPath"].(string)
if !isString {
return nil, fmt.Errorf("dependency provider path is not a string")
Expand All @@ -33,6 +34,6 @@ func (g *genericServiceClient) GetDependencies() (map[uri.URI][]*provider.Dep, e
return m, err
}

func (p *genericServiceClient) GetDependenciesDAG() (map[uri.URI][]provider.DepDAGItem, error) {
func (p *genericServiceClient) GetDependenciesDAG(ctx context.Context) (map[uri.URI][]provider.DepDAGItem, error) {
return nil, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ func (p *genericProvider) Init(ctx context.Context, log logr.Logger, c provider.

svcClient := genericServiceClient{
rpc: rpc,
ctx: ctx,
cancelFunc: cancelFunc,
cmd: cmd,
config: c,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (p *genericServiceClient) Stop() {
p.cancelFunc()
p.cmd.Wait()
}
func (p *genericServiceClient) Evaluate(cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
func (p *genericServiceClient) Evaluate(ctx context.Context, cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
var cond genericCondition
err := yaml.Unmarshal(conditionInfo, &cond)
if err != nil {
Expand All @@ -41,11 +41,11 @@ func (p *genericServiceClient) Evaluate(cap string, conditionInfo []byte) (provi
return provider.ProviderEvaluateResponse{}, fmt.Errorf("unable to get query info")
}

symbols := p.GetAllSymbols(query)
symbols := p.GetAllSymbols(ctx, query)

incidents := []provider.IncidentContext{}
for _, s := range symbols {
references := p.GetAllReferences(s)
references := p.GetAllReferences(ctx, s)
for _, ref := range references {
// Look for things that are in the location loaded, //Note may need to filter out vendor at some point
if strings.Contains(ref.URI, p.config.Location) {
Expand Down Expand Up @@ -74,15 +74,15 @@ func (p *genericServiceClient) Evaluate(cap string, conditionInfo []byte) (provi
}, nil
}

func (p *genericServiceClient) GetAllSymbols(query string) []protocol.WorkspaceSymbol {
func (p *genericServiceClient) GetAllSymbols(ctx context.Context, query string) []protocol.WorkspaceSymbol {

wsp := &protocol.WorkspaceSymbolParams{
Query: query,
}

var refs []protocol.WorkspaceSymbol
fmt.Printf("\nrpc call\n")
err := p.rpc.Call(context.TODO(), "workspace/symbol", wsp, &refs)
err := p.rpc.Call(ctx, "workspace/symbol", wsp, &refs)
fmt.Printf("\nrpc called\n")
if err != nil {
fmt.Printf("\n\nerror: %v\n", err)
Expand All @@ -91,7 +91,7 @@ func (p *genericServiceClient) GetAllSymbols(query string) []protocol.WorkspaceS
return refs
}

func (p *genericServiceClient) GetAllReferences(symbol protocol.WorkspaceSymbol) []protocol.Location {
func (p *genericServiceClient) GetAllReferences(ctx context.Context, symbol protocol.WorkspaceSymbol) []protocol.Location {
params := &protocol.ReferenceParams{
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
TextDocument: protocol.TextDocumentIdentifier{
Expand All @@ -102,7 +102,7 @@ func (p *genericServiceClient) GetAllReferences(symbol protocol.WorkspaceSymbol)
}

res := []protocol.Location{}
err := p.rpc.Call(p.ctx, "textDocument/references", params, &res)
err := p.rpc.Call(ctx, "textDocument/references", params, &res)
if err != nil {
fmt.Printf("Error rpc: %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions parser/rule_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ func (t testProvider) Init(ctx context.Context, log logr.Logger, config provider
return nil, nil
}

func (t testProvider) Evaluate(cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
func (t testProvider) Evaluate(ctx context.Context, cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
return provider.ProviderEvaluateResponse{}, nil
}

func (t testProvider) GetDependencies() (map[uri.URI][]*provider.Dep, error) {
func (t testProvider) GetDependencies(ctx context.Context) (map[uri.URI][]*provider.Dep, error) {
return nil, nil
}

func (t testProvider) GetDependenciesDAG() (map[uri.URI][]provider.DepDAGItem, error) {
func (t testProvider) GetDependenciesDAG(ctx context.Context) (map[uri.URI][]provider.DepDAGItem, error) {
return nil, nil
}

Expand Down
14 changes: 6 additions & 8 deletions provider/grpc/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ func NewGRPCClient(config provider.Config, log logr.Logger) *grpcProvider {
}

func (g *grpcProvider) ProviderInit(ctx context.Context) error {
g.ctx = ctx
for _, c := range g.config.InitConfig {
s, err := g.Init(ctx, g.log, c)
if err != nil {
Expand Down Expand Up @@ -99,22 +98,21 @@ func (g *grpcProvider) Init(ctx context.Context, log logr.Logger, config provide
}
return &grpcServiceClient{
id: r.Id,
ctx: ctx,
config: config,
client: g.Client,
}, nil
}

func (g *grpcProvider) Evaluate(cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
return provider.FullResponseFromServiceClients(g.serviceClients, cap, conditionInfo)
func (g *grpcProvider) Evaluate(ctx context.Context, cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
return provider.FullResponseFromServiceClients(ctx, g.serviceClients, cap, conditionInfo)
}

func (g *grpcProvider) GetDependencies() (map[uri.URI][]*provider.Dep, error) {
return provider.FullDepsResponse(g.serviceClients)
func (g *grpcProvider) GetDependencies(ctx context.Context) (map[uri.URI][]*provider.Dep, error) {
return provider.FullDepsResponse(ctx, g.serviceClients)
}

func (g *grpcProvider) GetDependenciesDAG() (map[uri.URI][]provider.DepDAGItem, error) {
return provider.FullDepDAGResponse(g.serviceClients)
func (g *grpcProvider) GetDependenciesDAG(ctx context.Context) (map[uri.URI][]provider.DepDAGItem, error) {
return provider.FullDepDAGResponse(ctx, g.serviceClients)
}

func (g *grpcProvider) Stop() {
Expand Down
13 changes: 6 additions & 7 deletions provider/grpc/service_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@ import (

type grpcServiceClient struct {
id int64
ctx context.Context
config provider.InitConfig
client pb.ProviderServiceClient
}

var _ provider.ServiceClient = &grpcServiceClient{}

func (g *grpcServiceClient) Evaluate(cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
func (g *grpcServiceClient) Evaluate(ctx context.Context, cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
m := pb.EvaluateRequest{
Cap: cap,
ConditionInfo: string(conditionInfo),
Id: g.id,
}
r, err := g.client.Evaluate(g.ctx, &m)
r, err := g.client.Evaluate(ctx, &m)
if err != nil {
return provider.ProviderEvaluateResponse{}, err
}
Expand Down Expand Up @@ -85,8 +84,8 @@ func (g *grpcServiceClient) Evaluate(cap string, conditionInfo []byte) (provider
}

// We don't have dependencies
func (g *grpcServiceClient) GetDependencies() (map[uri.URI][]*provider.Dep, error) {
d, err := g.client.GetDependencies(g.ctx, &pb.ServiceRequest{Id: g.id})
func (g *grpcServiceClient) GetDependencies(ctx context.Context) (map[uri.URI][]*provider.Dep, error) {
d, err := g.client.GetDependencies(ctx, &pb.ServiceRequest{Id: g.id})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -140,8 +139,8 @@ func recreateDAGAddedItems(items []*pb.DependencyDAGItem) []provider.DepDAGItem
}

// We don't have dependencies
func (g *grpcServiceClient) GetDependenciesDAG() (map[uri.URI][]provider.DepDAGItem, error) {
d, err := g.client.GetDependenciesDAG(g.ctx, &pb.ServiceRequest{Id: g.id})
func (g *grpcServiceClient) GetDependenciesDAG(ctx context.Context) (map[uri.URI][]provider.DepDAGItem, error) {
d, err := g.client.GetDependenciesDAG(ctx, &pb.ServiceRequest{Id: g.id})
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions provider/internal/builtin/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ func (p *builtinProvider) Capabilities() []provider.Capability {
return capabilities
}

func (p *builtinProvider) ProviderInit(context.Context) error {
func (p *builtinProvider) ProviderInit(ctx context.Context) error {
// First load all the tags for all init configs.
for _, c := range p.config.InitConfig {
p.loadTags(c)
}

for _, c := range p.config.InitConfig {
client, err := p.Init(p.ctx, p.log, c)
client, err := p.Init(ctx, p.log, c)
if err != nil {
return nil
}
Expand All @@ -124,7 +124,7 @@ func (p *builtinProvider) Init(ctx context.Context, log logr.Logger, config prov
if config.AnalysisMode != provider.AnalysisMode("") {
p.log.V(5).Info("skipping analysis mode setting for builtin")
}
return &builtintServiceClient{
return &builtinServiceClient{
config: config,
tags: p.tags,
UnimplementedDependenciesComponent: provider.UnimplementedDependenciesComponent{},
Expand Down Expand Up @@ -157,8 +157,8 @@ func (p *builtinProvider) loadTags(config provider.InitConfig) error {
return nil
}

func (p *builtinProvider) Evaluate(cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
return provider.FullResponseFromServiceClients(p.clients, cap, conditionInfo)
func (p *builtinProvider) Evaluate(ctx context.Context, cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
return provider.FullResponseFromServiceClients(ctx, p.clients, cap, conditionInfo)
}

func (p *builtinProvider) Stop() {
Expand Down
9 changes: 5 additions & 4 deletions provider/internal/builtin/service_client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package builtin

import (
"context"
"fmt"
"io/fs"
"os"
Expand All @@ -18,19 +19,19 @@ import (
"gopkg.in/yaml.v2"
)

type builtintServiceClient struct {
type builtinServiceClient struct {
config provider.InitConfig
tags map[string]bool
provider.UnimplementedDependenciesComponent
}

var _ provider.ServiceClient = &builtintServiceClient{}
var _ provider.ServiceClient = &builtinServiceClient{}

func (p *builtintServiceClient) Stop() {
func (p *builtinServiceClient) Stop() {
return
}

func (p *builtintServiceClient) Evaluate(cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
func (p *builtinServiceClient) Evaluate(ctx context.Context, cap string, conditionInfo []byte) (provider.ProviderEvaluateResponse, error) {
var cond builtinCondition
err := yaml.Unmarshal(conditionInfo, &cond)
if err != nil {
Expand Down
13 changes: 7 additions & 6 deletions provider/internal/java/dependency.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package java
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"io/fs"
Expand Down Expand Up @@ -42,7 +43,7 @@ func (p *javaServiceClient) findPom() string {
return f
}

func (p *javaServiceClient) GetDependencies() (map[uri.URI][]*provider.Dep, error) {
func (p *javaServiceClient) GetDependencies(ctx context.Context) (map[uri.URI][]*provider.Dep, error) {
if p.depsCache != nil {
return p.depsCache, nil
}
Expand All @@ -54,12 +55,12 @@ func (p *javaServiceClient) GetDependencies() (map[uri.URI][]*provider.Dep, erro
// for binaries we only find JARs embedded in archive
p.discoverDepsFromJars(p.config.DependencyPath, ll)
} else {
ll, err = p.GetDependenciesDAG()
ll, err = p.GetDependenciesDAG(ctx)
if err != nil {
return p.GetDependencyFallback()
return p.GetDependencyFallback(ctx)
}
if len(ll) == 0 {
return p.GetDependencyFallback()
return p.GetDependencyFallback(ctx)
}
}
for f, ds := range ll {
Expand Down Expand Up @@ -94,7 +95,7 @@ func (p *javaServiceClient) getLocalRepoPath() string {
return string(outb.String())
}

func (p *javaServiceClient) GetDependencyFallback() (map[uri.URI][]*provider.Dep, error) {
func (p *javaServiceClient) GetDependencyFallback(ctx context.Context) (map[uri.URI][]*provider.Dep, error) {
pomDependencyQuery := "//dependencies/dependency/*"
path := p.findPom()
file := uri.File(path)
Expand Down Expand Up @@ -143,7 +144,7 @@ func (p *javaServiceClient) GetDependencyFallback() (map[uri.URI][]*provider.Dep
return m, nil
}

func (p *javaServiceClient) GetDependenciesDAG() (map[uri.URI][]provider.DepDAGItem, error) {
func (p *javaServiceClient) GetDependenciesDAG(ctx context.Context) (map[uri.URI][]provider.DepDAGItem, error) {
localRepoPath := p.getLocalRepoPath()

path := p.findPom()
Expand Down
Loading

0 comments on commit 5a1c361

Please sign in to comment.