Skip to content

Commit

Permalink
Remove dead code
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 717712430
  • Loading branch information
tamaranorman authored and Torax team committed Jan 21, 2025
1 parent 198d7b5 commit a7166e6
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 88 deletions.
34 changes: 0 additions & 34 deletions torax/fvm/cell_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,6 @@ class CellVariable:
# Can't make the above default values be jax zeros because that would be a
# call to jax before absl.app.run

def project(self, weights):
assert self.history is not None

def project(x):
return jnp.dot(weights, x)

def opt_project(x):
if x is None:
return None
return project(x)

return dataclasses.replace(
self,
value=project(self.value),
dr=self.dr[0],
left_face_constraint=opt_project(self.left_face_constraint),
left_face_grad_constraint=opt_project(self.left_face_grad_constraint),
right_face_constraint=opt_project(self.right_face_constraint),
right_face_grad_constraint=opt_project(self.right_face_grad_constraint),
history=None,
)

def __post_init__(self):
self.sanity_check()

Expand Down Expand Up @@ -266,18 +244,6 @@ def assert_not_history(self):
'by `jax.lax.scan`. Most methods of a CellVariable '
'do not work in history mode.'
)
if hasattr(self.history, 'ndim'):
if self.history.ndim == 0 or (
self.history.ndim == 1 and self.history.shape[0] == 1
):
msg += (
f' self.history={self.history} which probably indicates'
' (due to its scalar shape)'
' that an indexing or projection operation failed to'
' turn off history mode. self.history should be None for'
' non-history or a a vector of shape (history_length) for'
' history.'
)
raise AssertionError(msg)

def __hash__(self):
Expand Down
23 changes: 0 additions & 23 deletions torax/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,29 +194,6 @@ def sanity_check(self):
if hasattr(value, "sanity_check"):
value.sanity_check()

def project(self, weights):
project = lambda x: jnp.dot(weights, x)
proj_currents = jax.tree_util.tree_map(project, self.currents)
return dataclasses.replace(
self,
temp_ion=self.temp_ion.project(weights),
temp_el=self.temp_el.project(weights),
psi=self.psi.project(weights),
psidot=self.psidot.project(weights),
ne=self.ne.project(weights),
ni=self.ni.project(weights),
currents=proj_currents,
q_face=project(self.q_face),
s_face=project(self.s_face),
nref=project(self.nref),
Zi=project(self.Zi),
Zi_face=project(self.Zi_face),
Ai=project(self.Ai),
Zimp=project(self.Zimp),
Zimp_face=project(self.Zimp_face),
Aimp=project(self.Aimp),
)

def __hash__(self):
"""Make CoreProfiles hashable.
Expand Down
31 changes: 0 additions & 31 deletions torax/tests/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,37 +153,6 @@ def test_index(
for i in range(self.history_length):
self.assertEqual(i, history.index(i).temp_ion.value[0])

@parameterized.parameters([
dict(references_getter=torax_refs.circular_references),
dict(references_getter=torax_refs.chease_references_Ip_from_chease),
dict(
references_getter=torax_refs.chease_references_Ip_from_runtime_params
),
])
def test_project(
self,
references_getter: Callable[[], torax_refs.References],
):
"""Test State.project."""
references = references_getter()
history = self._make_history(
references.runtime_params, references.geometry_provider
)

seed = 20230421
rng_state = jax.random.PRNGKey(seed)
del seed # Make sure seed isn't accidentally re-used
weights = jax.random.normal(rng_state, (self.history_length,))
del rng_state # Make sure rng_state isn't accidentally re-used

expected = jnp.dot(weights, jnp.arange(self.history_length))

projected = history.project(weights)

actual = projected.temp_ion.value[0]

np.testing.assert_allclose(expected, actual)


class InitialStatesTest(parameterized.TestCase):
"""Unit tests for the `torax.core_profile_setters` module."""
Expand Down

0 comments on commit a7166e6

Please sign in to comment.