Skip to content

Commit

Permalink
restrict number of traced structures in autograd
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed May 29, 2024
1 parent c75e583 commit afea9ee
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 12 deletions.
Binary file modified results.prof
Binary file not shown.
Empty file.
51 changes: 39 additions & 12 deletions tidy3d/web/api/autograd/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
)
URL_LINK = f"[blue underline][link={ISSUE_URL}]'{ISSUE_URL}'[/link][/blue underline]"

MAX_NUM_TRACED_STRUCTURES = 500


def warn_autograd(fn_name: str, exc: Exception) -> str:
"""Get warning message."""
Expand All @@ -42,6 +44,41 @@ def warn_autograd(fn_name: str, exc: Exception) -> str:
)


def is_valid_for_autograd(simulation: td.Simulation) -> bool:
"""Check whether a supplied simulation can use autograd run."""

# only support Simulations
if not isinstance(simulation, td.Simulation):
return False

# if no tracers just use regular web.run()
traced_fields = simulation.strip_traced_fields(
include_untraced_data_arrays=False, starting_path=("structures",)
)
if not traced_fields:
return False

# if too many structures, raise an error
structure_indices = [i for key, i, *_ in traced_fields.keys() if key == "structures"]
num_traced_structures = len(structure_indices)
if num_traced_structures > MAX_NUM_TRACED_STRUCTURES:
raise ValueError(
f"Autograd support is currently limited to {MAX_NUM_TRACED_STRUCTURES} structures with "
f"traced fields. Found {num_traced_structures} structures with traced fields."
)

return True


def is_valid_for_autograd_async(simulations: dict[str, td.Simulation]) -> bool:
"""Check whether the supplied simulations dict can use autograd run_async."""
if not isinstance(simulations, dict):
return False
if not all(is_valid_for_autograd(sim) for sim in simulations.values()):
return False
return True


def run(
simulation: SimulationType,
task_name: str,
Expand Down Expand Up @@ -131,7 +168,7 @@ def run(
Monitor progress of each of the running tasks.
"""

if isinstance(simulation, td.Simulation) and simulation.strip_traced_fields():
if is_valid_for_autograd(simulation):
try:
return _run(
simulation=simulation,
Expand Down Expand Up @@ -213,17 +250,7 @@ def run_async(
Interface for submitting several :class:`Simulation` objects to sever.
"""

def is_valid_for_autograd(simulations: dict[str, td.Simulation]) -> bool:
"""Check whether the supplied simulations dict can use autograd run."""
if not isinstance(simulations, dict):
return False
if not all(isinstance(sim, td.Simulation) for sim in simulations.values()):
return False
if not any(sim.strip_traced_fields() for sim in simulations.values()):
return False
return True

if is_valid_for_autograd(simulations):
if is_valid_for_autograd_async(simulations):
try:
return _run_async(
simulations=simulations,
Expand Down

0 comments on commit afea9ee

Please sign in to comment.