-
Notifications
You must be signed in to change notification settings - Fork 386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to avoid nodata-only patches #1330
Comments
https://gdal.org/programs/gdal_footprint.html looks very promising. It seems like it's possible to access this information. It's unclear how fast this would be or how we could make use of it in our R-tree. |
Hi! Just happend to come across this issue and wanted to add this as a relevant resource: |
Personally, at the moment, I'm using
From the link: "The gdal_footprint utility can be used to compute the footprint of a raster file, taking into account nodata values (or more generally the mask band attached to the raster bands), and generating polygons/multipolygons corresponding to areas where pixels are valid, and write to an output vector file." I believe it is something like:
I'm not entirely familiar with torchgeo codebase, the samplers have access to the raster sample itself before creating the sampler window/roi? perhaps an option is: for each raster, load the nodata mask at once, create this polygon, and then make it possible for the sampler to only work with the "valid" region (🤔 , reading the nodata mask (the byte array), using patches may is not a good option because the I/O overhead when doing this to sample the entire raster, will take >>> longer than reading this at once) |
We tried something similar to this before (reading files inside the sampler) but rasterio had issues with parallel locking. Did you not encounter any issues with your solution when using multiple workers? |
I'm thinking something like loading the nodata at dataset level, them just accessing it instead of each sampler looking into the files Edit: As I said, I am not familiar with the code base Assuming the dataset is a group of samples... what I'm thinking is: When the first access occurs (or sampling), or in the indexing the stage:
Then, when we go sample item N of the dataset, we just access the N polygon/bb if available. I don't see it causing a memory issue by using an oriented bounding box to store it. But it can be cached someway for gigant datasets. I don't know if it is stored the CRS of each here, but it would be the same idea of having a multi crs dataset where we need to handle it |
The problem is that R-tree does not support oriented bounding boxes. We would have to replace R-tree with something else. For the record, I'm completely fine with doing this, just don't know what else would work well for this use case. Also, note that this will make instantiating the dataset extremely slow because it needs to read every single file just to populate the index. We'll have to benchmark this to see if it's better or worse than doing it in the sampler. We may be able to parallelize this with multiprocessing though. |
Yeah, it's probably best to do this the first time we try to access the sample and then when indexing the dataset... as it's using rtree, I believe, it should be using some of their functionalities... some doubts:
|
We can store theta, but we can't check for overlap using theta. We would have to write our own check to find the valid regions where we can sample from. |
Yes, which I believe leads to the same case/function of #1190 |
There are currently two situations where nodata will be sampled.
Reprojection will not solve point two. A possible solution for both cases: For Step 1 Example on how to retrieve footprint in Sentinel-2 import rasterio
from shapely import wkt
with rasterio.open("/<product_id>.SAFE/MTD_MSIL1C.xml") as src:
valid_pixels_footprint = wkt.loads(src.tags()['FOOTPRINT']) For Step 2 I see this as the simplest solution. Alternatives like replacing rtree as index is a bigger task that should be viewed in relation to #409 in my opinion. |
Step 1: It's still not clear to me how this valid footprint would be used. Also, I would prefer a solution that works for all datasets, not just Sentinel-2 |
Here is an example on the how, but I don't know where in the code. But you would need access to the query representing the bounds of the sample/patch, so my logic is upon getitem.
One solution, as you have mentioned is using import rasterio
from shapely import wkt
from shapely.geometry import box
def extract_footprint(filepath):
# Sentinel-2 example showing how to find footprint
# In the general case could run gdal_footprint beforehand and save to some file
with rasterio.open(filepath) as src:
valid_pixels_footprint = wkt.loads(src.tags()['FOOTPRINT'])
# reproject to dataset crs if it is not
return
def query_intersects_with_footprint(query, filepath):
with rasterio.open(filepath) as src:
valid_pixels_footprint = extract_footprint(filepath)
bbox = box(query.minx, query.miny, query.maxx, query.maxy)
return shapely.overlaps(bbox, valid_pixels_footprint)
class RasterDataset(GeoDataset):
def __getitem__(self, query: BoundingBox):
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[str], [hit.object for hit in hits])
# remove filepaths that has nodata within sample bbox
filepaths = [
path if query_intersects_with_footprint(query, path)
for path in filepaths
]
if not filepaths:
return None
# ... rest of existing method returns samples
# then collate_fn can replace/remove None
def concat_samples_replace_none(samples):
"""
Based on this https://stackoverflow.com/questions/57815001/pytorch-collate-fn-reject-sample-and-yield-another
"""
len_batch = len(batch) # original batch length
batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
collated = concat_samples(samples) # original collate_fn used by IntersectionDataset
if len_batch > len(collated):
# if there are samples missing just use existing members,
# doesn't work if you reject every sample in a batch
diff = len_batch - len(collated)
for i in range(diff):
collated = collated + collated[:diff]
return collated |
"the query representing the bounds of the sample/patch" is created in the sampler. If we can decide whether or not it's a valid location to sample from before getting to the |
Makes sense! Something like this? RasterDataset still need to do the same check, since the class RandomGeoSampler(GeoSampler):
def __iter__(self) -> Iterator[BoundingBox]:
for _ in range(len(self)):
# Choose a random tile, weighted by area
idx = torch.multinomial(self.areas, 1)
hit = self.hits[idx]
bounds = BoundingBox(*hit.bounds)
# Choose a random index within that tile
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
if not query_intersects_with_footprint(bounding_box, hit.object):
# this bounding_box is outside of the valid-pixel footprint of the raster
continue
yield bounding_box |
We don't need to do it in RasterDataset because all hits will be merged (stitched together), so if the sampler says the query is valid, it's valid. Your implementation of |
Assuming the above is correct, My from geopandas import GeoSeries
def extract_footprint(filepath):
with rasterio.open(filepath) as src:
valid_pixels_footprint = wkt.loads(src.tags()['FOOTPRINT'])
return valid_pixels_footprint
def query_intersects_with_footprint(filepath, query, common_crs):
# Reproject vectordata to the same crs as the GeoSampler grid is based on
# using geopandas for simplicity
valid_pixels_footprint_reprojected = (
GeoSeries(
geometry=extract_footprint(filepath),
crs=4326
)
.to_crs(common_crs)
)
bbox = box(query.minx, query.miny, query.maxx, query.maxy)
return valid_pixels_footprint.intersects(bbox).all() # all() to get scalar bool from GeoSeries |
So this approach need access to the common CRS used by RasterDataset, and some way to map from the Can compare agains the |
The sampler is given the dataset index, which is already in a common CRS, no need to warp anything yourself. The problem is that images are rotated with respect to almost any CRS, and have significant nodata pixels around the border. I think I understand your implementation better now. You're not checking the bbox of the image, but of the patch. This should work. I'm just not sure how fast If you can get |
Yes, I think we are trying to explain the same thing here. I pushed a minimal working (🤞) example that fixes this for |
I'll try to review when I get a chance. I'll likely need to implement an I/O benchmarking subcommand to see how much this affects I/O rates before merging, which I won't have time to get to until March. So don't hold your breath, but I promise I'm interested and will review in detail as soon as I can. |
Update: we actually don't need a new I/O benchmarking subcommand, Lightning has built-in support for this!: https://lightning.ai/docs/pytorch/stable/tuning/profiler.html So all we really need is to:
For 1, in our preliminary TorchGeo paper, we sampled 100 random Landsat scenes and one CDL map, each in a different CRS. There are a million things we can play around with (COGs, block size, resolution, CRS, etc.). We may want to develop a list of multiple options:
Essentially, we would be developing a set of benchmark datasets not for benchmarking models, but for benchmarking I/O. I wonder if such a thing already exists. This might actually make for an interesting paper if it doesn't exist. Let me ask around with the GDAL folks. But anyway, for your contribution, a single dataset should be sufficient. Initially I framed this from the perspective of "as long as your PR doesn't make things significantly slower, it's fine". However, after more thought, it's possible your implementation is actually significantly faster, as it allows us to skip many regions we would have otherwise sampled. So when benchmarking I/O, we should definitely take this into account. I.e., the best measure of I/O speeds is how long it takes GridGeoSampler to iterate over the entire dataset, not how long it takes for a specific number of patches to be loaded. |
@AdeelH I'm curious if/how Raster Vision handles this situation. Basically, if you have a raster image that is rotated with many nodata pixels around the edge, how do you prevent your sliding/random window datasets from returning images that are entirely nodata pixels? |
If we're talking about how even though the data is stored as a rectangular array in the GeoTIFF, it is rotated when its bounds are transformed to map coordinates: this is not an issue in Raster Vision, since it samples windows in pixel coordinates (i.e. coordinates of the rectangular array). Now it might be the case that the data does not look right when read this way (I think this is the case with the projection used in MODIS imagery); if that's the case, then it is on the user to reproject it (I think this doesn't contradict what I said in the other issue, since the reprojection is for reasons other than alignment). If we're talking about how sometimes that rectangular array might have nodata pixels near the edges (e.g. if it has been reprojected): then you can specify an AOI geometry for the data region, if you have one (e.g. the geometry of the STAC item or output of GDAL footprint). If you don't have such a geometry, then you can try restricting the padding around the edges to avoid sampling too many windows near the edges. There's no mechanism for on-the-fly filtering based on nodata content in RV GeoDatasets. However, RV also provides a way to pre-chip a dataset and then treat it as a non-geospatial dataset. And that does allow filtering based on nodata percentage via a nodata_threshold option. |
Summary
Large scenes are often rotated due to CRS like so:
When using our GeoSamplers, we often sample from the nodata regions around the edges.
Rationale
Sampling these nodata-only patches results in slower I/O and slower GPU training, despite these patches contributing nothing to the model.
Implementation
There are two places where we could potentially improve performance.
I/O
Best case scenario would be to avoid sampling these regions entirely. In #449, someone tried adding a check to the sampler that actually loads the patch, checks if it's entirely nodata pixels, and skips it if so. Unfortunately, this doesn't seem to work in parallel, and is slow since it needs to load each patch twice.
If there was a way to load the image in its native CRS (i.e., a square with no nodata pixels in the orientation it was taken in), this would solve all of our problems. I don't know of a way to do this.
GPU
This is actually easier to solve. We could add a feature to our data module base classes that removes all nodata-only images from each mini-batch inside
transfer_batch_to_device
or another step. This would result in variable batch sizes, but I don't think that's an issue.Alternatives
No response
Additional information
This is a highly requested feature:
The text was updated successfully, but these errors were encountered: