Skip to content

Commit

Permalink
Refactor as_json in asyncio client
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabrizian committed Jul 27, 2023
1 parent 3af995f commit 63d89e5
Showing 1 changed file with 17 additions and 66 deletions.
83 changes: 17 additions & 66 deletions src/python/library/tritonclient/grpc/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def __init__(
self._client_stub = service_pb2_grpc.GRPCInferenceServiceStub(self._channel)
self._verbose = verbose

def _return_response(self, response, as_json):
if as_json:
return json.loads(MessageToJson(response, preserving_proto_field_name=True))
else:
return response

async def __aenter__(self):
return self

Expand Down Expand Up @@ -198,12 +204,7 @@ async def get_server_metadata(self, headers=None, as_json=False):
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response)

Check failure

Code scanning / CodeQL

Wrong number of arguments in a call Error

Call to
method InferenceServerClient._return_response
with too few arguments; should be no fewer than 2.
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -225,12 +226,7 @@ async def get_model_metadata(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -252,12 +248,7 @@ async def get_model_config(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -277,12 +268,7 @@ async def get_model_repository_index(self, headers=None, as_json=False):
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand Down Expand Up @@ -349,12 +335,7 @@ async def get_inference_statistics(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand Down Expand Up @@ -384,12 +365,7 @@ async def update_trace_settings(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -407,12 +383,7 @@ async def get_trace_settings(self, model_name=None, headers=None, as_json=False)
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -439,12 +410,7 @@ async def update_log_settings(self, settings, headers=None, as_json=False):
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -460,12 +426,7 @@ async def get_log_settings(self, headers=None, as_json=False):
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -487,12 +448,7 @@ async def get_system_shared_memory_status(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand Down Expand Up @@ -562,12 +518,7 @@ async def get_cuda_shared_memory_status(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand Down

0 comments on commit 63d89e5

Please sign in to comment.