Skip to content

Commit

Permalink
Added logging of throughput and GPU utilization, and a file containin…
Browse files Browse the repository at this point in the history
…g code for an asynchronous data loader, but I have yet to incorporate this into `train.py`.
  • Loading branch information
asaparov committed Dec 28, 2023
1 parent 10b3c58 commit 955937c
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 8 deletions.
113 changes: 113 additions & 0 deletions async_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import re
from queue import Queue
from threading import Thread
from typing import Any, Optional, Union

import torch
from torch import Tensor
from torch.utils.data import DataLoader, Dataset


class AsynchronousLoader:
"""Class for asynchronously loading from CPU memory to device memory with DataLoader.
Note that this only works for single GPU training, multiGPU uses PyTorch's DataParallel or
DistributedDataParallel which uses its own code for transferring data across GPUs. This could just
break or make things slower with DataParallel or DistributedDataParallel.
Args:
data: The PyTorch Dataset or DataLoader we're using to load.
device: The PyTorch device we are loading to
q_size: Size of the queue used to store the data loaded to the device
num_batches: Number of batches to load. This must be set if the dataloader
doesn't have a finite __len__. It will also override DataLoader.__len__
if set and DataLoader has a __len__. Otherwise it can be left as None
**kwargs: Any additional arguments to pass to the dataloader if we're
constructing one here
"""

def __init__(
self,
data: Union[DataLoader, Dataset],
device: torch.device = torch.device("cuda", 0),
q_size: int = 10,
num_batches: Optional[int] = None,
**kwargs: Any,
) -> None:
if isinstance(data, torch.utils.data.DataLoader):
self.dataloader = data
else:
self.dataloader = DataLoader(data, **kwargs)

if num_batches is not None:
self.num_batches = num_batches
elif hasattr(self.dataloader, "__len__"):
self.num_batches = len(self.dataloader)
else:
raise Exception("num_batches must be specified or data must have finite __len__")

self.device = device
self.q_size = q_size

self.load_stream = torch.cuda.Stream(device=device)
self.queue: Queue = Queue(maxsize=self.q_size)

self.idx = 0

self.np_str_obj_array_pattern = re.compile(r"[SaUO]")

def load_loop(self) -> None: # The loop that will load into the queue in the background
for i, sample in enumerate(self.dataloader):
self.queue.put(self.load_instance(sample))
if i == len(self):
break

# Recursive loading for each instance based on torch.utils.data.default_collate
def load_instance(self, sample: Any) -> Any:
elem_type = type(sample)

if torch.is_tensor(sample):
with torch.cuda.stream(self.load_stream):
# Can only do asynchronous transfer if we use pin_memory
if not sample.is_pinned():
sample = sample.pin_memory()
return sample.to(self.device, non_blocking=True)
elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_":
if elem_type.__name__ == "ndarray" and self.np_str_obj_array_pattern.search(sample.dtype.str) is not None:
return self.load_instance(sample)
return self.load_instance(torch.as_tensor(sample))
elif isinstance(sample, tuple) and hasattr(sample, "_fields"): # namedtuple
return elem_type(*(self.load_instance(d) for d in sample))
else:
return sample

def __iter__(self) -> "AsynchronousLoader":
# We don't want to run the thread more than once
# Start a new thread if we are at the beginning of a new epoch, and our current worker is dead

if_worker = not hasattr(self, "worker") or not self.worker.is_alive() # type: ignore[has-type]
if if_worker and self.queue.empty() and self.idx == 0:
self.worker = Thread(target=self.load_loop)
self.worker.daemon = True
self.worker.start()
return self

def __next__(self) -> Tensor:
# If we've reached the number of batches to return
# or the queue is empty and the worker is dead then exit
done = not self.worker.is_alive() and self.queue.empty()
done = done or self.idx >= len(self)
if done:
self.idx = 0
self.queue.join()
self.worker.join()
raise StopIteration
# Otherwise return the next batch
out = self.queue.get()
self.queue.task_done()
self.idx += 1
return out

def __len__(self) -> int:
return self.num_batches
21 changes: 13 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from random import sample, randrange, choice, shuffle, seed, uniform, getstate, setstate, Random
from os import listdir, makedirs, rename, remove
from os import listdir, makedirs, rename, remove, popen
from os.path import isfile, isdir
from sys import stdout
import pickle
Expand All @@ -9,7 +9,7 @@
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.data import Dataset, DataLoader
from gpt2 import Transformer
from gpt2 import Transformer, ToeplitzMode
from Sophia import SophiaG

RESERVED_INDICES = (0,)
Expand Down Expand Up @@ -525,7 +525,7 @@ def train(max_input_size, dataset_size, max_lookahead, seed_value, nlayers, hidd
else:
device = torch.device('cuda')

BATCH_SIZE = 4096
BATCH_SIZE = 512
if dataset_size != -1:
train_data = DummyDataset(inputs, outputs, device)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
Expand Down Expand Up @@ -601,8 +601,8 @@ def train(max_input_size, dataset_size, max_lookahead, seed_value, nlayers, hidd
# we are doing streaming training, so use an IterableDataset
from itertools import cycle
from threading import Lock
STREAMING_BLOCK_SIZE = 2 ** 18
NUM_DATA_WORKERS = 32
STREAMING_BLOCK_SIZE = 2 ** 16
NUM_DATA_WORKERS = 8
seed_generator = Random(seed_value)
seed_generator_lock = Lock()
seed_values = []
Expand Down Expand Up @@ -653,10 +653,11 @@ def __iter__(self):
train_loader = DataLoader(iterable_dataset, batch_size=None, num_workers=NUM_DATA_WORKERS, pin_memory=True)

while True:
#import time
#time1 = time.perf_counter()
import time
start_time = time.perf_counter()
epoch_loss = 0.0
num_batches = 0
effective_dataset_size = (STREAMING_BLOCK_SIZE if dataset_size == -1 else dataset_size)
for batch in cycle(train_loader):
model.train()
optimizer.zero_grad()
Expand Down Expand Up @@ -694,7 +695,7 @@ def compute_toeplitz_regularization(m):
optimizer.step()

num_batches += 1
if (dataset_size == -1 and num_batches == STREAMING_BLOCK_SIZE // BATCH_SIZE) or (dataset_size >= 0 and num_batches == dataset_size // BATCH_SIZE):
if num_batches == effective_dataset_size // BATCH_SIZE:
#time4 = time.perf_counter()
#print('[MAIN] Time to train: {}s'.format(time4 - time3))
#stdout.flush()
Expand All @@ -711,8 +712,12 @@ def compute_toeplitz_regularization(m):
#stdout.flush()

if epoch % log_interval == 0:
elapsed_time = time.perf_counter() - start_time
print("epoch = {}, training loss = {}".format(epoch, epoch_loss))
utilization = popen('nvidia-smi --query-gpu=utilization.gpu --format=csv').read().split('\n')[1]
print("throughput = {} examples/s, GPU utilization = {}".format(effective_dataset_size / elapsed_time, utilization))
stdout.flush()
start_time = time.perf_counter()

if epoch % eval_interval == 0:
model.eval()
Expand Down

0 comments on commit 955937c

Please sign in to comment.