-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Moved evaluator to jobs dir * Fixed breaking evaluator integration test * Providing minor for backend python version * Added first draft of inference job * Addressing review comments * Removed timing function * Minor fixes * Addressed review comments --------- Signed-off-by: Davide Eynard <[email protected]>
- Loading branch information
Showing
89 changed files
with
327 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
3.11 | ||
3.11.9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Inference Documentation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
"""python job to run batch inference""" | ||
|
||
import argparse | ||
import json | ||
from collections.abc import Iterable | ||
from pathlib import Path | ||
|
||
import s3fs | ||
from box import Box | ||
from datasets import load_from_disk | ||
from loguru import logger | ||
from model_clients import ( | ||
BaseModelClient, | ||
MistralModelClient, | ||
OpenAIModelClient, | ||
) | ||
from tqdm import tqdm | ||
|
||
|
||
def predict(dataset_iterable: Iterable, model_client: BaseModelClient) -> list: | ||
predictions = [] | ||
|
||
for sample_txt in dataset_iterable: | ||
predictions.append(model_client.predict(sample_txt)) | ||
|
||
return predictions | ||
|
||
|
||
def save_to_disk(local_path: Path, data_dict: dict): | ||
logger.info(f"Storing into {local_path}...") | ||
local_path.parent.mkdir(exist_ok=True, parents=True) | ||
with local_path.open("w") as f: | ||
json.dump(data_dict, f) | ||
|
||
|
||
def save_to_s3(config: Box, local_path: Path, storage_path: str): | ||
s3 = s3fs.S3FileSystem() | ||
if storage_path.endswith("/"): | ||
storage_path = "s3://" + str( | ||
Path(storage_path[5:]) / config.name / "inference_results.json" | ||
) | ||
logger.info(f"Storing into {storage_path}...") | ||
s3.put_file(local_path, storage_path) | ||
|
||
|
||
def save_outputs(config: Box, inference_results: dict) -> Path: | ||
storage_path = config.evaluation.storage_path | ||
|
||
# generate local temp file ANYWAY: | ||
# - if storage_path is not provided, it will be stored and kept into a default dir | ||
# - if storage_path is provided AND saving to S3 is successful, local file is deleted | ||
local_path = Path( | ||
Path.home() / ".lumigator" / "results" / config.name / "inference_results.json" | ||
) | ||
|
||
try: | ||
save_to_disk(local_path, inference_results) | ||
|
||
# copy to s3 and return path | ||
if storage_path is not None and storage_path.startswith("s3://"): | ||
save_to_s3(config, local_path, storage_path) | ||
Path.unlink(local_path) | ||
Path.rmdir(local_path.parent) | ||
return storage_path | ||
else: | ||
return local_path | ||
|
||
except Exception as e: | ||
logger.error(e) | ||
|
||
|
||
def run_inference(config: Box) -> Path: | ||
# initialize output dictionary | ||
output = {} | ||
|
||
# Load dataset given its URI | ||
dataset = load_from_disk(config.dataset.path) | ||
|
||
# Limit dataset length if max_samples is specified | ||
max_samples = config.evaluation.max_samples | ||
if max_samples is not None and max_samples > 0: | ||
if max_samples > len(dataset): | ||
logger.info(f"max_samples ({max_samples}) resized to dataset size ({len(dataset)})") | ||
max_samples = len(dataset) | ||
dataset = dataset.select(range(max_samples)) | ||
|
||
# Enable / disable tqdm | ||
input_samples = dataset["examples"] | ||
dataset_iterable = tqdm(input_samples) if config.evaluation.enable_tqdm else input_samples | ||
|
||
# Choose which model client to use | ||
if config.model.inference is not None: | ||
# a model *inference service* is passed | ||
base_url = config.model.inference.base_url | ||
output_model_name = config.model.inference.engine | ||
if "mistral" in base_url: | ||
# run the mistral client | ||
logger.info(f"Using Mistral client. Endpoint: {base_url}") | ||
model_client = MistralModelClient(base_url, config.model) | ||
else: | ||
# run the openai client | ||
logger.info(f"Using OAI client. Endpoint: {base_url}") | ||
model_client = OpenAIModelClient(base_url, config.model) | ||
|
||
# run inference | ||
output["predictions"] = predict(dataset_iterable, model_client) | ||
output["examples"] = dataset["examples"] | ||
output["ground_truth"] = dataset["ground_truth"] | ||
output["model"] = output_model_name | ||
|
||
output_path = save_outputs(config, output) | ||
return output_path | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--config", type=str, help="Configuration in JSON format") | ||
args = parser.parse_args() | ||
|
||
if not args.config: | ||
parser.print_help() # Print the usage message and exit | ||
err_str = "No input configuration provided. Please pass one using the --config flag" | ||
logger.error(err_str) | ||
else: | ||
config = json.loads(args.config) | ||
result_dataset_path = run_inference(Box(config, default_box=True, default_box_attr=None)) | ||
logger.info(f"Inference results stored at {result_dataset_path}") |
Oops, something went wrong.