From 5cf97f2a5890adbbff3f918317c6e0c9cb3593d3 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Tue, 13 Aug 2024 12:29:35 -0700 Subject: [PATCH] Restrict downloading to safetensor files only. --- model_filemanager/download_models.py | 5 +- .../download_models_test.py | 54 +++++++++---------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 621cbeae64a..712d59328f6 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -204,7 +204,7 @@ def validate_model_subdirectory(model_subdirectory: str) -> bool: return True -def validate_filename(filename): +def validate_filename(filename: str)-> bool: """ Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. @@ -214,6 +214,9 @@ def validate_filename(filename): Returns: bool: True if the filename is valid, False otherwise """ + if not filename.lower().endswith(('.sft', '.safetensors')): + return False + # Check if the filename is empty, None, or just whitespace if not filename or not filename.strip(): return False diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index 09d8fdcb4f7..66150a4682f 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -59,22 +59,22 @@ async def test_download_model_success(): mock_open.return_value.__enter__.return_value = mock_file time_values = itertools.count(0, 0.1) - with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.bin', 'checkpoints/model.bin')), \ + with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ patch('model_filemanager.check_file_exists', return_value=None), \ patch('builtins.open', mock_open), \ patch('time.time', side_effect=time_values): # Simulate time passing result = await download_model( mock_make_request, - 'model.bin', - 'http://example.com/model.bin', + 'model.sft', + 'http://example.com/model.sft', 'checkpoints', mock_progress_callback ) # Assert the result assert isinstance(result, DownloadModelStatus) - assert result.message == 'Successfully downloaded model.bin' + assert result.message == 'Successfully downloaded model.sft' assert result.status == 'completed' assert result.already_existed is False @@ -83,14 +83,14 @@ async def test_download_model_success(): # Check initial call mock_progress_callback.assert_any_call( - 'checkpoints/model.bin', - DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin", False) + 'checkpoints/model.sft', + DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False) ) # Check final call mock_progress_callback.assert_any_call( - 'checkpoints/model.bin', - DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin", False) + 'checkpoints/model.sft', + DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) ) # Verify file writing @@ -99,7 +99,7 @@ async def test_download_model_success(): mock_file.write.assert_any_call(b'c' * 200) # Verify request was made - mock_make_request.assert_called_once_with('http://example.com/model.bin') + mock_make_request.assert_called_once_with('http://example.com/model.sft') @pytest.mark.asyncio async def test_download_model_url_request_failure(): @@ -160,8 +160,8 @@ async def test_download_model_invalid_model_subdirectory(): result = await download_model( mock_make_request, - 'model.bin', - 'http://example.com/model.bin', + 'model.sft', + 'http://example.com/model.sft', '../bad_path', mock_progress_callback ) @@ -178,7 +178,7 @@ def test_create_model_path(tmp_path, monkeypatch): mock_models_dir = tmp_path / "models" monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) - model_name = "test_model.bin" + model_name = "test_model.sft" model_directory = "test_dir" file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir) @@ -190,30 +190,30 @@ def test_create_model_path(tmp_path, monkeypatch): @pytest.mark.asyncio async def test_check_file_exists_when_file_exists(tmp_path): - file_path = tmp_path / "existing_model.bin" + file_path = tmp_path / "existing_model.sft" file_path.touch() # Create an empty file mock_callback = AsyncMock() - result = await check_file_exists(str(file_path), "existing_model.bin", mock_callback, "test/existing_model.bin") + result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft") assert result is not None assert result.status == "completed" - assert result.message == "existing_model.bin already exists" + assert result.message == "existing_model.sft already exists" assert result.already_existed is True mock_callback.assert_called_once_with( - "test/existing_model.bin", - DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists", already_existed=True) + "test/existing_model.sft", + DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) ) @pytest.mark.asyncio async def test_check_file_exists_when_file_does_not_exist(tmp_path): - file_path = tmp_path / "non_existing_model.bin" + file_path = tmp_path / "non_existing_model.sft" mock_callback = AsyncMock() - result = await check_file_exists(str(file_path), "non_existing_model.bin", mock_callback, "test/non_existing_model.bin") + result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft") assert result is None mock_callback.assert_not_called() @@ -229,15 +229,15 @@ async def test_track_download_progress_no_content_length(): with patch('builtins.open', mock_open): result = await track_download_progress( - mock_response, '/mock/path/model.bin', 'model.bin', - mock_callback, 'models/model.bin', interval=0.1 + mock_response, '/mock/path/model.sft', 'model.sft', + mock_callback, 'models/model.sft', interval=0.1 ) assert result.status == "completed" # Check that progress was reported even without knowing the total size mock_callback.assert_any_call( - 'models/model.bin', - DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin", already_existed=False) + 'models/model.sft', + DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) ) @pytest.mark.asyncio @@ -256,8 +256,8 @@ async def test_track_download_progress_interval(): with patch('builtins.open', mock_open), \ patch('time.time', mock_time): await track_download_progress( - mock_response, '/mock/path/model.bin', 'model.bin', - mock_callback, 'models/model.bin', interval=1.0 + mock_response, '/mock/path/model.sft', 'model.sft', + mock_callback, 'models/model.sft', interval=1.0 ) # Print out the actual call count and the arguments of each call for debugging @@ -303,12 +303,10 @@ def test_empty_subdirectory(): @pytest.mark.parametrize("filename, expected", [ ("valid_model.safetensors", True), ("valid_model.sft", True), - ("another-valid_model.ckpt", True), ("valid model.safetensors", True), # Test with space ("UPPERCASE_MODEL.SAFETENSORS", True), - ("model_with.multiple.dots.pt", True), + ("model_with.multiple.dots.pt", False), ("", False), # Empty string - (None, False), # None value ("../../../etc/passwd", False), # Path traversal attempt ("/etc/passwd", False), # Absolute path ("\\windows\\system32\\config\\sam", False), # Windows path