Skip to content

Commit

Permalink
Second attempt at fixing model_repo issue
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf committed Aug 11, 2023
1 parent f640e10 commit 6fccc79
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
4 changes: 1 addition & 3 deletions model_analyzer/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,7 @@ def get_triton_handles(config, gpus):

client = get_client_handle(config)
fail_if_server_already_running(client, config)
server = TritonServerFactory.get_server_handle(
config, gpus, use_model_repository=bool(config.model_repository)
)
server = TritonServerFactory.get_server_handle(config, gpus)

return client, server

Expand Down
6 changes: 4 additions & 2 deletions model_analyzer/triton/server/server_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def get_server_handle(config, gpus, use_model_repository=False):
server = TritonServerFactory._get_remote_server_handle(config)
elif config.triton_launch_mode == "local":
server = TritonServerFactory._get_local_server_handle(
config, gpus, use_model_repository
config, gpus, use_model_repository=True
)
elif config.triton_launch_mode == "docker":
server = TritonServerFactory._get_docker_server_handle(
config, gpus, use_model_repository
config, gpus, use_model_repository=True
)
elif config.triton_launch_mode == "c_api":
server = TritonServerFactory._get_c_api_server_handle(
Expand Down Expand Up @@ -178,6 +178,8 @@ def _get_local_server_handle(config, gpus, use_model_repository):
triton_config = TritonServerConfig()
triton_config.update_config(config.triton_server_flags)

assert use_model_repository and config.model_repository

triton_config["model-repository"] = (
config.model_repository
if use_model_repository
Expand Down
6 changes: 5 additions & 1 deletion tests/test_triton_server_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ class TestTritonServerFactory(trc.TestResultCollector):
def setUp(self):
# Mock path validation
self.mock_os = MockOSMethods(
mock_paths=["model_analyzer.triton.server.server_factory"]
mock_paths=[
"model_analyzer.triton.server.server_factory",
"model_analyzer.config.input.config_utils",
]
)
self.mock_os.start()

Expand All @@ -53,6 +56,7 @@ def _test_get_server_handle_helper(
"""

config = ConfigCommandProfile()
config.model_repository = "/fake_model_repository"
config.triton_launch_mode = launch_mode
config.triton_http_endpoint = "fake_address:2345"
config.triton_grpc_endpoint = "fake_address:4567"
Expand Down

0 comments on commit 6fccc79

Please sign in to comment.