Skip to content

Commit

Permalink
Merge pull request #1214 from gzrp/dev-postgresql-trans
Browse files Browse the repository at this point in the history
Update the data processing for the benchmark dataset
  • Loading branch information
chrishkchris authored Sep 15, 2024
2 parents e0ebfe6 + 69299bf commit 42c5909
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
9 changes: 7 additions & 2 deletions examples/trans/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ optional arguments:
--n_layers int transformer model n_layers default 6
```

run the example
**run the example**

step 1: Download the dataset to the cmn-eng directory.

step 2: Run the following script.

```
python train.py --dataset cmn-2000.txt --max-epoch 100 --batch-size 32 --lr 0.01
python train.py --dataset cmn-eng/cmn-2000.txt --max-epoch 100 --batch-size 32 --lr 0.01
```
4 changes: 2 additions & 2 deletions examples/trans/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def __len__(self):


class CmnDataset:
def __init__(self, path='cmn-eng/cmn.txt', shuffle=False, batch_size=32, train_ratio=0.8, random_seed=0):
def __init__(self, path, shuffle=False, batch_size=32, train_ratio=0.8, random_seed=0):
"""
cmn dataset, download from https://www.manythings.org/anki/, contains 29909 Chinese and English translation
pairs, the pair format: English + TAB + Chinese + TAB + Attribution
Args:
path: the path of the dataset, default 'cmn-eng/cnn.txt'
path: the path of the dataset
shuffle: shuffle the dataset, default False
batch_size: the size of every batch, default 32
train_ratio: the proportion of the training set to the total data set, default 0.8
Expand Down
2 changes: 1 addition & 1 deletion examples/trans/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
#

# run this example
python train.py --dataset cmn-2000.txt --max-epoch 300 --batch-size 32 --lr 0.01
python train.py --dataset cmn-eng/cmn-2000.txt --max-epoch 100 --batch-size 32 --lr 0.01
5 changes: 2 additions & 3 deletions examples/trans/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def run(args):
np.random.seed(args.seed)

batch_size = args.batch_size
cmn_dataset = CmnDataset(path="cmn-eng/"+args.dataset, shuffle=args.shuffle, batch_size=batch_size, train_ratio=0.8)
cmn_dataset = CmnDataset(path=args.dataset, shuffle=args.shuffle, batch_size=batch_size, train_ratio=0.8)

print("【step-0】 prepare dataset...")
src_vocab_size, tgt_vocab_size = cmn_dataset.en_vab_size, cmn_dataset.cn_vab_size
Expand Down Expand Up @@ -151,8 +151,7 @@ def run(args):

if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Training Transformer Model.")
parser.add_argument('--dataset', choices=['cmn.txt', 'cmn-15000.txt',
'cmn-2000.txt'], default='cmn-2000.txt')
parser.add_argument('--dataset', default='cmn-eng/cmn-2000.txt')
parser.add_argument('--max-epoch', default=100, type=int, help='maximum epochs.', dest='max_epoch')
parser.add_argument('--batch-size', default=64, type=int, help='batch size', dest='batch_size')
parser.add_argument('--shuffle', default=True, type=bool, help='shuffle the dataset', dest='shuffle')
Expand Down

0 comments on commit 42c5909

Please sign in to comment.