Skip to content

Commit

Permalink
chore: simplify trampoline generation
Browse files Browse the repository at this point in the history
  • Loading branch information
y1yang0 committed Jul 30, 2024
1 parent b31b9e2 commit 1af9e32
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 49 deletions.
84 changes: 35 additions & 49 deletions tool/instrument/trampoline.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ func getHookParamTraits(t *api.InstFuncRule, onEnter bool) ([]ParamTrait, error)
// Find which parameter is type of interface{}
for i, field := range target.Type.Params.List {
attr := ParamTrait{Index: i}
if _, ok := field.Type.(*dst.InterfaceType); ok {
if shared.IsInterfaceType(field.Type) {
attr.IsInterfaceAny = true
}
if _, ok := field.Type.(*dst.Ellipsis); ok {
if shared.IsEllipsis(field.Type) {
attr.IsVaradic = true
}
attrs = append(attrs, attr)
Expand Down Expand Up @@ -243,32 +243,18 @@ func rectifyAnyType(paramList *dst.FieldList, traits []ParamTrait) error {
return nil
}

func (rp *RuleProcessor) addOnEnterHookVarDecl(t *api.InstFuncRule, traits []ParamTrait) error {
paramTypes := rp.buildTrampolineType(true)
func (rp *RuleProcessor) addHookFuncVar(t *api.InstFuncRule, traits []ParamTrait, onEnter bool) error {
paramTypes := rp.buildTrampolineType(onEnter)
addCallContext(paramTypes)
// Hook functions may uses interface{} as parameter type, as some types of
// raw function is not exposed, we need to use interface{} to represent them.
// raw function is not exposed
err := rectifyAnyType(paramTypes, traits)
if err != nil {
return fmt.Errorf("failed to rectify any type on enter: %w", err)
}

// Generate onEnter var decl
varDecl := shared.NewVarDecl(rp.makeOnXName(t, true), paramTypes)
rp.addDecl(varDecl)
return nil
}

func (rp *RuleProcessor) addOnExitVarHookDecl(t *api.InstFuncRule, traits []ParamTrait) error {
paramTypes := rp.buildTrampolineType(false)
addCallContext(paramTypes)
err := rectifyAnyType(paramTypes, traits)
if err != nil {
return fmt.Errorf("failed to rectify any type on exit: %w", err)
}

// Generate onExit var decl
varDecl := shared.NewVarDecl(rp.makeOnXName(t, false), paramTypes)
// Generate var decl
varDecl := shared.NewVarDecl(rp.makeOnXName(t, onEnter), paramTypes)
rp.addDecl(varDecl)
return nil
}
Expand Down Expand Up @@ -416,7 +402,7 @@ func setValue(field string, idx int, typ dst.Expr) *dst.CaseClause {
de := shared.DereferenceOf(pe)
val := shared.Ident(TrampolineValIdentifier)
assign := shared.AssignStmt(de, shared.TypeAssertExpr(val, typ))
if _, ok := typ.(*dst.InterfaceType); ok {
if shared.IsInterfaceType(typ) {
assign = shared.AssignStmt(ie, val)
}
caseClause := &dst.CaseClause{
Expand All @@ -435,7 +421,7 @@ func getValue(field string, idx int, typ dst.Expr) *dst.CaseClause {
pe := shared.ParenExpr(te)
de := shared.DereferenceOf(pe)
ret := shared.ReturnStmt(shared.Exprs(de))
if _, ok := typ.(*dst.InterfaceType); ok {
if shared.IsInterfaceType(typ) {
ret = shared.ReturnStmt(shared.Exprs(ie))
}
caseClause := &dst.CaseClause{
Expand Down Expand Up @@ -536,6 +522,29 @@ func (rp *RuleProcessor) rewriteCallContextImpl() {
}
}

func (rp *RuleProcessor) callHookFunc(t *api.InstFuncRule, onEnter bool) error {
traits, err := getHookParamTraits(t, onEnter)
if err != nil {
return fmt.Errorf("failed to get hook param traits: %w", err)
}
err = rp.addHookFuncVar(t, traits, onEnter)
if err != nil {
return fmt.Errorf("failed to add onEnter var hook decl: %w", err)
}
if onEnter {
err = rp.callOnEnterHook(t, traits)
} else {
err = rp.callOnExitHook(t, traits)
}
if err != nil {
return fmt.Errorf("failed to call onEnter: %w", err)
}
if !rp.replenishCallContext(onEnter) {
return errors.New("failed to replenish context in onEnter hook")
}
return nil
}

func (rp *RuleProcessor) generateTrampoline(t *api.InstFuncRule, funcDecl *dst.FuncDecl) error {
rp.rawFunc = funcDecl
// Materialize various declarations from template file, no one wants to see
Expand All @@ -548,45 +557,22 @@ func (rp *RuleProcessor) generateTrampoline(t *api.InstFuncRule, funcDecl *dst.F
rp.implementCallContext(t)
// Rewrite type-aware CallContext APIs
rp.rewriteCallContextImpl()

// Rename trampoline functions
rp.renameFunc(t)
// Rectify types of trampoline functions
rp.rectifyTypes()
// Generate calls to hook functions within trampoline functions
// Generate calls to hook functions
if t.OnEnter != "" {
traits, err := getHookParamTraits(t, true)
if err != nil {
return fmt.Errorf("failed to get hook param traits: %w", err)
}
err = rp.addOnEnterHookVarDecl(t, traits)
if err != nil {
return fmt.Errorf("failed to add onEnter var hook decl: %w", err)
}
err = rp.callOnEnterHook(t, traits)
err = rp.callHookFunc(t, true)
if err != nil {
return fmt.Errorf("failed to call onEnter: %w", err)
}
if !rp.replenishCallContext(true) {
return errors.New("failed to replenish context in onEnter hook")
}
}
if t.OnExit != "" {
traits, err := getHookParamTraits(t, false)
if err != nil {
return fmt.Errorf("failed to get hook param traits: %w", err)
}
err = rp.addOnExitVarHookDecl(t, traits)
if err != nil {
return fmt.Errorf("failed to add onExit var hook decl: %w", err)
}
err = rp.callOnExitHook(t, traits)
err = rp.callHookFunc(t, false)
if err != nil {
return fmt.Errorf("failed to call onExit: %w", err)
}
if !rp.replenishCallContext(false) {
return errors.New("failed to replenish context in onExit hook")
}
}
return nil
}
Expand Down
10 changes: 10 additions & 0 deletions tool/shared/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ func BoolFalse() *dst.BasicLit {
return &dst.BasicLit{Value: "false"}
}

func IsInterfaceType(typ dst.Expr) bool {
_, ok := typ.(*dst.InterfaceType)
return ok
}

func IsEllipsis(typ dst.Expr) bool {
_, ok := typ.(*dst.Ellipsis)
return ok
}

func InterfaceType() *dst.InterfaceType {
return &dst.InterfaceType{Methods: &dst.FieldList{List: nil}}
}
Expand Down

0 comments on commit 1af9e32

Please sign in to comment.