Skip to content

Commit

Permalink
Add logging, pre-commit hooks, fix setup.py incorrect paths (#141)
Browse files Browse the repository at this point in the history
* Add in root.pem. change quickstart dir to playground

* Add logging; add pre-commit hooks; fix setup.py incorrect paths
  • Loading branch information
chester-leung authored Jun 11, 2021
1 parent 31cb6d2 commit 8817d88
Show file tree
Hide file tree
Showing 18 changed files with 538 additions and 414 deletions.
3 changes: 3 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MaxEmptyLinesToKeep: 2
IndentWidth: 4
NamespaceIndentation: All
5 changes: 5 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[flake8]
ignore = E203, E266, E501, W503, W605, F403, F401, C901
max-line-length = 79
max-complexity = 18
select = B,C,E,F,W,T4,B9
20 changes: 20 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
repos:
- repo: https://github.com/ambv/black
rev: 21.6b0
hooks:
- id: black
language_version: python3
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
- repo: https://github.com/pocc/pre-commit-hooks
rev: v1.1.1
hooks:
- id: clang-format
exclude: ^(src/include/(csv.hpp|json.hpp|base64.h))
args:
- -i
- --style=file
- --fallback-style=Chromium

61 changes: 33 additions & 28 deletions mc2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import logging
import os
import pathlib
import shutil
Expand All @@ -10,6 +11,13 @@

from envyaml import EnvYAML

# Configure logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
logging.Formatter.converter = time.gmtime

parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(help="Command to run.", dest="command")
Expand All @@ -20,19 +28,15 @@
)

# -------------Launch------------------
parser_launch = subparsers.add_parser(
"launch", help="Launch Azure resources"
)
parser_launch = subparsers.add_parser("launch", help="Launch Azure resources")

# -------------Start-------------------
parser_start = subparsers.add_parser(
"start", help="Start services using specificed start up commands"
)

# -------------Upload----------------
parser_upload = subparsers.add_parser(
"upload", help="Encrypt and upload data."
)
parser_upload = subparsers.add_parser("upload", help="Encrypt and upload data.")

# -------------Run--------------------
parser_run = subparsers.add_parser(
Expand All @@ -45,19 +49,17 @@
)

# -------------Stop-------------------
parser_stop = subparsers.add_parser(
"stop", help="Stop previously started service"
)
parser_stop = subparsers.add_parser("stop", help="Stop previously started service")

# -------------Teardown---------------
parser_teardown = subparsers.add_parser(
"teardown", help="Teardown Azure resources"
)
parser_teardown = subparsers.add_parser("teardown", help="Teardown Azure resources")

if __name__ == "__main__":
oc_config = os.environ.get("MC2_CONFIG")
if not oc_config:
raise Exception("Please set the environment variable `MC2_CONFIG` to the path of your config file")
raise Exception(
"Please set the environment variable `MC2_CONFIG` to the path of your config file"
)

mc2.set_config(oc_config)
args = parser.parse_args()
Expand All @@ -75,8 +77,10 @@

# If the nodes have been manually specified, don't do anything
if config_launch.get("head") or config_launch.get("workers"):
print("Node addresses have been manually specified in the config "\
"... doing nothing")
logging.warning(
"Node addresses have been manually specified in the config "
"... doing nothing"
)
quit()

# Create resource group
Expand Down Expand Up @@ -123,16 +127,16 @@

encrypted_data = [d + ".enc" for d in data]

print("Encrypting and uploading data...")

dst_dir = config_upload.get("dst", "")
for i in range(len(data)):
# Encrypt data
if enc_format == "xgb":
mc2.encrypt_data(data[i], encrypted_data[i], None, "xgb")
elif enc_format == "sql":
if schemas is None:
raise Exception("Please specify a schema when uploading data for Opaque SQL")
raise Exception(
"Please specify a schema when uploading data for Opaque SQL"
)
# Remove temporary files from a previous run
if os.path.exists(encrypted_data[i]):
if os.path.isdir(encrypted_data[i]):
Expand All @@ -150,7 +154,6 @@
if dst_dir:
remote_path = os.path.join(dst_dir, filename)
mc2.upload_file(encrypted_data[i], remote_path, use_azure)
print("Uploaded data to {}".format(remote_path))

# Remove temporary directory
if os.path.isdir(encrypted_data[i]):
Expand All @@ -163,7 +166,7 @@
script = config_run["script"]

if config_run["compute"] == "xgb":
print("run() unimplemented for secure-xgboost")
logging.error("run() unimplemented for secure-xgboost")
quit()
elif config_run["compute"] == "sql":
mc2.configure_job(config)
Expand All @@ -183,8 +186,6 @@
remote_results = config_download.get("src", [])
local_results_dir = config_download["dst"]

print("Downloading and decrypting data")

# Create the local results directory if it doesn't exist
if not os.path.exists(local_results_dir):
pathlib.Path(local_results_dir).mkdir(parents=True, exist_ok=True)
Expand All @@ -195,15 +196,12 @@

# Fetch file
mc2.download_file(remote_result, local_result, use_azure)
print("Downloaded result to ", local_result)

# Decrypt data
if enc_format == "xgb":
mc2.decrypt_data(local_result, local_result + ".dec", "xgb")
print("Decrypted result saved to ", local_result + ".dec")
elif enc_format == "sql":
mc2.decrypt_data(local_result, local_result + ".dec", "sql")
print("Decrypted result saved to ", local_result + ".dec")
else:
raise Exception("Specified format {} not supported".format(enc_format))

Expand All @@ -213,16 +211,18 @@
os.remove(local_result)

elif args.command == "stop":
print("Currently unsupported")
logging.error("`opaque stop` is currently unsupported")
pass

elif args.command == "teardown":
config_teardown = config["teardown"]

# If the nodes have been manually specified, don't do anything
if config["launch"].get("head") or config["launch"].get("workers"):
print("Node addresses have been manually specified in the config "\
"... doing nothing")
logging.warning(
"Node addresses have been manually specified in the config "
"... doing nothing"
)
quit()

delete_container = config_teardown.get("container")
Expand All @@ -240,3 +240,8 @@
delete_resource_group = config_teardown.get("resource_group")
if delete_resource_group:
mc2.delete_resource_group()

else:
logging.error(
"Unsupported command specified. Please type `opaque -h` for a list of supported commands."
)
Loading

0 comments on commit 8817d88

Please sign in to comment.