diff --git a/CivitAI_Model.py b/CivitAI_Model.py index f187d01..4a02fc7 100644 --- a/CivitAI_Model.py +++ b/CivitAI_Model.py @@ -11,14 +11,12 @@ import comfy.utils import folder_paths - ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) MSG_PREFIX = '\33[1m\33[34m[CivitAI] \33[0m' WARN_PREFIX = '\33[1m\33[34m[CivitAI]\33[0m\33[93m Warning: \33[0m' ERR_PREFIX = '\33[1m\33[31m[CivitAI]\33[0m\33[1m Error: \33[0m' - class CivitAI_Model: ''' CivitAI Model Class © Civitai 2023 @@ -30,8 +28,8 @@ class CivitAI_Model: num_chunks = 8 chunk_size = 1024 max_retries = 120 - debug_response = False - warning = False + debug_response = True + warning = True def __init__(self, model_id, save_path, model_paths, model_types=[], token=None, model_version=None, download_chunks=None, max_download_retries=None, warning=True, debug_response=False): self.model_id = model_id @@ -49,25 +47,27 @@ def __init__(self, model_id, save_path, model_paths, model_types=[], token=None, self.file_size = 0 self.trained_words = None self.warning = warning - + if download_chunks: self.num_chunks = int(download_chunks) + self.token = None if token: self.token = token + else: + self.token = os.environ.get('CIVITAI_TOKEN', None) if max_download_retries: self.max_retries = int(max_download_retries) - + if debug_response: self.debug_response = True self.details() def details(self): - + # CHECK FOR EXISTING MODEL DATA - model_name = self.model_cached_name(self.model_id, self.version) if model_name and self.model_exists_disk(model_name): history_file_path = os.path.join(ROOT_PATH, 'download_history.json') @@ -93,7 +93,9 @@ def details(self): if file_id and file_id == file_version: self.name = name self.name_friendly = file.get('name_friendly') - self.download_url = f"{file.get('downloadUrl')}?token={self.token}" + self.download_url = file.get('downloadUrl') + if self.token: + self.download_url = self.download_url + f"?token={self.token}" self.trained_words = file.get('trained_words') self.file_details = file self.file_id = file_version @@ -121,9 +123,9 @@ def details(self): if hashes: self.file_sha256 = hashes.get('SHA256') return self.name, self.file_details - + del download_history - + # NO CACHE DATA FOUND | DOWNLOAD MODEL DETAILS model_url = f'{self.api}/models/{self.model_id}' @@ -131,14 +133,14 @@ def details(self): if response.status_code == 200: model_data = response.json() - + if self.debug_response: print(f"{MSG_PREFIX}API Response:") - print(''); print('') + print('') from pprint import pprint pprint(model_data, indent=4) - print(''); print('') - + print('') + model_versions = model_data.get('modelVersions') model_type = model_data.get('type', 'Model') self.type = model_type @@ -148,16 +150,26 @@ def details(self): raise Exception(f"{ERR_PREFIX}The model you requested is not a valid `{', '.join(self.valid_types)}`. Aborting!") if self.version: + print("self.version found") for version in model_versions: version_id = version.get('id') + print(f"Version IDs: {version_id} -> {self.version}") files = version.get('files') trained_words = version.get('trainedWords', []) model_download_url = version.get('downloadUrl', '') - if version_id == self.version and files: + print("Model download url:", model_download_url) + print(files) + if ( model_download_url and len(files) > 0 ) or ( version_id == self.version and len(files) > 0 ): for file in files: + if self.debug_response: + print("File Info:") + pprint(file, indent=4) download_url = file.get('downloadUrl') - if download_url == model_download_url: - self.download_url = download_url + f"?token={self.token}" + print("File download url:", download_url) + if download_url: + self.download_url = download_url + if self.token: + self.download_url = self.download_url + f"?token={self.token}" self.file_details = file self.file_id = file.get('id') self.name = file.get('name') @@ -169,16 +181,25 @@ def details(self): self.file_sha256 = hashes.get('SHA256') return self.download_url, self.file_details else: + print("No version found") version = model_versions[0] if version: version_id = version.get('id') + print(f"Version IDs: {version_id} -> {version}") files = version.get('files') model_download_url = version.get('downloadUrl', '') + print("Model download url:", model_download_url) trained_words = version.get('trainedWords', []) for file in files: + if self.debug_response: + print("File Info:") + pprint(file, indent=4) download_url = file.get('downloadUrl') - if download_url == model_download_url: - self.download_url = download_url + f"?token={self.token}" + print("File download url:", download_url) + if download_url: + self.download_url = download_url + if self.token: + self.download_url = self.download_url + f"?token={self.token}" self.file_details = file self.file_id = file.get('id') self.name = file.get('name') @@ -189,15 +210,12 @@ def details(self): if hashes: self.file_sha256 = hashes.get('SHA256') return self.download_url, self.file_details - else: raise Exception(f"{ERR_PREFIX}No cached model or model data found, and unable to reach CivitAI! Response Code: {response.status_code}\n Please try again later.") def download(self): - - # DOWNLAOD BYTE CHUNK - + def download_chunk(chunk_id, url, chunk_size, start_byte, end_byte, file_path, total_pbar, comfy_pbar, max_retries=30): retries = 0 retry_delay = 5 @@ -216,7 +234,7 @@ def download_chunk(chunk_id, url, chunk_size, start_byte, end_byte, file_path, t time.sleep(retry_delay * 10) total_pbar.set_postfix_str('') - file.seek(start_byte + downloaded_bytes) + file.seek(start_byte + downloaded_bytes) for chunk in response.iter_content(chunk_size=chunk_size): file.write(chunk) total_pbar.update(len(chunk)) @@ -231,9 +249,6 @@ def download_chunk(chunk_id, url, chunk_size, start_byte, end_byte, file_path, t else: raise Exception(f"{ERR_PREFIX}Unable to establish download connection.") except (requests.exceptions.RequestException, Exception) as e: - # We shouldn't warn on chunk loss, since end chunks may not be able to be established due to remaining filesize - #print(f"{WARN_PREFIX}Chunk {chunk_id} connection lost") - total_pbar.update() time.sleep(retry_delay) retries += 1 if retries > max_retries: @@ -242,13 +257,10 @@ def download_chunk(chunk_id, url, chunk_size, start_byte, end_byte, file_path, t if chunk_complete: break - + if not chunk_complete: raise Exception(f"{ERR_PREFIX}Unable to re-establish connection to CivitAI.") - - # GET FILE SIZE - def get_total_file_size(url): response = requests.get(url, stream=True) content_length = response.headers.get('Content-Length') @@ -260,16 +272,14 @@ def get_total_file_size(url): if content_range: total_bytes = int(re.search(r'/(\d+)', content_range).group(1)) return total_bytes - + if self.file_size: return self.file_size return None - # RESOLVE MODEL ID/VERSION TO FILENAME - model_name = self.model_cached_name(self.model_id, self.version) - + if model_name: model_path = self.model_exists_disk(model_name) if model_path: @@ -289,13 +299,9 @@ def get_total_file_size(url): else: self.name = self.download_url.split('/')[-1] - # NO MODEL FOUND! | DOWNLOAD MODEL FROM CIVITAI - print(f"{MSG_PREFIX}Downloading `{self.name}` from `{self.download_url}`") - save_path = os.path.join(self.model_path, self.name) # Assume default comfy folder, unless we take user input on extra paths - - # EXISTING MODEL FOUND -- CHECK SHA256 - + save_path = os.path.join(self.model_path, self.name) + if os.path.exists(save_path): print(f"{MSG_PREFIX}{self.type} file already exists at: {save_path}") self.dump_file_details() @@ -306,15 +312,13 @@ def get_total_file_size(url): else: print(f"{ERR_PREFIX}Existing {self.type} file's SHA256 does not match. Retrying download...") - # NO MODEL OR MODEL DATA AVAILABLE -- DOWNLOAD MODEL FROM CIVITAI - response = requests.head(self.download_url) - total_file_size = total_file_size = get_total_file_size(self.download_url) - + total_file_size = get_total_file_size(self.download_url) + response = requests.get(self.download_url, stream=True) if response.status_code != requests.codes.ok: raise Exception(f"{ERR_PREFIX}Failed to download {self.type} file from CivitAI. Status code: {response.status_code}") - + with open(save_path, 'wb') as file: file.seek(total_file_size - 1) file.write(b'\0') @@ -334,7 +338,7 @@ def get_total_file_size(url): for future in futures: future.result() - + total_pbar.close() model_sha256 = CivitAI_Model.calculate_sha256(save_path) if model_sha256 == self.file_sha256: @@ -343,11 +347,9 @@ def get_total_file_size(url): self.dump_file_details() return True else: - os.remove(save_path) # Remove Invalid / Broken / Insecure download file + os.remove(save_path) raise Exception(f"{ERR_PREFIX}{self.type} file's SHA256 does not match expected value after retry. Aborting download.") - - # DUMP MODEL DETAILS TO DOWNLOAD HISTORY - + def dump_file_details(self): history_file_path = os.path.join(ROOT_PATH, 'download_history.json') @@ -385,8 +387,6 @@ def dump_file_details(self): with open(history_file_path, 'w', encoding='utf-8') as history_file: json.dump(download_history, history_file, indent=4, ensure_ascii=False) - - # RESOLVE ID/VERSION TO FILENAME def model_cached_name(self, model_id, version_id): history_file_path = os.path.join(ROOT_PATH, 'download_history.json') @@ -398,7 +398,7 @@ def model_cached_name(self, model_id, version_id): version_id = int(version_id) if version_id else None if model_id_str in download_history: file_details_list = download_history[model_id_str] - for file_details in file_details_list: + for file_details in file_details_list: version = file_details.get('id') files = file_details.get('files') if files: @@ -409,23 +409,18 @@ def model_cached_name(self, model_id, version_id): elif self.model_exists_disk(name): return name return None - - - # CEHCK FOR MODEL ON DISK def model_exists_disk(self, name): for path in self.model_paths: if path and name: full_path = os.path.join(path, name) - if os.path.exists(full_path): + if os.path.exists(full_path): if os.path.getsize(full_path) <= 0: os.remove(full_path) else: return full_path return False - # CEHCK FOR MODEL ON DISK - def model_path(filename, search_paths): filename, _ = os.path.splitext(filename) for path in search_paths: @@ -438,8 +433,6 @@ def model_path(filename, search_paths): return os.path.join(root, file) return None - # CALCULATE SHA256 - @staticmethod def calculate_sha256(file_path): sha256_hash = hashlib.sha256() @@ -449,8 +442,7 @@ def calculate_sha256(file_path): sha256_hash.update(byte_block) return sha256_hash.hexdigest().upper() return 0 - - # STATIC HASH LOOKUP FOR MANUAL LOADING + @staticmethod def sha256_lookup(file_path): hash_value = CivitAI_Model.calculate_sha256(file_path) @@ -502,8 +494,6 @@ def sha256_lookup(file_path): return (None, None, None) - # STATIC DOWNLOAD HISTORY PUSH - @staticmethod def push_download_history(model_id, model_type, file_details): history_file_path = os.path.join(ROOT_PATH, 'download_history.json') diff --git a/README.md b/README.md index f5ad2c8..38b12f6 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Want to be share a fully reproducable workflow? ### Checkpoint Loader -- Load checkpoints directly from Civitai using just a Model `AIR` (`model id` or `version id`) +- Load checkpoints directly from Civitai using just a Model `AIR` tag (`model id` or `version id`), or by URL ex: https://civitai.com/models/112902/dreamshaper-xl?modelVersionId=126688 - Resources used in images will be automatically detected on image upload - Workflows copied from Civitai or shared via image metadata will include everything needed to generate the image including all resources diff --git a/civitai_checkpoint_loader.py b/civitai_checkpoint_loader.py index e802f0e..b0a1ee8 100644 --- a/civitai_checkpoint_loader.py +++ b/civitai_checkpoint_loader.py @@ -1,18 +1,10 @@ -import hashlib -import json import os -import requests -import sys -import time -from tqdm import tqdm import folder_paths -import comfy.sd -import comfy.utils from nodes import CheckpointLoaderSimple from .CivitAI_Model import CivitAI_Model -from .utils import short_paths_map, model_path +from .utils import short_paths_map, model_path, get_model_ids ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) @@ -67,13 +59,8 @@ def load_checkpoint(self, ckpt_air, ckpt_name, api_key=None, download_chunks=Non ckpt_id = None version_id = None - if '@' in ckpt_air: - ckpt_id, version_id = ckpt_air.split('@') - else: - ckpt_id = ckpt_air - - ckpt_id = int(ckpt_id) if ckpt_id else None - version_id = int(version_id) if version_id else None + ckpt_id, version_id = get_model_ids(ckpt_air) + print(f"CKPT_ID: {ckpt_id} VERSION: {version_id}") checkpoint_paths = short_paths_map(CHECKPOINTS) if download_path: diff --git a/civitai_lora_loader.py b/civitai_lora_loader.py index f25eafc..e265e9d 100644 --- a/civitai_lora_loader.py +++ b/civitai_lora_loader.py @@ -11,7 +11,7 @@ from nodes import LoraLoader from .CivitAI_Model import CivitAI_Model -from .utils import short_paths_map, model_path +from .utils import short_paths_map, model_path, get_model_ids ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) @@ -71,14 +71,9 @@ def load_lora(self, model, clip, lora_air, lora_name, strength_model, strength_c lora_id = None version_id = None - if '@' in lora_air: - lora_id, version_id = lora_air.split('@') - else: - lora_id = lora_air - - lora_id = int(lora_id) if lora_id else None - version_id = int(version_id) if version_id else None - + lora_id, version_id = get_model_ids(lora_air) + print(f"CKPT_ID: {lora_id} VERSION: {version_id}") + lora_paths = short_paths_map(LORAS) if download_path: if lora_paths.__contains__(download_path): diff --git a/utils.py b/utils.py index f041db2..44283f6 100644 --- a/utils.py +++ b/utils.py @@ -1,4 +1,5 @@ import os +from urllib.parse import urlparse, parse_qs def short_paths_map(paths): short_paths_map_dict = {} @@ -22,4 +23,24 @@ def model_path(filename, search_paths): if name.lower().strip() == filename or full_filename.lower().strip() == filename: return os.path.join(root, file) return None - \ No newline at end of file + +def get_model_ids(url_or_air): + """Extract the model ID and model version ID from a Civitai URL or AIR tag.""" + parsed_url = urlparse(url_or_air) + if parsed_url.scheme and parsed_url.netloc: + if "civitai.com" in parsed_url.netloc: + path_parts = parsed_url.path.strip('/').split('/') + if len(path_parts) >= 2 and 'models' in path_parts: + model_id = path_parts[path_parts.index('models') + 1] + query_params = parse_qs(parsed_url.query) + model_version_id = query_params.get('modelVersionId', [None])[0] + return model_id, model_version_id + elif '@' in url_or_air: + air_parts = url_or_air.split('@') + try: + model_id = int(air_parts[0]) + model_version_id = int(air_parts[1]) if air_parts[1] else None + return model_id, model_version_id + except ValueError: + raise ValueError(f"Invalid AIR tag format: {url_or_air}. Must be `model@modelVersionId`.") + raise ValueError(f"Unable to determine model ID, and version ID from input: {url_or_air}")