From e14207bd847c866c13f2a1d3e5771e66506b5acd Mon Sep 17 00:00:00 2001 From: Filippo Bigi <98903385+frostedoyster@users.noreply.github.com> Date: Mon, 22 Jul 2024 15:04:58 +0200 Subject: [PATCH] Fix bug with distributed training and fixed composition (#308) --- src/metatrain/experimental/soap_bpnn/trainer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index c9e7dd004..2a2967fa6 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -103,15 +103,18 @@ def train( "user-supplied composition weights" ) cur_weight_dict = self.hypers["fixed_composition_weights"][target_name] - atomic_types = set() + atomic_types = [] num_species = len(cur_weight_dict) fixed_weights = torch.zeros(num_species, dtype=dtype, device=device) for ii, (key, weight) in enumerate(cur_weight_dict.items()): - atomic_types.add(key) + atomic_types.append(key) fixed_weights[ii] = weight - if not set(atomic_types) == model.atomic_types: + if ( + not set(atomic_types) + == (model.module if is_distributed else model).atomic_types + ): raise ValueError( "Supplied atomic types are not present in the dataset." )