From 2c57614f801fb424db5937e73f549f4b5edb84c3 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Sat, 13 Jul 2024 11:56:02 +0200 Subject: [PATCH] Add rand implementation --- example/checkout.go | 10 +---- internal/rand/rand.go | 67 +++++++++++++++++++++++++++++++++ internal/rand/rand_test.go | 76 ++++++++++++++++++++++++++++++++++++++ internal/state/state.go | 8 ++++ router.go | 6 +++ 5 files changed, 158 insertions(+), 9 deletions(-) create mode 100644 internal/rand/rand.go create mode 100644 internal/rand/rand_test.go diff --git a/example/checkout.go b/example/checkout.go index 2a22c4e..d6cd8a4 100644 --- a/example/checkout.go +++ b/example/checkout.go @@ -4,7 +4,6 @@ import ( "fmt" "math/rand" - "github.com/google/uuid" restate "github.com/restatedev/sdk-go" ) @@ -27,17 +26,10 @@ func (c *checkout) Name() string { const CheckoutServiceName = "Checkout" func (c *checkout) Payment(ctx restate.Context, request PaymentRequest) (response PaymentResponse, err error) { - uuid, err := restate.RunAs(ctx, func(ctx restate.RunContext) (string, error) { - uuid := uuid.New() - return uuid.String(), nil - }) + uuid := ctx.Rand().UUID().String() response.ID = uuid - if err != nil { - return response, err - } - // We are a uniform shop where everything costs 30 USD // that is cheaper than the official example :P price := len(request.Tickets) * 30 diff --git a/internal/rand/rand.go b/internal/rand/rand.go new file mode 100644 index 0000000..946d548 --- /dev/null +++ b/internal/rand/rand.go @@ -0,0 +1,67 @@ +package rand + +import ( + "crypto/sha256" + "encoding/binary" + "math/rand/v2" + + "github.com/google/uuid" +) + +type Rand struct { + *rand.Rand +} + +func New(invocationID []byte) *Rand { + return &Rand{rand.New(newSource(invocationID))} +} + +func (r *Rand) UUID() uuid.UUID { + var uuid [16]byte + binary.LittleEndian.PutUint64(uuid[:8], r.Uint64()) + binary.LittleEndian.PutUint64(uuid[8:], r.Uint64()) + uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4 + uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10 + return uuid +} + +type Source struct { + state [4]uint64 +} + +func newSource(invocationID []byte) *Source { + hash := sha256.New() + hash.Write(invocationID) + var sum [32]byte + hash.Sum(sum[:0]) + + return &Source{state: [4]uint64{ + binary.LittleEndian.Uint64(sum[:8]), + binary.LittleEndian.Uint64(sum[8:16]), + binary.LittleEndian.Uint64(sum[16:24]), + binary.LittleEndian.Uint64(sum[24:32]), + }} +} + +func (s *Source) Uint64() uint64 { + result := rotl((s.state[0]+s.state[3]), 23) + s.state[0] + + t := (s.state[1] << 17) + + s.state[2] ^= s.state[0] + s.state[3] ^= s.state[1] + s.state[1] ^= s.state[2] + s.state[0] ^= s.state[3] + + s.state[2] ^= t + + s.state[3] = rotl(s.state[3], 45) + + return result +} + +func rotl(x uint64, k uint64) uint64 { + return (x << k) | (x >> (64 - k)) +} + +var _ rand.Source = (*Source)(nil) diff --git a/internal/rand/rand_test.go b/internal/rand/rand_test.go new file mode 100644 index 0000000..6a91faa --- /dev/null +++ b/internal/rand/rand_test.go @@ -0,0 +1,76 @@ +package rand + +import ( + "encoding/hex" + "math/rand/v2" + "testing" +) + +func TestUint64(t *testing.T) { + id, err := hex.DecodeString("f311f1fdcb9863f0018bd3400ecd7d69b547204e776218b2") + if err != nil { + t.Fatal(err) + } + rand := New(id) + + expected := []uint64{ + 6541268553928124324, + 1632128201851599825, + 3999496359968271420, + 9099219592091638755, + 2609122094717920550, + 16569362788292807660, + 14955958648458255954, + 15581072429430901841, + 4951852598761288088, + 2380816196140950843, + } + + for _, e := range expected { + if found := rand.Uint64(); e != found { + t.Fatalf("Unexpected uint64 %d, expected %d", found, e) + } + } +} + +func TestFloat64(t *testing.T) { + source := &Source{state: [4]uint64{1, 2, 3, 4}} + rand := &Rand{rand.New(source)} + + expected := []float64{ + 4.656612984099695e-9, 6.519269457605503e-9, 0.39843750651926946, + 0.3986824029416509, 0.5822761557370711, 0.2997488042907357, + 0.5336032865255543, 0.36335061693258097, 0.5968067925950846, + 0.18570456306457928, + } + + for _, e := range expected { + if found := rand.Float64(); e != found { + t.Fatalf("Unexpected float64 %v, expected %v", found, e) + } + } +} + +func TestUUID(t *testing.T) { + source := &Source{state: [4]uint64{1, 2, 3, 4}} + rand := &Rand{rand.New(source)} + + expected := []string{ + "01008002-0000-4000-a700-800300000000", + "67008003-00c0-4c00-b200-449901c20c00", + "cd33c49a-01a2-4280-ba33-eecd8a97698a", + "bd4a1533-4713-41c2-979e-167991a02bac", + "d83f078f-0a19-43db-a092-22b24af10591", + "677c91f7-146e-4769-a4fd-df3793e717e8", + "f15179b2-f220-4427-8d90-7b5437d9828d", + "9e97720f-42b8-4d09-a449-914cf221df26", + "09d0a109-6f11-4ef9-93fa-f013d0ad3808", + "41eb0e0c-41c9-4828-85d0-59fb901b4df4", + } + + for _, e := range expected { + if found := rand.UUID().String(); e != found { + t.Fatalf("Unexpected uuid %s, expected %s", found, e) + } + } +} diff --git a/internal/state/state.go b/internal/state/state.go index fa20bee..b270988 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -17,6 +17,7 @@ import ( "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/futures" "github.com/restatedev/sdk-go/internal/log" + "github.com/restatedev/sdk-go/internal/rand" "github.com/restatedev/sdk-go/internal/wire" "github.com/restatedev/sdk-go/rcontext" ) @@ -46,6 +47,10 @@ func (c *Context) Log() *slog.Logger { return c.machine.userLog } +func (c *Context) Rand() *rand.Rand { + return c.machine.rand +} + func (c *Context) Set(key string, value []byte) { c.machine.set(key, value) } @@ -172,6 +177,8 @@ type Machine struct { pendingAcks map[uint32]wire.AckableMessage pendingMutex sync.RWMutex + rand *rand.Rand + failure any } @@ -204,6 +211,7 @@ func (m *Machine) Start(inner context.Context, dropReplayLogs bool, logHandler s m.ctx = inner m.suspensionCtx, m.suspend = context.WithCancelCause(m.ctx) m.id = start.Id + m.rand = rand.New(m.id) m.key = start.Key logHandler = logHandler.WithAttrs([]slog.Attr{slog.String("invocationID", start.DebugId)}) diff --git a/router.go b/router.go index 6b02c8d..abfc3fc 100644 --- a/router.go +++ b/router.go @@ -9,6 +9,7 @@ import ( "github.com/restatedev/sdk-go/internal" "github.com/restatedev/sdk-go/internal/futures" + "github.com/restatedev/sdk-go/internal/rand" "github.com/vmihailenco/msgpack/v5" ) @@ -50,6 +51,11 @@ type Selector interface { type Context interface { RunContext + // Returns a random source which will give deterministic results for a given invocation + // The source wraps the stdlib rand.Rand but with some extra helper methods + // This source is not safe for use inside .Run() + Rand() *rand.Rand + // Sleep for the duration d Sleep(d time.Duration) // Return a handle on a sleep duration which can be combined