diff --git a/lumigator/python/mzai/jobs/inference/inference.py b/lumigator/python/mzai/jobs/inference/inference.py index e6aa8a5ad..fef58a5e8 100644 --- a/lumigator/python/mzai/jobs/inference/inference.py +++ b/lumigator/python/mzai/jobs/inference/inference.py @@ -109,10 +109,19 @@ def run_inference(config: InferenceJobConfig) -> Path: else: raise NotImplementedError("Inference pipeline not supported.") - # run inference + # We keep any columns that were already there (not just the original input + # samples, but also past predictions under another column name) + for k in dataset.column_names: + logger.info(f"Keeping original dataset's {k}") + output[k] = dataset[k] + + # We are trusting the user: if the dataset already had a column with the output_field + # they selected, we overwrite it with the values from our inference. + + if config.job.output_field in dataset.column_names: + logger.info(f"Overwriting {config.job.output_field}") + output[config.job.output_field] = predict(dataset_iterable, model_client) - output["examples"] = dataset["examples"] - output["ground_truth"] = dataset["ground_truth"] output["model"] = output_model_name output_path = save_outputs(config, output)