From 24b94bbbeb96663136b213839b2588d74569e6da Mon Sep 17 00:00:00 2001 From: Noemi Anau Montel Date: Sat, 30 Mar 2024 07:19:33 +0100 Subject: [PATCH] update autoregressive model (#148) --- swyft/lightning/estimators.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/swyft/lightning/estimators.py b/swyft/lightning/estimators.py index 8bfb54c..8f81074 100644 --- a/swyft/lightning/estimators.py +++ b/swyft/lightning/estimators.py @@ -351,6 +351,7 @@ def __init__( dropout=0.1, num_blocks=2, hidden_features=64, + min_l2 = None ): super().__init__() self.cl1 = swyft.LogRatioEstimator_1dim( @@ -372,6 +373,7 @@ def __init__( Lmax=0, ) self.num_params = num_params + self.min_l2 = min_l2 def forward(self, xA, zA, zB): xA, zB = swyft.equalize_tensors(xA, zB) @@ -397,7 +399,8 @@ def forward(self, xA, zA, zB): l1 = logratios1.logratios.sum(-1) l2 = logratios2.logratios.sum(-1) - l2 = torch.where(l2 > 0, l2, 0) + if self.min_l2 is not None: + l2 = torch.where(l2 > self.min_l2, l2, self.min_l2) l = (l1 - l2).detach().unsqueeze(-1) logratios_tot = swyft.LogRatioSamples(l, logratios1.params, logratios1.parnames)