Skip to content

Commit

Permalink
Merge pull request #466 from LSSTDESC/instcat_delay_read
Browse files Browse the repository at this point in the history
Delay read of instcat until data is needed
  • Loading branch information
rmjarvis authored Jun 6, 2024
2 parents d580f71 + feac1e0 commit 25f868b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 28 deletions.
76 changes: 55 additions & 21 deletions imsim/instcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,19 @@ class InstCatalog(object):
fnu = (0 * u.ABmag).to(u.erg/u.s/u.cm**2/u.Hz)
_flux_density = fnu.to_value(u.ph/u.nm/u.s/u.cm**2, u.spectral_density(500*u.nm))
def __init__(self, file_name, wcs, xsize=4096, ysize=4096, sed_dir=None,
edge_pix=100, sort_mag=True, flip_g2=True, approx_nobjects=None,
edge_pix=100, sort_mag=True, flip_g2=True,
pupil_area=RUBIN_AREA, min_source=None, skip_invalid=True,
logger=None):
logger = galsim.config.LoggerWrapper(logger)
self.file_name = file_name
self.wcs = wcs
self.xsize = xsize
self.ysize = ysize
self.edge_pix = edge_pix
self.sort_mag = sort_mag
self.flip_g2 = flip_g2
self.approx_nobjects = approx_nobjects
self.min_source = min_source
self.skip_invalid = skip_invalid
self.pupil_area = pupil_area
self._sed_cache = {}

Expand All @@ -185,8 +191,20 @@ def __init__(self, file_name, wcs, xsize=4096, ysize=4096, sed_dir=None,
self.sed_dir = sed_dir
self.inst_dir = os.path.dirname(file_name)

self._id = None # Sentinal that _load hasn't been run yet.

@property
def id(self):
self._load()
return self._id

def _load(self, logger=None):
if self._id is not None:
return

logger = galsim.config.LoggerWrapper(logger)
min_ra, max_ra, min_dec, max_dec, min_x, min_y, max_x, max_y, ref_ra \
= get_radec_limits(wcs, xsize, ysize, logger, edge_pix)
= get_radec_limits(self.wcs, self.xsize, self.ysize, logger, self.edge_pix)

# What position do the dust parameters start, based on object type.
dust_index_dict = {
Expand All @@ -204,7 +222,7 @@ def __init__(self, file_name, wcs, xsize=4096, ysize=4096, sed_dir=None,
lens_list = []
objinfo_list = []
dust_list = []
g2_sign = -1 if flip_g2 else 1
g2_sign = -1 if self.flip_g2 else 1
logger.warning('Reading instance catalog %s', self.file_name)
nuse = 0
ntot = 0
Expand All @@ -229,7 +247,7 @@ def __init__(self, file_name, wcs, xsize=4096, ysize=4096, sed_dir=None,
world_pos = galsim.CelestialCoord(ra, dec)
#logger.debug('world_pos = %s',world_pos)
try:
image_pos = wcs.toImage(world_pos)
image_pos = self.wcs.toImage(world_pos)
except RuntimeError as e:
# Inverse direction can fail for objects off the image.
logger.debug('%s',e)
Expand All @@ -251,7 +269,7 @@ def __init__(self, file_name, wcs, xsize=4096, ysize=4096, sed_dir=None,
objinfo = tokens[12:dust_index]
dust = tokens[dust_index:]

if skip_invalid:
if self.skip_invalid:
# Check for some reasons to skip this object.
object_is_valid = (magnorm < 50.0 and
not (objinfo[0] == 'sersic2d' and
Expand Down Expand Up @@ -281,7 +299,7 @@ def __init__(self, file_name, wcs, xsize=4096, ysize=4096, sed_dir=None,
logger.warning("No objects found on image")

# Sort the object lists by mag and convert to numpy arrays.
self.id = np.array(id_list, dtype=str)
self._id = np.array(id_list, dtype=str)
self.world_pos = np.array(world_pos_list, dtype=object)
self.image_pos = np.array(image_pos_list, dtype=object)
self.magnorm = np.array(magnorm_list, dtype=float)
Expand All @@ -290,11 +308,11 @@ def __init__(self, file_name, wcs, xsize=4096, ysize=4096, sed_dir=None,
self.objinfo = np.array(objinfo_list, dtype=object)
self.dust = np.array(dust_list, dtype=object)

if min_source is not None:
if self.min_source is not None:
nsersic = np.sum([params[0].lower() == 'sersic2d' for params in self.objinfo])
if nsersic < min_source:
logger.warning(f"Fewer than {min_source} galaxies on sensor. Skipping.")
self.id = self.id[:0]
if nsersic < self.min_source:
logger.warning(f"Fewer than {self.min_source} galaxies on sensor. Skipping.")
self._id = self._id[:0]
self.world_pos = self.world_pos[:0]
self.image_pos = self.image_pos[:0]
self.magnorm = self.magnorm[:0]
Expand All @@ -303,9 +321,9 @@ def __init__(self, file_name, wcs, xsize=4096, ysize=4096, sed_dir=None,
self.objinfo = self.objinfo[:0]
self.dust = self.dust[:0]

if sort_mag:
if self.sort_mag:
index = np.argsort(self.magnorm)
self.id = self.id[index]
self._id = self._id[index]
self.world_pos = self.world_pos[index]
self.image_pos = self.image_pos[index]
self.magnorm = self.magnorm[index]
Expand All @@ -316,11 +334,20 @@ def __init__(self, file_name, wcs, xsize=4096, ysize=4096, sed_dir=None,
logger.warning("Sorted objects by magnitude (brightest first).")

def getNObjects(self, logger=None):
self._load(logger)
# Note: This method name is required by the config parser.
return len(self.id)
return len(self._id)

def getApproxNObjects(self, logger=None):
return self.approx_nobjects or self.getNObjects()
if self._id is None:
# If we haven't read the file yet, just (over-)estimate the number by
# quickly counting the lines in the file without doing any processing.
with fopen(self.file_name, mode='rt') as _input:
# generators don't implement len(); this is the tricky workaround that doesn't
# store all the data in memory (like len(list(_input))).
return sum(1 for _ in _input)
else:
return self.getNObjects(logger)

@property
def nobjects(self):
Expand All @@ -331,18 +358,23 @@ def nobjects(self):
# attribute directly. Since input objects such as this are used via proxy in
# multiprocessing contexts, we need to keep the method version around.
def getID(self, index):
return self.id[index]
self._load()
return self._id[index]

def getWorldPos(self, index):
self._load()
return self.world_pos[index]

def getImagePos(self, index):
self._load()
return self.image_pos[index]

def getMagNorm(self, index):
self._load()
return self.magnorm[index]

def getSED(self, index):
self._load()
# These require reading in an input file. So cache the raw (unredshifted versions)
# to try to minimize how much I/O we'll need for these.
name, redshift = self.sed[index]
Expand Down Expand Up @@ -395,6 +427,7 @@ def getSED(self, index):
return sed

def getLens(self, index):
self._load()
# The galsim.lens(...) function wants to be passed reduced
# shears and magnification, so convert the WL parameters as
# defined in phosim instance catalogs to these values. See
Expand All @@ -407,6 +440,7 @@ def getLens(self, index):
return g1,g2,mu

def getDust(self, index):
self._load()
params = self.dust[index]
if params[0].lower() != 'none':
internal_av = float(params[1])
Expand All @@ -426,7 +460,8 @@ def getDust(self, index):

return internal_av, internal_rv, galactic_av, galactic_rv

def getObj(self, index, gsparams=None, rng=None, exptime=30):
def getObj(self, index, gsparams=None, rng=None, exptime=30, logger=None):
self._load(logger)
if self.objinfo.size == 0:
raise RuntimeError("Trying to get an object from an empty instance catalog")
params = self.objinfo[index]
Expand Down Expand Up @@ -534,7 +569,7 @@ def InstCatObj(config, base, ignore, gsparams, logger):
# Setup the indexing sequence if it hasn't been specified.
# The normal thing with a catalog is to just use each object in order,
# so we don't require the user to specify that by hand. We can do it for them.
galsim.config.SetDefaultIndex(config, inst.getNObjects())
galsim.config.SetDefaultIndex(config, inst.getNObjects(logger))

req = { 'index' : int }
opt = { 'num' : int }
Expand All @@ -544,7 +579,7 @@ def InstCatObj(config, base, ignore, gsparams, logger):
rng = galsim.config.GetRNG(config, base, logger, 'InstCatObj')
exptime = base.get('exptime', 30)

obj = inst.getObj(index, gsparams=gsparams, rng=rng, exptime=exptime)
obj = inst.getObj(index, gsparams=gsparams, rng=rng, exptime=exptime, logger=logger)
base['object_id'] = inst.getID(index)

return obj, safe
Expand Down Expand Up @@ -581,7 +616,7 @@ def buildSED(self, config, base, logger):
"""
inst = galsim.config.GetInputObj('instance_catalog', config, base, 'InstCatWorldPos')

galsim.config.SetDefaultIndex(config, inst.getNObjects())
galsim.config.SetDefaultIndex(config, inst.getNObjects(logger))

req = { 'index' : int }
opt = { 'num' : int }
Expand All @@ -602,7 +637,6 @@ def getKwargs(self, config, base, logger):
'edge_pix' : float,
'sort_mag' : bool,
'flip_g2' : bool,
'approx_nobjects' : int,
'pupil_area' : float,
'min_source' : int,
'skip_invalid' : bool,
Expand Down
14 changes: 7 additions & 7 deletions tests/test_instcat_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,14 @@ def test_object_extraction_galaxies(self):
# Note: the truth catalog apparently didn't flip the g2 values, so use flip_g2=False.
cat = all_cats[det_name] = imsim.InstCatalog(galaxy_phosim_file, all_wcs[det_name],
sed_dir=sed_dir, edge_pix=0, flip_g2=False)
print(det_name, cat.getNObjects(), cat.getApproxNObjects())
assert cat.getApproxNObjects() == cat.getNObjects()
approx_nobj = cat.getApproxNObjects() # This is only different from getNObjects
# if it is called first.
nobj = cat.getNObjects()
print(det_name, nobj, approx_nobj)
assert approx_nobj > nobj

cat2 = imsim.InstCatalog(galaxy_phosim_file, all_wcs[det_name],
sed_dir=sed_dir, edge_pix=0, flip_g2=False,
approx_nobjects=10**5)
assert cat2.getApproxNObjects() == 10**5
assert cat2.getNObjects() == cat.getNObjects()
# After loading, the two values are equal.
assert cat.getApproxNObjects() == cat.getNObjects()

id_arr = np.concatenate([cat.id for cat in all_cats.values()])
print('diff1 = ',set(truth_data['uniqueId'])-set(id_arr))
Expand Down

0 comments on commit 25f868b

Please sign in to comment.