Skip to content

Commit

Permalink
Implement dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 30, 2023
1 parent b9f0502 commit 366f39c
Show file tree
Hide file tree
Showing 10 changed files with 1,299 additions and 4 deletions.
Empty file.
2 changes: 2 additions & 0 deletions src/metatensor_models/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .dataset import Dataset, collate_fn # noqa: F401
from .readers import read_structures, read_targets # noqa: F401
76 changes: 76 additions & 0 deletions src/metatensor_models/utils/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Dict, List

import metatensor.torch
import rascaline.torch
import torch
from metatensor.torch import Labels, TensorMap


class Dataset(torch.utils.data.Dataset):
def __init__(
self, structures: List[rascaline.torch.System], targets: Dict[str, TensorMap]
):
"""Creates a dataset from a list of `rascaline.torch.System`s and
a list of dictionaries of `TensorMap`s."""

for tensor_map in targets.values():
n_structures = (
torch.max(tensor_map.block(0).samples["structure"]).item() + 1
)
if n_structures != len(structures):
raise ValueError(
f"Number of structures in input ({len(structures)}) and "
f"output ({n_structures}) must be the same"
)

self.structures = structures
self.targets = targets

def __len__(self):
"""
Return the total number of samples in the dataset.
"""
return len(self.structures)

def __getitem__(self, index):
"""
Generates one sample of data.
Args:
index: The index of the item in the dataset.
Returns:
A tuple containing the structure and targets for the given index.
"""
structure = self.structures[index]

structure_index_samples = Labels(
names=["structure"],
values=torch.tensor([[index]]), # must be a 2D-array
)

targets = {}
for name, tensor_map in self.targets.items():
targets[name] = metatensor.torch.slice(
tensor_map, "samples", structure_index_samples
)

return structure, targets


def collate_fn(batch):
"""
Creates a batch from a list of samples.
Args:
batch: A list of samples, where each sample is a tuple containing a
structure and targets.
Returns:
A tuple containing the structures and targets for the batch.
"""

structures = [sample[0] for sample in batch]
targets = {}
for name in batch[0][1].keys():
targets[name] = metatensor.torch.join(
[sample[1][name] for sample in batch], "samples"
)

return structures, targets
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from .structures import STRUCTURE_READERS
from .targets import TARGET_READERS

from rascaline.torch.system import Systems
from rascaline.torch.system import System


def read_structures(filename: str, fileformat: Optional[str] = None) -> List[Systems]:
def read_structures(filename: str, fileformat: Optional[str] = None) -> List[System]:
"""Reads a structure information from file."""

if fileformat is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import ase.io
from rascaline.systems import AseSystem
from rascaline.torch.system import Systems, systems_to_torch
from rascaline.torch.system import System, systems_to_torch


def read_ase(filename: str) -> List[Systems]:
def read_ase(filename: str) -> List[System]:
systems = [AseSystem(atoms) for atoms in ase.io.read(filename, ":")]

return systems_to_torch(systems)
16 changes: 16 additions & 0 deletions tests/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch

from metatensor_models.utils.data import Dataset, collate_fn, read_structures, read_targets


def test_dataset():
"""Tests the readers and the dataset class."""

structures = read_structures("data/qm9_reduced_100.xyz")
targets = read_targets("data/qm9_reduced_100.xyz", "U0")

dataset = Dataset(structures, targets)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, collate_fn=collate_fn)

for batch in dataloader:
assert batch[1]["U0"].block().values.shape == (10, 1)
1,201 changes: 1,201 additions & 0 deletions tests/data/qm9_reduced_100.xyz

Large diffs are not rendered by default.

0 comments on commit 366f39c

Please sign in to comment.