Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AVRO-3803: Python patched up ipc issues #2319

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 142 additions & 67 deletions lang/py/avro/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io
import os
import struct
from typing import Any, Dict, NamedTuple, Tuple

import avro.errors
import avro.io
Expand Down Expand Up @@ -185,7 +186,7 @@ def read_call_response(self, message_name, decoder):
the error, serialized per the message's error union schema.
"""
# response metadata
response_metadata = META_READER.read(decoder)
META_READER.read(decoder)

# remote response schema
remote_message_schema = self.remote_protocol.messages.get(message_name)
Expand Down Expand Up @@ -226,6 +227,11 @@ def issue_request(self, call_request, message_name, request_datum):
return self.request(message_name, request_datum)


class AvroHandshake(NamedTuple):
remote_protocol: avro.protocol.Protocol
handshake_response: Dict[str, str]


class Responder:
"""Base class for the server side of a protocol interaction."""

Expand All @@ -247,71 +253,15 @@ def get_protocol_cache(self, hash):
def set_protocol_cache(self, hash, protocol):
self.protocol_cache[hash] = protocol

def respond(self, call_request):
"""
Called by a server to deserialize a request, compute and serialize
a response or error. Compare to 'handle()' in Thrift.
"""
buffer_reader = io.BytesIO(call_request)
buffer_decoder = avro.io.BinaryDecoder(buffer_reader)
buffer_writer = io.BytesIO()
buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
error = None
response_metadata = {}

try:
remote_protocol = self.process_handshake(buffer_decoder, buffer_encoder)
# handshake failure
if remote_protocol is None:
return buffer_writer.getvalue()

# read request using remote protocol
request_metadata = META_READER.read(buffer_decoder)
remote_message_name = buffer_decoder.read_utf8()

# get remote and local request schemas so we can do
# schema resolution (one fine day)
remote_message = remote_protocol.messages.get(remote_message_name)
if remote_message is None:
fail_msg = f"Unknown remote message: {remote_message_name}"
raise avro.errors.AvroException(fail_msg)
local_message = self.local_protocol.messages.get(remote_message_name)
if local_message is None:
fail_msg = f"Unknown local message: {remote_message_name}"
raise avro.errors.AvroException(fail_msg)
writers_schema = remote_message.request
readers_schema = local_message.request
request = self.read_request(writers_schema, readers_schema, buffer_decoder)

# perform server logic
try:
response = self.invoke(local_message, request)
except avro.errors.AvroRemoteException as e:
error = e
except Exception as e:
error = avro.errors.AvroRemoteException(str(e))

# write response using local protocol
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(error is not None)
if error is None:
writers_schema = local_message.response
self.write_response(writers_schema, response, buffer_encoder)
else:
writers_schema = local_message.errors
self.write_error(writers_schema, error, buffer_encoder)
except schema.AvroException as e:
error = avro.errors.AvroRemoteException(str(e))
buffer_encoder = avro.io.BinaryEncoder(io.BytesIO())
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(True)
self.write_error(SYSTEM_ERROR_SCHEMA, error, buffer_encoder)
return buffer_writer.getvalue()

def process_handshake(self, decoder, encoder):
def process_handshake(
self,
decoder: avro.io.BinaryDecoder,
) -> AvroHandshake:
handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
handshake_response = {}
handshake_response: Dict[str, str] = {}

if not isinstance(handshake_request, dict):
raise avro.errors.AvroTypeException(f"invalid handshake request - {handshake_request}")
# determine the remote protocol
client_hash = handshake_request.get("clientHash")
client_protocol = handshake_request.get("clientProtocol")
Expand All @@ -337,8 +287,133 @@ def process_handshake(self, decoder, encoder):
handshake_response["serverProtocol"] = str(self.local_protocol)
handshake_response["serverHash"] = self.local_hash

HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
return remote_protocol
return AvroHandshake(remote_protocol, handshake_response)

def extract_messages_from_handshake(
self,
handshake: AvroHandshake,
buffer_decoder: avro.io.BinaryDecoder,
) -> Tuple[avro.protocol.Message, avro.protocol.Message]:
remote_protocol, handshake_response = handshake

# read request using remote protocol
META_READER.read(buffer_decoder)
remote_message_name = buffer_decoder.read_utf8()

# get remote and local request schemas so we can do
# schema resolution (one fine day)
if remote_protocol.messages is None:
raise avro.errors.AvroTypeException("Missing messages in remote_protocol")
remote_message = remote_protocol.messages.get(remote_message_name)
if remote_message is None:
fail_msg = f"Unknown remote message: {remote_message_name}"
raise avro.errors.AvroException(fail_msg)
if self.local_protocol.messages is None:
raise avro.errors.AvroTypeException("Missing messages in local_protocol")
local_message = self.local_protocol.messages.get(remote_message_name)
if local_message is None:
fail_msg = f"Unknown local message: {remote_message_name}"
raise avro.errors.AvroException(fail_msg)

return local_message, remote_message

def handle_request(
self,
local_message: avro.protocol.Message,
remote_message: avro.protocol.Message,
buffer_decoder: avro.io.BinaryDecoder,
) -> Any:
writers_schema = remote_message.request
readers_schema = local_message.request
request = self.read_request(writers_schema, readers_schema, buffer_decoder)
response = self.invoke(local_message, request)

return response

def handle_response(
self,
response_metadata: dict,
handshake: AvroHandshake,
local_message: avro.protocol.Message,
response: Any,
) -> bytes:
buffer_writer = io.BytesIO()
buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
HANDSHAKE_RESPONDER_WRITER.write(handshake.handshake_response, buffer_encoder)
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(False)
writers_schema = local_message.response
self.write_response(writers_schema, response, buffer_encoder)
return buffer_writer.getvalue()

def handle_error(
self,
response_metadata: dict,
writers_schema: avro.schema.Schema,
handshake: AvroHandshake,
error: Any,
):
buffer_writer = io.BytesIO()
buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
HANDSHAKE_RESPONDER_WRITER.write(handshake.handshake_response, buffer_encoder)
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(True)
self.write_error(writers_schema, error, buffer_encoder)
return buffer_writer.getvalue()

def respond(self, call_request: bytes):
"""
Called by a server to deserialize a request, compute and serialize
a response or error. Compare to 'handle()' in Thrift.
"""
buffer_reader = io.BytesIO(call_request)
buffer_decoder = avro.io.BinaryDecoder(buffer_reader)
response_metadata: dict = {}

handshake = self.process_handshake(buffer_decoder)

try:
# handshake failure
local_message, remote_message = self.extract_messages_from_handshake(
handshake,
buffer_decoder,
)

# perform server logic
try:
response = self.handle_request(
local_message=local_message,
remote_message=remote_message,
buffer_decoder=buffer_decoder,
)
except avro.errors.AvroRemoteException as err:
return self.handle_error(
response_metadata=response_metadata,
writers_schema=local_message.errors,
handshake=handshake,
error=err,
)
except Exception as e:
return self.handle_error(
response_metadata=response_metadata,
writers_schema=local_message.errors,
handshake=handshake,
error=avro.errors.AvroRemoteException(str(e)),
)

return self.handle_response(
response_metadata=response_metadata,
local_message=local_message,
response=response,
handshake=handshake,
)
except avro.errors.AvroException as e:
return self.handle_error(
response_metadata=response_metadata,
writers_schema=SYSTEM_ERROR_SCHEMA,
handshake=handshake,
error=avro.errors.AvroRemoteException(str(e)),
)

def invoke(self, local_message, request):
"""
Expand Down Expand Up @@ -382,7 +457,7 @@ def read_framed_message(self):
return b"".join(message)
while buffer.tell() < buffer_length:
chunk = self.reader.read(buffer_length - buffer.tell())
if chunk == "":
if chunk == b"":
kojiromike marked this conversation as resolved.
Show resolved Hide resolved
raise avro.errors.ConnectionClosedException("Reader read 0 bytes.")
buffer.write(chunk)
message.append(buffer.getvalue())
Expand Down
27 changes: 22 additions & 5 deletions lang/py/avro/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,20 +873,37 @@ def __init__(
if schema_type == "request":
Schema.__init__(self, schema_type, other_props)
else:
NamedSchema.__init__(self, schema_type, name, namespace, names, other_props, validate_names=validate_names)
NamedSchema.__init__(
self,
schema_type,
name,
namespace,
names,
other_props,
validate_names=validate_names,
)

names = names or Names(validate_names=self.validate_names)
if schema_type == "record":
if (schema_type == "record") or (schema_type == "error"):
asosnovsky marked this conversation as resolved.
Show resolved Hide resolved
old_default = names.default_namespace
names.default_namespace = Name(name, namespace, names.default_namespace, validate_name=validate_names).space
names.default_namespace = Name(
name,
namespace,
names.default_namespace,
validate_name=validate_names,
).space

# Add class members
field_objects = RecordSchema.make_field_objects(fields, names, validate_names=validate_names)
field_objects = RecordSchema.make_field_objects(
fields,
names,
validate_names=validate_names,
)
self.set_prop("fields", field_objects)
if doc is not None:
self.set_prop("doc", doc)

if schema_type == "record":
if (schema_type == "record") or (schema_type == "error"):
asosnovsky marked this conversation as resolved.
Show resolved Hide resolved
names.default_namespace = old_default

# read-only properties
Expand Down
8 changes: 8 additions & 0 deletions lang/py/avro/test/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
servers yet available.
"""

import io
import unittest

import avro.errors
import avro.ipc


Expand All @@ -38,6 +40,12 @@ def test_server_with_path(self):
client_with_default_path = avro.ipc.HTTPTransceiver("apache.org", 80)
self.assertEqual("/", client_with_default_path.req_resource)

def test_empty_reader(self):
response_reader = avro.ipc.FramedReader(io.BytesIO(b"Bad Response"))
with self.assertRaises(avro.errors.ConnectionClosedException) as cm:
response_reader.read_framed_message()
assert str(cm.exception) == "Reader read 0 bytes."
kojiromike marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__": # pragma: no coverage
unittest.main()