diff --git a/pkg/ebpf/event_parameters.go b/pkg/ebpf/event_parameters.go index a148ae33e080..e0a227fe24ee 100644 --- a/pkg/ebpf/event_parameters.go +++ b/pkg/ebpf/event_parameters.go @@ -96,7 +96,7 @@ func attachSuspiciousSyscallSourceProbes(t *Tracee, eventParams []map[string]fil // Attach probes i = 0 for syscallName := range syscalls { - if err := probeGroup.Attach(probes.Handle(i), t.kernelSymbols); err != nil { + if err := probeGroup.Attach(probes.Handle(i), t.getKernelSymbols()); err != nil { // Report attachment errors but don't fail, because it may be a syscall that doesn't exist on this system logger.Warnw("Failed to attach suspicious_syscall_source kprobe", "syscall", syscallName, "error", err) } diff --git a/pkg/ebpf/hooked_syscall_table.go b/pkg/ebpf/hooked_syscall_table.go index a7382dd677ad..ae2a1ff94d57 100644 --- a/pkg/ebpf/hooked_syscall_table.go +++ b/pkg/ebpf/hooked_syscall_table.go @@ -160,11 +160,11 @@ func (t *Tracee) getSyscallNameByKerVer(restrictions []events.KernelRestrictions // populateExpectedSyscallTableArray fills the expected values of the syscall table func (t *Tracee) populateExpectedSyscallTableArray(tableMap *bpf.BPFMap) error { // Get address to the function that defines the not implemented sys call - niSyscallSymbol, err := t.kernelSymbols.GetSymbolByOwnerAndName("system", events.SyscallPrefix+"ni_syscall") + niSyscallSymbol, err := t.getKernelSymbols().GetSymbolByOwnerAndName("system", events.SyscallPrefix+"ni_syscall") if err != nil { e := err // RHEL 8.x uses sys_ni_syscall instead of __arch_ni_syscall - niSyscallSymbol, err = t.kernelSymbols.GetSymbolByOwnerAndName("system", "sys_ni_syscall") + niSyscallSymbol, err = t.getKernelSymbols().GetSymbolByOwnerAndName("system", "sys_ni_syscall") if err != nil { logger.Debugw("hooked_syscall: syscall symbol not found", "name", "sys_ni_syscall") return e @@ -188,7 +188,7 @@ func (t *Tracee) populateExpectedSyscallTableArray(tableMap *bpf.BPFMap) error { continue } - kernelSymbol, err := t.kernelSymbols.GetSymbolByOwnerAndName("system", events.SyscallPrefix+syscallName) + kernelSymbol, err := t.getKernelSymbols().GetSymbolByOwnerAndName("system", events.SyscallPrefix+syscallName) if err != nil { logger.Warnw(fmt.Sprintf("hooked_syscall: Unable to locate syscall symbol... permanently skipping hook check for syscall ID %d", index)) zero := 0 diff --git a/pkg/ebpf/ksymbols.go b/pkg/ebpf/ksymbols.go index 64fa239a696a..b82f4b4d1b0c 100644 --- a/pkg/ebpf/ksymbols.go +++ b/pkg/ebpf/ksymbols.go @@ -45,7 +45,7 @@ func (t *Tracee) UpdateKallsyms() error { // For every ksymbol required by tracee ... for _, required := range allReqSymbols { // ... get the symbol address from the kallsyms file ... - symbol, err := t.kernelSymbols.GetSymbolByOwnerAndName(globalSymbolOwner, required) + symbol, err := t.getKernelSymbols().GetSymbolByOwnerAndName(globalSymbolOwner, required) if err != nil { logger.Debugw("failed to get symbol", "symbol", required, "error", err) continue diff --git a/pkg/ebpf/probes/trace.go b/pkg/ebpf/probes/trace.go index ef67ac7351fd..f5df4d127267 100644 --- a/pkg/ebpf/probes/trace.go +++ b/pkg/ebpf/probes/trace.go @@ -111,7 +111,7 @@ func (p *TraceProbe) attach(module *bpf.Module, args ...interface{}) error { var err error var link *bpf.BPFLink var attachFunc func(uint64) (*bpf.BPFLink, error) - var syms []environment.KernelSymbol + var syms []*environment.KernelSymbol // https://github.com/aquasecurity/tracee/issues/3653#issuecomment-1832642225 // // After commit b022f0c7e404 ('tracing/kprobes: Return EADDRNOTAVAIL diff --git a/pkg/ebpf/processor_funcs.go b/pkg/ebpf/processor_funcs.go index 253b16fd5116..c5e817451df3 100644 --- a/pkg/ebpf/processor_funcs.go +++ b/pkg/ebpf/processor_funcs.go @@ -21,6 +21,7 @@ import ( "github.com/aquasecurity/tracee/pkg/logger" "github.com/aquasecurity/tracee/pkg/time" "github.com/aquasecurity/tracee/pkg/utils" + "github.com/aquasecurity/tracee/pkg/utils/environment" "github.com/aquasecurity/tracee/types/trace" ) @@ -233,10 +234,11 @@ func (t *Tracee) processDoInitModule(event *trace.Event) error { err := capabilities.GetInstance().EBPF( func() error { - err := t.kernelSymbols.Refresh() + newKernelSymbols, err := environment.NewKernelSymbolTable(true, true, t.requiredKsyms...) if err != nil { return errfmt.WrapError(err) } + t.setKernelSymbols(newKernelSymbols) return t.UpdateKallsyms() }, ) @@ -281,7 +283,7 @@ func (t *Tracee) processHookedProcFops(event *trace.Event) error { if addr == 0 { // address is in text segment, marked as 0 continue } - hookingFunction := utils.ParseSymbol(addr, t.kernelSymbols) + hookingFunction := t.getKernelSymbols().GetPotentiallyHiddenSymbolByAddr(addr)[0] if hookingFunction.Owner == "system" { continue } @@ -326,7 +328,7 @@ func (t *Tracee) processPrintMemDump(event *trace.Event) error { } addressUint64 := uint64(address) - symbol := utils.ParseSymbol(addressUint64, t.kernelSymbols) + symbol := t.getKernelSymbols().GetPotentiallyHiddenSymbolByAddr(addressUint64)[0] var utsName unix.Utsname arch := "" if err := unix.Uname(&utsName); err != nil { diff --git a/pkg/ebpf/tracee.go b/pkg/ebpf/tracee.go index 7167d6b13ef9..523344eab72a 100644 --- a/pkg/ebpf/tracee.go +++ b/pkg/ebpf/tracee.go @@ -78,9 +78,8 @@ type Tracee struct { writtenFiles map[string]string netCapturePcap *pcaps.Pcaps // Internal Data - readFiles map[string]string - pidsInMntns bucketscache.BucketsCache // first n PIDs in each mountns - kernelSymbols *environment.KernelSymbolTable + readFiles map[string]string + pidsInMntns bucketscache.BucketsCache // first n PIDs in each mountns // eBPF bpfModule *bpf.Module defaultProbes *probes.ProbeGroup @@ -123,6 +122,9 @@ type Tracee struct { policyManager *policy.Manager // The dependencies of events used by Tracee eventsDependencies *dependencies.Manager + // A reference to a environment.KernelSymbolTable that might change at runtime. + // This should only be accessed using t.getKernelSymbols() and t.setKernelSymbols() + kernelSymbols unsafe.Pointer // Ksymbols needed to be kept alive in table. // This does not mean they are required for tracee to function. // TODO: remove this in favor of dependency manager nodes @@ -137,6 +139,14 @@ func (t *Tracee) Engine() *engine.Engine { return t.sigEngine } +func (t *Tracee) getKernelSymbols() *environment.KernelSymbolTable { + return (*environment.KernelSymbolTable)(atomic.LoadPointer(&t.kernelSymbols)) +} + +func (t *Tracee) setKernelSymbols(kernelSymbols *environment.KernelSymbolTable) { + atomic.StorePointer(&t.kernelSymbols, unsafe.Pointer(kernelSymbols)) +} + // New creates a new Tracee instance based on a given valid Config. It is expected that it won't // cause external system side effects (reads, writes, etc). func New(cfg config.Config) (*Tracee, error) { @@ -362,12 +372,13 @@ func (t *Tracee) Init(ctx gocontext.Context) error { err = capabilities.GetInstance().Specific( func() error { - t.kernelSymbols, err = environment.NewKernelSymbolTable( - environment.WithRequiredSymbols(t.requiredKsyms), - ) - // Cleanup memory in list - t.requiredKsyms = []string{} - return err + // t.requiredKsyms may contain non-data symbols, but it doesn't affect the validity of this call + kernelSymbols, err := environment.NewKernelSymbolTable(true, true, t.requiredKsyms...) + if err != nil { + return err + } + t.setKernelSymbols(kernelSymbols) + return nil }, cap.SYSLOG, ) @@ -604,13 +615,13 @@ func (t *Tracee) initDerivationTable() error { events.SyscallTableCheck: { events.HookedSyscall: { Enabled: shouldSubmit(events.SyscallTableCheck), - DeriveFunction: derive.DetectHookedSyscall(t.kernelSymbols), + DeriveFunction: derive.DetectHookedSyscall(t.getKernelSymbols()), }, }, events.PrintNetSeqOps: { events.HookedSeqOps: { Enabled: shouldSubmit(events.HookedSeqOps), - DeriveFunction: derive.HookedSeqOps(t.kernelSymbols), + DeriveFunction: derive.HookedSeqOps(t.getKernelSymbols()), }, }, events.HiddenKernelModuleSeeker: { @@ -913,18 +924,12 @@ func getUnavailbaleKsymbols(ksymbols []events.KSymbol, kernelSymbols *environmen var unavailableSymbols []events.KSymbol for _, ksymbol := range ksymbols { - sym, err := kernelSymbols.GetSymbolByName(ksymbol.GetSymbolName()) + _, err := kernelSymbols.GetSymbolByName(ksymbol.GetSymbolName()) if err != nil { // If the symbol is not found, it means it's unavailable. unavailableSymbols = append(unavailableSymbols, ksymbol) continue } - for _, s := range sym { - if s.Address == 0 { - // Same if the symbol is found but its address is 0. - unavailableSymbols = append(unavailableSymbols, ksymbol) - } - } } return unavailableSymbols } @@ -944,7 +949,7 @@ func (t *Tracee) validateKallsymsDependencies() { } validateEvent := func(eventId events.ID) bool { - missingDepSyms := getUnavailbaleKsymbols(evtDefSymDeps(eventId), t.kernelSymbols) + missingDepSyms := getUnavailbaleKsymbols(evtDefSymDeps(eventId), t.getKernelSymbols()) shouldFailEvent := false for _, symDep := range missingDepSyms { if symDep.IsRequired() { @@ -1159,7 +1164,7 @@ func (t *Tracee) attachEvent(id events.ID) error { return err } for _, probe := range depsNode.GetDependencies().GetProbes() { - err := t.defaultProbes.Attach(probe.GetHandle(), t.cgroups, t.kernelSymbols) + err := t.defaultProbes.Attach(probe.GetHandle(), t.cgroups, t.getKernelSymbols()) if err == nil { continue } @@ -1192,7 +1197,7 @@ func (t *Tracee) attachProbes() error { logger.Errorw("Got node from type not requested") return nil } - err := t.defaultProbes.Attach(probeNode.GetHandle(), t.cgroups, t.kernelSymbols) + err := t.defaultProbes.Attach(probeNode.GetHandle(), t.cgroups, t.getKernelSymbols()) if err != nil { return []dependencies.Action{dependencies.NewCancelNodeAddAction(err)} } @@ -1722,7 +1727,7 @@ func (t *Tracee) triggerSeqOpsIntegrityCheck(event trace.Event) { } var seqOpsPointers [len(derive.NetSeqOps)]uint64 for i, seqName := range derive.NetSeqOps { - seqOpsStruct, err := t.kernelSymbols.GetSymbolByOwnerAndName("system", seqName) + seqOpsStruct, err := t.getKernelSymbols().GetSymbolByOwnerAndName("system", seqName) if err != nil { continue } @@ -1816,7 +1821,7 @@ func (t *Tracee) triggerMemDump(event trace.Event) []error { continue } - symbol, err := t.kernelSymbols.GetSymbolByOwnerAndName(owner, name) + symbol, err := t.getKernelSymbols().GetSymbolByOwnerAndName(owner, name) if err != nil { if owner != "system" { errs = append(errs, errfmt.Errorf("policy %d: invalid symbols provided to print_mem_dump event: %s - %v", p.ID, field, err)) @@ -1828,7 +1833,7 @@ func (t *Tracee) triggerMemDump(event trace.Event) []error { prefixes := []string{"sys_", "__x64_sys_", "__arm64_sys_"} var errSyscall error for _, prefix := range prefixes { - symbol, errSyscall = t.kernelSymbols.GetSymbolByOwnerAndName(owner, prefix+name) + symbol, errSyscall = t.getKernelSymbols().GetSymbolByOwnerAndName(owner, prefix+name) if errSyscall == nil { err = nil break diff --git a/pkg/events/derive/hooked_seq_ops.go b/pkg/events/derive/hooked_seq_ops.go index 08af6ba5fe5f..6fe24d56f796 100644 --- a/pkg/events/derive/hooked_seq_ops.go +++ b/pkg/events/derive/hooked_seq_ops.go @@ -4,7 +4,6 @@ import ( "github.com/aquasecurity/tracee/pkg/errfmt" "github.com/aquasecurity/tracee/pkg/events" "github.com/aquasecurity/tracee/pkg/events/parse" - "github.com/aquasecurity/tracee/pkg/utils" "github.com/aquasecurity/tracee/pkg/utils/environment" "github.com/aquasecurity/tracee/types/trace" ) @@ -43,7 +42,7 @@ func deriveHookedSeqOpsArgs(kernelSymbols *environment.KernelSymbolTable) derive if addr == 0 { continue } - hookingFunction := utils.ParseSymbol(addr, kernelSymbols) + hookingFunction := kernelSymbols.GetPotentiallyHiddenSymbolByAddr(addr)[0] seqOpsStruct := NetSeqOps[i/4] seqOpsFunc := NetSeqOpsFuncs[i%4] hookedSeqOps[seqOpsStruct+"_"+seqOpsFunc] = diff --git a/pkg/utils/environment/kernel_symbols.go b/pkg/utils/environment/kernel_symbols.go index cb2b119c99f4..d6687df2215b 100644 --- a/pkg/utils/environment/kernel_symbols.go +++ b/pkg/utils/environment/kernel_symbols.go @@ -2,319 +2,283 @@ package environment import ( "bufio" - "fmt" + "io" "os" "strconv" "strings" - "sync" + + "github.com/aquasecurity/tracee/pkg/errfmt" + "github.com/aquasecurity/tracee/pkg/utils" ) const ( - kallsymsPath = "/proc/kallsyms" - chanBuffer = 112800 // TODO: check if we really need this buffer size + // Kernel symbols do not have an associated size, so we define a sensible size + // limit to prevent unrelated symbols from being returned for an address lookup + maxSymbolSize = 0x100000 + + ownerShift = 48 // Number of bits to shift the owner into the upper 16 bits + addressMask = (1 << ownerShift) - 1 // Mask to extract the address from the addressAndOwner field + kernelAddressPrefix = uint64(0xffff) << ownerShift // Precomputed prefix for kernel addresses ) +// KernelSymbol is a friendly representation of a kernel symbol. type KernelSymbol struct { Name string - Type string Address uint64 Owner string } -type nameAndOwner struct { - name string - owner string -} -type addrAndOwner struct { - addr uint64 - owner string -} - -// KernelSymbolTable manages kernel symbols with multiple maps for fast lookup. -type KernelSymbolTable struct { - symbols map[string][]*KernelSymbol - addrs map[uint64][]*KernelSymbol - symByName map[nameAndOwner][]*KernelSymbol - symByAddr map[addrAndOwner][]*KernelSymbol - requiredSyms map[string]struct{} - requiredAddrs map[uint64]struct{} - onlyRequired bool - updateLock sync.Mutex - updateWg sync.WaitGroup -} -func symNotFoundErr(v interface{}) error { - return fmt.Errorf("symbol not found: %v", v) +// kernelSymbolInternal is a memory efficient representation of +// a kernel symbol, used internally for storing all symbols. +type kernelSymbolInternal struct { + name string + // We save only the low 48 bits of the address, as all (non-percpu) symbols are at 0xffffXXXXXXXXXXXX + // Owner is a 16-bit index into a slice of seen owners for the symbol table this symbol belongs to. + // It can only be translated to the owner name if we have the symbol table. + // To conserve memory, we encode both of them as a single 64-bit integer where the lower 48-bits + // are the address and the hight 16-bits are the owner index. + addressAndOwner uint64 } -// NewKernelSymbolTable initializes a KernelSymbolTable with optional configuration functions. -func NewKernelSymbolTable(opts ...KSymbTableOption) (*KernelSymbolTable, error) { - k := &KernelSymbolTable{} - for _, opt := range opts { - if err := opt(k); err != nil { - return nil, err - } +func newKernelSymbolInternal(name string, address uint64, owner uint16) *kernelSymbolInternal { + return &kernelSymbolInternal{ + name: name, + addressAndOwner: (uint64(owner) << ownerShift) | (address & addressMask), } +} - // Set onlyRequired to true if there are required symbols or addresses - k.onlyRequired = k.requiredAddrs != nil || k.requiredSyms != nil - - // Initialize maps if they are nil - if k.requiredSyms == nil { - k.requiredSyms = make(map[string]struct{}) - } - if k.requiredAddrs == nil { - k.requiredAddrs = make(map[uint64]struct{}) - } +func (ks kernelSymbolInternal) Name() string { + return ks.name +} - return k, k.Refresh() +func (ks kernelSymbolInternal) Address() uint64 { + // Convert truncated address to the real kernel address + return kernelAddressPrefix | (ks.addressAndOwner & addressMask) } -// KSymbTableOption defines a function signature for configuration options. -type KSymbTableOption func(k *KernelSymbolTable) error +func (ks kernelSymbolInternal) owner() uint16 { + return uint16(ks.addressAndOwner >> ownerShift) +} -// WithRequiredSymbols sets the required symbols for the KernelSymbolTable. -func WithRequiredSymbols(reqSyms []string) KSymbTableOption { - return func(k *KernelSymbolTable) error { - k.requiredSyms = sliceToValidationMap(reqSyms) - return nil - } +func (ks kernelSymbolInternal) Contains(address uint64) bool { + symbolAddr := ks.Address() + return symbolAddr <= address && symbolAddr+maxSymbolSize > address } -// WithRequiredAddresses sets the required addresses for the KernelSymbolTable. -func WithRequiredAddresses(reqAddrs []uint64) KSymbTableOption { - return func(k *KernelSymbolTable) error { - k.requiredAddrs = sliceToValidationMap(reqAddrs) - return nil +func (ks kernelSymbolInternal) Clone() kernelSymbolInternal { + return kernelSymbolInternal{ + name: ks.name, + addressAndOwner: ks.addressAndOwner, } } -// TextSegmentContains returns true if the given address is in the kernel text segment. -func (k *KernelSymbolTable) TextSegmentContains(addr uint64) (bool, error) { - k.updateLock.Lock() - defer k.updateLock.Unlock() - - segStart, segEnd, err := k.getTextSegmentAddresses() - if err != nil { - return false, err - } +type KernelSymbolTable struct { + symbols *utils.SymbolTable[kernelSymbolInternal] - return addr >= segStart && addr < segEnd, nil + // Used for memory efficient representation of symbol owners + idxToSymbolOwner []string + symbolOwnerToIdx map[string]uint16 } -// GetSymbolByName returns all the symbols with the given name. -func (k *KernelSymbolTable) GetSymbolByName(name string) ([]KernelSymbol, error) { - k.updateLock.Lock() - defer k.updateLock.Unlock() - - if err := k.validateOrAddRequiredSym(name); err != nil { - return nil, err +// Creates a new KernelSymbolTable that will be populated from a reader. +// If lazyNameLookup is true, the mapping from name to symbol will be populated +// only when a failed lookup occurs. This reduces memory footprint at the cost +// of the time it takes to lookup a symbol name for the first time. +// If requiredDataSymbolsOnly is true, only the data symbols passed in the +// optional requiredDataSymbols argument will be added. +func NewKernelSymbolTableFromReader(reader io.Reader, lazyNameLookup bool, requiredDataSymbolsOnly bool, requiredDataSymbols ...string) (*KernelSymbolTable, error) { + kst := &KernelSymbolTable{ + symbols: utils.NewSymbolTable[kernelSymbolInternal](lazyNameLookup), + idxToSymbolOwner: []string{"system"}, + symbolOwnerToIdx: map[string]uint16{"system": 0}, } - symbols, exist := k.symbols[name] - if !exist { - return nil, symNotFoundErr(name) + if err := kst.update(reader, requiredDataSymbolsOnly, requiredDataSymbols); err != nil { + return nil, err } - return copySliceOfPointersToSliceOfStructs(symbols), nil + return kst, nil } -// GetSymbolByAddr returns all the symbols with the given address. -func (k *KernelSymbolTable) GetSymbolByAddr(addr uint64) ([]KernelSymbol, error) { - k.updateLock.Lock() - defer k.updateLock.Unlock() - - if err := k.validateOrAddRequiredAddr(addr); err != nil { - return nil, err +// Creates a new KernelSymbolTable that will be populated from /proc/kallsyms. +// If lazyNameLookup is true, the mapping from name to symbol will be populated +// only when a failed lookup occurs. This reduces memory footprint at the cost +// of the time it takes to lookup a symbol name for the first time. +// If requiredDataSymbolsOnly is true, only the data symbols passed in the +// optional requiredDataSymbols argument will be added. +func NewKernelSymbolTable(lazyNameLookup bool, requiredDataSymbolsOnly bool, requiredDataSymbols ...string) (*KernelSymbolTable, error) { + file, err := os.Open("/proc/kallsyms") + if err != nil { + return nil, errfmt.WrapError(err) } + defer func() { + _ = file.Close() + }() - symbols, exist := k.addrs[addr] - if !exist { - return nil, symNotFoundErr(addr) + return NewKernelSymbolTableFromReader(file, lazyNameLookup, requiredDataSymbolsOnly, requiredDataSymbols...) +} + +// Read the contents of the given buffer and update the symbol table +func (kst *KernelSymbolTable) update(reader io.Reader, requiredDataSymbolsOnly bool, requiredDataSymbols []string) error { + // Build set of required data symbols for efficient lookup + requiredDataSymbolsSet := make(map[string]struct{}) + for _, symbolName := range requiredDataSymbols { + requiredDataSymbolsSet[symbolName] = struct{}{} } - return copySliceOfPointersToSliceOfStructs(symbols), nil -} + symbols := []*kernelSymbolInternal{} -// GetSymbolByOwnerAndName returns all the symbols with the given owner and name. -func (k *KernelSymbolTable) GetSymbolByOwnerAndName(owner, name string) ([]KernelSymbol, error) { - k.updateLock.Lock() - defer k.updateLock.Unlock() + // Make sure we hold the required privileges by checking if we see actual addresses + seenRealAddress := false - if err := k.validateOrAddRequiredSym(name); err != nil { - return nil, err - } + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if len(fields) < 3 { + continue + } - symbols, exist := k.symByName[nameAndOwner{name, owner}] - if !exist { - return nil, symNotFoundErr(nameAndOwner{name, owner}) - } + symbolAddr, err := strconv.ParseUint(fields[0], 16, 64) + if err != nil { + continue + } + if symbolAddr != 0 { + seenRealAddress = true + } - return copySliceOfPointersToSliceOfStructs(symbols), nil -} + // All kernel symbols are at 0xffffXXXXXXXXXXXX, except percpu symbols which we ignore + if !validKernelAddr(symbolAddr) { + continue + } -// GetSymbolByOwnerAndAddr returns all the symbols with the given owner and address. -func (k *KernelSymbolTable) GetSymbolByOwnerAndAddr(owner string, addr uint64) ([]KernelSymbol, error) { - k.updateLock.Lock() - defer k.updateLock.Unlock() + symbolType := fields[1] + symbolName := fields[2] - if err := k.validateOrAddRequiredAddr(addr); err != nil { - return nil, err - } + symbolOwner := "system" + if len(fields) > 3 { + symbolOwner = fields[3] + symbolOwner = strings.TrimPrefix(symbolOwner, "[") + symbolOwner = strings.TrimSuffix(symbolOwner, "]") + } - symbols, exist := k.symByAddr[addrAndOwner{addr, owner}] - if !exist { - return nil, symNotFoundErr(addrAndOwner{addr, owner}) - } + // This is a data symbol, requiredDataSymbolsOnly is true, and this symbol isn't required + if requiredDataSymbolsOnly && strings.ContainsAny(symbolType, "DdBbRr") { + if _, exists := requiredDataSymbolsSet[symbolName]; !exists { + continue + } + } - return copySliceOfPointersToSliceOfStructs(symbols), nil -} + // Get index of symbol owner, or add it if it doesn't exist + ownerIdx := kst.getOrAddSymbolOwner(symbolOwner) -// getTextSegmentAddresses gets the start and end addresses of the kernel text segment. -func (k *KernelSymbolTable) getTextSegmentAddresses() (uint64, uint64, error) { - stext, exist1 := k.symByName[nameAndOwner{"_stext", "system"}] - etext, exist2 := k.symByName[nameAndOwner{"_etext", "system"}] + symbols = append(symbols, newKernelSymbolInternal(symbolName, symbolAddr, ownerIdx)) + } - if !exist1 || !exist2 { - return 0, 0, fmt.Errorf("kernel text segment symbol(s) not found") + // We didn't hold the required privileges + if len(symbols) > 0 && !seenRealAddress { + return errfmt.Errorf("insufficient privileges when reading from /proc/kallsyms") } - textSegStart := stext[0].Address - textSegEnd := etext[0].Address + // Update the symbol table + kst.symbols.AddSymbols(symbols) - return textSegStart, textSegEnd, nil + return nil } -// validateOrAddRequiredSym checks if the given symbol is in the required list and adds it if not. -func (k *KernelSymbolTable) validateOrAddRequiredSym(sym string) error { - return k.validateOrAddRequired(func() bool { - _, ok := k.requiredSyms[sym] - return ok - }, func() { - k.requiredSyms[sym] = struct{}{} - }) -} +func (kst *KernelSymbolTable) getOrAddSymbolOwner(ownerStr string) uint16 { + ownerIdx, found := kst.symbolOwnerToIdx[ownerStr] + if !found { + kst.idxToSymbolOwner = append(kst.idxToSymbolOwner, ownerStr) + ownerIdx = uint16(len(kst.idxToSymbolOwner) - 1) + kst.symbolOwnerToIdx[ownerStr] = ownerIdx + } -// validateOrAddRequiredAddr checks if the given address is in the required list and adds it if not. -func (k *KernelSymbolTable) validateOrAddRequiredAddr(addr uint64) error { - return k.validateOrAddRequired(func() bool { - _, ok := k.requiredAddrs[addr] - return ok - }, func() { - k.requiredAddrs[addr] = struct{}{} - }) + return ownerIdx } -// validateOrAddRequired is a common function to check and add required symbols or addresses. -func (k *KernelSymbolTable) validateOrAddRequired(checkRequired func() bool, addRequired func()) error { - if !k.onlyRequired { - return nil +func (kst *KernelSymbolTable) symbolFromInternal(symbol *kernelSymbolInternal) *KernelSymbol { + return &KernelSymbol{ + Name: symbol.Name(), + Address: symbol.Address(), + Owner: kst.idxToSymbolOwner[symbol.owner()], } +} - if !checkRequired() { - addRequired() - return k.refresh() +// GetSymbolByName returns all the symbols with the given name. +func (kst *KernelSymbolTable) GetSymbolByName(name string) ([]*KernelSymbol, error) { + symbolsInternal, err := kst.symbols.LookupByName(name) + if err != nil { + return nil, errfmt.WrapError(err) } - return nil -} + symbols := make([]*KernelSymbol, 0, len(symbolsInternal)) + for _, symbolInternal := range symbolsInternal { + symbols = append(symbols, kst.symbolFromInternal(symbolInternal)) + } -// Refresh is the exported method that acquires the lock and calls the internal refresh method. -func (k *KernelSymbolTable) Refresh() error { - k.updateLock.Lock() - defer k.updateLock.Unlock() - return k.refresh() + return symbols, nil } -// refresh refreshes the KernelSymbolTable, reading the symbols from /proc/kallsyms. -func (k *KernelSymbolTable) refresh() error { - // Re-initialize the maps to include all new symbols. - k.symbols = make(map[string][]*KernelSymbol) - k.addrs = make(map[uint64][]*KernelSymbol) - k.symByName = make(map[nameAndOwner][]*KernelSymbol) - k.symByAddr = make(map[addrAndOwner][]*KernelSymbol) - - // Open the kallsyms file. - file, err := os.Open(kallsymsPath) +// GetSymbolByOwnerAndName returns all the symbols with the given owner and name. +func (kst *KernelSymbolTable) GetSymbolByOwnerAndName(owner, name string) ([]*KernelSymbol, error) { + symbolsInternal, err := kst.symbols.LookupByName(name) if err != nil { - return err + return nil, errfmt.WrapError(err) } - defer func() { - _ = file.Close() - }() - // Read the kallsyms file line by line and process each line. - scanner := bufio.NewScanner(file) - for scanner.Scan() { - fields := strings.Fields(scanner.Text()) - if len(fields) < 3 { - continue - } - sym := parseKallsymsLine(fields) - if sym == nil { - continue + symbols := make([]*KernelSymbol, 0, len(symbolsInternal)) + for _, symbolInternal := range symbolsInternal { + symbol := kst.symbolFromInternal(symbolInternal) + // Return only symbols that have the requested owner + if symbol.Owner == owner { + symbols = append(symbols, symbol) } - - if k.onlyRequired { - _, symRequired := k.requiredSyms[sym.Name] - _, addrRequired := k.requiredAddrs[sym.Address] - if !symRequired && !addrRequired { - continue - } - } - - k.symbols[sym.Name] = append(k.symbols[sym.Name], sym) - k.addrs[sym.Address] = append(k.addrs[sym.Address], sym) - k.symByName[nameAndOwner{sym.Name, sym.Owner}] = append(k.symByName[nameAndOwner{sym.Name, sym.Owner}], sym) - k.symByAddr[addrAndOwner{sym.Address, sym.Owner}] = append(k.symByAddr[addrAndOwner{sym.Address, sym.Owner}], sym) } - err = scanner.Err() - return err + return symbols, nil } -// parseKallsymsLine parses a line from /proc/kallsyms and returns a KernelSymbol. -func parseKallsymsLine(line []string) *KernelSymbol { - if len(line) < 3 { - return nil +// GetSymbolByAddr returns all the symbols with the given address. +func (kst *KernelSymbolTable) GetSymbolByAddr(addr uint64) ([]*KernelSymbol, error) { + symbolsInternal, err := kst.symbols.LookupByAddressExact(addr) + if err != nil { + return nil, errfmt.WrapError(err) } - symbolAddr, err := strconv.ParseUint(line[0], 16, 64) - if err != nil { - return nil + symbols := make([]*KernelSymbol, 0, len(symbolsInternal)) + for _, symbolInternal := range symbolsInternal { + symbols = append(symbols, kst.symbolFromInternal(symbolInternal)) } - symbolType := line[1] - symbolName := line[2] + return symbols, nil +} - symbolOwner := "system" - if len(line) > 3 { - line[3] = strings.TrimPrefix(line[3], "[") - line[3] = strings.TrimSuffix(line[3], "]") - symbolOwner = line[3] +// GetPotentiallyHiddenSymbolByAddr returns all the symbols with the given address, +// or if none are found, a fake symbol with the "hidden" owner. +func (kst *KernelSymbolTable) GetPotentiallyHiddenSymbolByAddr(addr uint64) []*KernelSymbol { + symbolsInternal, err := kst.symbols.LookupByAddressExact(addr) + if err != nil || !validKernelAddr(addr) { + // No symbol found or address not in kernel range, return a fake "hidden" symbol + return []*KernelSymbol{{ + Address: addr, + Owner: "hidden", + }} } - return &KernelSymbol{ - Name: symbolName, - Type: symbolType, - Address: symbolAddr, - Owner: symbolOwner, + symbols := make([]*KernelSymbol, 0, len(symbolsInternal)) + for _, symbolInternal := range symbolsInternal { + symbols = append(symbols, kst.symbolFromInternal(symbolInternal)) } + + return symbols } -// copySliceOfPointersToSliceOfStructs converts a slice of pointers to a slice of structs. -func copySliceOfPointersToSliceOfStructs(s []*KernelSymbol) []KernelSymbol { - ret := make([]KernelSymbol, len(s)) - for i, v := range s { - ret[i] = *v - } - return ret +func (kst *KernelSymbolTable) ForEachSymbol(callback func(*KernelSymbol)) { + kst.symbols.ForEachSymbol(func(symbol *kernelSymbolInternal) { + callback(kst.symbolFromInternal(symbol)) + }) } -// sliceToValidationMap converts a slice to a map for validation purposes. -func sliceToValidationMap[T comparable](items []T) map[T]struct{} { - res := make(map[T]struct{}) - for _, item := range items { - res[item] = struct{}{} - } - return res +func validKernelAddr(addr uint64) bool { + return addr&kernelAddressPrefix == kernelAddressPrefix } diff --git a/pkg/utils/environment/kernel_symbols_test.go b/pkg/utils/environment/kernel_symbols_test.go index 834b58026a1f..6bf9e9d81c2f 100644 --- a/pkg/utils/environment/kernel_symbols_test.go +++ b/pkg/utils/environment/kernel_symbols_test.go @@ -2,32 +2,30 @@ package environment import ( "reflect" + "strings" "testing" ) -// TestParseLine tests the parseKallsymsLine function. -func TestParseKallsymsLine(t *testing.T) { - testCases := []struct { - line []string - expected *KernelSymbol - }{ - {[]string{"00000000", "t", "my_symbol", "[my_owner]"}, &KernelSymbol{Name: "my_symbol", Type: "t", Address: 0, Owner: "my_owner"}}, - {[]string{"00000001", "T", "another_symbol"}, &KernelSymbol{Name: "another_symbol", Type: "T", Address: 1, Owner: "system"}}, - {[]string{"invalid_address", "T", "invalid_symbol"}, nil}, - {[]string{"00000002", "T"}, nil}, - } +type symbolInfo struct { + name string + address uint64 + owner string +} - for _, tc := range testCases { - result := parseKallsymsLine(tc.line) - if !reflect.DeepEqual(result, tc.expected) { - t.Errorf("parseKallsymsLine(%v) = %v; want %v", tc.line, result, tc.expected) - } +func symbolToSymbolInfo(symbol *KernelSymbol) *symbolInfo { + if symbol == nil { + return nil + } + return &symbolInfo{ + name: symbol.Name, + address: symbol.Address, + owner: symbol.Owner, } } // TestNewKernelSymbolTable tests the NewKernelSymbolTable function. func TestNewKernelSymbolTable(t *testing.T) { - kst, err := NewKernelSymbolTable() + kst, err := NewKernelSymbolTable(true, false) if err != nil { t.Fatalf("NewKernelSymbolTable() failed: %v", err) } @@ -36,26 +34,56 @@ func TestNewKernelSymbolTable(t *testing.T) { t.Fatalf("NewKernelSymbolTable() returned nil") } - // Check if the onlyRequired flag is set correctly - if kst.onlyRequired { - t.Errorf("onlyRequired flag should be false by default") + // Check if symbols is initialized + if kst.symbols == nil { + t.Errorf("KernelSymbolTable is not initialized correctly") + } +} + +func getTheOnlySymbol(t *testing.T, kst *KernelSymbolTable) *KernelSymbol { + i := 0 + var foundSymbol *KernelSymbol + kst.ForEachSymbol(func(symbol *KernelSymbol) { + i++ + foundSymbol = symbol + }) + if i > 1 { + t.Errorf("multiple symbols found") + } + return foundSymbol +} + +// TestUpdate tests the kallsyms parsing logic. +func TestUpdate(t *testing.T) { + testCases := []struct { + buf string + expected *symbolInfo + }{ + {"ffffffff00000001 t my_symbol [my_owner]", &symbolInfo{name: "my_symbol", address: 0xffffffff00000001, owner: "my_owner"}}, + {"ffffffff00000002 T another_symbol", &symbolInfo{name: "another_symbol", address: 0xffffffff00000002, owner: "system"}}, + {"invalid_address T invalid_symbol", nil}, + {"ffffffff00000003 T", nil}, } - // Check if maps are initialized - if kst.symbols == nil || kst.addrs == nil || kst.symByName == nil || kst.symByAddr == nil { - t.Errorf("KernelSymbolTable maps are not initialized correctly") + for _, tc := range testCases { + kst, err := NewKernelSymbolTableFromReader(strings.NewReader(tc.buf), false, false) + if err != nil { + t.Fatalf("NewKernelSymbolTableFromReader() failed: %v", err) + } + symbol := getTheOnlySymbol(t, kst) + result := symbolToSymbolInfo(symbol) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("update(%v) = %v; want %v", tc.buf, result, tc.expected) + } } } // TestGetSymbolByName tests the GetSymbolByName function. func TestGetSymbolByName(t *testing.T) { - kst, err := NewKernelSymbolTable() + buf := "ffffffff00000001 t test_symbol test_owner" + kst, err := NewKernelSymbolTableFromReader(strings.NewReader(buf), false, false) if err != nil { - t.Fatalf("NewKernelSymbolTable() failed: %v", err) - } - - kst.symbols["test_symbol"] = []*KernelSymbol{ - {Name: "test_symbol", Type: "t", Address: 0, Owner: "test_owner"}, + t.Fatalf("NewKernelSymbolTableFromReader() failed: %v", err) } symbols, err := kst.GetSymbolByName("test_symbol") @@ -67,110 +95,79 @@ func TestGetSymbolByName(t *testing.T) { t.Errorf("Expected 1 symbol, got %d", len(symbols)) } - expectedSymbol := KernelSymbol{Name: "test_symbol", Type: "t", Address: 0, Owner: "test_owner"} - if !reflect.DeepEqual(symbols[0], expectedSymbol) { - t.Errorf("GetSymbolByName() = %v; want %v", symbols[0], expectedSymbol) + expected := &symbolInfo{name: "test_symbol", address: 0xffffffff00000001, owner: "test_owner"} + result := symbolToSymbolInfo(symbols[0]) + if !reflect.DeepEqual(result, expected) { + t.Errorf("GetSymbolByName() = %v; want %v", result, expected) } } -// TestGetSymbolByAddr tests the GetSymbolByAddr function. -func TestGetSymbolByAddr(t *testing.T) { - kst, err := NewKernelSymbolTable() +// TestGetSymbolByOwnerAndName tests the GetSymbolByOwnerAndName function. +func TestGetSymbolByOwnerAndName(t *testing.T) { + buf := `ffffffff00000001 t test_symbol test_owner1 +ffffffff00000002 t test_symbol test_owner2` + kst, err := NewKernelSymbolTableFromReader(strings.NewReader(buf), false, false) if err != nil { - t.Fatalf("NewKernelSymbolTable() failed: %v", err) + t.Fatalf("NewKernelSymbolTableFromReader() failed: %v", err) } - kst.addrs[0x1234] = []*KernelSymbol{ - {Name: "test_symbol", Type: "t", Address: 0x1234, Owner: "test_owner"}, - } - - symbols, err := kst.GetSymbolByAddr(0x1234) + symbols, err := kst.GetSymbolByOwnerAndName("test_owner1", "test_symbol") if err != nil { - t.Fatalf("GetSymbolByAddr() failed: %v", err) + t.Fatalf("GetSymbolByName() failed: %v", err) } if len(symbols) != 1 { t.Errorf("Expected 1 symbol, got %d", len(symbols)) } - expectedSymbol := KernelSymbol{Name: "test_symbol", Type: "t", Address: 0x1234, Owner: "test_owner"} - if !reflect.DeepEqual(symbols[0], expectedSymbol) { - t.Errorf("GetSymbolByAddr() = %v; want %v", symbols[0], expectedSymbol) + expected := &symbolInfo{name: "test_symbol", address: 0xffffffff00000001, owner: "test_owner1"} + result := symbolToSymbolInfo(symbols[0]) + if !reflect.DeepEqual(result, expected) { + t.Errorf("GetSymbolByOwnerAndName() = %v; want %v", result, expected) } } -// TestRefresh tests the Refresh function. -func TestRefresh(t *testing.T) { - // Creating a mock KernelSymbolTable with required symbols to test Refresh - kst, err := NewKernelSymbolTable(WithRequiredSymbols([]string{"_stext", "_etext"})) +// TestGetSymbolByAddr tests the GetSymbolByAddr function. +func TestGetSymbolByAddr(t *testing.T) { + buf := "ffffffff00001234 t test_symbol test_owner" + kst, err := NewKernelSymbolTableFromReader(strings.NewReader(buf), false, false) if err != nil { - t.Fatalf("NewKernelSymbolTable() failed: %v", err) - } - - // Simulate the presence of these symbols - kst.symbols["_stext"] = []*KernelSymbol{{Name: "_stext", Type: "T", Address: 0x1000, Owner: "system"}} - kst.symbols["_etext"] = []*KernelSymbol{{Name: "_etext", Type: "T", Address: 0x2000, Owner: "system"}} - - // Call Refresh to update the symbol table - if err := kst.Refresh(); err != nil { - t.Fatalf("Refresh() failed: %v", err) - } - - // Check if symbols were added correctly - symbolsToTest := []string{"_stext", "_etext"} - for _, symbol := range symbolsToTest { - if syms, err := kst.GetSymbolByName(symbol); err != nil || len(syms) == 0 { - t.Errorf("Expected to find symbol %s, but it was not found", symbol) - } + t.Fatalf("NewKernelSymbolTableFromReader() failed: %v", err) } -} -// TestTextSegmentContains tests the TextSegmentContains function. -func TestTextSegmentContains(t *testing.T) { - // Creating a mock KernelSymbolTable with text segment addresses - kst, err := NewKernelSymbolTable() + symbols, err := kst.GetSymbolByAddr(0xffffffff00001234) if err != nil { - t.Fatalf("NewKernelSymbolTable() failed: %v", err) + t.Fatalf("GetSymbolByAddr() failed: %v", err) } - kst.symByName[nameAndOwner{"_stext", "system"}] = []*KernelSymbol{{Name: "_stext", Type: "T", Address: 0x1000, Owner: "system"}} - kst.symByName[nameAndOwner{"_etext", "system"}] = []*KernelSymbol{{Name: "_etext", Type: "T", Address: 0x2000, Owner: "system"}} - - tests := []struct { - addr uint64 - expected bool - }{ - {0x1000, true}, - {0x1500, true}, - {0x2000, false}, - {0x0999, false}, + if len(symbols) != 1 { + t.Errorf("Expected 1 symbol, got %d", len(symbols)) } - for _, tt := range tests { - result, err := kst.TextSegmentContains(tt.addr) - if err != nil { - t.Errorf("TextSegmentContains(%v) failed: %v", tt.addr, err) - } - if result != tt.expected { - t.Errorf("TextSegmentContains(%v) = %v; want %v", tt.addr, result, tt.expected) - } + expected := &symbolInfo{name: "test_symbol", address: 0xffffffff00001234, owner: "test_owner"} + result := symbolToSymbolInfo(symbols[0]) + if !reflect.DeepEqual(result, expected) { + t.Errorf("GetSymbolByAddr() = %v; want %v", result, expected) } } -// Helper function to test required symbols or addresses. -func TestValidateOrAddRequired(t *testing.T) { - kst, err := NewKernelSymbolTable(WithRequiredSymbols([]string{"test_symbol"})) +// TestGetPotentiallyHiddenSymbolByAddr tests the GetPotentiallyHiddenSymbolByAddr function. +func TestGetPotentiallyHiddenSymbolByAddr(t *testing.T) { + buf := "ffffffff00000001 t test_symbol test_owner" + kst, err := NewKernelSymbolTableFromReader(strings.NewReader(buf), false, false) if err != nil { - t.Fatalf("NewKernelSymbolTable() failed: %v", err) + t.Fatalf("NewKernelSymbolTableFromReader() failed: %v", err) } - kst.requiredSyms["test_symbol"] = struct{}{} + symbols := kst.GetPotentiallyHiddenSymbolByAddr(0xffffffff00000002) - if err := kst.validateOrAddRequiredSym("test_symbol"); err != nil { - t.Errorf("validateOrAddRequiredSym() failed: %v", err) + if len(symbols) != 1 { + t.Errorf("Expected 1 symbol, got %d", len(symbols)) } - if err := kst.validateOrAddRequiredAddr(0x1234); err != nil { - t.Errorf("validateOrAddRequiredAddr() failed: %v", err) + expected := &symbolInfo{name: "", address: 0xffffffff00000002, owner: "hidden"} + result := symbolToSymbolInfo(symbols[0]) + if !reflect.DeepEqual(result, expected) { + t.Errorf("GetSymbolByAddr() = %v; want %v", result, expected) } } diff --git a/pkg/utils/symbol_table.go b/pkg/utils/symbol_table.go new file mode 100644 index 000000000000..bb10f6053068 --- /dev/null +++ b/pkg/utils/symbol_table.go @@ -0,0 +1,209 @@ +package utils + +import ( + "errors" + "sort" + "sync" +) + +// The Symbol interface defines what is needed from a symbol implementation in +// order to facilitate the lookup functionalities provided by SymbolTable. +// Implementations of Symbol can hold various types of information relevant to +// the type of symbol they represent. +type Symbol[T any] interface { + // Name returns the symbol's name + Name() string + // Address returns the base address of the symbol + Address() uint64 + // Contains returns whether a given address belongs to the symbol's + // address range, which is defined by the symbol's implementation + Contains(address uint64) bool + Cloner[T] +} + +// SymbolTable is used to hold information about symbols (mapping from symbolic +// names used in code to their address) in a certain executable. +// It can be used to hold symbols from an ELF binary, or symbols of the entire +// kernel and its modules. +// It provides functions to lookup symbols by address and name. +type SymbolTable[T Symbol[T]] struct { + mu sync.RWMutex + // All symbols sorted by their address in descending order, + // for quick binary searches by address. + sortedSymbols []*T + // If lazyNameLookup is true, the symbolsByName map + // will be populated only when a failed lookup occurs. + symbolsByName map[string][]*T + lazyNameLookup bool +} + +var ErrSymbolNotFound = errors.New("symbol not found") + +// Creates a new SymbolTable. If lazyNameLookup is true, the mapping from +// name to symbol will be populated only when a failed lookup occurs. +// This reduces memory footprint at the cost of the time it takes to lookup +// a symbol name for the first time. +func NewSymbolTable[T Symbol[T]](lazyNameLookup bool) *SymbolTable[T] { + return &SymbolTable[T]{ + sortedSymbols: make([]*T, 0), + symbolsByName: make(map[string][]*T), + lazyNameLookup: lazyNameLookup, + } +} + +// Adds a slice of symbols to the symbol table. +func (st *SymbolTable[T]) AddSymbols(symbols []*T) { + st.mu.Lock() + defer st.mu.Unlock() + + // Add the new symbols to the sorted slice (which now becomes unsorted). + // Allocate the slice with the needed capacity to avoid overallocation. + oldSymbols := st.sortedSymbols + newLen := len(oldSymbols) + len(symbols) + st.sortedSymbols = make([]*T, 0, newLen) + st.sortedSymbols = append(st.sortedSymbols, oldSymbols...) + st.sortedSymbols = append(st.sortedSymbols, symbols...) + + // If lazyNameLookup is false, we update the name to symbol mapping for + // each new symbol + if !st.lazyNameLookup { + for _, symbol := range symbols { + name := (*symbol).Name() + if symbols, found := st.symbolsByName[name]; found { + st.symbolsByName[name] = append(symbols, symbol) + } else { + st.symbolsByName[name] = []*T{symbol} + } + } + } + + // Sort the symbols slice by address in descending order + sort.Slice(st.sortedSymbols, + func(i, j int) bool { + return (*st.sortedSymbols[i]).Address() > (*st.sortedSymbols[j]).Address() + }) +} + +// Lookup a symbol in the table by its name. +// Because there may be multiple symbols with the same name, a slice of all +// matching symbols is returned. +func (st *SymbolTable[T]) LookupByName(name string) ([]*T, error) { + st.mu.RLock() + // We call RUnlock manually and not using defer because we may need to upgrade to a write lock later + + // Lookup the name in the name to symbol mapping + if symbols, found := st.symbolsByName[name]; found { + st.mu.RUnlock() + return symbols, nil + } + + // Lazy name lookup is disabled, the lookup failed + if !st.lazyNameLookup { + st.mu.RUnlock() + return nil, ErrSymbolNotFound + } + + // Lazy name lookup is enabled, perform a linear search to find the requested name + symbols := []*T{} + for _, symbol := range st.sortedSymbols { + if (*symbol).Name() == name { + symbols = append(symbols, symbol) + } + } + + if len(symbols) > 0 { + // We found symbols with this name, update the mapping + st.mu.RUnlock() + st.mu.Lock() + defer st.mu.Unlock() + st.symbolsByName[name] = symbols + return symbols, nil + } + + st.mu.RUnlock() + return nil, ErrSymbolNotFound +} + +// Lookup a symbol in the table by its exact address. +// Because there may be multiple symbols at the same address, a slice of all +// matching symbols is returned. +func (st *SymbolTable[T]) LookupByAddressExact(address uint64) ([]*T, error) { + st.mu.RLock() + defer st.mu.RUnlock() + + // Find the first symbol at an address smaller than or equal to the requested address + idx := sort.Search(len(st.sortedSymbols), + func(i int) bool { + return address >= (*st.sortedSymbols[i]).Address() + }) + + // Not found or not exact match + if idx == len(st.sortedSymbols) || (*st.sortedSymbols[idx]).Address() != address { + return nil, ErrSymbolNotFound + } + + // The search result is the first symbol with the requested address, + // find any additional symbols with the same address. + syms := []*T{st.sortedSymbols[idx]} + for i := idx + 1; i < len(st.sortedSymbols); i++ { + if (*st.sortedSymbols[i]).Address() != address { + break + } + syms = append(syms, st.sortedSymbols[i]) + } + + return syms, nil +} + +// Find the symbol which contains the given address. +// If multiple symbols at different addresses contain the requested address, +// the symbol with the highest address will be returned. +// If multiple symbols at the same address contain the requested address, +// one of them will be returned, but there is no guarantee which one. +// This function assumes that symbols don't overlap in a way that a symbol with +// a smaller address contains the requested address while a symbol with a larger +// address (but still smaller that requested) doesn't contain it. +// For example, the following situation is assumed to be impossible: +// +// Requested Address +// | +// | +// +---------------+--+ +// |Symbol 1 | | +// +---------------+--+ +// +--------+ | +// |Symbol 2| | +// +--------+ v +// <----------------------> +// +// Smaller Larger +// Address Address +// +// If the above situation happens, no symbol will be returned. +func (st *SymbolTable[T]) LookupByAddressContains(address uint64) (*T, error) { + st.mu.RLock() + defer st.mu.RUnlock() + + // Find the first symbol at an address smaller than or equal to the requested address + idx := sort.Search(len(st.sortedSymbols), + func(i int) bool { + return address >= (*st.sortedSymbols[i]).Address() + }) + + // Not found or the symbol doesn't contain this address + if idx == len(st.sortedSymbols) || !(*st.sortedSymbols[idx]).Contains(address) { + return nil, ErrSymbolNotFound + } + + return st.sortedSymbols[idx], nil +} + +func (st *SymbolTable[T]) ForEachSymbol(callback func(symbol *T)) { + st.mu.RLock() + defer st.mu.RUnlock() + + for i := range len(st.sortedSymbols) { + sym := (*st.sortedSymbols[i]).Clone() + callback(&sym) + } +} diff --git a/pkg/utils/symbol_table_test.go b/pkg/utils/symbol_table_test.go new file mode 100644 index 000000000000..a06077f5de1c --- /dev/null +++ b/pkg/utils/symbol_table_test.go @@ -0,0 +1,344 @@ +package utils + +import ( + "reflect" + "testing" +) + +type testSymbol struct { + name string + addr uint64 + size uint64 +} + +func (s testSymbol) Name() string { + return s.name +} + +func (s testSymbol) Address() uint64 { + return s.addr +} + +func (s testSymbol) Contains(address uint64) bool { + return s.addr <= address && s.addr+s.size > address +} + +func (s testSymbol) Clone() testSymbol { + return testSymbol{ + name: s.name, + addr: s.addr, + size: s.size, + } +} + +// TestNewSymbolTable tests the NewSymbolTable function. +func TestNewSymbolTable(t *testing.T) { + st := NewSymbolTable[testSymbol](true) + if st == nil { + t.Fatalf("NewSymbolTable() returned nil") + } + + if !st.lazyNameLookup { + t.Errorf("lazyNameLookup was not set to true") + } + + if st.sortedSymbols == nil || st.symbolsByName == nil { + t.Errorf("data structures are nil") + } +} + +// TestAddSymbols tests the AddSymbols function +func TestAddSymbols(t *testing.T) { + testCases := []struct { + symbols []*testSymbol + expectedOrder []int + }{ + {[]*testSymbol{ + {name: "symbol1", addr: 1, size: 1}, + {name: "symbol2", addr: 1, size: 1}, + }, []int{0, 1}}, + {[]*testSymbol{ + {name: "symbol1", addr: 2, size: 1}, + {name: "symbol2", addr: 1, size: 1}, + }, []int{0, 1}}, + {[]*testSymbol{ + {name: "symbol1", addr: 1, size: 1}, + {name: "symbol2", addr: 2, size: 1}, + }, []int{1, 0}}, + } + + for _, tc := range testCases { + st := NewSymbolTable[testSymbol](false) + st.AddSymbols(tc.symbols) + + if len(st.sortedSymbols) != len(tc.symbols) { + t.Errorf("len(st.sortedSymbol) = %d, want %d", len(st.sortedSymbols), len(tc.symbols)) + continue + } + + for i := range st.sortedSymbols { + if !reflect.DeepEqual(*st.sortedSymbols[i], *tc.symbols[tc.expectedOrder[i]]) { + t.Errorf("AddSymbols(%v) = symbol %d: %v; want %v", tc.symbols, i, st.sortedSymbols[i], tc.symbols[tc.expectedOrder[i]]) + } + } + } +} + +// TestLookupByName tests the LookupByName function +func TestLookupByName(t *testing.T) { + testCases := []struct { + symbols []*testSymbol + lookupName string + expectLookupError bool + expected []testSymbol + }{ + { + []*testSymbol{{name: "symbol1", addr: 1, size: 1}}, + "symbol1", + false, + []testSymbol{{name: "symbol1", addr: 1, size: 1}}, + }, + { + []*testSymbol{}, + "symbol2", + true, + []testSymbol{}, + }, + { + []*testSymbol{{name: "symbol3", addr: 1, size: 1}}, + "symbol4", + true, + []testSymbol{}, + }, + { + []*testSymbol{{name: "symbol5", addr: 1, size: 1}, {name: "symbol6", addr: 2, size: 2}}, + "symbol6", + false, + []testSymbol{{name: "symbol6", addr: 2, size: 2}}, + }, + { + []*testSymbol{{name: "symbol7", addr: 1, size: 1}, {name: "symbol7", addr: 2, size: 2}}, + "symbol7", + false, + []testSymbol{{name: "symbol7", addr: 1, size: 1}, {name: "symbol7", addr: 2, size: 2}}, + }, + } + + for _, tc := range testCases { + st := NewSymbolTable[testSymbol](false) + st.AddSymbols(tc.symbols) + result, err := st.LookupByName(tc.lookupName) + if !tc.expectLookupError && err != nil { + t.Errorf("LookupByName(%s) failed: %v", tc.lookupName, err) + continue + } else if tc.expectLookupError { + if err == nil { + t.Errorf("LookupByName(%s) expected to fail but didn't", tc.lookupName) + } + continue + } + if !reflect.DeepEqual(copySliceOfPointersToSliceOfStructs(result), tc.expected) { + t.Errorf("LookupByName(%s) = %v, expected %v", tc.lookupName, copySliceOfPointersToSliceOfStructs(result), tc.expected) + } + } +} + +// TestLazyNameLookup tests the lazy name lookup functionality +func TestLazyNameLookup(t *testing.T) { + testCases := []struct { + symbols []*testSymbol + lazyNameLookup bool + lookups []string + expectedMappings []int + }{ + { + []*testSymbol{{name: "symbol", addr: 1, size: 1}}, + false, + []string{}, + []int{0}, + }, + { + []*testSymbol{{name: "symbol", addr: 1, size: 1}}, + true, + []string{}, + []int{}, + }, + { + []*testSymbol{{name: "symbol1", addr: 1, size: 1}, {name: "symbol2", addr: 2, size: 1}}, + true, + []string{"symbol1", "symbol2"}, + []int{0, 1}, + }, + { + []*testSymbol{{name: "symbol1", addr: 1, size: 1}, {name: "symbol2", addr: 2, size: 1}}, + true, + []string{"symbol2"}, + []int{1}, + }, + } + +testLoop: + for _, tc := range testCases { + st := NewSymbolTable[testSymbol](tc.lazyNameLookup) + st.AddSymbols(tc.symbols) + if tc.lazyNameLookup { + if len(st.symbolsByName) != 0 { + t.Errorf("len(st.symbolsByName) = %d, expected 0", len(st.symbolsByName)) + continue + } + } else { + if len(st.symbolsByName) != len(tc.symbols) { + t.Errorf("len(st.symbolsByName) = %d, expected %d", len(st.symbolsByName), len(tc.symbols)) + continue + } + } + for _, lookupName := range tc.lookups { + _, err := st.LookupByName(lookupName) + if err != nil { + t.Errorf("LookupByName(%s) failed: %v", lookupName, err) + continue testLoop + } + } + for i := range tc.expectedMappings { + if !reflect.DeepEqual(*(st.symbolsByName[tc.symbols[tc.expectedMappings[i]].name][0]), *tc.symbols[tc.expectedMappings[i]]) { + t.Errorf("st.symbolsByName[\"%s\"] = %v, expected %v", tc.symbols[tc.expectedMappings[i]].name, *(st.symbolsByName[tc.symbols[tc.expectedMappings[i]].name][0]), *tc.symbols[tc.expectedMappings[i]]) + continue + } + } + } +} + +// TestLookupByAddressExact tests the LookupByAddressExact function +func TestLookupByAddressExact(t *testing.T) { + testCases := []struct { + symbols []*testSymbol + lookupAddr uint64 + expectLookupError bool + expected []testSymbol + }{ + { + []*testSymbol{{name: "symbol1", addr: 1, size: 1}}, + 1, + false, + []testSymbol{{name: "symbol1", addr: 1, size: 1}}, + }, + { + []*testSymbol{}, + 2, + true, + []testSymbol{}, + }, + { + []*testSymbol{{name: "symbol3", addr: 3, size: 1}}, + 4, + true, + []testSymbol{}, + }, + { + []*testSymbol{{name: "symbol5", addr: 5, size: 1}, {name: "symbol6", addr: 6, size: 2}}, + 6, + false, + []testSymbol{{name: "symbol6", addr: 6, size: 2}}, + }, + { + []*testSymbol{{name: "symbol7", addr: 7, size: 1}, {name: "symbol8", addr: 7, size: 2}}, + 7, + false, + []testSymbol{{name: "symbol7", addr: 7, size: 1}, {name: "symbol8", addr: 7, size: 2}}, + }, + } + + for _, tc := range testCases { + st := NewSymbolTable[testSymbol](false) + st.AddSymbols(tc.symbols) + result, err := st.LookupByAddressExact(tc.lookupAddr) + if !tc.expectLookupError && err != nil { + t.Errorf("LookupByAddressExact(%d) failed: %v", tc.lookupAddr, err) + continue + } else if tc.expectLookupError && err == nil { + t.Errorf("LookupByAddressExact(%d) expected to fail but didn't", tc.lookupAddr) + continue + } + if !reflect.DeepEqual(copySliceOfPointersToSliceOfStructs(result), tc.expected) { + t.Errorf("LookupByAddressExact(%d) = %v, expected %v", tc.lookupAddr, copySliceOfPointersToSliceOfStructs(result), tc.expected) + } + } +} + +// TestLookupByAddressContains tests the LookupByAddressContains function +func TestLookupByAddressContains(t *testing.T) { + testCases := []struct { + symbols []*testSymbol + lookupAddr uint64 + expected *testSymbol + }{ + { + []*testSymbol{}, + 1, + nil, + }, + { + []*testSymbol{{name: "symbol1", addr: 2, size: 2}}, + 2, + &testSymbol{name: "symbol1", addr: 2, size: 2}, + }, + { + []*testSymbol{{name: "symbol2", addr: 3, size: 2}}, + 4, + &testSymbol{name: "symbol2", addr: 3, size: 2}, + }, + { + []*testSymbol{{name: "symbol3", addr: 4, size: 2}}, + 6, + nil, + }, + { + []*testSymbol{{name: "symbol4", addr: 10, size: 2}}, + 8, + nil, + }, + { + []*testSymbol{{name: "symbol5", addr: 11, size: 2}}, + 14, + nil, + }, + { + []*testSymbol{{name: "symbol6", addr: 15, size: 5}, {name: "symbol7", addr: 17, size: 3}}, + 18, + &testSymbol{name: "symbol7", addr: 17, size: 3}, + }, + { // this is a special case assumed to be impossible in practice, see the docstring of LookupByAddressContains() + []*testSymbol{{name: "symbol8", addr: 20, size: 5}, {name: "symbol9", addr: 21, size: 2}}, + 23, + nil, + }, + } + + for _, tc := range testCases { + st := NewSymbolTable[testSymbol](false) + st.AddSymbols(tc.symbols) + result, err := st.LookupByAddressContains(tc.lookupAddr) + if tc.expected != nil && err != nil { + t.Errorf("LookupByAddressContains(%d) failed: %v", tc.lookupAddr, err) + continue + } + if tc.expected == nil { + if err == nil { + t.Errorf("LookupByAddressContains(%d) expected to fail, but returned %v", tc.lookupAddr, *result) + } + continue + } + if !reflect.DeepEqual(*result, *tc.expected) { + t.Errorf("LookupByAddressContains(%d) = %v, expected %v", tc.lookupAddr, *result, *tc.expected) + } + } +} + +// copySliceOfPointersToSliceOfStructs converts a slice of pointers to a slice of structs. +func copySliceOfPointersToSliceOfStructs(s []*testSymbol) []testSymbol { + ret := make([]testSymbol, len(s)) + for i, v := range s { + ret[i] = *v + } + return ret +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 3a88eb74232b..e72529dcb34f 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -5,10 +5,7 @@ import ( "io" "math/rand" "reflect" - "strings" "time" - - "github.com/aquasecurity/tracee/pkg/utils/environment" ) // Cloner is a generic interface for objects that can clone themselves. @@ -25,23 +22,6 @@ type Iterator[T any] interface { Next() T } -func ParseSymbol(address uint64, table *environment.KernelSymbolTable) environment.KernelSymbol { - var hookingFunction environment.KernelSymbol - - symbols, err := table.GetSymbolByAddr(address) - if err != nil { - hookingFunction = environment.KernelSymbol{} - hookingFunction.Owner = "hidden" - } else { - hookingFunction = symbols[0] - } - - hookingFunction.Owner = strings.TrimPrefix(hookingFunction.Owner, "[") - hookingFunction.Owner = strings.TrimSuffix(hookingFunction.Owner, "]") - - return hookingFunction -} - func HasBit(n uint64, offset uint) bool { return (n & (1 << offset)) > 0 } diff --git a/tests/e2e-inst-signatures/e2e-set_fs_pwd.go b/tests/e2e-inst-signatures/e2e-set_fs_pwd.go index f9cda190b64b..99f7b6cac83b 100644 --- a/tests/e2e-inst-signatures/e2e-set_fs_pwd.go +++ b/tests/e2e-inst-signatures/e2e-set_fs_pwd.go @@ -4,6 +4,9 @@ import ( "fmt" "strings" + "kernel.org/pub/linux/libs/security/libcap/cap" + + "github.com/aquasecurity/tracee/pkg/capabilities" "github.com/aquasecurity/tracee/pkg/utils/environment" "github.com/aquasecurity/tracee/signatures/helpers" "github.com/aquasecurity/tracee/types/detect" @@ -21,7 +24,15 @@ func (sig *e2eSetFsPwd) Init(ctx detect.SignatureContext) error { // Find if this system has the bpf_probe_read_user_str helper. // If it doesn't we won't expect the unresolved path to contain anything - ksyms, err := environment.NewKernelSymbolTable() + var ksyms *environment.KernelSymbolTable + err := capabilities.GetInstance().Specific( + func() error { + var err error + ksyms, err = environment.NewKernelSymbolTable(false, false) + return err + }, + cap.SYSLOG, + ) if err != nil { return err }