Skip to content

Commit

Permalink
fix: crash if requested output field for inference doesn't exist in d…
Browse files Browse the repository at this point in the history
…ataset
  • Loading branch information
ividal committed Jan 16, 2025
1 parent 898b997 commit b9f096a
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions lumigator/python/mzai/jobs/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,19 @@ def run_inference(config: InferenceJobConfig) -> Path:
else:
raise NotImplementedError("Inference pipeline not supported.")

# run inference
# We keep any columns that were already there (not just the original input
# samples, but also past predictions under another column name)
for k in dataset.column_names:
logger.info(f"Keeping original dataset's {k}")
output[k] = dataset[k]

# We are trusting the user: if the dataset already had a column with the output_field
# they selected, we overwrite it with the values from our inference.

if config.job.output_field in dataset.column_names:
logger.info(f"Overwriting {config.job.output_field}")

output[config.job.output_field] = predict(dataset_iterable, model_client)
output["examples"] = dataset["examples"]
output["ground_truth"] = dataset["ground_truth"]
output["model"] = output_model_name

output_path = save_outputs(config, output)
Expand Down

0 comments on commit b9f096a

Please sign in to comment.