Skip to content

Commit

Permalink
Fixed the bug to access tensor stride (#150)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #150

There is a bug to access tensor stride. The current implementation uses the index of the input tensor to access node.tensor_strides. However  node.tensor_strides is for all input variables including non tensor variables. So the fix is to get the list of tensor stride, then uses the index of the input tensor to access the list of tensor stride.

Differential Revision: D60642585
  • Loading branch information
shengfukevin authored and facebook-github-bot committed Aug 7, 2024
1 parent c466b60 commit efc853d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
21 changes: 21 additions & 0 deletions et_replay/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,30 @@ def get_tensors(self, param_list: Iterable) -> List[tuple]:
tensors.extend(self.get_tensors(zip(elem_type, input, shape)))
return tensors

def get_tensor_strides(
self, input_list: Iterable, stride_list: Iterable
) -> List[tuple]:
strides = []
for (type, input, shape), stride in zip(input_list, stride_list):
if type.startswith("Tensor"):
strides.append(stride)
# GenericList could have tensor elements
elif type.startswith("GenericList"):
elem_type = type[12:-1].split(",")
strides.extend(
self.get_tensor_strides(zip(elem_type, input, shape), stride)
)
return strides

def get_input_tensors(self) -> List[tuple]:
return self.get_tensors(self.get_inputs())

def get_input_tensor_strides(self) -> Optional[List[tuple]]:
if self.input_strides is None:
return None
else:
return self.get_tensor_strides(self.get_inputs(), self.input_strides)

def get_output_tensors(self) -> List[tuple]:
return self.get_tensors(self.get_outputs())

Expand Down
3 changes: 2 additions & 1 deletion et_replay/tools/et_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def allocate_tensors(self):
self.args.pooling_factor,
self.args.alpha,
)
tensor_strides = node.get_input_tensor_strides()
for idx, (data_type, t_id, shape) in enumerate(get_input_tensors(node)):
device = self.device
if self.tensor_with_device:
Expand Down Expand Up @@ -549,7 +550,7 @@ def allocate_tensors(self):

strides = None
if node.input_strides is not None:
strides = node.input_strides[idx]
strides = tensor_strides[idx]
tensor = self.get_tensor_from_storage(
t_id[1], # storage_id
t_id[2], # offset
Expand Down

0 comments on commit efc853d

Please sign in to comment.