From 26771701638ef24d7e350edbfc3e8954060513b5 Mon Sep 17 00:00:00 2001 From: briancoutinho Date: Mon, 29 Apr 2024 16:56:52 -0700 Subject: [PATCH] fix black --- train/compute/python/tools/execution_trace.py | 1 + train/compute/python/tools/validate_trace.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/train/compute/python/tools/execution_trace.py b/train/compute/python/tools/execution_trace.py index 9bc04086..806cbcd7 100644 --- a/train/compute/python/tools/execution_trace.py +++ b/train/compute/python/tools/execution_trace.py @@ -104,6 +104,7 @@ def is_leaf_tensor(self): @dataclass class _CommArgs: """Contains communication collective metadata""" + collective_name: str dtype: str # .. TODO add more see https://github.com/pytorch/pytorch/issues/124674 diff --git a/train/compute/python/tools/validate_trace.py b/train/compute/python/tools/validate_trace.py index 249b1455..7205be99 100644 --- a/train/compute/python/tools/validate_trace.py +++ b/train/compute/python/tools/validate_trace.py @@ -13,7 +13,6 @@ class TraceValidator: - def __init__(self, execution_trace: ExecutionTrace): self.et = execution_trace @@ -60,7 +59,7 @@ def check_comms_node_old(n) -> bool: def check_comms_node_new(n) -> bool: """New elements are added as per - https://github.com/pytorch/pytorch/issues/124674 + https://github.com/pytorch/pytorch/issues/124674 """ # TODO check for node.commArgs dataclass print(n.commArgs)