diff --git a/eqcorrscan/core/match_filter/detection.py b/eqcorrscan/core/match_filter/detection.py index 9ac074dd..4b11f380 100644 --- a/eqcorrscan/core/match_filter/detection.py +++ b/eqcorrscan/core/match_filter/detection.py @@ -397,7 +397,8 @@ def extract_stream(self, stream, length, prepick, all_vert=False, pick = [p for p in pick if p.waveform_id.channel_code == channel] if len(pick) == 0: - Logger.info("No pick for {0}.{1}".format(station, channel)) + Logger.info( + "No pick for {0}.{1}".format(station, channel)) continue elif len(pick) > 1: Logger.info( @@ -406,13 +407,28 @@ def extract_stream(self, stream, length, prepick, all_vert=False, pick.sort(key=lambda p: p.time) pick = pick[0] cut_start = pick.time - prepick - cut_end = cut_start + length - _st = _st.slice(starttime=cut_start, endtime=cut_end).copy() - # Minimum length check + # Find nearest sample to avoid to too-short length - see #573 for tr in _st: - if abs((tr.stats.endtime - tr.stats.starttime) - + sample_offset = (cut_start - + tr.stats.starttime) * tr.stats.sampling_rate + Logger.debug( + f"Sample offset for slice on {tr.id}: {sample_offset}") + sample_offset //= 1 + # If the sample offset is not a whole number, always take the + # sample before that requested + _tr_cut_start = tr.stats.starttime + ( + sample_offset * tr.stats.delta) + _tr_cut_end = _tr_cut_start + length + Logger.debug( + f"Trimming {tr.id} between {_tr_cut_end} " + f"and {_tr_cut_end}.") + _tr = tr.slice(_tr_cut_start, _tr_cut_end).copy() + Logger.debug( + f"Length: {(_tr.stats.endtime - _tr.stats.starttime)}") + Logger.debug(f"Requested length: {length}") + if abs((_tr.stats.endtime - _tr.stats.starttime) - length) < tr.stats.delta: - cut_stream += tr + cut_stream += _tr else: Logger.info( "Insufficient data length for {0}".format(tr.id)) diff --git a/eqcorrscan/tests/lag_calc_test.py b/eqcorrscan/tests/lag_calc_test.py index 97f33c7c..4a660a3f 100644 --- a/eqcorrscan/tests/lag_calc_test.py +++ b/eqcorrscan/tests/lag_calc_test.py @@ -26,7 +26,6 @@ class SyntheticTests(unittest.TestCase): def setUpClass(cls): np.random.seed(999) print("Setting up class") - np.random.seed(999) samp_rate = 50 t_length = .75 # Make some synthetic templates @@ -34,6 +33,7 @@ def setUpClass(cls): nsta=5, ntemplates=5, nseeds=10, samp_rate=samp_rate, t_length=t_length, max_amp=10, max_lag=15, phaseout="both", jitter=0, noise=False, same_phase=True) + print("Made synthetic data") # Rename channels channel_mapper = {"SYN_Z": "HHZ", "SYN_H": "HHN"} for tr in data: @@ -44,6 +44,7 @@ def setUpClass(cls): party = Party() t = 0 data_start = data[0].stats.starttime + print("Making party") for template, template_seeds in zip(templates, seeds): template_name = "template_{0}".format(t) detections = [] @@ -68,6 +69,8 @@ def setUpClass(cls): family = Family(template=_template, detections=detections) party += family t += 1 + print(f"Made template {template_name}") + print("Made party") cls.party = party cls.data = data cls.t_length = t_length diff --git a/eqcorrscan/tests/matched_filter/match_filter_test.py b/eqcorrscan/tests/matched_filter/match_filter_test.py index fd9cb3e1..9f8ebca1 100644 --- a/eqcorrscan/tests/matched_filter/match_filter_test.py +++ b/eqcorrscan/tests/matched_filter/match_filter_test.py @@ -1606,6 +1606,52 @@ def test_family_catalogs(self): family.detections.append(additional_detection) self.assertEqual(family.catalog, get_catalog(family.detections)) + def test_detection_extract_stream(self): + # Create simple synthetic stream + traces = [] + pick_sids = {p.waveform_id.get_seed_string() for f in self.party + for d in f for p in d.event.picks} + pick_times = sorted([p.time for f in self.party for d in f + for p in d.event.picks]) + first_pick, last_pick = pick_times[0], pick_times[-1] + delta = self.party[0].template.st[0].stats.delta + n_samples = int(((last_pick - first_pick) + 120) / delta) + data = np.arange(n_samples, dtype=int) + for sid in pick_sids: + tr = Trace(data=data.copy()) + tr.stats.starttime = first_pick - 60 + n, s, l, c = sid.split('.') + tr.stats.delta = delta + tr.stats.network = n + tr.stats.station = s + tr.stats.location = l + tr.stats.channel = c + traces.append(tr) + st = Stream(traces=traces) + + # Test straightforward extraction + detection = self.party[0][0] + length, pre_pick = 40.0, 1 / delta + for shift in range(6): + pre_pick -= shift * (delta / 6) # Sub-sample shifting + expected_starts = { + p.waveform_id.get_seed_string(): p.time - pre_pick + for p in detection.event.picks} + cut_st = detection.extract_stream( + stream=st, length=length, prepick=pre_pick) + for sid, expected_start in expected_starts.items(): + Logger.debug(f"Checking for {sid}") + tr = cut_st.select(id=sid) + # Check that we get a returned trace + self.assertTrue(len(tr), 1) + tr = tr[0] + # Check that start is within one sample of the expected start + self.assertLess(abs(tr.stats.starttime - expected_start), + delta) + # Check that the length is correct + returned_length = tr.stats.endtime - tr.stats.starttime + self.assertEqual(length, returned_length) + class TestTemplateGrouping(unittest.TestCase): @classmethod