Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove panic from crdt counter #598

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions api/converter/from_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,19 @@ func fromJSONCounter(pbCnt *api.JSONElement_Counter) (*crdt.Counter, error) {
if err != nil {
return nil, err
}
counterValue, err := crdt.CounterValueFromBytes(counterType, pbCnt.Value)
if err != nil {
return nil, err
}

counter := crdt.NewCounter(
counter, err := crdt.NewCounter(
counterType,
crdt.CounterValueFromBytes(counterType, pbCnt.Value),
counterValue,
createdAt,
)
if err != nil {
return nil, err
}
counter.SetMovedAt(movedAt)
counter.SetRemovedAt(removedAt)

Expand Down
15 changes: 12 additions & 3 deletions api/converter/from_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,11 +737,20 @@ func fromElement(pbElement *api.JSONElementSimple) (crdt.Element, error) {
if err != nil {
return nil, err
}
return crdt.NewCounter(
counterValue, err := crdt.CounterValueFromBytes(counterType, pbElement.Value)
if err != nil {
return nil, err
}

counter, err := crdt.NewCounter(
counterType,
crdt.CounterValueFromBytes(counterType, pbElement.Value),
counterValue,
createdAt,
), nil
)
if err != nil {
return nil, err
}
return counter, nil
case api.ValueType_VALUE_TYPE_TREE:
return BytesToTree(pbElement.Value)
}
Expand Down
6 changes: 5 additions & 1 deletion api/converter/to_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,15 @@ func toCounter(counter *crdt.Counter) (*api.JSONElement, error) {
if err != nil {
return nil, err
}
counterValue, err := counter.Bytes()
if err != nil {
return nil, err
}

return &api.JSONElement{
Body: &api.JSONElement_Counter_{Counter: &api.JSONElement_Counter{
Type: pbCounterType,
Value: counter.Bytes(),
Value: counterValue,
CreatedAt: ToTimeTicket(counter.CreatedAt()),
MovedAt: ToTimeTicket(counter.MovedAt()),
RemovedAt: ToTimeTicket(counter.RemovedAt()),
Expand Down
6 changes: 5 additions & 1 deletion api/converter/to_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,11 +435,15 @@ func toJSONElementSimple(elem crdt.Element) (*api.JSONElementSimple, error) {
if err != nil {
return nil, err
}
counterValue, err := elem.Bytes()
if err != nil {
return nil, err
}

return &api.JSONElementSimple{
Type: pbCounterType,
CreatedAt: ToTimeTicket(elem.CreatedAt()),
Value: elem.Bytes(),
Value: counterValue,
}, nil
case *crdt.Tree:
bytes, err := TreeToBytes(elem)
Expand Down
94 changes: 57 additions & 37 deletions pkg/document/crdt/counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ package crdt

import (
"encoding/binary"
"errors"
"fmt"

"github.com/yorkie-team/yorkie/pkg/document/time"
)

// ErrUnsupportedType returned when the given type is not supported.
var ErrUnsupportedType = errors.New("unsupported type")

// CounterType represents any type that can be used as a counter.
type CounterType int

Expand All @@ -33,16 +37,16 @@ const (
)

// CounterValueFromBytes parses the given bytes into value.
func CounterValueFromBytes(counterType CounterType, value []byte) interface{} {
func CounterValueFromBytes(counterType CounterType, value []byte) (interface{}, error) {
switch counterType {
case IntegerCnt:
val := int32(binary.LittleEndian.Uint32(value))
return int(val)
return int(val), nil
case LongCnt:
return int64(binary.LittleEndian.Uint64(value))
return int64(binary.LittleEndian.Uint64(value)), nil
default:
return nil, ErrUnsupportedType
}

panic("unsupported type")
}

// Counter represents changeable number data type.
Expand All @@ -55,39 +59,47 @@ type Counter struct {
}

// NewCounter creates a new instance of Counter.
func NewCounter(valueType CounterType, value interface{}, createdAt *time.Ticket) *Counter {
func NewCounter(valueType CounterType, value interface{}, createdAt *time.Ticket) (*Counter, error) {
switch valueType {
case IntegerCnt:
intValue, err := castToInt(value)
if err != nil {
return nil, err
}
return &Counter{
valueType: IntegerCnt,
value: castToInt(value),
value: intValue,
createdAt: createdAt,
}
}, nil
case LongCnt:
longValue, err := castToLong(value)
if err != nil {
return nil, err
}
return &Counter{
valueType: LongCnt,
value: castToLong(value),
value: longValue,
createdAt: createdAt,
}
}, nil
default:
return nil, ErrUnsupportedType
}

panic("unsupported type")
}

// Bytes creates an array representing the value.
func (p *Counter) Bytes() []byte {
func (p *Counter) Bytes() ([]byte, error) {
switch val := p.value.(type) {
case int32:
bytes := [4]byte{}
binary.LittleEndian.PutUint32(bytes[:], uint32(val))
return bytes[:]
return bytes[:], nil
case int64:
bytes := [8]byte{}
binary.LittleEndian.PutUint64(bytes[:], uint64(val))
return bytes[:]
return bytes[:], nil
default:
return nil, ErrUnsupportedType
}

panic("unsupported type")
}

// Marshal returns the JSON encoding of the value.
Expand Down Expand Up @@ -146,20 +158,28 @@ func (p *Counter) ValueType() CounterType {
// than MinInt32, Counter's value type can be changed Integer to Long.
// Because in golang, int can be either int32 or int64.
// So we need to assert int to int32.
func (p *Counter) Increase(v *Primitive) *Counter {
func (p *Counter) Increase(v *Primitive) (*Counter, error) {
if !p.IsNumericType() || !v.IsNumericType() {
panic("unsupported type")
return nil, ErrUnsupportedType
}
switch p.valueType {
case IntegerCnt:
p.value = p.value.(int32) + castToInt(v.value)
intValue, err := castToInt(v.value)
if err != nil {
return nil, err
}
p.value = p.value.(int32) + intValue
case LongCnt:
p.value = p.value.(int64) + castToLong(v.value)
longValue, err := castToLong(v.value)
if err != nil {
return nil, err
}
p.value = p.value.(int64) + longValue
default:
panic("unsupported type")
return nil, ErrUnsupportedType
}

return p
return p, nil
}

// IsNumericType checks for numeric types.
Expand All @@ -169,37 +189,37 @@ func (p *Counter) IsNumericType() bool {
}

// castToInt casts numeric type to int32.
func castToInt(value interface{}) int32 {
func castToInt(value interface{}) (int32, error) {
switch val := value.(type) {
case int32:
return val
return val, nil
case int64:
return int32(val)
return int32(val), nil
case int:
return int32(val)
return int32(val), nil
case float32:
return int32(val)
return int32(val), nil
case float64:
return int32(val)
return int32(val), nil
default:
panic("unsupported type")
return 0, ErrUnsupportedType
}
}

// castToLong casts numeric type to int64.
func castToLong(value interface{}) int64 {
func castToLong(value interface{}) (int64, error) {
switch val := value.(type) {
case int64:
return val
return val, nil
case int32:
return int64(val)
return int64(val), nil
case int:
return int64(val)
return int64(val), nil
case float32:
return int64(val)
return int64(val), nil
case float64:
return int64(val)
return int64(val), nil
default:
panic("unsupported type")
return 0, ErrUnsupportedType
}
}
Loading
Loading