Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduction of support for CivitAI URLs #22

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 57 additions & 67 deletions CivitAI_Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@
import comfy.utils
import folder_paths


ROOT_PATH = os.path.dirname(os.path.abspath(__file__))

MSG_PREFIX = '\33[1m\33[34m[CivitAI] \33[0m'
WARN_PREFIX = '\33[1m\33[34m[CivitAI]\33[0m\33[93m Warning: \33[0m'
ERR_PREFIX = '\33[1m\33[31m[CivitAI]\33[0m\33[1m Error: \33[0m'


class CivitAI_Model:
'''
CivitAI Model Class © Civitai 2023
Expand All @@ -30,8 +28,8 @@ class CivitAI_Model:
num_chunks = 8
chunk_size = 1024
max_retries = 120
debug_response = False
warning = False
debug_response = True
warning = True

def __init__(self, model_id, save_path, model_paths, model_types=[], token=None, model_version=None, download_chunks=None, max_download_retries=None, warning=True, debug_response=False):
self.model_id = model_id
Expand All @@ -49,25 +47,27 @@ def __init__(self, model_id, save_path, model_paths, model_types=[], token=None,
self.file_size = 0
self.trained_words = None
self.warning = warning

if download_chunks:
self.num_chunks = int(download_chunks)

self.token = None
if token:
self.token = token
else:
self.token = os.environ.get('CIVITAI_TOKEN', None)

if max_download_retries:
self.max_retries = int(max_download_retries)

if debug_response:
self.debug_response = True

self.details()

def details(self):

# CHECK FOR EXISTING MODEL DATA

model_name = self.model_cached_name(self.model_id, self.version)
if model_name and self.model_exists_disk(model_name):
history_file_path = os.path.join(ROOT_PATH, 'download_history.json')
Expand All @@ -93,7 +93,9 @@ def details(self):
if file_id and file_id == file_version:
self.name = name
self.name_friendly = file.get('name_friendly')
self.download_url = f"{file.get('downloadUrl')}?token={self.token}"
self.download_url = file.get('downloadUrl')
if self.token:
self.download_url = self.download_url + f"?token={self.token}"
self.trained_words = file.get('trained_words')
self.file_details = file
self.file_id = file_version
Expand Down Expand Up @@ -121,24 +123,24 @@ def details(self):
if hashes:
self.file_sha256 = hashes.get('SHA256')
return self.name, self.file_details

del download_history

# NO CACHE DATA FOUND | DOWNLOAD MODEL DETAILS

model_url = f'{self.api}/models/{self.model_id}'
response = requests.get(model_url)

if response.status_code == 200:
model_data = response.json()

if self.debug_response:
print(f"{MSG_PREFIX}API Response:")
print(''); print('')
print('')
from pprint import pprint
pprint(model_data, indent=4)
print(''); print('')
print('')

model_versions = model_data.get('modelVersions')
model_type = model_data.get('type', 'Model')
self.type = model_type
Expand All @@ -148,16 +150,26 @@ def details(self):
raise Exception(f"{ERR_PREFIX}The model you requested is not a valid `{', '.join(self.valid_types)}`. Aborting!")

if self.version:
print("self.version found")
for version in model_versions:
version_id = version.get('id')
print(f"Version IDs: {version_id} -> {self.version}")
files = version.get('files')
trained_words = version.get('trainedWords', [])
model_download_url = version.get('downloadUrl', '')
if version_id == self.version and files:
print("Model download url:", model_download_url)
print(files)
if ( model_download_url and len(files) > 0 ) or ( version_id == self.version and len(files) > 0 ):
for file in files:
if self.debug_response:
print("File Info:")
pprint(file, indent=4)
download_url = file.get('downloadUrl')
if download_url == model_download_url:
self.download_url = download_url + f"?token={self.token}"
print("File download url:", download_url)
if download_url:
self.download_url = download_url
if self.token:
self.download_url = self.download_url + f"?token={self.token}"
self.file_details = file
self.file_id = file.get('id')
self.name = file.get('name')
Expand All @@ -169,16 +181,25 @@ def details(self):
self.file_sha256 = hashes.get('SHA256')
return self.download_url, self.file_details
else:
print("No version found")
version = model_versions[0]
if version:
version_id = version.get('id')
print(f"Version IDs: {version_id} -> {version}")
files = version.get('files')
model_download_url = version.get('downloadUrl', '')
print("Model download url:", model_download_url)
trained_words = version.get('trainedWords', [])
for file in files:
if self.debug_response:
print("File Info:")
pprint(file, indent=4)
download_url = file.get('downloadUrl')
if download_url == model_download_url:
self.download_url = download_url + f"?token={self.token}"
print("File download url:", download_url)
if download_url:
self.download_url = download_url
if self.token:
self.download_url = self.download_url + f"?token={self.token}"
self.file_details = file
self.file_id = file.get('id')
self.name = file.get('name')
Expand All @@ -189,15 +210,12 @@ def details(self):
if hashes:
self.file_sha256 = hashes.get('SHA256')
return self.download_url, self.file_details


else:
raise Exception(f"{ERR_PREFIX}No cached model or model data found, and unable to reach CivitAI! Response Code: {response.status_code}\n Please try again later.")

def download(self):

# DOWNLAOD BYTE CHUNK


def download_chunk(chunk_id, url, chunk_size, start_byte, end_byte, file_path, total_pbar, comfy_pbar, max_retries=30):
retries = 0
retry_delay = 5
Expand All @@ -216,7 +234,7 @@ def download_chunk(chunk_id, url, chunk_size, start_byte, end_byte, file_path, t
time.sleep(retry_delay * 10)
total_pbar.set_postfix_str('')

file.seek(start_byte + downloaded_bytes)
file.seek(start_byte + downloaded_bytes)
for chunk in response.iter_content(chunk_size=chunk_size):
file.write(chunk)
total_pbar.update(len(chunk))
Expand All @@ -231,9 +249,6 @@ def download_chunk(chunk_id, url, chunk_size, start_byte, end_byte, file_path, t
else:
raise Exception(f"{ERR_PREFIX}Unable to establish download connection.")
except (requests.exceptions.RequestException, Exception) as e:
# We shouldn't warn on chunk loss, since end chunks may not be able to be established due to remaining filesize
#print(f"{WARN_PREFIX}Chunk {chunk_id} connection lost")
total_pbar.update()
time.sleep(retry_delay)
retries += 1
if retries > max_retries:
Expand All @@ -242,13 +257,10 @@ def download_chunk(chunk_id, url, chunk_size, start_byte, end_byte, file_path, t

if chunk_complete:
break

if not chunk_complete:
raise Exception(f"{ERR_PREFIX}Unable to re-establish connection to CivitAI.")


# GET FILE SIZE

def get_total_file_size(url):
response = requests.get(url, stream=True)
content_length = response.headers.get('Content-Length')
Expand All @@ -260,16 +272,14 @@ def get_total_file_size(url):
if content_range:
total_bytes = int(re.search(r'/(\d+)', content_range).group(1))
return total_bytes

if self.file_size:
return self.file_size

return None

# RESOLVE MODEL ID/VERSION TO FILENAME

model_name = self.model_cached_name(self.model_id, self.version)

if model_name:
model_path = self.model_exists_disk(model_name)
if model_path:
Expand All @@ -289,13 +299,9 @@ def get_total_file_size(url):
else:
self.name = self.download_url.split('/')[-1]

# NO MODEL FOUND! | DOWNLOAD MODEL FROM CIVITAI

print(f"{MSG_PREFIX}Downloading `{self.name}` from `{self.download_url}`")
save_path = os.path.join(self.model_path, self.name) # Assume default comfy folder, unless we take user input on extra paths

# EXISTING MODEL FOUND -- CHECK SHA256

save_path = os.path.join(self.model_path, self.name)

if os.path.exists(save_path):
print(f"{MSG_PREFIX}{self.type} file already exists at: {save_path}")
self.dump_file_details()
Expand All @@ -306,15 +312,13 @@ def get_total_file_size(url):
else:
print(f"{ERR_PREFIX}Existing {self.type} file's SHA256 does not match. Retrying download...")

# NO MODEL OR MODEL DATA AVAILABLE -- DOWNLOAD MODEL FROM CIVITAI

response = requests.head(self.download_url)
total_file_size = total_file_size = get_total_file_size(self.download_url)
total_file_size = get_total_file_size(self.download_url)

response = requests.get(self.download_url, stream=True)
if response.status_code != requests.codes.ok:
raise Exception(f"{ERR_PREFIX}Failed to download {self.type} file from CivitAI. Status code: {response.status_code}")

with open(save_path, 'wb') as file:
file.seek(total_file_size - 1)
file.write(b'\0')
Expand All @@ -334,7 +338,7 @@ def get_total_file_size(url):

for future in futures:
future.result()

total_pbar.close()
model_sha256 = CivitAI_Model.calculate_sha256(save_path)
if model_sha256 == self.file_sha256:
Expand All @@ -343,11 +347,9 @@ def get_total_file_size(url):
self.dump_file_details()
return True
else:
os.remove(save_path) # Remove Invalid / Broken / Insecure download file
os.remove(save_path)
raise Exception(f"{ERR_PREFIX}{self.type} file's SHA256 does not match expected value after retry. Aborting download.")

# DUMP MODEL DETAILS TO DOWNLOAD HISTORY


def dump_file_details(self):
history_file_path = os.path.join(ROOT_PATH, 'download_history.json')

Expand Down Expand Up @@ -385,8 +387,6 @@ def dump_file_details(self):

with open(history_file_path, 'w', encoding='utf-8') as history_file:
json.dump(download_history, history_file, indent=4, ensure_ascii=False)

# RESOLVE ID/VERSION TO FILENAME

def model_cached_name(self, model_id, version_id):
history_file_path = os.path.join(ROOT_PATH, 'download_history.json')
Expand All @@ -398,7 +398,7 @@ def model_cached_name(self, model_id, version_id):
version_id = int(version_id) if version_id else None
if model_id_str in download_history:
file_details_list = download_history[model_id_str]
for file_details in file_details_list:
for file_details in file_details_list:
version = file_details.get('id')
files = file_details.get('files')
if files:
Expand All @@ -409,23 +409,18 @@ def model_cached_name(self, model_id, version_id):
elif self.model_exists_disk(name):
return name
return None


# CEHCK FOR MODEL ON DISK

def model_exists_disk(self, name):
for path in self.model_paths:
if path and name:
full_path = os.path.join(path, name)
if os.path.exists(full_path):
if os.path.exists(full_path):
if os.path.getsize(full_path) <= 0:
os.remove(full_path)
else:
return full_path
return False

# CEHCK FOR MODEL ON DISK

def model_path(filename, search_paths):
filename, _ = os.path.splitext(filename)
for path in search_paths:
Expand All @@ -438,8 +433,6 @@ def model_path(filename, search_paths):
return os.path.join(root, file)
return None

# CALCULATE SHA256

@staticmethod
def calculate_sha256(file_path):
sha256_hash = hashlib.sha256()
Expand All @@ -449,8 +442,7 @@ def calculate_sha256(file_path):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest().upper()
return 0

# STATIC HASH LOOKUP FOR MANUAL LOADING

@staticmethod
def sha256_lookup(file_path):
hash_value = CivitAI_Model.calculate_sha256(file_path)
Expand Down Expand Up @@ -502,8 +494,6 @@ def sha256_lookup(file_path):

return (None, None, None)

# STATIC DOWNLOAD HISTORY PUSH

@staticmethod
def push_download_history(model_id, model_type, file_details):
history_file_path = os.path.join(ROOT_PATH, 'download_history.json')
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Want to be share a fully reproducable workflow?
<img width="300" src="https://github.com/civitai/comfy-nodes/assets/607609/a581144d-8e6f-4798-96ec-eba92ceef927"/>

### Checkpoint Loader
- Load checkpoints directly from Civitai using just a Model `AIR` (`model id` or `version id`)
- Load checkpoints directly from Civitai using just a Model `AIR` tag (`model id` or `version id`), or by URL ex: https://civitai.com/models/112902/dreamshaper-xl?modelVersionId=126688
- Resources used in images will be automatically detected on image upload
- Workflows copied from Civitai or shared via image metadata will include everything needed to generate the image including all resources

Expand Down
Loading