Skip to content

Commit

Permalink
Add progress_interval as an optional parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
robinjhuang committed Aug 8, 2024
1 parent c1d78d6 commit 5537d25
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
15 changes: 12 additions & 3 deletions model_filemanager/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from folder_paths import models_dir
import re
from typing import Callable, Any, Optional, Awaitable, Tuple
from typing import Callable, Any, Optional, Awaitable, Tuple, Dict
from enum import Enum
import time
from dataclasses import dataclass
Expand All @@ -27,12 +27,21 @@ def __init__(self, status: DownloadStatusType, progress_percentage: float, messa
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed

def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
"progress_percentage": self.progress_percentage,
"message": self.message,
"already_existed": self.already_existed
}

async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_sub_directory: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]) -> DownloadModelStatus:
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.
Expand Down Expand Up @@ -77,7 +86,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
await progress_callback(relative_path, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)

return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)

except Exception as e:
logging.error(f"Error in downloading model: {e}")
Expand Down
16 changes: 6 additions & 10 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import node_helpers
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from model_filemanager import download_model, DownloadStatus
from model_filemanager import download_model, DownloadModelStatus
from typing import Optional

class BinaryEventTypes:
Expand Down Expand Up @@ -563,18 +563,14 @@ async def post_history(request):

@routes.post("/models/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadStatus):
await self.send_json("download_progress", {
"filename": filename,
"progress_percentage": status.progress_percentage,
"status": status.status,
"message": status.message
})
async def report_progress(filename: str, status: DownloadModelStatus):
await self.send_json("download_progress", status.to_dict())

data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.

if not url or not model_directory or not model_filename:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
Expand All @@ -584,10 +580,10 @@ async def report_progress(filename: str, status: DownloadStatus):
logging.error("Client session is not initialized")
return web.Response(status=500)

task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress))
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
await task

return web.Response(status=200)
return web.json_response(task.result().to_dict())

async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
Expand Down

0 comments on commit 5537d25

Please sign in to comment.