Skip to content

Commit

Permalink
Merge pull request #14 from johnbradley/12-dataverse-models
Browse files Browse the repository at this point in the history
Download models from Imageomics Dataverse
  • Loading branch information
thibaulttabarin authored Sep 22, 2022
2 parents 5dc0ab3 + a637cea commit 68cc0f4
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 6 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/deploy-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ name: Create and publish a Docker image
on:
release:
types: [published]

env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
Expand Down Expand Up @@ -36,6 +35,8 @@ jobs:
uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc
with:
context: Segment_mini
build-args: |
DATAVERSE_API_TOKEN=${{ secrets.DATAVERSE_API_TOKEN }}
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
15 changes: 11 additions & 4 deletions Segment_mini/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
FROM ubuntu:20.04
ARG DATAVERSE_API_TOKEN

# Label
LABEL org.opencontainers.image.title="fish trait segmentation"
Expand Down Expand Up @@ -45,10 +46,16 @@ RUN curl -sLo ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py38
WORKDIR /pipeline
ENV TORCH_HOME=/pipeline/.cache/torch/

# Download Maruf Model
RUN gdown -O /pipeline/saved_models/ https://drive.google.com/uc?id=1HBSGXbWw5Vorj82buF-gCi6S2DpF4mFL
# Downlaod pretrained Model, it should to build cache outside the container
RUN wget -c --no-check-certificate -P /pipeline/.cache/torch/hub/checkpoints http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth
# Download Maruf Model from Imageomics Dataverse

ENV DATAVERSE_URL=https://datacommons.tdai.osu.edu
ADD scripts/dataverse_download.py /pipeline/dataverse_download.py

# Download Trained_model_SM.pth
RUN python /pipeline/dataverse_download.py ${DATAVERSE_URL} doi:10.5072/FK2/SWV0YL Trained_model_SM.pth saved_models/Trained_model_SM.pth

# Download se_resnext50_32x4d-a260b3a4.pth (dependency of Trained_model_SM.pth)
RUN python /pipeline/dataverse_download.py ${DATAVERSE_URL} doi:10.5072/FK2/CGWDW4 se_resnext50_32x4d-a260b3a4.pth .cache/torch/hub/checkpoints/se_resnext50_32x4d-a260b3a4.pth

# Setup pipeline specific scripts
ENV PATH="/pipeline:${PATH}"
Expand Down
2 changes: 1 addition & 1 deletion Segment_mini/env_segment_mini.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies:
- certifi==2021.10.8
- efficientnet-pytorch==0.6.3
- filelock==3.6.0
- gdown==4.4.0
- pyDataverse==0.3.1
- munch==2.5.0
- packaging==21.3
- pretrainedmodels==0.7.4
Expand Down
73 changes: 73 additions & 0 deletions Segment_mini/scripts/dataverse_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Script to download a dataset from a Dataverse (https://dataverse.org/)
import os
import sys
import hashlib
from pyDataverse.api import NativeApi, DataAccessApi


def download_file_in_dataset(base_url, api_token, doi, src, dest):
api = NativeApi(base_url, api_token)
data_api = DataAccessApi(base_url, api_token)
dataset = api.get_dataset(doi)
files_list = dataset.json()['data']['latestVersion']['files']
for dv_file in files_list:
remote_path = get_directory_path(dv_file)
if remote_path == src:
os.makedirs(os.path.dirname(dest), exist_ok=True)
filepath = download_file(data_api, dv_file, dest)
verify_checksum(dv_file, dest)
return
raise ValueError(f"Unable to find path {src} within {doi}.")


def get_directory_path(dv_file):
directory_label = dv_file.get("directoryLabel")
filename = dv_file["dataFile"]["filename"]
if directory_label:
return f"{directory_label}/{filename}"
return filename


def download_file(data_api, dv_file, filepath):
file_id = dv_file["dataFile"]["id"]
print("Downloading file {}, id {}".format(filepath, file_id))
response = data_api.get_datafile(file_id)
with open(filepath, "wb") as f:
f.write(response.content)
return filepath


def verify_checksum(dv_file, filepath):
checksum = dv_file["dataFile"]["checksum"]
checksum_type = checksum["type"]
checksum_value = checksum["value"]
if checksum_type != "MD5":
raise ValueError(f"Unsupported checksum type {checksum_type}")

with open(filepath, 'rb') as infile:
hash = hashlib.md5(infile.read()).hexdigest()
if checksum_value == hash:
print(f"Verified file checksum for {filepath}.")
else:
raise ValueError(f"Hash value mismatch for {filepath}: {checksum_value} vs {hash} ")


def show_usage():
print()
print(f"Usage: python {sys.argv[0]} <dataverse_base_url> <doi>\n")
print("To specify an API token set the DATAVERSE_API_TOKEN environment variable.")
print()


if __name__ == '__main__':
if len(sys.argv) != 5:
show_usage()
sys.exit(1)
else:
base_url = sys.argv[1]
doi = sys.argv[2]
source = sys.argv[3]
dest = sys.argv[4]
api_token = os.environ.get('DATAVERSE_API_TOKEN')
download_file_in_dataset(base_url, api_token, doi, source, dest)

0 comments on commit 68cc0f4

Please sign in to comment.