Skip to content

Commit

Permalink
fix: updates to continue training for newest OpenNMT-py
Browse files Browse the repository at this point in the history
  • Loading branch information
helderlopes97 committed Oct 30, 2024
1 parent eb13000 commit c8abd7c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ install_requires =
rxn-chem-utils>=1.1.4
rxn-reaction-preprocessing>=2.0.2
rxn-utils>=1.1.9
rxn-onmt-utils @ git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@3561002cb13158fb91fa5690e0d1cd406a6b0f5f #rxn-onmt-utils without rxn-opennmt-py depedency
rxn-onmt-utils @ git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@0058c723c7371c6ff3b88647247c9e44cf1ffaa7 #rxn-onmt-utils without rxn-opennmt-py depedency
OpenNMT-py>=3.5.1 # official onmt

[options.packages.find]
Expand Down
13 changes: 13 additions & 0 deletions src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from pathlib import Path
from typing import Optional, Tuple

import click
Expand All @@ -20,6 +21,12 @@
logger.addHandler(logging.NullHandler())


def get_src_tgt_vocab(data: Path) -> Tuple[Path, Path]:
src_vocab = data.parent / (data.name + ".vocab.src")
tgt_vocab = data.parent / (data.name + ".vocab.tgt")
return src_vocab, tgt_vocab


@click.command(context_settings=dict(show_default=True))
@click.option("--batch_size", default=defaults.BATCH_SIZE)
@click.option(
Expand Down Expand Up @@ -102,9 +109,15 @@ def main(
dropout = get_model_dropout(train_from)
seed = get_model_seed(train_from)

src_vocab, tgt_vocab = get_src_tgt_vocab(
data=onmt_preprocessed_files.preprocess_prefix
)

train_cmd = OnmtTrainCommand.continue_training(
batch_size=batch_size,
data=onmt_preprocessed_files.preprocess_prefix,
src_vocab=src_vocab,
tgt_vocab=tgt_vocab,
keep_checkpoint=keep_checkpoint,
dropout=dropout,
save_model=model_files.model_prefix,
Expand Down
2 changes: 1 addition & 1 deletion src/rxn/onmt_models/training_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def preprocess_prefix(self) -> Path:

@property
def vocab_file(self) -> Path:
return self.preprocess_prefix.with_suffix(".vocab.pt")
return self.preprocess_prefix.with_suffix(".vocab.src")


class RxnPreprocessingFiles:
Expand Down

0 comments on commit c8abd7c

Please sign in to comment.