diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index c3a7c037..cd4ad941 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -294,9 +294,30 @@ def get_tensors(self, param_list: Iterable) -> List[tuple]: tensors.extend(self.get_tensors(zip(elem_type, input, shape))) return tensors + def get_tensor_strides( + self, input_list: Iterable, stride_list: Iterable + ) -> List[tuple]: + strides = [] + for (type, input, shape), stride in zip(input_list, stride_list): + if type.startswith("Tensor"): + strides.append(tuple(stride)) + # GenericList could have tensor elements + elif type.startswith("GenericList"): + elem_type = type[len("GenericList[") : -1].split(",") + strides.extend( + self.get_tensor_strides(zip(elem_type, input, shape), stride) + ) + return strides + def get_input_tensors(self) -> List[tuple]: return self.get_tensors(self.get_inputs()) + def get_input_tensor_strides(self) -> Optional[List[tuple]]: + if self.input_strides is None: + return None + else: + return self.get_tensor_strides(self.get_inputs(), self.input_strides) + def get_output_tensors(self) -> List[tuple]: return self.get_tensors(self.get_outputs()) @@ -542,7 +563,7 @@ def get_param(value, type, shape): if type.startswith("genericlist"): param = {"type": "genericlist"} param["value"] = [] - type_list = type[12:-1].split(",") + type_list = type[len("GenericList[") : -1].split(",") param_list = zip(value, type_list, shape) for v, t, s in param_list: param["value"].append(get_param(v, t, s)) diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 21531dee..480c078f 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -517,6 +517,7 @@ def allocate_tensors(self): self.args.pooling_factor, self.args.alpha, ) + tensor_strides = node.get_input_tensor_strides() for idx, (data_type, t_id, shape) in enumerate(get_input_tensors(node)): device = self.device if self.tensor_with_device: @@ -549,7 +550,7 @@ def allocate_tensors(self): strides = None if node.input_strides is not None: - strides = node.input_strides[idx] + strides = tensor_strides[idx] tensor = self.get_tensor_from_storage( t_id[1], # storage_id t_id[2], # offset @@ -833,7 +834,7 @@ def _parse_element_type(node, output_type, output_tensors, override): ) elif output_type.startswith("GenericList"): outputs += "[" - elements_type = output_type[12:-1].split(",") + elements_type = output_type[len("GenericList[") : -1].split(",") for element_type in elements_type: outputs += _parse_element_type( node, element_type, output_tensors, override @@ -1022,7 +1023,7 @@ def get_tensor_from_storage( storage_tensor.untyped_storage(), storage_offset=data_offset, size=shape, - stride=tuple(strides), + stride=strides, ) return x