Skip to content

Commit

Permalink
Merge pull request #88 from jmisilo/jmisilo/fix/model-weights-url-mis…
Browse files Browse the repository at this point in the history
…match

Model weights url mismatch
  • Loading branch information
jmisilo authored Oct 29, 2023
2 parents 8afd2a6 + 4663547 commit 4f74c7a
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/utils/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@

import gdown

MODEL_WEIGHTS = {
"L": "1Gh32arzhW06C1ZJyzcJSSfdJDi3RgWoG",
"S": "1pSQruQyg8KJq6VmzhMLFbT_VaHJMdlWF",
}


def download_weights(checkpoint_fpath, model_size="L"):
"""
Downloads weights from Google Drive.
"""

download_id = (
"1pSQruQyg8KJq6VmzhMLFbT_VaHJMdlWF"
if model_size.strip().upper() == "L"
else "1Gh32arzhW06C1ZJyzcJSSfdJDi3RgWoG"
)
download_id = MODEL_WEIGHTS[model_size.strip().upper()]

gdown.download(
f"https://drive.google.com/uc?id={download_id}", checkpoint_fpath, quiet=False
Expand Down

0 comments on commit 4f74c7a

Please sign in to comment.