Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Feature] Add Machine translation estimator in api #1156

Open
wants to merge 26 commits into
base: v0.x
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update index.rst
liuzh47 committed Feb 13, 2020
commit feef52e223580ae0987fe2e10e3f21718b879056
4 changes: 2 additions & 2 deletions scripts/machine_translation/index.rst
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ Use the following command to train the GNMT model on the IWSLT2015 dataset.

.. code-block:: console

$ MXNET_GPU_MEM_POOL_TYPE=Round python train_gnmt.py --src_lang en --tgt_lang vi --batch_size 128 \
$ MXNET_GPU_MEM_POOL_TYPE=Round python train_gnmt_estimator.py --src_lang en --tgt_lang vi --batch_size 128 \
--optimizer adam --lr 0.001 --lr_update_factor 0.5 --beam_size 10 --bucket_scheme exp \
--num_hidden 512 --save_dir gnmt_en_vi_l2_h512_beam10 --epochs 12 --gpu 0

@@ -23,7 +23,7 @@ Use the following commands to train the Transformer model on the WMT14 dataset f

.. code-block:: console

$ MXNET_GPU_MEM_POOL_TYPE=Round python train_transformer.py --dataset WMT2014BPE \
$ MXNET_GPU_MEM_POOL_TYPE=Round python train_transformer_estimator.py --dataset WMT2014BPE \
--src_lang en --tgt_lang de --batch_size 2700 \
--optimizer adam --num_accumulated 16 --lr 2.0 --warmup_steps 4000 \
--save_dir transformer_en_de_u512 --epochs 30 --gpus 0,1,2,3,4,5,6,7 --scaled \
4 changes: 1 addition & 3 deletions src/gluonnlp/estimator/__init__.py
Original file line number Diff line number Diff line change
@@ -20,8 +20,6 @@
from .machine_translation_estimator import *
from .machine_translation_event_handler import *
from .machine_translation_batch_processor import *
from .length_normalized_loss import *

__all__ = (machine_translation_estimator.__all__ + machine_translation_event_handler.__all__
+ machine_translation_batch_processor.__all__
+ length_normalized_loss.__all__)
+ machine_translation_batch_processor.__all__)