Skip to content

Commit

Permalink
Improve predict batch (#876)
Browse files Browse the repository at this point in the history
* run per batch, not per image per batch
  • Loading branch information
bw4sz authored Jan 9, 2025
1 parent ef4435c commit 2bf8a9a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 99 deletions.
9 changes: 4 additions & 5 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,12 +881,11 @@ def predict_batch(self, images, preprocess_fn=None):

#using Pytorch Ligthning's predict_step
with torch.no_grad():
predictions = []
for idx, image in enumerate(images):
predictions = self.predict_step(image.unsqueeze(0), idx)
predictions.extend(predictions)
predictions = self.predict_step(images, 0)

#convert predictions to dataframes
results = [pd.DataFrame(pred) for pred in predictions if pred is not None]
results = [utilities.read_file(pred) for pred in predictions if pred is not None]

return results

def configure_optimizers(self):
Expand Down
133 changes: 39 additions & 94 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,107 +701,52 @@ def test_predict_tile_with_crop_model_empty():
# Assert the result
assert result is None

# @pytest.mark.parametrize("batch_size", [1, 4, 8])
# def test_batch_prediction(m, batch_size, raster_path):
#
# # Prepare input data
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=batch_size)

# # Perform prediction
# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# # Check results
# assert len(predictions) == len(dl)
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }

# @pytest.mark.parametrize("batch_size", [1, 4])
# def test_batch_training(m, batch_size, tmpdir):
#
# # Generate synthetic training data
# csv_file = get_data("example.csv")
# root_dir = os.path.dirname(csv_file)
# train_ds = m.load_dataset(csv_file, root_dir=root_dir)
# train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

# # Configure the model and trainer
# m.config["batch_size"] = batch_size
# m.create_trainer()
# trainer = m.trainer

# # Train the model
# trainer.fit(m, train_dl)

# # Assertions
# assert trainer.current_epoch == 1
# assert trainer.batch_size == batch_size

# @pytest.mark.parametrize("batch_size", [2, 4])
# def test_batch_data_augmentation(m, batch_size, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100, augment=True)
# dl = DataLoader(ds, batch_size=batch_size)
def test_batch_prediction(m, raster_path):
# Prepare input data
tile = np.array(Image.open(raster_path))
ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300)
dl = DataLoader(ds, batch_size=3)

# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)
# Perform prediction
predictions = []
for batch in dl:
prediction = m.predict_batch(batch)
predictions.append(prediction)

# assert len(predictions) == len(dl)
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }

# def test_batch_inference_consistency(m, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=4)
# Check results
assert len(predictions) == len(dl)
for batch_pred in predictions:
for image_pred in batch_pred:
assert isinstance(image_pred, pd.DataFrame)
assert "label" in image_pred.columns
assert "score" in image_pred.columns
assert "geometry" in image_pred.columns

def test_batch_inference_consistency(m, raster_path):
tile = np.array(Image.open(raster_path))
ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300)
dl = DataLoader(ds, batch_size=4)

# batch_predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# batch_predictions.append(prediction)
batch_predictions = []
for batch in dl:
prediction = m.predict_batch(batch)
batch_predictions.extend(prediction)

# single_predictions = []
# for image in ds:
# prediction = m.predict_image(image=image)
# single_predictions.append(prediction)
single_predictions = []
for image in ds:
image = image.permute(1,2,0).numpy() * 255
prediction = m.predict_image(image=image)
single_predictions.append(prediction)

# batch_df = pd.concat(batch_predictions, ignore_index=True)
# single_df = pd.concat(single_predictions, ignore_index=True)
batch_df = pd.concat(batch_predictions, ignore_index=True)
single_df = pd.concat(single_predictions, ignore_index=True)

# pd.testing.assert_frame_equal(batch_df, single_df)
# Make all xmin, ymin, xmax, ymax integers
for col in ["xmin", "ymin", "xmax", "ymax"]:
batch_df[col] = batch_df[col].astype(int)
single_df[col] = single_df[col].astype(int)
pd.testing.assert_frame_equal(batch_df[["xmin", "ymin", "xmax", "ymax"]], single_df[["xmin", "ymin", "xmax", "ymax"]], check_dtype=False)

# def test_large_batch_handling(m, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=16)

# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# assert len(predictions) > 0
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }
# assert not batch_pred.empty

def test_epoch_evaluation_end(m):
preds = [{
Expand Down

0 comments on commit 2bf8a9a

Please sign in to comment.