From a189729e52398678cacc8002d6e54d08a213e4c5 Mon Sep 17 00:00:00 2001 From: Aidan Gomez Date: Thu, 19 Mar 2020 18:28:19 -0400 Subject: [PATCH] Update formatting and support multiprocessing (#30) --- cloud/__init__.py | 2 + cloud/cloud.py | 44 ++-- cloud/envs/aws.py | 45 ++-- cloud/envs/azure.py | 47 ++-- cloud/envs/env.py | 216 +++++++++-------- cloud/envs/gcp.py | 524 +++++++++++++++++++---------------------- cloud/envs/registry.py | 14 +- cloud/envs/utils.py | 132 +++++------ setup.py | 4 +- 9 files changed, 497 insertions(+), 531 deletions(-) diff --git a/cloud/__init__.py b/cloud/__init__.py index 9a7aca4..bf28938 100644 --- a/cloud/__init__.py +++ b/cloud/__init__.py @@ -11,6 +11,8 @@ logging.getLogger('googleapiclient.discovery_cache').setLevel(logging.ERROR) logging.getLogger('googleapiclient.discovery').setLevel(logging.CRITICAL) +socket_path = None + from cloud.envs import registry from cloud.envs import env diff --git a/cloud/cloud.py b/cloud/cloud.py index 80903c3..f19de30 100644 --- a/cloud/cloud.py +++ b/cloud/cloud.py @@ -1,9 +1,11 @@ import atexit import logging +import os -import cloud import libcloud import toml + +import cloud from cloud import Instance from cloud import registry as reg from cloud.envs import utils @@ -11,33 +13,39 @@ logger = logging.getLogger(__name__) -def connect(): - config_filepath = utils.config_path() - if config_filepath is None: - logger.error( - "ASSUMING LOCAL: Configuration file not found in any of the locations" - " below.\n See github.com/for-ai/cloud/tree/master/configs for " - "example configurations to fill in and use; copy and place in a file " - "named `cloud.toml` at `/cloud.toml` or `$HOME/cloud.toml`.") - return +def connect(socket_path=None): + config_filepath = utils.config_path() + if config_filepath is None: + logger.error("ASSUMING LOCAL: Configuration file not found in any of the locations" + " below.\n See github.com/for-ai/cloud/tree/master/configs for " + "example configurations to fill in and use; copy and place in a file " + "named `cloud.toml` at `/cloud.toml` or `$HOME/cloud.toml`.") + return + + with open(config_filepath, "r") as cf: + config = toml.load(cf) + provider = config.pop("provider").lower() + if socket_path: + _set_socket_path(socket_path) + cloud.instance = reg.retrieve(provider, config=config) + - with open(config_filepath, "r") as cf: - config = toml.load(cf) - provider = config.pop("provider").lower() - cloud.instance = reg.retrieve(provider, config=config) +def _set_socket_path(local_socket_path): + os.makedirs(os.path.dirname(local_socket_path), exist_ok=True) + cloud.socket_path = local_socket_path def close(): - utils.kill_transport() - utils.kill_server() + utils.kill_transport() + utils.kill_server() def down(): - cloud.instance.down() + cloud.instance.down() def delete(confirm=True): - cloud.instance.delete(confirm) + cloud.instance.delete(confirm) atexit.register(close) diff --git a/cloud/envs/aws.py b/cloud/envs/aws.py index 5567982..91cc27b 100644 --- a/cloud/envs/aws.py +++ b/cloud/envs/aws.py @@ -1,34 +1,31 @@ import os -import requests - -from cloud.envs import env -from cloud.envs import registry -from cloud.envs import utils -from libcloud.compute.types import Provider +import requests from libcloud.compute.providers import get_driver +from libcloud.compute.types import Provider + +from cloud.envs import env, registry, utils @registry.register("aws") class AWSInstance(env.Instance): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.access_key = config["access_key"] - self.secret_key = config["secret_key"] - self.region = config["region"] + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.access_key = config["access_key"] + self.secret_key = config["secret_key"] + self.region = config["region"] - @property - def driver(self): - if getattr(self, '_driver', None) is None: - self._driver = get_driver(Provider.EC2)( - self.access_key, self.secret_key, region=self.region) - return self._driver + @property + def driver(self): + if getattr(self, '_driver', None) is None: + self._driver = get_driver(Provider.EC2)(self.access_key, self.secret_key, region=self.region) + return self._driver - @property - def name(self): - if getattr(self, '_name', None) is None: - # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html - r = requests.get("http://169.254.169.254/latest/meta-data/instance-id") - self._name = r.text - return self._name + @property + def name(self): + if getattr(self, '_name', None) is None: + # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html + r = requests.get("http://169.254.169.254/latest/meta-data/instance-id") + self._name = r.text + return self._name diff --git a/cloud/envs/azure.py b/cloud/envs/azure.py index feb1e79..6331e25 100644 --- a/cloud/envs/azure.py +++ b/cloud/envs/azure.py @@ -1,33 +1,30 @@ -from cloud.envs import env -from cloud.envs import registry -from cloud.envs import utils - -from libcloud.compute.types import Provider from libcloud.compute.providers import get_driver +from libcloud.compute.types import Provider + +from cloud.envs import env, registry, utils @registry.register("azure") class AzureInstance(env.Instance): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.application_id = config["application_id"] - self.subscription_id = config["subscription_id"] - self.tenant_id = config["tenant_id"] - self.key = config["key"] + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.application_id = config["application_id"] + self.subscription_id = config["subscription_id"] + self.tenant_id = config["tenant_id"] + self.key = config["key"] - @property - def driver(self): - if getattr(self, '_driver', None) is None: - self._driver = get_driver(Provider.AZURE_ARM)( - tenant_id=self.tenant_id, - subscription_id=self.subscription_id, - key=self.application_id, - secret=self.key) - return self._driver + @property + def driver(self): + if getattr(self, '_driver', None) is None: + self._driver = get_driver(Provider.AZURE_ARM)(tenant_id=self.tenant_id, + subscription_id=self.subscription_id, + key=self.application_id, + secret=self.key) + return self._driver - @property - def name(self): - if getattr(self, '_name', None) is None: - self._name = utils.call(["hostname"])[1].strip() - return self._name \ No newline at end of file + @property + def name(self): + if getattr(self, '_name', None) is None: + self._name = utils.call(["hostname"])[1].strip() + return self._name diff --git a/cloud/envs/env.py b/cloud/envs/env.py index c018401..cc8126b 100644 --- a/cloud/envs/env.py +++ b/cloud/envs/env.py @@ -1,7 +1,7 @@ import logging -import traceback import sys import time +import traceback from cloud.envs import utils @@ -10,139 +10,135 @@ class Resource(object): - def __init__(self, manager=None): - super().__init__() - self.manager = manager + def __init__(self, manager=None): + super().__init__() + self.manager = manager - @property - def name(self): - raise NotImplementedError + @property + def name(self): + raise NotImplementedError - @property - def usable(self): - return True + @property + def usable(self): + return True - def up(self, background=False): - raise NotImplementedError + def up(self, background=False): + raise NotImplementedError - def down(self, background=True): - raise NotImplementedError + def down(self, background=True): + raise NotImplementedError - def delete(self, background=True): - if self.manager: - self.manager.remove(self) + def delete(self, background=True): + if self.manager: + self.manager.remove(self) class Instance(Resource): - def __init__(self, manager=None, **kwargs): - super().__init__(manager=manager) - self.resource_managers = [] + def __init__(self, manager=None, **kwargs): + super().__init__(manager=manager) + self.resource_managers = [] - assert utils.get_server().is_alive() + assert utils.get_server().is_alive() - def _kill_command_server(self): - utils.kill_transport() - utils.kill_server() + def _kill_command_server(self): + utils.kill_transport() + utils.kill_server() - @property - def driver(self): - raise NotImplementedError + @property + def driver(self): + raise NotImplementedError - @property - def node(self): - if getattr(self, '_node', None) is None: - nodes = self.driver.list_nodes() - if len(nodes) == 0: - raise Exception( - "list_nodes returned an empty list, did you set up your cloud permissions correctly?" - ) + @property + def node(self): + if getattr(self, '_node', None) is None: + nodes = self.driver.list_nodes() + if len(nodes) == 0: + raise Exception("list_nodes returned an empty list, did you set up your cloud permissions correctly?") - for n in nodes: - if n.name == self.name: - self._node = n + for n in nodes: + if n.name == self.name: + self._node = n - # node not found - if self._node is None: - raise Exception( - "current node could not be found - name: {} node_list: {}".format( - self.name, nodes)) + # node not found + if self._node is None: + raise Exception("current node could not be found - name: {} node_list: {}".format(self.name, nodes)) - return self._node + return self._node - def clean(self, background=True, delete_resources=True): - for rm in self.resource_managers: - if delete_resources: - rm.delete(background=background) - else: - rm.down(background=background) + def clean(self, background=True, delete_resources=True): + for rm in self.resource_managers: + if delete_resources: + rm.delete(background=background) + else: + rm.down(background=background) - self._kill_command_server() + self._kill_command_server() - def down(self, background=True, delete_resources=True): - self.clean(background=background, delete_resources=delete_resources) - self.driver.ex_stop_node(self.node) + def down(self, background=True, delete_resources=True): + self.clean(background=background, delete_resources=delete_resources) + self.driver.ex_stop_node(self.node) - def delete(self, background=True, confirm=True): - while confirm: - r = input("Are you sure you wish to delete this instance (y/[n]): ") + def delete(self, background=True, confirm=True): + while confirm: + r = input("Are you sure you wish to delete this instance (y/[n]): ") - if r == "y": - break - elif r in ["n", ""]: - logger.info("Aborting deletion...") - return + if r == "y": + break + elif r in ["n", ""]: + logger.info("Aborting deletion...") + return - super().delete(background=background) + super().delete(background=background) - self.clean(background=background, delete_resources=True) - self.driver.destroy_node(self.node, destroy_boot_disk=True) + self.clean(background=background, delete_resources=True) + self.driver.destroy_node(self.node, destroy_boot_disk=True) class ResourceManager(object): - def __init__(self, instance, resource_cls): - super().__init__() - self.instance = instance - self.resource_cls = resource_cls - self.resources = [] - - def __getitem__(self, idx): - return self.resources[idx] - - def add(self, *args, **kwargs): - if len(args) == 1: - arg = args[0] - if isinstance(arg, self.resource_cls): - self.resources += [arg] - return arg - - raise NotImplementedError - - def remove(self, *args, **kwargs): - if len(args) == 1: - arg = args[0] - if isinstance(arg, self.resource_cls): - self.resources.remove(arg) - return - - raise NotImplementedError - - def up(self, background=False): - raise NotImplementedError - - def down(self, background=True): - for r in self.resources: - try: - r.down(background=background) - except Exception as e: - logger.error("Failed to shutdown resource: %s" % r) - logger.error(traceback.format_exc()) - - def delete(self, background=True): - for r in self.resources: - try: - r.delete(background=background) - except Exception as e: - logger.error("Failed to delete resource: %s" % r) - logger.error(traceback.format_exc()) + def __init__(self, instance, resource_cls): + super().__init__() + self.instance = instance + self.resource_cls = resource_cls + self.resources = [] + + def __getitem__(self, idx): + return self.resources[idx] + + def add(self, *args, **kwargs): + if len(args) == 1: + arg = args[0] + if isinstance(arg, self.resource_cls): + self.resources += [arg] + return arg + + raise NotImplementedError + + def remove(self, *args, **kwargs): + if len(args) == 1: + arg = args[0] + if isinstance(arg, self.resource_cls): + self.resources.remove(arg) + return + + raise NotImplementedError + + def up(self, background=False): + raise NotImplementedError + + def down(self, background=True): + for r in self.resources: + try: + r.down(background=background) + except Exception as e: + logger.error("Failed to shutdown resource: %s" % r) + logger.error(traceback.format_exc()) + + def delete(self, background=True): + for r in self.resources: + try: + r.delete(background=background) + except Exception as e: + logger.error("Failed to delete resource: %s" % r) + logger.error(traceback.format_exc()) diff --git a/cloud/envs/gcp.py b/cloud/envs/gcp.py index dd4675d..585a04b 100644 --- a/cloud/envs/gcp.py +++ b/cloud/envs/gcp.py @@ -18,291 +18,257 @@ @registry.register("gcp") class GCPInstance(env.Instance): - def __init__(self, collect_existing_tpus=True, **kwargs): - super().__init__(**kwargs) - - # Check for dependencies - try: - utils.call(["gcloud", "--version"]) - except Exception as e: - raise(e) - - self.tpu = TPUManager(self) - self.resource_managers = [self.tpu] - - @property - def driver(self): - if getattr(self, '_driver', None) is None: - r = requests.get( - "http://metadata.google.internal/computeMetadata/v1/project/project-id", - headers={"Metadata-Flavor": "Google"}) - project_id = r.text - self._driver = get_driver(Provider.GCE)("", "", project=project_id) - return self._driver - - @property - def name(self): - if getattr(self, '_name', None) is None: - self._name = utils.call(["hostname"])[1].strip() - return self._name + def __init__(self, collect_existing_tpus=True, **kwargs): + super().__init__(**kwargs) + + # Check for dependencies + try: + utils.call(["gcloud", "--version"]) + except Exception as e: + raise (e) + + self.tpu = TPUManager(self) + self.resource_managers = [self.tpu] + + @property + def driver(self): + if getattr(self, '_driver', None) is None: + r = requests.get("http://metadata.google.internal/computeMetadata/v1/project/project-id", + headers={"Metadata-Flavor": "Google"}) + project_id = r.text + self._driver = get_driver(Provider.GCE)("", "", project=project_id) + return self._driver + + @property + def name(self): + if getattr(self, '_name', None) is None: + self._name = utils.call(["hostname"])[1].strip() + return self._name class TPU(env.Resource): - def __init__(self, name, manager=None): - super().__init__(manager=manager) - self._name = name - details = self.details - self.ip = details.get("ipAddress") - self.preemptible = details.get("preemptible") == "true" - self.version = details.get("acceleratorType") - self._in_use = False - - @property - def name(self): - return self._name - - @property - def details(self): - _, r, _ = utils.call([ - "gcloud", "alpha", "compute", "tpus", "describe", - "--zone={}".format(self.manager.zone), self.name - ]) - r = r.split("\n") - details = dict() - for line in r: - v = line.split(": ") - if len(v) != 2: - continue - k, v = v - details[k.strip()] = v.strip() - return details - - @property - def still_exists(self): - return self.name in self.manager.get_all_tpu_names() - - @property - def free(self): - return not self._in_use - - @property - def usable(self): - details = self.details - if not self.still_exists: - logger.debug("tpu {} no longer exists and will be removed.".format( - self.name)) - self.manager.remove(self) - return False - - is_running = details.get("state") in ["READY", "RUNNING"] - is_healthy = details.get("health") in ["HEALTHY", None] - - if not is_running: - logger.debug("tpu {} is no longer running.".format(self.name)) - - if not is_healthy: - logger.debug("tpu {} is no longer healthy.".format(self.name)) - - return is_running and is_healthy - - def up(self, background=False): - cmd = [ - "gcloud", "alpha", "compute", "tpus", "start", - "--zone={}".format(self.manager.zone), self.name - ] - if background: - cmd += ["--async"] - - utils.try_call(cmd) - - def down(self, background=True): - cmd = [ - "gcloud", "alpha", "compute", "tpus", "stop", - "--zone={}".format(self.manager.zone), self.name - ] - if background: - cmd += ["--async"] - - utils.try_call(cmd) - - def delete(self, background=True): - super().delete(background=background) - - if not self.still_exists: - return - - cmd = [ - "gcloud", "alpha", "compute", "tpus", "delete", - "--zone={}".format(self.manager.zone), self.name - ] - if background: - cmd += ["--async"] - cmd += ["--quiet"] # suppress user confirmation - - utils.try_call(cmd) - - def in_use(self): - self._in_use = True - - def release(self): - self._in_use = False + def __init__(self, name, manager=None): + super().__init__(manager=manager) + self._name = name + details = self.details + self.ip = details.get("ipAddress") + self.preemptible = details.get("preemptible") == "true" + self.version = details.get("acceleratorType") + self._in_use = False + + @property + def name(self): + return self._name + + @property + def details(self): + _, r, _ = utils.call( + ["gcloud", "alpha", "compute", "tpus", "describe", "--zone={}".format(self.manager.zone), self.name]) + r = r.split("\n") + details = dict() + for line in r: + v = line.split(": ") + if len(v) != 2: + continue + k, v = v + details[k.strip()] = v.strip() + return details + + @property + def still_exists(self): + return self.name in self.manager.get_all_tpu_names() + + @property + def free(self): + return not self._in_use + + @property + def usable(self): + details = self.details + if not self.still_exists: + logger.debug("tpu {} no longer exists and will be removed.".format(self.name)) + self.manager.remove(self) + return False + + is_running = details.get("state") in ["READY", "RUNNING"] + is_healthy = details.get("health") in ["HEALTHY", None] + + if not is_running: + logger.debug("tpu {} is no longer running.".format(self.name)) + + if not is_healthy: + logger.debug("tpu {} is no longer healthy.".format(self.name)) + + return is_running and is_healthy + + def up(self, background=False): + cmd = ["gcloud", "alpha", "compute", "tpus", "start", "--zone={}".format(self.manager.zone), self.name] + if background: + cmd += ["--async"] + + utils.try_call(cmd) + + def down(self, background=True): + cmd = ["gcloud", "alpha", "compute", "tpus", "stop", "--zone={}".format(self.manager.zone), self.name] + if background: + cmd += ["--async"] + + utils.try_call(cmd) + + def delete(self, background=True): + super().delete(background=background) + + if not self.still_exists: + return + + cmd = ["gcloud", "alpha", "compute", "tpus", "delete", "--zone={}".format(self.manager.zone), self.name] + if background: + cmd += ["--async"] + cmd += ["--quiet"] # suppress user confirmation + + utils.try_call(cmd) + + def in_use(self): + self._in_use = True + + def release(self): + self._in_use = False class TPUManager(env.ResourceManager): - def __init__(self, instance): - super().__init__(instance, TPU) - try: - import tensorflow as tf - import re - m = re.search(r'(\d+\.\d+)\.\d+', tf.__version__) - self.tf_version = m.group(1) - if "dev" in tf.__version__: - logging.info("Found Tensorflow nightly version. Using TPU software version: nightly") - self.tf_version = "nightly" - except: - logger.warn("Unable to determine Tensorflow version. Assuming 1.15") - self.tf_version = "1.15" - self.hostname = socket.gethostname() - _, r, _ = utils.call([ - "gcloud", "compute", "instances", "list", - "--filter=\"name={}\"".format(self.hostname) - ]) - lines = r.split("\n")[1:] - lines = list(filter(lambda l: l != "", lines)) - self.zone = lines[0].split()[1] - self.refresh() - - @property - def names(self): - return [r.name for r in self.resources] - - @property - def ips(self): - return [r.ip for r in self.resources] - - def get_all_tpu_names(self): - _, r, _ = utils.call([ - "gcloud", "alpha", "compute", "tpus", "list", - "--zone={}".format(self.zone) - ]) - lines = r.split("\n")[1:] - lines = filter(lambda l: l != "", lines) - names = [l.split()[0] for l in lines] - return filter(lambda n: self.instance.name in n, names) - - def refresh(self, background=True): - self.collect_existing() - self.clean(background=background) - - def collect_existing(self): - names = self.get_all_tpu_names() - existing_names = self.names - new_tpus = [ - TPU(name=n, manager=self) for n in names if n not in existing_names - ] - - for tpu in new_tpus: - logger.debug("Found TPU named {}".format(tpu.name)) - - self.resources.extend(new_tpus) - - def clean(self, background=True): - all_tpu_names = self.get_all_tpu_names() - for tpu in self.resources: - if tpu.name not in all_tpu_names: - self.remove(tpu) - elif not tpu.usable: - tpu.delete(background=background) - - def _new_name(self, length=5): - while True: - name = random.sample(string.ascii_lowercase, length) - name = self.instance.name + "-" + ''.join(name) - if name not in self.names: - return name - - def _new_ip(self): - while True: - ip = random.randint(1, 98) - if ip not in self.ips: - return ip - - def add(self, *args, **kwargs): - if len(args) == 1: - arg = args[0] - if isinstance(arg, str): - tpu = TPU(name=arg, manager=self) - self.resources.append(tpu) + def __init__(self, instance): + super().__init__(instance, TPU) + try: + import tensorflow as tf + import re + m = re.search(r'(\d+\.\d+)\.\d+', tf.__version__) + self.tf_version = m.group(1) + if "dev" in tf.__version__: + logging.info("Found Tensorflow nightly version. Using TPU software version: nightly") + self.tf_version = "nightly" + except: + logger.warn("Unable to determine Tensorflow version. Assuming 1.15") + self.tf_version = "1.15" + self.hostname = socket.gethostname() + _, r, _ = utils.call(["gcloud", "compute", "instances", "list", "--filter=\"name={}\"".format(self.hostname)]) + lines = r.split("\n")[1:] + lines = list(filter(lambda l: l != "", lines)) + self.zone = lines[0].split()[1] + self.refresh() + + @property + def names(self): + return [r.name for r in self.resources] + + @property + def ips(self): + return [r.ip for r in self.resources] + + def get_all_tpu_names(self): + _, r, _ = utils.call(["gcloud", "alpha", "compute", "tpus", "list", "--zone={}".format(self.zone)]) + lines = r.split("\n")[1:] + lines = filter(lambda l: l != "", lines) + names = [l.split()[0] for l in lines] + return filter(lambda n: self.instance.name in n, names) + + def refresh(self, background=True): + self.collect_existing() + self.clean(background=background) + + def collect_existing(self): + names = self.get_all_tpu_names() + existing_names = self.names + new_tpus = [TPU(name=n, manager=self) for n in names if n not in existing_names] + + for tpu in new_tpus: + logger.debug("Found TPU named {}".format(tpu.name)) + + self.resources.extend(new_tpus) + + def clean(self, background=True): + all_tpu_names = self.get_all_tpu_names() + for tpu in self.resources: + if tpu.name not in all_tpu_names: + self.remove(tpu) + elif not tpu.usable: + tpu.delete(background=background) + + def _new_name(self, length=5): + while True: + name = random.sample(string.ascii_lowercase, length) + name = self.instance.name + "-" + ''.join(name) + if name not in self.names: + return name + + def _new_ip(self): + while True: + ip = random.randint(1, 98) + if ip not in self.ips: + return ip + + def add(self, *args, **kwargs): + if len(args) == 1: + arg = args[0] + if isinstance(arg, str): + tpu = TPU(name=arg, manager=self) + self.resources.append(tpu) + return tpu + return super().add(*args, **kwargs) + + def get(self, preemptible=True, name=None, version='v3-8', zone=None): + tpu = None + assert re.match(r"v\d-\d+", version) + for tpu in self.resources: + logger.debug("Considering tpu: {}".format(tpu.name)) + if tpu.version == version and tpu.usable and tpu.free and not name: + logger.debug("tpu usable") + break + + if tpu.name == name: + break + else: + logger.debug("creating tpu") + tpu = self.up(preemptible=preemptible, name=name, version=version, zone=zone) + tpu.in_use() return tpu - return super().add(*args, **kwargs) - - def get(self, preemptible=True, name=None, version='v3-8', zone=None): - tpu = None - assert re.match(r"v\d-\d+", version) - for tpu in self.resources: - logger.debug("Considering tpu: {}".format(tpu.name)) - if tpu.version == version and tpu.usable and tpu.free and not name: - logger.debug("tpu usable") - break - - if tpu.name == name: - break - else: - logger.debug("creating tpu") - tpu = self.up(preemptible=preemptible, - name=name, - version=version, - zone=zone) - tpu.in_use() - return tpu - - def _up(self, name, ip, preemptible, version, zone, background): - logger.info("Trying to acquire TPU with name: {} ip: {}".format(name, ip)) - cmd = [ - "gcloud", "alpha", "compute", "tpus", "create", name, - "--range=10.0.{}.0".format(ip), - "--accelerator-type={}".format(version), - "--version={}".format(self.tf_version), "--network=default" - ] - if zone: - cmd += ["--zone={}".format(zone)] - if preemptible: - cmd += ["--preemptible"] - if background: - cmd += ["--async"] - - s, _, err = utils.call(cmd) - if s == 0: - return TPU(name=name, manager=self) - - raise Exception( - "Failed to create TPU with name: {} ip: {} error: \n{}".format( - name, ip, err)) - - def up(self, - preemptible=True, - background=False, - attempts=5, - name=None, - version='v3-8', - zone=None): - if not name: - name = self._new_name() - for i in range(attempts): - try: - tpu = self._up(name, - self._new_ip(), - preemptible=preemptible, - version=version, - zone=zone, - background=background) - tpu.manager = self - self.resources.append(tpu) - return tpu - except Exception as e: - logger.debug("Call resulted in error:\n{}".format(e)) - if i + 1 == attempts: - raise e - continue + + def _up(self, name, ip, preemptible, version, zone, background): + logger.info("Trying to acquire TPU with name: {} ip: {}".format(name, ip)) + cmd = [ + "gcloud", "alpha", "compute", "tpus", "create", name, "--range=10.0.{}.0".format(ip), + "--accelerator-type={}".format(version), "--version={}".format(self.tf_version), "--network=default" + ] + if zone: + cmd += ["--zone={}".format(zone)] + if preemptible: + cmd += ["--preemptible"] + if background: + cmd += ["--async"] + + s, _, err = utils.call(cmd) + if s == 0: + return TPU(name=name, manager=self) + + raise Exception("Failed to create TPU with name: {} ip: {} error: \n{}".format(name, ip, err)) + + def up(self, preemptible=True, background=False, attempts=5, name=None, version='v3-8', zone=None): + if not name: + name = self._new_name() + for i in range(attempts): + try: + tpu = self._up(name, + self._new_ip(), + preemptible=preemptible, + version=version, + zone=zone, + background=background) + tpu.manager = self + self.resources.append(tpu) + return tpu + except Exception as e: + logger.debug("Call resulted in error:\n{}".format(e)) + if i + 1 == attempts: + raise e + continue diff --git a/cloud/envs/registry.py b/cloud/envs/registry.py index d5d4c27..ee054dc 100644 --- a/cloud/envs/registry.py +++ b/cloud/envs/registry.py @@ -2,15 +2,15 @@ def register(name): - global INSTANCES + global INSTANCES - def fn(cls): - INSTANCES[name] = cls - return cls + def fn(cls): + INSTANCES[name] = cls + return cls - return fn + return fn def retrieve(name, **kwargs): - global INSTANCES - return INSTANCES[name](**kwargs) + global INSTANCES + return INSTANCES[name](**kwargs) diff --git a/cloud/envs/utils.py b/cloud/envs/utils.py index a01fd71..6fb465e 100644 --- a/cloud/envs/utils.py +++ b/cloud/envs/utils.py @@ -7,6 +7,8 @@ from errand_boy.run import main as eb_main from errand_boy.transports.unixsocket import UNIXSocketTransport +import cloud + logger = logging.getLogger(__name__) EB_TRANSPORT = None @@ -14,92 +16,90 @@ def get_transport(): - global EB_TRANSPORT - if EB_TRANSPORT is None: - EB_TRANSPORT = UNIXSocketTransport() - return EB_TRANSPORT + global EB_TRANSPORT + if EB_TRANSPORT is None: + EB_TRANSPORT = UNIXSocketTransport(socket_path=cloud.socket_path) + return EB_TRANSPORT def kill_transport(): - global EB_TRANSPORT - if EB_TRANSPORT is None: - return + global EB_TRANSPORT + if EB_TRANSPORT is None: + return - logger.warn("Killing transport") - del EB_TRANSPORT - EB_TRANSPORT = None + logger.warn("Killing transport") + del EB_TRANSPORT + EB_TRANSPORT = None -def _server_fn(): - server = UNIXSocketTransport() - server.run_server() +def _server_fn(socket_path): + server = UNIXSocketTransport(socket_path=socket_path) + server.run_server() def get_server(): - global EB_SERVER - if EB_SERVER is None: - EB_SERVER = multiprocessing.Process(target=_server_fn) - EB_SERVER.start() - time.sleep(1) - logging.getLogger("errand_boy").setLevel(logging.ERROR) - return EB_SERVER + global EB_SERVER + if EB_SERVER is None: + EB_SERVER = multiprocessing.Process(target=_server_fn, args=(cloud.socket_path,)) + EB_SERVER.start() + time.sleep(3) + logging.getLogger("errand_boy").setLevel(logging.ERROR) + return EB_SERVER def kill_server(): - global EB_SERVER - if EB_SERVER is None: - return + global EB_SERVER + if EB_SERVER is None: + return - logger.warn("Killing server") - if EB_SERVER.is_alive(): - EB_SERVER.terminate() - time.sleep(0.5) - EB_SERVER.join(timeout=1) - del EB_SERVER - EB_SERVER = None + logger.warn("Killing server") + if EB_SERVER.is_alive(): + EB_SERVER.terminate() + time.sleep(0.5) + EB_SERVER.join(timeout=1) + del EB_SERVER + EB_SERVER = None def call(cmd): - if isinstance(cmd, list): - cmd = " ".join(cmd) + if isinstance(cmd, list): + cmd = " ".join(cmd) - stdout, stderr, returncode = get_transport().run_cmd(cmd) - return returncode, stdout.decode("utf-8"), stderr.decode("utf-8") + stdout, stderr, returncode = get_transport().run_cmd(cmd) + return returncode, stdout.decode("utf-8"), stderr.decode("utf-8") def try_call(cmd, retry_count=5): - c = cmd - status = -1 - for _ in range(retry_count): - if callable(cmd): - c = cmd() - status, stdout, stderr = call(c) - if status == 0: - logger.debug("Call to `{}` successful".format(c)) - return c - else: - logger.debug("Call to `{}` failed with status: {}. Retrying...".format( - c, status)) - - raise Exception("Call to `{}` failed {} times." - "Aborting.\n STDOUT: {}\n STDERR: {}".format( - c, retry_count, stdout, stderr)) + c = cmd + status = -1 + for _ in range(retry_count): + if callable(cmd): + c = cmd() + status, stdout, stderr = call(c) + if status == 0: + logger.debug("Call to `{}` successful".format(c)) + return c + else: + logger.debug("Call to `{}` failed with status: {}. Retrying...".format(c, status)) + + raise Exception("Call to `{}` failed {} times." + "Aborting.\n STDOUT: {}\n STDERR: {}".format(c, retry_count, stdout, stderr)) def config_path(): - path = os.environ.get("CLOUD_CFG") - if path is not None and os.path.isfile(path): - return path - logger.debug("Unable to find config file at path: {}".format(path)) - - path = os.path.join(os.environ["HOME"], "cloud.toml") - if os.path.isfile(path): - return path - logger.debug("Unable to find config file at path: {}".format(path)) - - path = "/cloud.toml" - if os.path.isfile(path): - return path - logger.debug("Unable to find config file at path: {}".format(path)) - - return None + path = os.environ.get("CLOUD_CFG") + if path is not None and os.path.isfile(path): + return path + logger.debug("Unable to find config file at path: {}".format(path)) + + path = os.path.join(os.environ["HOME"], "cloud.toml") + if os.path.isfile(path): + return path + logger.debug("Unable to find config file at path: {}".format(path)) + + path = "/cloud.toml" + if os.path.isfile(path): + return path + logger.debug("Unable to find config file at path: {}".format(path)) + + return None diff --git a/setup.py b/setup.py index 725f0ef..3ce5152 100644 --- a/setup.py +++ b/setup.py @@ -6,10 +6,10 @@ # Get the long description from the README file with open(path.join(here, 'README.md'), encoding='utf-8') as f: - long_description = f.read() + long_description = f.read() setup(name='dl-cloud', - version='0.1.7', + version='0.1.8', description='Cloud resource management for deep learning applications.', long_description=long_description, long_description_content_type='text/markdown',