Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework forward pass to remove old gradients #46

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

Arkay92
Copy link

@Arkay92 Arkay92 commented Dec 27, 2022

Using the torch.cuda.device_of() function to determine if the input tensors are on the GPU or CPU, and then choosing the appropriate layer implementations for better performance. Uses the torch.no_grad() context manager to prevent the model from tracking gradients in the forward pass.

@Arkay92
Copy link
Author

Arkay92 commented Dec 27, 2022

This may be linked to issue #27

@dancergraham
Copy link

This is awesome - without this change I cannot run any of the examples on my Geforce GTX 1650 with 4Gb of dedicated GPU memory. With this change I can run the 40M-textvec model. This takes sampling time from nearly one hour (cpu) to a couple of minutes (Gpu) on my laptop. Thank you so much ! I hope it is accepted in to the repo.

@dancergraham
Copy link

This also relates to issue #36

@dancergraham
Copy link

dancergraham commented Dec 29, 2022

Hello,
Using the pointcloud2mesh.ipynb notebook I get an error:

AttributeError: module 'torch.nn' has no attribute 'CUDALayerNorm'

I am using pytorch version '1.13.1+cu117'

@Arkay92
Copy link
Author

Arkay92 commented Dec 29, 2022

Good spot yet again @dancergraham have switched over to nvidia apex for fusedlayernorm on GPU can you try this now ?

@Arkay92
Copy link
Author

Arkay92 commented Dec 29, 2022

NB this requires the external lib for apex to work should speed up rendering once it fires up. Any issues let me know and I'll rework @dancergraham (have added to setup.py install_requires)

@dancergraham
Copy link

I was not able to install apex with pip on my Windows machine - I got a lot of errors about "filename too long"

I tried python -m pip install "apex @ git+https://github.com/NVIDIA/apex.git"

@Arkay92
Copy link
Author

Arkay92 commented Dec 29, 2022

Looking into other alternatives to layernorm and it seems instance or group normalisation may help speed things up here ! Will ping another refactored PR soon

@dancergraham
Copy link

hmm this looks rather complex - I will try it out on my machine but if I was running the point-e repo I don't think I would want to adopt a complicated dependency, especially one marked as "experimental" on Windows...

It might be good to add it as an optional dependency in the same way that the code currently works with or without cuda; That adds complexity to the library so it is the maintainers' call whether or not to accept that approach.

@Arkay92
Copy link
Author

Arkay92 commented Dec 30, 2022

Shall remove the apex lib but keep the forward pass change this should still preserve performance without the lib dependency

@Arkay92
Copy link
Author

Arkay92 commented Dec 30, 2022

@dancergraham try this now, textvec rendering should still be significantly faster whilst I find a native way of speeding up layer norm on gpu / cuda

@dancergraham
Copy link

I now get an error TypeError: ResidualCrossAttentionBlock.forward() missing 1 required positional argument: 'device' when I try to run pointcloud2mesh

perceiver.py:154, in SimplePerceiver.forward(self, x, data)
    152 with torch.no_grad():
    153     for block in self.resblocks:
--> 154         x = block(x, data)
    155 return x

@Arkay92
Copy link
Author

Arkay92 commented Jan 2, 2023

My bad @dancergraham forgot I added as a param, changed back so .to() uses torch.deice directly rather than by reference from param list

@dancergraham
Copy link

still not working for me - I get errors with pointcloud2mesh:

File ...\point_e\models\perceiver.py:154, in SimplePerceiver.forward(self, x, data)
    152 with torch.no_grad():
    153     for block in self.resblocks:
--> 154         x = block(x, data)
    155 return x

File ...\lib\site-packages\torch\nn\modules\module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ...\point-e\point_e\models\perceiver.py:106, in ResidualCrossAttentionBlock.forward(self, x, data)
    103 def forward(self, x: torch.Tensor, data: torch.Tensor):
    104     with torch.no_grad():
    105         # Use the to() method to move the input tensors to the specified device
--> 106         x = x.to(torch.device)
    107         data = data.to(torch.device)
    109         # Normalize input tensors and pass them through the attention and MLP layers

TypeError: to() received an invalid combination of arguments - got (type), but expected one of:
 * (torch.device device, torch.dtype dtype, bool non_blocking, bool copy, *, torch.memory_format memory_format)
 * (torch.dtype dtype, bool non_blocking, bool copy, *, torch.memory_format memory_format)
 * (Tensor tensor, bool non_blocking, bool copy, *, torch.memory_format memory_format)

@dancergraham
Copy link

We have liftoff 🚀 I can now run at grid_size=128 in 45 seconds per model on my GPU - many thanks again !

@Arkay92
Copy link
Author

Arkay92 commented Jan 3, 2023

Thankyou so much for the testing support @dancergraham ! LFG !

@Arkay92 Arkay92 changed the title Rework perceiver tensor logic Rework forward pass to remove old gradients Jan 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants