From 352626aee62b69922271d9b618fa1179f3c57428 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Sun, 29 Aug 2021 21:30:56 +0200 Subject: [PATCH] functioning packaging and entry point --- ganslate/engines/utils.py | 22 +++++++++++++ ganslate/utils/builders.py | 11 +++++-- ganslate/utils/interface.py | 44 ++++++++++++++++++++++++++ setup.cfg | 62 +++++++++++++++++++++++++++++++++++++ setup.py | 13 ++++++++ tools/infer.py | 26 ---------------- tools/test.py | 29 ----------------- tools/train.py | 28 ----------------- 8 files changed, 149 insertions(+), 86 deletions(-) create mode 100644 ganslate/engines/utils.py create mode 100644 ganslate/utils/interface.py create mode 100644 setup.py delete mode 100644 tools/infer.py delete mode 100644 tools/test.py delete mode 100644 tools/train.py diff --git a/ganslate/engines/utils.py b/ganslate/engines/utils.py new file mode 100644 index 00000000..973c3577 --- /dev/null +++ b/ganslate/engines/utils.py @@ -0,0 +1,22 @@ +from ganslate.engines.trainer import Trainer +from ganslate.engines.validator_tester import Tester +from ganslate.engines.inferer import Inferer +from ganslate.utils import communication, environment +from ganslate.utils.builders import build_conf + + +ENGINES = { + 'train': Trainer, + 'test': Tester, + 'infer': Inferer +} + +def init_engine(mode): + assert mode in ENGINES.keys() + + # inits distributed mode if ran with torch.distributed.launch + communication.init_distributed() + environment.setup_threading() + + conf = build_conf() + return ENGINES[mode](conf) diff --git a/ganslate/utils/builders.py b/ganslate/utils/builders.py index d87fd64e..a6d717ce 100644 --- a/ganslate/utils/builders.py +++ b/ganslate/utils/builders.py @@ -1,4 +1,5 @@ import copy +import sys import torch from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -13,9 +14,13 @@ def build_conf(): - cli = omegaconf.OmegaConf.from_cli() - conf = init_config(cli.pop("config"), config_class=Config) - return omegaconf.OmegaConf.merge(conf, cli) + cli = omegaconf.OmegaConf.from_dotlist(sys.argv[2:]) + + yaml_conf = cli.pop("config") + cli_conf_overrides = cli + + conf = init_config(yaml_conf, config_class=Config) + return omegaconf.OmegaConf.merge(conf, cli_conf_overrides) def build_loader(conf): diff --git a/ganslate/utils/interface.py b/ganslate/utils/interface.py new file mode 100644 index 00000000..7dfe9c7a --- /dev/null +++ b/ganslate/utils/interface.py @@ -0,0 +1,44 @@ +import sys +from ganslate.engines.utils import init_engine + + +def run_interface(): + Interface().run() + + +class Interface: + def __init__(self): + + self.COMMANDS = { + 'help': self._help, + 'train': self._train_test_infer, + 'test': self._train_test_infer, + 'infer': self._train_test_infer, + 'new-project': self._new_project, + } + self.mode = sys.argv[1] if len(sys.argv) > 1 else 'help' + + def _help(self): + # TODO: this is a temporary help. find what's the best convention and way to do it + msg = "\nGANSLATE COMMAND LIST\n\n" + msg += "help List all available commands and info on them.\n" + msg += "train Run training.\n" + msg += "test Test a trained model. Possible only with paired dataset.\n" + msg += "infer Perform inference with a trained model.\n" + msg += "new-project Generate a project template. Specify the name and the path where it should be generated\n" + sys.stdout.write(msg + "\n") + + def _train_test_infer(self): + engine = init_engine(self.mode) + engine.run() + + def _new_project(self): + raise NotImplementedError + + def run(self): + if self.mode not in self.COMMANDS.keys(): + sys.stderr.write(f"\nUnknown command `{self.mode}`\n") + self.mode = 'help' + + command = self.COMMANDS[self.mode] + command() diff --git a/setup.cfg b/setup.cfg index c0d3207a..718fb386 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,65 @@ +[metadata] +name = ganslate +version = 0.1.0 +author = "ganslate team" +# author-email = +url = https://github.com/Maastro-CDS-Imaging-Group/ganslate +description = Add a short description here! +long-description = file: README.rst +long-description-content-type = text/markdown +license = mit +platforms = any +keywords = one, two +classifiers = + Development Status :: 2 - Pre-Alpha, + Intended Audience :: Science/Research, + Natural Language :: English, + Operating System :: OS Independent, + Programming Language :: Python :: 3.6, + Programming Language :: Python :: 3.7, + Programming Language :: Python :: 3.8, + Programming Language :: Python :: 3.9, +project-urls = + Documentation = https://ganslate.netlify.app/ # TODO: Update + +[options] +zip_safe = False +packages = find: +include_package_data = True +#setup_requires = +# Add here dependencies of your project (semicolon/line-separated), e.g. +install_requires = + torch + opencv-python + simpleitk + opencv-python + memcnn + loguru + wandb + tensorboard + +# The usage of test_requires is discouraged, see `Dependency Management` docs +# tests_require = pytest; pytest-cov +# Require a specific Python version, e.g. Python 2.7 or >= 3.4 +# python_requires = >=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.* + +[options.packages.find] +exclude = + tests + +[options.extras_require] +# Add here additional requirements for extra features, to install with: +# `pip install clinical-evaluation[PDF]` like: +# PDF = ReportLab; RXP +# Add here test requirements (semicolon/line-separated) +testing = + pytest + pytest-cov + +[options.entry_points] +console_scripts = + ganslate = ganslate.utils.interface:run_interface + [yapf] based_on_style = google column_limit = 100 diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..0ac365cb --- /dev/null +++ b/setup.py @@ -0,0 +1,13 @@ +import sys + +from pkg_resources import VersionConflict, require +from setuptools import setup + +try: + require('setuptools>=38.3') +except VersionConflict: + print("Error: version of setuptools is too old (<38.3)!") + sys.exit(1) + +if __name__ == "__main__": + setup() \ No newline at end of file diff --git a/tools/infer.py b/tools/infer.py deleted file mode 100644 index 9784f45c..00000000 --- a/tools/infer.py +++ /dev/null @@ -1,26 +0,0 @@ -from loguru import logger -import sys - -try: - import ganslate -except ImportError: - logger.warning("ganslate not installed as a package, importing it from the local directory.") - sys.path.append('./') - import ganslate - -from ganslate.engines.inferer import Inferer -from ganslate.utils import communication -from ganslate.utils.builders import build_conf - - -def main(): - # inits distributed mode if ran with torch.distributed.launch - communication.init_distributed() - - conf = build_conf() - inferer = Inferer(conf) - inferer.run() - - -if __name__ == '__main__': - main() diff --git a/tools/test.py b/tools/test.py deleted file mode 100644 index 9bcc54c7..00000000 --- a/tools/test.py +++ /dev/null @@ -1,29 +0,0 @@ -from loguru import logger -import sys - -try: - import ganslate -except ImportError: - logger.warning("ganslate not installed as a package, importing it from the local directory.") - sys.path.append('./') - import ganslate - -#from ganslate import configs -from ganslate.utils.builders import build_conf, build_gan -from ganslate.engines.validator_tester import Tester -from ganslate.utils import communication, environment - - -def main(): - # inits distributed mode if ran with torch.distributed.launch - communication.init_distributed() - environment.setup_threading() - - conf = build_conf() - - tester = Tester(conf) - tester.run() - - -if __name__ == '__main__': - main() diff --git a/tools/train.py b/tools/train.py deleted file mode 100644 index 7e0cf2f3..00000000 --- a/tools/train.py +++ /dev/null @@ -1,28 +0,0 @@ -from loguru import logger -import sys - -try: - import ganslate -except ImportError: - logger.warning("ganslate not installed as a package, importing it from the local directory.") - sys.path.append('./') - import ganslate - -from ganslate.engines.trainer import Trainer -from ganslate.utils import communication, environment -from ganslate.utils.builders import build_conf - - -def main(): - # inits distributed mode if ran with torch.distributed.launch - communication.init_distributed() - environment.setup_threading() - - conf = build_conf() - - trainer = Trainer(conf) - trainer.run() - - -if __name__ == '__main__': - main()