Skip to content

Commit

Permalink
feat: Async/Lazy for easy using
Browse files Browse the repository at this point in the history
  • Loading branch information
jizhuozhi committed Jul 12, 2024
1 parent a9ba0ce commit 1489f49
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 6 deletions.
17 changes: 17 additions & 0 deletions async.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package future

func Async[T any](f func() (T, error)) *Future[T] {
p := NewPromise[T]()
go func() {
val, err := f()
p.Set(val, err)
}()
return p.Future()
}

func Lazy[T any](f func() (T, error)) *Future[T] {
p := NewPromise[T]()
p.state.state |= flagLazy
p.state.f = f
return p.Future()
}
50 changes: 50 additions & 0 deletions async_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package future

import (
"runtime"
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
)

func TestAsync(t *testing.T) {
f := Async(func() (int, error) {
return 1, nil
})
val, err := f.Get()
assert.Equal(t, 1, val)
assert.Equal(t, nil, err)
}

func TestLazy(t *testing.T) {
f := Lazy(func() (int, error) {
return 1, nil
})
val, err := f.Get()
assert.Equal(t, 1, val)
assert.Equal(t, nil, err)
}

func TestLazyConcurrency(t *testing.T) {
n := runtime.NumCPU() - 1

var counter int32
f := Lazy(func() (int, error) {
c := atomic.AddInt32(&counter, 1)
return int(c), nil
})

wg := sync.WaitGroup{}
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer wg.Done()
val, err := f.Get()
assert.Equal(t, val, 1)
assert.Equal(t, err, nil)
}()
}
wg.Wait()
}
35 changes: 29 additions & 6 deletions future.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,24 @@ const (
stateDone
)

const stateDelta = 1 << 32

const (
maskCounter = 1<<32 - 1
maskState = 1<<34 - 1
)

const flagLazy uint64 = 1 << 63

type state[T any] struct {
noCopy noCopy

state uint64 // high 32 bits are state, low 32 bits are waiter count.
state uint64 // high 30 bits are flags, mid 2 bits are state, low 32 bits are waiter count.
sema uint32

val T
err error
f func() (T, error)
}

type Promise[T any] struct {
Expand All @@ -31,14 +41,14 @@ type Future[T any] struct {
func (s *state[T]) set(val T, err error) {
for {
st := atomic.LoadUint64(&s.state)
if (st >> 32) > stateFree {
if ((st & maskState) >> 32) > stateFree {
panic("promise already satisfied")
}
if atomic.CompareAndSwapUint64(&s.state, st, st+(1<<32)) {
if atomic.CompareAndSwapUint64(&s.state, st, st+stateDelta) {
s.val = val
s.err = err
st = atomic.AddUint64(&s.state, 1<<32)
for w := st & (1<<32 - 1); w > 0; w-- {
st = atomic.AddUint64(&s.state, stateDelta)
for w := st & maskCounter; w > 0; w-- {
runtime_Semrelease(&s.sema, false, 0)
}
return
Expand All @@ -47,9 +57,22 @@ func (s *state[T]) set(val T, err error) {
}

func (s *state[T]) get() (T, error) {
if atomic.LoadUint64(&s.state)&flagLazy == flagLazy {
for {
st := atomic.LoadUint64(&s.state)
if st&flagLazy != flagLazy {
break
}
if atomic.CompareAndSwapUint64(&s.state, st, st&(^flagLazy)) {
val, err := s.f()
s.set(val, err)
return val, err
}
}
}
for {
st := atomic.LoadUint64(&s.state)
if (st >> 32) == stateDone {
if ((st & maskState) >> 32) == stateDone {
return s.val, s.err
}
if atomic.CompareAndSwapUint64(&s.state, st, st+1) {
Expand Down

0 comments on commit 1489f49

Please sign in to comment.