Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
fpetrini15 committed Mar 27, 2024
1 parent a905021 commit ea16c37
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 48 deletions.
34 changes: 17 additions & 17 deletions qa/L0_shared_memory/shared_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def tearDown(self):
util.kill_server(self._server_process)
# Restore stdout / stderr so we can print to console and see server
# output in CI even after logs expire. Print test result to client
# before doing so.
# before doing so for legibility.
if not self._test_passed:
print("*\n*\n*\nTest Failed\n*\n*\n*\n")
util.stream_to_console(self._original_stdout, self._original_stderr)
Expand Down Expand Up @@ -294,9 +294,9 @@ def test_valid_create_set_register(self, protocol):
)
shm_status = self._triton_client.get_system_shared_memory_status()
if self._protocol == "http":
self.assertTrue(len(shm_status) == 1)
self.assertEqual(len(shm_status), 1)
else:
self.assertTrue(len(shm_status.regions) == 1)
self.assertEqual(len(shm_status.regions), 1)
shm.destroy_shared_memory_region(shm_op0_handle)
self._test_passed = True

Expand All @@ -316,9 +316,9 @@ def test_unregister_before_register(self, protocol):
self._triton_client.unregister_system_shared_memory("dummy_data")
shm_status = self._triton_client.get_system_shared_memory_status()
if self._protocol == "http":
self.assertTrue(len(shm_status) == 0)
self.assertEqual(len(shm_status), 0)
else:
self.assertTrue(len(shm_status.regions) == 0)
self.assertEqual(len(shm_status.regions), 0)
shm.destroy_shared_memory_region(shm_op0_handle)
self._test_passed = True

Expand All @@ -340,9 +340,9 @@ def test_unregister_after_register(self, protocol):
self._triton_client.unregister_system_shared_memory("dummy_data")
shm_status = self._triton_client.get_system_shared_memory_status()
if self._protocol == "http":
self.assertTrue(len(shm_status) == 0)
self.assertEqual(len(shm_status), 0)
else:
self.assertTrue(len(shm_status.regions) == 0)
self.assertEqual(len(shm_status.regions), 0)
shm.destroy_shared_memory_region(shm_op0_handle)
self._test_passed = True

Expand Down Expand Up @@ -371,9 +371,9 @@ def test_reregister_after_register(self, protocol):
)

Check notice

Code scanning / CodeQL

Imprecise assert Note

assertTrue(a in b) cannot provide an informative message. Using assertIn(a, b) instead will give more informative messages.
shm_status = self._triton_client.get_system_shared_memory_status()
if self._protocol == "http":
self.assertTrue(len(shm_status) == 1)
self.assertEqual(len(shm_status), 1)
else:
self.assertTrue(len(shm_status.regions) == 1)
self.assertEqual(len(shm_status.regions), 1)
shm.destroy_shared_memory_region(shm_op0_handle)
self._test_passed = True

Expand All @@ -400,9 +400,9 @@ def test_unregister_after_inference(self, protocol):
self._triton_client.unregister_system_shared_memory("output0_data")
shm_status = self._triton_client.get_system_shared_memory_status()
if self._protocol == "http":
self.assertTrue(len(shm_status) == 3)
self.assertEqual(len(shm_status), 3)
else:
self.assertTrue(len(shm_status.regions) == 3)
self.assertEqual(len(shm_status.regions), 3)
self._cleanup_server(shm_handles)
self._test_passed = True

Expand Down Expand Up @@ -433,9 +433,9 @@ def test_register_after_inference(self, protocol):
)
shm_status = self._triton_client.get_system_shared_memory_status()
if self._protocol == "http":
self.assertTrue(len(shm_status) == 5)
self.assertEqual(len(shm_status), 5)
else:
self.assertTrue(len(shm_status.regions) == 5)
self.assertEqual(len(shm_status.regions), 5)
shm_handles.append(shm_ip2_handle)
self._cleanup_server(shm_handles)
self._test_passed = True
Expand Down Expand Up @@ -502,15 +502,15 @@ def test_unregisterall(self, protocol):

status_before = self._triton_client.get_system_shared_memory_status()
if self._protocol == "http":
self.assertTrue(len(status_before) == 4)
self.assertEqual(len(status_before), 4)
else:
self.assertTrue(len(status_before.regions) == 4)
self.assertEqual(len(status_before.regions), 4)
self._triton_client.unregister_system_shared_memory()
status_after = self._triton_client.get_system_shared_memory_status()
if self._protocol == "http":
self.assertTrue(len(status_after) == 0)
self.assertEqual(len(status_after), 0)
else:
self.assertTrue(len(status_after.regions) == 0)
self.assertEqual(len(status_after.regions), 0)
self._cleanup_server(shm_handles)
self._test_passed = True

Expand Down
29 changes: 14 additions & 15 deletions qa/L0_shared_memory/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,10 @@ CLIENT_LOG="./client.log"
SHM_TEST=shared_memory_test.py
TEST_RESULT_FILE='test_results.txt'

if [[ ${TEST_WINDOWS} == 1 ]]; then
TRITON_DIR=${TRITON_DIR:=c:/tritonserver}
SERVER=${SERVER:=c:/tritonserver/bin/tritonserver.exe}
BACKEND_DIR=${BACKEND_DIR:=c:/tritonserver/backends}
MODELDIR=${MODELDIR:=c:/}
else
# Configure to support test on jetson as well
TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"}
SERVER=${TRITON_DIR}/bin/tritonserver
BACKEND_DIR=${TRITON_DIR}/backends
MODELDIR=${MODELDIR:=`pwd`}
fi
# Configure to support test on jetson as well
TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"}
SERVER=${TRITON_DIR}/bin/tritonserver
BACKEND_DIR=${TRITON_DIR}/backends
SERVER_ARGS_EXTRA="--backend-directory=${BACKEND_DIR}"
source ../common/util.sh

Expand All @@ -49,7 +41,6 @@ rm -fr *.log

for i in \
test_invalid_create_shm \
test_invalid_registration \
test_valid_create_set_register \
test_unregister_before_register \
test_unregister_after_register \
Expand All @@ -61,7 +52,7 @@ for i in \
test_unregisterall \
test_infer_offset_out_of_bound; do
for client_type in http grpc; do
SERVER_ARGS="--model-repository=${MODELDIR} --log-verbose=1 ${SERVER_ARGS_EXTRA}"
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1 ${SERVER_ARGS_EXTRA}"
SERVER_LOG="./$i.$client_type.server.log"
run_server
if [ "$SERVER_PID" == "0" ]; then
Expand All @@ -78,10 +69,18 @@ for i in \
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Test Failed\n***"
RET=1
else
check_test_results $TEST_RESULT_FILE 1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi
set -e

kill_server
kill $SERVER_PID
wait $SERVER_PID
done
done

Expand Down
41 changes: 31 additions & 10 deletions qa/common/util.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
from pathlib import Path
import shutil
import signal
import subprocess
import time
import sys
import signal
import shutil
import time
from pathlib import Path

from tritonclient.utils import InferenceServerException


def run_server(server_executable: str, launch_command: str, log_file):
if not Path(server_executable).is_file():
raise Exception(f"{server_executable} does not exist")
print(f"=== Running {launch_command}")
if sys.platform == "win32":
server = subprocess.Popen(launch_command, text=True, stdout=log_file, stderr=log_file)
server = subprocess.Popen(
launch_command, text=True, stdout=log_file, stderr=log_file
)
else:
server = subprocess.Popen(launch_command.split(), text=True, stdout=log_file, stderr=log_file)
server = subprocess.Popen(
launch_command.split(), text=True, stdout=log_file, stderr=log_file
)
time.sleep(3)
return server


def wait_for_server_ready(server_process, triton_client, timeout):
start = time.time()
while time.time() - start < timeout:
Expand All @@ -34,36 +40,47 @@ def wait_for_server_ready(server_process, triton_client, timeout):
print("=== Server is ready", flush=True)
return True
except InferenceServerException:
pass # Host not ready
pass # Host not ready
raise Exception(f"=== Timeout {timeout} secs. Server not ready. ===")


def kill_server(server_process):
# Only kill process if it's stil running
if server_process and not server_process.poll():
print("*\n*\n*\nTerminating server\n*\n*\n*\n")
# Terminate gracefully for Linux
if sys.platform == "win32":
server_process.kill()
server_process.kill()
else:
server_process.send_signal(signal.SIGINT)
server_process.wait()


def stream_to_log(client_log):
original_stdout = sys.stdout
original_sterr = sys.stderr
sys.stdout = sys.stderr = client_log
return original_stdout, original_sterr


def stream_to_console(original_stdout, original_sterr):
sys.stdout = original_stdout
sys.stderr = original_sterr


def remove_model_dir(model_dir_path: Path):
if not model_dir_path.is_dir():
return
shutil.rmtree(model_dir_path)

def create_model_dir(model_dir_path: Path, model_name: str, model_version: int, model_source_path: Path, model_config_path: Path):

def create_model_dir(
model_dir_path: Path,
model_name: str,
model_version: int,
model_source_path: Path,
model_config_path: Path,
):
remove_model_dir(model_dir_path)
model_dir_path = model_dir_path / model_name / str(model_version)
model_dir_path.mkdir(parents=True)
Expand All @@ -72,11 +89,15 @@ def create_model_dir(model_dir_path: Path, model_name: str, model_version: int,
shutil.copy(model_source_path, model_dir_path)
shutil.copy(model_config_path, model_dir_path.parent)

def replace_config_attribute(model_config_path: Path, current_attribute: str, desired_attribute: str):

def replace_config_attribute(
model_config_path: Path, current_attribute: str, desired_attribute: str
):
original_config = model_config_path.read_text()
new_config = original_config.replace(current_attribute, desired_attribute)
model_config_path.write_text(new_config)


def add_config_attribute(model_config_path: Path, new_attribute: str):
with model_config_path.open("a") as f:
f.write(new_attribute)
12 changes: 6 additions & 6 deletions src/shared_memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,11 @@ MapSharedMemory(
if (*mapped_addr == NULL) {
CloseSharedMemoryRegion(shm_handle);
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, std::string(
"unable to process address space: " +
std::to_string(GetLastError()))
.c_str());
TRITONSERVER_ERROR_INTERNAL,
std::string(
"unable to process address space, error code: " +
std::to_string(GetLastError()))
.c_str());
}
return nullptr;
}
Expand Down Expand Up @@ -191,7 +192,6 @@ SharedMemoryManager::RegisterSystemSharedMemory(
// Map and then close the shared memory handle
TRITONSERVER_Error* err_map =
MapSharedMemory(shm_handle, offset, byte_size, &mapped_addr);
// TODO: Test if we can close windows handles without invalidating mapping
TRITONSERVER_Error* err_close = CloseSharedMemoryRegion(shm_handle);
if (err_map != nullptr) {
return TRITONSERVER_ErrorNew(
Expand All @@ -207,7 +207,7 @@ SharedMemoryManager::RegisterSystemSharedMemory(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"failed to register shared memory region '" + name +
"': " + std::to_string(GetLastError()))
"' with error code: " + std::to_string(GetLastError()))
.c_str());
}

Expand Down

0 comments on commit ea16c37

Please sign in to comment.