diff --git a/complex/complex.go b/complex/complex.go new file mode 100644 index 0000000..cc88668 --- /dev/null +++ b/complex/complex.go @@ -0,0 +1,20 @@ +package complex + +import "github.com/shimmeringbee/persistence" + +func Store[T any](section persistence.Section, key string, val T, enc func(persistence.Section, string, T) error) error { + return enc(section, key, val) +} + +func Retrieve[T any](section persistence.Section, key string, dec func(persistence.Section, string) (T, bool), defValue ...T) (T, bool) { + if v, ok := dec(section, key); ok { + return v, ok + } else { + if len(defValue) > 0 { + return defValue[0], false + } else { + v = *new(T) + return v, false + } + } +} diff --git a/complex/complex_test.go b/complex/complex_test.go new file mode 100644 index 0000000..b26490b --- /dev/null +++ b/complex/complex_test.go @@ -0,0 +1,30 @@ +package complex + +import ( + "github.com/shimmeringbee/persistence/impl/memory" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestRetrieve(t *testing.T) { + t.Run("default value provided is returned if not found and default provided", func(t *testing.T) { + s := memory.New() + + expected := time.Duration(1) + + actual, found := Retrieve(s, Key, DurationDecoder, expected) + assert.False(t, found) + assert.Equal(t, expected, actual) + }) + + t.Run("zero value provided is returned if not found and no default", func(t *testing.T) { + s := memory.New() + + expected := time.Duration(0) + + actual, found := Retrieve(s, Key, DurationDecoder) + assert.False(t, found) + assert.Equal(t, expected, actual) + }) +} diff --git a/complex/time.go b/complex/time.go new file mode 100644 index 0000000..522852e --- /dev/null +++ b/complex/time.go @@ -0,0 +1,30 @@ +package complex + +import ( + "github.com/shimmeringbee/persistence" + "time" +) + +func TimeEncoder(s persistence.Section, k string, v time.Time) error { + return s.Set(k, v.UnixMilli()) +} + +func TimeDecoder(s persistence.Section, k string) (time.Time, bool) { + if ev, found := s.Int(k); found { + return time.UnixMilli(ev), true + } else { + return time.Time{}, false + } +} + +func DurationEncoder(s persistence.Section, k string, v time.Duration) error { + return s.Set(k, v.Milliseconds()) +} + +func DurationDecoder(s persistence.Section, k string) (time.Duration, bool) { + if ev, found := s.Int(k); found { + return time.Duration(ev) * time.Millisecond, true + } else { + return time.Duration(0), false + } +} diff --git a/complex/time_test.go b/complex/time_test.go new file mode 100644 index 0000000..e5ebacf --- /dev/null +++ b/complex/time_test.go @@ -0,0 +1,40 @@ +package complex + +import ( + "github.com/shimmeringbee/persistence/impl/memory" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +const Key = "key" + +func TestTime(t *testing.T) { + t.Run("time is stored and retrieved to the millisecond level", func(t *testing.T) { + s := memory.New() + + expected := time.UnixMilli(time.Now().UnixMilli()) + + err := Store(s, Key, expected, TimeEncoder) + assert.NoError(t, err) + + actual, found := Retrieve(s, Key, TimeDecoder) + assert.True(t, found) + assert.Equal(t, expected, actual) + }) +} + +func TestDuration(t *testing.T) { + t.Run("duration is stored and retrieved to the millisecond level", func(t *testing.T) { + s := memory.New() + + expected := time.Duration(1234) * time.Millisecond + + err := Store(s, Key, expected, DurationEncoder) + assert.NoError(t, err) + + actual, found := Retrieve(s, Key, DurationDecoder) + assert.True(t, found) + assert.Equal(t, expected, actual) + }) +} diff --git a/impl/memory/memory.go b/impl/memory/memory.go index 63c404a..d237a65 100644 --- a/impl/memory/memory.go +++ b/impl/memory/memory.go @@ -86,11 +86,11 @@ func genericRetrieve[T any](m *memory, key string, defValue ...T) (T, bool) { } } -func (m *memory) Int(key string, defValue ...int) (int, bool) { +func (m *memory) Int(key string, defValue ...int64) (int64, bool) { return genericRetrieve(m, key, defValue...) } -func (m *memory) UInt(key string, defValue ...uint) (uint, bool) { +func (m *memory) UInt(key string, defValue ...uint64) (uint64, bool) { return genericRetrieve(m, key, defValue...) } @@ -117,25 +117,25 @@ func (m *memory) Set(key string, value interface{}) error { case string: sV = v case int: - sV = v + sV = int64(v) case int8: - sV = int(v) + sV = int64(v) case int16: - sV = int(v) + sV = int64(v) case int32: - sV = int(v) + sV = int64(v) case int64: - sV = int(v) - case uint: sV = v + case uint: + sV = uint64(v) case uint8: - sV = uint(v) + sV = uint64(v) case uint16: - sV = uint(v) + sV = uint64(v) case uint32: - sV = uint(v) + sV = uint64(v) case uint64: - sV = uint(v) + sV = v case float32: sV = float64(v) case float64: diff --git a/impl/memory/memory_test.go b/impl/memory/memory_test.go index a1f285f..409ea7d 100644 --- a/impl/memory/memory_test.go +++ b/impl/memory/memory_test.go @@ -130,41 +130,41 @@ func TestMemory_Int(t *testing.T) { s := New() val, found := s.Int("intKey") - assert.Equal(t, 0, val) + assert.Equal(t, int64(0), val) assert.False(t, found) val, found = s.Int("intKey", 1) - assert.Equal(t, 1, val) + assert.Equal(t, int64(1), val) assert.False(t, found) assert.NoError(t, s.Set("intKey", 2)) val, found = s.Int("intKey", 1) - assert.Equal(t, 2, val) + assert.Equal(t, int64(2), val) assert.True(t, found) assert.NoError(t, s.Set("int8Key", int8(2))) val, found = s.Int("int8Key", 1) - assert.Equal(t, 2, val) + assert.Equal(t, int64(2), val) assert.True(t, found) assert.NoError(t, s.Set("int16Key", int16(2))) val, found = s.Int("int16Key", 1) - assert.Equal(t, 2, val) + assert.Equal(t, int64(2), val) assert.True(t, found) assert.NoError(t, s.Set("int32Key", int32(2))) val, found = s.Int("int32Key", 1) - assert.Equal(t, 2, val) + assert.Equal(t, int64(2), val) assert.True(t, found) assert.NoError(t, s.Set("int64Key", int64(2))) val, found = s.Int("int64Key", 1) - assert.Equal(t, 2, val) + assert.Equal(t, int64(2), val) assert.True(t, found) }) } @@ -174,41 +174,41 @@ func TestMemory_UInt(t *testing.T) { s := New() val, found := s.UInt("intKey") - assert.Equal(t, uint(0), val) + assert.Equal(t, uint64(0), val) assert.False(t, found) val, found = s.UInt("intKey", 1) - assert.Equal(t, uint(1), val) + assert.Equal(t, uint64(1), val) assert.False(t, found) assert.NoError(t, s.Set("intKey", uint(2))) val, found = s.UInt("intKey", 1) - assert.Equal(t, uint(2), val) + assert.Equal(t, uint64(2), val) assert.True(t, found) assert.NoError(t, s.Set("int8Key", uint8(2))) val, found = s.UInt("int8Key", 1) - assert.Equal(t, uint(2), val) + assert.Equal(t, uint64(2), val) assert.True(t, found) assert.NoError(t, s.Set("int16Key", uint16(2))) val, found = s.UInt("int16Key", 1) - assert.Equal(t, uint(2), val) + assert.Equal(t, uint64(2), val) assert.True(t, found) assert.NoError(t, s.Set("int32Key", uint32(2))) val, found = s.UInt("int32Key", 1) - assert.Equal(t, uint(2), val) + assert.Equal(t, uint64(2), val) assert.True(t, found) assert.NoError(t, s.Set("int64Key", uint64(2))) val, found = s.UInt("int64Key", 1) - assert.Equal(t, uint(2), val) + assert.Equal(t, uint64(2), val) assert.True(t, found) }) } diff --git a/interface.go b/interface.go index 0a08492..96a0477 100644 --- a/interface.go +++ b/interface.go @@ -6,8 +6,8 @@ type Section interface { Section(key ...string) Section - Int(key string, defValue ...int) (int, bool) - UInt(key string, defValue ...uint) (uint, bool) + Int(key string, defValue ...int64) (int64, bool) + UInt(key string, defValue ...uint64) (uint64, bool) String(key string, defValue ...string) (string, bool) Bool(key string, defValue ...bool) (bool, bool) Float(key string, defValue ...float64) (float64, bool) @@ -17,20 +17,3 @@ type Section interface { Delete(key string) bool } - -func StoreComplex[T any](section Section, key string, val T, enc func(Section, string, T) error) error { - return enc(section, key, val) -} - -func RetrieveComplex[T any](section Section, key string, dec func(Section, string) (T, bool), defValue ...T) (T, bool) { - if v, ok := dec(section, key); ok { - return v, ok - } else { - if len(defValue) > 0 { - return defValue[0], false - } else { - v = *new(T) - return v, false - } - } -}