From a6f02a20b61bdf7b005edbe7de7201252751a325 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 15 Nov 2024 15:02:17 -0900 Subject: [PATCH] dnsm-experiments-1 #39 supporting PR (#87) Adds output_dim to `single` model for matching output dimensions expected by DASM. Supporting PR for https://github.com/matsengrp/dnsm-experiments-1/pull/41 --- netam/models.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/netam/models.py b/netam/models.py index 835b5984..0f5b2854 100644 --- a/netam/models.py +++ b/netam/models.py @@ -706,18 +706,23 @@ def predict(self, representation: Tensor): class SingleValueBinarySelectionModel(AbstractBinarySelectionModel): """A one parameter selection model as a baseline.""" - def __init__(self): + def __init__(self, output_dim: int = 1): super().__init__() self.single_value = nn.Parameter(torch.tensor(0.0)) + self.output_dim = output_dim @property def hyperparameters(self): - return {} + return {"output_dim": self.output_dim} def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: """Build a binary log selection matrix from an index-encoded parent sequence.""" - replicated_value = self.single_value.expand_as(amino_acid_indices) - return replicated_value + if self.output_dim == 1: + return self.single_value.expand(amino_acid_indices.shape) + else: + return self.single_value.expand( + amino_acid_indices.shape + (self.output_dim,) + ) class HitClassModel(nn.Module):