diff --git a/cli/any.go b/cli/any.go index 11db5b2..0bb603c 100644 --- a/cli/any.go +++ b/cli/any.go @@ -3,10 +3,15 @@ package cli import ( "fmt" "log/slog" + "strconv" + "strings" pythonv1 "buf.build/gen/go/stealthrocket/dispatch-proto/protocolbuffers/go/dispatch/sdk/python/v1" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -15,51 +20,130 @@ func anyString(any *anypb.Any) string { return "nil" } - var s string - var err error - switch any.TypeUrl { - case "buf.build/stealthrocket/dispatch-proto/dispatch.sdk.python.v1.Pickled": - var pickled proto.Message - pickled, err = any.UnmarshalNew() - if err == nil { - if p, ok := pickled.(*pythonv1.Pickled); ok { - s, err = pythonPickleString(p.PickledValue) - } else { - err = fmt.Errorf("invalid pickled message: %T", p) - } + m, err := any.UnmarshalNew() + if err != nil { + return unsupportedAny(any, err) + } + + switch mm := m.(type) { + case *wrapperspb.BytesValue: + // The Python SDK originally wrapped pickled values in a + // wrapperspb.BytesValue. Try to unpickle the bytes first, + // and return literal bytes if they cannot be unpickled. + s, err := pythonPickleString(mm.Value) + if err != nil { + s = fmt.Sprintf("bytes(%s)", truncateBytes(mm.Value)) + } + return s + + case *wrapperspb.Int32Value: + return strconv.FormatInt(int64(mm.Value), 10) + + case *wrapperspb.Int64Value: + return strconv.FormatInt(mm.Value, 10) + + case *wrapperspb.UInt32Value: + return strconv.FormatUint(uint64(mm.Value), 10) + + case *wrapperspb.UInt64Value: + return strconv.FormatUint(mm.Value, 10) + + case *wrapperspb.StringValue: + return fmt.Sprintf("%q", mm.Value) + + case *wrapperspb.BoolValue: + return strconv.FormatBool(mm.Value) + + case *wrapperspb.FloatValue: + return fmt.Sprintf("%v", mm.Value) + + case *wrapperspb.DoubleValue: + return fmt.Sprintf("%v", mm.Value) + + case *emptypb.Empty: + return "empty()" + + case *timestamppb.Timestamp: + return mm.AsTime().String() + + case *durationpb.Duration: + return mm.AsDuration().String() + + case *structpb.Struct: + return structpbStructString(mm) + + case *structpb.ListValue: + return structpbListString(mm) + + case *structpb.Value: + return structpbValueString(mm) + + case *pythonv1.Pickled: + s, err := pythonPickleString(mm.PickledValue) + if err != nil { + return unsupportedAny(any, fmt.Errorf("pickle error: %w", err)) } - case "type.googleapis.com/google.protobuf.BytesValue": - s, err = anyBytesString(any) + return s + default: - // TODO: support unpacking other types of serialized values - err = fmt.Errorf("not implemented: %s", any.TypeUrl) + return unsupportedAny(any, fmt.Errorf("not implemented: %T", m)) } - if err != nil { - slog.Debug("cannot parse input/output value", "error", err) - return fmt.Sprintf("%s(?)", any.TypeUrl) +} + +func structpbStructString(s *structpb.Struct) string { + var b strings.Builder + b.WriteByte('{') + i := 0 + for name, value := range s.Fields { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(fmt.Sprintf("%q", name)) + b.WriteString(": ") + b.WriteString(structpbValueString(value)) + i++ } - return s + b.WriteByte('}') + return b.String() } -func anyBytesString(any *anypb.Any) (string, error) { - m, err := anypb.UnmarshalNew(any, proto.UnmarshalOptions{}) - if err != nil { - return "", err +func structpbListString(s *structpb.ListValue) string { + var b strings.Builder + b.WriteByte('[') + for i, value := range s.Values { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(structpbValueString(value)) } - bv, ok := m.(*wrapperspb.BytesValue) - if !ok { - return "", fmt.Errorf("invalid bytes value: %T", m) + b.WriteByte(']') + return b.String() +} + +func structpbValueString(s *structpb.Value) string { + switch v := s.Kind.(type) { + case *structpb.Value_StructValue: + return structpbStructString(v.StructValue) + case *structpb.Value_ListValue: + return structpbListString(v.ListValue) + case *structpb.Value_BoolValue: + return strconv.FormatBool(v.BoolValue) + case *structpb.Value_NumberValue: + return fmt.Sprintf("%v", v.NumberValue) + case *structpb.Value_StringValue: + return fmt.Sprintf("%q", v.StringValue) + case *structpb.Value_NullValue: + return "null" + default: + panic("unreachable") } - b := bv.Value +} - // The Python SDK originally wrapped pickled values in a - // wrapperspb.BytesValue. Try to unpickle the bytes first, - // and return literal bytes if they cannot be unpickled. - s, err := pythonPickleString(b) +func unsupportedAny(any *anypb.Any, err error) string { if err != nil { - s = string(truncateBytes(b)) + slog.Debug("cannot parse input/output value", "error", err) } - return s, nil + return fmt.Sprintf("%s(?)", any.TypeUrl) } func truncateBytes(b []byte) []byte { diff --git a/cli/any_test.go b/cli/any_test.go new file mode 100644 index 0000000..a30fca6 --- /dev/null +++ b/cli/any_test.go @@ -0,0 +1,159 @@ +package cli + +import ( + "testing" + "time" + + pythonv1 "buf.build/gen/go/stealthrocket/dispatch-proto/protocolbuffers/go/dispatch/sdk/python/v1" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +func TestAnyString(t *testing.T) { + for _, test := range []struct { + input *anypb.Any + want string + }{ + { + input: asAny(wrapperspb.Bool(true)), + want: "true", + }, + { + input: asAny(wrapperspb.Int32(-1)), + want: "-1", + }, + { + input: asAny(wrapperspb.Int64(2)), + want: "2", + }, + { + input: asAny(wrapperspb.UInt32(3)), + want: "3", + }, + { + input: asAny(wrapperspb.UInt64(4)), + want: "4", + }, + { + input: asAny(wrapperspb.Float(1.25)), + want: "1.25", + }, + { + input: asAny(wrapperspb.Double(3.14)), + want: "3.14", + }, + { + input: asAny(wrapperspb.String("foo")), + want: `"foo"`, + }, + { + input: asAny(wrapperspb.Bytes([]byte("foobar"))), + want: "bytes(foob...)", + }, + { + input: asAny(timestamppb.New(time.Date(2024, time.June, 25, 10, 56, 11, 1234, time.UTC))), + want: "2024-06-25 10:56:11.000001234 +0000 UTC", + }, + { + input: asAny(durationpb.New(1 * time.Second)), + want: "1s", + }, + { + // $ python3 -c 'import pickle; print(pickle.dumps(1))' + // b'\x80\x04K\x01.' + input: pickled([]byte("\x80\x04K\x01.")), + want: "1", + }, + { + // Legacy way that the Python SDK wrapped pickled values: + input: asAny(wrapperspb.Bytes([]byte("\x80\x04K\x01."))), + want: "1", + }, + { + // $ python3 -c 'import pickle; print(pickle.dumps("bar"))' + // b'\x80\x04\x95\x07\x00\x00\x00\x00\x00\x00\x00\x8c\x03foo\x94.' + input: pickled([]byte("\x80\x04\x95\x07\x00\x00\x00\x00\x00\x00\x00\x8c\x03bar\x94.")), + want: `"bar"`, + }, + { + input: pickled([]byte("!!!invalid!!!")), + want: "buf.build/stealthrocket/dispatch-proto/dispatch.sdk.python.v1.Pickled(?)", + }, + { + input: &anypb.Any{TypeUrl: "com.example/some.Message"}, + want: "com.example/some.Message(?)", + }, + { + input: asAny(&emptypb.Empty{}), + want: "empty()", + }, + { + input: asAny(structpb.NewNullValue()), + want: "null", + }, + { + input: asAny(structpb.NewBoolValue(false)), + want: "false", + }, + { + input: asAny(structpb.NewNumberValue(1111)), + want: "1111", + }, + { + input: asAny(structpb.NewNumberValue(3.14)), + want: "3.14", + }, + { + input: asAny(structpb.NewStringValue("foobar")), + want: `"foobar"`, + }, + { + input: asStructValue([]any{1, true, "abc", nil, map[string]any{}, []any{}}), + want: `[1, true, "abc", null, {}, []]`, + }, + { + input: asStructValue(map[string]any{"foo": []any{"bar", "baz"}}), + want: `{"foo": ["bar", "baz"]}`, + }, + } { + t.Run(test.want, func(*testing.T) { + got := anyString(test.input) + if got != test.want { + t.Errorf("unexpected string: got %v, want %v", got, test.want) + } + }) + } +} + +func asAny(m proto.Message) *anypb.Any { + any, err := anypb.New(m) + if err != nil { + panic(err) + } + return any +} + +func asStructValue(v any) *anypb.Any { + m, err := structpb.NewValue(v) + if err != nil { + panic(err) + } + return asAny(m) +} + +func pickled(b []byte) *anypb.Any { + m := &pythonv1.Pickled{PickledValue: b} + mb, err := proto.Marshal(m) + if err != nil { + panic(err) + } + return &anypb.Any{ + TypeUrl: "buf.build/stealthrocket/dispatch-proto/" + string(m.ProtoReflect().Descriptor().FullName()), + Value: mb, + } +}