From 4aade13e04fee9a7614c157897a800f60b9a1d48 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 1 Apr 2021 17:10:48 -0700 Subject: [PATCH] make sure is appended --- dalle_pytorch/simple_tokenizer.py | 6 +++--- setup.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dalle_pytorch/simple_tokenizer.py b/dalle_pytorch/simple_tokenizer.py index 08b2619c..e821bf78 100644 --- a/dalle_pytorch/simple_tokenizer.py +++ b/dalle_pytorch/simple_tokenizer.py @@ -122,12 +122,12 @@ def decode(self, tokens): tokenizer = SimpleTokenizer() -def tokenize(texts, context_length = 256, add_start_and_end = False, truncate_text=False): +def tokenize(texts, context_length = 256, add_start = False, add_end = True, truncate_text = False): if isinstance(texts, str): texts = [texts] - sot_tokens = [tokenizer.encoder["<|startoftext|>"]] if add_start_and_end else [] - eot_tokens = [tokenizer.encoder["<|endoftext|>"]] if add_start_and_end else [] + sot_tokens = [tokenizer.encoder["<|startoftext|>"]] if add_start else [] + eot_tokens = [tokenizer.encoder["<|endoftext|>"]] if add_end else [] all_tokens = [sot_tokens + tokenizer.encode(text) + eot_tokens for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) diff --git a/setup.py b/setup.py index fb852855..a4711877 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'dalle-pytorch', packages = find_packages(), include_package_data = True, - version = '0.8.0', + version = '0.8.1', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',