From c8abd7c670b1a89d3b9677f66269456f66865176 Mon Sep 17 00:00:00 2001 From: Helder Lopes Date: Wed, 30 Oct 2024 11:34:05 +0000 Subject: [PATCH] fix: updates to continue training for newest OpenNMT-py --- setup.cfg | 2 +- .../scripts/rxn_onmt_continue_training.py | 13 +++++++++++++ src/rxn/onmt_models/training_files.py | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index ddf0d5f..d34f092 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ install_requires = rxn-chem-utils>=1.1.4 rxn-reaction-preprocessing>=2.0.2 rxn-utils>=1.1.9 - rxn-onmt-utils @ git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@3561002cb13158fb91fa5690e0d1cd406a6b0f5f #rxn-onmt-utils without rxn-opennmt-py depedency + rxn-onmt-utils @ git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@0058c723c7371c6ff3b88647247c9e44cf1ffaa7 #rxn-onmt-utils without rxn-opennmt-py depedency OpenNMT-py>=3.5.1 # official onmt [options.packages.find] diff --git a/src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py b/src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py index 4e3373e..e993557 100644 --- a/src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py +++ b/src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path from typing import Optional, Tuple import click @@ -20,6 +21,12 @@ logger.addHandler(logging.NullHandler()) +def get_src_tgt_vocab(data: Path) -> Tuple[Path, Path]: + src_vocab = data.parent / (data.name + ".vocab.src") + tgt_vocab = data.parent / (data.name + ".vocab.tgt") + return src_vocab, tgt_vocab + + @click.command(context_settings=dict(show_default=True)) @click.option("--batch_size", default=defaults.BATCH_SIZE) @click.option( @@ -102,9 +109,15 @@ def main( dropout = get_model_dropout(train_from) seed = get_model_seed(train_from) + src_vocab, tgt_vocab = get_src_tgt_vocab( + data=onmt_preprocessed_files.preprocess_prefix + ) + train_cmd = OnmtTrainCommand.continue_training( batch_size=batch_size, data=onmt_preprocessed_files.preprocess_prefix, + src_vocab=src_vocab, + tgt_vocab=tgt_vocab, keep_checkpoint=keep_checkpoint, dropout=dropout, save_model=model_files.model_prefix, diff --git a/src/rxn/onmt_models/training_files.py b/src/rxn/onmt_models/training_files.py index 1157480..9eb5954 100644 --- a/src/rxn/onmt_models/training_files.py +++ b/src/rxn/onmt_models/training_files.py @@ -94,7 +94,7 @@ def preprocess_prefix(self) -> Path: @property def vocab_file(self) -> Path: - return self.preprocess_prefix.with_suffix(".vocab.pt") + return self.preprocess_prefix.with_suffix(".vocab.src") class RxnPreprocessingFiles: