Skip to content

Commit

Permalink
Merge pull request #136 from mckayward/fix/data-mount-shortnames
Browse files Browse the repository at this point in the history
skip data config check in dataset name resolution in job-related commands
  • Loading branch information
narenst authored Nov 20, 2017
2 parents 64cd83b + 8037df4 commit 19bac48
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 17 deletions.
2 changes: 1 addition & 1 deletion floyd/cli/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def clone(id):
Download the code for the job to the current path
"""

data_source = DataClient().get(normalize_data_name(id))
data_source = DataClient().get(normalize_data_name(id, use_data_config=False))
if id and not data_source:
# Try with the raw ID
data_source = DataClient().get(id)
Expand Down
4 changes: 2 additions & 2 deletions floyd/cli/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def clone(id):
Download the code for the experiment to the current path
"""
try:
experiment = ExperimentClient().get(normalize_job_name(id))
experiment = ExperimentClient().get(normalize_job_name(id, use_config=False))
except FloydException:
experiment = ExperimentClient().get(id)
experiment = ExperimentClient().get(id, use_config=False)

task_instance_id = get_module_task_instance_id(experiment.task_instances)
task_instance = TaskInstanceClient().get(task_instance_id) if task_instance_id else None
Expand Down
8 changes: 4 additions & 4 deletions floyd/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from floyd.client.data import DataClient
from floyd.client.project import ProjectClient
from floyd.cli.utils import (
get_mode_parameter, get_data_name, normalize_job_name
get_mode_parameter, get_data_name, normalize_data_name, normalize_job_name
)
from floyd.client.experiment import ExperimentClient
from floyd.client.module import ModuleClient
Expand Down Expand Up @@ -46,9 +46,9 @@ def process_data_ids(data):
path = None
if ':' in data_name_or_id:
data_name_or_id, path = data_name_or_id.split(':')
data_name_or_id = data_name_or_id
data_name_or_id = normalize_data_name(data_name_or_id, use_data_config=False)

data_obj = DataClient().get(data_name_or_id)
data_obj = DataClient().get(normalize_data_name(data_name_or_id, use_data_config=False))

if not data_obj:
# Try with the raw ID
Expand Down Expand Up @@ -244,7 +244,7 @@ def get_command_line(instance_type, env, message, data, mode, open_notebook, ten
parts = data_item.split(':')

if len(parts) > 1:
data_item = parts[0] + ':' + parts[1]
data_item = normalize_data_name(parts[0], use_data_config=False) + ':' + parts[1]

floyd_command += ["--data", data_item]
if tensorboard:
Expand Down
23 changes: 17 additions & 6 deletions floyd/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,18 @@ def get_data_id(data_str):
return data_str


def normalize_data_name(raw_name, default_username=None, default_dataset_name=None):
def normalize_data_name(raw_name, default_username='', default_dataset_name='', use_data_config=True):
raw_name = raw_name or ''
if use_data_config:
default_dataset_name = default_dataset_name or current_dataset_name()

if raw_name.endswith('/output'):
return normalize_job_name(raw_name[:-len('/output')], default_username, default_dataset_name) + '/output'

name_parts = raw_name.split('/')

username = default_username or current_username()
name = default_dataset_name or current_dataset_name()
name = default_dataset_name
number = None # current version number

# When nothing is passed, use all the defaults
Expand Down Expand Up @@ -108,17 +110,23 @@ def normalize_data_name(raw_name, default_username=None, default_dataset_name=No
if number is not None:
name_parts.append(number)

if not name:
raise FloydException('Dataset name resolution: Could not infer a name from "%s". Please include a name to identify the dataset' % raw_name)

return '/'.join(name_parts)


def normalize_job_name(raw_job_name, default_username=None, default_project_name=None):
def normalize_job_name(raw_job_name, default_username='', default_project_name='', use_config=True):
raw_job_name = raw_job_name or ''

if use_config:
default_project_name = default_project_name or current_experiment_name()

name_parts = raw_job_name.split('/')

username = default_username or current_username()
project_name = default_project_name or current_experiment_name()
number = None # current job number
project_name = default_project_name
number = '' # current job number

# When nothing is passed, use all the defaults
if not raw_job_name:
Expand Down Expand Up @@ -153,12 +161,15 @@ def normalize_job_name(raw_job_name, default_username=None, default_project_name
return raw_job_name

# If no number is found, query the API for the most recent job number
if number is None:
if not number:
job_name_from_api = get_latest_job_name_for_project(username, project_name)
if not job_name_from_api:
raise FloydException("Could not resolve %s. Make sure the project exists and has jobs." % raw_job_name)
return job_name_from_api

if not project_name:
raise FloydException('Job name resolution: Could not infer a project name from "%s". Please include a name to identify the project' % raw_job_name)

return '/'.join([username, 'projects', project_name, number])


Expand Down
9 changes: 5 additions & 4 deletions tests/cli/run/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ def test_with_multiple_data_ids(self,
result = self.runner.invoke(run, ['command', '--data', 'data-id1', '--data', 'data-id2'], catch_exceptions=False)
assert(result.exit_code == 0)

def test_get_command_line(self):
@patch('floyd.cli.run.normalize_data_name', return_value='mckay/datasets/foo/1')
def test_get_command_line(self, _):
re = get_command_line(
instance_type='g1p',
env='pytorch-2.0:py2',
message='test\' message',
data=['mckay/datasets/foo/1:input'],
data=['foo:input'],
mode='job',
open_notebook=False,
tensorboard=True,
Expand All @@ -79,7 +80,7 @@ def test_get_command_line(self):
instance_type='c1',
env='tensorflow',
message=None,
data=['mckay/datasets/foo/1:input', 'bar'],
data=['foo:input', 'bar'],
mode='jupyter',
open_notebook=True,
tensorboard=False,
Expand All @@ -91,7 +92,7 @@ def test_get_command_line(self):
instance_type='g1',
env='tensorflow',
message=None,
data=['mckay/datasets/foo/1:input'],
data=['foo:input'],
mode='job',
open_notebook=False,
tensorboard=True,
Expand Down

0 comments on commit 19bac48

Please sign in to comment.