Skip to content

Commit

Permalink
[inductor] parallel-compile: call triton_key() before forking (pytorc…
Browse files Browse the repository at this point in the history
…h#127639)

Summary:
A user reported severe slowdown on a workload when using parallel compile. The issue is that in some environments, the process affinity changes after forking such that all forked subprocesses use a single logical processor. Described here: pytorch#99625. That requires a separate fix, but during debuging we noticed that we can at least optimize the expensive call to triton_key() before forking.

Pull Request resolved: pytorch#127639
Approved by: https://github.com/eellison, https://github.com/anijain2305
  • Loading branch information
masnesral authored and pytorchmergebot committed Jun 7, 2024
1 parent 96806b1 commit e8e0bdf
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
23 changes: 20 additions & 3 deletions torch/_inductor/async_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")


def pre_fork_setup():
"""
Setup that must be done prior to forking with a process pool.
"""
# ensure properties have been calculated before processes
# are forked
caching_device_properties()

# Computing the triton key can be slow. If we call it before fork,
# it will be cached for the forked subprocesses.
try:
from triton.compiler.compiler import triton_key

triton_key()
except ModuleNotFoundError:
# Might not be installed.
pass


def caching_device_properties():
for _, device_interface in get_registered_device_interfaces():
if device_interface.is_available():
Expand Down Expand Up @@ -115,9 +134,7 @@ def process_pool() -> AnyPool:
# Wrapper around ProcessPoolExecutor forks in a new process we control
pool = SubprocPool(config.compile_threads)
else:
# ensure properties have been calculated before processes
# are forked
caching_device_properties()
pre_fork_setup()
ctx = multiprocessing.get_context(config.worker_start_method)
pool = ProcessPoolExecutor(
config.compile_threads,
Expand Down
6 changes: 2 additions & 4 deletions torch/_inductor/compile_worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import typing

from torch._inductor.async_compile import caching_device_properties
from torch._inductor.async_compile import pre_fork_setup
from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path
Expand Down Expand Up @@ -34,9 +34,7 @@ def main():
# redirect output of workers to stderr
os.dup2(sys.stderr.fileno(), sys.stdout.fileno())

# ensure properties have been calculated before processes
# are forked
caching_device_properties()
pre_fork_setup()

_async_compile_initializer(args.parent)
SubprocMain(args.workers, read_fd, write_fd).main()
Expand Down

0 comments on commit e8e0bdf

Please sign in to comment.