diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index 6d08f561..74108164 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -292,9 +292,30 @@ def get_tensors(self, param_list: Iterable) -> List[tuple]: tensors.extend(self.get_tensors(zip(elem_type, input, shape))) return tensors + def get_tensor_strides( + self, input_list: Iterable, stride_list: Iterable + ) -> List[tuple]: + strides = [] + for (type, input, shape), stride in zip(input_list, stride_list): + if type.startswith("Tensor"): + strides.append(stride) + # GenericList could have tensor elements + elif type.startswith("GenericList"): + elem_type = type[12:-1].split(",") + strides.extend( + self.get_tensor_strides(zip(elem_type, input, shape), stride) + ) + return strides + def get_input_tensors(self) -> List[tuple]: return self.get_tensors(self.get_inputs()) + def get_input_tensor_strides(self) -> Optional[List[tuple]]: + if self.input_strides is None: + return None + else: + return self.get_tensor_strides(self.get_inputs(), self.input_strides) + def get_output_tensors(self) -> List[tuple]: return self.get_tensors(self.get_outputs()) diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 21531dee..ae40fe0b 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -517,6 +517,7 @@ def allocate_tensors(self): self.args.pooling_factor, self.args.alpha, ) + tensor_strides = node.get_input_tensor_strides() for idx, (data_type, t_id, shape) in enumerate(get_input_tensors(node)): device = self.device if self.tensor_with_device: @@ -549,7 +550,7 @@ def allocate_tensors(self): strides = None if node.input_strides is not None: - strides = node.input_strides[idx] + strides = tensor_strides[idx] tensor = self.get_tensor_from_storage( t_id[1], # storage_id t_id[2], # offset