From 6e6eb8fcafd6c074e156fc419fbc38c6587b4624 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Mon, 15 Jul 2024 02:58:19 -0700 Subject: [PATCH] Restructure et_relay subpackage (#133) Summary: Pull Request resolved: https://github.com/facebookresearch/param/pull/133 1. Make `et_replay` structure easier to refactor: added an entry point to export public symbols, used it in train code. 2. Enabled tests run as part of GHA CI. Reviewed at https://github.com/facebookresearch/param/pull/124 Pull Request resolved: https://github.com/facebookresearch/param/pull/129 Test Plan: CI ## Additional Notes -- Differential Revision: D59178710 Pulled By: sanrise --- .github/workflows/python_lint.yml | 14 +++++++ .gitignore | 2 + et_replay/__init__.py | 3 ++ et_replay/{lib => }/comm/commsTraceParser.py | 13 ++----- et_replay/{lib => }/comm/comms_utils.py | 11 ++---- et_replay/{lib => }/comm/param_profile.py | 0 .../{lib => }/comm/pytorch_backend_utils.py | 2 +- .../{lib => }/comm/pytorch_dist_backend.py | 10 ++--- .../{lib => }/comm/pytorch_tpu_backend.py | 2 +- et_replay/{lib => }/et_replay_utils.py | 15 ++++---- et_replay/{lib => }/execution_trace.py | 0 et_replay/pyproject.toml | 4 +- et_replay/tests/test_execution_trace.py | 6 +-- et_replay/tools/comm_replay.py | 3 +- et_replay/tools/et_replay.py | 37 +++++++++---------- et_replay/tools/validate_trace.py | 2 +- et_replay/{lib => }/utils.py | 2 +- requirements.txt | 1 + train/comms/pt/commsTraceParser.py | 2 +- 19 files changed, 66 insertions(+), 63 deletions(-) create mode 100644 .gitignore create mode 100644 et_replay/__init__.py rename et_replay/{lib => }/comm/commsTraceParser.py (98%) rename et_replay/{lib => }/comm/comms_utils.py (99%) rename et_replay/{lib => }/comm/param_profile.py (100%) rename et_replay/{lib => }/comm/pytorch_backend_utils.py (99%) rename et_replay/{lib => }/comm/pytorch_dist_backend.py (99%) rename et_replay/{lib => }/comm/pytorch_tpu_backend.py (98%) rename et_replay/{lib => }/et_replay_utils.py (98%) rename et_replay/{lib => }/execution_trace.py (100%) rename et_replay/{lib => }/utils.py (96%) diff --git a/.github/workflows/python_lint.yml b/.github/workflows/python_lint.yml index b92f35a6..00d218e9 100644 --- a/.github/workflows/python_lint.yml +++ b/.github/workflows/python_lint.yml @@ -21,3 +21,17 @@ jobs: - name: Run Black run: black . --check + + - name: Run tests + run: | + python -m pip install -r requirements.txt + python -m pip install et_replay/ + python et_replay/tests/test_execution_trace.py + + - name: Validate imports + run: | + python -m pip install fbgemm-gpu + python -c 'from et_replay import ExecutionTrace' + python -c 'from et_replay.comm import comms_utils' + python -c 'from et_replay.tools.validate_trace import TraceValidator' + python -c 'from et_replay.utils import trace_handler' diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..a230a78a --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.venv/ +__pycache__/ diff --git a/et_replay/__init__.py b/et_replay/__init__.py new file mode 100644 index 00000000..e9864914 --- /dev/null +++ b/et_replay/__init__.py @@ -0,0 +1,3 @@ +from et_replay.execution_trace import ExecutionTrace + +__all__ = ["ExecutionTrace"] diff --git a/et_replay/lib/comm/commsTraceParser.py b/et_replay/comm/commsTraceParser.py similarity index 98% rename from et_replay/lib/comm/commsTraceParser.py rename to et_replay/comm/commsTraceParser.py index c9a07cd3..aa23f22a 100644 --- a/et_replay/lib/comm/commsTraceParser.py +++ b/et_replay/comm/commsTraceParser.py @@ -2,14 +2,12 @@ from __future__ import annotations import json - from typing import List, Tuple -from et_replay.lib.comm import comms_utils -from et_replay.lib.comm.comms_utils import commsArgs -from et_replay.lib.comm.pytorch_backend_utils import supportedP2pOps - -from et_replay.lib.execution_trace import ExecutionTrace +from et_replay import ExecutionTrace +from et_replay.comm import comms_utils +from et_replay.comm.comms_utils import commsArgs +from et_replay.comm.pytorch_backend_utils import supportedP2pOps tensorDtypeMap = { "Tensor(int)": "int", @@ -63,7 +61,6 @@ def _parseBasicTrace(in_trace: List): """ newCommsTrace = [] for cnt, curComm in enumerate(in_trace): - newComm = commsArgs() newComm.id = cnt newComm.markerStack = curComm.get("markers") @@ -84,7 +81,6 @@ def _parseBasicTrace(in_trace: List): def _parseBasicTraceComms(curComm, newComm: commsArgs) -> None: - newComm.comms = comms_utils.paramToCommName(curComm["comms"].lower()) if newComm.markerStack is None: newComm.markerStack = [newComm.comms] @@ -165,7 +161,6 @@ def _parseKinetoUnitrace(in_trace: List, target_rank: int) -> List: and entry["name"] == "record_param_comms" and entry["args"]["rank"] == target_rank ): - newComm = commsArgs() newComm.comms = comms_utils.paramToCommName(entry["args"]["comms"].lower()) newComm.id = commsCnt diff --git a/et_replay/lib/comm/comms_utils.py b/et_replay/comm/comms_utils.py similarity index 99% rename from et_replay/lib/comm/comms_utils.py rename to et_replay/comm/comms_utils.py index 94ad2acc..6805303d 100644 --- a/et_replay/lib/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -38,8 +38,10 @@ import numpy as np import torch -from et_replay.lib.comm.param_profile import paramTimer -from et_replay.lib.comm.pytorch_backend_utils import ( + + +from et_replay.comm.param_profile import paramTimer +from et_replay.comm.pytorch_backend_utils import ( backendFunctions, collectiveArgsHolder, customized_backend, @@ -1072,7 +1074,6 @@ def _prep_all_to_all_single( scaleFactor: float, allocate: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - ipTensor = None opTensor = None if allocate: @@ -1193,7 +1194,6 @@ def _prep_all_gather_base( scaleFactor: float, allocate: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - opTensor = [] if not commsParams.size_from_trace: numElementsOut = numElementsIn @@ -1257,7 +1257,6 @@ def _prep_reduce_scatter( scaleFactor: float, allocate: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - ipTensor = [] opTensor = [] if not commsParams.size_from_trace: @@ -1304,7 +1303,6 @@ def _prep_reduce_scatter_base( scaleFactor: float, allocate: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - ipTensor = [] opTensor = [] if not commsParams.size_from_trace: @@ -1727,7 +1725,6 @@ def init_emb_lookup(collectiveArgs, commsParams, backendFuncs): try: # fbgemm_gpu can be downloaded from https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu from fbgemm_gpu.split_embedding_utils import generate_requests - from fbgemm_gpu.split_table_batched_embeddings_ops import ( ComputeDevice, EmbeddingLocation, diff --git a/et_replay/lib/comm/param_profile.py b/et_replay/comm/param_profile.py similarity index 100% rename from et_replay/lib/comm/param_profile.py rename to et_replay/comm/param_profile.py diff --git a/et_replay/lib/comm/pytorch_backend_utils.py b/et_replay/comm/pytorch_backend_utils.py similarity index 99% rename from et_replay/lib/comm/pytorch_backend_utils.py rename to et_replay/comm/pytorch_backend_utils.py index 35d72847..017888ec 100644 --- a/et_replay/lib/comm/pytorch_backend_utils.py +++ b/et_replay/comm/pytorch_backend_utils.py @@ -9,8 +9,8 @@ import torch -from et_replay.lib.comm.param_profile import paramTimer +from et_replay.comm.param_profile import paramTimer from torch.distributed import ProcessGroup logger = logging.getLogger(__name__) diff --git a/et_replay/lib/comm/pytorch_dist_backend.py b/et_replay/comm/pytorch_dist_backend.py similarity index 99% rename from et_replay/lib/comm/pytorch_dist_backend.py rename to et_replay/comm/pytorch_dist_backend.py index 6f1a0960..f7711fd5 100644 --- a/et_replay/lib/comm/pytorch_dist_backend.py +++ b/et_replay/comm/pytorch_dist_backend.py @@ -13,11 +13,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from et_replay.lib.comm.param_profile import paramProfile -from et_replay.lib.comm.pytorch_backend_utils import ( - backendFunctions, - collectiveArgsHolder, -) + +from et_replay.comm.param_profile import paramProfile +from et_replay.comm.pytorch_backend_utils import backendFunctions, collectiveArgsHolder + try: from param_bench.train.comms.pt.fb.internals import ( @@ -634,7 +633,6 @@ def complete_single_op(self, collectiveArgs, retFlag=False): self.device_sync(collectiveArgs) def wait(self, collectiveArgs, retFlag=False): - # for backwards compatibility, use old wait functionality. if len(collectiveArgs.waitObjIds) == 0: self.complete_single_op(collectiveArgs) diff --git a/et_replay/lib/comm/pytorch_tpu_backend.py b/et_replay/comm/pytorch_tpu_backend.py similarity index 98% rename from et_replay/lib/comm/pytorch_tpu_backend.py rename to et_replay/comm/pytorch_tpu_backend.py index 7c5675eb..def1bf6a 100644 --- a/et_replay/lib/comm/pytorch_tpu_backend.py +++ b/et_replay/comm/pytorch_tpu_backend.py @@ -7,7 +7,7 @@ import torch_xla.core.xla_model as xm # @manual import torch_xla.distributed.xla_multiprocessing as xmp # @manual -from .comms_utils import backendFunctions +from et_replay.comm.comms_utils import backendFunctions class PyTorchTPUBackend(backendFunctions): diff --git a/et_replay/lib/et_replay_utils.py b/et_replay/et_replay_utils.py similarity index 98% rename from et_replay/lib/et_replay_utils.py rename to et_replay/et_replay_utils.py index bced9dbe..5013bb62 100644 --- a/et_replay/lib/et_replay_utils.py +++ b/et_replay/et_replay_utils.py @@ -2,11 +2,10 @@ import re import torch -from et_replay.lib.execution_trace import NodeType -from fbgemm_gpu.split_table_batched_embeddings_ops import PoolingMode, WeightDecayMode +from et_replay.execution_trace import NodeType +from fbgemm_gpu.split_table_batched_embeddings_ops import PoolingMode, WeightDecayMode from param_bench.train.compute.python.lib.pytorch.config_util import create_op_args - from param_bench.train.compute.python.workloads.pytorch.split_table_batched_embeddings_ops import ( SplitTableBatchedEmbeddingBagsCodegenInputDataGenerator, SplitTableBatchedEmbeddingBagsCodegenOp, @@ -477,11 +476,11 @@ def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows) import os import time from datetime import datetime -from et_replay.lib.comm import comms_utils +from et_replay.comm import comms_utils import torch -from et_replay.lib.comm import commsTraceReplay -from et_replay.lib.et_replay_utils import ( +from et_replay.comm import commsTraceReplay +from et_replay.et_replay_utils import ( build_fbgemm_func, build_torchscript_func, generate_fbgemm_tensors, @@ -490,8 +489,8 @@ def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows) is_qualified, ) -from et_replay.lib.execution_trace import ExecutionTrace -from et_replay.lib.utils import trace_handler +from et_replay.execution_trace import ExecutionTrace +from et_replay.utils import trace_handler print("PyTorch version: ", torch.__version__) diff --git a/et_replay/lib/execution_trace.py b/et_replay/execution_trace.py similarity index 100% rename from et_replay/lib/execution_trace.py rename to et_replay/execution_trace.py diff --git a/et_replay/pyproject.toml b/et_replay/pyproject.toml index de2e3bd2..e5ee8cac 100644 --- a/et_replay/pyproject.toml +++ b/et_replay/pyproject.toml @@ -7,9 +7,7 @@ name = "et_replay" version = "0.5.0" [tool.setuptools.package-dir] -"et_replay.lib" = "lib" -"et_replay.lib.comm" = "lib/comm" -"et_replay.tools" = "tools" +"et_replay" = "." "param_bench" = ".." [project.scripts] diff --git a/et_replay/tests/test_execution_trace.py b/et_replay/tests/test_execution_trace.py index f5fddaf1..8a31a6c1 100644 --- a/et_replay/tests/test_execution_trace.py +++ b/et_replay/tests/test_execution_trace.py @@ -3,15 +3,15 @@ import os import unittest -from param_bench.train.compute.python.tools.execution_trace import ExecutionTrace -from param_bench.train.compute.python.tools.validate_trace import TraceValidator +from et_replay import ExecutionTrace +from et_replay.tools.validate_trace import TraceValidator CURR_DIR = os.path.dirname(os.path.realpath(__file__)) class TestTraceLoadAndValidate(unittest.TestCase): def setUp(self): - self.trace_base = os.path.join(CURR_DIR, "data") + self.trace_base = os.path.join(CURR_DIR, "inputs") def _test_and_validate_trace(self, trace_file): with ( diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 6d5e8bd4..63a06144 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -17,6 +17,7 @@ import numpy as np import torch + from et_replay.comm import comms_utils from et_replay.comm.comms_utils import ( bootstrap_info_holder, @@ -294,7 +295,6 @@ def checkArgs(self, args: argparse.Namespace) -> None: and not os.path.isfile(self.trace_file) and not os.path.isdir(self.trace_file) ): - raise ValueError( f"The specified trace path '{self.trace_file}' is neither a " "file nor a directory. Please provide a valid path." @@ -637,7 +637,6 @@ def prepComms( # for all_to_allv, we can shrink the size if running on smaller scale # this is for sanity test or debug purpose only since we don't always get to run very large scale if self.shrink: - cur_world_size = self.collectiveArgs.world_size real_world_size = cur_world_size diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 7939048f..03dba90f 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -1,7 +1,6 @@ import argparse import gc import json - import logging import os import sys @@ -13,23 +12,10 @@ import numpy as np import torch -from et_replay.lib.comm import comms_utils - -from et_replay.lib.execution_trace import ExecutionTrace, NodeType -from et_replay.lib.utils import trace_handler +from et_replay.comm import comms_utils +from et_replay.et_replay_utils import ( -from param_bench.train.compute.python.lib import pytorch as lib_pytorch -from param_bench.train.compute.python.lib.init_helper import load_modules -from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch -from torch._inductor.async_compile import AsyncCompile -from torch._inductor.codecache import TritonFuture - -# grid and split_scan_grid are dynamically loaded -from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid -from torch.profiler import ExecutionTraceObserver - -from ..lib.et_replay_utils import ( build_fbgemm_func, build_torchscript_func, build_triton_func, @@ -52,6 +38,17 @@ TORCH_DTYPES_RNG, TORCH_DTYPES_RNG_str, ) +from et_replay.execution_trace import ExecutionTrace, NodeType +from et_replay.utils import trace_handler +from param_bench.train.compute.python.lib import pytorch as lib_pytorch +from param_bench.train.compute.python.lib.init_helper import load_modules +from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.codecache import TritonFuture + +# grid and split_scan_grid are dynamically loaded +from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid +from torch.profiler import ExecutionTraceObserver class ExgrReplayManager: @@ -129,7 +126,7 @@ def __init__(self): self.label = "" try: - from param_bench.et_replay.lib.fb.internals import ( + from param_bench.et_replay.fb.internals import ( add_internal_label, add_internal_parallel_nodes_parents, add_internal_skip_nodes, @@ -212,7 +209,7 @@ def initBench(self): # Input et trace should be explicitly specified after --input. if "://" in self.args.input: try: - from param_bench.et_replay.lib.fb.internals import read_remote_trace + from param_bench.et_replay.fb.internals import read_remote_trace except ImportError: logging.info("FB internals not present") exit(1) @@ -237,7 +234,7 @@ def initBench(self): # Different processes should read different traces based on global_rank_id. if "://" in self.args.trace_path: try: - from param_bench.et_replay.lib.fb.internals import read_remote_trace + from param_bench.et_replay.fb.internals import read_remote_trace except ImportError: logging.info("FB internals not present") exit(1) @@ -1437,7 +1434,7 @@ def benchTime(self): end_time = datetime.now() try: - from param_bench.et_replay.lib.fb.internals import generate_query_url + from param_bench.et_replay.fb.internals import generate_query_url except ImportError: logging.info("FB internals not present") else: diff --git a/et_replay/tools/validate_trace.py b/et_replay/tools/validate_trace.py index 46af10bb..a705e5bf 100644 --- a/et_replay/tools/validate_trace.py +++ b/et_replay/tools/validate_trace.py @@ -9,7 +9,7 @@ import gzip import json -from et_replay.lib.execution_trace import ExecutionTrace +from et_replay.execution_trace import ExecutionTrace class TraceValidator: diff --git a/et_replay/lib/utils.py b/et_replay/utils.py similarity index 96% rename from et_replay/lib/utils.py rename to et_replay/utils.py index 5f188666..c695bc24 100644 --- a/et_replay/lib/utils.py +++ b/et_replay/utils.py @@ -6,7 +6,7 @@ import uuid from typing import Any, Dict -from et_replay.lib.execution_trace import ExecutionTrace +from et_replay import ExecutionTrace def get_tmp_trace_filename() -> str: diff --git a/requirements.txt b/requirements.txt index 79f10f90..0fb4ac63 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch future numpy +pydot diff --git a/train/comms/pt/commsTraceParser.py b/train/comms/pt/commsTraceParser.py index 0111537f..de3b4dc6 100644 --- a/train/comms/pt/commsTraceParser.py +++ b/train/comms/pt/commsTraceParser.py @@ -5,7 +5,7 @@ from typing import List, Tuple -from et_replay.lib.execution_trace import ExecutionTrace +from et_replay import ExecutionTrace from param_bench.train.comms.pt import comms_utils from param_bench.train.comms.pt.comms_utils import commsArgs