Skip to content

Commit

Permalink
Fix in llm test runner
Browse files Browse the repository at this point in the history
  • Loading branch information
renan-souza committed Jan 13, 2025
1 parent ff5623a commit 516d5da
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions examples/llm_complex/llm_main_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

from examples.llm_complex.llm_dataprep import dataprep_workflow
from examples.llm_complex.llm_model import model_train, TransformerModel
from flowcept.commons.daos.docdb_dao.mongodb_dao import MongoDBDAO
from flowcept.commons.flowcept_logger import FlowceptLogger

from flowcept.configs import MONGO_ENABLED, INSTRUMENTATION
from flowcept import Flowcept
Expand Down Expand Up @@ -237,7 +235,7 @@ def run_asserts_and_exports(campaign_id, model_search_wf_id):
return n_workflows_expected, n_tasks_expected


def save_files(campaign_id, model_search_wf_id, output_dir="output_data"):
def save_files(mongo_dao, campaign_id, model_search_wf_id, output_dir="output_data"):
os.makedirs(output_dir, exist_ok=True)
best_task = Flowcept.db.query({"workflow_id": model_search_wf_id, "activity_id": "model_train"}, limit=1,
sort=[("generated.test_loss", Flowcept.db.ASCENDING)])[0]
Expand All @@ -258,8 +256,7 @@ def save_files(campaign_id, model_search_wf_id, output_dir="output_data"):
f"{output_dir}/wf_{model_search_wf_id}_transformer_wikitext2.pth")

print("Deleting best model from the database.")
dao = MongoDBDAO(create_indices=False)
dao.delete_object_keys("object_id", [doc["object_id"]])
mongo_dao.delete_object_keys("object_id", [doc["object_id"]])

workflows_file = f"{output_dir}/workflows_{uuid.uuid4()}.json"
print(f"workflows_file = '{workflows_file}'")
Expand Down Expand Up @@ -311,7 +308,7 @@ def run_campaign():
return _campaign_id, _dataprep_wf_id, _search_wf_id, epochs, max_runs, dataprep_generated["train_n_batches"], dataprep_generated["val_n_batches"]


def asserts_on_saved_dfs(workflows_file, tasks_file, n_workflows_expected, n_tasks_expected, epoch_iterations, max_runs, n_batches_train, n_batches_eval, n_modules):
def asserts_on_saved_dfs(mongo_dao, workflows_file, tasks_file, n_workflows_expected, n_tasks_expected, epoch_iterations, max_runs, n_batches_train, n_batches_eval, n_modules):
workflows_df = pd.read_json(workflows_file)
# Assert workflows dump
assert len(workflows_df) == n_workflows_expected
Expand Down Expand Up @@ -365,16 +362,14 @@ def asserts_on_saved_dfs(workflows_file, tasks_file, n_workflows_expected, n_tas
task_ids = list(tasks_df["task_id"].unique())
workflow_ids = list(workflows_df["workflow_id"].unique())
print("Deleting generated data in MongoDB")
dao = MongoDBDAO(create_indices=False)
dao.delete_task_keys("task_id", task_ids)
dao.delete_workflow_keys("workflow_id", workflow_ids)
mongo_dao.delete_task_keys("task_id", task_ids)
mongo_dao.delete_workflow_keys("workflow_id", workflow_ids)


def verify_number_docs_in_db(n_tasks=None, n_wfs=None, n_objects=None):
dao = MongoDBDAO(create_indices=False)
_n_tasks = dao.count_tasks()
_n_wfs = dao.count_workflows()
_n_objects = dao.count_objects()
def verify_number_docs_in_db(mongo_dao, n_tasks=None, n_wfs=None, n_objects=None):
_n_tasks = mongo_dao.count_tasks()
_n_wfs = mongo_dao.count_workflows()
_n_objects = mongo_dao.count_objects()

if n_tasks:
if n_tasks != _n_tasks:
Expand All @@ -400,25 +395,29 @@ def verify_number_docs_in_db(n_tasks=None, n_wfs=None, n_objects=None):

def main():

if not MONGO_ENABLED:
print("This test is only available if Mongo is enabled.")
sys.exit(0)

print("TORCH SETTINGS: " + str(INSTRUMENTATION.get("torch")))

n_tasks, n_wfs, n_objects = verify_number_docs_in_db()
from flowcept.commons.daos.docdb_dao.mongodb_dao import MongoDBDAO
mongo_dao = MongoDBDAO(create_indices=False)

n_tasks, n_wfs, n_objects = verify_number_docs_in_db(mongo_dao)

campaign_id, dataprep_wf_id, model_search_wf_id, epochs, max_runs, n_batches_train, n_batches_eval = run_campaign()

n_workflows_expected, n_tasks_expected = run_asserts_and_exports(campaign_id, model_search_wf_id)
workflows_file, tasks_file = save_files(campaign_id, model_search_wf_id)
asserts_on_saved_dfs(workflows_file, tasks_file, n_workflows_expected, n_tasks_expected,
workflows_file, tasks_file = save_files(mongo_dao, campaign_id, model_search_wf_id)
asserts_on_saved_dfs(mongo_dao, workflows_file, tasks_file, n_workflows_expected, n_tasks_expected,
epochs, max_runs, n_batches_train, n_batches_eval, n_modules=4)
verify_number_docs_in_db(n_tasks, n_wfs, n_objects)
verify_number_docs_in_db(mongo_dao, n_tasks, n_wfs, n_objects)
print("Alright! Congrats.")


if __name__ == "__main__":

if not MONGO_ENABLED:
print("This test is only available if Mongo is enabled.")
sys.exit(0)

main()
sys.exit(0)

0 comments on commit 516d5da

Please sign in to comment.