Skip to content

Commit

Permalink
Rebase to main
Browse files Browse the repository at this point in the history
  • Loading branch information
javiermtorres committed Jan 29, 2025
1 parent 5cfea73 commit 0c8ef33
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 195 deletions.
9 changes: 6 additions & 3 deletions lumigator/python/mzai/backend/backend/services/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@

import loguru
from fastapi import BackgroundTasks
from lumigator_schemas.experiments import ExperimentCreate, ExperimentResponse
from lumigator_schemas.experiments import (
ExperimentCreate,
ExperimentResponse,
ExperimentResultDownloadResponse,
)
from lumigator_schemas.extras import ListingResponse
from lumigator_schemas.jobs import (
JobCreate,
JobEvalLiteConfig,
JobInferenceConfig,
JobResponse,
JobStatus,
JobType,
)
from lumigator_schemas.extras import ListingResponse
from lumigator_schemas.jobs import JobEvalLiteCreate, JobInferenceCreate, JobResponse, JobStatus
from s3fs import S3FileSystem

from backend.records.jobs import JobRecord
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,7 @@
from lumigator_schemas.datasets import DatasetFormat, DatasetResponse
from lumigator_schemas.experiments import ExperimentResponse
from lumigator_schemas.extras import ListingResponse
<<<<<<< HEAD
from lumigator_schemas.jobs import (
Job,
JobLogsResponse,
JobResponse,
JobResultDownloadResponse,
JobStatus,
)
=======
from lumigator_schemas.jobs import JobLogsResponse, JobResponse, JobResultDownloadResponse, JobType
>>>>>>> 697d8f81 (Replace templates with pydantic models)

from backend.main import app
from backend.tests.conftest import (
Expand All @@ -30,187 +20,6 @@
wait_for_job,
)

<<<<<<< HEAD
# @app.on_event("startup")
# def test_health_ok(local_client: TestClient):
# response = local_client.get("/health/")
# assert response.status_code == 200
#
#
# def test_upload_data_launch_job(
# local_client: TestClient,
# dialog_dataset,
# simple_eval_template,
# simple_infer_template,
# dependency_overrides_services,
# ):
# response = local_client.get("/health")
# assert response.status_code == 200
#
# logger.info(f"Running test...")
#
# create_response = local_client.post(
# "/datasets/",
# data={},
# files={"dataset": dialog_dataset, "format": (None, DatasetFormat.JOB.value)},
# )
#
# assert create_response.status_code == 201
#
# created_dataset = DatasetResponse.model_validate(create_response.json())
#
# get_ds_response = local_client.get("/datasets/")
# assert get_ds_response.status_code == 200
# get_ds = ListingResponse[DatasetResponse].model_validate(get_ds_response.json())
#
# headers = {
# "accept": "application/json",
# "Content-Type": "application/json",
# }
# infer_payload = {
# "name": "test_run_hugging_face",
# "description": "Test run for Huggingface model",
# "model": TEST_CAUSAL_MODEL,
# "dataset": str(created_dataset.id),
# "max_samples": 10,
# "config_template": simple_infer_template,
# "output_field": "predictions",
# "store_to_dataset": True,
# }
# create_inference_job_response = local_client.post(
# "/jobs/inference/", headers=headers, json=infer_payload
# )
# assert create_inference_job_response.status_code == 201
#
# create_inference_job_response_model = JobResponse.model_validate(
# create_inference_job_response.json()
# )
#
# wait_for_job(local_client, create_inference_job_response_model.id)
#
# logs_infer_job_response = local_client.get(
# f"/jobs/{create_inference_job_response_model.id}/logs"
# )
# logs_infer_job_response_model = JobLogsResponse.model_validate(logs_infer_job_response.json())
# logger.info(f"-- infer logs -- {create_inference_job_response_model.id}")
# logger.info(f"#{logs_infer_job_response_model.logs}#")
#
# # retrieve the DS for the infer job...
# output_infer_job_response = local_client.get(
# f"/jobs/{create_inference_job_response_model.id}/dataset"
# )
# output_infer_job_response_model = DatasetResponse.model_validate(
# output_infer_job_response.json()
# )
# assert output_infer_job_response_model is not None
#
# headers = {
# "accept": "application/json",
# "Content-Type": "application/json",
# }
# eval_payload = {
# "name": "test_run_hugging_face",
# "description": "Test run for Huggingface model",
# "model": TEST_CAUSAL_MODEL,
# "dataset": str(output_infer_job_response_model.id),
# "config_template": simple_eval_template,
# "max_samples": 10,
# }
#
# create_evaluation_job_response = local_client.post(
# "/jobs/eval_lite/", headers=headers, json=eval_payload
# )
# assert create_evaluation_job_response.status_code == 201
#
# create_evaluation_job_response_model = JobResponse.model_validate(
# create_evaluation_job_response.json()
# )
#
# assert wait_for_job(local_client, create_evaluation_job_response_model.id)
#
# logs_evaluation_job_response = local_client.get(
# f"/jobs/{create_evaluation_job_response_model.id}/logs"
# )
# logs_evaluation_job_response_model = JobLogsResponse.model_validate(
# logs_evaluation_job_response.json()
# )
# logger.info(f"-- eval logs -- {create_evaluation_job_response_model.id}")
# logger.info(f"#{logs_evaluation_job_response_model.logs}#")
#
# get_ds_after_response = local_client.get("/datasets/")
# assert get_ds_after_response.status_code == 200
# get_ds_after = ListingResponse[DatasetResponse].model_validate(get_ds_after_response.json())
# assert get_ds_after.total == get_ds.total + 1
#
#
# @pytest.mark.parametrize("unnanotated_dataset", ["dialog_empty_gt_dataset", "dialog_no_gt_dataset"])
# def test_upload_data_no_gt_launch_annotation(
# request: pytest.FixtureRequest,
# local_client: TestClient,
# unnanotated_dataset,
# simple_eval_template,
# simple_infer_template,
# dependency_overrides_services,
# ):
# dataset = request.getfixturevalue(unnanotated_dataset)
# create_response = local_client.post(
# "/datasets/",
# data={},
# files={"dataset": dataset, "format": (None, DatasetFormat.JOB.value)},
# )
#
# assert create_response.status_code == 201
#
# created_dataset = DatasetResponse.model_validate(create_response.json())
#
# headers = {
# "accept": "application/json",
# "Content-Type": "application/json",
# }
#
# annotation_payload = {
# "name": "test_annotate",
# "description": "Test run for Huggingface model",
# "dataset": str(created_dataset.id),
# "max_samples": 2,
# "task": "summarization",
# }
# create_annotation_job_response = local_client.post(
# "/jobs/annotate/", headers=headers, json=annotation_payload
# )
# assert create_annotation_job_response.status_code == 201
#
# create_annotation_job_response_model = JobResponse.model_validate(
# create_annotation_job_response.json()
# )
#
# assert wait_for_job(local_client, create_annotation_job_response_model.id)
#
# logs_annotation_job_response = local_client.get(
# f"/jobs/{create_annotation_job_response_model.id}/logs"
# )
# logger.info(logs_annotation_job_response)
# logs_annotation_job_response_model = JobLogsResponse.model_validate(
# logs_annotation_job_response.json()
# )
# logger.info(f"-- infer logs -- {create_annotation_job_response_model.id}")
# logger.info(f"#{logs_annotation_job_response_model.logs}#")
#
# logs_annotation_job_results = local_client.get(
# f"/jobs/{create_annotation_job_response_model.id}/result/download"
# )
# logs_annotation_job_results_model = JobResultDownloadResponse.model_validate(
# logs_annotation_job_results.json()
# )
# logger.info(f"Download url: {logs_annotation_job_results_model.download_url}")
# logs_annotation_job_results_url = requests.get(logs_annotation_job_results_model.download_url)
# logs_annotation_job_output = InferenceJobOutput.model_validate(
# logs_annotation_job_results_url.json()
# )
# assert logs_annotation_job_output.predictions is None
# assert logs_annotation_job_output.ground_truth is not None
# logger.info(f"Created results: {logs_annotation_job_output}")
=======

@app.on_event("startup")
def test_health_ok(local_client: TestClient):
Expand Down Expand Up @@ -393,7 +202,6 @@ def test_upload_data_no_gt_launch_annotation(
assert logs_annotation_job_output.predictions is None
assert logs_annotation_job_output.ground_truth is not None
logger.info(f"Created results: {logs_annotation_job_output}")
>>>>>>> 697d8f81 (Replace templates with pydantic models)


def test_full_experiment_launch(
Expand Down

0 comments on commit 0c8ef33

Please sign in to comment.