From 69299bf070174f31b8704d272a7c39c08769ca6b Mon Sep 17 00:00:00 2001 From: gzhangruipeng Date: Sat, 14 Sep 2024 19:43:35 +0800 Subject: [PATCH] Update the data processing for the benchmark dataset --- examples/trans/README.md | 9 +++++++-- examples/trans/data.py | 4 ++-- examples/trans/run.sh | 2 +- examples/trans/train.py | 5 ++--- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/trans/README.md b/examples/trans/README.md index 43066fab2..101c9ab3b 100644 --- a/examples/trans/README.md +++ b/examples/trans/README.md @@ -46,7 +46,12 @@ optional arguments: --n_layers int transformer model n_layers default 6 ``` -run the example +**run the example** + +step 1: Download the dataset to the cmn-eng directory. + +step 2: Run the following script. + ``` -python train.py --dataset cmn-2000.txt --max-epoch 100 --batch-size 32 --lr 0.01 +python train.py --dataset cmn-eng/cmn-2000.txt --max-epoch 100 --batch-size 32 --lr 0.01 ``` diff --git a/examples/trans/data.py b/examples/trans/data.py index 8be5157ed..1087bed7a 100644 --- a/examples/trans/data.py +++ b/examples/trans/data.py @@ -56,12 +56,12 @@ def __len__(self): class CmnDataset: - def __init__(self, path='cmn-eng/cmn.txt', shuffle=False, batch_size=32, train_ratio=0.8, random_seed=0): + def __init__(self, path, shuffle=False, batch_size=32, train_ratio=0.8, random_seed=0): """ cmn dataset, download from https://www.manythings.org/anki/, contains 29909 Chinese and English translation pairs, the pair format: English + TAB + Chinese + TAB + Attribution Args: - path: the path of the dataset, default 'cmn-eng/cnn.txt' + path: the path of the dataset shuffle: shuffle the dataset, default False batch_size: the size of every batch, default 32 train_ratio: the proportion of the training set to the total data set, default 0.8 diff --git a/examples/trans/run.sh b/examples/trans/run.sh index 3e7e11bc0..b73559f39 100644 --- a/examples/trans/run.sh +++ b/examples/trans/run.sh @@ -18,4 +18,4 @@ # # run this example -python train.py --dataset cmn-2000.txt --max-epoch 300 --batch-size 32 --lr 0.01 \ No newline at end of file +python train.py --dataset cmn-eng/cmn-2000.txt --max-epoch 100 --batch-size 32 --lr 0.01 \ No newline at end of file diff --git a/examples/trans/train.py b/examples/trans/train.py index 995312033..906ff0916 100644 --- a/examples/trans/train.py +++ b/examples/trans/train.py @@ -35,7 +35,7 @@ def run(args): np.random.seed(args.seed) batch_size = args.batch_size - cmn_dataset = CmnDataset(path="cmn-eng/"+args.dataset, shuffle=args.shuffle, batch_size=batch_size, train_ratio=0.8) + cmn_dataset = CmnDataset(path=args.dataset, shuffle=args.shuffle, batch_size=batch_size, train_ratio=0.8) print("【step-0】 prepare dataset...") src_vocab_size, tgt_vocab_size = cmn_dataset.en_vab_size, cmn_dataset.cn_vab_size @@ -151,8 +151,7 @@ def run(args): if __name__ == '__main__': parser = argparse.ArgumentParser(description="Training Transformer Model.") - parser.add_argument('--dataset', choices=['cmn.txt', 'cmn-15000.txt', - 'cmn-2000.txt'], default='cmn-2000.txt') + parser.add_argument('--dataset', default='cmn-eng/cmn-2000.txt') parser.add_argument('--max-epoch', default=100, type=int, help='maximum epochs.', dest='max_epoch') parser.add_argument('--batch-size', default=64, type=int, help='batch size', dest='batch_size') parser.add_argument('--shuffle', default=True, type=bool, help='shuffle the dataset', dest='shuffle')