From 78036a43c07945aa3dfc3b50ac2d62e1b2f86423 Mon Sep 17 00:00:00 2001 From: Sheng Fu Date: Fri, 9 Aug 2024 13:39:40 -0700 Subject: [PATCH] Fixed the bug to access tensor stride (#150) Summary: Pull Request resolved: https://github.com/facebookresearch/param/pull/150 There is a bug to access tensor stride. The current implementation uses the index of the input tensor to access node.tensor_strides. However node.tensor_strides is for all input variables including non tensor variables. So the fix is to get the list of tensor stride, then uses the index of the input tensor to access the list of tensor stride. Reviewed By: briancoutinho Differential Revision: D60642585 --- et_replay/execution_trace.py | 23 ++++++++++++++++++++++- et_replay/tools/et_replay.py | 7 ++++--- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index c3a7c037..cd4ad941 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -294,9 +294,30 @@ def get_tensors(self, param_list: Iterable) -> List[tuple]: tensors.extend(self.get_tensors(zip(elem_type, input, shape))) return tensors + def get_tensor_strides( + self, input_list: Iterable, stride_list: Iterable + ) -> List[tuple]: + strides = [] + for (type, input, shape), stride in zip(input_list, stride_list): + if type.startswith("Tensor"): + strides.append(tuple(stride)) + # GenericList could have tensor elements + elif type.startswith("GenericList"): + elem_type = type[len("GenericList[") : -1].split(",") + strides.extend( + self.get_tensor_strides(zip(elem_type, input, shape), stride) + ) + return strides + def get_input_tensors(self) -> List[tuple]: return self.get_tensors(self.get_inputs()) + def get_input_tensor_strides(self) -> Optional[List[tuple]]: + if self.input_strides is None: + return None + else: + return self.get_tensor_strides(self.get_inputs(), self.input_strides) + def get_output_tensors(self) -> List[tuple]: return self.get_tensors(self.get_outputs()) @@ -542,7 +563,7 @@ def get_param(value, type, shape): if type.startswith("genericlist"): param = {"type": "genericlist"} param["value"] = [] - type_list = type[12:-1].split(",") + type_list = type[len("GenericList[") : -1].split(",") param_list = zip(value, type_list, shape) for v, t, s in param_list: param["value"].append(get_param(v, t, s)) diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 21531dee..480c078f 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -517,6 +517,7 @@ def allocate_tensors(self): self.args.pooling_factor, self.args.alpha, ) + tensor_strides = node.get_input_tensor_strides() for idx, (data_type, t_id, shape) in enumerate(get_input_tensors(node)): device = self.device if self.tensor_with_device: @@ -549,7 +550,7 @@ def allocate_tensors(self): strides = None if node.input_strides is not None: - strides = node.input_strides[idx] + strides = tensor_strides[idx] tensor = self.get_tensor_from_storage( t_id[1], # storage_id t_id[2], # offset @@ -833,7 +834,7 @@ def _parse_element_type(node, output_type, output_tensors, override): ) elif output_type.startswith("GenericList"): outputs += "[" - elements_type = output_type[12:-1].split(",") + elements_type = output_type[len("GenericList[") : -1].split(",") for element_type in elements_type: outputs += _parse_element_type( node, element_type, output_tensors, override @@ -1022,7 +1023,7 @@ def get_tensor_from_storage( storage_tensor.untyped_storage(), storage_offset=data_offset, size=shape, - stride=tuple(strides), + stride=strides, ) return x