diff --git a/pkg/frame/native.go b/pkg/frame/native.go index a4e6363..6a11aab 100644 --- a/pkg/frame/native.go +++ b/pkg/frame/native.go @@ -169,7 +169,8 @@ func (n *NativeFrame[I]) Equals(target INativeFrame) bool { } if n.Rows() != target.Rows() || n.Cols() != target.Cols() || - n.BitsPerSample() != n.BitsPerSample() { + n.BitsPerSample() != target.BitsPerSample() || + n.SamplesPerPixel() != target.SamplesPerPixel() { return false } diff --git a/pkg/frame/native_test.go b/pkg/frame/native_test.go index 3ce576c..1669e2a 100644 --- a/pkg/frame/native_test.go +++ b/pkg/frame/native_test.go @@ -204,6 +204,92 @@ func TestNativeFrame_RawDataSlice(t *testing.T) { } } +func TestNativeFrame_Equals(t *testing.T) { + cases := []struct { + name string + a frame.NativeFrame[int] + b frame.NativeFrame[int] + equal bool + }{ + { + name: "equal", + a: frame.NativeFrame[int]{ + RawData: []int{1, 2, 3}, + InternalSamplesPerPixel: 2, + InternalCols: 3, + InternalRows: 4, + InternalBitsPerSample: 64, + }, + b: frame.NativeFrame[int]{ + RawData: []int{1, 2, 3}, + InternalSamplesPerPixel: 2, + InternalCols: 3, + InternalRows: 4, + InternalBitsPerSample: 64, + }, + equal: true, + }, + { + name: "mismatched data", + a: frame.NativeFrame[int]{ + RawData: []int{1, 2, 3}, + }, + b: frame.NativeFrame[int]{ + RawData: []int{2, 2, 3}, + }, + equal: false, + }, + { + name: "mismatched BitsPerSample", + a: frame.NativeFrame[int]{ + InternalBitsPerSample: 2, + }, + b: frame.NativeFrame[int]{ + InternalBitsPerSample: 4, + }, + equal: false, + }, + { + name: "mismatched SamplesPerPixel", + a: frame.NativeFrame[int]{ + InternalSamplesPerPixel: 2, + }, + b: frame.NativeFrame[int]{ + InternalSamplesPerPixel: 4, + }, + equal: false, + }, + { + name: "mismatched Rows", + a: frame.NativeFrame[int]{ + InternalRows: 2, + }, + b: frame.NativeFrame[int]{ + InternalRows: 4, + }, + equal: false, + }, + { + name: "mismatched Cols", + a: frame.NativeFrame[int]{ + InternalCols: 2, + }, + b: frame.NativeFrame[int]{ + InternalCols: 4, + }, + equal: false, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := tc.a.Equals(&tc.b) + if got != tc.equal { + t.Errorf("Equals(%+v, %+v) got unexpected value. got: %v, want: %v", tc.a, tc.b, got, tc.equal) + } + }) + } +} + // within returns true if pt is in the []point func within(pt point, set []point) bool { for _, item := range set {