Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Majorly slow dataloader using pylidc #68

Open
joangog opened this issue Apr 18, 2024 · 2 comments
Open

Majorly slow dataloader using pylidc #68

joangog opened this issue Apr 18, 2024 · 2 comments

Comments

@joangog
Copy link

joangog commented Apr 18, 2024

Hi! I am trying to implement a dataloader for the LIDC dataset using pylidc. Unfortunately there is major slowdown (20 seconds per epoch to 15 minutes per epoch) in training my model when using this dataloader. I wonder if the issue is that the loading data/annotation function of a specific sample is a bit slow in pylidc. The augmentations are for sure not slow because I apply them on a different dataset (BraTS) and I have no problems. Does anyone know what the issue is, and what a solution would be to implement an efficient dataloader? Currectly this is what I have:

class LidcFineTune(Dataset):
    def __init__(self, config, img_list, crop_size=(128, 128, 64), train=False):  # 64 because some raw data don't have many slices
        self.config = config
        self.train = train
        self.img_list = img_list
        self.crop_size = crop_size

        # Create setup file for pylidc
        txt = f"""
        [dicom]
        path = {config.data}
        warn = True
        """
        with open(os.path.join(os.path.expanduser('~'),'.pylidcrc'), 'w') as file:
            file.write(txt)

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, index):
        pid = self.img_list[index]
        scan = pl.query(pl.Scan).filter(pl.Scan.patient_id == pid).first()
        ann = pl.query(pl.Annotation).filter(pl.Scan.patient_id == pid).first()
        
        # Image
        try:
            x = torch.FloatTensor(scan.to_volume())
        except Exception as e:
            raise RuntimeError(f"Corrupted file in {pid}. Redownload!") from e

        # Segmentation mask
        y = torch.zeros(x.shape)
        mask = torch.FloatTensor(ann.boolean_mask())
        bbox = ann.bbox()
        y[bbox[0].start:bbox[0].stop,bbox[1].start:bbox[1].stop,bbox[2].start:bbox[2].stop] = mask
        
        # Resize
        x = x.T.unsqueeze(1) # Move slice dim to batch dim and add temporary channel dimension (H x W x D) -> (D x 1 x H x W)
        y = y.T.unsqueeze(1)
        x = f.interpolate(x, scale_factor=(0.5,0.5))  # Scale only height and weight, not slice dim
        y = f.interpolate(y, scale_factor=(0.5,0.5))
        x = x.permute(1,2,3,0)  # Put slice dim last (D x 1 x H x W -> 1 x H x W x D)
        y = y.permute(1,2,3,0)
        
        x, y = self.aug_sample(x, y)

        # min max
        x = self.normalize(x)

        return x, y
    
    def aug_sample(self, x, y):
        if self.train:
            # Random crop and augment
            x, y = self.random_crop(x, y)
            if random.random() < 0.5:
                x = torch.flip(x, dims=(1,))  # torch.flip not the source of the major slowdown
                y = torch.flip(y, dims=(1,))
            if random.random() < 0.5:
                x = torch.flip(x, dims=(2,))
                y = torch.flip(y, dims=(2,))
            if random.random() < 0.5:
                x = torch.flip(x, dims=(3,))
                y = torch.flip(y, dims=(3,))
        else:
            # Center crop
            x, y = self.center_crop(x, y)
        
        return x, y

    def random_crop(self, x, y):
        """
        Args:
            x: 4d array, [channel, h, w, d]
        """
        crop_size = self.crop_size
        height, width, depth = x.shape[-3:]
        sx = random.randint(0, height - crop_size[0] - 1)
        sy = random.randint(0, width - crop_size[1] - 1)
        sz = random.randint(0, depth - crop_size[2] - 1)
        crop_volume = x[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]
        crop_seg = y[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]

        return crop_volume, crop_seg

    def center_crop(self, x, y):
        crop_size = self.crop_size
        height, width, depth = x.shape[-3:]
        sx = (height - crop_size[0] - 1) // 2
        sy = (width - crop_size[1] - 1) // 2
        sz = (depth - crop_size[2] - 1) // 2
        crop_volume = x[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]
        crop_seg = y[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]

        return crop_volume, crop_seg

    def normalize(self, x):
        return (x - x.min()) / (x.max() - x.min())
@notmatthancock
Copy link
Owner

You're repeatedly computing scan.to_volume() and ann.boolean_mask() in your augmentations pipeline. to_volume() reads all the DICOM from disk and converts it to a numpy volume. ann.boolean_mask() does ray-casting for each pixel to determine if it is inside or outside of the annotation contour. Both these things are pretty expensive. If you are generating multiple augmentations from the same index in __get_item__ I would recommend caching the base data used to generate those augmentations.

@Emvlt
Copy link

Emvlt commented Jul 24, 2024

Hello! We are working on making a pytorch compatible dataloader for LIDC-IDRI: please check out https://github.com/CambridgeCIA/LION/tree/main/LION/data_loaders
for direct use or some inspiration to write your own :)
Currently, we use the dataset in 2D and preprocess it for this purpose, and to avoid the expensive calls mentioned by @notmatthancock

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants