diff --git a/ludwig/encoders/bag_encoders.py b/ludwig/encoders/bag_encoders.py index ffdd6394353..495e110857c 100644 --- a/ludwig/encoders/bag_encoders.py +++ b/ludwig/encoders/bag_encoders.py @@ -22,6 +22,7 @@ from ludwig.constants import BAG from ludwig.encoders.base import Encoder from ludwig.encoders.registry import register_encoder +from ludwig.encoders.types import EncoderOutputDict from ludwig.modules.embedding_modules import EmbedWeighted from ludwig.modules.fully_connected_modules import FCStack from ludwig.schema.encoders.bag_encoders import BagEmbedWeightedConfig @@ -99,7 +100,7 @@ def input_shape(self) -> torch.Size: def output_shape(self) -> torch.Size: return self.fc_stack.output_shape - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: torch.Tensor) -> EncoderOutputDict: """ :param inputs: The inputs fed into the encoder. Shape: [batch x vocab size], type torch.int32 @@ -109,4 +110,4 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: hidden = self.embed_weighted(inputs) hidden = self.fc_stack(hidden) - return hidden + return {"encoder_output": hidden} diff --git a/ludwig/encoders/category_encoders.py b/ludwig/encoders/category_encoders.py index 92a14f2b5b0..4e96b8423ca 100644 --- a/ludwig/encoders/category_encoders.py +++ b/ludwig/encoders/category_encoders.py @@ -23,6 +23,7 @@ from ludwig.constants import CATEGORY from ludwig.encoders.base import Encoder from ludwig.encoders.registry import register_encoder +from ludwig.encoders.types import EncoderOutputDict from ludwig.modules.embedding_modules import Embed from ludwig.schema.encoders.base import BaseEncoderConfig from ludwig.schema.encoders.category_encoders import ( @@ -45,12 +46,12 @@ def __init__(self, input_size=1, encoder_config=None, **kwargs): logger.debug(f" {self.name}") self.input_size = input_size - def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> EncoderOutputDict: """ :param inputs: The inputs fed into the encoder. Shape: [batch x 1] """ - return inputs.float() + return {"encoder_output": inputs.float()} @staticmethod def get_schema_cls() -> Type[BaseEncoderConfig]: @@ -101,7 +102,7 @@ def __init__( ) self.embedding_size = self.embed.embedding_size - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: torch.Tensor) -> EncoderOutputDict: """ :param inputs: The inputs fed into the encoder. Shape: [batch x 1], type torch.int32 @@ -109,7 +110,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: :param return: embeddings of shape [batch x embed size], type torch.float32 """ embedded = self.embed(inputs) - return embedded + return {"encoder_output": embedded} @staticmethod def get_schema_cls() -> Type[BaseEncoderConfig]: @@ -156,7 +157,7 @@ def __init__( ) self.embedding_size = self.embed.embedding_size - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: torch.Tensor) -> EncoderOutputDict: """ :param inputs: The inputs fed into the encoder. Shape: [batch x 1], type torch.int32 @@ -164,7 +165,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: :param return: embeddings of shape [batch x embed size], type torch.float32 """ embedded = self.embed(inputs) - return embedded + return {"encoder_output": embedded} @staticmethod def get_schema_cls() -> Type[BaseEncoderConfig]: @@ -194,7 +195,7 @@ def __init__( logger.debug(f" {self.name}") self.vocab_size = len(vocab) - def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> EncoderOutputDict: """ :param inputs: The inputs fed into the encoder. Shape: [batch, 1] or [batch] @@ -202,7 +203,8 @@ def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> t = inputs.reshape(-1).long() # the output of this must be a float so that it can be concatenated with other # encoder outputs and passed to dense layers in the combiner, decoder, etc. - return torch.nn.functional.one_hot(t, num_classes=self.vocab_size).float() + outputs = torch.nn.functional.one_hot(t, num_classes=self.vocab_size).float() + return {"encoder_output": outputs} @staticmethod def get_schema_cls() -> Type[BaseEncoderConfig]: diff --git a/ludwig/encoders/set_encoders.py b/ludwig/encoders/set_encoders.py index cb3906befc1..19856654817 100644 --- a/ludwig/encoders/set_encoders.py +++ b/ludwig/encoders/set_encoders.py @@ -22,6 +22,7 @@ from ludwig.constants import SET from ludwig.encoders.base import Encoder from ludwig.encoders.registry import register_encoder +from ludwig.encoders.types import EncoderOutputDict from ludwig.modules.embedding_modules import EmbedSet from ludwig.modules.fully_connected_modules import FCStack from ludwig.schema.encoders.base import BaseEncoderConfig @@ -89,7 +90,7 @@ def __init__( default_dropout=dropout, ) - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: torch.Tensor) -> EncoderOutputDict: """ Params: inputs: The inputs fed into the encoder. @@ -101,7 +102,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: hidden = self.embed(inputs) hidden = self.fc_stack(hidden) - return hidden + return {"encoder_output": hidden} @staticmethod def get_schema_cls() -> Type[BaseEncoderConfig]: diff --git a/ludwig/features/bag_feature.py b/ludwig/features/bag_feature.py index 01ca5fc3abe..fcb1b7488c0 100644 --- a/ludwig/features/bag_feature.py +++ b/ludwig/features/bag_feature.py @@ -103,7 +103,7 @@ def forward(self, inputs): encoder_output = self.encoder_obj(inputs) - return {"encoder_output": encoder_output} + return encoder_output @property def input_shape(self) -> torch.Size: diff --git a/ludwig/features/category_feature.py b/ludwig/features/category_feature.py index 0f65a2521fb..48ed77dd724 100644 --- a/ludwig/features/category_feature.py +++ b/ludwig/features/category_feature.py @@ -278,12 +278,7 @@ def __init__(self, input_feature_config: CategoryInputFeatureConfig, encoder_obj def forward(self, inputs): assert isinstance(inputs, torch.Tensor) - assert ( - inputs.dtype == torch.int8 - or inputs.dtype == torch.int16 - or inputs.dtype == torch.int32 - or inputs.dtype == torch.int64 - ) + assert inputs.dtype in (torch.int8, torch.int16, torch.int32, torch.int64) assert len(inputs.shape) == 1 or (len(inputs.shape) == 2 and inputs.shape[1] == 1) inputs = inputs.reshape(-1, 1) @@ -291,10 +286,7 @@ def forward(self, inputs): inputs = inputs.type(torch.int) encoder_output = self.encoder_obj(inputs) - batch_size = inputs.shape[0] - inputs = inputs.reshape(batch_size, -1) - - return {"encoder_output": encoder_output} + return encoder_output @property def input_dtype(self): diff --git a/ludwig/features/set_feature.py b/ludwig/features/set_feature.py index 29ea510d6cc..46ae76d8cf7 100644 --- a/ludwig/features/set_feature.py +++ b/ludwig/features/set_feature.py @@ -214,7 +214,7 @@ def forward(self, inputs): encoder_output = self.encoder_obj(inputs) - return {"encoder_output": encoder_output} + return encoder_output @property def input_dtype(self):