diff --git a/train/compute/python/tools/execution_trace.py b/train/compute/python/tools/execution_trace.py index 9bc04086..806cbcd7 100644 --- a/train/compute/python/tools/execution_trace.py +++ b/train/compute/python/tools/execution_trace.py @@ -104,6 +104,7 @@ def is_leaf_tensor(self): @dataclass class _CommArgs: """Contains communication collective metadata""" + collective_name: str dtype: str # .. TODO add more see https://github.com/pytorch/pytorch/issues/124674 diff --git a/train/compute/python/tools/validate_trace.py b/train/compute/python/tools/validate_trace.py index 249b1455..7205be99 100644 --- a/train/compute/python/tools/validate_trace.py +++ b/train/compute/python/tools/validate_trace.py @@ -13,7 +13,6 @@ class TraceValidator: - def __init__(self, execution_trace: ExecutionTrace): self.et = execution_trace @@ -60,7 +59,7 @@ def check_comms_node_old(n) -> bool: def check_comms_node_new(n) -> bool: """New elements are added as per - https://github.com/pytorch/pytorch/issues/124674 + https://github.com/pytorch/pytorch/issues/124674 """ # TODO check for node.commArgs dataclass print(n.commArgs)