-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor the forward API of modules #8
Comments
I had a discussion with @abmazitov. Regarding the usefulness of torch_geometric: They have a bunch of data loader iterators that allow different walks over graphs (https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html) that might be useful in the future. The offer a bit nicer handling of mini-batches. Given that is not that much added complexity and it still overtakes the concatenation for of different structures/neighborlists, I would give it a go. |
I checked from typing import Optional, Dict
from torch import Tensor
from torch_geometric.nn import MessagePassing
import torch
class MyConv(MessagePassing):
propagate_type = {'x': Tensor, 'edge_weight': Optional[Tensor] }
def forward(self, x: Tensor, edge_index: Tensor,
edge_weight: Optional[Tensor]) -> Tensor:
return self.propagate(edge_index, x=x, edge_weight=edge_weight,
size=None)
torch.jit.script(MyConv().jittable())
class MyConv2(MessagePassing):
propagate_type = {'x': Tensor, 'edge_weight': Optional[Tensor] }
def forward(self, data: Dict[str, torch.Tensor]) -> Tensor:
x: Tensor = data['x']
edge_index: Tensor = data['edge_index']
if 'edge_weight' in data.keys():
edge_weight: Tensor = data['edge_weight']
else:
edge_weight = None
return self.propagate(edge_index, x=x, edge_weight=edge_weight,
size=None)
torch.jit.script(MyConv2().jittable()) |
So in PR #9 we took the step support as input a neighborlist + direction vectors to move out unnecessary computation that can be precomputed for training. In the next step as soon we have a equistore Dataset and DataLoader is ready, I would like to change the API to something like def forward_eqs(self,
direction_vectors: TensorMap, # metadata (structure, center, neighbor, species_center, species_neighbor || cell_shift_x cell_shift_y, cell_shift_z)
) -> TensorMap: The position and cell information is outside of the SphericalExpansion and can be used to retrieve the gradients positions: TensorMap, # metadata (structure center species || x y z
cell: TensorMap, # metadata (structure || cell_x cell_y cell_z) |
The goal is to agree on a forward function input type (data type and logically what is the type representing) and there are two aspects I want to discuss. Let me just briefly describe the two
Data type:
torch_geometric.data.Batch
as input type (convenient usage, creates dependency, seems not to support torchscript see that their models are differently defined https://pytorch-geometric.readthedocs.io/en/latest/advanced/jit.html)Dict[str, torch.Tensor]
(after the PR Remove ase.Atoms dependence of SphericalExpansion by Structure object #7 is merged) as input typeI looked into the mace code, to see how they handle it. and they transform it to a
Dict
before passing it to the model.https://github.com/ACEsuit/mace/blob/44b6da4a5edaa4c3ef867a11728555403d1d475d/mace/calculators/mace.py#L71
So I think it is pretty clear for me that we don't use
torch_geometric.data.Batch
for the modules for spherical expansion computation.Logical type (neighborlist creation):
I think we should move the neighbor list creation outside of the spherical expansion computation. With that the usage of torch_geometric in the DataLoader is a possibility (like in mace).
Usefulness of torch_geometric.data.Data
I don't see much the reason to use it. We can achieve the same thing with the regular torch Dataloader. We in any case need to define the logic of creating a NL from whatever atomic datatype and put it to a
Dict[str, torch.Tensor]
. Everything I could get from torch_alchemical and mace is that it gives somewhat a nicer structure because you define a Data object, but you could also do this without the extra dependency. I think theAtomisticDataset(torch.utils.data.Dataset)
class in torch_alchemical is something we want, but withouttorch_geometric.data.Data
.The text was updated successfully, but these errors were encountered: