Skip to content

Commit

Permalink
Restructure et_relay subpackage (#133)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #133

1. Make `et_replay` structure easier to refactor: added an entry point to export public symbols, used it in train code.
2. Enabled tests run as part of GHA CI.

Reviewed at #124

Pull Request resolved: #129

Test Plan:
CI

## Additional Notes
--

Differential Revision: D59178710

Pulled By: sanrise
  • Loading branch information
TaekyungHeo authored and facebook-github-bot committed Jul 15, 2024
1 parent 0ce2962 commit 6e6eb8f
Show file tree
Hide file tree
Showing 19 changed files with 66 additions and 63 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/python_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,17 @@ jobs:
- name: Run Black
run: black . --check

- name: Run tests
run: |
python -m pip install -r requirements.txt
python -m pip install et_replay/
python et_replay/tests/test_execution_trace.py
- name: Validate imports
run: |
python -m pip install fbgemm-gpu
python -c 'from et_replay import ExecutionTrace'
python -c 'from et_replay.comm import comms_utils'
python -c 'from et_replay.tools.validate_trace import TraceValidator'
python -c 'from et_replay.utils import trace_handler'
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.venv/
__pycache__/
3 changes: 3 additions & 0 deletions et_replay/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from et_replay.execution_trace import ExecutionTrace

__all__ = ["ExecutionTrace"]
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
from __future__ import annotations

import json

from typing import List, Tuple

from et_replay.lib.comm import comms_utils
from et_replay.lib.comm.comms_utils import commsArgs
from et_replay.lib.comm.pytorch_backend_utils import supportedP2pOps

from et_replay.lib.execution_trace import ExecutionTrace
from et_replay import ExecutionTrace
from et_replay.comm import comms_utils
from et_replay.comm.comms_utils import commsArgs
from et_replay.comm.pytorch_backend_utils import supportedP2pOps

tensorDtypeMap = {
"Tensor(int)": "int",
Expand Down Expand Up @@ -63,7 +61,6 @@ def _parseBasicTrace(in_trace: List):
"""
newCommsTrace = []
for cnt, curComm in enumerate(in_trace):

newComm = commsArgs()
newComm.id = cnt
newComm.markerStack = curComm.get("markers")
Expand All @@ -84,7 +81,6 @@ def _parseBasicTrace(in_trace: List):


def _parseBasicTraceComms(curComm, newComm: commsArgs) -> None:

newComm.comms = comms_utils.paramToCommName(curComm["comms"].lower())
if newComm.markerStack is None:
newComm.markerStack = [newComm.comms]
Expand Down Expand Up @@ -165,7 +161,6 @@ def _parseKinetoUnitrace(in_trace: List, target_rank: int) -> List:
and entry["name"] == "record_param_comms"
and entry["args"]["rank"] == target_rank
):

newComm = commsArgs()
newComm.comms = comms_utils.paramToCommName(entry["args"]["comms"].lower())
newComm.id = commsCnt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@

import numpy as np
import torch
from et_replay.lib.comm.param_profile import paramTimer
from et_replay.lib.comm.pytorch_backend_utils import (


from et_replay.comm.param_profile import paramTimer
from et_replay.comm.pytorch_backend_utils import (
backendFunctions,
collectiveArgsHolder,
customized_backend,
Expand Down Expand Up @@ -1072,7 +1074,6 @@ def _prep_all_to_all_single(
scaleFactor: float,
allocate: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:

ipTensor = None
opTensor = None
if allocate:
Expand Down Expand Up @@ -1193,7 +1194,6 @@ def _prep_all_gather_base(
scaleFactor: float,
allocate: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:

opTensor = []
if not commsParams.size_from_trace:
numElementsOut = numElementsIn
Expand Down Expand Up @@ -1257,7 +1257,6 @@ def _prep_reduce_scatter(
scaleFactor: float,
allocate: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:

ipTensor = []
opTensor = []
if not commsParams.size_from_trace:
Expand Down Expand Up @@ -1304,7 +1303,6 @@ def _prep_reduce_scatter_base(
scaleFactor: float,
allocate: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:

ipTensor = []
opTensor = []
if not commsParams.size_from_trace:
Expand Down Expand Up @@ -1727,7 +1725,6 @@ def init_emb_lookup(collectiveArgs, commsParams, backendFuncs):
try:
# fbgemm_gpu can be downloaded from https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu
from fbgemm_gpu.split_embedding_utils import generate_requests

from fbgemm_gpu.split_table_batched_embeddings_ops import (
ComputeDevice,
EmbeddingLocation,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import torch

from et_replay.lib.comm.param_profile import paramTimer

from et_replay.comm.param_profile import paramTimer
from torch.distributed import ProcessGroup

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from et_replay.lib.comm.param_profile import paramProfile
from et_replay.lib.comm.pytorch_backend_utils import (
backendFunctions,
collectiveArgsHolder,
)

from et_replay.comm.param_profile import paramProfile
from et_replay.comm.pytorch_backend_utils import backendFunctions, collectiveArgsHolder


try:
from param_bench.train.comms.pt.fb.internals import (
Expand Down Expand Up @@ -634,7 +633,6 @@ def complete_single_op(self, collectiveArgs, retFlag=False):
self.device_sync(collectiveArgs)

def wait(self, collectiveArgs, retFlag=False):

# for backwards compatibility, use old wait functionality.
if len(collectiveArgs.waitObjIds) == 0:
self.complete_single_op(collectiveArgs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch_xla.core.xla_model as xm # @manual
import torch_xla.distributed.xla_multiprocessing as xmp # @manual

from .comms_utils import backendFunctions
from et_replay.comm.comms_utils import backendFunctions


class PyTorchTPUBackend(backendFunctions):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import re

import torch
from et_replay.lib.execution_trace import NodeType
from fbgemm_gpu.split_table_batched_embeddings_ops import PoolingMode, WeightDecayMode

from et_replay.execution_trace import NodeType
from fbgemm_gpu.split_table_batched_embeddings_ops import PoolingMode, WeightDecayMode
from param_bench.train.compute.python.lib.pytorch.config_util import create_op_args

from param_bench.train.compute.python.workloads.pytorch.split_table_batched_embeddings_ops import (
SplitTableBatchedEmbeddingBagsCodegenInputDataGenerator,
SplitTableBatchedEmbeddingBagsCodegenOp,
Expand Down Expand Up @@ -477,11 +476,11 @@ def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows)
import os
import time
from datetime import datetime
from et_replay.lib.comm import comms_utils
from et_replay.comm import comms_utils
import torch
from et_replay.lib.comm import commsTraceReplay
from et_replay.lib.et_replay_utils import (
from et_replay.comm import commsTraceReplay
from et_replay.et_replay_utils import (
build_fbgemm_func,
build_torchscript_func,
generate_fbgemm_tensors,
Expand All @@ -490,8 +489,8 @@ def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows)
is_qualified,
)
from et_replay.lib.execution_trace import ExecutionTrace
from et_replay.lib.utils import trace_handler
from et_replay.execution_trace import ExecutionTrace
from et_replay.utils import trace_handler
print("PyTorch version: ", torch.__version__)
Expand Down
File renamed without changes.
4 changes: 1 addition & 3 deletions et_replay/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ name = "et_replay"
version = "0.5.0"

[tool.setuptools.package-dir]
"et_replay.lib" = "lib"
"et_replay.lib.comm" = "lib/comm"
"et_replay.tools" = "tools"
"et_replay" = "."
"param_bench" = ".."

[project.scripts]
Expand Down
6 changes: 3 additions & 3 deletions et_replay/tests/test_execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import os
import unittest

from param_bench.train.compute.python.tools.execution_trace import ExecutionTrace
from param_bench.train.compute.python.tools.validate_trace import TraceValidator
from et_replay import ExecutionTrace
from et_replay.tools.validate_trace import TraceValidator

CURR_DIR = os.path.dirname(os.path.realpath(__file__))


class TestTraceLoadAndValidate(unittest.TestCase):
def setUp(self):
self.trace_base = os.path.join(CURR_DIR, "data")
self.trace_base = os.path.join(CURR_DIR, "inputs")

def _test_and_validate_trace(self, trace_file):
with (
Expand Down
3 changes: 1 addition & 2 deletions et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import torch

from et_replay.comm import comms_utils
from et_replay.comm.comms_utils import (
bootstrap_info_holder,
Expand Down Expand Up @@ -294,7 +295,6 @@ def checkArgs(self, args: argparse.Namespace) -> None:
and not os.path.isfile(self.trace_file)
and not os.path.isdir(self.trace_file)
):

raise ValueError(
f"The specified trace path '{self.trace_file}' is neither a "
"file nor a directory. Please provide a valid path."
Expand Down Expand Up @@ -637,7 +637,6 @@ def prepComms(
# for all_to_allv, we can shrink the size if running on smaller scale
# this is for sanity test or debug purpose only since we don't always get to run very large scale
if self.shrink:

cur_world_size = self.collectiveArgs.world_size
real_world_size = cur_world_size

Expand Down
37 changes: 17 additions & 20 deletions et_replay/tools/et_replay.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import gc
import json

import logging
import os
import sys
Expand All @@ -13,23 +12,10 @@
import numpy as np
import torch

from et_replay.lib.comm import comms_utils

from et_replay.lib.execution_trace import ExecutionTrace, NodeType

from et_replay.lib.utils import trace_handler
from et_replay.comm import comms_utils
from et_replay.et_replay_utils import (

from param_bench.train.compute.python.lib import pytorch as lib_pytorch
from param_bench.train.compute.python.lib.init_helper import load_modules
from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.codecache import TritonFuture

# grid and split_scan_grid are dynamically loaded
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
from torch.profiler import ExecutionTraceObserver

from ..lib.et_replay_utils import (
build_fbgemm_func,
build_torchscript_func,
build_triton_func,
Expand All @@ -52,6 +38,17 @@
TORCH_DTYPES_RNG,
TORCH_DTYPES_RNG_str,
)
from et_replay.execution_trace import ExecutionTrace, NodeType
from et_replay.utils import trace_handler
from param_bench.train.compute.python.lib import pytorch as lib_pytorch
from param_bench.train.compute.python.lib.init_helper import load_modules
from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.codecache import TritonFuture

# grid and split_scan_grid are dynamically loaded
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
from torch.profiler import ExecutionTraceObserver


class ExgrReplayManager:
Expand Down Expand Up @@ -129,7 +126,7 @@ def __init__(self):
self.label = ""

try:
from param_bench.et_replay.lib.fb.internals import (
from param_bench.et_replay.fb.internals import (
add_internal_label,
add_internal_parallel_nodes_parents,
add_internal_skip_nodes,
Expand Down Expand Up @@ -212,7 +209,7 @@ def initBench(self):
# Input et trace should be explicitly specified after --input.
if "://" in self.args.input:
try:
from param_bench.et_replay.lib.fb.internals import read_remote_trace
from param_bench.et_replay.fb.internals import read_remote_trace
except ImportError:
logging.info("FB internals not present")
exit(1)
Expand All @@ -237,7 +234,7 @@ def initBench(self):
# Different processes should read different traces based on global_rank_id.
if "://" in self.args.trace_path:
try:
from param_bench.et_replay.lib.fb.internals import read_remote_trace
from param_bench.et_replay.fb.internals import read_remote_trace
except ImportError:
logging.info("FB internals not present")
exit(1)
Expand Down Expand Up @@ -1437,7 +1434,7 @@ def benchTime(self):
end_time = datetime.now()

try:
from param_bench.et_replay.lib.fb.internals import generate_query_url
from param_bench.et_replay.fb.internals import generate_query_url
except ImportError:
logging.info("FB internals not present")
else:
Expand Down
2 changes: 1 addition & 1 deletion et_replay/tools/validate_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import gzip
import json

from et_replay.lib.execution_trace import ExecutionTrace
from et_replay.execution_trace import ExecutionTrace


class TraceValidator:
Expand Down
2 changes: 1 addition & 1 deletion et_replay/lib/utils.py → et_replay/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from typing import Any, Dict

from et_replay.lib.execution_trace import ExecutionTrace
from et_replay import ExecutionTrace


def get_tmp_trace_filename() -> str:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torch
future
numpy
pydot
2 changes: 1 addition & 1 deletion train/comms/pt/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import List, Tuple

from et_replay.lib.execution_trace import ExecutionTrace
from et_replay import ExecutionTrace

from param_bench.train.comms.pt import comms_utils
from param_bench.train.comms.pt.comms_utils import commsArgs
Expand Down

0 comments on commit 6e6eb8f

Please sign in to comment.