Skip to content

Commit

Permalink
simplify __init__ files, remove some unused code (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez authored Jul 5, 2024
1 parent e6ef35e commit c74e495
Show file tree
Hide file tree
Showing 12 changed files with 10 additions and 213 deletions.
24 changes: 1 addition & 23 deletions eole/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1 @@
import eole.inputters
import eole.encoders
import eole.decoders
import eole.models
import eole.utils
import eole.modules
import sys
import eole.utils.optimizers

eole.utils.optimizers.Optim = eole.utils.optimizers.Optimizer
sys.modules["eole.Optim"] = eole.utils.optimizers

# For Flake
__all__ = [
eole.inputters,
eole.encoders,
eole.decoders,
eole.models,
eole.utils,
eole.modules,
]

__version__ = "3.5.1"
__version__ = "0.0.1"
2 changes: 1 addition & 1 deletion eole/bin/model/extract_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from eole.models import get_model_class
from eole.models.model import get_model_class
from eole.models.model_saver import load_checkpoint
from eole.inputters.inputter import dict_to_vocabs

Expand Down
2 changes: 1 addition & 1 deletion eole/bin/model/lora_weights.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from eole.utils.logging import init_logger, logger
from eole.models.model_saver import load_checkpoint
from eole.models import get_model_class
from eole.models.model import get_model_class
from eole.inputters.inputter import dict_to_vocabs, vocabs_to_dict
from eole.config import recursive_model_fields_set
from safetensors import safe_open
Expand Down
11 changes: 0 additions & 11 deletions eole/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from collections import OrderedDict
from pydantic import Field
from eole.config.config import Config
from eole.utils.logging import logger

Expand Down Expand Up @@ -109,13 +108,3 @@ def get_non_default_values(parsed_args, defaults):
if value != defaults.get(key, None):
non_default_values[key] = value
return non_default_values


# tentative wrapper functions to lighten definitions below
def field_with_default(default, description, **kwargs):
return Field(default=default, description=description, **kwargs)


def required_field(description, **kwargs):
# no default
return Field(description=description, **kwargs)
51 changes: 0 additions & 51 deletions eole/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""Module defining decoders."""
import os
import importlib
from eole.decoders.decoder import DecoderBase
from eole.decoders.rnn_decoder import InputFeedRNNDecoder, StdRNNDecoder
from eole.decoders.transformer_decoder import TransformerDecoder
from eole.decoders.transformer_lm_decoder import TransformerLMDecoder
Expand All @@ -15,51 +12,3 @@
"transformer": TransformerDecoder,
"transformer_lm": TransformerLMDecoder,
}

__all__ = [
"DecoderBase",
"TransformerDecoder",
"StdRNNDecoder",
"CNNDecoder",
"InputFeedRNNDecoder",
"str2dec",
"TransformerLMDecoder",
]


def get_decoders_cls(decoders_names):
"""Return valid encoder class indicated in `decoders_names`."""
decoders_cls = {}
for name in decoders_names:
if name not in str2dec:
raise ValueError("%s decoder not supported!" % name)
decoders_cls[name] = str2dec[name]
return decoders_cls


def register_decoder(name):
"""Encoder register that can be used to add new encoder class."""

def register_decoder_cls(cls):
if name in str2dec:
raise ValueError("Cannot register duplicate decoder ({})".format(name))
if not issubclass(cls, DecoderBase):
raise ValueError(f"decoder ({name}: {cls.__name_}) must extend DecoderBase")
str2dec[name] = cls
__all__.append(cls.__name__) # added to be complete
return cls

return register_decoder_cls


# Auto import python files in this directory
decoder_dir = os.path.dirname(__file__)
for file in os.listdir(decoder_dir):
path = os.path.join(decoder_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
file_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("eole.decoders." + file_name)
2 changes: 1 addition & 1 deletion eole/decoders/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import copy
from eole.encoders.encoder import EncoderBase
from eole.decoders.decoder import DecoderBase
from eole.models import EncoderDecoderModel, BaseModel
from eole.models.model import EncoderDecoderModel, BaseModel


class EnsembleDecoderOutput(object):
Expand Down
50 changes: 0 additions & 50 deletions eole/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""Module defining encoders."""
import os
import importlib
from eole.encoders.encoder import EncoderBase
from eole.encoders.transformer import TransformerEncoder
from eole.encoders.rnn_encoder import RNNEncoder
from eole.encoders.cnn_encoder import CNNEncoder
Expand All @@ -15,50 +12,3 @@
"transformer": TransformerEncoder,
"mean": MeanEncoder,
}

__all__ = [
"EncoderBase",
"TransformerEncoder",
"RNNEncoder",
"CNNEncoder",
"MeanEncoder",
"str2enc",
]


def get_encoders_cls(encoder_names):
"""Return valid encoder class indicated in `encoder_names`."""
encoders_cls = {}
for name in encoder_names:
if name not in str2enc:
raise ValueError("%s encoder not supported!" % name)
encoders_cls[name] = str2enc[name]
return encoders_cls


def register_encoder(name):
"""Encoder register that can be used to add new encoder class."""

def register_encoder_cls(cls):
if name in str2enc:
raise ValueError("Cannot register duplicate encoder ({})".format(name))
if not issubclass(cls, EncoderBase):
raise ValueError(f"encoder ({name}: {cls.__name_}) must extend EncoderBase")
str2enc[name] = cls
__all__.append(cls.__name__) # added to be complete
return cls

return register_encoder_cls


# Auto import python files in this directory
encoder_dir = os.path.dirname(__file__)
for file in os.listdir(encoder_dir):
path = os.path.join(encoder_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
file_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("eole.encoders." + file_name)
34 changes: 0 additions & 34 deletions eole/inputters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,34 +0,0 @@
"""Module defining inputters.
Inputters implement the logic of transforming raw data to vectorized inputs,
e.g., from a line of text to a sequence of vectors.
"""
from eole.inputters.inputter import build_vocab
from eole.inputters.text_utils import (
text_sort_key,
transform_bucket,
numericalize,
tensorify,
)
from eole.inputters.text_corpus import ParallelCorpus, ParallelCorpusIterator
from eole.inputters.dynamic_iterator import (
MixingStrategy,
SequentialMixer,
WeightedMixer,
DynamicDatasetIter,
)


__all__ = [
"build_vocab",
"text_sort_key",
"transform_bucket",
"numericalize",
"tensorify",
"ParallelCorpus",
"ParallelCorpusIterator",
"MixingStrategy",
"SequentialMixer",
"WeightedMixer",
"DynamicDatasetIter",
]
18 changes: 0 additions & 18 deletions eole/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +0,0 @@
"""Module defining models."""
from eole.models.model_saver import build_model_saver
from eole.models.model import (
BaseModel,
EncoderDecoderModel,
DecoderModel,
EncoderModel,
get_model_class,
)

__all__ = [
"build_model_saver",
"BaseModel",
"EncoderDecoderModel",
"DecoderModel",
"EncoderModel",
"get_model_class",
]
21 changes: 2 additions & 19 deletions eole/predict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,12 @@
from eole.predict.generator import GeneratorLM
from eole.predict.encoder import Encoder

from eole.predict.beam_search import BeamSearch, GNMTGlobalScorer
from eole.predict.beam_search import BeamSearchLM
from eole.predict.decode_strategy import DecodeStrategy
from eole.predict.greedy_search import GreedySearch, GreedySearchLM
from eole.predict.penalties import PenaltyBuilder
from eole.predict.beam_search import GNMTGlobalScorer
from eole.decoders.ensemble import load_test_model as ensemble_load_test_model
from eole.models import BaseModel
from eole.models.model import BaseModel
import codecs


__all__ = [
"Translator",
"BeamSearch",
"GNMTGlobalScorer",
"PenaltyBuilder",
"DecodeStrategy",
"GreedySearch",
"GreedySearchLM",
"BeamSearchLM",
"GeneratorLM",
]


def get_infer_class(model_config):
# might have more cases later
if model_config.decoder is None:
Expand Down
6 changes: 3 additions & 3 deletions eole/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
get_specials,
get_transforms_cls,
)
from eole.inputters import build_vocab
from eole.inputters.inputter import vocabs_to_dict # , dict_to_vocabs
from eole.inputters.inputter import vocabs_to_dict, build_vocab # , dict_to_vocabs
from eole.inputters.dynamic_iterator import build_dynamic_dataset_iter
from eole.inputters.text_corpus import save_transformed_sample
from eole.models.model_saver import load_checkpoint
from eole.utils.optimizers import Optimizer
from eole.utils.misc import set_random_seed
from eole.trainer import build_trainer
from eole.models import build_model_saver, get_model_class
from eole.models.model_saver import build_model_saver
from eole.models.model import get_model_class
from eole.modules.embeddings import prepare_pretrained_embeddings

from eole.config import (
Expand Down
2 changes: 1 addition & 1 deletion eole/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from eole.modules.sparse_losses import SparsemaxLoss
from eole.modules.sparse_activations import LogSparsemax
from eole.constants import DefaultTokens
from eole.models import DecoderModel
from eole.models.model import DecoderModel

try:
import ctranslate2
Expand Down

0 comments on commit c74e495

Please sign in to comment.