Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add type hints for encoders #3449

Merged
merged 8 commits into from
Jul 13, 2023
42 changes: 21 additions & 21 deletions ludwig/combiners/combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.nn import Linear, ModuleList

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import BINARY, NUMBER
from ludwig.constants import BINARY, ENCODER_OUTPUT, NUMBER
from ludwig.encoders.registry import get_sequence_encoder_registry
from ludwig.features.base_feature import InputFeature
from ludwig.modules.attention_modules import TransformerStack
Expand Down Expand Up @@ -96,7 +96,7 @@ def output_shape(self) -> torch.Size:
pseudo_input = {}
for k in self.handle.input_features:
pseudo_input[k] = {
"encoder_output": torch.rand(
ENCODER_OUTPUT: torch.rand(
2, *self.handle.input_features.get(k).output_shape, dtype=self.input_dtype, device=self.device
)
}
Expand Down Expand Up @@ -157,7 +157,7 @@ def __init__(self, input_features: Dict[str, "InputFeature"] = None, config: Con
self.supports_masking = True

def forward(self, inputs: Dict) -> Dict: # encoder outputs
encoder_outputs = [inputs[k]["encoder_output"] for k in inputs]
encoder_outputs = [inputs[k][ENCODER_OUTPUT] for k in inputs]

# ================ Flatten ================
if self.flatten_inputs:
Expand All @@ -181,7 +181,7 @@ def forward(self, inputs: Dict) -> Dict: # encoder outputs
# TODO(Justin): Think about how to make this communication work for multi-sequence
# features. Other combiners.
for key, value in [d for d in inputs.values()][0].items():
if key != "encoder_output":
if key != ENCODER_OUTPUT:
return_data[key] = value

return return_data
Expand Down Expand Up @@ -233,7 +233,7 @@ def forward(self, inputs: Dict) -> Dict: # encoder outputs
# todo: when https://github.com/ludwig-ai/ludwig/issues/810 is closed
# convert following test from using shape to use explicit
# if_outputs[TYPE] values for sequence features
if len(if_outputs["encoder_output"].shape) == 3:
if len(if_outputs[ENCODER_OUTPUT].shape) == 3:
self.main_sequence_feature = if_name
break

Expand All @@ -242,7 +242,7 @@ def forward(self, inputs: Dict) -> Dict: # encoder outputs

main_sequence_feature_encoding = inputs[self.main_sequence_feature]

representation = main_sequence_feature_encoding["encoder_output"]
representation = main_sequence_feature_encoding[ENCODER_OUTPUT]
representations = [representation]

sequence_max_length = representation.shape[1]
Expand All @@ -251,7 +251,7 @@ def forward(self, inputs: Dict) -> Dict: # encoder outputs
# ================ Concat ================
for if_name, if_outputs in inputs.items():
if if_name != self.main_sequence_feature:
if_representation = if_outputs["encoder_output"]
if_representation = if_outputs[ENCODER_OUTPUT]
if len(if_representation.shape) == 3:
# The following check makes sense when
# both representations have a specified
Expand Down Expand Up @@ -322,7 +322,7 @@ def forward(self, inputs: Dict) -> Dict: # encoder outputs

if len(inputs) == 1:
for key, value in [d for d in inputs.values()][0].items():
if key != "encoder_output":
if key != ENCODER_OUTPUT:
return_data[key] = value

return return_data
Expand Down Expand Up @@ -383,9 +383,9 @@ def forward(self, inputs: Dict) -> Dict: # encoder outputs
# ================ Sequence encoding ================
hidden = self.encoder_obj(hidden["combiner_output"])

return_data = {"combiner_output": hidden["encoder_output"]}
return_data = {"combiner_output": hidden[ENCODER_OUTPUT]}
for key, value in hidden.items():
if key != "encoder_output":
if key != ENCODER_OUTPUT:
return_data[key] = value

return return_data
Expand Down Expand Up @@ -435,7 +435,7 @@ def forward(
self,
inputs: torch.Tensor, # encoder outputs
) -> Dict:
encoder_outputs = [inputs[k]["encoder_output"] for k in inputs]
encoder_outputs = [inputs[k][ENCODER_OUTPUT] for k in inputs]

# ================ Flatten ================
batch_size = encoder_outputs[0].shape[0]
Expand All @@ -460,7 +460,7 @@ def forward(

if len(inputs) == 1:
for key, value in [d for d in inputs.values()][0].items():
if key != "encoder_output":
if key != ENCODER_OUTPUT:
return_data[key] = value

return return_data
Expand Down Expand Up @@ -536,7 +536,7 @@ def forward(
self,
inputs, # encoder outputs
) -> Dict:
encoder_outputs = [inputs[k]["encoder_output"] for k in inputs]
encoder_outputs = [inputs[k][ENCODER_OUTPUT] for k in inputs]

# ================ Flatten ================
batch_size = encoder_outputs[0].shape[0]
Expand All @@ -561,7 +561,7 @@ def forward(

if len(inputs) == 1:
for key, value in [d for d in inputs.values()][0].items():
if key != "encoder_output":
if key != ENCODER_OUTPUT:
return_data[key] = value

return return_data
Expand Down Expand Up @@ -695,8 +695,8 @@ def forward(
self,
inputs: Dict, # encoder outputs
) -> Dict:
unembeddable_encoder_outputs = [inputs[k]["encoder_output"] for k in inputs if k in self.unembeddable_features]
embeddable_encoder_outputs = [inputs[k]["encoder_output"] for k in inputs if k in self.embeddable_features]
unembeddable_encoder_outputs = [inputs[k][ENCODER_OUTPUT] for k in inputs if k in self.unembeddable_features]
embeddable_encoder_outputs = [inputs[k][ENCODER_OUTPUT] for k in inputs if k in self.embeddable_features]

batch_size = (
embeddable_encoder_outputs[0].shape[0]
Expand Down Expand Up @@ -758,7 +758,7 @@ def forward(

if len(inputs) == 1:
for key, value in [d for d in inputs.values()][0].items():
if key != "encoder_output":
if key != ENCODER_OUTPUT:
return_data[key] = value

return return_data
Expand Down Expand Up @@ -845,7 +845,7 @@ def forward(
############
# Entity 1 #
############
e1_enc_outputs = [inputs[k]["encoder_output"] for k in self.entity_1]
e1_enc_outputs = [inputs[k][ENCODER_OUTPUT] for k in self.entity_1]

# ================ Flatten ================
batch_size = e1_enc_outputs[0].shape[0]
Expand All @@ -863,7 +863,7 @@ def forward(
############
# Entity 2 #
############
e2_enc_outputs = [inputs[k]["encoder_output"] for k in self.entity_2]
e2_enc_outputs = [inputs[k][ENCODER_OUTPUT] for k in self.entity_2]

# ================ Flatten ================
batch_size = e2_enc_outputs[0].shape[0]
Expand Down Expand Up @@ -960,7 +960,7 @@ def __init__(
self.supports_masking = True

def forward(self, inputs: Dict) -> Dict: # encoder outputs
encoder_outputs = [inputs[k]["encoder_output"] for k in inputs]
encoder_outputs = [inputs[k][ENCODER_OUTPUT] for k in inputs]

# ================ Flatten ================
batch_size = encoder_outputs[0].shape[0]
Expand All @@ -986,7 +986,7 @@ def forward(self, inputs: Dict) -> Dict: # encoder outputs
# TODO(Justin): Think about how to make this communication work for multi-sequence
# features. Other combiners.
for key, value in [d for d in inputs.values()][0].items():
if key != "encoder_output":
if key != ENCODER_OUTPUT:
return_data[key] = value

return return_data
1 change: 1 addition & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
LOGITS = "logits"
HIDDEN = "hidden"
LAST_HIDDEN = "last_hidden"
ENCODER_OUTPUT = "encoder_output"
ENCODER_OUTPUT_STATE = "encoder_output_state"
PROJECTION_INPUT = "projection_input"
LEARNING_RATE_SCHEDULER = "learning_rate_scheduler"
Expand Down
2 changes: 1 addition & 1 deletion ludwig/decoders/sequence_decoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_lstm_init_state(
Returns:
Tuple of 2 tensors (decoder hidden state, decoder cell state), each [num_layers, batch_size, hidden_size].
"""
if "encoder_output_state" not in combiner_outputs:
if ENCODER_OUTPUT_STATE not in combiner_outputs:
# Use the combiner's hidden state.
decoder_hidden_state = combiner_outputs[HIDDEN]
decoder_cell_state = torch.clone(decoder_hidden_state)
Expand Down
12 changes: 7 additions & 5 deletions ludwig/encoders/bag_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
# limitations under the License.
# ==============================================================================
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type

import torch

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import BAG
from ludwig.constants import BAG, ENCODER_OUTPUT
from ludwig.encoders.base import Encoder
from ludwig.encoders.registry import register_encoder
from ludwig.encoders.types import EncoderOutputDict
from ludwig.modules.embedding_modules import EmbedWeighted
from ludwig.modules.fully_connected_modules import FCStack
from ludwig.schema.encoders.bag_encoders import BagEmbedWeightedConfig
from ludwig.schema.encoders.base import BaseEncoderConfig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,7 +89,7 @@ def __init__(
)

@staticmethod
def get_schema_cls():
def get_schema_cls() -> Type[BaseEncoderConfig]:
return BagEmbedWeightedConfig

@property
Expand All @@ -98,7 +100,7 @@ def input_shape(self) -> torch.Size:
def output_shape(self) -> torch.Size:
return self.fc_stack.output_shape

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x vocab size], type torch.int32
Expand All @@ -108,4 +110,4 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
hidden = self.embed_weighted(inputs)
hidden = self.fc_stack(hidden)

return hidden
return {ENCODER_OUTPUT: hidden}
8 changes: 6 additions & 2 deletions ludwig/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@ def forward(self, inputs, training=None, mask=None):
raise NotImplementedError

def get_embedding_layer(self) -> nn.Module:
"""Returns layer that embeds inputs, used for computing explanations."""
"""Returns layer that embeds inputs, used for computing explanations.

Captum adds an evaluation hook to this module returned by this function. The hook copies the module's return
with .clone(). The module returned by this function must return a tensor, not a dictionary of tensors.
"""
return next(self.children())

@property
def name(self):
def name(self) -> str:
return self.__class__.__name__
37 changes: 21 additions & 16 deletions ludwig/encoders/category_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
# limitations under the License.
# ==============================================================================
import logging
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union

import torch
from torch import nn

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import CATEGORY
from ludwig.constants import CATEGORY, ENCODER_OUTPUT
from ludwig.encoders.base import Encoder
from ludwig.encoders.registry import register_encoder
from ludwig.encoders.types import EncoderOutputDict
from ludwig.modules.embedding_modules import Embed
from ludwig.schema.encoders.base import BaseEncoderConfig
from ludwig.schema.encoders.category_encoders import (
CategoricalEmbedConfig,
CategoricalOneHotEncoderConfig,
Expand All @@ -43,16 +45,17 @@ def __init__(self, input_size=1, encoder_config=None, **kwargs):

logger.debug(f" {self.name}")
self.input_size = input_size
self.identity = nn.Identity()

def forward(self, inputs, mask=None):
def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x 1]
"""
return inputs.float()
return {"encoder_output": self.identity(inputs.float())}

@staticmethod
def get_schema_cls():
def get_schema_cls() -> Type[BaseEncoderConfig]:
return CategoricalPassthroughEncoderConfig

@property
Expand All @@ -64,7 +67,7 @@ def output_shape(self) -> torch.Size:
return self.input_shape

def get_embedding_layer(self) -> nn.Module:
return self
return self.identity


@DeveloperAPI
Expand Down Expand Up @@ -100,18 +103,18 @@ def __init__(
)
self.embedding_size = self.embed.embedding_size

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x 1], type torch.int32

:param return: embeddings of shape [batch x embed size], type torch.float32
"""
embedded = self.embed(inputs)
return embedded
return {ENCODER_OUTPUT: embedded}

@staticmethod
def get_schema_cls():
def get_schema_cls() -> Type[BaseEncoderConfig]:
return CategoricalEmbedConfig

@property
Expand Down Expand Up @@ -155,18 +158,18 @@ def __init__(
)
self.embedding_size = self.embed.embedding_size

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch x 1], type torch.int32

:param return: embeddings of shape [batch x embed size], type torch.float32
"""
embedded = self.embed(inputs)
return embedded
return {ENCODER_OUTPUT: embedded}

@staticmethod
def get_schema_cls():
def get_schema_cls() -> Type[BaseEncoderConfig]:
return CategoricalSparseConfig

@property
Expand All @@ -192,19 +195,21 @@ def __init__(

logger.debug(f" {self.name}")
self.vocab_size = len(vocab)
self.identity = nn.Identity()

def forward(self, inputs, mask=None):
def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Shape: [batch, 1] or [batch]
"""
t = inputs.reshape(-1).long()
# the output of this must be a float so that it can be concatenated with other
# encoder outputs and passed to dense layers in the combiner, decoder, etc.
return torch.nn.functional.one_hot(t, num_classes=self.vocab_size).float()
outputs = self.identity(torch.nn.functional.one_hot(t, num_classes=self.vocab_size).float())
return {"encoder_output": outputs}

@staticmethod
def get_schema_cls():
def get_schema_cls() -> Type[BaseEncoderConfig]:
return CategoricalOneHotEncoderConfig

@property
Expand All @@ -216,4 +221,4 @@ def output_shape(self) -> torch.Size:
return torch.Size([self.vocab_size])

def get_embedding_layer(self) -> nn.Module:
return self
return self.identity
Loading
Loading