From cc36bed00059f817c988e352ebda3bd1d4d86606 Mon Sep 17 00:00:00 2001 From: Arran Schlosberg Date: Wed, 11 Sep 2024 17:12:54 +0100 Subject: [PATCH] feat: `pseudo.PointerTo()` --- libevm/pseudo/type.go | 21 +++++++++++++++++++++ libevm/pseudo/type_test.go | 23 +++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/libevm/pseudo/type.go b/libevm/pseudo/type.go index 8c453f4cb0e7..8d1568638ed5 100644 --- a/libevm/pseudo/type.go +++ b/libevm/pseudo/type.go @@ -60,6 +60,27 @@ func Zero[T any]() *Pseudo[T] { return From[T](x) } +// PointerTo is equivalent to [From] called with a pointer to the payload +// carried by `t`. It first confirms that the payload is of type `T`. +func PointerTo[T any](t *Type) (*Pseudo[*T], error) { + c, ok := t.val.(*concrete[T]) + if !ok { + var want *T + return nil, fmt.Errorf("cannot create *Pseudo[%T] from *Type carrying %T", want, t.val.get()) + } + return From(&c.val), nil +} + +// MustPointerTo is equivalent to [PointerTo] except that it panics instead of +// returning an error. +func MustPointerTo[T any](t *Type) *Pseudo[*T] { + p, err := PointerTo[T](t) + if err != nil { + panic(err) + } + return p +} + // Interface returns the wrapped value as an `any`, equivalent to // [reflect.Value.Interface]. Prefer [Value.Get]. func (t *Type) Interface() any { return t.val.get() } diff --git a/libevm/pseudo/type_test.go b/libevm/pseudo/type_test.go index 27ecf7e497ea..0b25c945ce29 100644 --- a/libevm/pseudo/type_test.go +++ b/libevm/pseudo/type_test.go @@ -77,3 +77,26 @@ func ExamplePseudo_TypeAndValue() { _ = typ _ = val } + +func TestPointer(t *testing.T) { + type carrier struct { + payload int + } + + typ, val := From(carrier{42}).TypeAndValue() + + t.Run("invalid type", func(t *testing.T) { + _, err := PointerTo[int](typ) + require.Errorf(t, err, "PointerTo[int](%T)", carrier{}) + }) + + t.Run("valid type", func(t *testing.T) { + ptrVal := MustPointerTo[carrier](typ).Value + + assert.Equal(t, 42, val.Get().payload, "before setting via pointer") + var ptr *carrier = ptrVal.Get() + ptr.payload = 314159 + assert.Equal(t, 314159, val.Get().payload, "after setting via pointer") + }) + +}