Skip to content

Commit

Permalink
Merge pull request #179 from dispatchrun/use-proto-for-serialization-…
Browse files Browse the repository at this point in the history
…if-available

Avoid pickling primitive values and proto messages
  • Loading branch information
chriso authored Jun 26, 2024
2 parents bb6ec79 + 6b283a4 commit 90b97d9
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 59 deletions.
170 changes: 170 additions & 0 deletions src/dispatch/any.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from __future__ import annotations

import pickle
from datetime import datetime, timedelta, timezone
from typing import Any

import google.protobuf.any_pb2
import google.protobuf.duration_pb2
import google.protobuf.empty_pb2
import google.protobuf.message
import google.protobuf.struct_pb2
import google.protobuf.timestamp_pb2
import google.protobuf.wrappers_pb2
from google.protobuf import descriptor_pool, message_factory

from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb

INT64_MIN = -9223372036854775808
INT64_MAX = 9223372036854775807


def marshal_any(value: Any) -> google.protobuf.any_pb2.Any:
if value is None:
value = google.protobuf.empty_pb2.Empty()
elif isinstance(value, bool):
value = google.protobuf.wrappers_pb2.BoolValue(value=value)
elif isinstance(value, int) and INT64_MIN <= value <= INT64_MAX:
# To keep things simple, serialize all integers as int64 on the wire.
# For larger integers, fall through and use pickle.
value = google.protobuf.wrappers_pb2.Int64Value(value=value)
elif isinstance(value, float):
value = google.protobuf.wrappers_pb2.DoubleValue(value=value)
elif isinstance(value, str):
value = google.protobuf.wrappers_pb2.StringValue(value=value)
elif isinstance(value, bytes):
value = google.protobuf.wrappers_pb2.BytesValue(value=value)
elif isinstance(value, datetime):
# Note: datetime only supports microsecond granularity
seconds = int(value.timestamp())
nanos = value.microsecond * 1000
value = google.protobuf.timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
elif isinstance(value, timedelta):
# Note: timedelta only supports microsecond granularity
seconds = int(value.total_seconds())
nanos = value.microseconds * 1000
value = google.protobuf.duration_pb2.Duration(seconds=seconds, nanos=nanos)

if isinstance(value, list) or isinstance(value, dict):
try:
value = as_struct_value(value)
except ValueError:
pass # fallthrough

if not isinstance(value, google.protobuf.message.Message):
value = pickled_pb.Pickled(pickled_value=pickle.dumps(value))

any = google.protobuf.any_pb2.Any()
if value.DESCRIPTOR.full_name.startswith("dispatch.sdk."):
any.Pack(value, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
else:
any.Pack(value)

return any


def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any:
pool = descriptor_pool.Default()
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
proto = message_factory.GetMessageClass(msg_descriptor)()
any.Unpack(proto)

if isinstance(proto, pickled_pb.Pickled):
return pickle.loads(proto.pickled_value)

elif isinstance(proto, google.protobuf.empty_pb2.Empty):
return None

elif isinstance(proto, google.protobuf.wrappers_pb2.BoolValue):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.Int32Value):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.Int64Value):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.UInt32Value):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.UInt64Value):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.FloatValue):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.DoubleValue):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.StringValue):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.BytesValue):
try:
# Assume it's the legacy container for pickled values.
return pickle.loads(proto.value)
except Exception as e:
# Otherwise, return the literal bytes.
return proto.value

elif isinstance(proto, google.protobuf.timestamp_pb2.Timestamp):
return proto.ToDatetime(tzinfo=timezone.utc)

elif isinstance(proto, google.protobuf.duration_pb2.Duration):
return proto.ToTimedelta()

elif isinstance(proto, google.protobuf.struct_pb2.Value):
return from_struct_value(proto)

return proto


def as_struct_value(value: Any) -> google.protobuf.struct_pb2.Value:
if value is None:
null_value = google.protobuf.struct_pb2.NullValue.NULL_VALUE
return google.protobuf.struct_pb2.Value(null_value=null_value)

elif isinstance(value, bool):
return google.protobuf.struct_pb2.Value(bool_value=value)

elif isinstance(value, int) or isinstance(value, float):
return google.protobuf.struct_pb2.Value(number_value=float(value))

elif isinstance(value, str):
return google.protobuf.struct_pb2.Value(string_value=value)

elif isinstance(value, list):
list_value = google.protobuf.struct_pb2.ListValue(
values=[as_struct_value(v) for v in value]
)
return google.protobuf.struct_pb2.Value(list_value=list_value)

elif isinstance(value, dict):
for key in value.keys():
if not isinstance(key, str):
raise ValueError("unsupported object key")

struct_value = google.protobuf.struct_pb2.Struct(
fields={k: as_struct_value(v) for k, v in value.items()}
)
return google.protobuf.struct_pb2.Value(struct_value=struct_value)

raise ValueError("unsupported value")


def from_struct_value(value: google.protobuf.struct_pb2.Value) -> Any:
if value.HasField("null_value"):
return None
elif value.HasField("bool_value"):
return value.bool_value
elif value.HasField("number_value"):
return value.number_value
elif value.HasField("string_value"):
return value.string_value
elif value.HasField("list_value"):

return [from_struct_value(v) for v in value.list_value.values]
elif value.HasField("struct_value"):
return {k: from_struct_value(v) for k, v in value.struct_value.fields.items()}
else:
raise RuntimeError(f"invalid struct_pb2.Value: {value}")
64 changes: 9 additions & 55 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tblib # type: ignore[import-untyped]
from google.protobuf import descriptor_pool, duration_pb2, message_factory

from dispatch.any import marshal_any, unmarshal_any
from dispatch.error import IncompatibleStateError, InvalidArgumentError
from dispatch.id import DispatchID
from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb
Expand Down Expand Up @@ -78,11 +79,11 @@ def __init__(self, req: function_pb.RunRequest):

self._has_input = req.HasField("input")
if self._has_input:
self._input = _pb_any_unpack(req.input)
self._input = unmarshal_any(req.input)
else:
if req.poll_result.coroutine_state:
raise IncompatibleStateError # coroutine_state is deprecated
self._coroutine_state = _any_unpickle(req.poll_result.typed_coroutine_state)
self._coroutine_state = unmarshal_any(req.poll_result.typed_coroutine_state)
self._call_results = [
CallResult._from_proto(r) for r in req.poll_result.results
]
Expand Down Expand Up @@ -141,7 +142,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
return Input(
req=function_pb.RunRequest(
function=function,
input=_pb_any_pickle(input),
input=marshal_any(input),
)
)

Expand All @@ -157,7 +158,7 @@ def from_poll_results(
req=function_pb.RunRequest(
function=function,
poll_result=poll_pb.PollResult(
typed_coroutine_state=_pb_any_pickle(coroutine_state),
typed_coroutine_state=marshal_any(coroutine_state),
results=[result._as_proto() for result in call_results],
error=error._as_proto() if error else None,
),
Expand Down Expand Up @@ -241,7 +242,7 @@ def poll(
else None
)
poll = poll_pb.Poll(
typed_coroutine_state=_pb_any_pickle(coroutine_state),
typed_coroutine_state=marshal_any(coroutine_state),
min_results=min_results,
max_results=max_results,
max_wait=max_wait,
Expand Down Expand Up @@ -279,7 +280,7 @@ class Call:
correlation_id: Optional[int] = None

def _as_proto(self) -> call_pb.Call:
input_bytes = _pb_any_pickle(self.input)
input_bytes = marshal_any(self.input)
return call_pb.Call(
correlation_id=self.correlation_id,
endpoint=self.endpoint,
Expand All @@ -301,7 +302,7 @@ def _as_proto(self) -> call_pb.CallResult:
output_any = None
error_proto = None
if self.output is not None:
output_any = _pb_any_pickle(self.output)
output_any = marshal_any(self.output)
if self.error is not None:
error_proto = self.error._as_proto()

Expand All @@ -317,7 +318,7 @@ def _from_proto(cls, proto: call_pb.CallResult) -> CallResult:
output = None
error = None
if proto.HasField("output"):
output = _any_unpickle(proto.output)
output = unmarshal_any(proto.output)
if proto.HasField("error"):
error = Error._from_proto(proto.error)

Expand Down Expand Up @@ -438,50 +439,3 @@ def _as_proto(self) -> error_pb.Error:
return error_pb.Error(
type=self.type, message=self.message, value=value, traceback=self.traceback
)


def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
return pickle.loads(b.value)

elif not any.type_url and not any.value:
return None

raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}")


def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
any = google.protobuf.any_pb2.Any()
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
return any


def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any:
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
try:
# Assume it's the legacy container for pickled values.
return pickle.loads(b.value)
except Exception as e:
# Otherwise, return the literal bytes.
return b.value

pool = descriptor_pool.Default()
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
proto = message_factory.GetMessageClass(msg_descriptor)()
any.Unpack(proto)
return proto
Loading

0 comments on commit 90b97d9

Please sign in to comment.