From 2369eafdcf22700e2a27de97dd2ac63024eab556 Mon Sep 17 00:00:00 2001 From: irinaespejo Date: Tue, 30 Apr 2024 10:44:44 +0200 Subject: [PATCH] fix: warning hidden vs rnn size, priority hidden_size --- src/rxn/onmt_models/scripts/rxn_onmt_train.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/rxn/onmt_models/scripts/rxn_onmt_train.py b/src/rxn/onmt_models/scripts/rxn_onmt_train.py index 497fdf2..0e84af5 100644 --- a/src/rxn/onmt_models/scripts/rxn_onmt_train.py +++ b/src/rxn/onmt_models/scripts/rxn_onmt_train.py @@ -1,4 +1,5 @@ import logging +import warnings from pathlib import Path from typing import Tuple @@ -22,6 +23,24 @@ def get_src_tgt_vocab(data: Path) -> Tuple[Path, Path]: return src_vocab, tgt_vocab +def check_rnn_vs_hidden_size(hidden_size: int, rnn_size: int) -> int: + """ + Helper function that checks wether hidden_size and rnn_size are given, decides which one to use and raises warnings. + rnn_size always has a default defaults.RNN_SIZE, if no hidden_size is given, rnn_size will be used. + If hidden_size is given, hidden size will be used. + """ + if hidden_size is None: + warnings.warn( + f"Argument hidden_size is not given, rnn_size with value {rnn_size} will be used" + ) + return rnn_size + if hidden_size is not None: + warnings.warn( + f"Argument hidden_size was given with value {hidden_size}, rnn_size argument will be overwritten." + ) + return hidden_size + + @click.command(context_settings=dict(show_default=True)) @click.option("--batch_size", default=defaults.BATCH_SIZE) @click.option( @@ -51,6 +70,7 @@ def get_src_tgt_vocab(data: Path) -> Tuple[Path, Path]: help="Directory with OpenNMT-preprocessed files", ) @click.option("--rnn_size", default=defaults.RNN_SIZE) +@click.option("--hidden_size") @click.option("--seed", default=defaults.SEED) @click.option("--train_num_steps", default=100000) @click.option("--transformer_ff", default=defaults.TRANSFORMER_FF) @@ -68,6 +88,7 @@ def main( no_gpu: bool, preprocess_dir: str, rnn_size: int, + hidden_size: int, seed: int, train_num_steps: int, transformer_ff: int, @@ -80,6 +101,9 @@ def main( `data_weights` parameters are given (Note: needs to be consistent with the rxn-onmt-preprocess command executed before training. """ + # Check rnn_size or hidden_size given, not both + # NOTE: rnn_size argument is kept for compatibility + hidden_size = check_rnn_vs_hidden_size(hidden_size=hidden_size, rnn_size=rnn_size) # set up paths model_files = ModelFiles(model_output_dir) @@ -110,7 +134,7 @@ def main( keep_checkpoint=keep_checkpoint, layers=layers, learning_rate=learning_rate, - hidden_size=rnn_size, + hidden_size=hidden_size, save_model=model_files.model_prefix, seed=seed, train_steps=train_num_steps,