diff --git a/concurrency/atomicmap.go b/concurrency/atomicmap.go index 529a605..758bb20 100644 --- a/concurrency/atomicmap.go +++ b/concurrency/atomicmap.go @@ -54,29 +54,29 @@ func NewAtomicMapStringInt64() *AtomicMap[string, int64] { } } -func NewAtomicMapStringUint64() *AtomicMap[string, uint64] { - return &AtomicMap[string, uint64]{ - items: make(map[string]*AtomicValue[uint64]), - } -} - func NewAtomicMapStringInt32() *AtomicMap[string, int32] { return &AtomicMap[string, int32]{ items: make(map[string]*AtomicValue[int32]), } } +func NewAtomicMapStringUint64() *AtomicMap[string, uint64] { + return &AtomicMap[string, uint64]{ + items: make(map[string]*AtomicValue[uint64]), + } +} + func NewAtomicMapStringUint32() *AtomicMap[string, uint32] { return &AtomicMap[string, uint32]{ items: make(map[string]*AtomicValue[uint32]), } } - func (a *AtomicMap[K, T]) Get(key K) (*AtomicValue[T], bool) { a.lock.RLock() + defer a.lock.RUnlock() + item, ok := a.items[key] - a.lock.RUnlock() if !ok { return nil, false } @@ -90,9 +90,10 @@ func (a *AtomicMap[K, T]) GetOrCreate(key K, createT T) *AtomicValue[T] { if !ok { a.lock.Lock() // Double-check the key exists to avoid race condition - if item, ok = a.items[key]; !ok { - a.items[key] = &AtomicValue[T]{value: createT} - item = a.items[key] + item, ok = a.items[key] + if !ok { + item = &AtomicValue[T]{value: createT} + a.items[key] = item } a.lock.Unlock() } @@ -105,10 +106,16 @@ func (a *AtomicMap[K, T]) Delete(key K) { a.lock.Unlock() } -func (a *AtomicMap[K, T]) ForEach(fn func(key K, val *AtomicValue[T])) { +func (a *AtomicMap[K, T]) ForEach(fn func(key K, value *AtomicValue[T])) { a.lock.RLock() - for key, val := range a.items { - fn(key, val) + defer a.lock.RUnlock() + for k, v := range a.items { + fn(k, v) } - a.lock.RUnlock() -} \ No newline at end of file +} + +func (a *AtomicMap[K, T]) Clear() { + a.lock.Lock() + defer a.lock.Unlock() + a.items = make(map[K]*AtomicValue[T]) +} diff --git a/concurrency/atomicmap_test.go b/concurrency/atomicmap_test.go index ed11048..016530a 100644 --- a/concurrency/atomicmap_test.go +++ b/concurrency/atomicmap_test.go @@ -9,29 +9,29 @@ import ( ) func TestAtomicMapInt32_New_Get_Delete(t *testing.T) { - m := NewAtomicMapInt32() + m := NewAtomicMapStringInt32() require.NotNil(t, m) - require.NotNil(t, m.Items) - require.Empty(t, m.Items) + require.NotNil(t, m.items) + require.Empty(t, m.items) t.Run("basic operations", func(t *testing.T) { key := "key1" value := int32(10) // Initially, the key should not exist - _, err := m.Get(key) - require.Error(t, err) + _, ok := m.Get(key) + require.False(t, ok) // Add a value and check it - m.GetOrCreate(key).Store(value) - result, err := m.Get(key) - require.NoError(t, err) + m.GetOrCreate(key, 0).Store(value) + result, ok := m.Get(key) + require.True(t, ok) assert.Equal(t, value, result.Load()) // Delete the key and check it no longer exists m.Delete(key) - _, err = m.Get(key) - require.Error(t, err) + _, ok = m.Get(key) + require.False(t, ok) }) t.Run("concurrent access multiple keys", func(t *testing.T) { @@ -44,21 +44,21 @@ func TestAtomicMapInt32_New_Get_Delete(t *testing.T) { go func(k string) { defer wg.Done() for i := 0; i < iterations; i++ { - m.GetOrCreate(k).Add(1) + m.GetOrCreate(k, 0).Add(1) } }(key) go func(k string) { defer wg.Done() for i := 0; i < iterations; i++ { - m.GetOrCreate(k).Add(-1) + m.GetOrCreate(k, 0).Add(-1) } }(key) } wg.Wait() for _, key := range keys { - val, err := m.Get(key) - require.NoError(t, err) + val, ok := m.Get(key) + require.True(t, ok) require.Equal(t, int32(0), val.Load()) } }) diff --git a/concurrency/mutexmap.go b/concurrency/mutexmap.go index 5dbd561..3ca6207 100644 --- a/concurrency/mutexmap.go +++ b/concurrency/mutexmap.go @@ -1,102 +1,91 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package concurrency -import "sync" +import ( + "sync" +) -type MutexMap struct { - mu sync.RWMutex // outer lock - mutex map[string]*sync.RWMutex // inner locks +type MutexMap[T comparable] struct { + lock sync.RWMutex + items map[T]*sync.RWMutex } -func NewMutexMap() *MutexMap { - return &MutexMap{ - mutex: make(map[string]*sync.RWMutex), +func NewMutexMapString() *MutexMap[string] { + return &MutexMap[string]{ + items: make(map[string]*sync.RWMutex), } } -func (mm *MutexMap) Lock(key string) { - mm.OuterRLock() - lock, ok := mm.mutex[key] - mm.OuterRUnlock() - +func (a *MutexMap[T]) Lock(key T) { + a.lock.RLock() + mutex, ok := a.items[key] + a.lock.RUnlock() if !ok { - mm.OuterLock() - lock, ok = mm.mutex[key] + a.lock.Lock() + mutex, ok = a.items[key] if !ok { - mm.mutex[key] = &sync.RWMutex{} - lock = mm.mutex[key] + mutex = &sync.RWMutex{} + a.items[key] = mutex } - mm.OuterUnlock() + a.lock.Unlock() } - lock.Lock() + mutex.Lock() } -func (mm *MutexMap) Unlock(key string) { - mm.OuterLock() - defer mm.OuterUnlock() - - if _, ok := mm.mutex[key]; ok { - mm.mutex[key].Unlock() +func (a *MutexMap[T]) Unlock(key T) { + a.lock.RLock() + mutex, ok := a.items[key] + a.lock.RUnlock() + if ok { + mutex.Unlock() } } -func (mm *MutexMap) RLock(key string) { - mm.OuterRLock() - lock, ok := mm.mutex[key] - mm.OuterRUnlock() - +func (a *MutexMap[T]) RLock(key T) { + a.lock.RLock() + mutex, ok := a.items[key] + a.lock.RUnlock() if !ok { - mm.OuterLock() - lock, ok = mm.mutex[key] + a.lock.Lock() + mutex, ok = a.items[key] if !ok { - mm.mutex[key] = &sync.RWMutex{} - lock = mm.mutex[key] + mutex = &sync.RWMutex{} + a.items[key] = mutex } - mm.OuterUnlock() - } - lock.RLock() -} - -func (mm *MutexMap) RUnlock(key string) { - mm.OuterLock() - defer mm.OuterUnlock() - - if _, ok := mm.mutex[key]; ok { - mm.mutex[key].RUnlock() + a.lock.Unlock() } + mutex.Lock() } -// Add adds a new mutex to the map -// If the calling code already holds the outer lock, set lock parameter to false -func (mm *MutexMap) Add(key string) { - mm.OuterLock() - defer mm.OuterUnlock() - - if _, ok := mm.mutex[key]; !ok { - mm.mutex[key] = &sync.RWMutex{} +func (a *MutexMap[T]) RUnlock(key T) { + a.lock.RLock() + mutex, ok := a.items[key] + a.lock.RUnlock() + if ok { + mutex.Unlock() } } -// Delete deletes a mutex from the map -// If the calling code already holds the outer lock, set lock parameter to false -func (mm *MutexMap) Delete(key string) { - mm.OuterLock() - defer mm.OuterUnlock() - - delete(mm.mutex, key) -} - -func (mm *MutexMap) OuterLock() { - mm.mu.Lock() -} - -func (mm *MutexMap) OuterUnlock() { - mm.mu.Unlock() -} - -func (mm *MutexMap) OuterRLock() { - mm.mu.RLock() +func (a *MutexMap[T]) Delete(key T) { + a.lock.Lock() + delete(a.items, key) + a.lock.Unlock() } -func (mm *MutexMap) OuterRUnlock() { - mm.mu.RUnlock() +func (a *MutexMap[T]) Clear() { + a.lock.Lock() + a.items = make(map[T]*sync.RWMutex) + a.lock.Unlock() } diff --git a/concurrency/mutexmap_test.go b/concurrency/mutexmap_test.go index 8b3824d..7eb5436 100644 --- a/concurrency/mutexmap_test.go +++ b/concurrency/mutexmap_test.go @@ -8,68 +8,63 @@ import ( ) func TestNewMutexMap_Add_Delete(t *testing.T) { - mm := NewMutexMap() + mm := NewMutexMapString() t.Run("New mutex map", func(t *testing.T) { require.NotNil(t, mm) - require.NotNil(t, mm.mutex) - require.Empty(t, mm.mutex) + require.NotNil(t, mm.items) + require.Empty(t, mm.items) }) - t.Run("Add mutex ", func(t *testing.T) { - mm.Add("key1") - require.Len(t, mm.mutex, 1) - _, ok := mm.mutex["key1"] + t.Run("Lock and unlock mutex", func(t *testing.T) { + mm.Lock("key1") + _, ok := mm.items["key1"] require.True(t, ok) + mm.Unlock("key1") }) - t.Run("Delete mutex", func(t *testing.T) { - mm.Delete("key1") - require.Empty(t, mm.mutex) - _, ok := mm.mutex["key1"] - require.False(t, ok) - }) + t.Run("Concurrently lock and unlock mutexes", func(t *testing.T) { + var counter int + var wg sync.WaitGroup - t.Run("Concurrently add and delete mutexes", func(t *testing.T) { numGoroutines := 10 - keys := []string{"key1", "key2", "key3"} - - var wg sync.WaitGroup wg.Add(numGoroutines) - // Concurrently add and delete keys + // Concurrently lock and unlock for each key for i := 0; i < numGoroutines; i++ { go func() { defer wg.Done() - for _, key := range keys { - mm.Add(key) - mm.Delete(key) - } + mm.Lock("key1") + counter++ + mm.Unlock("key1") }() } wg.Wait() - // Additional check that all keys have been deleted - for _, key := range keys { - _, ok := mm.mutex[key] - require.False(t, ok) - } + require.Equal(t, 10, counter) }) - t.Run("Concurrently lock and unlock mutexes", func(t *testing.T) { + t.Run("RLock and RUnlock mutex", func(t *testing.T) { + mm.RLock("key1") + _, ok := mm.items["key1"] + require.True(t, ok) + mm.RUnlock("key1") + }) + + t.Run("Concurrently RLock and RUnlock mutexes", func(t *testing.T) { var counter int var wg sync.WaitGroup numGoroutines := 10 wg.Add(numGoroutines) - // Concurrently lock and unlock for each key + // Concurrently RLock and RUnlock for each key for i := 0; i < numGoroutines; i++ { go func() { defer wg.Done() - mm.Lock("key1") + mm.RLock("key1") counter++ - mm.Unlock("key1") + mm.RUnlock("key1") }() } wg.Wait() @@ -77,11 +72,20 @@ func TestNewMutexMap_Add_Delete(t *testing.T) { require.Equal(t, 10, counter) }) - t.Run("Lock and unlock nonexistent mutexes", func(t *testing.T) { - mm.Lock("non-existent-key") - _, ok := mm.mutex["non-existent-key"] - mm.Unlock("non-existent-key") + t.Run("Delete mutex", func(t *testing.T) { + mm.Lock("key1") + mm.Unlock("key1") + mm.Delete("key1") + _, ok := mm.items["key1"] + require.False(t, ok) + }) - require.True(t, ok) + t.Run("Clear all mutexes", func(t *testing.T) { + mm.Lock("key1") + mm.Unlock("key1") + mm.Lock("key2") + mm.Unlock("key2") + mm.Clear() + require.Empty(t, mm.items) }) }