diff --git a/examples/llm_complex/llm_main_example.py b/examples/llm_complex/llm_main_example.py index 46253d0b..f0916f7d 100644 --- a/examples/llm_complex/llm_main_example.py +++ b/examples/llm_complex/llm_main_example.py @@ -12,8 +12,6 @@ from examples.llm_complex.llm_dataprep import dataprep_workflow from examples.llm_complex.llm_model import model_train, TransformerModel -from flowcept.commons.daos.docdb_dao.mongodb_dao import MongoDBDAO -from flowcept.commons.flowcept_logger import FlowceptLogger from flowcept.configs import MONGO_ENABLED, INSTRUMENTATION from flowcept import Flowcept @@ -237,7 +235,7 @@ def run_asserts_and_exports(campaign_id, model_search_wf_id): return n_workflows_expected, n_tasks_expected -def save_files(campaign_id, model_search_wf_id, output_dir="output_data"): +def save_files(mongo_dao, campaign_id, model_search_wf_id, output_dir="output_data"): os.makedirs(output_dir, exist_ok=True) best_task = Flowcept.db.query({"workflow_id": model_search_wf_id, "activity_id": "model_train"}, limit=1, sort=[("generated.test_loss", Flowcept.db.ASCENDING)])[0] @@ -258,8 +256,7 @@ def save_files(campaign_id, model_search_wf_id, output_dir="output_data"): f"{output_dir}/wf_{model_search_wf_id}_transformer_wikitext2.pth") print("Deleting best model from the database.") - dao = MongoDBDAO(create_indices=False) - dao.delete_object_keys("object_id", [doc["object_id"]]) + mongo_dao.delete_object_keys("object_id", [doc["object_id"]]) workflows_file = f"{output_dir}/workflows_{uuid.uuid4()}.json" print(f"workflows_file = '{workflows_file}'") @@ -311,7 +308,7 @@ def run_campaign(): return _campaign_id, _dataprep_wf_id, _search_wf_id, epochs, max_runs, dataprep_generated["train_n_batches"], dataprep_generated["val_n_batches"] -def asserts_on_saved_dfs(workflows_file, tasks_file, n_workflows_expected, n_tasks_expected, epoch_iterations, max_runs, n_batches_train, n_batches_eval, n_modules): +def asserts_on_saved_dfs(mongo_dao, workflows_file, tasks_file, n_workflows_expected, n_tasks_expected, epoch_iterations, max_runs, n_batches_train, n_batches_eval, n_modules): workflows_df = pd.read_json(workflows_file) # Assert workflows dump assert len(workflows_df) == n_workflows_expected @@ -365,16 +362,14 @@ def asserts_on_saved_dfs(workflows_file, tasks_file, n_workflows_expected, n_tas task_ids = list(tasks_df["task_id"].unique()) workflow_ids = list(workflows_df["workflow_id"].unique()) print("Deleting generated data in MongoDB") - dao = MongoDBDAO(create_indices=False) - dao.delete_task_keys("task_id", task_ids) - dao.delete_workflow_keys("workflow_id", workflow_ids) + mongo_dao.delete_task_keys("task_id", task_ids) + mongo_dao.delete_workflow_keys("workflow_id", workflow_ids) -def verify_number_docs_in_db(n_tasks=None, n_wfs=None, n_objects=None): - dao = MongoDBDAO(create_indices=False) - _n_tasks = dao.count_tasks() - _n_wfs = dao.count_workflows() - _n_objects = dao.count_objects() +def verify_number_docs_in_db(mongo_dao, n_tasks=None, n_wfs=None, n_objects=None): + _n_tasks = mongo_dao.count_tasks() + _n_wfs = mongo_dao.count_workflows() + _n_objects = mongo_dao.count_objects() if n_tasks: if n_tasks != _n_tasks: @@ -400,25 +395,29 @@ def verify_number_docs_in_db(n_tasks=None, n_wfs=None, n_objects=None): def main(): - if not MONGO_ENABLED: - print("This test is only available if Mongo is enabled.") - sys.exit(0) - print("TORCH SETTINGS: " + str(INSTRUMENTATION.get("torch"))) - n_tasks, n_wfs, n_objects = verify_number_docs_in_db() + from flowcept.commons.daos.docdb_dao.mongodb_dao import MongoDBDAO + mongo_dao = MongoDBDAO(create_indices=False) + + n_tasks, n_wfs, n_objects = verify_number_docs_in_db(mongo_dao) campaign_id, dataprep_wf_id, model_search_wf_id, epochs, max_runs, n_batches_train, n_batches_eval = run_campaign() n_workflows_expected, n_tasks_expected = run_asserts_and_exports(campaign_id, model_search_wf_id) - workflows_file, tasks_file = save_files(campaign_id, model_search_wf_id) - asserts_on_saved_dfs(workflows_file, tasks_file, n_workflows_expected, n_tasks_expected, + workflows_file, tasks_file = save_files(mongo_dao, campaign_id, model_search_wf_id) + asserts_on_saved_dfs(mongo_dao, workflows_file, tasks_file, n_workflows_expected, n_tasks_expected, epochs, max_runs, n_batches_train, n_batches_eval, n_modules=4) - verify_number_docs_in_db(n_tasks, n_wfs, n_objects) + verify_number_docs_in_db(mongo_dao, n_tasks, n_wfs, n_objects) print("Alright! Congrats.") if __name__ == "__main__": + + if not MONGO_ENABLED: + print("This test is only available if Mongo is enabled.") + sys.exit(0) + main() sys.exit(0)