Skip to content

Commit

Permalink
Script for aggregated mode requests updated.
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaskov committed Dec 22, 2023
1 parent a7f5455 commit 9b806ce
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 73 deletions.
98 changes: 39 additions & 59 deletions scripts/aggregated_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import tempfile
import time
from datetime import datetime, timedelta
from request_tools import get as get_request
from request_tools import push as push_request


MERGE_TASK_KEY = 1234 # TODO: peplace with actual merge task key, adjust task input as well
INTERMEDIATE_LAYER_VERIFIER_TASK = 1234 # TODO: peplace with actual merge task key, adjust task input as well
ROOT_VERIFIER_TASK = 4321 # TODO: peplace with actual merge task key, adjust task input as well


def progress_bar(iterable, prefix="", suffix="", fill="█"):
Expand All @@ -34,7 +37,8 @@ class TaskDistributor:
class TimeoutError(Exception):
pass

def __init__(self, scripts_path, subtasks_number, task_timeout, poll_interval):
def __init__(self, auth, scripts_path, subtasks_number, task_timeout, poll_interval):
self.auth = auth
self.scripts_path = scripts_path
self.subtasks_number = subtasks_number
self.task_timeout = task_timeout
Expand All @@ -54,57 +58,32 @@ def _run_command(command):
result = subprocess.run(command, stderr=subprocess.PIPE, text=True)
return result.stderr

def distribute_and_merge_tasks(self, key, file, cost):
completed_tasks = self._run_tasks(key, file, cost)
def distribute_and_merge_tasks(self, key, input_file, cost):
completed_tasks = self._order_layer0_tasks(key, input_file, cost)
return self._merge_proofs(completed_tasks, cost)

def _push_task(self, key, file, cost, subkey=None):
cmd = [
"python3",
f"{self.scripts_path}/request_tools.py",
"push",
"--cost", str(cost),
"--file", file,
"--key", str(key),
]
if subkey is not None:
cmd += ["--subkey", str(subkey)]
result = self._run_command(cmd)
response = self._extract_json(result)
return response["_key"]
def _push_task(self, statement_key, input_file, cost, aggregated_mode_id):
response = push_request(self.auth, statement_key, input_file, cost, aggregated_mode_id=aggregated_mode_id, verbose=True)
return response["_key"]["id"]

def _get_proof(self, request_key):
cmd = [
"python3",
f"{self.scripts_path}/proof_tools.py",
"get",
"--request_key", str(request_key),
]
result = self._run_command(cmd)
response = self._extract_json(result)
response = get_request(self.auth, request_key, verbose=True)
return response["proof"]

def _get_status(self, task_key):
cmd = [
"python3",
f"{self.scripts_path}/request_tools.py",
"get",
"--key", str(task_key),
]
result = self._run_command(cmd)
response = self._extract_json(result)
def _get_status(self, request_key):
response = get_request(self.auth, request_key, verbose=True)
return response["status"]

def _wait_for_completion(self, task_key):
def _wait_for_completion(self, request_key):
end_time = datetime.now() + timedelta(seconds=self.task_timeout)
while datetime.now() < end_time:
status = self._get_status(task_key)
if status and status == "completed":
status = self._get_status(request_key)
if status and status == "DONE":
return
time.sleep(self.poll_interval)
raise TimeoutError(f"Task {task_key} timed out.")
raise TimeoutError(f"Task {request_key} timed out.")

def _process_level(self, tasks, cost):
def _process_intermediate_layer(self, tasks, cost):
new_tasks = []
i = 0
while i < len(tasks):
Expand All @@ -115,32 +94,32 @@ def _process_level(self, tasks, cost):
proofs.append(self._get_proof(tasks[i]))
i += 1

with tempfile.NamedTemporaryFile(mode="w") as tmp_file:
json.dump(proofs, tmp_file)
tmp_file.flush()
with tempfile.NamedTemporaryFile(mode="w") as joined_proofs_file:
json.dump(proofs, joined_proofs_file)
joined_proofs_file.flush()
combine_task_key = self._push_task(
MERGE_TASK_KEY, tmp_file.name, cost)
INTERMEDIATE_LAYER_VERIFIER_TASK, joined_proofs_file.name, cost)

new_tasks.append(combine_task_key)

for task_key in progress_bar(new_tasks, prefix="Merges awaited:"):
self._wait_for_completion(task_key)
for request_key in progress_bar(new_tasks, prefix="Intermediate layers awaited:"):
self._wait_for_completion(request_key)

return new_tasks

def _run_tasks(self, key, file, cost):
def _order_layer0_tasks(self, key, input_file, cost):
tasks = [
self._push_task(key, file, cost, i) for i in range(self.subtasks_number)
self._push_task(key, input_file, cost, aggregated_mode_id) for aggregated_mode_id in range(self.subtasks_number)
]
for task_key in progress_bar(tasks, prefix="Proofs awaited:"):
self._wait_for_completion(task_key)
for request_key in progress_bar(tasks, prefix="Proofs awaited:"):
self._wait_for_completion(request_key)

return tasks

def _merge_proofs(self, tasks, cost):
# Process results in a Merkle tree fashion
# Process results in a binary tree fashion
while len(tasks) > 1:
tasks = self._process_level(tasks, cost)
tasks = self._process_layer(tasks, cost)

return self._get_proof(tasks[0])

Expand All @@ -149,28 +128,29 @@ def main():
parser = argparse.ArgumentParser(
description="Distribute tasks and assemble results in a Merkle tree fashion."
)
parser.add_argument("--auth", type=str, help="auth file")
parser.add_argument(
"--scripts-path",
required=True,
help="Path to the directory containing the scripts",
)
parser.add_argument(
"--key", required=True, help="Key to be forwarded to request_tool.py"
"--statement-key", required=True, help="Key to be forwarded to request_tool.py"
)
parser.add_argument(
"--file", required=True, help="File to be forwarded to request_tool.py"
"--input", required=True, help="Input file for the proof producer"
)
parser.add_argument(
"--cost",
type=int,
required=True,
help="Cost parameter to be forwarded to request_tool.py",
help="Cost parameter for proof producing",
)
parser.add_argument(
"--subtasks-number",
"--aggregation-ratio",
type=int,
required=True,
help="How many subtasks to split into",
help="How many provers to use on the zero layer",
)
parser.add_argument(
"--task-timeout",
Expand All @@ -185,10 +165,10 @@ def main():
args = parser.parse_args()

distributor = TaskDistributor(
args.scripts_path, args.subtasks_number, args.task_timeout, args.poll_interval
args.auth, args.scripts_path, args.aggregation_ratio, args.task_timeout, args.poll_interval
)
merged_proof = distributor.distribute_and_merge_tasks(
args.key, args.file, args.cost
args.statement_key, args.input, args.cost
)

print(f"Final proof is: {merged_proof}")
Expand Down
28 changes: 14 additions & 14 deletions scripts/request_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ def get_prepared_input(input_file):
input = json.load(f)
return input

def push(auth, key, file, cost, subkey=None, verbose=False):
def push(auth, key, file, cost, aggregated_mode_id=None, verbose=False):
data = {
"statement_key": key,
"input": get_prepared_input(file),
"cost": cost,
}
if subkey is not None:
data["statement_subkey"] = subkey
if aggregated_mode_id is not None:
data["aggregated_mode_id"] = aggregated_mode_id

headers = get_headers(auth)
url = URL + "/request"
Expand All @@ -41,7 +41,7 @@ def get(auth, key=None, request_status=None, verbose=False):
if request_status:
url += f'?q=[{{"key" : "status", "value" : "{request_status}"}}]&limit=100'
elif key:
url += key
url += str(key)
else:
url += "?limit=100"
res = requests.get(url=url, headers=headers)
Expand All @@ -58,11 +58,11 @@ def get(auth, key=None, request_status=None, verbose=False):


def push_parser(args):
push(args.auth, args.key, args.file, args.cost, args.subkey, args.verbose)
push(args.auth, args.statement_key, args.input, args.cost, verbose=args.verbose)


def get_parser(args):
get(args.auth, args.key, args.request_status, args.verbose)
get(args.auth, args.request_key, args.request_status, args.verbose)


if __name__ == "__main__":
Expand All @@ -74,20 +74,20 @@ def get_parser(args):
"-v", "--verbose", action="store_true", help="increase output verbosity"
)
subparsers = parser.add_subparsers(help="sub-command help")
parser_push = subparsers.add_parser("push", help="push request")
parser_push.set_defaults(func=push_parser)
parser_get = subparsers.add_parser("get", help="get request")
parser_get.set_defaults(func=get_parser)
parser_get.add_argument("--key", type=str, help="request key")
parser_get.add_argument("--request_status", type=str, help="request status")
parser_get.add_argument("--request-key", type=str, help="request key")
parser_get.add_argument("--request-status", type=str, help="request status")

parser_push = subparsers.add_parser("push", help="push request")
parser_push.set_defaults(func=push_parser)
parser_push.add_argument("--cost", type=float, required=True, help="cost")
parser_push.add_argument(
"--file", type=str, required=True, help="json file with public input"
"--input", type=str, required=True, help="json file with public input"
)
parser_push.add_argument("--key", type=str, required=True, help="statement key")
parser_push.add_argument("--subkey", type=str, help="statement key")
parser_push.add_argument("--statement-key", type=str, required=True, help="statement key")
parser_push.add_argument(
"--generation_time",
"--generation-time",
default=30,
type=int,
help="required proof time generation (in mins)",
Expand Down

0 comments on commit 9b806ce

Please sign in to comment.