Skip to content

Commit

Permalink
feat: new online streaming dataloader and VAE dtype conf
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 12, 2024
1 parent 3cacef5 commit 6244b6a
Show file tree
Hide file tree
Showing 6 changed files with 458 additions and 225 deletions.
269 changes: 103 additions & 166 deletions datasets/dataset preparations.ipynb

Large diffs are not rendered by default.

201 changes: 146 additions & 55 deletions evaluate.ipynb

Large diffs are not rendered by default.

Empty file added flaxdiff/data/__init__.py
Empty file.
205 changes: 205 additions & 0 deletions flaxdiff/data/online_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import multiprocessing
import threading
from multiprocessing import Queue
# from arrayqueues.shared_arrays import ArrayQueue
# from faster_fifo import Queue
import time
import albumentations as A
import queue
import cv2
from functools import partial
from typing import Any, Dict, List, Tuple

import numpy as np
from functools import partial

from datasets import load_dataset, concatenate_datasets, Dataset
from datasets.utils.file_utils import get_datasets_user_agent
from concurrent.futures import ThreadPoolExecutor
import io
import urllib

import PIL.Image
import cv2

USER_AGENT = get_datasets_user_agent()

data_queue = Queue(16*2000)
error_queue = Queue(16*2000)


def fetch_single_image(image_url, timeout=None, retries=0):
for _ in range(retries + 1):
try:
request = urllib.request.Request(
image_url,
data=None,
headers={"user-agent": USER_AGENT},
)
with urllib.request.urlopen(request, timeout=timeout) as req:
image = PIL.Image.open(io.BytesIO(req.read()))
break
except Exception:
image = None
return image

def map_sample(
url, caption,
image_shape=(256, 256),
upscale_interpolation=cv2.INTER_LANCZOS4,
downscale_interpolation=cv2.INTER_AREA,
):
try:
image = fetch_single_image(url, timeout=15, retries=3) # Assuming fetch_single_image is defined elsewhere
if image is None:
return

image = np.array(image)
original_height, original_width = image.shape[:2]
# check if the image is too small
if min(original_height, original_width) < min(image_shape):
return
# check if wrong aspect ratio
if max(original_height, original_width) / min(original_height, original_width) > 2:
return
# check if the variance is too low
if np.std(image) < 1e-4:
return
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
downscale = max(original_width, original_height) > max(image_shape)
interpolation = downscale_interpolation if downscale else upscale_interpolation
image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)
image = A.pad(
image,
min_height=image_shape[0],
min_width=image_shape[1],
border_mode=cv2.BORDER_CONSTANT,
value=[255, 255, 255],
)
data_queue.put({
"url": url,
"caption": caption,
"image": image
})
except Exception as e:
error_queue.put({
"url": url,
"caption": caption,
"error": str(e)
})

def map_batch(batch, num_threads=256, timeout=None, retries=0):
with ThreadPoolExecutor(max_workers=num_threads) as executor:
executor.map(map_sample, batch["url"], batch['caption'])

def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=256):
map_batch_fn = partial(map_batch, num_threads=num_threads)
shard_len = len(dataset) // num_workers
print(f"Local Shard lengths: {shard_len}")
with multiprocessing.Pool(num_workers) as pool:
iteration = 0
while True:
# Repeat forever
dataset = dataset.shuffle(seed=iteration)
shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
pool.map(map_batch_fn, shards)
iteration += 1

class ImageBatchIterator:
def __init__(self, dataset: Dataset, batch_size: int = 64, num_workers: int = 8, num_threads=256):
self.dataset = dataset
self.num_workers = num_workers
self.batch_size = batch_size
loader = partial(parallel_image_loader, num_threads=num_threads)
self.thread = threading.Thread(target=loader, args=(dataset, num_workers))
self.thread.start()

def __iter__(self):
return self

def __next__(self):
def fetcher(_):
return data_queue.get()
with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
batch = list(executor.map(fetcher, range(self.batch_size)))
return batch

def __del__(self):
self.thread.join()

def __len__(self):
return len(self.dataset) // self.batch_size

def default_collate(batch):
urls = [sample["url"] for sample in batch]
captions = [sample["caption"] for sample in batch]
images = np.stack([sample["image"] for sample in batch], axis=0)
return {
"url": urls,
"caption": captions,
"image": images,
}

def dataMapper(map: Dict[str, Any]):
def _map(sample) -> Dict[str, Any]:
return {
"url": sample[map["url"]],
"caption": sample[map["caption"]],
}
return _map

class OnlineStreamingDataLoader():
def __init__(
self,
dataset,
batch_size=64,
num_workers=16,
num_threads=512,
default_split="all",
pre_map_maker=dataMapper,
pre_map_def={
"url": "URL",
"caption": "TEXT",
},
global_process_count=1,
global_process_index=0,
prefetch=1000,
collate_fn=default_collate,
):
if isinstance(dataset, str):
dataset_path = dataset
print("Loading dataset from path")
dataset = load_dataset(dataset_path, split=default_split)
elif isinstance(dataset, list):
if isinstance(dataset[0], str):
print("Loading multiple datasets from paths")
dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
else:
print("Concatenating multiple datasets")
dataset = concatenate_datasets(dataset)
dataset = dataset.map(pre_map_maker(pre_map_def))
self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
print(f"Dataset length: {len(dataset)}")
self.iterator = ImageBatchIterator(self.dataset, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
self.collate_fn = collate_fn

# Launch a thread to load batches in the background
self.batch_queue = queue.Queue(prefetch)

def batch_loader():
for batch in self.iterator:
self.batch_queue.put(batch)

self.loader_thread = threading.Thread(target=batch_loader)
self.loader_thread.start()

def __iter__(self):
return self

def __next__(self):
return self.collate_fn(self.batch_queue.get())
# return self.collate_fn(next(self.iterator))

def __len__(self):
return len(self.dataset) // self.batch_size

6 changes: 3 additions & 3 deletions flaxdiff/models/autoencoder/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
"""

class StableDiffusionVAE(AutoEncoder):
def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):

from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
from diffusers import FlaxStableDiffusionPipeline

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
modelname,
revision="bf16",
dtype=jnp.bfloat16,
revision=revision,
dtype=dtype,
)

vae = pipeline.vae
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.12',
version='0.1.13',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 6244b6a

Please sign in to comment.