-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
functioning packaging and entry point
- Loading branch information
Showing
8 changed files
with
149 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from ganslate.engines.trainer import Trainer | ||
from ganslate.engines.validator_tester import Tester | ||
from ganslate.engines.inferer import Inferer | ||
from ganslate.utils import communication, environment | ||
from ganslate.utils.builders import build_conf | ||
|
||
|
||
ENGINES = { | ||
'train': Trainer, | ||
'test': Tester, | ||
'infer': Inferer | ||
} | ||
|
||
def init_engine(mode): | ||
assert mode in ENGINES.keys() | ||
|
||
# inits distributed mode if ran with torch.distributed.launch | ||
communication.init_distributed() | ||
environment.setup_threading() | ||
|
||
conf = build_conf() | ||
return ENGINES[mode](conf) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import sys | ||
from ganslate.engines.utils import init_engine | ||
|
||
|
||
def run_interface(): | ||
Interface().run() | ||
|
||
|
||
class Interface: | ||
def __init__(self): | ||
|
||
self.COMMANDS = { | ||
'help': self._help, | ||
'train': self._train_test_infer, | ||
'test': self._train_test_infer, | ||
'infer': self._train_test_infer, | ||
'new-project': self._new_project, | ||
} | ||
self.mode = sys.argv[1] if len(sys.argv) > 1 else 'help' | ||
|
||
def _help(self): | ||
# TODO: this is a temporary help. find what's the best convention and way to do it | ||
msg = "\nGANSLATE COMMAND LIST\n\n" | ||
msg += "help List all available commands and info on them.\n" | ||
msg += "train Run training.\n" | ||
msg += "test Test a trained model. Possible only with paired dataset.\n" | ||
msg += "infer Perform inference with a trained model.\n" | ||
msg += "new-project Generate a project template. Specify the name and the path where it should be generated\n" | ||
sys.stdout.write(msg + "\n") | ||
|
||
def _train_test_infer(self): | ||
engine = init_engine(self.mode) | ||
engine.run() | ||
|
||
def _new_project(self): | ||
raise NotImplementedError | ||
|
||
def run(self): | ||
if self.mode not in self.COMMANDS.keys(): | ||
sys.stderr.write(f"\nUnknown command `{self.mode}`\n") | ||
self.mode = 'help' | ||
|
||
command = self.COMMANDS[self.mode] | ||
command() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import sys | ||
|
||
from pkg_resources import VersionConflict, require | ||
from setuptools import setup | ||
|
||
try: | ||
require('setuptools>=38.3') | ||
except VersionConflict: | ||
print("Error: version of setuptools is too old (<38.3)!") | ||
sys.exit(1) | ||
|
||
if __name__ == "__main__": | ||
setup() |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.