Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem training with DataLoader #1

Closed
ClaraVth opened this issue Jan 21, 2025 · 3 comments
Closed

Problem training with DataLoader #1

ClaraVth opened this issue Jan 21, 2025 · 3 comments

Comments

@ClaraVth
Copy link
Owner

ClaraVth commented Jan 21, 2025

I have isolated the training task from the .pyt script to run it without ArcGIS shutting down after the second run. In the first run, I get the error that I could reproduce in this isolated script torchgeo_logic.py.

It gives the error ValueError: A frozen dataclass was passed to 'apply_to_collection' but this is not allowed.

It should be similar to Issue #1426, but I can't implement the suggested solution.

@adamjstewart
Copy link
Collaborator

We have a few options we can use to solve this problem:

  1. Fix this in TorchGeo (remove crs and bounds from __getitem__)
  2. Fix this in PyTorch/Lightning (figure out why frozen dataclasses aren't allowed and ignore them)
  3. Fix this in your code

1 would be problematic, we would like to keep this metadata if possible. 2 would be worth doing, but may take a while to get the bug fix into a release. I suggest 3 for now.

To fix this in your code, you can either:

  1. Use a transform to pop the crs and bounds keys from the dataset
  2. Use a GeoDataModule to pop those keys

I think 1 is easiest. Try using:

class DropFrozenKeys:
    def __call__(self, sample):
        sample.pop('crs')
        sample.pop('bounds')
        return sample

transforms = DropFrozenKeys()

image_dataset = RasterDataset(paths=image_path, transforms=transforms)
mask_dataset = RasterDataset(paths=mask_path, transforms=transforms)
mask_dataset.is_image = False

Let me know if this helps. If it does, this will be sufficient for a temporary fix. We can discuss fixing this in TorchGeo/PyTorch/Lightning more properly at a later date.

@ClaraVth ClaraVth reopened this Jan 22, 2025
@ClaraVth
Copy link
Owner Author

I've reopended this issue since we need to keep the bounds for the following segmentation task.

@adamjstewart
Copy link
Collaborator

To keep the bounds, try something like:

bbox = sample['bounds']
sample['bounds'] = torch.tensor([bbox.minx, bbox.miny, bbox.maxx, bbox.maxy])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants