Skip to content

Commit

Permalink
Make BagEncoders CategoricalEncoders and SetEncoders return dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisrall committed Jul 11, 2023
1 parent cce7d5d commit 41df76d
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 24 deletions.
5 changes: 3 additions & 2 deletions ludwig/encoders/bag_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ludwig.constants import BAG
from ludwig.encoders.base import Encoder
from ludwig.encoders.registry import register_encoder
from ludwig.encoders.types import EncoderOutputDict
from ludwig.modules.embedding_modules import EmbedWeighted
from ludwig.modules.fully_connected_modules import FCStack
from ludwig.schema.encoders.bag_encoders import BagEmbedWeightedConfig
Expand Down Expand Up @@ -99,7 +100,7 @@ def input_shape(self) -> torch.Size:
def output_shape(self) -> torch.Size:
return self.fc_stack.output_shape

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x vocab size], type torch.int32
Expand All @@ -109,4 +110,4 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
hidden = self.embed_weighted(inputs)
hidden = self.fc_stack(hidden)

return hidden
return {"encoder_output": hidden}
18 changes: 10 additions & 8 deletions ludwig/encoders/category_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ludwig.constants import CATEGORY
from ludwig.encoders.base import Encoder
from ludwig.encoders.registry import register_encoder
from ludwig.encoders.types import EncoderOutputDict
from ludwig.modules.embedding_modules import Embed
from ludwig.schema.encoders.base import BaseEncoderConfig
from ludwig.schema.encoders.category_encoders import (
Expand All @@ -45,12 +46,12 @@ def __init__(self, input_size=1, encoder_config=None, **kwargs):
logger.debug(f" {self.name}")
self.input_size = input_size

def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x 1]
"""
return inputs.float()
return {"encoder_output": inputs.float()}

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
Expand Down Expand Up @@ -101,15 +102,15 @@ def __init__(
)
self.embedding_size = self.embed.embedding_size

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x 1], type torch.int32
:param return: embeddings of shape [batch x embed size], type torch.float32
"""
embedded = self.embed(inputs)
return embedded
return {"encoder_output": embedded}

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
Expand Down Expand Up @@ -156,15 +157,15 @@ def __init__(
)
self.embedding_size = self.embed.embedding_size

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x 1], type torch.int32
:param return: embeddings of shape [batch x embed size], type torch.float32
"""
embedded = self.embed(inputs)
return embedded
return {"encoder_output": embedded}

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
Expand Down Expand Up @@ -194,15 +195,16 @@ def __init__(
logger.debug(f" {self.name}")
self.vocab_size = len(vocab)

def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch, 1] or [batch]
"""
t = inputs.reshape(-1).long()
# the output of this must be a float so that it can be concatenated with other
# encoder outputs and passed to dense layers in the combiner, decoder, etc.
return torch.nn.functional.one_hot(t, num_classes=self.vocab_size).float()
outputs = torch.nn.functional.one_hot(t, num_classes=self.vocab_size).float()
return {"encoder_output": outputs}

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
Expand Down
5 changes: 3 additions & 2 deletions ludwig/encoders/set_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ludwig.constants import SET
from ludwig.encoders.base import Encoder
from ludwig.encoders.registry import register_encoder
from ludwig.encoders.types import EncoderOutputDict
from ludwig.modules.embedding_modules import EmbedSet
from ludwig.modules.fully_connected_modules import FCStack
from ludwig.schema.encoders.base import BaseEncoderConfig
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(
default_dropout=dropout,
)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
Params:
inputs: The inputs fed into the encoder.
Expand All @@ -101,7 +102,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
hidden = self.embed(inputs)
hidden = self.fc_stack(hidden)

return hidden
return {"encoder_output": hidden}

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
Expand Down
2 changes: 1 addition & 1 deletion ludwig/features/bag_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(self, inputs):

encoder_output = self.encoder_obj(inputs)

return {"encoder_output": encoder_output}
return encoder_output

@property
def input_shape(self) -> torch.Size:
Expand Down
12 changes: 2 additions & 10 deletions ludwig/features/category_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,23 +278,15 @@ def __init__(self, input_feature_config: CategoryInputFeatureConfig, encoder_obj

def forward(self, inputs):
assert isinstance(inputs, torch.Tensor)
assert (
inputs.dtype == torch.int8
or inputs.dtype == torch.int16
or inputs.dtype == torch.int32
or inputs.dtype == torch.int64
)
assert inputs.dtype in (torch.int8, torch.int16, torch.int32, torch.int64)
assert len(inputs.shape) == 1 or (len(inputs.shape) == 2 and inputs.shape[1] == 1)

inputs = inputs.reshape(-1, 1)
if inputs.dtype == torch.int8 or inputs.dtype == torch.int16:
inputs = inputs.type(torch.int)
encoder_output = self.encoder_obj(inputs)

batch_size = inputs.shape[0]
inputs = inputs.reshape(batch_size, -1)

return {"encoder_output": encoder_output}
return encoder_output

@property
def input_dtype(self):
Expand Down
2 changes: 1 addition & 1 deletion ludwig/features/set_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def forward(self, inputs):

encoder_output = self.encoder_obj(inputs)

return {"encoder_output": encoder_output}
return encoder_output

@property
def input_dtype(self):
Expand Down

0 comments on commit 41df76d

Please sign in to comment.