Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

57 use tifffiles #58

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 0 additions & 2 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ dependencies:
- numpy>=2.0
- pillow
- pytorch>=2.1
- rasterio
- scikit-learn
- tensorboard
- torchaudio
- torchvision
- tqdm
- tifffile
Expand Down
10 changes: 4 additions & 6 deletions pangaea/datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import torch
import pandas as pd
import pathlib
import rasterio
from tifffile import imread
import tifffile
from os.path import join as opj

from pangaea.datasets.utils import read_tif
Expand All @@ -23,7 +22,7 @@ def read_imgs(multi_temporal, temp , fname, data_dir, img_size):

s1_filepath = data_dir.joinpath(s1_fname)
if s1_filepath.exists():
img_s1 = imread(s1_filepath)
img_s1 = tifffile.imread(s1_filepath)
m = img_s1 == -9999
img_s1 = img_s1.astype('float32')
img_s1 = np.where(m, 0, img_s1)
Expand All @@ -32,7 +31,7 @@ def read_imgs(multi_temporal, temp , fname, data_dir, img_size):

s2_filepath = data_dir.joinpath(s2_fname)
if s2_filepath.exists():
img_s2 = imread(s2_filepath)
img_s2 = tifffile.imread(s2_filepath)
img_s2 = img_s2.astype('float32')
else:
img_s2 = np.zeros((img_size, img_size) + (11,), dtype='float32')
Expand Down Expand Up @@ -155,8 +154,7 @@ def __getitem__(self, index):
fname = str(chip_id)+'_agbm.tif'

imgs_s1, imgs_s2, mask = read_imgs(self.multi_temporal, self.temp, fname, self.dir_features, self.img_size)
with rasterio.open(self.dir_labels.joinpath(fname)) as lbl:
target = lbl.read(1)
target = tifffile.imread(self.dir_labels.joinpath(fname), key=0)
target = np.nan_to_num(target)

imgs_s1 = torch.from_numpy(imgs_s1).float()
Expand Down
1 change: 0 additions & 1 deletion pangaea/datasets/fivebillionpixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import time
import torch
import numpy as np
import rasterio
import random
from glob import glob

Expand Down
1 change: 0 additions & 1 deletion pangaea/datasets/hlsburnscars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import time
import torch
import numpy as np
# import rasterio
import tifffile as tiff
from typing import Sequence, Dict, Any, Union, Literal, Tuple
from sklearn.model_selection import train_test_split
Expand Down
80 changes: 38 additions & 42 deletions pangaea/datasets/mados.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@
import zipfile

from glob import glob
import rasterio
import cv2
import tifffile
import numpy as np

import warnings

warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as T

from pangaea.datasets.utils import DownloadProgressBar
from pangaea.datasets.base import GeoFMDataset
Expand Down Expand Up @@ -112,27 +107,32 @@ def __init__(
self.download_url = download_url
self.auto_download = auto_download

self.ROIs_split = np.genfromtxt(os.path.join(self.root_path, 'splits', f'{split}_X.txt'), dtype='str')
self.ROIs_split = np.genfromtxt(
os.path.join(self.root_path, "splits", f"{split}_X.txt"), dtype="str"
)

self.image_list = []
self.target_list = []

self.tiles = sorted(glob(os.path.join(self.root_path, '*')))
self.tiles = sorted(glob(os.path.join(self.root_path, "*")))

for tile in self.tiles:
splits = [f.split('_cl_')[-1] for f in glob(os.path.join(tile, '10', '*_cl_*'))]
splits = [
f.split("_cl_")[-1] for f in glob(os.path.join(tile, "10", "*_cl_*"))
]

for crop in splits:
crop_name = os.path.basename(tile) + '_' + crop.split('.tif')[0]
crop_name = os.path.basename(tile) + "_" + crop.split(".tif")[0]

if crop_name in self.ROIs_split:
all_bands = glob(os.path.join(tile, '*', '*L2R_rhorc*_' + crop))
all_bands = glob(os.path.join(tile, "*", "*L2R_rhorc*_" + crop))
all_bands = sorted(all_bands, key=self.get_band)
# all_bands = np.array(all_bands)

self.image_list.append(all_bands)

cl_path = os.path.join(tile, '10', os.path.basename(tile) + '_L2R_cl_' + crop)
cl_path = os.path.join(
tile, "10", os.path.basename(tile) + "_L2R_cl_" + crop
)
self.target_list.append(cl_path)

def __len__(self):
Expand All @@ -143,42 +143,36 @@ def getnames(self):

def __getitem__(self, index):

all_bands = self.image_list[index]
band_paths = self.image_list[index]
current_image = []
for c, band in enumerate(all_bands):
upscale_factor = int(os.path.basename(os.path.dirname(band))) // 10
with rasterio.open(band, mode='r') as src:
this_band = src.read(1,
out_shape=(int(src.height * upscale_factor), int(src.width * upscale_factor)),
resampling=rasterio.enums.Resampling.nearest
)
this_band = torch.from_numpy(this_band)
#this_band[torch.isnan(this_band)] = self.data_mean['optical'][c]
current_image.append(this_band)

image = torch.stack(current_image)
invalid_mask = torch.isnan(image)
image[invalid_mask] = 0
for path in band_paths:
upscale_factor = int(os.path.basename(os.path.dirname(path))) // 10

band = tifffile.imread(path)
band = cv2.resize(band, dsize=None, fx=upscale_factor, fy=upscale_factor, interpolation=cv2.INTER_NEAREST_EXACT)
band_tensor = torch.from_numpy(band).unsqueeze(0)
current_image.append(band_tensor)

with rasterio.open(self.target_list[index], mode='r') as src:
target = src.read(1)
image = torch.cat(current_image)
invalid_mask = torch.isnan(image)
image[invalid_mask] = 0
target = tifffile.imread(self.target_list[index])
target = torch.from_numpy(target.astype(np.int64))
target = target - 1

output = {
'image': {
'optical': image,
"image": {
"optical": image,
},
'target': target,
'metadata': {}
"target": target,
"metadata": {},
}

return output

@staticmethod
def get_band(path):
return int(path.split('_')[-2])
return int(path.split("_")[-2])

@staticmethod
def download(self, silent=False):
Expand All @@ -199,15 +193,17 @@ def download(self, silent=False):
try:
urllib.request.urlretrieve(url, output_path / temp_file_name, pbar)
except urllib.error.HTTPError as e:
print('Error while downloading dataset: The server couldn\'t fulfill the request.')
print('Error code: ', e.code)
print(
"Error while downloading dataset: The server couldn't fulfill the request."
)
print("Error code: ", e.code)
return
except urllib.error.URLError as e:
print('Error while downloading dataset: Failed to reach a server.')
print('Reason: ', e.reason)
print("Error while downloading dataset: Failed to reach a server.")
print("Reason: ", e.reason)
return

with zipfile.ZipFile(output_path / temp_file_name, 'r') as zip_ref:
with zipfile.ZipFile(output_path / temp_file_name, "r") as zip_ref:
print(f"Extracting to {output_path} ...")
# Remove top-level dir in ZIP file for nicer data dir structure
members = []
Expand All @@ -219,4 +215,4 @@ def download(self, silent=False):
zip_ref.extractall(output_path, members)
print("done.")

(output_path / temp_file_name).unlink()
(output_path / temp_file_name).unlink()
14 changes: 6 additions & 8 deletions pangaea/datasets/pastis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
import tifffile
import torch
from einops import rearrange

Expand Down Expand Up @@ -203,17 +203,15 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor

for modality in self.modalities:
if modality == "aerial":
with rasterio.open(
os.path.join(
path = os.path.join(
self.path,
"DATA_SPOT/PASTIS_SPOT6_RVB_1M00_2019/SPOT6_RVB_1M00_2019_"
+ str(name)
+ ".tif",
)
) as f:
output["aerial"] = split_image(
torch.FloatTensor(f.read()), self.nb_split, part
)
)
output["aerial"] = split_image(
torch.FloatTensor(tifffile.imread(path).transpose(2,0,1), self.nb_split, part)
)
elif modality == "s1-median":
modality_name = "s1a"
images = split_image(
Expand Down
16 changes: 6 additions & 10 deletions pangaea/datasets/sen1floods11.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import geopandas
import numpy as np
import pandas as pd
import rasterio
import tifffile
import torch

from pangaea.datasets.utils import download_bucket_concurrently
Expand Down Expand Up @@ -138,17 +138,13 @@ def _get_date(self, index):
return date_np

def __getitem__(self, index):
with rasterio.open(self.s2_image_list[index]) as src:
s2_image = src.read()
s2_image = tifffile.imread(self.s2_image_list[index])

with rasterio.open(self.s1_image_list[index]) as src:
s1_image = src.read()
# Convert the missing values (clouds etc.)
s1_image = np.nan_to_num(s1_image)

with rasterio.open(self.target_list[index]) as src:
target = src.read(1)
s1_image = tifffile.imread(self.s1_image_list[index])
# Convert the missing values (clouds etc.)
s1_image = np.nan_to_num(s1_image)

target = tifffile.imread(self.target_list[index], key=0)
timestamp = self._get_date(index)

s2_image = torch.from_numpy(s2_image).float()
Expand Down
15 changes: 8 additions & 7 deletions pangaea/datasets/spacenet7.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import json
from glob import glob
import rasterio
import cv2
import tifffile
import numpy as np

import torch
Expand Down Expand Up @@ -214,17 +215,17 @@ def __len__(self) -> int:
def load_planet_mosaic(self, aoi_id: str, year: int, month: int) -> np.ndarray:
folder = self.root_path / 'train' / aoi_id / 'images_masked'
file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}.tif'
with rasterio.open(str(file), mode='r') as src:
img = src.read(out_shape=(self.sn7_img_size, self.sn7_img_size), resampling=rasterio.enums.Resampling.nearest)
# 4th band (last oen) is alpha band
img = img[:-1]
img = tifffile.imread(file)
img = cv2.resize(img, dsize=(self.sn7_img_size, self.sn7_img_size), interpolation=cv2.INTER_NEAREST_EXACT)
# 4th band (last one) is alpha band
img = img.transpose(2, 0, 1)[:-1]
return img.astype(np.float32)

def load_building_label(self, aoi_id: str, year: int, month: int) -> np.ndarray:
folder = self.root_path / 'train' / aoi_id / 'labels_raster'
file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}_Buildings.tif'
with rasterio.open(str(file), mode='r') as src:
label = src.read(out_shape=(self.sn7_img_size, self.sn7_img_size), resampling=rasterio.enums.Resampling.nearest)
label = tifffile.imread(file)
label = cv2.resize(label, dsize=(self.sn7_img_size, self.sn7_img_size), interpolation=cv2.INTER_NEAREST_EXACT)
label = (label > 0).squeeze()
return label.astype(np.int64)

Expand Down
30 changes: 15 additions & 15 deletions pangaea/datasets/utae_dynamicen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import numpy as np
import rasterio
import tifffile
import torch
# from torch.utils.data import Dataset
# from torchvision import transforms
Expand Down Expand Up @@ -154,19 +154,19 @@ def load_data(self, index):
cur_images, cur_dates = [], []
if self.mode == 'daily':
for i in range(1, self.all_days[index][0]+1):
img = rasterio.open(os.path.join(self.root_path, self.all_days[index][i][0][1:]))
red = img.read(3)
green = img.read(2)
blue = img.read(1)
nir = img.read(4)
with tifffile.TiffFile.open(os.path.join(self.root_path, self.all_days[index][i][0][1:])) as img:
red = img.pages[2].asarray()
green = img.pages[1].asarray()
blue = img.pages[0].asarray()
nir = img.pages[3].asarray()
image = np.dstack((red, green, blue, nir))
cur_images.append(np.expand_dims(np.asarray(image, dtype=np.float32), axis=0)) # np.array already\
cur_dates.append(self.all_days[index][i][1])

image_stack = np.concatenate(cur_images, axis=0)
dates = torch.from_numpy(np.array(cur_dates, dtype=np.int32))
label = rasterio.open(os.path.join(self.root_path, self.labels[index][1:]))
label = label.read()
label = tifffile.imread(os.path.join(self.root_path, self.labels[index][1:]))
label = label.transpose(2, 0, 1)
mask = np.zeros((label.shape[1], label.shape[2]), dtype=np.int32)

for i in range(self.num_classes + 1):
Expand All @@ -180,17 +180,17 @@ def load_data(self, index):
else:
for i in range(len(self.dates)):
# read .tif
img = rasterio.open(os.path.join(self.root_path, self.planet_day[index][i][1:]))
red = img.read(3)
green = img.read(2)
blue = img.read(1)
nir = img.read(4)
with tifffile.TiffFile.open(os.path.join(self.root_path, self.planet_day[index][i][1:])) as img:
red = img.pages[2].asarray()
green = img.pages[1].asarray()
blue = img.pages[0].asarray()
nir = img.pages[3].asarray()
image = np.dstack((red, green, blue, nir))
cur_images.append(np.expand_dims(np.asarray(image, dtype=np.float32), axis=0)) # np.array already\
image_stack = np.concatenate(cur_images, axis=0)
dates = torch.from_numpy(np.array(self.planet_day[index][len(self.dates):], dtype=np.int32))
label = rasterio.open(os.path.join(self.root_path, self.labels[index][1:]))
label = label.read()
label = tifffile.imread(os.path.join(self.root_path, self.labels[index][1:]))
label = label.transpose(2, 0, 1)
mask = np.zeros((label.shape[1], label.shape[2]), dtype=np.int32)

for i in range(self.num_classes + 1):
Expand Down
15 changes: 3 additions & 12 deletions pangaea/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import tqdm
import rasterio
import tifffile
import pathlib
import concurrent.futures
from google.cloud.storage import Client
Expand Down Expand Up @@ -83,14 +83,5 @@ def download_blob_file_pair(blob_file_pair):


def read_tif(file: pathlib.Path):
with rasterio.open(file) as dataset:
arr = dataset.read() # (bands X height X width)
return arr.transpose((1, 2, 0))


def read_tif_with_metadata(file: pathlib.Path):
with rasterio.open(file) as dataset:
arr = dataset.read() # (bands X height X width)
transform = dataset.transform
crs = dataset.crs
return arr.transpose((1, 2, 0)), transform, crs
arr = tifffile.imread(file)
return arr.transpose(2, 0, 1)
Loading