diff --git a/damnit/ctxsupport/ctxrunner.py b/damnit/ctxsupport/ctxrunner.py index e6c404ec..5c69ca05 100644 --- a/damnit/ctxsupport/ctxrunner.py +++ b/damnit/ctxsupport/ctxrunner.py @@ -233,8 +233,12 @@ def from_str(cls, code: str, path=''): log.debug("Loaded %d variables", len(vars)) return cls(vars, code) - def vars_to_dict(self): - """Get a plain dict of variable metadata to store in the database""" + def vars_to_dict(self, inc_transient=False): + """Get a plain dict of variable metadata to store in the database + + args: + inc_transient (bool): include transient Variables in the dict + """ return { name: { 'title': v.title, @@ -244,6 +248,7 @@ def vars_to_dict(self): 'type': None, } for (name, v) in self.vars.items() + if not v.transient or inc_transient } def filter(self, run_data=RunData.ALL, cluster=True, name_matches=(), variables=()): @@ -349,6 +354,12 @@ def execute(self, run_data, run_number, proposal, input_vars) -> 'Results': t1 = time.perf_counter() log.info("Computed %s in %.03f s", name, t1 - t0) res[name] = data + + # remove transient results + for name, var in self.vars.items(): + if var.transient and (name in res): + res.pop(name) + return Results(res, self) diff --git a/damnit/ctxsupport/damnit_ctx.py b/damnit/ctxsupport/damnit_ctx.py index 35464261..451cb797 100644 --- a/damnit/ctxsupport/damnit_ctx.py +++ b/damnit/ctxsupport/damnit_ctx.py @@ -41,13 +41,15 @@ class Variable: _name = None def __init__( - self, title=None, description=None, summary=None, data=None, cluster=False, tags=None, + self, title=None, description=None, summary=None, data=None, + cluster=False, tags=None, transient=False ): self.title = title self.tags = (tags,) if isinstance(tags, str) else tags self.description = description self.summary = summary self.cluster = cluster + self.transient = transient self._data = data # @Variable() is used as a decorator on a function that computes a value diff --git a/docs/backend.md b/docs/backend.md index 163743a0..85918ed8 100644 --- a/docs/backend.md +++ b/docs/backend.md @@ -71,6 +71,9 @@ these arguments: ``` - `cluster` (bool): whether or not to execute this variable in a Slurm job. This should always be used if the variable does any heavy processing. +- `transient` (bool): whether or not to save the variable's result to the + database. This is useful for e.g. intermediate results. By default variables + save their results (transient=False). Variable functions can return any of: diff --git a/docs/changelog.md b/docs/changelog.md index a567d928..64333881 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -10,6 +10,7 @@ Added: - Add a `tags` attribute allowing cathegorizing `Variable`s (!354). - Add support for `complex` numbers (!374) - GUI: Add a Dark theme (!376) +- add a`transient` attribute for variables we don't want to save data (!xxx) Changed: diff --git a/tests/test_backend.py b/tests/test_backend.py index dc447a94..04fdc8fb 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -25,7 +25,7 @@ from damnit.backend import backend_is_running, initialize_and_start_backend from damnit.backend.db import DamnitDB -from damnit.backend.extract_data import Extractor, RunExtractor, add_to_db +from damnit.backend.extract_data import Extractor, RunExtractor, add_to_db, load_reduced_data from damnit.backend.extraction_control import ExtractionJobTracker from damnit.backend.listener import (MAX_CONCURRENT_THREADS, EventProcessor, local_extraction_threads) @@ -1002,3 +1002,47 @@ def test_job_tracker(): fake_squeue.assert_called() assert set(tracker.jobs) == set() + +def test_transient_variables(mock_run, mock_db, tmp_path): + db_dir, db = mock_db + + ctx_code = """ + from damnit_ctx import Variable, Cell + import numpy as np + + @Variable() + def var1(run): + return 7 + + @Variable(transient=True) + def var2(run, data: 'var#var1'): + return np.arange(data) + + @Variable(summary='max') + def var3(run, data: 'var#var2'): + return data.size * data + """ + ctx = mkcontext(ctx_code) + results = ctx.execute(mock_run, 1000, 123, {}) + results_hdf5_path = tmp_path / 'results.h5' + results.save_hdf5(results_hdf5_path) + + with h5py.File(results_hdf5_path) as f: + assert '.reduced/var1' in f + assert 'var1' in f + # transient variables are not saved + assert '.reduced/var2' not in f + assert 'var2' not in f + assert '.reduced/var3' in f + assert 'var3' in f + + assert f['.reduced/var3'][()] == 42 + assert np.allclose(f['var3/data'][()], np.arange(7) * 7) + + reduced_data = load_reduced_data(results_hdf5_path) + add_to_db(reduced_data, db, 1000, 123) + vars = db.conn.execute('SELECT value FROM run_variables WHERE name="var3"').fetchall() + assert vars[0]['value'] == 42 + # also not saved in the db + vars = db.conn.execute('SELECT * FROM run_variables WHERE name="var2"').fetchall() + assert vars == []