From d59475da10ddd98c91f84237ea4c1bf8e43b1ff0 Mon Sep 17 00:00:00 2001 From: Patrick Boyd Date: Thu, 19 Sep 2024 19:39:14 -0500 Subject: [PATCH] More optimizations --- dbus.go | 2 +- encoder.go | 67 ++++++++++++++++++++++++++++++++++++++++++++++- message.go | 50 ++++++++++++++--------------------- transport_unix.go | 13 ++++++--- 4 files changed, 96 insertions(+), 36 deletions(-) diff --git a/dbus.go b/dbus.go index 8bc2799..e163990 100644 --- a/dbus.go +++ b/dbus.go @@ -365,7 +365,7 @@ func alignment(t reflect.Type) int { return 1 case reflect.Uint16, reflect.Int16: return 2 - case reflect.Uint, reflect.Int, reflect.Uint32, reflect.Int32, reflect.String, reflect.Array, reflect.Slice, reflect.Map: + case reflect.Uint, reflect.Int, reflect.Uint32, reflect.Int32, reflect.String, reflect.Array, reflect.Slice, reflect.Map, reflect.Bool: return 4 case reflect.Uint64, reflect.Int64, reflect.Float64, reflect.Struct: return 8 diff --git a/encoder.go b/encoder.go index 2de96ad..de31fd8 100644 --- a/encoder.go +++ b/encoder.go @@ -258,6 +258,18 @@ func (enc *encoder) Encode(vs ...interface{}) (err error) { return nil } +func CountFDs(vs ...interface{}) (int, error) { + var err error + defer func() { + err, _ = recover().(error) + }() + count := 0 + for _, v := range vs { + count += fdCounter(reflect.ValueOf(v), 0) + } + return count, err +} + // encode encodes the given value to the writer and panics on error. depth holds // the depth of the container nesting. func (enc *encoder) encode(v reflect.Value, depth int) { @@ -273,7 +285,7 @@ func (enc *encoder) encode(v reflect.Value, depth int) { if v.Bool() { enc.binWriteIntType(uint32(1)) } else { - enc.binWriteIntType(uint32(1)) + enc.binWriteIntType(uint32(0)) } enc.pos += 4 case reflect.Int16: @@ -414,3 +426,56 @@ func (enc *encoder) encode(v reflect.Value, depth int) { panic(InvalidTypeError{v.Type()}) } } + +func fdCounter(v reflect.Value, depth int) int { + if depth > 64 { + panic(FormatError("input exceeds depth limitation")) + } + switch v.Kind() { + case reflect.Int, reflect.Int32: + if v.Type() == unixFDType { + return 1 + } + return 0 + case reflect.Ptr: + return fdCounter(v.Elem(), depth) + case reflect.Slice, reflect.Array: + // we don't really need the child encoder in this case since we aren't actually messing with the buffer at all + count := 0 + for i := 0; i < v.Len(); i++ { + count += fdCounter(v.Index(i), depth+1) + } + return count + case reflect.Struct: + switch t := v.Type(); t { + case variantType: + variant := v.Interface().(Variant) + return fdCounter(reflect.ValueOf(variant.value), depth+1) + default: + count := 0 + for i := 0; i < v.Type().NumField(); i++ { + field := t.Field(i) + if field.PkgPath == "" && field.Tag.Get("dbus") != "-" { + count += fdCounter(v.Field(i), depth+1) + } + } + return count + } + case reflect.Map: + // Maps are arrays of structures, so they actually increase the depth by + // 2. + // we don't really need the child encoder in this case since we aren't actually messing with the buffer at all + iter := v.MapRange() + count := 0 + for iter.Next() { + count += fdCounter(iter.Key(), depth+2) + count += fdCounter(iter.Value(), depth+2) + } + return count + case reflect.Interface: + return fdCounter(reflect.ValueOf(MakeVariant(v.Interface())), depth) + default: + // do nothing we are skipping most types + return 0 + } +} diff --git a/message.go b/message.go index 2b05f9f..5f164ea 100644 --- a/message.go +++ b/message.go @@ -3,7 +3,6 @@ package dbus import ( "bytes" "encoding/binary" - "errors" "io" "reflect" "strconv" @@ -203,33 +202,20 @@ func DecodeMessage(rd io.Reader) (msg *Message, err error) { return DecodeMessageWithFDs(rd, make([]int, 0)) } -type nullwriter struct{} - -func (nullwriter) Write(p []byte) (cnt int, err error) { - return len(p), nil -} - func (msg *Message) CountFds() (int, error) { if len(msg.Body) == 0 { return 0, nil } - enc := newEncoder(nullwriter{}, nativeEndian, make([]int, 0)) - err := enc.Encode(msg.Body...) - return len(enc.fds), err + return CountFDs(msg.Body...) } func (msg *Message) EncodeToWithFDs(out io.Writer, order binary.ByteOrder) (fds []int, err error) { if err := msg.validateHeader(); err != nil { return nil, err } - var vs [7]interface{} - switch order { - case binary.LittleEndian: - vs[0] = byte('l') - case binary.BigEndian: - vs[0] = byte('B') - default: - return nil, errors.New("dbus: invalid byte order") + endianByte := byte('l') + if order == binary.BigEndian { + endianByte = byte('B') } body := new(bytes.Buffer) fds = make([]int, 0) @@ -240,32 +226,34 @@ func (msg *Message) EncodeToWithFDs(out io.Writer, order binary.ByteOrder) (fds return } } - vs[1] = msg.Type - vs[2] = msg.Flags - vs[3] = protoVersion - vs[4] = uint32(len(body.Bytes())) - vs[5] = msg.serial headers := make([]header, 0, len(msg.Headers)) for k, v := range msg.Headers { headers = append(headers, header{byte(k), v}) } - vs[6] = headers - var buf bytes.Buffer - enc = newEncoder(&buf, order, enc.fds) - err = enc.Encode(vs[:]...) + buf := bytes.NewBuffer(make([]byte, 0, 128)) + // No need to alloc a new encoder, just reset the old one + enc.Reset(buf, order, enc.fds) + buf.WriteByte(endianByte) + buf.WriteByte(byte(msg.Type)) + buf.WriteByte(byte(msg.Flags)) + buf.WriteByte(protoVersion) + enc.binWriteIntType(uint32(len(body.Bytes()))) + enc.binWriteIntType(msg.serial) + enc.pos = 12 + err = enc.Encode(headers) if err != nil { return } enc.align(8) - if _, err := body.WriteTo(&buf); err != nil { - return nil, err - } - if buf.Len() > 1<<27 { + if buf.Len()+body.Len() > 1<<27 { return nil, InvalidMessageError("message is too long") } if _, err := buf.WriteTo(out); err != nil { return nil, err } + if _, err := body.WriteTo(out); err != nil { + return nil, err + } return enc.fds, nil } diff --git a/transport_unix.go b/transport_unix.go index f0a80dc..ce0e195 100644 --- a/transport_unix.go +++ b/transport_unix.go @@ -9,6 +9,7 @@ import ( "errors" "io" "net" + "sync" "syscall" ) @@ -31,7 +32,6 @@ type oobReader struct { // The following fields are used to reduce memory allocs. csheader []byte b *bytes.Buffer - dec *decoder msghead } @@ -92,6 +92,12 @@ func (t *unixTransport) EnableUnixFDs() { t.hasUnixFDs = true } +var decodePool = sync.Pool{ + New: func() interface{} { + return new(decoder) + }, +} + func (t *unixTransport) ReadMessage() (*Message, error) { // To be sure that all bytes of out-of-band data are read, we use a special // reader that uses ReadUnix on the underlying connection instead of Read @@ -102,15 +108,16 @@ func (t *unixTransport) ReadMessage() (*Message, error) { // This buffer is used to decode the part of the header that has a constant size. csheader: make([]byte, 16), b: bytes.NewBuffer(make([]byte, defaultBufferSize)), - dec: &decoder{}, } } else { t.rdr.oob = t.rdr.oob[:0] } var ( b = t.rdr.b - dec = t.rdr.dec + dec = decodePool.Get().(*decoder) ) + // Put the decoder back in the pool for others to use. + defer decodePool.Put(dec) b.Reset() if _, err := io.CopyN(b, t.rdr, 16); err != nil {