diff --git a/CHANGELOG.md b/CHANGELOG.md index f62080c5d..c03587756 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. ### Changed +- Made MessagePack extension encoding and decoding functions exportable, + allowing users to reuse the logic for custom extensions (#421). + ### Fixed ## [v2.2.0] - 2024-12-16 diff --git a/arrow/arrow.go b/arrow/arrow.go index aaeaccca9..b02a3ee57 100644 --- a/arrow/arrow.go +++ b/arrow/arrow.go @@ -7,8 +7,8 @@ import ( "github.com/vmihailenco/msgpack/v5" ) -// Arrow MessagePack extension type. -const arrowExtId = 8 +// ExtID represents the Arrow MessagePack extension type identifier. +const ExtID = 8 // Arrow struct wraps a raw arrow data buffer. type Arrow struct { @@ -26,31 +26,31 @@ func (a Arrow) Raw() []byte { return a.data } -func arrowDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error { +func DecodeExt(d *msgpack.Decoder, v reflect.Value, extLen int) error { arrow := Arrow{ data: make([]byte, extLen), } n, err := d.Buffered().Read(arrow.data) if err != nil { - return fmt.Errorf("arrowDecoder: can't read bytes on Arrow decode: %w", err) + return fmt.Errorf("decode: can't read bytes on Arrow decode: %w", err) } if n < extLen || n != len(arrow.data) { - return fmt.Errorf("arrowDecoder: unexpected end of stream after %d Arrow bytes", n) + return fmt.Errorf("decode: unexpected end of stream after %d Arrow bytes", n) } v.Set(reflect.ValueOf(arrow)) return nil } -func arrowEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) { +func EncodeExt(_ *msgpack.Encoder, v reflect.Value) ([]byte, error) { arr, ok := v.Interface().(Arrow) if !ok { - return []byte{}, fmt.Errorf("arrowEncoder: not an Arrow type") + return []byte{}, fmt.Errorf("encode: not an Arrow type") } return arr.data, nil } func init() { - msgpack.RegisterExtDecoder(arrowExtId, Arrow{}, arrowDecoder) - msgpack.RegisterExtEncoder(arrowExtId, Arrow{}, arrowEncoder) + msgpack.RegisterExtDecoder(ExtID, Arrow{}, DecodeExt) + msgpack.RegisterExtEncoder(ExtID, Arrow{}, EncodeExt) } diff --git a/datetime/datetime.go b/datetime/datetime.go index f5a2a8278..c70b089a6 100644 --- a/datetime/datetime.go +++ b/datetime/datetime.go @@ -1,4 +1,4 @@ -// Package with support of Tarantool's datetime data type. +// Package datetime provides support for Tarantool's datetime data type. // // Datetime data type supported in Tarantool since 2.10. // @@ -34,9 +34,10 @@ import ( // * [optional] all the other fields (nsec, tzoffset, tzindex) if any of them // were having not 0 value. They are packed naturally in little-endian order; -// Datetime external type. Supported since Tarantool 2.10. See more details in +// ExtID represents the Datetime MessagePack extension type identifier. +// Supported since Tarantool 2.10. See more details in // issue https://github.com/tarantool/tarantool/issues/5946. -const datetimeExtID = 4 +const ExtID = 4 // datetime structure keeps a number of seconds and nanoseconds since Unix Epoch. // Time is normalized by UTC, so time-zone offset is informative only. @@ -242,7 +243,7 @@ func (d *Datetime) ToTime() time.Time { return d.time } -func datetimeEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) { +func EncodeExt(_ *msgpack.Encoder, v reflect.Value) ([]byte, error) { dtime := v.Interface().(Datetime) tm := dtime.ToTime() @@ -275,7 +276,7 @@ func datetimeEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) { return buf, nil } -func datetimeDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error { +func DecodeExt(d *msgpack.Decoder, v reflect.Value, extLen int) error { if extLen != maxSize && extLen != secondsSize { return fmt.Errorf("invalid data length: got %d, wanted %d or %d", extLen, secondsSize, maxSize) @@ -333,6 +334,6 @@ func datetimeDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error { } func init() { - msgpack.RegisterExtDecoder(datetimeExtID, Datetime{}, datetimeDecoder) - msgpack.RegisterExtEncoder(datetimeExtID, Datetime{}, datetimeEncoder) + msgpack.RegisterExtDecoder(ExtID, Datetime{}, DecodeExt) + msgpack.RegisterExtEncoder(ExtID, Datetime{}, EncodeExt) } diff --git a/datetime/interval.go b/datetime/interval.go index bcd052383..b243b9137 100644 --- a/datetime/interval.go +++ b/datetime/interval.go @@ -8,7 +8,8 @@ import ( "github.com/vmihailenco/msgpack/v5" ) -const interval_extId = 6 +// IntervalExtID represents the Interval MessagePack extension type identifier. +const IntervalExtID = 6 const ( fieldYear = 0 @@ -74,7 +75,7 @@ func encodeIntervalValue(e *msgpack.Encoder, typ uint64, value int64) (err error if err == nil { if value > 0 { err = e.EncodeUint(uint64(value)) - } else if value < 0 { + } else { err = e.EncodeInt(value) } } @@ -181,20 +182,22 @@ func decodeInterval(d *msgpack.Decoder, v reflect.Value) (err error) { return nil } -func init() { - msgpack.RegisterExtEncoder(interval_extId, Interval{}, - func(e *msgpack.Encoder, v reflect.Value) (ret []byte, err error) { - var b bytes.Buffer +func EncodeIntervalExt(_ *msgpack.Encoder, v reflect.Value) (ret []byte, err error) { + var b bytes.Buffer - enc := msgpack.NewEncoder(&b) - if err = encodeInterval(enc, v); err == nil { - ret = b.Bytes() - } + enc := msgpack.NewEncoder(&b) + if err = encodeInterval(enc, v); err == nil { + ret = b.Bytes() + } - return - }) - msgpack.RegisterExtDecoder(interval_extId, Interval{}, - func(d *msgpack.Decoder, v reflect.Value, extLen int) error { - return decodeInterval(d, v) - }) + return +} + +func DecodeIntervalExt(d *msgpack.Decoder, v reflect.Value, _ int) error { + return decodeInterval(d, v) +} + +func init() { + msgpack.RegisterExtEncoder(IntervalExtID, Interval{}, EncodeIntervalExt) + msgpack.RegisterExtDecoder(IntervalExtID, Interval{}, DecodeIntervalExt) } diff --git a/decimal/decimal.go b/decimal/decimal.go index 3a1abb76e..f75aceb54 100644 --- a/decimal/decimal.go +++ b/decimal/decimal.go @@ -1,4 +1,4 @@ -// Package decimal with support of Tarantool's decimal data type. +// Package decimal provides support for Tarantool's decimal data type. // // Decimal data type supported in Tarantool since 2.2. // @@ -37,11 +37,10 @@ import ( // - Tarantool module decimal: // https://www.tarantool.io/en/doc/latest/reference/reference_lua/decimal/ -const ( - // Decimal external type. - decimalExtID = 1 - decimalPrecision = 38 -) +// ExtID represents the Decimal MessagePack extension type identifier. +const ExtID = 1 + +const decimalPrecision = 38 var ( one decimal.Decimal = decimal.NewFromInt(1) @@ -71,7 +70,7 @@ func MakeDecimalFromString(src string) (Decimal, error) { return result, nil } -func decimalEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) { +func EncodeExt(_ *msgpack.Encoder, v reflect.Value) ([]byte, error) { dec := v.Interface().(Decimal) if dec.GreaterThan(maxSupportedDecimal) { return nil, @@ -94,7 +93,7 @@ func decimalEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) { return bcdBuf, nil } -func decimalDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error { +func DecodeExt(d *msgpack.Decoder, v reflect.Value, extLen int) error { b := make([]byte, extLen) n, err := d.Buffered().Read(b) if err != nil { @@ -131,6 +130,6 @@ func decimalDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error { } func init() { - msgpack.RegisterExtDecoder(decimalExtID, Decimal{}, decimalDecoder) - msgpack.RegisterExtEncoder(decimalExtID, Decimal{}, decimalEncoder) + msgpack.RegisterExtDecoder(ExtID, Decimal{}, DecodeExt) + msgpack.RegisterExtEncoder(ExtID, Decimal{}, EncodeExt) } diff --git a/uuid/uuid.go b/uuid/uuid.go index cc2be736f..f34976d3c 100644 --- a/uuid/uuid.go +++ b/uuid/uuid.go @@ -1,4 +1,4 @@ -// Package with support of Tarantool's UUID data type. +// Package uuid provides support for Tarantool's UUID data type. // // UUID data type supported in Tarantool since 2.4.1. // @@ -24,8 +24,8 @@ import ( "github.com/vmihailenco/msgpack/v5" ) -// UUID external type. -const uuid_extID = 2 +// ExtID represents the UUID MessagePack extension type identifier. +const ExtID = 2 func encodeUUID(e *msgpack.Encoder, v reflect.Value) error { id := v.Interface().(uuid.UUID) @@ -64,15 +64,17 @@ func decodeUUID(d *msgpack.Decoder, v reflect.Value) error { return nil } +func EncodeExt(_ *msgpack.Encoder, v reflect.Value) ([]byte, error) { + u := v.Interface().(uuid.UUID) + return u.MarshalBinary() +} + +func DecodeExt(d *msgpack.Decoder, v reflect.Value, _ int) error { + return decodeUUID(d, v) +} + func init() { msgpack.Register(reflect.TypeOf((*uuid.UUID)(nil)).Elem(), encodeUUID, decodeUUID) - msgpack.RegisterExtEncoder(uuid_extID, uuid.UUID{}, - func(e *msgpack.Encoder, v reflect.Value) ([]byte, error) { - uuid := v.Interface().(uuid.UUID) - return uuid.MarshalBinary() - }) - msgpack.RegisterExtDecoder(uuid_extID, uuid.UUID{}, - func(d *msgpack.Decoder, v reflect.Value, extLen int) error { - return decodeUUID(d, v) - }) + msgpack.RegisterExtEncoder(ExtID, uuid.UUID{}, EncodeExt) + msgpack.RegisterExtDecoder(ExtID, uuid.UUID{}, DecodeExt) }