Skip to content

Commit

Permalink
install nvidia apex via cli
Browse files Browse the repository at this point in the history
ibro45 committed Aug 31, 2021

Verified

This commit was signed with the committer’s verified signature. The key has expired.
Czaki Grzegorz Bokota
1 parent 3718e83 commit a4e374d
Showing 6 changed files with 24 additions and 116 deletions.
11 changes: 0 additions & 11 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -112,18 +112,7 @@ venv.bak/
checkpoints/
\.wandb
\wandb
core.login18*
IGNORE/
output_*
log.txt
pad_crop_info.json

output/

apex
results
*.nrrd
Untitled.ipynb
core.*

projects/maastro_lung_proton_cbct_to_ct/experiments/2d_vnet_local.yaml
97 changes: 0 additions & 97 deletions environment.yml

This file was deleted.

11 changes: 8 additions & 3 deletions ganslate/nn/gans/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from loguru import logger
import os
import sys
from abc import ABC, abstractmethod
from pathlib import Path

@@ -115,9 +116,13 @@ def setup(self):
"The (main) generator has to be named `G` or `G_AB`."

if self.conf[self.conf.mode].mixed_precision:
from apex import amp

# Allow the methods to access AMP that was imported here
try:
from apex import amp
except ModuleNotFoundError:
sys.exit("\nMixed precision not installed! "
"Install Nvidia Apex mixed precision support "
"by running `ganslate install-nvidia-apex`'\n")
# Allow the methods to access AMP that's imported here
globals()["amp"] = amp

# Initialize Generators and Discriminators
16 changes: 16 additions & 0 deletions ganslate/utils/cli/interface.py
Original file line number Diff line number Diff line change
@@ -55,5 +55,21 @@ def download_dataset(name, path):
download_script_path = "ganslate/utils/scripts/download_cyclegan_datasets.sh"
subprocess.call(["bash", download_script_path, name, path])

# Install Nvidia Apex
@interface.command(help="Install Nvidia Apex for mixed precision support.")
@click.option(
"--cpp/--python",
default=True,
help=("C++ support is faster and preferred, use Python fallback "
"only when CUDA is not installed natively.")
)
def install_nvidia_apex(cpp):
# TODO: (Ibro) I need to verify this in a few days when I have access to the GPU
cmd = 'pip install -v --disable-pip-version-check --no-cache-dir'
if cpp:
cmd += ' --global-option="--cpp_ext" --global-option="--cuda_ext"'
cmd += ' git+https://github.com/NVIDIA/apex.git'
subprocess.run(cmd.split(' '))

if __name__ == "__main__":
interface()
3 changes: 0 additions & 3 deletions install_apex.sh

This file was deleted.

2 changes: 0 additions & 2 deletions requirements.txt

This file was deleted.

0 comments on commit a4e374d

Please sign in to comment.