diff --git a/predict.py b/predict.py index b457f696524..a1edd6df95e 100644 --- a/predict.py +++ b/predict.py @@ -61,7 +61,13 @@ def write_mask(mask: torch.Tensor, output_dir: str, input_filename: str) -> None: - """Write mask to specified output directory.""" + """Write mask to specified output directory with same filename as input raster. + + Args: + mask (torch.Tensor): mask tensor + output_dir (str): output directory + input_filename (str): path to input raster + """ output_path = os.path.join(output_dir, os.path.basename(input_filename)) with rio.open(input_filename) as src: profile = src.profile @@ -73,7 +79,19 @@ def write_mask(mask: torch.Tensor, output_dir: str, input_filename: str) -> None def main(config_dir: str, predict_on: str, output_dir: str, device: str) -> None: - """Main inference loop.""" + """Main inference loop. + + Args: + config_dir (str): Path to config-dir to load config and ckpt + predict_on (str): Directory/Dataset to run inference on + output_dir (str): Path to output_directory to save predicted masks + device (str): Choice of device. Must be in [cuda, cpu] + + Raises: + ValueError: Raised if task name is not in TASK_TO_MODULES_MAPPING + FileExistsError: Raised if specified output directory contains + files and overwrite=False. + """ os.makedirs(output_dir, exist_ok=True) # Load checkpoint and config