Skip to content

Commit

Permalink
fix: warning hidden vs rnn size, priority hidden_size
Browse files Browse the repository at this point in the history
  • Loading branch information
irinaespejo committed Apr 30, 2024
1 parent d362b3e commit 2369eaf
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion src/rxn/onmt_models/scripts/rxn_onmt_train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings
from pathlib import Path
from typing import Tuple

Expand All @@ -22,6 +23,24 @@ def get_src_tgt_vocab(data: Path) -> Tuple[Path, Path]:
return src_vocab, tgt_vocab


def check_rnn_vs_hidden_size(hidden_size: int, rnn_size: int) -> int:
"""
Helper function that checks wether hidden_size and rnn_size are given, decides which one to use and raises warnings.
rnn_size always has a default defaults.RNN_SIZE, if no hidden_size is given, rnn_size will be used.
If hidden_size is given, hidden size will be used.
"""
if hidden_size is None:
warnings.warn(
f"Argument hidden_size is not given, rnn_size with value {rnn_size} will be used"
)
return rnn_size
if hidden_size is not None:
warnings.warn(
f"Argument hidden_size was given with value {hidden_size}, rnn_size argument will be overwritten."
)
return hidden_size


@click.command(context_settings=dict(show_default=True))
@click.option("--batch_size", default=defaults.BATCH_SIZE)
@click.option(
Expand Down Expand Up @@ -51,6 +70,7 @@ def get_src_tgt_vocab(data: Path) -> Tuple[Path, Path]:
help="Directory with OpenNMT-preprocessed files",
)
@click.option("--rnn_size", default=defaults.RNN_SIZE)
@click.option("--hidden_size")
@click.option("--seed", default=defaults.SEED)
@click.option("--train_num_steps", default=100000)
@click.option("--transformer_ff", default=defaults.TRANSFORMER_FF)
Expand All @@ -68,6 +88,7 @@ def main(
no_gpu: bool,
preprocess_dir: str,
rnn_size: int,
hidden_size: int,
seed: int,
train_num_steps: int,
transformer_ff: int,
Expand All @@ -80,6 +101,9 @@ def main(
`data_weights` parameters are given (Note: needs to be consistent with the
rxn-onmt-preprocess command executed before training.
"""
# Check rnn_size or hidden_size given, not both
# NOTE: rnn_size argument is kept for compatibility
hidden_size = check_rnn_vs_hidden_size(hidden_size=hidden_size, rnn_size=rnn_size)

# set up paths
model_files = ModelFiles(model_output_dir)
Expand Down Expand Up @@ -110,7 +134,7 @@ def main(
keep_checkpoint=keep_checkpoint,
layers=layers,
learning_rate=learning_rate,
hidden_size=rnn_size,
hidden_size=hidden_size,
save_model=model_files.model_prefix,
seed=seed,
train_steps=train_num_steps,
Expand Down

0 comments on commit 2369eaf

Please sign in to comment.