diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f6002efee1..69e3b9bf0c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -22,6 +22,8 @@ Config Updates - Added a new ``initial_image`` input type that lets you read in an existing image file and draw onto it. (#1237) - Added skip_failures option in stamp fields. (#1238) +- Let input items depend on other input items, even if they appear later in the input field. + (#1239) New Features diff --git a/galsim/config/input.py b/galsim/config/input.py index ab5e311c54..784d05f964 100644 --- a/galsim/config/input.py +++ b/galsim/config/input.py @@ -164,11 +164,7 @@ class InputManager(SafeManager): pass if file_scope_only and not loader.file_scope: continue logger.debug('file %d: Process input key %s',file_num,key) - fields = config['input'][key] - nfields = len(fields) if isinstance(fields, list) else 1 - - for num in range(nfields): - input_obj = LoadInputObj(config, key, num, safe_only, logger) + LoadAllInputObj(config, key, safe_only, logger) # Check that there are no other attributes specified. valid_keys = valid_input_types.keys() @@ -192,6 +188,34 @@ def SetupInput(config, logger=None): PropagateIndexKeyRNGNum(config['input']) ProcessInput(config, logger=logger) + +def LoadAllInputObj(config, key, safe_only=False, logger=None): + """Load all items of a single input type, named key, with definition given by the dict field. + + This function just detects if the dict item for this key is a list and calls LoadInputObj + for every num. + + .. note:: + + This is designed as an internal implementation detail, not meant to be used by end users. + So it doesn't have some of the safeguards we normally put on public facing functions. + However, we expect the API to persist, and we'll try to use deprecations if we change + anything, so if users find it useful, it is fine to go ahead and use it in your own + custom input module implementations. + + Parameters: + config: The configuration dict to process + key: The key name of this input type + safe_only: Only load "safe" input objects. + logger: If given, a logger object to log progress. [default: None] + """ + fields = config['input'][key] + nfields = len(fields) if isinstance(fields, list) else 1 + for num in range(nfields): + input_obj = LoadInputObj(config, key, num, safe_only, logger) + return input_obj + + def LoadInputObj(config, key, num=0, safe_only=False, logger=None): """Load a single input object, named key, with definition given by the dict field. @@ -249,6 +273,14 @@ def LoadInputObj(config, key, num=0, safe_only=False, logger=None): # just implies that this input object isn't safe to keep around anyway. # So in this case, we just continue on. If it was not a safe_only run, # the exception is reraised. + # The one exception is if the exception indicates that another input type was needed. + # In that case, we load the dependent input type and try again. + if str(e).startswith("No input"): + dep_input = str(e).split()[2] + logger.info("%s input seems to depend on %s. Try loading that.", key, dep_input) + input_obj = LoadAllInputObj(config, dep_input, safe_only=safe_only, logger=logger) + # Now recurse to try this key again. + return LoadInputObj(config, key, num=num, safe_only=safe_only, logger=logger) if safe_only: logger.debug('file %d: caught exception: %s', file_num,e) safe = False @@ -507,7 +539,7 @@ def getKwargs(self, config, base, logger): """ req, opt, single, takes_rng = get_cls_params(self.init_func) kwargs, safe = GetAllParams(config, base, req=req, opt=opt, single=single) - if takes_rng: # pragma: no cover (We don't have any inputs that do this.) + if takes_rng: rng = GetRNG(config, base, logger, 'input '+self.init_func.__name__) kwargs['rng'] = rng safe = False diff --git a/tests/test_config_input.py b/tests/test_config_input.py index 6e3f0e30c4..32b3185890 100644 --- a/tests/test_config_input.py +++ b/tests/test_config_input.py @@ -288,6 +288,154 @@ def BuildAtmPSF(config, base, ignore, gsparams, logger): galsim.config.InputLoader(AtmPSF, use_proxy=False, worker_initargs=galsim.phase_screens.initWorkerArgs) +def test_dependent_inputs(): + """Test inputs that depend on other inputs. + + imSim has input types that depend on other input types. If these are listed in order, + with dependencies first, everything is fine. But when using recursive templates, this can + be difficult to get right. So we now have a check in GalSim when loading in input object, + if it raises an exception that indicates it needs a different input type, try to load that + one first. + """ + class Dict1: + def __init__(self): + self.d = {'a': 1, 'b': 2} + + def Dict1Item(config, base, value_type): + d = galsim.config.GetInputObj('dict1', config, base, 'Dict1').d + key, safe = galsim.config.ParseValue(config, 'key', base, str) + return d[key], safe + + class Dict2: + def __init__(self): + self.d = {'c': 1, 'd': 2} + + def Dict2Item(config, base, value_type): + d = galsim.config.GetInputObj('dict2', config, base, 'Dict2').d + key, safe = galsim.config.ParseValue(config, 'key', base, str) + return d[key], safe + + class Dep: + _req_params = dict(a=float, b=float, c=float, d=float) + + def __init__(self, a, b, c, d): + self.d = dict(a=a, b=b, c=c, d=d) + + def DepItem(config, base, value_type): + d = galsim.config.GetInputObj('dep', config, base, 'Dep').d + key, safe = galsim.config.ParseValue(config, 'key', base, str) + return d[key], safe + + galsim.config.RegisterInputType('dict1', galsim.config.InputLoader(Dict1)) + galsim.config.RegisterValueType('Dict1Item', Dict1Item, input_type='dict1', valid_types=[float]) + galsim.config.RegisterInputType('dict2', galsim.config.InputLoader(Dict2)) + galsim.config.RegisterValueType('Dict2Item', Dict2Item, input_type='dict2', valid_types=[float]) + galsim.config.RegisterInputType('dep', galsim.config.InputLoader(Dep)) + galsim.config.RegisterValueType('DepItem', DepItem, input_type='dep', valid_types=[float]) + + # First put the input items in order, so all dependencies are resolved before they are needed. + config = { + 'input': { + 'dict1': {}, + 'dict2': {}, + 'dep': { + 'a': {'type': 'Dict1Item', 'key': 'a'}, + 'b': {'type': 'Dict1Item', 'key': 'b'}, + 'c': {'type': 'Dict2Item', 'key': 'c'}, + 'd': {'type': 'Dict2Item', 'key': 'd'}, + }, + } + } + with CaptureLog() as cl: + galsim.config.ProcessInput(config, cl.logger) + assert 'input seems to depend on' not in cl.output + dep = galsim.config.GetInputObj('dep', config, config, 'Dep') + assert dep.d == dict(a=1, b=2, c=1, d=2) + + # Now put them out of order. + config = { + 'input': { + 'dep': { + 'a': {'type': 'Dict1Item', 'key': 'a'}, + 'b': {'type': 'Dict1Item', 'key': 'b'}, + 'c': {'type': 'Dict2Item', 'key': 'c'}, + 'd': {'type': 'Dict2Item', 'key': 'd'}, + }, + 'dict1': {}, + 'dict2': {}, + } + } + with CaptureLog() as cl: + galsim.config.ProcessInput(config, cl.logger) + assert 'dep input seems to depend on dict1' in cl.output + assert 'dep input seems to depend on dict2' in cl.output + dep = galsim.config.GetInputObj('dep', config, config, 'Dep') + assert dep.d == dict(a=1, b=2, c=1, d=2) + + # If it depends on something that in turn fails, it doesn't work. + class Dict3: + _opt_params = dict(e=float) + _takes_rng = True + def __init__(self, rng, e=None): + ud = galsim.UniformDeviate(rng) + self.d = dict(a=ud(), b=ud(), c=ud(), d=ud()) + if e: self.d['e'] = e + + def Dict3Item(config, base, value_type): + d = galsim.config.GetInputObj('dict3', config, base, 'Dict3').d + key, safe = galsim.config.ParseValue(config, 'key', base, str) + return d[key], safe + + galsim.config.RegisterInputType('dict3', galsim.config.InputLoader(Dict3)) + galsim.config.RegisterValueType('Dict3Item', Dict3Item, input_type='dict3', valid_types=[float]) + + config = { + 'input': { + 'dep': { + 'a': {'type': 'Dict1Item', 'key': 'a'}, + 'b': {'type': 'Dict3Item', 'key': 'b'}, + 'c': {'type': 'Dict2Item', 'key': 'c'}, + 'd': {'type': 'Dict3Item', 'key': 'd'}, + }, + 'dict1': {}, + 'dict2': {}, + 'dict3': { + 'e': {'type': 'Dict4Item', 'key': 'a'}, + } + } + } + with np.testing.assert_raises(galsim.config.GalSimConfigError): + galsim.config.ProcessInput(config, cl.logger) + + # But with safe_only=True, it doesn't raise an exception + config = galsim.config.CleanConfig(config) + with CaptureLog() as cl: + galsim.config.ProcessInput(config, cl.logger, safe_only=True) + + assert 'dep input seems to depend on dict1' in cl.output + assert 'dep input seems to depend on dict2' not in cl.output # Doesn't get to the Dict2Item + assert 'dep input seems to depend on dict3' in cl.output + dep = galsim.config.GetInputObj('dep', config, config, 'Dep') + assert dep is None + + # If the dependency graph is circular, make sure we don't get an infinite loop. + config = galsim.config.CleanConfig(config) + config['input']['dict3']['e'] = {'type': 'DepItem', 'key': 'a'} + galsim.config.ProcessInput(config, cl.logger, safe_only=True) + dep = galsim.config.GetInputObj('dep', config, config, 'Dep') + assert dep is None + + # Finally, just make sure that if dict3 isn't broken, it all works as expected. + config = galsim.config.CleanConfig(config) + config['input']['dict3']['e']['type'] = 'Dict1Item' + galsim.config.ProcessInput(config, cl.logger) + dict3 = galsim.config.GetInputObj('dict3', config, config, 'Dict3') + dep = galsim.config.GetInputObj('dep', config, config, 'Dep') + print('dict3.d = ',dict3.d) + print('dep.d = ',dep.d) + assert dep.d == dict(a=1, b=dict3.d['b'], c=1, d=dict3.d['d']) + + if __name__ == "__main__": testfns = [v for k, v in vars().items() if k[:5] == 'test_' and callable(v)] for testfn in testfns: