Skip to content

Commit

Permalink
Restrict downloading to safetensor files only.
Browse files Browse the repository at this point in the history
  • Loading branch information
robinjhuang committed Aug 13, 2024
1 parent bc85c0d commit 5cf97f2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
5 changes: 4 additions & 1 deletion model_filemanager/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def validate_model_subdirectory(model_subdirectory: str) -> bool:

return True

def validate_filename(filename):
def validate_filename(filename: str)-> bool:
"""
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Expand All @@ -214,6 +214,9 @@ def validate_filename(filename):
Returns:
bool: True if the filename is valid, False otherwise
"""
if not filename.lower().endswith(('.sft', '.safetensors')):
return False

# Check if the filename is empty, None, or just whitespace
if not filename or not filename.strip():
return False
Expand Down
54 changes: 26 additions & 28 deletions tests-unit/prompt_server_test/download_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,22 @@ async def test_download_model_success():
mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1)

with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.bin', 'checkpoints/model.bin')), \
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \
patch('time.time', side_effect=time_values): # Simulate time passing

result = await download_model(
mock_make_request,
'model.bin',
'http://example.com/model.bin',
'model.sft',
'http://example.com/model.sft',
'checkpoints',
mock_progress_callback
)

# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message == 'Successfully downloaded model.bin'
assert result.message == 'Successfully downloaded model.sft'
assert result.status == 'completed'
assert result.already_existed is False

Expand All @@ -83,14 +83,14 @@ async def test_download_model_success():

# Check initial call
mock_progress_callback.assert_any_call(
'checkpoints/model.bin',
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin", False)
'checkpoints/model.sft',
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
)

# Check final call
mock_progress_callback.assert_any_call(
'checkpoints/model.bin',
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin", False)
'checkpoints/model.sft',
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
)

# Verify file writing
Expand All @@ -99,7 +99,7 @@ async def test_download_model_success():
mock_file.write.assert_any_call(b'c' * 200)

# Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.bin')
mock_make_request.assert_called_once_with('http://example.com/model.sft')

@pytest.mark.asyncio
async def test_download_model_url_request_failure():
Expand Down Expand Up @@ -160,8 +160,8 @@ async def test_download_model_invalid_model_subdirectory():

result = await download_model(
mock_make_request,
'model.bin',
'http://example.com/model.bin',
'model.sft',
'http://example.com/model.sft',
'../bad_path',
mock_progress_callback
)
Expand All @@ -178,7 +178,7 @@ def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))

model_name = "test_model.bin"
model_name = "test_model.sft"
model_directory = "test_dir"

file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
Expand All @@ -190,30 +190,30 @@ def test_create_model_path(tmp_path, monkeypatch):

@pytest.mark.asyncio
async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.bin"
file_path = tmp_path / "existing_model.sft"
file_path.touch() # Create an empty file

mock_callback = AsyncMock()

result = await check_file_exists(str(file_path), "existing_model.bin", mock_callback, "test/existing_model.bin")
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")

assert result is not None
assert result.status == "completed"
assert result.message == "existing_model.bin already exists"
assert result.message == "existing_model.sft already exists"
assert result.already_existed is True

mock_callback.assert_called_once_with(
"test/existing_model.bin",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists", already_existed=True)
"test/existing_model.sft",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
)

@pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.bin"
file_path = tmp_path / "non_existing_model.sft"

mock_callback = AsyncMock()

result = await check_file_exists(str(file_path), "non_existing_model.bin", mock_callback, "test/non_existing_model.bin")
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")

assert result is None
mock_callback.assert_not_called()
Expand All @@ -229,15 +229,15 @@ async def test_track_download_progress_no_content_length():

with patch('builtins.open', mock_open):
result = await track_download_progress(
mock_response, '/mock/path/model.bin', 'model.bin',
mock_callback, 'models/model.bin', interval=0.1
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=0.1
)

assert result.status == "completed"
# Check that progress was reported even without knowing the total size
mock_callback.assert_any_call(
'models/model.bin',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin", already_existed=False)
'models/model.sft',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
)

@pytest.mark.asyncio
Expand All @@ -256,8 +256,8 @@ async def test_track_download_progress_interval():
with patch('builtins.open', mock_open), \
patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.bin', 'model.bin',
mock_callback, 'models/model.bin', interval=1.0
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=1.0
)

# Print out the actual call count and the arguments of each call for debugging
Expand Down Expand Up @@ -303,12 +303,10 @@ def test_empty_subdirectory():
@pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True),
("valid_model.sft", True),
("another-valid_model.ckpt", True),
("valid model.safetensors", True), # Test with space
("UPPERCASE_MODEL.SAFETENSORS", True),
("model_with.multiple.dots.pt", True),
("model_with.multiple.dots.pt", False),
("", False), # Empty string
(None, False), # None value
("../../../etc/passwd", False), # Path traversal attempt
("/etc/passwd", False), # Absolute path
("\\windows\\system32\\config\\sam", False), # Windows path
Expand Down

0 comments on commit 5cf97f2

Please sign in to comment.