Skip to content

Commit

Permalink
Merge branch 'main' into improve-device-error
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Jul 26, 2024
2 parents ad996f5 + 81e5836 commit 864e575
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
19 changes: 12 additions & 7 deletions src/metatrain/utils/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,20 @@ def calculate_composition_weights(
)
targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions

structure_list = [sample["system"] for dataset in datasets for sample in dataset]

dtype = structure_list[0].positions.dtype
total_num_structures = sum([len(dataset) for dataset in datasets])
dtype = datasets[0][0]["system"].positions.dtype
composition_features = torch.empty(
(len(structure_list), len(atomic_types)), dtype=dtype
(total_num_structures, len(atomic_types)), dtype=dtype
)
for i, structure in enumerate(structure_list):
for j, s in enumerate(atomic_types):
composition_features[i, j] = torch.sum(structure.types == s)
structure_index = 0
for dataset in datasets:
for sample in dataset:
structure = sample["system"]
for j, s in enumerate(atomic_types):
composition_features[structure_index, j] = torch.sum(
structure.types == s
)
structure_index += 1

regularizer = 1e-20
while regularizer:
Expand Down
6 changes: 3 additions & 3 deletions src/metatrain/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ def get_atomic_types(datasets: Union[Dataset, List[Dataset]]) -> List[int]:
if not isinstance(datasets, list):
datasets = [datasets]

types = []
types = set()
for dataset in datasets:
for index in range(len(dataset)):
system = dataset[index]["system"]
types += system.types.tolist()
types.update(set(system.types.tolist()))

return sorted(set(types))
return sorted(types)


def get_all_targets(datasets: Union[Dataset, List[Dataset]]) -> List[str]:
Expand Down

0 comments on commit 864e575

Please sign in to comment.