diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 8e55ed92..bca9fe81 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -881,12 +881,11 @@ def predict_batch(self, images, preprocess_fn=None): #using Pytorch Ligthning's predict_step with torch.no_grad(): - predictions = [] - for idx, image in enumerate(images): - predictions = self.predict_step(image.unsqueeze(0), idx) - predictions.extend(predictions) + predictions = self.predict_step(images, 0) + #convert predictions to dataframes - results = [pd.DataFrame(pred) for pred in predictions if pred is not None] + results = [utilities.read_file(pred) for pred in predictions if pred is not None] + return results def configure_optimizers(self): diff --git a/tests/test_main.py b/tests/test_main.py index c234aa83..7addb59e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -701,107 +701,52 @@ def test_predict_tile_with_crop_model_empty(): # Assert the result assert result is None -# @pytest.mark.parametrize("batch_size", [1, 4, 8]) -# def test_batch_prediction(m, batch_size, raster_path): -# -# # Prepare input data -# tile = np.array(Image.open(raster_path)) -# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100) -# dl = DataLoader(ds, batch_size=batch_size) - -# # Perform prediction -# predictions = [] -# for batch in dl: -# prediction = m.predict_batch(batch) -# predictions.append(prediction) - -# # Check results -# assert len(predictions) == len(dl) -# for batch_pred in predictions: -# assert isinstance(batch_pred, pd.DataFrame) -# assert set(batch_pred.columns) == { -# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry" -# } - -# @pytest.mark.parametrize("batch_size", [1, 4]) -# def test_batch_training(m, batch_size, tmpdir): -# -# # Generate synthetic training data -# csv_file = get_data("example.csv") -# root_dir = os.path.dirname(csv_file) -# train_ds = m.load_dataset(csv_file, root_dir=root_dir) -# train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) - -# # Configure the model and trainer -# m.config["batch_size"] = batch_size -# m.create_trainer() -# trainer = m.trainer - -# # Train the model -# trainer.fit(m, train_dl) - -# # Assertions -# assert trainer.current_epoch == 1 -# assert trainer.batch_size == batch_size - -# @pytest.mark.parametrize("batch_size", [2, 4]) -# def test_batch_data_augmentation(m, batch_size, raster_path): -# -# tile = np.array(Image.open(raster_path)) -# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100, augment=True) -# dl = DataLoader(ds, batch_size=batch_size) +def test_batch_prediction(m, raster_path): + # Prepare input data + tile = np.array(Image.open(raster_path)) + ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300) + dl = DataLoader(ds, batch_size=3) -# predictions = [] -# for batch in dl: -# prediction = m.predict_batch(batch) -# predictions.append(prediction) + # Perform prediction + predictions = [] + for batch in dl: + prediction = m.predict_batch(batch) + predictions.append(prediction) -# assert len(predictions) == len(dl) -# for batch_pred in predictions: -# assert isinstance(batch_pred, pd.DataFrame) -# assert set(batch_pred.columns) == { -# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry" -# } - -# def test_batch_inference_consistency(m, raster_path): -# -# tile = np.array(Image.open(raster_path)) -# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100) -# dl = DataLoader(ds, batch_size=4) + # Check results + assert len(predictions) == len(dl) + for batch_pred in predictions: + for image_pred in batch_pred: + assert isinstance(image_pred, pd.DataFrame) + assert "label" in image_pred.columns + assert "score" in image_pred.columns + assert "geometry" in image_pred.columns + +def test_batch_inference_consistency(m, raster_path): + tile = np.array(Image.open(raster_path)) + ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300) + dl = DataLoader(ds, batch_size=4) -# batch_predictions = [] -# for batch in dl: -# prediction = m.predict_batch(batch) -# batch_predictions.append(prediction) + batch_predictions = [] + for batch in dl: + prediction = m.predict_batch(batch) + batch_predictions.extend(prediction) -# single_predictions = [] -# for image in ds: -# prediction = m.predict_image(image=image) -# single_predictions.append(prediction) + single_predictions = [] + for image in ds: + image = image.permute(1,2,0).numpy() * 255 + prediction = m.predict_image(image=image) + single_predictions.append(prediction) -# batch_df = pd.concat(batch_predictions, ignore_index=True) -# single_df = pd.concat(single_predictions, ignore_index=True) + batch_df = pd.concat(batch_predictions, ignore_index=True) + single_df = pd.concat(single_predictions, ignore_index=True) -# pd.testing.assert_frame_equal(batch_df, single_df) + # Make all xmin, ymin, xmax, ymax integers + for col in ["xmin", "ymin", "xmax", "ymax"]: + batch_df[col] = batch_df[col].astype(int) + single_df[col] = single_df[col].astype(int) + pd.testing.assert_frame_equal(batch_df[["xmin", "ymin", "xmax", "ymax"]], single_df[["xmin", "ymin", "xmax", "ymax"]], check_dtype=False) -# def test_large_batch_handling(m, raster_path): -# -# tile = np.array(Image.open(raster_path)) -# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100) -# dl = DataLoader(ds, batch_size=16) - -# predictions = [] -# for batch in dl: -# prediction = m.predict_batch(batch) -# predictions.append(prediction) - -# assert len(predictions) > 0 -# for batch_pred in predictions: -# assert isinstance(batch_pred, pd.DataFrame) -# assert set(batch_pred.columns) == { -# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry" -# } -# assert not batch_pred.empty def test_epoch_evaluation_end(m): preds = [{