Skip to content

Commit

Permalink
refactor: make BagEncoders CategoricalEncoders and SetEncoders return…
Browse files Browse the repository at this point in the history
… dicts
  • Loading branch information
dennisrall committed Jul 11, 2023
1 parent f600e76 commit 5595148
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 28 deletions.
5 changes: 3 additions & 2 deletions ludwig/encoders/bag_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ludwig.constants import BAG
from ludwig.encoders.base import Encoder
from ludwig.encoders.registry import register_encoder
from ludwig.encoders.types import EncoderOutputDict
from ludwig.modules.embedding_modules import EmbedWeighted
from ludwig.modules.fully_connected_modules import FCStack
from ludwig.schema.encoders.bag_encoders import BagEmbedWeightedConfig
Expand Down Expand Up @@ -99,7 +100,7 @@ def input_shape(self) -> torch.Size:
def output_shape(self) -> torch.Size:
return self.fc_stack.output_shape

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x vocab size], type torch.int32
Expand All @@ -109,4 +110,4 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
hidden = self.embed_weighted(inputs)
hidden = self.fc_stack(hidden)

return hidden
return {"encoder_output": hidden}
18 changes: 10 additions & 8 deletions ludwig/encoders/category_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ludwig.constants import CATEGORY
from ludwig.encoders.base import Encoder
from ludwig.encoders.registry import register_encoder
from ludwig.encoders.types import EncoderOutputDict
from ludwig.modules.embedding_modules import Embed
from ludwig.schema.encoders.base import BaseEncoderConfig
from ludwig.schema.encoders.category_encoders import (
Expand All @@ -45,12 +46,12 @@ def __init__(self, input_size=1, encoder_config=None, **kwargs):
logger.debug(f" {self.name}")
self.input_size = input_size

def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x 1]
"""
return inputs.float()
return {"encoder_output": inputs.float()}

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
Expand Down Expand Up @@ -101,15 +102,15 @@ def __init__(
)
self.embedding_size = self.embed.embedding_size

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x 1], type torch.int32
:param return: embeddings of shape [batch x embed size], type torch.float32
"""
embedded = self.embed(inputs)
return embedded
return {"encoder_output": embedded}

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
Expand Down Expand Up @@ -156,15 +157,15 @@ def __init__(
)
self.embedding_size = self.embed.embedding_size

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x 1], type torch.int32
:param return: embeddings of shape [batch x embed size], type torch.float32
"""
embedded = self.embed(inputs)
return embedded
return {"encoder_output": embedded}

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
Expand Down Expand Up @@ -194,15 +195,16 @@ def __init__(
logger.debug(f" {self.name}")
self.vocab_size = len(vocab)

def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch, 1] or [batch]
"""
t = inputs.reshape(-1).long()
# the output of this must be a float so that it can be concatenated with other
# encoder outputs and passed to dense layers in the combiner, decoder, etc.
return torch.nn.functional.one_hot(t, num_classes=self.vocab_size).float()
outputs = torch.nn.functional.one_hot(t, num_classes=self.vocab_size).float()
return {"encoder_output": outputs}

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
Expand Down
5 changes: 3 additions & 2 deletions ludwig/encoders/set_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ludwig.constants import SET
from ludwig.encoders.base import Encoder
from ludwig.encoders.registry import register_encoder
from ludwig.encoders.types import EncoderOutputDict
from ludwig.modules.embedding_modules import EmbedSet
from ludwig.modules.fully_connected_modules import FCStack
from ludwig.schema.encoders.base import BaseEncoderConfig
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(
default_dropout=dropout,
)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
Params:
inputs: The inputs fed into the encoder.
Expand All @@ -101,7 +102,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
hidden = self.embed(inputs)
hidden = self.fc_stack(hidden)

return hidden
return {"encoder_output": hidden}

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
Expand Down
2 changes: 1 addition & 1 deletion ludwig/features/bag_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(self, inputs):

encoder_output = self.encoder_obj(inputs)

return {"encoder_output": encoder_output}
return encoder_output

@property
def input_shape(self) -> torch.Size:
Expand Down
12 changes: 2 additions & 10 deletions ludwig/features/category_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,23 +278,15 @@ def __init__(self, input_feature_config: CategoryInputFeatureConfig, encoder_obj

def forward(self, inputs):
assert isinstance(inputs, torch.Tensor)
assert (
inputs.dtype == torch.int8
or inputs.dtype == torch.int16
or inputs.dtype == torch.int32
or inputs.dtype == torch.int64
)
assert inputs.dtype in (torch.int8, torch.int16, torch.int32, torch.int64)
assert len(inputs.shape) == 1 or (len(inputs.shape) == 2 and inputs.shape[1] == 1)

inputs = inputs.reshape(-1, 1)
if inputs.dtype == torch.int8 or inputs.dtype == torch.int16:
inputs = inputs.type(torch.int)
encoder_output = self.encoder_obj(inputs)

batch_size = inputs.shape[0]
inputs = inputs.reshape(batch_size, -1)

return {"encoder_output": encoder_output}
return encoder_output

@property
def input_dtype(self):
Expand Down
2 changes: 1 addition & 1 deletion ludwig/features/set_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def forward(self, inputs):

encoder_output = self.encoder_obj(inputs)

return {"encoder_output": encoder_output}
return encoder_output

@property
def input_dtype(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/ludwig/encoders/test_bag_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_set_encoder(vocab: List[str], embedding_size: int, representation: str,
dropout=dropout,
).to(DEVICE)
inputs = torch.randint(0, 9, size=(2, len(vocab))).to(DEVICE)
outputs = bag_encoder(inputs)
outputs = bag_encoder(inputs)["encoder_output"]
assert outputs.shape[1:] == bag_encoder.output_shape

# check for parameter updating
Expand Down
4 changes: 2 additions & 2 deletions tests/ludwig/encoders/test_category_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_categorical_dense_encoder(vocab: List[str], embedding_size: int, traina
).to(DEVICE)
inputs = torch.randint(len(vocab), (10,)).to(DEVICE) # Chooses 10 items from vocab with replacement.
inputs = torch.unsqueeze(inputs, 1)
outputs = dense_encoder(inputs)
outputs = dense_encoder(inputs)["encoder_output"]
# In dense mode, the embedding size should be less than or equal to vocab size.
assert outputs.shape[-1] == min(embedding_size, len(vocab))
# Ensures output shape matches encoder expected output shape.
Expand All @@ -52,7 +52,7 @@ def test_categorical_sparse_encoder(vocab: List[str], trainable: bool):
sparse_encoder = CategoricalSparseEncoder(vocab=vocab, embeddings_trainable=trainable).to(DEVICE)
inputs = torch.randint(len(vocab), (10,)).to(DEVICE) # Chooses 10 items from vocab with replacement.
inputs = torch.unsqueeze(inputs, 1)
outputs = sparse_encoder(inputs)
outputs = sparse_encoder(inputs)["encoder_output"]
# In sparse mode, embedding_size will always be equal to vocab size.
assert outputs.shape[-1] == len(vocab)
# Ensures output shape matches encoder expected output shape.
Expand Down
2 changes: 1 addition & 1 deletion tests/ludwig/encoders/test_set_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_set_encoder(
num_fc_layers=num_fc_layers,
).to(DEVICE)
inputs = torch.randint(0, 2, size=(2, len(vocab))).bool().to(DEVICE)
outputs = set_encoder(inputs)
outputs = set_encoder(inputs)["encoder_output"]
assert outputs.shape[1:] == set_encoder.output_shape

# check for parameter updating
Expand Down

0 comments on commit 5595148

Please sign in to comment.