Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

336 kornia #12

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions config/augmentation/basic_augmentation_segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ augmentation:

# Radiometric augmentations
noise: # Standard deviation of Gaussian Noise
clahe_clip_limit: 0.1
brightness_contrast_range: # Not yet implemented

# Augmentations done immediately before conversion to torch tensor
normalization: # Normalization: parameters for finetuning. See examples below:
Expand Down
5 changes: 3 additions & 2 deletions config/visualization/default_visualization.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# @package _global_
visualization:
vis_at_train:
vis_at_train: True
vis_at_evaluation: True
vis_batch_range: [0,1,1]
vis_at_init: True
vis_batch_range: [0,10,1]
vis_at_checkpoint:
vis_at_ckpt_min_ep_diff:
vis_at_ckpt_dataset:
21 changes: 13 additions & 8 deletions inference_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import rasterio
import ttach as tta
from collections import OrderedDict

from einops import rearrange
from fiona.crs import to_string
from tqdm import tqdm
from rasterio import features
Expand All @@ -21,6 +23,7 @@
from omegaconf.listconfig import ListConfig

from dataset.aoi import aois_from_csv
from utils.create_dataset import SegmentationDatamodule
from utils.logger import get_logger, set_tracker
from models.model_choice import define_model, read_checkpoint
from utils import augmentation
Expand Down Expand Up @@ -157,9 +160,11 @@ def segmentation(param,
Returns:

"""
dummy_datamodule = SegmentationDatamodule(dontcare2backgr=True)

subdiv = 2
threshold = 0.5
sample = {'sat_img': None, 'map_img': None, 'metadata': None}
sample = {'image': None, 'map_img': None, 'metadata': None}
start_seg = time.time()
print_log = True if logging.level == 20 else False # 20 is INFO
pad = chunk_size * 2
Expand Down Expand Up @@ -188,14 +193,14 @@ def segmentation(param,
raster_info={})

sample['metadata'] = image_metadata
# FIXME: update according to last devs getting closer to torchgeo/kornia
totensor_transform = augmentation.compose_transforms(param,
dataset="tst",
scale=scale,
aug_type='totensor',
print_log=print_log)
sample['sat_img'] = sub_image
sample = totensor_transform(sample)
inputs = sample['sat_img'].unsqueeze_(0)
dataset="tst")
sample['image'] = sub_image
sample["image"] = rearrange(sample["image"], 'h w c -> c h w')
sample["image"] = torch.from_numpy(sample["image"])
sample = dummy_datamodule.preprocess(sample)
inputs = sample['image'].unsqueeze_(0)
inputs = inputs.to(device)
if inputs.shape[1] == 4 and any("module.modelNIR" in s for s in model.state_dict().keys()):
# Init NIR TODO: make a proper way to read the NIR channel
Expand Down
131 changes: 67 additions & 64 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
from sklearn.utils import compute_sample_weight
import torch
from torch import optim
from torchgeo.datasets import stack_samples
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from tqdm import tqdm

from models.model_choice import read_checkpoint, define_model, adapt_checkpoint_to_dp_model
from tiling_segmentation import Tiler
from utils import augmentation as aug, create_dataset
from utils import augmentation as aug
from utils.create_dataset import SegmentationDatamodule, SegmentationDataset
from utils.logger import InformationLogger, tsv_line, get_logger, set_tracker
from utils.loss import verify_weights, define_loss
from utils.metrics import report_classification, create_metrics_dict, iou
Expand Down Expand Up @@ -47,14 +50,11 @@ def create_dataloader(samples_folder: Path,
gpu_devices_dict: dict,
sample_size: int,
dontcare_val: int,
crop_size: int,
num_bands: int,
min_annot_perc: int,
attr_vals: Sequence,
scale: Sequence,
cfg: dict,
cfg: DictConfig,
eval_batch_size: int = None,
dontcare2backgr: bool = False,
compute_sampler_weights: bool = False,
debug: bool = False):
"""
Expand All @@ -64,15 +64,11 @@ def create_dataloader(samples_folder: Path,
@param gpu_devices_dict: (dict) dictionary where each key contains an available GPU with its ram info stored as value
@param sample_size: (int) size of hdf5 samples (used to evaluate eval batch-size)
@param dontcare_val: (int) value in label to be ignored during loss calculation
@param crop_size: (int) size of one side of the square crop performed on original patch during training
@param num_bands: (int) number of bands in imagery
@param min_annot_perc: (int) minimum proportion of ground truth containing non-background information
@param attr_vals: (Sequence)
@param scale: (List) imagery data will be scaled to this min and max value (ex.: 0 to 1)
@param cfg: (dict) Parameters found in the yaml config file.
@param eval_batch_size: (int) Batch size for evaluation (val and test). Optional, calculated automatically if omitted
@param dontcare2backgr: (bool) if True, all dontcare values in label will be replaced with 0 (background value)
before training
@param compute_sampler_weights: (bool)
if True, weights will be computed from dataset patches to oversample the minority class(es) and undersample
the majority class(es) during training.
Expand All @@ -92,30 +88,21 @@ def create_dataloader(samples_folder: Path,
if not num_samples['trn'] >= batch_size and num_samples['val'] >= batch_size:
raise ValueError(f"Number of patches is smaller than batch size")
logging.info(f"Number of samples : {num_samples}\n")
dataset_constr = create_dataset.SegmentationDataset
datasets = []
dummy_datamodule = SegmentationDatamodule(dontcare2backgr=True, dontcare_val=dontcare_val)

for subset in ["trn", "val", "tst"]:
# TODO: should user point to the paths of these csvs directly?
dataset_file, _ = Tiler.make_dataset_file_name(experiment_name, min_annot_perc, subset, attr_vals)
dataset_filepath = samples_folder / dataset_file
datasets.append(dataset_constr(dataset_filepath, subset, num_bands,
max_sample_count=num_samples[subset],
radiom_transform=aug.compose_transforms(params=cfg,
dataset=subset,
aug_type='radiometric'),
geom_transform=aug.compose_transforms(params=cfg,
dataset=subset,
aug_type='geometric',
dontcare=dontcare_val,
crop_size=crop_size),
totensor_transform=aug.compose_transforms(params=cfg,
dataset=subset,
scale=scale,
dontcare2backgr=dontcare2backgr,
dontcare=dontcare_val,
aug_type='totensor'),
debug=debug))
datasets.append(SegmentationDataset(dataset_filepath, subset, num_bands,
max_sample_count=num_samples[subset],
transforms=Compose([dummy_datamodule.preprocess]),
augmentations=aug.compose_transforms(
params=cfg,
dataset=subset,
),
debug=debug))
trn_dataset, val_dataset, tst_dataset = datasets

# Number of workers
Expand All @@ -134,12 +121,30 @@ def create_dataloader(samples_folder: Path,
elif not eval_batch_size:
eval_batch_size = batch_size

trn_dataloader = DataLoader(trn_dataset, batch_size=batch_size, num_workers=num_workers, sampler=sampler,
drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=eval_batch_size, num_workers=num_workers, shuffle=False,
drop_last=True)
tst_dataloader = DataLoader(tst_dataset, batch_size=eval_batch_size, num_workers=num_workers, shuffle=False,
drop_last=True) if num_samples['tst'] > 0 else None
trn_dataloader = DataLoader(
trn_dataset,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
drop_last=True,
collate_fn=stack_samples
)
val_dataloader = DataLoader(
val_dataset,
batch_size=eval_batch_size,
num_workers=num_workers,
shuffle=False,
drop_last=True,
collate_fn=stack_samples
)
tst_dataloader = DataLoader(
tst_dataset,
batch_size=eval_batch_size,
num_workers=num_workers,
shuffle=False,
drop_last=True,
collate_fn=stack_samples
) if num_samples['tst'] > 0 else None

if len(trn_dataloader) == 0 or len(val_dataloader) == 0:
raise ValueError(f"\nTrain and validation dataloader should contain at least one data item."
Expand Down Expand Up @@ -178,7 +183,7 @@ def get_num_samples(
params,
min_annot_perc,
attr_vals,
experiment_name:str,
experiment_name: str,
compute_sampler_weights=False
):
"""
Expand Down Expand Up @@ -267,8 +272,8 @@ def vis_from_dataloader(vis_params,
for batch_index, data in enumerate(_tqdm):
if vis_batch_range is not None and batch_index in range(min_vis_batch, max_vis_batch, increment):
with torch.no_grad():
inputs = data['sat_img'].to(device)
labels = data['map_img'].to(device)
inputs = data['image'].to(device)
labels = data['mask'].to(device)

outputs = model(inputs)
if isinstance(outputs, OrderedDict):
Expand All @@ -285,19 +290,19 @@ def vis_from_dataloader(vis_params,


def training(train_loader,
model,
criterion,
optimizer,
scheduler,
num_classes,
batch_size,
ep_idx,
progress_log,
device,
scale,
vis_params,
debug=False
):
model,
criterion,
optimizer,
scheduler,
num_classes,
batch_size,
ep_idx,
progress_log,
device,
scale,
vis_params,
debug=False
):
"""
Train the model and return the metrics of the training epoch

Expand All @@ -322,8 +327,8 @@ def training(train_loader,
for batch_index, data in enumerate(tqdm(train_loader, desc=f'Iterating train batches with {device.type}')):
progress_log.open('a', buffering=1).write(tsv_line(ep_idx, 'trn', batch_index, len(train_loader), time.time()))

inputs = data['sat_img'].to(device)
labels = data['map_img'].to(device)
inputs = data['image'].to(device)
labels = data['mask'].to(device)

# forward
optimizer.zero_grad()
Expand Down Expand Up @@ -360,8 +365,8 @@ def training(train_loader,
gpu_perc=f"{res['gpu']} %",
gpu_RAM=f"{mem['used'] / (1024 ** 2):.0f}/{mem['total'] / (1024 ** 2):.0f} MiB",
lr=optimizer.param_groups[0]['lr'],
img=data['sat_img'].numpy().shape,
smpl=data['map_img'].numpy().shape,
img=data['image'].numpy().shape,
smpl=data['mask'].numpy().shape,
bs=batch_size,
out_vals=np.unique(outputs[0].argmax(dim=0).detach().cpu().numpy()),
gt_vals=np.unique(labels[0].detach().cpu().numpy())))
Expand Down Expand Up @@ -414,8 +419,8 @@ def evaluation(eval_loader,
progress_log.open('a', buffering=1).write(tsv_line(ep_idx, dataset, batch_index, len(eval_loader), time.time()))

with torch.no_grad():
inputs = data['sat_img'].to(device)
labels = data['map_img'].to(device)
inputs = data['image'].to(device)
labels = data['mask'].to(device)

labels_flatten = flatten_labels(labels)

Expand Down Expand Up @@ -468,7 +473,7 @@ def evaluation(eval_loader,
res, mem = gpu_stats(device=device.index)
logging.debug(OrderedDict(
device=device, gpu_perc=f"{res['gpu']} %",
gpu_RAM=f"{mem['used']/(1024**2):.0f}/{mem['total']/(1024**2):.0f} MiB"
gpu_RAM=f"{mem['used'] / (1024 ** 2):.0f}/{mem['total'] / (1024 ** 2):.0f} MiB"
))

if eval_metrics['loss'].avg:
Expand Down Expand Up @@ -524,7 +529,7 @@ def train(cfg: DictConfig) -> None:

# OPTIONAL PARAMETERS
debug = get_key_def('debug', cfg)
task = get_key_def('task', cfg['general'], default='segmentation')
task = get_key_def('task', cfg['general'], default='segmentation')
dontcare_val = get_key_def("ignore_index", cfg['dataset'], default=-1)
scale = get_key_def('scale_data', cfg['augmentation'], default=[0, 1])
batch_metrics = get_key_def('batch_metrics', cfg['training'], default=None)
Expand All @@ -539,7 +544,8 @@ def train(cfg: DictConfig) -> None:
elif not cfg.loss.is_binary and num_classes == 1:
raise ValueError(f"Parameter mismatch: a multiclass loss was chosen for a 1-class (binary) task")
del cfg.loss.is_binary # prevent exception at instantiation
optimizer = get_key_def('optimizer_name', cfg['optimizer'], default='adam', expected_type=str) # TODO change something to call the function
optimizer = get_key_def('optimizer_name', cfg['optimizer'], default='adam',
expected_type=str) # TODO change something to call the function
pretrained = get_key_def('pretrained', cfg['model'], default=True, expected_type=(bool, str))
train_state_dict_path = get_key_def('state_dict_path', cfg['training'], default=None, expected_type=str)
state_dict_strict = get_key_def('state_dict_strict_load', cfg['training'], default=True, expected_type=bool)
Expand Down Expand Up @@ -575,7 +581,7 @@ def train(cfg: DictConfig) -> None:

data_path = get_key_def('raw_data_dir', cfg['dataset'], to_path=True, validate_path_exists=True)
tiling_root_dir = get_key_def('tiling_data_dir', cfg['tiling'], default=data_path, to_path=True,
validate_path_exists=True)
validate_path_exists=True)
logging.info("\nThe tiling directory used '{}'".format(tiling_root_dir))

tiling_dir = tiling_root_dir / experiment_name
Expand Down Expand Up @@ -652,13 +658,10 @@ def train(cfg: DictConfig) -> None:
gpu_devices_dict=gpu_devices_dict,
sample_size=samples_size,
dontcare_val=dontcare_val,
crop_size=crop_size,
num_bands=num_bands,
min_annot_perc=min_annot_perc,
attr_vals=attr_vals,
scale=scale,
cfg=cfg,
dontcare2backgr=dontcare2backgr,
compute_sampler_weights=compute_sampler_weights,
debug=debug)

Expand Down Expand Up @@ -770,7 +773,7 @@ def train(cfg: DictConfig) -> None:
vis_from_dataloader(vis_params=vis_params,
eval_loader=val_dataloader if vis_at_ckpt_dataset == 'val' else tst_dataloader,
model=model,
ep_num=epoch+1,
ep_num=epoch + 1,
output_path=output_path,
dataset=vis_at_ckpt_dataset,
scale=scale,
Expand All @@ -782,7 +785,7 @@ def train(cfg: DictConfig) -> None:
# logging.info(f'\nCurrent elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s')

# load checkpoint model and evaluate it on test dataset.
if int(cfg['general']['max_epochs']) > 0: # if num_epochs is set to 0, model is loaded to evaluate on test set
if int(cfg['general']['max_epochs']) > 0: # if num_epochs is set to 0, model is loaded to evaluate on test set
checkpoint = read_checkpoint(filename)
checkpoint = adapt_checkpoint_to_dp_model(checkpoint, model)
model.load_state_dict(state_dict=checkpoint['model_state_dict'])
Expand Down
Loading