-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathprepare_dataset.py
44 lines (32 loc) · 1.32 KB
/
prepare_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from LatentPixel.dataprocess.preprocess import preprocess_pretrain_data
from LatentPixel.config import PretrainDatasetConfig
from LatentPixel import TGraph, DEFAULT_BINARY_RENDERING
from datasets import load_dataset, Dataset
import pandas as pd
import pyarrow as pa
books = load_dataset('lucadiliello/bookcorpusopen', split='train')
books.save_to_disk('storage/bookcorpusopen', num_shards=256) # shard the dataset for parallism
wiki = load_dataset('wikipedia', '20220301.en', split='train')
wiki.save_to_disk('storage/enwiki/', num_shards=256)
del books
del wiki
TGraph.init_render(**DEFAULT_BINARY_RENDERING)
dataset = preprocess_pretrain_data(
PretrainDatasetConfig(
dataset_paths=['storage/bookcorpusopen/', 'storage/enwiki/'],
max_len=1180,
min_len=100,
seed=42,
shuffle=True,
num_shards=256
)
)
# dataset = Dataset(pa.Table.from_pandas(pd.DataFrame([{'text': 'this is a sentence'},{'text': 'this is another sentence'}] * 10000)))
def add_image(sample: dict) -> dict:
img = TGraph.from_text(sample['text'])
sample['image'] = img._value.to(torch.uint8)
sample['num_text_patches'] = img.num_text_patches
return sample
dataset_with_im = dataset.map(add_image, num_proc=2)
dataset_with_im.save_to_disk('storage/booksAndWiki2/data', num_shards=256)