Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the forward API of modules #8

Open
agoscinski opened this issue Jun 7, 2023 · 3 comments
Open

Refactor the forward API of modules #8

agoscinski opened this issue Jun 7, 2023 · 3 comments

Comments

@agoscinski
Copy link
Collaborator

agoscinski commented Jun 7, 2023

The goal is to agree on a forward function input type (data type and logically what is the type representing) and there are two aspects I want to discuss. Let me just briefly describe the two

Data type:

I looked into the mace code, to see how they handle it. and they transform it to a Dict before passing it to the model.
https://github.com/ACEsuit/mace/blob/44b6da4a5edaa4c3ef867a11728555403d1d475d/mace/calculators/mace.py#L71
So I think it is pretty clear for me that we don't use torch_geometric.data.Batch for the modules for spherical expansion computation.

Logical type (neighborlist creation):

  • torch_alchemical: creates the neighbor list outside of the model as preprocessing step (logical type are neighborlists)
  • torch_spex: creates the neighbor list in the model (logical type are atomic structures)

I think we should move the neighbor list creation outside of the spherical expansion computation. With that the usage of torch_geometric in the DataLoader is a possibility (like in mace).

Usefulness of torch_geometric.data.Data

I don't see much the reason to use it. We can achieve the same thing with the regular torch Dataloader. We in any case need to define the logic of creating a NL from whatever atomic datatype and put it to a Dict[str, torch.Tensor]. Everything I could get from torch_alchemical and mace is that it gives somewhat a nicer structure because you define a Data object, but you could also do this without the extra dependency. I think the AtomisticDataset(torch.utils.data.Dataset) class in torch_alchemical is something we want, but without torch_geometric.data.Data.

@agoscinski
Copy link
Collaborator Author

I had a discussion with @abmazitov.
To support in future models inheriting from torch_geometric.nn.MessagePassing we will move to a forward API supporting multiple torch.Tensor as arguments (see example https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html?highlight=MessagePassing#the-messagepassing-base-class). The Dict[str, torch.Tensor] would have given us more flexibility in the input type, but I don't think we need it in the spherical expansion module.

Regarding the usefulness of torch_geometric: They have a bunch of data loader iterators that allow different walks over graphs (https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html) that might be useful in the future. The offer a bit nicer handling of mini-batches. Given that is not that much added complexity and it still overtakes the concatenation for of different structures/neighborlists, I would give it a go.

@agoscinski
Copy link
Collaborator Author

agoscinski commented Jun 7, 2023

I checked Dict[str, torch.Tensor] also works with torch_geometric.nn.MessagePassing, but I still would go further with explicit inputs because it gives a nicer user interface.

from typing import Optional, Dict
from torch import Tensor
from torch_geometric.nn import MessagePassing
import torch

class MyConv(MessagePassing):

    propagate_type = {'x': Tensor, 'edge_weight': Optional[Tensor] }

    def forward(self, x: Tensor, edge_index: Tensor,
                edge_weight: Optional[Tensor]) -> Tensor:

        return self.propagate(edge_index, x=x, edge_weight=edge_weight,
                              size=None)

torch.jit.script(MyConv().jittable())

class MyConv2(MessagePassing):

    propagate_type = {'x': Tensor, 'edge_weight': Optional[Tensor] }

    def forward(self, data: Dict[str, torch.Tensor]) -> Tensor:
        x: Tensor = data['x']
        edge_index: Tensor = data['edge_index']
        if 'edge_weight' in data.keys():
            edge_weight: Tensor = data['edge_weight']
        else:
            edge_weight = None
        return self.propagate(edge_index, x=x, edge_weight=edge_weight,
                              size=None)

torch.jit.script(MyConv2().jittable())

@agoscinski
Copy link
Collaborator Author

So in PR #9 we took the step support as input a neighborlist + direction vectors to move out unnecessary computation that can be precomputed for training.

In the next step as soon we have a equistore Dataset and DataLoader is ready, I would like to change the API to something like

def forward_eqs(self,
        direction_vectors: TensorMap, #  metadata (structure, center, neighbor, species_center, species_neighbor || cell_shift_x cell_shift_y, cell_shift_z)
    ) -> TensorMap:

The position and cell information is outside of the SphericalExpansion and can be used to retrieve the gradients

      positions: TensorMap, # metadata (structure center species || x y z
      cell: TensorMap, # metadata (structure || cell_x cell_y cell_z)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant