diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 633946bb4ed84..b8e3d338dd9b7 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -47,6 +47,25 @@ kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") +def pre_fork_setup(): + """ + Setup that must be done prior to forking with a process pool. + """ + # ensure properties have been calculated before processes + # are forked + caching_device_properties() + + # Computing the triton key can be slow. If we call it before fork, + # it will be cached for the forked subprocesses. + try: + from triton.compiler.compiler import triton_key + + triton_key() + except ModuleNotFoundError: + # Might not be installed. + pass + + def caching_device_properties(): for _, device_interface in get_registered_device_interfaces(): if device_interface.is_available(): @@ -115,9 +134,7 @@ def process_pool() -> AnyPool: # Wrapper around ProcessPoolExecutor forks in a new process we control pool = SubprocPool(config.compile_threads) else: - # ensure properties have been calculated before processes - # are forked - caching_device_properties() + pre_fork_setup() ctx = multiprocessing.get_context(config.worker_start_method) pool = ProcessPoolExecutor( config.compile_threads, diff --git a/torch/_inductor/compile_worker/__main__.py b/torch/_inductor/compile_worker/__main__.py index e478a53456752..fc8148f20c5fb 100644 --- a/torch/_inductor/compile_worker/__main__.py +++ b/torch/_inductor/compile_worker/__main__.py @@ -3,7 +3,7 @@ import sys import typing -from torch._inductor.async_compile import caching_device_properties +from torch._inductor.async_compile import pre_fork_setup from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path @@ -34,9 +34,7 @@ def main(): # redirect output of workers to stderr os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) - # ensure properties have been calculated before processes - # are forked - caching_device_properties() + pre_fork_setup() _async_compile_initializer(args.parent) SubprocMain(args.workers, read_fd, write_fd).main()