diff --git a/caveat.go b/caveat.go index f8a16ef..f011134 100644 --- a/caveat.go +++ b/caveat.go @@ -1,8 +1,8 @@ package macaroon import ( - "fmt" "reflect" + "strconv" ) // A numeric identifier for caveat types. Values less than @@ -40,12 +40,12 @@ const ( // pull requests to this repository. Add a meaningful name of the caveat // type (e.g. CavAcmeCorpWidgetID) on the line prior to // CavMaxUserRegisterable. - CavMinUserRegisterable = 1 << 32 - CavMaxUserRegisterable = 1<<48 - 1 + CavMinUserRegisterable CaveatType = 1 << 32 + CavMaxUserRegisterable CaveatType = 1<<48 - 1 - CavMinUserDefined = 1 << 48 - CavMaxUserDefined = 1<<64 - 2 - CavUnregistered = 1<<64 - 1 + CavMinUserDefined CaveatType = 1 << 48 + CavMaxUserDefined CaveatType = 1<<64 - 2 + CavUnregistered CaveatType = 1<<64 - 1 ) // Caveat is the interface implemented by all caveats. @@ -112,6 +112,14 @@ func RegisterCaveatType(zeroValue Caveat) { s2t[name] = typ } +func unregisterCaveatType(zeroValue Caveat) { + typ := zeroValue.CaveatType() + name := zeroValue.Name() + delete(t2c, typ) + delete(t2s, typ) + delete(s2t, name) +} + // Register an alternate name for this caveat type that will be recognized when // decoding JSON. func RegisterCaveatJSONAlias(typ CaveatType, alias string) { @@ -124,30 +132,33 @@ func RegisterCaveatJSONAlias(typ CaveatType, alias string) { s2t[alias] = typ } -func typeToCaveat(t CaveatType) (Caveat, error) { +func typeToCaveat(t CaveatType) Caveat { cav, ok := t2c[t] if !ok { - return nil, fmt.Errorf("unregistered caveat type %d", t) + return &UnregisteredCaveat{Type: t} } ct := reflect.TypeOf(cav) if ct.Kind() == reflect.Pointer { - return reflect.New(ct.Elem()).Interface().(Caveat), nil + return reflect.New(ct.Elem()).Interface().(Caveat) } - return reflect.Zero(ct).Interface().(Caveat), nil + return reflect.Zero(ct).Interface().(Caveat) } func caveatTypeFromString(s string) CaveatType { if t, ok := s2t[s]; ok { return t } + if t, err := strconv.ParseUint(s, 10, 64); err == nil { + return CaveatType(t) + } return CavUnregistered } func caveatTypeToString(t CaveatType) string { - if s, ok := t2s[t]; ok { + if s, ok := t2s[t]; ok && t < CavMinUserDefined { return s } - return "[unregistered]" + return strconv.FormatUint(uint64(t), 10) } diff --git a/caveat_set.go b/caveat_set.go index c8ef60f..66c5ab5 100644 --- a/caveat_set.go +++ b/caveat_set.go @@ -130,11 +130,7 @@ func (c *CaveatSet) DecodeMsgpack(dec *msgpack.Decoder) error { return err } - cav, err := typeToCaveat(CaveatType(t)) - if err != nil { - return err - } - + cav := typeToCaveat(CaveatType(t)) if err := dec.Decode(cav); err != nil { return err } @@ -181,10 +177,7 @@ func (c *CaveatSet) UnmarshalJSON(b []byte) error { for i := range jcavs { t := caveatTypeFromString(jcavs[i].Type) - if c.Caveats[i], _ = typeToCaveat(t); c.Caveats[i] == nil { - return fmt.Errorf("bad caveat type: %s", jcavs[i].Type) - } - + c.Caveats[i] = typeToCaveat(t) if err := json.Unmarshal(jcavs[i].Body, &c.Caveats[i]); err != nil { return err } diff --git a/caveat_test.go b/caveat_test.go index 849b913..3f73942 100644 --- a/caveat_test.go +++ b/caveat_test.go @@ -19,8 +19,6 @@ func TestCaveatRegistry(t *testing.T) { assert.Equal(t, 1, len(cs.Caveats)) assert.Equal(t, c, cs.Caveats[0]) - assert.Error(t, json.Unmarshal(j2, cs)) - RegisterCaveatJSONAlias(cavTestParentResource, "Foobar") assert.NoError(t, json.Unmarshal(j1, cs)) diff --git a/caveats.go b/caveats.go index 0ef9e6e..97f2965 100644 --- a/caveats.go +++ b/caveats.go @@ -1,8 +1,11 @@ package macaroon import ( + "encoding/json" "fmt" "time" + + msgpack "github.com/vmihailenco/msgpack/v5" ) // Caveat3P is a requirement that the token be presented along with a 3P discharge token. @@ -70,3 +73,47 @@ func (c *BindToParentToken) Prohibits(f Access) error { // access validation. return fmt.Errorf("%w (bind-to-parent)", ErrBadCaveat) } + +type UnregisteredCaveat struct { + Type CaveatType + Body any + RawJSON []byte + RawMsgpack []byte +} + +func (c *UnregisteredCaveat) CaveatType() CaveatType { return c.Type } +func (c *UnregisteredCaveat) Name() string { return "Unregistered" } + +func (c *UnregisteredCaveat) Prohibits(f Access) error { + return fmt.Errorf("%w (unregistered)", ErrBadCaveat) +} + +func (c UnregisteredCaveat) MarshalMsgpack() ([]byte, error) { + // JSON is just for user-readability, but msgpack is what's used for + // signature verification. With struct tags, etc, it's lossy to encode + // things from json<->msgpack, so we just don't allow it. + if len(c.RawMsgpack) == 0 { + return nil, fmt.Errorf("cannot convert unregistered caveats from JSON to msgpack") + } + return c.RawMsgpack, nil +} + +func (c *UnregisteredCaveat) UnmarshalMsgpack(data []byte) error { + c.RawMsgpack = data + return msgpack.Unmarshal(data, &c.Body) +} + +func (c UnregisteredCaveat) MarshalJSON() ([]byte, error) { + // JSON is just for user-readability, but msgpack is what's used for + // signature verification. With struct tags, etc, it's lossy to encode + // things from json<->msgpack, so we just don't allow it. + if len(c.RawJSON) == 0 { + return nil, fmt.Errorf("cannot convert unregistered caveats from msgpack to JSON") + } + return c.RawJSON, nil +} + +func (c *UnregisteredCaveat) UnmarshalJSON(data []byte) error { + c.RawJSON = data + return json.Unmarshal(data, &c.Body) +} diff --git a/caveats_test.go b/caveats_test.go index b1960b2..f1f45da 100644 --- a/caveats_test.go +++ b/caveats_test.go @@ -1,6 +1,8 @@ package macaroon import ( + "encoding/json" + "errors" "testing" "time" @@ -143,6 +145,102 @@ func TestSimple(t *testing.T) { }) } +type myUnregistered struct { + Bar map[string]string `json:"bar"` + Foo int `json:"foo"` +} + +func (c *myUnregistered) CaveatType() CaveatType { return cavMyUnregistered } +func (c *myUnregistered) Name() string { return "MyUnregistered" } +func (c *myUnregistered) Prohibits(f Access) error { return nil } + +func TestUnregisteredCaveatJSON(t *testing.T) { + RegisterCaveatType(&myUnregistered{}) + c := &myUnregistered{Foo: 1, Bar: map[string]string{"a": "b"}} + cs := NewCaveatSet(c) + b, err := json.Marshal(cs) + assert.NoError(t, err) + unregisterCaveatType(&myUnregistered{}) + + cs2 := NewCaveatSet() + err = json.Unmarshal(b, cs2) + assert.NoError(t, err) + assert.Equal(t, 1, len(cs2.Caveats)) + + uc, ok := cs2.Caveats[0].(*UnregisteredCaveat) + assert.True(t, ok) + assert.Equal(t, cavMyUnregistered, uc.Type) + + assert.Equal(t, + any(map[string]any{ + "bar": map[string]any{ + "a": "b", + }, + "foo": float64(1), + }), + uc.Body, + ) + + _, err = cs2.MarshalMsgpack() + assert.EqualError(t, err, "cannot convert unregistered caveats from JSON to msgpack") + + b2, err := json.Marshal(cs2) + assert.NoError(t, err) + assert.Equal(t, string(b), string(b2)) + + RegisterCaveatType(&myUnregistered{}) + t.Cleanup(func() { unregisterCaveatType(&myUnregistered{}) }) + + cs3 := NewCaveatSet() + err = json.Unmarshal(b2, cs3) + assert.NoError(t, err) + assert.Equal(t, cs, cs3) +} + +func TestUnregisteredCaveatMsgpack(t *testing.T) { + RegisterCaveatType(&myUnregistered{}) + c := &myUnregistered{Foo: 1, Bar: map[string]string{"a": "b"}} + cs := NewCaveatSet(c) + b, err := cs.MarshalMsgpack() + assert.NoError(t, err) + unregisterCaveatType(&myUnregistered{}) + + cs2, err := DecodeCaveats(b) + assert.NoError(t, err) + assert.Equal(t, 1, len(cs2.Caveats)) + + uc, ok := cs2.Caveats[0].(*UnregisteredCaveat) + assert.True(t, ok) + assert.Equal(t, cavMyUnregistered, uc.Type) + + assert.Equal(t, + any([]any{ + map[string]any{ + "a": "b", + }, + int8(1), + }), + uc.Body, + ) + + b2, err := cs2.MarshalMsgpack() + assert.NoError(t, err) + assert.Equal(t, b, b2) + + _, err = json.Marshal(cs2) + assert.EqualError(t, errors.Unwrap(errors.Unwrap(err)), "cannot convert unregistered caveats from msgpack to JSON") + + RegisterCaveatType(&myUnregistered{}) + t.Cleanup(func() { unregisterCaveatType(&myUnregistered{}) }) + + cs3, err := DecodeCaveats(b2) + assert.NoError(t, err) + assert.Equal(t, 1, len(cs3.Caveats)) + mucs := GetCaveats[*myUnregistered](cs3) + assert.Equal(t, 1, len(mucs)) + assert.Equal(t, c, mucs[0]) +} + func ptr[T any](t T) *T { return &t } diff --git a/macaroon_test.go b/macaroon_test.go index da845ea..ff42045 100644 --- a/macaroon_test.go +++ b/macaroon_test.go @@ -27,6 +27,7 @@ func cavExpiry(d time.Duration) Caveat { const ( cavTestParentResource = iota + CavMinUserDefined cavTestChildResource + cavMyUnregistered ) type testCaveatParentResource struct {