diff --git a/.gitignore b/.gitignore index 1eed167..6ac667c 100644 --- a/.gitignore +++ b/.gitignore @@ -112,18 +112,7 @@ venv.bak/ checkpoints/ \.wandb \wandb -core.login18* -IGNORE/ -output_* -log.txt -pad_crop_info.json - -output/ apex -results *.nrrd -Untitled.ipynb core.* - -projects/maastro_lung_proton_cbct_to_ct/experiments/2d_vnet_local.yaml diff --git a/environment.yml b/environment.yml deleted file mode 100644 index 6eee75e..0000000 --- a/environment.yml +++ /dev/null @@ -1,97 +0,0 @@ -name: gan_env -channels: - - pytorch - - defaults -dependencies: - - _libgcc_mutex=0.1=main - - blas=1.0=mkl - - ca-certificates=2020.7.22=0 - - certifi=2020.6.20=py38_0 - - conda-env=2.6.0=1 - - cudatoolkit=10.2.89=hfd86e86_1 - - freetype=2.10.2=h5ab3b9f_0 - - intel-openmp=2020.2=254 - - jpeg=9b=h024ee3a_2 - - lcms2=2.11=h396b838_0 - - ld_impl_linux-64=2.33.1=h53a641e_7 - - libedit=3.1.20191231=h14c3975_1 - - libffi=3.3=he6710b0_2 - - libgcc-ng=9.1.0=hdf63c60_0 - - libpng=1.6.37=hbc83047_0 - - libstdcxx-ng=9.1.0=hdf63c60_0 - - libtiff=4.1.0=h2733197_1 - - lz4-c=1.9.2=he6710b0_1 - - mkl=2020.2=256 - - mkl-service=2.3.0=py38he904b0f_0 - - mkl_fft=1.2.0=py38h23d657b_0 - - mkl_random=1.1.1=py38h0573a6f_0 - - ncurses=6.2=he6710b0_1 - - ninja=1.10.1=py38hfd86e86_0 - - numpy=1.19.1=py38hbc911f0_0 - - numpy-base=1.19.1=py38hfa32c7d_0 - - olefile=0.46=py_0 - - openssl=1.1.1h=h7b6447c_0 - - pip=20.2.3=py38_0 - - python=3.8.5=h7579374_1 - - pytorch=1.6.0=py3.8_cuda10.2.89_cudnn7.6.5_0 - - readline=8.0=h7b6447c_0 - - setuptools=49.6.0=py38_1 - - six=1.15.0=py_0 - - sqlite=3.33.0=h62c20be_0 - - tk=8.6.10=hbc83047_0 - - torchvision=0.7.0=py38_cu102 - - wheel=0.35.1=py_0 - - xz=5.2.5=h7b6447c_0 - - zlib=1.2.11=h7b6447c_3 - - zstd=1.4.5=h9ceee32_0 - - pip: - - absl-py==0.10.0 - - aiohttp==3.6.2 - - async-timeout==3.0.1 - - attrs==20.2.0 - - cachetools==4.1.1 - - chardet==3.0.4 - - click==7.1.2 - - configparser==5.0.0 - - docker-pycreds==0.4.0 - - future==0.18.2 - - gitdb==4.0.5 - - gitpython==3.1.9 - - google-auth==1.22.0 - - google-auth-oauthlib==0.4.1 - - grpcio==1.32.0 - - idna==2.10 - - markdown==3.2.2 - - memcnn==1.4.0 - - multidict==4.7.6 - - oauthlib==3.1.0 - - omegaconf==2.0.2 - - pathlib2==2.3.5 - - pathtools==0.1.2 - - pillow==6.2.2 - - promise==2.3 - - protobuf==3.13.0 - - psutil==5.7.2 - - pyasn1==0.4.8 - - pyasn1-modules==0.2.8 - - python-dateutil==2.8.1 - - pyyaml==5.3.1 - - requests==2.24.0 - - requests-oauthlib==1.3.0 - - rsa==4.6 - - sentry-sdk==0.18.0 - - shortuuid==1.0.1 - - simpleitk==2.0.0 - - smmap==3.0.4 - - subprocess32==3.5.4 - - tensorboard==2.3.0 - - tensorboard-plugin-wit==1.7.0 - - tqdm==4.50.0 - - typing-extensions==3.7.4.3 - - urllib3==1.25.10 - - wandb==0.10.4 - - watchdog==0.10.3 - - werkzeug==1.0.1 - - yarl==1.6.0 -prefix: /workspace/miniconda3/envs/gan_env - diff --git a/ganslate/nn/gans/base.py b/ganslate/nn/gans/base.py index f60d4c4..624c583 100644 --- a/ganslate/nn/gans/base.py +++ b/ganslate/nn/gans/base.py @@ -1,5 +1,6 @@ from loguru import logger import os +import sys from abc import ABC, abstractmethod from pathlib import Path @@ -115,9 +116,13 @@ def setup(self): "The (main) generator has to be named `G` or `G_AB`." if self.conf[self.conf.mode].mixed_precision: - from apex import amp - - # Allow the methods to access AMP that was imported here + try: + from apex import amp + except ModuleNotFoundError: + sys.exit("\nMixed precision not installed! " + "Install Nvidia Apex mixed precision support " + "by running `ganslate install-nvidia-apex`'\n") + # Allow the methods to access AMP that's imported here globals()["amp"] = amp # Initialize Generators and Discriminators diff --git a/ganslate/utils/cli/interface.py b/ganslate/utils/cli/interface.py index 08b3630..933ac5f 100644 --- a/ganslate/utils/cli/interface.py +++ b/ganslate/utils/cli/interface.py @@ -55,5 +55,21 @@ def download_dataset(name, path): download_script_path = "ganslate/utils/scripts/download_cyclegan_datasets.sh" subprocess.call(["bash", download_script_path, name, path]) +# Install Nvidia Apex +@interface.command(help="Install Nvidia Apex for mixed precision support.") +@click.option( + "--cpp/--python", + default=True, + help=("C++ support is faster and preferred, use Python fallback " + "only when CUDA is not installed natively.") +) +def install_nvidia_apex(cpp): + # TODO: (Ibro) I need to verify this in a few days when I have access to the GPU + cmd = 'pip install -v --disable-pip-version-check --no-cache-dir' + if cpp: + cmd += ' --global-option="--cpp_ext" --global-option="--cuda_ext"' + cmd += ' git+https://github.com/NVIDIA/apex.git' + subprocess.run(cmd.split(' ')) + if __name__ == "__main__": interface() \ No newline at end of file diff --git a/install_apex.sh b/install_apex.sh deleted file mode 100755 index 7811dcc..0000000 --- a/install_apex.sh +++ /dev/null @@ -1,3 +0,0 @@ -git clone https://github.com/NVIDIA/apex -cd apex -pip install -v --no-cache-dir ./ diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 9a8a4ca..0000000 --- a/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -mkdocs -mkdocs-material