From 49728f709d7a171479c352a8fffe9828e70a896c Mon Sep 17 00:00:00 2001 From: Marcel van Lohuizen Date: Wed, 23 Aug 2023 13:40:32 +0200 Subject: [PATCH] internal/tdtest: make function detection more robust MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of detecting the package (which was already brittle), test the first argument of the closure of the second argument of Run. This allows other packages to wrap the tdtest.Run function, as long as they keep the same signature. This is necessary to handle errors in cuetest.Run. Signed-off-by: Marcel van Lohuizen Change-Id: If8dea69244fec9111916df667b0a8c09dc85fa4d Reviewed-on: https://review.gerrithub.io/c/cue-lang/cue/+/1167818 TryBot-Result: CUEcueckoo Unity-Result: CUE porcuepine Reviewed-by: Daniel Martí --- internal/tdtest/update.go | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/internal/tdtest/update.go b/internal/tdtest/update.go index e3f9b96d750..42bff91489f 100644 --- a/internal/tdtest/update.go +++ b/internal/tdtest/update.go @@ -204,10 +204,7 @@ func findFileAndPackage(path string, pkgs []*packages.Package) (*ast.File, *pack return nil, nil } -const ( - typeT = "*cuelang.org/go/internal/tdtest.T" - tdtestParen = `("cuelang.org/go/internal/tdtest")` -) +const typeT = "*cuelang.org/go/internal/tdtest.T" // findCalls finds all call expressions within a given block for functions // or methods defined within the tdtest package. @@ -229,10 +226,19 @@ func (i *info) findCalls(block *ast.BlockStmt, names ...string) []*callInfo { info := i.testPkg.TypesInfo for _, name := range names { if sel.Sel.Name == name { - if info.TypeOf(sel.X).String() == typeT { - } else if ident, ok := sel.X.(*ast.Ident); !ok { - return true // Run method. - } else if id, ok := info.Uses[ident].(*types.PkgName); ok && strings.Contains(id.String(), tdtestParen) { + receiver := info.TypeOf(sel.X).String() + if receiver == typeT { + // Method. + } else if len(c.Args) == 3 { + // Run function. + fn := c.Args[2].(*ast.FuncLit) + if len(fn.Type.Params.List) != 2 { + return true + } + argType := info.TypeOf(fn.Type.Params.List[0].Type).String() + if argType != typeT { + return true + } } else { return true }