Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dead code #664

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions torax/fvm/cell_variable.py
Original file line number Diff line number Diff line change
@@ -76,28 +76,6 @@ class CellVariable:
# Can't make the above default values be jax zeros because that would be a
# call to jax before absl.app.run

def project(self, weights):
assert self.history is not None

def project(x):
return jnp.dot(weights, x)

def opt_project(x):
if x is None:
return None
return project(x)

return dataclasses.replace(
self,
value=project(self.value),
dr=self.dr[0],
left_face_constraint=opt_project(self.left_face_constraint),
left_face_grad_constraint=opt_project(self.left_face_grad_constraint),
right_face_constraint=opt_project(self.right_face_constraint),
right_face_grad_constraint=opt_project(self.right_face_grad_constraint),
history=None,
)

def __post_init__(self):
self.sanity_check()

@@ -266,18 +244,6 @@ def assert_not_history(self):
'by `jax.lax.scan`. Most methods of a CellVariable '
'do not work in history mode.'
)
if hasattr(self.history, 'ndim'):
if self.history.ndim == 0 or (
self.history.ndim == 1 and self.history.shape[0] == 1
):
msg += (
f' self.history={self.history} which probably indicates'
' (due to its scalar shape)'
' that an indexing or projection operation failed to'
' turn off history mode. self.history should be None for'
' non-history or a a vector of shape (history_length) for'
' history.'
)
raise AssertionError(msg)

def __hash__(self):
23 changes: 0 additions & 23 deletions torax/state.py
Original file line number Diff line number Diff line change
@@ -194,29 +194,6 @@ def sanity_check(self):
if hasattr(value, "sanity_check"):
value.sanity_check()

def project(self, weights):
project = lambda x: jnp.dot(weights, x)
proj_currents = jax.tree_util.tree_map(project, self.currents)
return dataclasses.replace(
self,
temp_ion=self.temp_ion.project(weights),
temp_el=self.temp_el.project(weights),
psi=self.psi.project(weights),
psidot=self.psidot.project(weights),
ne=self.ne.project(weights),
ni=self.ni.project(weights),
currents=proj_currents,
q_face=project(self.q_face),
s_face=project(self.s_face),
nref=project(self.nref),
Zi=project(self.Zi),
Zi_face=project(self.Zi_face),
Ai=project(self.Ai),
Zimp=project(self.Zimp),
Zimp_face=project(self.Zimp_face),
Aimp=project(self.Aimp),
)

def __hash__(self):
"""Make CoreProfiles hashable.

31 changes: 0 additions & 31 deletions torax/tests/state.py
Original file line number Diff line number Diff line change
@@ -153,37 +153,6 @@ def test_index(
for i in range(self.history_length):
self.assertEqual(i, history.index(i).temp_ion.value[0])

@parameterized.parameters([
dict(references_getter=torax_refs.circular_references),
dict(references_getter=torax_refs.chease_references_Ip_from_chease),
dict(
references_getter=torax_refs.chease_references_Ip_from_runtime_params
),
])
def test_project(
self,
references_getter: Callable[[], torax_refs.References],
):
"""Test State.project."""
references = references_getter()
history = self._make_history(
references.runtime_params, references.geometry_provider
)

seed = 20230421
rng_state = jax.random.PRNGKey(seed)
del seed # Make sure seed isn't accidentally re-used
weights = jax.random.normal(rng_state, (self.history_length,))
del rng_state # Make sure rng_state isn't accidentally re-used

expected = jnp.dot(weights, jnp.arange(self.history_length))

projected = history.project(weights)

actual = projected.temp_ion.value[0]

np.testing.assert_allclose(expected, actual)


class InitialStatesTest(parameterized.TestCase):
"""Unit tests for the `torax.core_profile_setters` module."""