From a7166e6e952b133ebffa2a6ce1c95b21f049f66e Mon Sep 17 00:00:00 2001 From: Tamara Norman Date: Mon, 20 Jan 2025 20:22:10 -0800 Subject: [PATCH] Remove dead code PiperOrigin-RevId: 717712430 --- torax/fvm/cell_variable.py | 34 ---------------------------------- torax/state.py | 23 ----------------------- torax/tests/state.py | 31 ------------------------------- 3 files changed, 88 deletions(-) diff --git a/torax/fvm/cell_variable.py b/torax/fvm/cell_variable.py index ed508406..050beb50 100644 --- a/torax/fvm/cell_variable.py +++ b/torax/fvm/cell_variable.py @@ -76,28 +76,6 @@ class CellVariable: # Can't make the above default values be jax zeros because that would be a # call to jax before absl.app.run - def project(self, weights): - assert self.history is not None - - def project(x): - return jnp.dot(weights, x) - - def opt_project(x): - if x is None: - return None - return project(x) - - return dataclasses.replace( - self, - value=project(self.value), - dr=self.dr[0], - left_face_constraint=opt_project(self.left_face_constraint), - left_face_grad_constraint=opt_project(self.left_face_grad_constraint), - right_face_constraint=opt_project(self.right_face_constraint), - right_face_grad_constraint=opt_project(self.right_face_grad_constraint), - history=None, - ) - def __post_init__(self): self.sanity_check() @@ -266,18 +244,6 @@ def assert_not_history(self): 'by `jax.lax.scan`. Most methods of a CellVariable ' 'do not work in history mode.' ) - if hasattr(self.history, 'ndim'): - if self.history.ndim == 0 or ( - self.history.ndim == 1 and self.history.shape[0] == 1 - ): - msg += ( - f' self.history={self.history} which probably indicates' - ' (due to its scalar shape)' - ' that an indexing or projection operation failed to' - ' turn off history mode. self.history should be None for' - ' non-history or a a vector of shape (history_length) for' - ' history.' - ) raise AssertionError(msg) def __hash__(self): diff --git a/torax/state.py b/torax/state.py index 1f7bc6a2..d432d11a 100644 --- a/torax/state.py +++ b/torax/state.py @@ -194,29 +194,6 @@ def sanity_check(self): if hasattr(value, "sanity_check"): value.sanity_check() - def project(self, weights): - project = lambda x: jnp.dot(weights, x) - proj_currents = jax.tree_util.tree_map(project, self.currents) - return dataclasses.replace( - self, - temp_ion=self.temp_ion.project(weights), - temp_el=self.temp_el.project(weights), - psi=self.psi.project(weights), - psidot=self.psidot.project(weights), - ne=self.ne.project(weights), - ni=self.ni.project(weights), - currents=proj_currents, - q_face=project(self.q_face), - s_face=project(self.s_face), - nref=project(self.nref), - Zi=project(self.Zi), - Zi_face=project(self.Zi_face), - Ai=project(self.Ai), - Zimp=project(self.Zimp), - Zimp_face=project(self.Zimp_face), - Aimp=project(self.Aimp), - ) - def __hash__(self): """Make CoreProfiles hashable. diff --git a/torax/tests/state.py b/torax/tests/state.py index a9755703..f20918ca 100644 --- a/torax/tests/state.py +++ b/torax/tests/state.py @@ -153,37 +153,6 @@ def test_index( for i in range(self.history_length): self.assertEqual(i, history.index(i).temp_ion.value[0]) - @parameterized.parameters([ - dict(references_getter=torax_refs.circular_references), - dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict( - references_getter=torax_refs.chease_references_Ip_from_runtime_params - ), - ]) - def test_project( - self, - references_getter: Callable[[], torax_refs.References], - ): - """Test State.project.""" - references = references_getter() - history = self._make_history( - references.runtime_params, references.geometry_provider - ) - - seed = 20230421 - rng_state = jax.random.PRNGKey(seed) - del seed # Make sure seed isn't accidentally re-used - weights = jax.random.normal(rng_state, (self.history_length,)) - del rng_state # Make sure rng_state isn't accidentally re-used - - expected = jnp.dot(weights, jnp.arange(self.history_length)) - - projected = history.project(weights) - - actual = projected.temp_ion.value[0] - - np.testing.assert_allclose(expected, actual) - class InitialStatesTest(parameterized.TestCase): """Unit tests for the `torax.core_profile_setters` module."""