Skip to content

Commit

Permalink
collection: support embedded structs when assigning
Browse files Browse the repository at this point in the history
Allow using embedded structs when assigning maps and programs from
a CollectionSpec. For example:

    type maps struct {
        Bar *ebpf.Map `ebpf:"bar"`
    }

    var objs struct {
        maps
        Prog *ebpf.Program `ebpf:"prog"`
    }

    spec.LoadAndAssign(&objs)

This is useful for separating out maps and programs, since they tend
to have different lifetimes: programs are loaded, attached and then
discarded. Maps stick around, since they are used to communicate with
the program.

The same functionality now applies to CollectionSpec.Assign as well.
  • Loading branch information
lmb committed Mar 10, 2021
1 parent 2c42beb commit 70f4a6d
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 32 deletions.
105 changes: 77 additions & 28 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,16 +441,8 @@ func (coll *Collection) DetachProgram(name string) *Program {

// Assign the contents of a collection to a struct.
//
// `to` must be a pointer to a struct like the following:
//
// struct {
// Foo *ebpf.Program `ebpf:"xdp_foo"`
// Bar *ebpf.Map `ebpf:"bar_map"`
// Ignored int
// }
//
// DetachMap and DetachProgram is invoked for all assigned elements
// if the function is successful.
// Deprecated: use CollectionSpec.Assign instead. It provides the same
// functionality but creates only the maps and programs requested.
func (coll *Collection) Assign(to interface{}) error {
assignedMaps := make(map[string]struct{})
assignedPrograms := make(map[string]struct{})
Expand Down Expand Up @@ -491,28 +483,86 @@ func (coll *Collection) Assign(to interface{}) error {
}

func assignValues(to interface{}, valueOf func(reflect.Type, string) (reflect.Value, error)) error {
v := reflect.ValueOf(to)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
return fmt.Errorf("%T is not a pointer to a struct", to)
type structField struct {
reflect.StructField
value reflect.Value
}

var (
fields []structField
visitedTypes = make(map[reflect.Type]bool)
flattenStruct func(reflect.Value) error
)

flattenStruct = func(structVal reflect.Value) error {
structType := structVal.Type()
if structType.Kind() != reflect.Struct {
return fmt.Errorf("%s is not a struct", structType)
}

if visitedTypes[structType] {
return fmt.Errorf("recursion on type %s", structType)
}

for i := 0; i < structType.NumField(); i++ {
field := structField{structType.Field(i), structVal.Field(i)}

name := field.Tag.Get("ebpf")
if name != "" {
fields = append(fields, field)
continue
}

var err error
switch field.Type.Kind() {
case reflect.Ptr:
if field.Type.Elem().Kind() != reflect.Struct {
continue
}

if field.value.IsNil() {
return fmt.Errorf("nil pointer to %s", structType)
}

err = flattenStruct(field.value.Elem())

case reflect.Struct:
err = flattenStruct(field.value)

default:
continue
}

if err != nil {
return fmt.Errorf("field %s: %s", field.Name, err)
}
}

return nil
}

toValue := reflect.ValueOf(to)
if toValue.Type().Kind() != reflect.Ptr {
return fmt.Errorf("%T is not a pointer to struct", to)
}

if toValue.IsNil() {
return fmt.Errorf("nil pointer to %T", to)
}

if err := flattenStruct(toValue.Elem()); err != nil {
return err
}

type elem struct {
// Either *Map or *Program
typ reflect.Type
name string
}

var (
s = v.Elem()
sT = s.Type()
assignedTo = make(map[elem]string)
)
for i := 0; i < sT.NumField(); i++ {
field := sT.Field(i)

assignedTo := make(map[elem]string)
for _, field := range fields {
name := field.Tag.Get("ebpf")
if name == "" {
continue
}
if strings.Contains(name, ",") {
return fmt.Errorf("field %s: ebpf tag contains a comma", field.Name)
}
Expand All @@ -527,12 +577,11 @@ func assignValues(to interface{}, valueOf func(reflect.Type, string) (reflect.Va
return fmt.Errorf("field %s: %w", field.Name, err)
}

fieldValue := s.Field(i)
if !fieldValue.CanSet() {
return fmt.Errorf("can't set value of field %s", field.Name)
if !field.value.CanSet() {
return fmt.Errorf("field %s: can't set value", field.Name)
}

fieldValue.Set(value)
field.value.Set(value)
assignedTo[e] = field.Name
}

Expand Down
95 changes: 91 additions & 4 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ebpf

import (
"fmt"
"reflect"
"testing"

"github.com/cilium/ebpf/asm"
Expand Down Expand Up @@ -276,11 +277,23 @@ func TestCollectionAssign(t *testing.T) {
Map *Map `ebpf:"map1"`
}

prog1 := coll.Programs["prog1"]
defer prog1.Close()

map1 := coll.Maps["map1"]
defer map1.Close()

if err := coll.Assign(&objs); err != nil {
t.Fatal("Can't Assign objects:", err)
}
objs.Program.Close()
objs.Map.Close()

if objs.Program != prog1 {
t.Errorf("Program is %p not %p", objs.Program, prog1)
}

if objs.Map != map1 {
t.Errorf("Map is %p not %p", objs.Map, map1)
}

if coll.Programs["prog1"] != nil {
t.Fatal("Assign doesn't detach Program")
Expand All @@ -291,6 +304,72 @@ func TestCollectionAssign(t *testing.T) {
}
}

func TestAssignValues(t *testing.T) {
zero := func(t reflect.Type, name string) (reflect.Value, error) {
return reflect.Zero(t), nil
}

type t1 struct {
Bar int `ebpf:"bar"`
}

type t2 struct {
t1
Foo int `ebpf:"foo"`
}

type t2ptr struct {
*t1
Foo int `ebpf:"foo"`
}

invalid := []struct {
name string
to interface{}
}{
{"non-struct", 1},
{"non-pointer struct", t1{}},
{"pointer to non-struct", new(int)},
{"embedded nil pointer", &t2ptr{}},
{"unexported field", new(struct {
foo int `ebpf:"foo"`
})},
{"identical tag", new(struct {
Foo1 int `ebpf:"foo"`
Foo2 int `ebpf:"foo"`
})},
}

for _, testcase := range invalid {
t.Run(testcase.name, func(t *testing.T) {
if err := assignValues(testcase.to, zero); err == nil {
t.Fatal("assignValues didn't return an error")
} else {
t.Log(err)
}
})
}

valid := []struct {
name string
to interface{}
}{
{"pointer to struct", new(t1)},
{"embedded struct", new(t2)},
{"embedded struct pointer", &t2ptr{t1: new(t1)}},
{"untagged field", new(struct{ Foo int })},
}

for _, testcase := range valid {
t.Run(testcase.name, func(t *testing.T) {
if err := assignValues(testcase.to, zero); err != nil {
t.Fatal("assignValues returned", err)
}
})
}

}

func ExampleCollectionSpec_Assign() {
spec := &CollectionSpec{
Maps: map[string]*MapSpec{
Expand All @@ -313,9 +392,13 @@ func ExampleCollectionSpec_Assign() {
},
}

type maps struct {
Map *MapSpec `ebpf:"map1"`
}

var specs struct {
maps
Program *ProgramSpec `ebpf:"prog1"`
Map *MapSpec `ebpf:"map1"`
}

if err := spec.Assign(&specs); err != nil {
Expand Down Expand Up @@ -392,9 +475,13 @@ func ExampleCollection_Assign() {
panic(err)
}

type maps struct {
Map *Map `ebpf:"map1"`
}

var objs struct {
maps
Program *Program `ebpf:"prog1"`
Map *Map `ebpf:"map1"`
}

if err := coll.Assign(&objs); err != nil {
Expand Down

0 comments on commit 70f4a6d

Please sign in to comment.