Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async runner improvements #2056

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
12 changes: 4 additions & 8 deletions metaflow/plugins/argo/argo_workflows_deployer_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metaflow.plugins.argo.argo_workflows import ArgoWorkflows
from metaflow.runner.deployer import Deployer, DeployedFlow, TriggeredRun

from metaflow.runner.utils import get_lower_level_group, handle_timeout
from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo


def generate_fake_flow_file_contents(
Expand Down Expand Up @@ -341,18 +341,14 @@ def trigger(self, **kwargs) -> ArgoWorkflowsTriggeredRun:
Exception
If there is an error during the trigger process.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)

with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
self.deployer.api,
self.deployer.top_level_kwargs,
self.deployer.TYPE,
self.deployer.deployer_kwargs,
).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).trigger(deployer_attribute_file=attribute_file_path, **kwargs)

pid = self.deployer.spm.run_command(
[sys.executable, *command],
Expand All @@ -363,7 +359,7 @@ def trigger(self, **kwargs) -> ArgoWorkflowsTriggeredRun:

command_obj = self.deployer.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, self.deployer.file_read_timeout
attribute_file_fd, command_obj, self.deployer.file_read_timeout
)

if command_obj.process.returncode == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from metaflow.plugins.aws.step_functions.step_functions import StepFunctions
from metaflow.runner.deployer import DeployedFlow, TriggeredRun

from metaflow.runner.utils import get_lower_level_group, handle_timeout
from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo


class StepFunctionsTriggeredRun(TriggeredRun):
Expand Down Expand Up @@ -196,18 +196,14 @@ def trigger(self, **kwargs) -> StepFunctionsTriggeredRun:
Exception
If there is an error during the trigger process.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)

with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
self.deployer.api,
self.deployer.top_level_kwargs,
self.deployer.TYPE,
self.deployer.deployer_kwargs,
).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).trigger(deployer_attribute_file=attribute_file_path, **kwargs)

pid = self.deployer.spm.run_command(
[sys.executable, *command],
Expand All @@ -218,7 +214,7 @@ def trigger(self, **kwargs) -> StepFunctionsTriggeredRun:

command_obj = self.deployer.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, self.deployer.file_read_timeout
attribute_file_fd, command_obj, self.deployer.file_read_timeout
)

if command_obj.process.returncode == 0:
Expand Down
12 changes: 4 additions & 8 deletions metaflow/runner/deployer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import json
import os
import sys
import tempfile

from typing import Any, ClassVar, Dict, Optional, TYPE_CHECKING, Type

from .subprocess_manager import SubprocessManager
from .utils import get_lower_level_group, handle_timeout
from .utils import get_lower_level_group, handle_timeout, temporary_fifo

if TYPE_CHECKING:
import metaflow.runner.deployer
Expand Down Expand Up @@ -121,14 +120,11 @@ def create(self, **kwargs) -> "metaflow.runner.deployer.DeployedFlow":
def _create(
self, create_class: Type["metaflow.runner.deployer.DeployedFlow"], **kwargs
) -> "metaflow.runner.deployer.DeployedFlow":
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
self.api, self.top_level_kwargs, self.TYPE, self.deployer_kwargs
).create(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).create(deployer_attribute_file=attribute_file_path, **kwargs)

pid = self.spm.run_command(
[sys.executable, *command],
Expand All @@ -139,7 +135,7 @@ def _create(

command_obj = self.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
attribute_file_fd, command_obj, self.file_read_timeout
)
content = json.loads(content)
self.name = content.get("name")
Expand Down
63 changes: 33 additions & 30 deletions metaflow/runner/metaflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import os
import sys
import json
import tempfile

from typing import Dict, Iterator, Optional, Tuple

from metaflow import Run

from .utils import handle_timeout
from .utils import (
temporary_fifo,
handle_timeout,
async_handle_timeout,
)
from .subprocess_manager import CommandManager, SubprocessManager


Expand Down Expand Up @@ -267,10 +270,8 @@ def __enter__(self) -> "Runner":
async def __aenter__(self) -> "Runner":
return self

def __get_executing_run(self, tfp_runner_attribute, command_obj):
content = handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
)
def __get_executing_run(self, attribute_file_fd, command_obj):
content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout)
content = json.loads(content)
pathspec = "%s/%s" % (content.get("flow_name"), content.get("run_id"))

Expand All @@ -282,6 +283,20 @@ def __get_executing_run(self, tfp_runner_attribute, command_obj):
)
return ExecutingRun(self, command_obj, run_object)

async def __async_get_executing_run(self, attribute_file_fd, command_obj):
content = await async_handle_timeout(
attribute_file_fd, command_obj, self.file_read_timeout
)
content = json.loads(content)
pathspec = "%s/%s" % (content.get("flow_name"), content.get("run_id"))

# Set the correct metadata from the runner_attribute file corresponding to this run.
metadata_for_flow = content.get("metadata")
metadata(metadata_for_flow)

run_object = Run(pathspec, _namespace_check=False)
madhur-ob marked this conversation as resolved.
Show resolved Hide resolved
return ExecutingRun(self, command_obj, run_object)

def run(self, **kwargs) -> ExecutingRun:
"""
Blocking execution of the run. This method will wait until
Expand All @@ -298,12 +313,9 @@ def run(self, **kwargs) -> ExecutingRun:
ExecutingRun
ExecutingRun containing the results of the run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).run(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = self.spm.run_command(
Expand All @@ -314,7 +326,7 @@ def run(self, **kwargs) -> ExecutingRun:
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return self.__get_executing_run(attribute_file_fd, command_obj)

def resume(self, **kwargs):
"""
Expand All @@ -332,12 +344,9 @@ def resume(self, **kwargs):
ExecutingRun
ExecutingRun containing the results of the resumed run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).resume(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = self.spm.run_command(
Expand All @@ -348,7 +357,7 @@ def resume(self, **kwargs):
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return self.__get_executing_run(attribute_file_fd, command_obj)

async def async_run(self, **kwargs) -> ExecutingRun:
"""
Expand All @@ -368,12 +377,9 @@ async def async_run(self, **kwargs) -> ExecutingRun:
ExecutingRun
ExecutingRun representing the run that was started.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).run(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = await self.spm.async_run_command(
Expand All @@ -383,7 +389,7 @@ async def async_run(self, **kwargs) -> ExecutingRun:
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return await self.__async_get_executing_run(attribute_file_fd, command_obj)

async def async_resume(self, **kwargs):
"""
Expand All @@ -403,12 +409,9 @@ async def async_resume(self, **kwargs):
ExecutingRun
ExecutingRun representing the resumed run that was started.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).resume(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = await self.spm.async_run_command(
Expand All @@ -418,7 +421,7 @@ async def async_resume(self, **kwargs):
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return await self.__async_get_executing_run(attribute_file_fd, command_obj)

def __exit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()
Expand Down
55 changes: 46 additions & 9 deletions metaflow/runner/subprocess_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,49 @@
import threading
from typing import Callable, Dict, Iterator, List, Optional, Tuple

from .utils import check_process_exited

def kill_process_and_descendants(pid, termination_timeout):

def kill_processes_and_descendants(pids: List[str], termination_timeout: float):
# TODO: there's a race condition that new descendants might
# spawn b/w the invocations of 'pkill' and 'kill'.
# Needs to be fixed in future.
try:
subprocess.check_call(["pkill", "-TERM", "-P", str(pid)])
subprocess.check_call(["kill", "-TERM", str(pid)])
subprocess.check_call(["pkill", "-TERM", "-P", *pids])
subprocess.check_call(["kill", "-TERM", *pids])
except subprocess.CalledProcessError:
pass

time.sleep(termination_timeout)

try:
subprocess.check_call(["pkill", "-KILL", "-P", str(pid)])
subprocess.check_call(["kill", "-KILL", str(pid)])
subprocess.check_call(["pkill", "-KILL", "-P", *pids])
subprocess.check_call(["kill", "-KILL", *pids])
except subprocess.CalledProcessError:
pass


async def async_kill_processes_and_descendants(
pids: List[str], termination_timeout: float
):
# TODO: there's a race condition that new descendants might
# spawn b/w the invocations of 'pkill' and 'kill'.
# Needs to be fixed in future.
sub_term = await asyncio.create_subprocess_exec("pkill", "-TERM", "-P", *pids)
await sub_term.wait()

main_term = await asyncio.create_subprocess_exec("kill", "-TERM", *pids)
await main_term.wait()

await asyncio.sleep(termination_timeout)

sub_kill = await asyncio.create_subprocess_exec("pkill", "-KILL", "-P", *pids)
await sub_kill.wait()

main_kill = await asyncio.create_subprocess_exec("kill", "-KILL", *pids)
await main_kill.wait()


class LogReadTimeoutError(Exception):
"""Exception raised when reading logs times out."""

Expand All @@ -46,14 +69,28 @@ def __init__(self):
loop = asyncio.get_running_loop()
loop.add_signal_handler(
signal.SIGINT,
lambda: self._handle_sigint(signum=signal.SIGINT, frame=None),
lambda: asyncio.create_task(self._async_handle_sigint()),
)
except RuntimeError:
signal.signal(signal.SIGINT, self._handle_sigint)

async def _async_handle_sigint(self):
pids = [
str(command.process.pid)
for command in self.commands.values()
if command.process and not check_process_exited(command)
]
if pids:
await async_kill_processes_and_descendants(pids, termination_timeout=2)

def _handle_sigint(self, signum, frame):
for each_command in self.commands.values():
each_command.kill(termination_timeout=2)
pids = [
str(command.process.pid)
for command in self.commands.values()
if command.process and not check_process_exited(command)
]
if pids:
kill_processes_and_descendants(pids, termination_timeout=2)

async def __aenter__(self) -> "SubprocessManager":
return self
Expand Down Expand Up @@ -472,7 +509,7 @@ def kill(self, termination_timeout: float = 2):
"""

if self.process is not None:
kill_process_and_descendants(self.process.pid, termination_timeout)
kill_processes_and_descendants([str(self.process.pid)], termination_timeout)
else:
print("No process to kill.")

Expand Down
Loading