From 47b9804f5218f014a33cb7af353715dc74673738 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irina=20Vidal=20Migall=C3=B3n?= Date: Fri, 10 Jan 2025 17:51:58 +0100 Subject: [PATCH] fix: crash if requested output field for inference doesn't exist in dataset --- lumigator/python/mzai/jobs/inference/inference.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/lumigator/python/mzai/jobs/inference/inference.py b/lumigator/python/mzai/jobs/inference/inference.py index e6aa8a5ad..fef58a5e8 100644 --- a/lumigator/python/mzai/jobs/inference/inference.py +++ b/lumigator/python/mzai/jobs/inference/inference.py @@ -109,10 +109,19 @@ def run_inference(config: InferenceJobConfig) -> Path: else: raise NotImplementedError("Inference pipeline not supported.") - # run inference + # We keep any columns that were already there (not just the original input + # samples, but also past predictions under another column name) + for k in dataset.column_names: + logger.info(f"Keeping original dataset's {k}") + output[k] = dataset[k] + + # We are trusting the user: if the dataset already had a column with the output_field + # they selected, we overwrite it with the values from our inference. + + if config.job.output_field in dataset.column_names: + logger.info(f"Overwriting {config.job.output_field}") + output[config.job.output_field] = predict(dataset_iterable, model_client) - output["examples"] = dataset["examples"] - output["ground_truth"] = dataset["ground_truth"] output["model"] = output_model_name output_path = save_outputs(config, output)