Skip to content

Commit

Permalink
[inductor] fix the cudagraph tree test (pytorch#132043)
Browse files Browse the repository at this point in the history
Summary:
There are two kinds of exceptions:
Case #1:
```
static input data pointer changed.
input name: primals_2. data pointer changed from 140315748992000 to 140315748993536. input stack trace:   File "/dev/shm/uid-30083/c0899c70-seed-nspid4026535598_cgpid16622182-ns-4026535192/caffe2/test/inductor/test_cudagraph_trees.py", line 1826, in forward
    return self.static_tensor + x + self.goo(x)
  File "/dev/shm/uid-30083/c0899c70-seed-nspid4026535598_cgpid16622182-ns-4026535192/caffe2/test/inductor/test_cudagraph_trees.py", line 1816, in forward
    return self.linear(x)

input name: primals_3. data pointer changed from 140315748990976 to 140315748993024. input stack trace:   File "/dev/shm/uid-30083/c0899c70-seed-nspid4026535598_cgpid16622182-ns-4026535192/caffe2/test/inductor/test_cudagraph_trees.py", line 1825, in forward
    self.static_tensor.add_(torch.ones((2, 2), device="cuda"))

```
Case #2:
```
static input data pointer changed.
input name: primals_2. data pointer changed from 139852509086720 to 139852509088256. input stack trace: None
input name: primals_3. data pointer changed from 139852509085696 to 139852509087744. input stack trace:   File "/dev/shm/uid-30083/f61ee184-seed-nspid4026560782_cgpid769179-ns-4026560865/caffe2/test/inductor/test_cudagraph_trees.py", line 1825, in forward
    self.static_tensor.add_(torch.ones((2, 2), device="cuda"))

```
The current impl only covered the case #2

Test Plan: https://www.internalfb.com/intern/testinfra/testrun/15481123762274476

Differential Revision: D60340212

Pull Request resolved: pytorch#132043
Approved by: https://github.com/BoyuanFeng
  • Loading branch information
sijiac authored and pytorchmergebot committed Jul 30, 2024
1 parent 36e8289 commit 83db609
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/inductor/test_cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,9 +1839,9 @@ def forward(self, x) -> torch.Tensor:
with self.assertRaisesRegex(
Exception,
r"static input data pointer changed.\n"
r"input name: primals_2. data pointer changed from .* to .*. input stack trace: None\n"
r"input name: primals_2. data pointer changed from .* to .*. input stack trace:(?s).*"
r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*,"
r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n\n",
r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n",
):
self.curr_node().run(
[foo.goo.linear.weight, foo.goo.linear.bias, foo.static_tensor, inp]
Expand Down

0 comments on commit 83db609

Please sign in to comment.