diff --git a/protobuf-go-lite.go b/protobuf-go-lite.go index 6f407dbd..98ae13fa 100644 --- a/protobuf-go-lite.go +++ b/protobuf-go-lite.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "math/bits" + "slices" ) var ( @@ -43,6 +44,18 @@ type CloneVT[T comparable] interface { CloneVT() T } +// CloneVTSlice clones a slice of CloneVT messages. +func CloneVTSlice[S ~[]E, E CloneVT[E]](s S) S { + out := make([]E, len(s)) + var empty E + for i := range s { + if s[i] != empty { + out[i] = s[i].CloneVT() + } + } + return out +} + // EqualVT is a message with a EqualVT function (VTProtobuf). type EqualVT[T comparable] interface { comparable @@ -50,17 +63,17 @@ type EqualVT[T comparable] interface { EqualVT(other T) bool } -// CompareEqualVT returns a compare function to compare two VTProtobuf messages. -func CompareEqualVT[T EqualVT[T]]() func(t1, t2 T) bool { +// CompareComparable returns a compare function to compare two comparable types. +func CompareComparable[T comparable]() func(t1, t2 T) bool { return func(t1, t2 T) bool { - return IsEqualVT(t1, t2) + return t1 == t2 } } -// CompareComparable returns a compare function to compare two comparable types. -func CompareComparable[T comparable]() func(t1, t2 T) bool { +// CompareEqualVT returns a compare function to compare two VTProtobuf messages. +func CompareEqualVT[T EqualVT[T]]() func(t1, t2 T) bool { return func(t1, t2 T) bool { - return t1 == t2 + return IsEqualVT(t1, t2) } } @@ -77,6 +90,11 @@ func IsEqualVT[T EqualVT[T]](t1, t2 T) bool { return t1.EqualVT(t2) } +// IsEqualVTSlice checks if two slices of EqualVT messages are equal. +func IsEqualVTSlice[S ~[]E, E EqualVT[E]](s1, s2 S) bool { + return slices.EqualFunc(s1, s2, CompareEqualVT[E]()) +} + // EncodeVarint encodes a uint64 into a varint-encoded byte slice and returns the offset of the encoded value. // The provided offset is the offset after the last byte of the encoded value. func EncodeVarint(dAtA []byte, offset int, v uint64) int { diff --git a/protobuf-go-lite_test.go b/protobuf-go-lite_test.go index a8f8b6b9..8a5a5809 100644 --- a/protobuf-go-lite_test.go +++ b/protobuf-go-lite_test.go @@ -7,7 +7,16 @@ type testCase struct { } func (t *testCase) EqualVT(ot *testCase) bool { - return t == ot + if t == ot { + return true + } + if (ot == nil) != (t == nil) { + return false + } + if ot == nil { + return true + } + return t.val == ot.val } func TestCompareVT(t *testing.T) { @@ -26,3 +35,53 @@ func TestCompareVT(t *testing.T) { t.Fail() } } + +func TestIsEqualVTSlice(t *testing.T) { + testCases := []struct { + s1, s2 []*testCase + expect bool + }{ + { + s1: []*testCase{{val: 1}, {val: 2}}, + s2: []*testCase{{val: 1}, {val: 2}}, + expect: true, + }, + { + s1: []*testCase{{val: 1}, {val: 2}}, + s2: []*testCase{{val: 1}, {val: 3}}, + expect: false, + }, + { + s1: []*testCase{{val: 1}, {val: 2}}, + s2: []*testCase{{val: 1}, {val: 2}, {val: 3}}, + expect: false, + }, + { + s1: []*testCase{{val: 1}, nil}, + s2: []*testCase{{val: 1}, nil}, + expect: true, + }, + { + s1: []*testCase{{val: 1}, nil}, + s2: []*testCase{{val: 1}, {val: 2}}, + expect: false, + }, + { + s1: []*testCase{}, + s2: []*testCase{}, + expect: true, + }, + { + s1: nil, + s2: nil, + expect: true, + }, + } + + for _, tc := range testCases { + actual := IsEqualVTSlice(tc.s1, tc.s2) + if actual != tc.expect { + t.Errorf("IsEqualVTSlice(%v, %v) = %v; want %v", tc.s1, tc.s2, actual, tc.expect) + } + } +}