Skip to content

Commit

Permalink
Added multi-threaded mode solver via mode.run_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
prashkh authored and tylerflex committed May 24, 2024
1 parent 38b9f93 commit 50f5e3e
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- A batch of `ModeSolver` objects can be run concurrently using `tidy3d.plugins.mode.web.run_batch()`
- `RectangularWaveguide.plot_field` optionally draws geometry edges over fields.
- `RectangularWaveguide` supports layered cladding above and below core.

Expand Down
71 changes: 71 additions & 0 deletions tests/test_plugins/test_mode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..utils import log_capture # noqa: F401
from tidy3d import ScalarFieldDataArray
from tidy3d.web.core.environment import Env
from tidy3d.components.data.monitor_data import ModeSolverData


WG_MEDIUM = td.Medium(permittivity=4.0, conductivity=1e-4)
Expand Down Expand Up @@ -108,6 +109,34 @@ def mock_download(resource_id, remote_filename, to_file, *args, **kwargs):
status=200,
)

responses.add(
responses.POST,
f"{Env.current.web_api_endpoint}/tidy3d/modesolver/py",
match=[
responses.matchers.json_params_matcher(
{
"projectId": PROJECT_ID,
"taskName": "BatchModeSolver_0",
"modeSolverName": MODESOLVER_NAME + "_batch_0",
"fileType": "Gz",
"source": "Python",
"protocolVersion": td.version.__version__,
}
)
],
json={
"data": {
"refId": TASK_ID,
"id": SOLVER_ID,
"status": "draft",
"createdAt": "2023-05-19T16:47:57.190Z",
"charge": 0,
"fileType": "Gz",
}
},
status=200,
)

responses.add(
responses.GET,
f"{Env.current.web_api_endpoint}/tidy3d/modesolver/py/{TASK_ID}/{SOLVER_ID}",
Expand Down Expand Up @@ -837,3 +866,45 @@ def test_mode_solver_method_defaults():

sim = ms.sim_with_monitor(name="test")
assert np.allclose(sim.monitors[-1].freqs, ms.freqs)


@responses.activate
def test_mode_solver_web_run_batch(mock_remote_api):
"""Testing run_batch function for the web mode solver."""

wav = 1.5
wav_min = 1.4
wav_max = 1.5
num_freqs = 1
num_of_sims = 1
freqs = np.linspace(td.C_0 / wav_min, td.C_0 / wav_max, num_freqs)

simulation = td.Simulation(
size=SIM_SIZE,
grid_spec=td.GridSpec(wavelength=wav),
structures=[WAVEGUIDE],
run_time=1e-12,
boundary_spec=td.BoundarySpec.all_sides(boundary=td.PML()),
)

# create a list of mode solvers
mode_solver_list = [None] * num_of_sims

# create three different mode solvers with different number of modes specifications
for i in range(num_of_sims):
mode_solver_list[i] = ModeSolver(
simulation=simulation,
plane=PLANE,
mode_spec=td.ModeSpec(
num_modes=i + 1,
target_neff=2.0,
),
freqs=freqs,
direction="+",
)

# Run mode solver one at a time
results = msweb.run_batch(mode_solver_list, verbose=False, folder_name="Mode Solver")
[print(type(x)) for x in results]
assert all([isinstance(x, ModeSolverData) for x in results])
assert (results[i].n_eff.shape == (num_freqs, i + 1) for i in range(num_of_sims))
5 changes: 3 additions & 2 deletions tidy3d/plugins/mode/web.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Web API for mode solver"""
from ...web.api.mode import run

__all__ = ["run"]
from ...web.api.mode import run, run_batch

__all__ = ["run", "run_batch"]
111 changes: 111 additions & 0 deletions tidy3d/web/api/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
import pathlib
import tempfile
import time
from joblib import Parallel, delayed

import pydantic.v1 as pydantic
from botocore.exceptions import ClientError

from typing import List
from rich.progress import Progress

from ..core.environment import Env
from ...components.simulation import Simulation
from ...components.data.monitor_data import ModeSolverData
Expand Down Expand Up @@ -39,6 +43,10 @@
MODESOLVER_RESULT = "output/result.hdf5"
MODESOLVER_RESULT_GZ = "output/mode_solver_data.hdf5.gz"

DEFAULT_NUM_WORKERS = 10
DEFAULT_MAX_RETRIES = 3
DEFAULT_RETRY_DELAY = 10 # in seconds


def run(
mode_solver: ModeSolver,
Expand Down Expand Up @@ -138,6 +146,109 @@ def run(
)


def run_batch(
mode_solvers: List[ModeSolver],
task_name: str = "BatchModeSolver",
folder_name: str = "BatchModeSolvers",
results_files: List[str] = None,
verbose: bool = True,
max_workers: int = DEFAULT_NUM_WORKERS,
max_retries: int = DEFAULT_MAX_RETRIES,
retry_delay: float = DEFAULT_RETRY_DELAY,
progress_callback_upload: Callable[[float], None] = None,
progress_callback_download: Callable[[float], None] = None,
) -> List[ModeSolverData]:
"""
Submits a batch of ModeSolver to the server concurrently, manages progress, and retrieves results.
Parameters
----------
mode_solvers : List[ModeSolver]
List of mode solvers to be submitted to the server.
task_name : str
Base name for tasks. Each task in the batch will have a unique index appended to this base name.
folder_name : str
Name of the folder where tasks are stored on the server's web UI.
results_files : List[str], optional
List of file paths where the results for each ModeSolver should be downloaded. If None, a default path based on the folder name and index is used.
verbose : bool
If True, displays a progress bar. If False, runs silently.
max_workers : int
Maximum number of concurrent workers to use for processing the batch of simulations.
max_retries : int
Maximum number of retries for each simulation in case of failure before giving up.
retry_delay : int
Delay in seconds between retries when a simulation fails.
progress_callback_upload : Callable[[float], None], optional
Optional callback function called when uploading file with ``bytes_in_chunk`` as argument.
progress_callback_download : Callable[[float], None], optional
Optional callback function called when downloading file with ``bytes_in_chunk`` as argument.
Returns
-------
List[ModeSolverData]
A list of ModeSolverData objects containing the results from each simulation in the batch. ``None`` is placed in the list for simulations that fail after all retries.
"""
console = get_logging_console()

if not all(isinstance(x, ModeSolver) for x in mode_solvers):
console.log(
"Validation Error: All items in `mode_solvers` must be instances of `ModeSolver`."
)
return []

num_mode_solvers = len(mode_solvers)

if results_files is None:
results_files = [f"mode_solver_batch_results_{i}.hdf5" for i in range(num_mode_solvers)]

def handle_mode_solver(index, progress, pbar):
retries = 0
while retries <= max_retries:
try:
result = run(
mode_solver=mode_solvers[index],
task_name=f"{task_name}_{index}",
mode_solver_name=f"mode_solver_batch_{index}",
folder_name=folder_name,
results_file=results_files[index],
verbose=False,
progress_callback_upload=progress_callback_upload,
progress_callback_download=progress_callback_download,
)
if verbose:
progress.update(pbar, advance=1)
return result
except Exception as e:
console.log(f"Error in mode solver {index}: {str(e)}")
if retries < max_retries:
time.sleep(retry_delay)
retries += 1
else:
console.log(f"The {index}-th mode solver failed after {max_retries} tries.")
if verbose:
progress.update(pbar, advance=1)
return None

if verbose:
console.log(f"[cyan]Running a batch of [deep_pink4]{num_mode_solvers} mode solvers.\n")
with Progress(console=console) as progress:
pbar = progress.add_task("Status:", total=num_mode_solvers)
results = Parallel(n_jobs=max_workers, backend="threading")(
delayed(handle_mode_solver)(i, progress, pbar) for i in range(num_mode_solvers)
)
# Make sure the progress bar is complete
progress.update(pbar, completed=num_mode_solvers, refresh=True)
console.log("[green]A batch of `ModeSolver` tasks completed successfully!")
else:
results = Parallel(n_jobs=max_workers, backend="threading")(
delayed(handle_mode_solver)(i, None, None) for i in range(num_mode_solvers)
)

return results


class ModeSolverTask(ResourceLifecycle, Submittable, extra=pydantic.Extra.allow):
"""Interface for managing the running of a :class:`.ModeSolver` task on server."""

Expand Down

0 comments on commit 50f5e3e

Please sign in to comment.