diff --git a/api/converter/from_bytes.go b/api/converter/from_bytes.go index 04662aded..3c18f9971 100644 --- a/api/converter/from_bytes.go +++ b/api/converter/from_bytes.go @@ -267,12 +267,19 @@ func fromJSONCounter(pbCnt *api.JSONElement_Counter) (*crdt.Counter, error) { if err != nil { return nil, err } + counterValue, err := crdt.CounterValueFromBytes(counterType, pbCnt.Value) + if err != nil { + return nil, err + } - counter := crdt.NewCounter( + counter, err := crdt.NewCounter( counterType, - crdt.CounterValueFromBytes(counterType, pbCnt.Value), + counterValue, createdAt, ) + if err != nil { + return nil, err + } counter.SetMovedAt(movedAt) counter.SetRemovedAt(removedAt) diff --git a/api/converter/from_pb.go b/api/converter/from_pb.go index f63585a2d..868aefd59 100644 --- a/api/converter/from_pb.go +++ b/api/converter/from_pb.go @@ -737,11 +737,20 @@ func fromElement(pbElement *api.JSONElementSimple) (crdt.Element, error) { if err != nil { return nil, err } - return crdt.NewCounter( + counterValue, err := crdt.CounterValueFromBytes(counterType, pbElement.Value) + if err != nil { + return nil, err + } + + counter, err := crdt.NewCounter( counterType, - crdt.CounterValueFromBytes(counterType, pbElement.Value), + counterValue, createdAt, - ), nil + ) + if err != nil { + return nil, err + } + return counter, nil case api.ValueType_VALUE_TYPE_TREE: return BytesToTree(pbElement.Value) } diff --git a/api/converter/to_bytes.go b/api/converter/to_bytes.go index 7313e9f51..42d3ce711 100644 --- a/api/converter/to_bytes.go +++ b/api/converter/to_bytes.go @@ -158,11 +158,15 @@ func toCounter(counter *crdt.Counter) (*api.JSONElement, error) { if err != nil { return nil, err } + counterValue, err := counter.Bytes() + if err != nil { + return nil, err + } return &api.JSONElement{ Body: &api.JSONElement_Counter_{Counter: &api.JSONElement_Counter{ Type: pbCounterType, - Value: counter.Bytes(), + Value: counterValue, CreatedAt: ToTimeTicket(counter.CreatedAt()), MovedAt: ToTimeTicket(counter.MovedAt()), RemovedAt: ToTimeTicket(counter.RemovedAt()), diff --git a/api/converter/to_pb.go b/api/converter/to_pb.go index 49d736ff0..b4d22e570 100644 --- a/api/converter/to_pb.go +++ b/api/converter/to_pb.go @@ -435,11 +435,15 @@ func toJSONElementSimple(elem crdt.Element) (*api.JSONElementSimple, error) { if err != nil { return nil, err } + counterValue, err := elem.Bytes() + if err != nil { + return nil, err + } return &api.JSONElementSimple{ Type: pbCounterType, CreatedAt: ToTimeTicket(elem.CreatedAt()), - Value: elem.Bytes(), + Value: counterValue, }, nil case *crdt.Tree: bytes, err := TreeToBytes(elem) diff --git a/pkg/document/crdt/counter.go b/pkg/document/crdt/counter.go index 0bf7c5159..f43dac3fa 100644 --- a/pkg/document/crdt/counter.go +++ b/pkg/document/crdt/counter.go @@ -18,11 +18,15 @@ package crdt import ( "encoding/binary" + "errors" "fmt" "github.com/yorkie-team/yorkie/pkg/document/time" ) +// ErrUnsupportedType is returned when the given type is not supported. +var ErrUnsupportedType = errors.New("unsupported type") + // CounterType represents any type that can be used as a counter. type CounterType int @@ -33,16 +37,16 @@ const ( ) // CounterValueFromBytes parses the given bytes into value. -func CounterValueFromBytes(counterType CounterType, value []byte) interface{} { +func CounterValueFromBytes(counterType CounterType, value []byte) (interface{}, error) { switch counterType { case IntegerCnt: val := int32(binary.LittleEndian.Uint32(value)) - return int(val) + return int(val), nil case LongCnt: - return int64(binary.LittleEndian.Uint64(value)) + return int64(binary.LittleEndian.Uint64(value)), nil + default: + return nil, ErrUnsupportedType } - - panic("unsupported type") } // Counter represents changeable number data type. @@ -55,39 +59,47 @@ type Counter struct { } // NewCounter creates a new instance of Counter. -func NewCounter(valueType CounterType, value interface{}, createdAt *time.Ticket) *Counter { +func NewCounter(valueType CounterType, value interface{}, createdAt *time.Ticket) (*Counter, error) { switch valueType { case IntegerCnt: + intValue, err := castToInt(value) + if err != nil { + return nil, err + } return &Counter{ valueType: IntegerCnt, - value: castToInt(value), + value: intValue, createdAt: createdAt, - } + }, nil case LongCnt: + longValue, err := castToLong(value) + if err != nil { + return nil, err + } return &Counter{ valueType: LongCnt, - value: castToLong(value), + value: longValue, createdAt: createdAt, - } + }, nil + default: + return nil, ErrUnsupportedType } - - panic("unsupported type") } // Bytes creates an array representing the value. -func (p *Counter) Bytes() []byte { +func (p *Counter) Bytes() ([]byte, error) { switch val := p.value.(type) { case int32: bytes := [4]byte{} binary.LittleEndian.PutUint32(bytes[:], uint32(val)) - return bytes[:] + return bytes[:], nil case int64: bytes := [8]byte{} binary.LittleEndian.PutUint64(bytes[:], uint64(val)) - return bytes[:] + return bytes[:], nil + default: + return nil, ErrUnsupportedType } - - panic("unsupported type") } // Marshal returns the JSON encoding of the value. @@ -146,20 +158,28 @@ func (p *Counter) ValueType() CounterType { // than MinInt32, Counter's value type can be changed Integer to Long. // Because in golang, int can be either int32 or int64. // So we need to assert int to int32. -func (p *Counter) Increase(v *Primitive) *Counter { +func (p *Counter) Increase(v *Primitive) (*Counter, error) { if !p.IsNumericType() || !v.IsNumericType() { - panic("unsupported type") + return nil, ErrUnsupportedType } switch p.valueType { case IntegerCnt: - p.value = p.value.(int32) + castToInt(v.value) + intValue, err := castToInt(v.value) + if err != nil { + return nil, err + } + p.value = p.value.(int32) + intValue case LongCnt: - p.value = p.value.(int64) + castToLong(v.value) + longValue, err := castToLong(v.value) + if err != nil { + return nil, err + } + p.value = p.value.(int64) + longValue default: - panic("unsupported type") + return nil, ErrUnsupportedType } - return p + return p, nil } // IsNumericType checks for numeric types. @@ -169,37 +189,37 @@ func (p *Counter) IsNumericType() bool { } // castToInt casts numeric type to int32. -func castToInt(value interface{}) int32 { +func castToInt(value interface{}) (int32, error) { switch val := value.(type) { case int32: - return val + return val, nil case int64: - return int32(val) + return int32(val), nil case int: - return int32(val) + return int32(val), nil case float32: - return int32(val) + return int32(val), nil case float64: - return int32(val) + return int32(val), nil default: - panic("unsupported type") + return 0, ErrUnsupportedType } } // castToLong casts numeric type to int64. -func castToLong(value interface{}) int64 { +func castToLong(value interface{}) (int64, error) { switch val := value.(type) { case int64: - return val + return val, nil case int32: - return int64(val) + return int64(val), nil case int: - return int64(val) + return int64(val), nil case float32: - return int64(val) + return int64(val), nil case float64: - return int64(val) + return int64(val), nil default: - panic("unsupported type") + return 0, ErrUnsupportedType } } diff --git a/pkg/document/crdt/counter_test.go b/pkg/document/crdt/counter_test.go index de3345dfb..d83a2b066 100644 --- a/pkg/document/crdt/counter_test.go +++ b/pkg/document/crdt/counter_test.go @@ -30,81 +30,94 @@ import ( func TestCounter(t *testing.T) { t.Run("new counter test", func(t *testing.T) { - intCntWithInt32Value := crdt.NewCounter(crdt.IntegerCnt, int32(math.MaxInt32), time.InitialTicket) + intCntWithInt32Value, err := crdt.NewCounter(crdt.IntegerCnt, int32(math.MaxInt32), time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.IntegerCnt, intCntWithInt32Value.ValueType()) - intCntWithInt64Value := crdt.NewCounter(crdt.IntegerCnt, int64(math.MaxInt32+1), time.InitialTicket) + intCntWithInt64Value, err := crdt.NewCounter(crdt.IntegerCnt, int64(math.MaxInt32+1), time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.IntegerCnt, intCntWithInt64Value.ValueType()) - intCntWithIntValue := crdt.NewCounter(crdt.IntegerCnt, math.MaxInt32, time.InitialTicket) + intCntWithIntValue, err := crdt.NewCounter(crdt.IntegerCnt, math.MaxInt32, time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.IntegerCnt, intCntWithIntValue.ValueType()) - intCntWithDoubleValue := crdt.NewCounter(crdt.IntegerCnt, 0.5, time.InitialTicket) + intCntWithDoubleValue, err := crdt.NewCounter(crdt.IntegerCnt, 0.5, time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.IntegerCnt, intCntWithDoubleValue.ValueType()) - intCntWithUnsupportedValue := func() { crdt.NewCounter(crdt.IntegerCnt, "", time.InitialTicket) } - assert.Panics(t, intCntWithUnsupportedValue) + _, err = crdt.NewCounter(crdt.IntegerCnt, "", time.InitialTicket) + assert.ErrorIs(t, err, crdt.ErrUnsupportedType) - longCntWithInt32Value := crdt.NewCounter(crdt.LongCnt, int32(math.MaxInt32), time.InitialTicket) + longCntWithInt32Value, err := crdt.NewCounter(crdt.LongCnt, int32(math.MaxInt32), time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.LongCnt, longCntWithInt32Value.ValueType()) - longCntWithInt64Value := crdt.NewCounter(crdt.LongCnt, int64(math.MaxInt32+1), time.InitialTicket) + longCntWithInt64Value, err := crdt.NewCounter(crdt.LongCnt, int64(math.MaxInt32+1), time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.LongCnt, longCntWithInt64Value.ValueType()) - longCntWithIntValue := crdt.NewCounter(crdt.LongCnt, math.MaxInt32+1, time.InitialTicket) + longCntWithIntValue, err := crdt.NewCounter(crdt.LongCnt, math.MaxInt32+1, time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.LongCnt, longCntWithIntValue.ValueType()) - longCntWithDoubleValue := crdt.NewCounter(crdt.LongCnt, 0.5, time.InitialTicket) + longCntWithDoubleValue, err := crdt.NewCounter(crdt.LongCnt, 0.5, time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.LongCnt, longCntWithDoubleValue.ValueType()) - longCntWithUnsupportedValue := func() { crdt.NewCounter(crdt.LongCnt, "", time.InitialTicket) } - assert.Panics(t, longCntWithUnsupportedValue) + _, err = crdt.NewCounter(crdt.LongCnt, "", time.InitialTicket) + assert.ErrorIs(t, err, crdt.ErrUnsupportedType) }) t.Run("increase test", func(t *testing.T) { var x = 5 var y int64 = 10 var z = 3.14 - integer := crdt.NewCounter(crdt.IntegerCnt, x, time.InitialTicket) - long := crdt.NewCounter(crdt.LongCnt, y, time.InitialTicket) - double := crdt.NewCounter(crdt.IntegerCnt, z, time.InitialTicket) + integer, err := crdt.NewCounter(crdt.IntegerCnt, x, time.InitialTicket) + assert.NoError(t, err) + long, err := crdt.NewCounter(crdt.LongCnt, y, time.InitialTicket) + assert.NoError(t, err) + double, err := crdt.NewCounter(crdt.IntegerCnt, z, time.InitialTicket) + assert.NoError(t, err) integerOperand := crdt.NewPrimitive(x, time.InitialTicket) longOperand := crdt.NewPrimitive(y, time.InitialTicket) doubleOperand := crdt.NewPrimitive(z, time.InitialTicket) // normal process test - integer.Increase(integerOperand) - integer.Increase(longOperand) - integer.Increase(doubleOperand) + _, err = integer.Increase(integerOperand) + assert.NoError(t, err) + _, err = integer.Increase(longOperand) + assert.NoError(t, err) + _, err = integer.Increase(doubleOperand) + assert.NoError(t, err) assert.Equal(t, integer.Marshal(), "23") - long.Increase(integerOperand) - long.Increase(longOperand) - long.Increase(doubleOperand) + _, err = long.Increase(integerOperand) + assert.NoError(t, err) + _, err = long.Increase(longOperand) + assert.NoError(t, err) + _, err = long.Increase(doubleOperand) + assert.NoError(t, err) assert.Equal(t, long.Marshal(), "28") - double.Increase(integerOperand) - double.Increase(longOperand) - double.Increase(doubleOperand) + _, err = double.Increase(integerOperand) + assert.NoError(t, err) + _, err = double.Increase(longOperand) + assert.NoError(t, err) + _, err = double.Increase(doubleOperand) + assert.NoError(t, err) assert.Equal(t, double.Marshal(), "21") // error process test - // TODO: it should be modified to error check - // when 'Remove panic from server code (#50)' is completed. - unsupportedTypePanicTest := func() { - r := recover() - assert.NotNil(t, r) - assert.Equal(t, r, "unsupported type") + unsupportedTypeErrorTest := func(v interface{}) { + _, err = crdt.NewCounter(crdt.IntegerCnt, v, time.InitialTicket) + assert.ErrorIs(t, err, crdt.ErrUnsupportedType) } - unsupportedTest := func(v interface{}) { - defer unsupportedTypePanicTest() - crdt.NewCounter(crdt.IntegerCnt, v, time.InitialTicket) - } - unsupportedTest("str") - unsupportedTest(true) - unsupportedTest([]byte{2}) - unsupportedTest(gotime.Now()) + unsupportedTypeErrorTest("str") + unsupportedTypeErrorTest(true) + unsupportedTypeErrorTest([]byte{2}) + unsupportedTypeErrorTest(gotime.Now()) assert.Equal(t, integer.Marshal(), "23") assert.Equal(t, long.Marshal(), "28") @@ -112,11 +125,13 @@ func TestCounter(t *testing.T) { }) t.Run("Counter value overflow test", func(t *testing.T) { - integer := crdt.NewCounter(crdt.IntegerCnt, math.MaxInt32, time.InitialTicket) + integer, err := crdt.NewCounter(crdt.IntegerCnt, math.MaxInt32, time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, integer.ValueType(), crdt.IntegerCnt) operand := crdt.NewPrimitive(1, time.InitialTicket) - integer.Increase(operand) + _, err = integer.Increase(operand) + assert.NoError(t, err) assert.Equal(t, integer.ValueType(), crdt.IntegerCnt) assert.Equal(t, integer.Marshal(), strconv.FormatInt(math.MinInt32, 10)) }) diff --git a/pkg/document/json/counter.go b/pkg/document/json/counter.go index 668ebd23c..0ae72809a 100644 --- a/pkg/document/json/counter.go +++ b/pkg/document/json/counter.go @@ -71,7 +71,9 @@ func (p *Counter) Increase(v interface{}) *Counter { panic("unsupported type") } - p.Counter.Increase(primitive) + if _, err := p.Counter.Increase(primitive); err != nil { + panic(err) + } p.context.Push(operations.NewIncrease( p.CreatedAt(), diff --git a/pkg/document/json/object.go b/pkg/document/json/object.go index 2cf157480..11b44bf24 100644 --- a/pkg/document/json/object.go +++ b/pkg/document/json/object.go @@ -75,14 +75,22 @@ func (p *Object) SetNewCounter(k string, t crdt.CounterType, n interface{}) *Cou v := p.setInternal(k, func(ticket *time.Ticket) crdt.Element { switch t { case crdt.IntegerCnt: + counter, err := crdt.NewCounter(crdt.IntegerCnt, n, ticket) + if err != nil { + panic(err) + } return NewCounter( p.context, - crdt.NewCounter(crdt.IntegerCnt, n, ticket), + counter, ) case crdt.LongCnt: + counter, err := crdt.NewCounter(crdt.LongCnt, n, ticket) + if err != nil { + panic(err) + } return NewCounter( p.context, - crdt.NewCounter(crdt.LongCnt, n, ticket), + counter, ) default: panic("unsupported type") diff --git a/pkg/document/operations/increase.go b/pkg/document/operations/increase.go index cff99aeea..5b3d126a6 100644 --- a/pkg/document/operations/increase.go +++ b/pkg/document/operations/increase.go @@ -51,7 +51,9 @@ func (o *Increase) Execute(root *crdt.Root) error { } value := o.value.(*crdt.Primitive) - cnt.Increase(value) + if _, err := cnt.Increase(value); err != nil { + return err + } return nil }