Skip to content

Commit

Permalink
Fix periodically stuck in parallel datamanager (#2713)
Browse files Browse the repository at this point in the history
* Fix sticking in parallel datamanager
  • Loading branch information
liruilong940607 authored Jan 5, 2024
1 parent a8e6f8f commit 98a2126
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 38 deletions.
49 changes: 12 additions & 37 deletions nerfstudio/data/datamanagers/parallel_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,24 @@
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import (
Dict,
Generic,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
from typing import Dict, Generic, List, Literal, Optional, Tuple, Type, Union

import torch
import torch.multiprocessing as mp
from pathos.helpers import mp
from rich.progress import track
from torch.nn import Parameter

from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.data.datamanagers.base_datamanager import (
DataManager,
VanillaDataManagerConfig,
TDataset,
VanillaDataManagerConfig,
variable_res_collate,
)
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.pixel_samplers import (
PixelSampler,
PixelSamplerConfig,
PatchPixelSamplerConfig,
)
from nerfstudio.data.utils.dataloaders import (
CacheDataloader,
FixedIndicesEvalDataloader,
RandIndicesEvalDataloader,
)
from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig
from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader
from nerfstudio.model_components.ray_generators import RayGenerator
from nerfstudio.utils.rich_utils import CONSOLE

Expand All @@ -76,7 +59,7 @@ class ParallelDataManagerConfig(VanillaDataManagerConfig):
"""Maximum number of threads to use in thread pool executor. If None, use ThreadPool default."""


class DataProcessor(mp.Process):
class DataProcessor(mp.Process): # type: ignore
"""Parallel dataset batch processor.
This class is responsible for generating ray bundles from an input dataset
Expand All @@ -92,7 +75,7 @@ class DataProcessor(mp.Process):

def __init__(
self,
out_queue: mp.Queue,
out_queue: mp.Queue, # type: ignore
config: ParallelDataManagerConfig,
dataparser_outputs: DataparserOutputs,
dataset: TDataset,
Expand Down Expand Up @@ -120,7 +103,7 @@ def run(self):
ray_bundle = ray_bundle.pin_memory()
while True:
try:
self.out_queue.put_nowait((ray_bundle, batch))
self.out_queue.put((ray_bundle, batch))
break
except queue.Full:
time.sleep(0.0001)
Expand Down Expand Up @@ -188,8 +171,8 @@ def __init__(
self.exclude_batch_keys_from_device = self.train_dataset.exclude_batch_keys_from_device
# Spawn is critical for not freezing the program (PyTorch compatability issue)
# check if spawn is already set
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")
if mp.get_start_method(allow_none=True) is None: # type: ignore
mp.set_start_method("spawn") # type: ignore
super().__init__()

def create_train_dataset(self) -> TDataset:
Expand Down Expand Up @@ -223,7 +206,7 @@ def setup_train(self):
"""Sets up parallel python data processes for training."""
assert self.train_dataset is not None
self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch) # type: ignore
self.data_queue = mp.Manager().Queue(maxsize=self.config.queue_size)
self.data_queue = mp.Queue(maxsize=self.config.queue_size) # type: ignore
self.data_procs = [
DataProcessor(
out_queue=self.data_queue, # type: ignore
Expand All @@ -238,10 +221,6 @@ def setup_train(self):
proc.start()
print("Started threads")

# Prime the executor with the first batch
self.train_executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.config.max_thread_workers)
self.train_batch_fut = self.train_executor.submit(self.data_queue.get)

def setup_eval(self):
"""Sets up the data loader for evaluation."""
assert self.eval_dataset is not None
Expand Down Expand Up @@ -274,11 +253,7 @@ def setup_eval(self):
def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
"""Returns the next batch of data from the parallel training processes."""
self.train_count += 1

# Fetch the next batch in an executor to parallelize the queue get() operation
# with the train step
bundle, batch = self.train_batch_fut.result()
self.train_batch_fut = self.train_executor.submit(self.data_queue.get)
bundle, batch = self.data_queue.get()
ray_bundle = bundle.to(self.device)
return ray_bundle, batch

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ dependencies = [
"trimesh>=3.20.2",
"timm==0.6.7",
"gsplat==0.1.0",
"pytorch-msssim"
"pytorch-msssim",
"pathos"
]

[project.urls]
Expand Down

0 comments on commit 98a2126

Please sign in to comment.