-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update formatting and support multiprocessing (#30)
- Loading branch information
1 parent
5049c2d
commit a189729
Showing
9 changed files
with
497 additions
and
531 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,51 @@ | ||
import atexit | ||
import logging | ||
import os | ||
|
||
import cloud | ||
import libcloud | ||
import toml | ||
|
||
import cloud | ||
from cloud import Instance | ||
from cloud import registry as reg | ||
from cloud.envs import utils | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def connect(): | ||
config_filepath = utils.config_path() | ||
if config_filepath is None: | ||
logger.error( | ||
"ASSUMING LOCAL: Configuration file not found in any of the locations" | ||
" below.\n See github.com/for-ai/cloud/tree/master/configs for " | ||
"example configurations to fill in and use; copy and place in a file " | ||
"named `cloud.toml` at `/cloud.toml` or `$HOME/cloud.toml`.") | ||
return | ||
def connect(socket_path=None): | ||
config_filepath = utils.config_path() | ||
if config_filepath is None: | ||
logger.error("ASSUMING LOCAL: Configuration file not found in any of the locations" | ||
" below.\n See github.com/for-ai/cloud/tree/master/configs for " | ||
"example configurations to fill in and use; copy and place in a file " | ||
"named `cloud.toml` at `/cloud.toml` or `$HOME/cloud.toml`.") | ||
return | ||
|
||
with open(config_filepath, "r") as cf: | ||
config = toml.load(cf) | ||
provider = config.pop("provider").lower() | ||
if socket_path: | ||
_set_socket_path(socket_path) | ||
cloud.instance = reg.retrieve(provider, config=config) | ||
|
||
|
||
with open(config_filepath, "r") as cf: | ||
config = toml.load(cf) | ||
provider = config.pop("provider").lower() | ||
cloud.instance = reg.retrieve(provider, config=config) | ||
def _set_socket_path(local_socket_path): | ||
os.makedirs(os.path.dirname(local_socket_path), exist_ok=True) | ||
cloud.socket_path = local_socket_path | ||
|
||
|
||
def close(): | ||
utils.kill_transport() | ||
utils.kill_server() | ||
utils.kill_transport() | ||
utils.kill_server() | ||
|
||
|
||
def down(): | ||
cloud.instance.down() | ||
cloud.instance.down() | ||
|
||
|
||
def delete(confirm=True): | ||
cloud.instance.delete(confirm) | ||
cloud.instance.delete(confirm) | ||
|
||
|
||
atexit.register(close) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,31 @@ | ||
import os | ||
import requests | ||
|
||
from cloud.envs import env | ||
from cloud.envs import registry | ||
from cloud.envs import utils | ||
|
||
from libcloud.compute.types import Provider | ||
import requests | ||
from libcloud.compute.providers import get_driver | ||
from libcloud.compute.types import Provider | ||
|
||
from cloud.envs import env, registry, utils | ||
|
||
|
||
@registry.register("aws") | ||
class AWSInstance(env.Instance): | ||
|
||
def __init__(self, config, **kwargs): | ||
super().__init__(**kwargs) | ||
self.access_key = config["access_key"] | ||
self.secret_key = config["secret_key"] | ||
self.region = config["region"] | ||
def __init__(self, config, **kwargs): | ||
super().__init__(**kwargs) | ||
self.access_key = config["access_key"] | ||
self.secret_key = config["secret_key"] | ||
self.region = config["region"] | ||
|
||
@property | ||
def driver(self): | ||
if getattr(self, '_driver', None) is None: | ||
self._driver = get_driver(Provider.EC2)( | ||
self.access_key, self.secret_key, region=self.region) | ||
return self._driver | ||
@property | ||
def driver(self): | ||
if getattr(self, '_driver', None) is None: | ||
self._driver = get_driver(Provider.EC2)(self.access_key, self.secret_key, region=self.region) | ||
return self._driver | ||
|
||
@property | ||
def name(self): | ||
if getattr(self, '_name', None) is None: | ||
# https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html | ||
r = requests.get("http://169.254.169.254/latest/meta-data/instance-id") | ||
self._name = r.text | ||
return self._name | ||
@property | ||
def name(self): | ||
if getattr(self, '_name', None) is None: | ||
# https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html | ||
r = requests.get("http://169.254.169.254/latest/meta-data/instance-id") | ||
self._name = r.text | ||
return self._name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,30 @@ | ||
from cloud.envs import env | ||
from cloud.envs import registry | ||
from cloud.envs import utils | ||
|
||
from libcloud.compute.types import Provider | ||
from libcloud.compute.providers import get_driver | ||
from libcloud.compute.types import Provider | ||
|
||
from cloud.envs import env, registry, utils | ||
|
||
|
||
@registry.register("azure") | ||
class AzureInstance(env.Instance): | ||
|
||
def __init__(self, config, **kwargs): | ||
super().__init__(**kwargs) | ||
self.application_id = config["application_id"] | ||
self.subscription_id = config["subscription_id"] | ||
self.tenant_id = config["tenant_id"] | ||
self.key = config["key"] | ||
def __init__(self, config, **kwargs): | ||
super().__init__(**kwargs) | ||
self.application_id = config["application_id"] | ||
self.subscription_id = config["subscription_id"] | ||
self.tenant_id = config["tenant_id"] | ||
self.key = config["key"] | ||
|
||
@property | ||
def driver(self): | ||
if getattr(self, '_driver', None) is None: | ||
self._driver = get_driver(Provider.AZURE_ARM)( | ||
tenant_id=self.tenant_id, | ||
subscription_id=self.subscription_id, | ||
key=self.application_id, | ||
secret=self.key) | ||
return self._driver | ||
@property | ||
def driver(self): | ||
if getattr(self, '_driver', None) is None: | ||
self._driver = get_driver(Provider.AZURE_ARM)(tenant_id=self.tenant_id, | ||
subscription_id=self.subscription_id, | ||
key=self.application_id, | ||
secret=self.key) | ||
return self._driver | ||
|
||
@property | ||
def name(self): | ||
if getattr(self, '_name', None) is None: | ||
self._name = utils.call(["hostname"])[1].strip() | ||
return self._name | ||
@property | ||
def name(self): | ||
if getattr(self, '_name', None) is None: | ||
self._name = utils.call(["hostname"])[1].strip() | ||
return self._name |
Oops, something went wrong.