From 50f5e3ee640c392da1507e59dd80cdbc6ab63e27 Mon Sep 17 00:00:00 2001 From: prashkh Date: Tue, 30 Apr 2024 21:07:53 -0400 Subject: [PATCH] Added multi-threaded mode solver via `mode.run_batch` --- CHANGELOG.md | 1 + tests/test_plugins/test_mode_solver.py | 71 ++++++++++++++++ tidy3d/plugins/mode/web.py | 5 +- tidy3d/web/api/mode.py | 111 +++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ad530ede..3ec88652e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- A batch of `ModeSolver` objects can be run concurrently using `tidy3d.plugins.mode.web.run_batch()` - `RectangularWaveguide.plot_field` optionally draws geometry edges over fields. - `RectangularWaveguide` supports layered cladding above and below core. diff --git a/tests/test_plugins/test_mode_solver.py b/tests/test_plugins/test_mode_solver.py index e43296859..edcf9975f 100644 --- a/tests/test_plugins/test_mode_solver.py +++ b/tests/test_plugins/test_mode_solver.py @@ -16,6 +16,7 @@ from ..utils import log_capture # noqa: F401 from tidy3d import ScalarFieldDataArray from tidy3d.web.core.environment import Env +from tidy3d.components.data.monitor_data import ModeSolverData WG_MEDIUM = td.Medium(permittivity=4.0, conductivity=1e-4) @@ -108,6 +109,34 @@ def mock_download(resource_id, remote_filename, to_file, *args, **kwargs): status=200, ) + responses.add( + responses.POST, + f"{Env.current.web_api_endpoint}/tidy3d/modesolver/py", + match=[ + responses.matchers.json_params_matcher( + { + "projectId": PROJECT_ID, + "taskName": "BatchModeSolver_0", + "modeSolverName": MODESOLVER_NAME + "_batch_0", + "fileType": "Gz", + "source": "Python", + "protocolVersion": td.version.__version__, + } + ) + ], + json={ + "data": { + "refId": TASK_ID, + "id": SOLVER_ID, + "status": "draft", + "createdAt": "2023-05-19T16:47:57.190Z", + "charge": 0, + "fileType": "Gz", + } + }, + status=200, + ) + responses.add( responses.GET, f"{Env.current.web_api_endpoint}/tidy3d/modesolver/py/{TASK_ID}/{SOLVER_ID}", @@ -837,3 +866,45 @@ def test_mode_solver_method_defaults(): sim = ms.sim_with_monitor(name="test") assert np.allclose(sim.monitors[-1].freqs, ms.freqs) + + +@responses.activate +def test_mode_solver_web_run_batch(mock_remote_api): + """Testing run_batch function for the web mode solver.""" + + wav = 1.5 + wav_min = 1.4 + wav_max = 1.5 + num_freqs = 1 + num_of_sims = 1 + freqs = np.linspace(td.C_0 / wav_min, td.C_0 / wav_max, num_freqs) + + simulation = td.Simulation( + size=SIM_SIZE, + grid_spec=td.GridSpec(wavelength=wav), + structures=[WAVEGUIDE], + run_time=1e-12, + boundary_spec=td.BoundarySpec.all_sides(boundary=td.PML()), + ) + + # create a list of mode solvers + mode_solver_list = [None] * num_of_sims + + # create three different mode solvers with different number of modes specifications + for i in range(num_of_sims): + mode_solver_list[i] = ModeSolver( + simulation=simulation, + plane=PLANE, + mode_spec=td.ModeSpec( + num_modes=i + 1, + target_neff=2.0, + ), + freqs=freqs, + direction="+", + ) + + # Run mode solver one at a time + results = msweb.run_batch(mode_solver_list, verbose=False, folder_name="Mode Solver") + [print(type(x)) for x in results] + assert all([isinstance(x, ModeSolverData) for x in results]) + assert (results[i].n_eff.shape == (num_freqs, i + 1) for i in range(num_of_sims)) diff --git a/tidy3d/plugins/mode/web.py b/tidy3d/plugins/mode/web.py index cc132832c..edbbbecca 100644 --- a/tidy3d/plugins/mode/web.py +++ b/tidy3d/plugins/mode/web.py @@ -1,4 +1,5 @@ """Web API for mode solver""" -from ...web.api.mode import run -__all__ = ["run"] +from ...web.api.mode import run, run_batch + +__all__ = ["run", "run_batch"] diff --git a/tidy3d/web/api/mode.py b/tidy3d/web/api/mode.py index aa90e706d..b6fb8a298 100644 --- a/tidy3d/web/api/mode.py +++ b/tidy3d/web/api/mode.py @@ -8,10 +8,14 @@ import pathlib import tempfile import time +from joblib import Parallel, delayed import pydantic.v1 as pydantic from botocore.exceptions import ClientError +from typing import List +from rich.progress import Progress + from ..core.environment import Env from ...components.simulation import Simulation from ...components.data.monitor_data import ModeSolverData @@ -39,6 +43,10 @@ MODESOLVER_RESULT = "output/result.hdf5" MODESOLVER_RESULT_GZ = "output/mode_solver_data.hdf5.gz" +DEFAULT_NUM_WORKERS = 10 +DEFAULT_MAX_RETRIES = 3 +DEFAULT_RETRY_DELAY = 10 # in seconds + def run( mode_solver: ModeSolver, @@ -138,6 +146,109 @@ def run( ) +def run_batch( + mode_solvers: List[ModeSolver], + task_name: str = "BatchModeSolver", + folder_name: str = "BatchModeSolvers", + results_files: List[str] = None, + verbose: bool = True, + max_workers: int = DEFAULT_NUM_WORKERS, + max_retries: int = DEFAULT_MAX_RETRIES, + retry_delay: float = DEFAULT_RETRY_DELAY, + progress_callback_upload: Callable[[float], None] = None, + progress_callback_download: Callable[[float], None] = None, +) -> List[ModeSolverData]: + """ + Submits a batch of ModeSolver to the server concurrently, manages progress, and retrieves results. + + Parameters + ---------- + mode_solvers : List[ModeSolver] + List of mode solvers to be submitted to the server. + task_name : str + Base name for tasks. Each task in the batch will have a unique index appended to this base name. + folder_name : str + Name of the folder where tasks are stored on the server's web UI. + results_files : List[str], optional + List of file paths where the results for each ModeSolver should be downloaded. If None, a default path based on the folder name and index is used. + verbose : bool + If True, displays a progress bar. If False, runs silently. + max_workers : int + Maximum number of concurrent workers to use for processing the batch of simulations. + max_retries : int + Maximum number of retries for each simulation in case of failure before giving up. + retry_delay : int + Delay in seconds between retries when a simulation fails. + progress_callback_upload : Callable[[float], None], optional + Optional callback function called when uploading file with ``bytes_in_chunk`` as argument. + progress_callback_download : Callable[[float], None], optional + Optional callback function called when downloading file with ``bytes_in_chunk`` as argument. + + + Returns + ------- + List[ModeSolverData] + A list of ModeSolverData objects containing the results from each simulation in the batch. ``None`` is placed in the list for simulations that fail after all retries. + """ + console = get_logging_console() + + if not all(isinstance(x, ModeSolver) for x in mode_solvers): + console.log( + "Validation Error: All items in `mode_solvers` must be instances of `ModeSolver`." + ) + return [] + + num_mode_solvers = len(mode_solvers) + + if results_files is None: + results_files = [f"mode_solver_batch_results_{i}.hdf5" for i in range(num_mode_solvers)] + + def handle_mode_solver(index, progress, pbar): + retries = 0 + while retries <= max_retries: + try: + result = run( + mode_solver=mode_solvers[index], + task_name=f"{task_name}_{index}", + mode_solver_name=f"mode_solver_batch_{index}", + folder_name=folder_name, + results_file=results_files[index], + verbose=False, + progress_callback_upload=progress_callback_upload, + progress_callback_download=progress_callback_download, + ) + if verbose: + progress.update(pbar, advance=1) + return result + except Exception as e: + console.log(f"Error in mode solver {index}: {str(e)}") + if retries < max_retries: + time.sleep(retry_delay) + retries += 1 + else: + console.log(f"The {index}-th mode solver failed after {max_retries} tries.") + if verbose: + progress.update(pbar, advance=1) + return None + + if verbose: + console.log(f"[cyan]Running a batch of [deep_pink4]{num_mode_solvers} mode solvers.\n") + with Progress(console=console) as progress: + pbar = progress.add_task("Status:", total=num_mode_solvers) + results = Parallel(n_jobs=max_workers, backend="threading")( + delayed(handle_mode_solver)(i, progress, pbar) for i in range(num_mode_solvers) + ) + # Make sure the progress bar is complete + progress.update(pbar, completed=num_mode_solvers, refresh=True) + console.log("[green]A batch of `ModeSolver` tasks completed successfully!") + else: + results = Parallel(n_jobs=max_workers, backend="threading")( + delayed(handle_mode_solver)(i, None, None) for i in range(num_mode_solvers) + ) + + return results + + class ModeSolverTask(ResourceLifecycle, Submittable, extra=pydantic.Extra.allow): """Interface for managing the running of a :class:`.ModeSolver` task on server."""