Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AVRO-3803: Python patched up ipc issues #2319

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 143 additions & 68 deletions lang/py/avro/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io
import os
import struct
from typing import Any, Dict, NamedTuple, Tuple

import avro.errors
import avro.io
Expand Down Expand Up @@ -185,7 +186,7 @@ def read_call_response(self, message_name, decoder):
the error, serialized per the message's error union schema.
"""
# response metadata
response_metadata = META_READER.read(decoder)
META_READER.read(decoder)

# remote response schema
remote_message_schema = self.remote_protocol.messages.get(message_name)
Expand Down Expand Up @@ -226,6 +227,11 @@ def issue_request(self, call_request, message_name, request_datum):
return self.request(message_name, request_datum)


class AvroHandshake(NamedTuple):
remote_protocol: avro.protocol.Protocol
handshake_response: Dict[str, str]


class Responder:
"""Base class for the server side of a protocol interaction."""

Expand All @@ -247,71 +253,15 @@ def get_protocol_cache(self, hash):
def set_protocol_cache(self, hash, protocol):
self.protocol_cache[hash] = protocol

def respond(self, call_request):
"""
Called by a server to deserialize a request, compute and serialize
a response or error. Compare to 'handle()' in Thrift.
"""
buffer_reader = io.BytesIO(call_request)
buffer_decoder = avro.io.BinaryDecoder(buffer_reader)
buffer_writer = io.BytesIO()
buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
error = None
response_metadata = {}

try:
remote_protocol = self.process_handshake(buffer_decoder, buffer_encoder)
# handshake failure
if remote_protocol is None:
return buffer_writer.getvalue()

# read request using remote protocol
request_metadata = META_READER.read(buffer_decoder)
remote_message_name = buffer_decoder.read_utf8()

# get remote and local request schemas so we can do
# schema resolution (one fine day)
remote_message = remote_protocol.messages.get(remote_message_name)
if remote_message is None:
fail_msg = f"Unknown remote message: {remote_message_name}"
raise avro.errors.AvroException(fail_msg)
local_message = self.local_protocol.messages.get(remote_message_name)
if local_message is None:
fail_msg = f"Unknown local message: {remote_message_name}"
raise avro.errors.AvroException(fail_msg)
writers_schema = remote_message.request
readers_schema = local_message.request
request = self.read_request(writers_schema, readers_schema, buffer_decoder)

# perform server logic
try:
response = self.invoke(local_message, request)
except avro.errors.AvroRemoteException as e:
error = e
except Exception as e:
error = avro.errors.AvroRemoteException(str(e))

# write response using local protocol
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(error is not None)
if error is None:
writers_schema = local_message.response
self.write_response(writers_schema, response, buffer_encoder)
else:
writers_schema = local_message.errors
self.write_error(writers_schema, error, buffer_encoder)
except schema.AvroException as e:
error = avro.errors.AvroRemoteException(str(e))
buffer_encoder = avro.io.BinaryEncoder(io.BytesIO())
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(True)
self.write_error(SYSTEM_ERROR_SCHEMA, error, buffer_encoder)
return buffer_writer.getvalue()

def process_handshake(self, decoder, encoder):
def process_handshake(
self,
decoder: avro.io.BinaryDecoder,
) -> AvroHandshake:
handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
handshake_response = {}
handshake_response: Dict[str, str] = {}

if not isinstance(handshake_request, dict):
raise avro.errors.AvroTypeException(f"invalid handshake request - {handshake_request}")
# determine the remote protocol
client_hash = handshake_request.get("clientHash")
client_protocol = handshake_request.get("clientProtocol")
Expand All @@ -337,8 +287,133 @@ def process_handshake(self, decoder, encoder):
handshake_response["serverProtocol"] = str(self.local_protocol)
handshake_response["serverHash"] = self.local_hash

HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
return remote_protocol
return AvroHandshake(remote_protocol, handshake_response)

def extract_messages_from_handshake(
self,
handshake: AvroHandshake,
buffer_decoder: avro.io.BinaryDecoder,
) -> Tuple[avro.protocol.Message, avro.protocol.Message]:
remote_protocol, handshake_response = handshake

# read request using remote protocol
META_READER.read(buffer_decoder)
remote_message_name = buffer_decoder.read_utf8()

# get remote and local request schemas so we can do
# schema resolution (one fine day)
if remote_protocol.messages is None:
raise avro.errors.AvroTypeException("Missing messages in remote_protocol")
remote_message = remote_protocol.messages.get(remote_message_name)
if remote_message is None:
fail_msg = f"Unknown remote message: {remote_message_name}"
raise avro.errors.AvroException(fail_msg)
if self.local_protocol.messages is None:
raise avro.errors.AvroTypeException("Missing messages in local_protocol")
local_message = self.local_protocol.messages.get(remote_message_name)
if local_message is None:
fail_msg = f"Unknown local message: {remote_message_name}"
raise avro.errors.AvroException(fail_msg)

return local_message, remote_message

def handle_request(
self,
local_message: avro.protocol.Message,
remote_message: avro.protocol.Message,
buffer_decoder: avro.io.BinaryDecoder,
) -> Any:
writers_schema = remote_message.request
readers_schema = local_message.request
request = self.read_request(writers_schema, readers_schema, buffer_decoder)
response = self.invoke(local_message, request)

return response

def handle_response(
self,
response_metadata: dict,
handshake: AvroHandshake,
local_message: avro.protocol.Message,
response: Any,
) -> bytes:
buffer_writer = io.BytesIO()
buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
HANDSHAKE_RESPONDER_WRITER.write(handshake.handshake_response, buffer_encoder)
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(False)
writers_schema = local_message.response
self.write_response(writers_schema, response, buffer_encoder)
return buffer_writer.getvalue()

def handle_error(
self,
response_metadata: dict,
writers_schema: avro.schema.Schema,
handshake: AvroHandshake,
error: Any,
):
buffer_writer = io.BytesIO()
buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
HANDSHAKE_RESPONDER_WRITER.write(handshake.handshake_response, buffer_encoder)
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(True)
self.write_error(writers_schema, error, buffer_encoder)
return buffer_writer.getvalue()

def respond(self, call_request: bytes):
"""
Called by a server to deserialize a request, compute and serialize
a response or error. Compare to 'handle()' in Thrift.
"""
buffer_reader = io.BytesIO(call_request)
buffer_decoder = avro.io.BinaryDecoder(buffer_reader)
response_metadata: dict = {}

handshake = self.process_handshake(buffer_decoder)

try:
# handshake failure
local_message, remote_message = self.extract_messages_from_handshake(
handshake,
buffer_decoder,
)

# perform server logic
try:
response = self.handle_request(
local_message=local_message,
remote_message=remote_message,
buffer_decoder=buffer_decoder,
)
except avro.errors.AvroRemoteException as err:
return self.handle_error(
response_metadata=response_metadata,
writers_schema=local_message.errors,
handshake=handshake,
error=err,
)
except Exception as e:
return self.handle_error(
response_metadata=response_metadata,
writers_schema=local_message.errors,
handshake=handshake,
error=avro.errors.AvroRemoteException(str(e)),
)

return self.handle_response(
response_metadata=response_metadata,
local_message=local_message,
response=response,
handshake=handshake,
)
except avro.errors.AvroException as e:
return self.handle_error(
response_metadata=response_metadata,
writers_schema=SYSTEM_ERROR_SCHEMA,
handshake=handshake,
error=avro.errors.AvroRemoteException(str(e)),
)

def invoke(self, local_message, request):
"""
Expand Down Expand Up @@ -382,14 +457,14 @@ def read_framed_message(self):
return b"".join(message)
while buffer.tell() < buffer_length:
chunk = self.reader.read(buffer_length - buffer.tell())
if chunk == "":
if chunk == b"":
kojiromike marked this conversation as resolved.
Show resolved Hide resolved
raise avro.errors.ConnectionClosedException("Reader read 0 bytes.")
buffer.write(chunk)
message.append(buffer.getvalue())

def _read_buffer_length(self):
read = self.reader.read(BUFFER_HEADER_LENGTH)
if read == "":
if read == b"":
raise avro.errors.ConnectionClosedException("Reader read 0 bytes.")
return BIG_ENDIAN_INT_STRUCT.unpack(read)[0]

Expand Down
27 changes: 22 additions & 5 deletions lang/py/avro/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@

def match(self, writer):
"""Return True if the current schema (as reader) matches the other schema.

Check warning

Code scanning / CodeQL

`__eq__` not overridden when adding attributes Warning

The class 'UnionSchema' does not override
'__eq__'
, but adds the new attribute
_schemas
.
The class 'UnionSchema' does not override
'__eq__'
, but adds the new attribute
_schemas
.
@arg writer: the schema to match against
@return bool
"""
Expand Down Expand Up @@ -873,20 +873,37 @@
if schema_type == "request":
Schema.__init__(self, schema_type, other_props)
else:
NamedSchema.__init__(self, schema_type, name, namespace, names, other_props, validate_names=validate_names)
NamedSchema.__init__(
self,
schema_type,
name,
namespace,
names,
other_props,
validate_names=validate_names,
)

names = names or Names(validate_names=self.validate_names)
if schema_type == "record":
if schema_type in ("record", "error"):
old_default = names.default_namespace
names.default_namespace = Name(name, namespace, names.default_namespace, validate_name=validate_names).space
names.default_namespace = Name(
name,
namespace,
names.default_namespace,
validate_name=validate_names,
).space

# Add class members
field_objects = RecordSchema.make_field_objects(fields, names, validate_names=validate_names)
field_objects = RecordSchema.make_field_objects(
fields,
names,
validate_names=validate_names,
)
self.set_prop("fields", field_objects)
if doc is not None:
self.set_prop("doc", doc)

if schema_type == "record":
if schema_type in ("record", "error"):
names.default_namespace = old_default

# read-only properties
Expand Down
8 changes: 8 additions & 0 deletions lang/py/avro/test/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
servers yet available.
"""

import io
import unittest

import avro.errors
import avro.ipc


Expand All @@ -38,6 +40,12 @@ def test_server_with_path(self):
client_with_default_path = avro.ipc.HTTPTransceiver("apache.org", 80)
self.assertEqual("/", client_with_default_path.req_resource)

def test_empty_reader(self):
response_reader = avro.ipc.FramedReader(io.BytesIO(b"Bad Response"))
with self.assertRaises(avro.errors.ConnectionClosedException) as cm:
response_reader.read_framed_message()
assert str(cm.exception) == "Reader read 0 bytes."
kojiromike marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__": # pragma: no coverage
unittest.main()