diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 813dc0b1076..7fab7c8e963 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -254,7 +254,7 @@ jobs: # remove torch and ray from the dependencies so we can add them depending on the matrix args for the job. cat requirements.txt | sed '/^torch[>=<\b]/d' | sed '/^torchtext/d' | sed '/^torchvision/d' | sed '/^torchaudio/d' > requirements-temp && mv requirements-temp requirements.txt cat requirements_distributed.txt | sed '/^ray[\[]/d' - pip install torch==2.0.0 torchtext torchvision torchaudio + pip install torch==2.1.0 torchtext torchvision torchaudio pip install ray==2.3.0 pip install '.[test]' pip list diff --git a/ludwig/explain/captum.py b/ludwig/explain/captum.py index 081568e18f7..36cd0a569a9 100644 --- a/ludwig/explain/captum.py +++ b/ludwig/explain/captum.py @@ -216,13 +216,13 @@ def explain(self) -> ExplanationsResult: feat_to_token_attributions_global[feat_name] = token_attributions_global self.global_explanation.add( - input_features.keys(), total_attribution_global, feat_to_token_attributions_global + list(input_features.keys()), total_attribution_global, feat_to_token_attributions_global ) for i, (feature_attributions, explanation) in enumerate(zip(total_attribution, self.row_explanations)): # Add the feature attributions to the explanation object for this row. explanation.add( - input_features.keys(), + list(input_features.keys()), feature_attributions, {k: v[i] for k, v in feat_to_token_attributions.items()}, ) @@ -245,7 +245,7 @@ def explain(self) -> ExplanationsResult: } # Prepend the negative class to the list of label explanations. self.global_explanation.add( - input_features.keys(), negated_attributions, negated_token_attributions, prepend=True + list(input_features.keys()), negated_attributions, negated_token_attributions, prepend=True ) for explanation in self.row_explanations: @@ -257,7 +257,9 @@ def explain(self) -> ExplanationsResult: if fa.token_attributions is not None } # Prepend the negative class to the list of label explanations. - explanation.add(input_features.keys(), negated_attributions, negated_token_attributions, prepend=True) + explanation.add( + list(input_features.keys()), negated_attributions, negated_token_attributions, prepend=True + ) # TODO(travis): for force plots, need something similar to SHAP E[X] expected_values.append(0.0) diff --git a/ludwig/explain/captum_ray.py b/ludwig/explain/captum_ray.py index 21d15db3815..24e96a8c45d 100644 --- a/ludwig/explain/captum_ray.py +++ b/ludwig/explain/captum_ray.py @@ -115,13 +115,13 @@ def explain(self) -> ExplanationsResult: feat_to_token_attributions_global[feat_name] = token_attributions_global self.global_explanation.add( - input_features.keys(), total_attribution_global, feat_to_token_attributions_global + list(input_features.keys()), total_attribution_global, feat_to_token_attributions_global ) for i, (feature_attributions, explanation) in enumerate(zip(total_attribution, self.row_explanations)): # Add the feature attributions to the explanation object for this row. explanation.add( - input_features.keys(), + list(input_features.keys()), feature_attributions, {k: v[i] for k, v in feat_to_token_attributions.items()}, ) @@ -140,7 +140,7 @@ def explain(self) -> ExplanationsResult: } # Prepend the negative class to the list of label explanations. self.global_explanation.add( - input_features.keys(), negated_attributions, negated_token_attributions, prepend=True + list(input_features.keys()), negated_attributions, negated_token_attributions, prepend=True ) for explanation in self.row_explanations: @@ -152,7 +152,9 @@ def explain(self) -> ExplanationsResult: if fa.token_attributions is not None } # Prepend the negative class to the list of label explanations. - explanation.add(input_features.keys(), negated_attributions, negated_token_attributions, prepend=True) + explanation.add( + list(input_features.keys()), negated_attributions, negated_token_attributions, prepend=True + ) # TODO(travis): for force plots, need something similar to SHAP E[X] expected_values.append(0.0) diff --git a/ludwig/explain/gbm.py b/ludwig/explain/gbm.py index 0eb8c652ac1..395a596d2cc 100644 --- a/ludwig/explain/gbm.py +++ b/ludwig/explain/gbm.py @@ -55,11 +55,11 @@ def explain(self) -> ExplanationsResult: expected_values = [] for _ in range(self.vocab_size): - self.global_explanation.add(base_model.input_features.keys(), feat_imp) + self.global_explanation.add(list(base_model.input_features.keys()), feat_imp) for explanation in self.row_explanations: # Add the feature attributions to the explanation object for this row. - explanation.add(base_model.input_features.keys(), feat_imp) + explanation.add(list(base_model.input_features.keys()), feat_imp) # TODO: expected_values.append(0.0) diff --git a/ludwig/features/feature_utils.py b/ludwig/features/feature_utils.py index 50834a89cfc..ff134742a8e 100644 --- a/ludwig/features/feature_utils.py +++ b/ludwig/features/feature_utils.py @@ -14,7 +14,8 @@ # limitations under the License. # ============================================================================== import re -from typing import Dict, List, Optional, Tuple, Union +from collections.abc import MutableMapping +from typing import Iterator, List, Optional, Union import numpy as np import torch @@ -157,7 +158,7 @@ def get_name_from_module_dict_key(key: str, feature_name_suffix_length: int = FE return name[:-feature_name_suffix_length] -class LudwigFeatureDict(torch.nn.Module): +class LudwigFeatureDict(torch.nn.Module, MutableMapping): """Torch ModuleDict wrapper that permits keys with any name. Torch's ModuleDict implementation doesn't allow certain keys to be used if they conflict with existing class @@ -174,39 +175,26 @@ class LudwigFeatureDict(torch.nn.Module): def __init__(self): super().__init__() self.module_dict = torch.nn.ModuleDict() - self.internal_key_to_original_name_map = {} - def get(self, key) -> torch.nn.Module: + def __getitem__(self, key: str) -> torch.nn.Module: return self.module_dict[get_module_dict_key_from_name(key)] - def set(self, key: str, module: torch.nn.Module) -> None: + def __setitem__(self, key: str, value: torch.nn.Module) -> None: module_dict_key_name = get_module_dict_key_from_name(key) - self.internal_key_to_original_name_map[module_dict_key_name] = key - self.module_dict[module_dict_key_name] = module + self.module_dict[module_dict_key_name] = value - def __len__(self) -> int: - return len(self.module_dict) - - def __next__(self) -> None: - return next(iter(self)) + def __delitem__(self, key: str) -> None: + del self.module_dict[get_module_dict_key_from_name(key)] - def __iter__(self) -> None: - return iter(self.keys()) + def __iter__(self) -> Iterator[str]: + return (get_name_from_module_dict_key(key) for key in self.module_dict) - def keys(self) -> List[str]: - return [ - get_name_from_module_dict_key(feature_name) - for feature_name in self.internal_key_to_original_name_map.keys() - ] - - def values(self) -> List[torch.nn.Module]: - return [module for _, module in self.module_dict.items()] + def __len__(self) -> int: + return len(self.module_dict) - def items(self) -> List[Tuple[str, torch.nn.Module]]: - return [ - (get_name_from_module_dict_key(feature_name), module) for feature_name, module in self.module_dict.items() - ] + def set(self, key: str, value: torch.nn.Module) -> None: + self[key] = value - def update(self, modules: Dict[str, torch.nn.Module]) -> None: - for feature_name, module in modules.items(): - self.set(feature_name, module) + def __hash__(self) -> int: + """Static hash value, because the object is mutable, but needs to be hashable for pytorch.""" + return 1 diff --git a/ludwig/models/ecd.py b/ludwig/models/ecd.py index fcb5450d1cf..6b0cb64e01b 100644 --- a/ludwig/models/ecd.py +++ b/ludwig/models/ecd.py @@ -143,7 +143,7 @@ def forward( else: targets = None - assert list(inputs.keys()) == self.input_features.keys() + assert list(inputs.keys()) == list(self.input_features.keys()) encoder_outputs = self.encode(inputs) combiner_outputs = self.combine(encoder_outputs) diff --git a/ludwig/models/gbm.py b/ludwig/models/gbm.py index b249fce2cb6..a30c170025f 100644 --- a/ludwig/models/gbm.py +++ b/ludwig/models/gbm.py @@ -101,14 +101,14 @@ def forward( ) -> Dict[str, torch.Tensor]: # Invoke output features. output_logits = {} - output_feature_name = self.output_features.keys()[0] + output_feature_name = next(iter(self.output_features.keys())) output_feature = self.output_features.get(output_feature_name) # If `inputs` is a tuple, it should contain `(inputs, targets)`. if isinstance(inputs, tuple): inputs, _ = inputs - assert list(inputs.keys()) == self.input_features.keys() + assert list(inputs.keys()) == list(self.input_features.keys()) # If the model has not been compiled, predict using the LightGBM sklearn iterface. Otherwise, use torch with # the Hummingbird compiled model. Notably, when compiling the model to torchscript, compiling with Hummingbird diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index cc7cbc2efa8..acef9f8e222 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -65,13 +65,13 @@ def __iter__(self) -> None: return iter(self.obj.keys()) def keys(self) -> List[str]: - return self.obj.keys() + return list(self.obj.keys()) def values(self) -> List[torch.nn.Module]: - return self.obj.values() + return list(self.obj.values()) def items(self) -> List[Tuple[str, torch.nn.Module]]: - return self.obj.items() + return list(self.obj.items()) def update(self, modules: Dict[str, torch.nn.Module]) -> None: self.obj.update(modules) @@ -148,7 +148,8 @@ def __init__( ) # Extract the decoder object for the forward pass - self._output_feature_decoder = ModuleWrapper(self.output_features.items()[0][1]) + decoder = next(iter(self.output_features.values())) + self._output_feature_decoder = ModuleWrapper(decoder) self.attention_masks = None @@ -401,7 +402,7 @@ def _unpack_inputs( else: targets = None - assert list(inputs.keys()) == self.input_features.keys() + assert list(inputs.keys()) == list(self.input_features.keys()) input_ids = self.get_input_ids(inputs) target_ids = self.get_target_ids(targets) if targets else None diff --git a/ludwig/trainers/trainer_lightgbm.py b/ludwig/trainers/trainer_lightgbm.py index 2a8174bc4f1..01cd31082e3 100644 --- a/ludwig/trainers/trainer_lightgbm.py +++ b/ludwig/trainers/trainer_lightgbm.py @@ -831,8 +831,8 @@ def _construct_lgb_datasets( validation_set: Optional["Dataset"] = None, # noqa: F821 test_set: Optional["Dataset"] = None, # noqa: F821 ) -> Tuple[lgb.Dataset, List[lgb.Dataset], List[str]]: - X_train = training_set.to_scalar_df(self.model.input_features.values()) - y_train = training_set.to_scalar_df(self.model.output_features.values()) + X_train = training_set.to_scalar_df(list(self.model.input_features.values())) + y_train = training_set.to_scalar_df(list(self.model.output_features.values())) # create dataset for lightgbm # keep raw data for continued training https://github.com/microsoft/LightGBM/issues/4965#issuecomment-1019344293 @@ -850,8 +850,8 @@ def _construct_lgb_datasets( eval_sets = [lgb_train] eval_names = [LightGBMTrainer.TRAIN_KEY] if validation_set is not None: - X_val = validation_set.to_scalar_df(self.model.input_features.values()) - y_val = validation_set.to_scalar_df(self.model.output_features.values()) + X_val = validation_set.to_scalar_df(list(self.model.input_features.values())) + y_val = validation_set.to_scalar_df(list(self.model.output_features.values())) try: lgb_val = lgb.Dataset(X_val, label=y_val, reference=lgb_train, free_raw_data=False).construct() except lgb.basic.LightGBMError as e: @@ -869,8 +869,8 @@ def _construct_lgb_datasets( pass if test_set is not None: - X_test = test_set.to_scalar_df(self.model.input_features.values()) - y_test = test_set.to_scalar_df(self.model.output_features.values()) + X_test = test_set.to_scalar_df(list(self.model.input_features.values())) + y_test = test_set.to_scalar_df(list(self.model.output_features.values())) try: lgb_test = lgb.Dataset(X_test, label=y_test, reference=lgb_train, free_raw_data=False).construct() except lgb.basic.LightGBMError as e: diff --git a/requirements.txt b/requirements.txt index 1b2f320d387..8c58ec0875f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -67,3 +67,6 @@ datasets # pin required for torch 2.1.0 urllib3<2 + +# required for torchaudio 2.1.0 +sox diff --git a/tests/ludwig/features/test_feature_utils.py b/tests/ludwig/features/test_feature_utils.py index c0fd14ac9cd..d59e2dea015 100644 --- a/tests/ludwig/features/test_feature_utils.py +++ b/tests/ludwig/features/test_feature_utils.py @@ -5,6 +5,113 @@ from ludwig.features import feature_utils +@pytest.fixture +def to_module() -> torch.nn.Module: + """Dummy Module to test the LudwigFeatureDict.""" + return torch.nn.Module() + + +@pytest.fixture +def type_module() -> torch.nn.Module: + """Dummy Module to test the LudwigFeatureDict.""" + return torch.nn.Module() + + +@pytest.fixture +def feature_dict(to_module: torch.nn.Module, type_module: torch.nn.Module) -> feature_utils.LudwigFeatureDict: + fdict = feature_utils.LudwigFeatureDict() + fdict.set("to", to_module) + fdict["type"] = type_module + return fdict + + +def test_ludwig_feature_dict_get( + feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module, type_module: torch.nn.Module +): + assert feature_dict["to"] == to_module + assert feature_dict.get("type") == type_module + assert feature_dict.get("other_key", default=None) is None + + +def test_ludwig_feature_dict_keys(feature_dict: feature_utils.LudwigFeatureDict): + assert list(feature_dict.keys()) == ["to", "type"] + + +def test_ludwig_feature_dict_values( + feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module, type_module: torch.nn.Module +): + assert list(feature_dict.values()) == [to_module, type_module] + + +def test_ludwig_feature_dict_items( + feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module, type_module: torch.nn.Module +): + assert list(feature_dict.items()) == [("to", to_module), ("type", type_module)] + + +def test_ludwig_feature_dict_iter(feature_dict: feature_utils.LudwigFeatureDict): + assert list(iter(feature_dict)) == ["to", "type"] + assert list(feature_dict) == ["to", "type"] + + +def test_ludwig_feature_dict_len(feature_dict: feature_utils.LudwigFeatureDict): + assert len(feature_dict) == 2 + + +def test_ludwig_feature_dict_contains(feature_dict: feature_utils.LudwigFeatureDict): + assert "to" in feature_dict and "type" in feature_dict + + +def test_ludwig_feature_dict_eq(feature_dict: feature_utils.LudwigFeatureDict): + other_dict = feature_utils.LudwigFeatureDict() + assert not feature_dict == other_dict + other_dict.update(feature_dict.items()) + assert feature_dict == other_dict + + +def test_ludwig_feature_dict_update( + feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module, type_module: torch.nn.Module +): + feature_dict.update({"to": torch.nn.Module(), "new": torch.nn.Module()}) + assert len(feature_dict) == 3 + assert not feature_dict.get("to") == to_module + assert feature_dict.get("type") == type_module + + +def test_ludwig_feature_dict_del(feature_dict: feature_utils.LudwigFeatureDict): + del feature_dict["to"] + assert len(feature_dict) == 1 + + +def test_ludwig_feature_dict_clear(feature_dict: feature_utils.LudwigFeatureDict): + feature_dict.clear() + assert len(feature_dict) == 0 + + +def test_ludwig_feature_dict_pop(feature_dict: feature_utils.LudwigFeatureDict, type_module: torch.nn.Module): + assert feature_dict.pop("type") == type_module + assert len(feature_dict) == 1 + assert feature_dict.pop("type", default=None) is None + + +def test_ludwig_feature_dict_popitem(feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module): + assert feature_dict.popitem() == ("to", to_module) + assert len(feature_dict) == 1 + + +def test_ludwig_feature_dict_setdefault(feature_dict: feature_utils.LudwigFeatureDict, to_module: torch.nn.Module): + assert feature_dict.setdefault("to") == to_module + assert feature_dict.get("other_key") is None + + +@pytest.mark.parametrize("name", ["to", "type", "foo", "foo.bar"]) +def test_name_to_module_dict_key(name: str): + key = feature_utils.get_module_dict_key_from_name(name) + assert key != name + assert "." not in key + assert feature_utils.get_name_from_module_dict_key(key) == name + + def test_ludwig_feature_dict(): feature_dict = feature_utils.LudwigFeatureDict() @@ -15,10 +122,9 @@ def test_ludwig_feature_dict(): feature_dict.set("type", type_module) assert iter(feature_dict) is not None - assert next(feature_dict) is not None assert len(feature_dict) == 2 - assert feature_dict.keys() == ["to", "type"] - assert feature_dict.items() == [("to", to_module), ("type", type_module)] + assert list(feature_dict.keys()) == ["to", "type"] + assert list(feature_dict.items()) == [("to", to_module), ("type", type_module)] assert feature_dict.get("to"), to_module feature_dict.update({"to_empty": torch.nn.Module()}) @@ -34,8 +140,8 @@ def test_ludwig_feature_dict_with_periods(): feature_dict.set("to.", to_module) - assert feature_dict.keys() == ["to."] - assert feature_dict.items() == [("to.", to_module)] + assert list(feature_dict.keys()) == ["to."] + assert list(feature_dict.items()) == [("to.", to_module)] assert feature_dict.get("to.") == to_module