Skip to content

Commit

Permalink
Merge pull request #248 from rmodrak/master
Browse files Browse the repository at this point in the history
 Improved treatment of time shifts
  • Loading branch information
rmodrak authored Feb 2, 2024
2 parents 3730282 + fdd2cda commit 90ad897
Show file tree
Hide file tree
Showing 15 changed files with 733 additions and 224 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-app.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ jobs:
python tests/test_grid_search_mt.py --no_figures
python tests/test_grid_search_mt_depth.py --no_figures
python tests/test_greens_SPECFEM3D_SAC.py --no_figures
python tests/test_time_shifts.py --no_figures
# unfortunately, these Conda installation tests exceed the resource limits
# for GitHub workflows
Expand Down
2 changes: 1 addition & 1 deletion docs/install/issues.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ MTUQ installation on Apple M1 and Apple M2 Macs

Installation on Apple M1 and Apple M2 Macs is now possible using the default installation procedure.

For older versions of MTUQ, a modified installation procedure may stil be necessary. For more information, please see:
For older versions of MTUQ, a modified installation procedure may still be necessary. For more information, please see:

`MTUQ installation on ARM64 systems <https://uafgeotools.github.io/mtuq/install/arm64.html>`_

25 changes: 24 additions & 1 deletion mtuq/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,23 @@ def get_stations(self):
return stations


def get_stats(self):
""" Returns trace metadata in nested lists
.. note ::
For Datasets created using ``mtuq.io.readers``, SAC header metadata
is used to populate the Station attributes
"""
stats = []
for stream in self:
stats += [[]]
for trace in stream:
stats[-1] += [trace.stats]
return stats


def get_origins(self):
""" Returns origin metadata from all streams as a `list` of
`mtuq.event.Origin` objects
Expand Down Expand Up @@ -269,8 +286,12 @@ def __copy__(self):
return new_ds


def copy(self):
return self.__copy__()


def write(self, path, format='sac'):
""" Writes a Python pickle of current dataset
""" Writes dataset to disk
"""
if format.lower() == 'pickle':

Expand All @@ -289,5 +310,7 @@ def write(self, path, format='sac'):
fullpath = '%s/%s.%s' % (path,filename,'sac')
trace.write(fullpath, format='sac')

else:
raise ValueError('Unrecognized file format')


8 changes: 4 additions & 4 deletions mtuq/graphics/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def trace_label_writer(axis, dat, syn, total_misfit=1.):
d = dat.data

# display cross-correlation time shift
time_shift = 0.
time_shift += _getattr(syn, 'time_shift', np.nan)
time_shift += _getattr(dat, 'static_time_shift', 0)
axis.text(0.,(1/4.)*ymin, '%.2f' %time_shift, fontsize=11)
total_shift = 0.
total_shift += _getattr(syn, 'time_shift', np.nan)
total_shift += _getattr(dat, 'static_shift', 0)
axis.text(0.,(1/4.)*ymin, '%.2f' % total_shift, fontsize=11)

# display maximum cross-correlation coefficient
Ns = np.dot(s,s)**0.5
Expand Down
25 changes: 23 additions & 2 deletions mtuq/graphics/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,30 @@
from mtuq.util import defaults, warn


def plot_time_shifts(dirname, attrs, stations, origin, **kwargs):
def plot_time_shifts(dirname, attrs, stations, origin, key='total_shift',
**kwargs):

""" Plots how time shifts vary by location and component
By default, total time shifts are plotted. To plot just static or
cross-correlation time shifts, use ``key='static_shift'`` or
``key='time_shift'``, respectively
.. note ::
MTUQ distinguishes between the following different types of
time shifts
- `static_shift` is an initial user-supplied time shift applied during
data processing
- `time_shift` is a subsequent cross-correlation time shift applied
during misfit evaluation
- `total_shift` is the total correction, or in other words the sum of
static and cross-correlation time shifts
.. rubric :: Required input arguments
``dirname`` (`str`):
Expand All @@ -37,7 +58,7 @@ def plot_time_shifts(dirname, attrs, stations, origin, **kwargs):
'label': 'Time shift (s)',
})

_plot_attrs(dirname, stations, origin, attrs, 'time_shift', **kwargs)
_plot_attrs(dirname, stations, origin, attrs, key, **kwargs)


def plot_amplitude_ratios(dirname, attrs, stations, origin, **kwargs):
Expand Down
5 changes: 3 additions & 2 deletions mtuq/graphics/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,9 @@ def _plot(axis, dat, syn, label=None):
"""
t1,t2,nt,dt = _time_stats(dat)

start = _getattr(syn, 'start', 0)
stop = _getattr(syn, 'stop', len(syn.data))
# which start and stop indices will correctly align synthetics?
start = _getattr(syn, 'idx_start', 0)
stop = _getattr(syn, 'idx_stop', len(syn.data))

t = np.linspace(0,t2-t1,nt,dt)
d = dat.data
Expand Down
2 changes: 1 addition & 1 deletion mtuq/greens_tensor/FK.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _precompute(self):

def _precompute_mt(self):
""" Recombines FK time series so they can be used in straightforward
liner combination with Mrr,Mtt,Mpp,Mrt,Mrp,Mtp
linear combination with Mrr,Mtt,Mpp,Mrt,Mrp,Mtp
"""

array = self._array
Expand Down
41 changes: 23 additions & 18 deletions mtuq/greens_tensor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,24 +142,27 @@ def _get_shape(self):
return nc, nr, nt


def _allocate_stream(self):
def _allocate_stream(self, stats=None):
""" Allocates ObsPy stream used by `get_synthetics`
"""
nc, nr, nt = self._get_shape()

if not stats:
stats = []
for component in self.components:
stats += [self[0].stats.copy()]
stats[-1].update({'npts': nt, 'channel': component})

stream = Stream()
for component in self.components:
# add stats object
stats = self.station.copy()
stats.update({'npts': nt, 'channel': component})
for _i, component in enumerate(self.components):
# add trace object
stream += Trace(np.zeros(nt), stats)
stream += Trace(np.zeros(nt), stats[_i])

return stream



def get_synthetics(self, source, components=None, inplace=False):
def get_synthetics(self, source, components=None, stats=None, inplace=False):
""" Generates synthetics through a linear combination of time series
Returns an ObsPy stream
Expand Down Expand Up @@ -190,7 +193,7 @@ def get_synthetics(self, source, components=None, inplace=False):
if inplace:
synthetics = self._synthetics
else:
synthetics = self._allocate_stream()
synthetics = self._allocate_stream(stats)

for _i, component in enumerate(self.components):
# Even with careful attention to index order, np.dot is very slow.
Expand Down Expand Up @@ -284,7 +287,7 @@ def select(self, selector):
return selected


def get_synthetics(self, source, components=None, mode='apply', **kwargs):
def get_synthetics(self, source, components=None, stats=None, mode='apply', **kwargs):
""" Generates synthetics through a linear combination of time series
Returns an MTUQ `Dataset`
Expand All @@ -302,15 +305,15 @@ def get_synthetics(self, source, components=None, mode='apply', **kwargs):
if mode=='map':
synthetics = Dataset()
for _i, tensor in enumerate(self):
synthetics.append(
tensor.get_synthetics(source, components=components[_i], **kwargs))
synthetics.append(tensor.get_synthetics(
source, components=components[_i], stats=stats[_i], **kwargs))
return synthetics

elif mode=='apply':
synthetics = Dataset()
for tensor in self:
synthetics.append(
tensor.get_synthetics(source, components=components, **kwargs))
synthetics.append(tensor.get_synthetics(
source, components=components, stats=stats, **kwargs))
return synthetics

else:
Expand All @@ -335,8 +338,8 @@ def apply(self, function, *args, **kwargs):
"""
processed = []
for tensor in self:
processed +=\
[function(tensor, *args, **kwargs)]
processed += [function(tensor, *args, **kwargs)]

return self.__class__(processed)


Expand All @@ -359,8 +362,8 @@ def map(self, function, *sequences):
processed = []
for _i, tensor in enumerate(self):
args = [sequence[_i] for sequence in sequences]
processed +=\
[function(tensor, *args)]
processed += [function(tensor, *args)]

return self.__class__(processed)


Expand Down Expand Up @@ -416,7 +419,6 @@ def parallel_map(self, function, *sequences):
return self.__class__(final_list)



def convolve(self, wavelet):
""" Convolves time series with given wavelet
Expand Down Expand Up @@ -486,6 +488,9 @@ def __copy__(self):
return new_ds


def copy(self):
return self.__copy__()


def write(self, filename):
""" Writes a Python pickle of current `GreensTensorList`
Expand Down
11 changes: 8 additions & 3 deletions mtuq/misfit/waveform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def __init__(self,
assert norm in ['L1', 'L2', 'hybrid'],\
ValueError("Bad input argument: norm")

assert time_shift_min <= 0.,\
ValueError("Bad input argument: time_shift_min")

assert time_shift_max >= 0.,\
ValueError("Bad input argument: time_shift_max")

Expand Down Expand Up @@ -176,7 +179,7 @@ def __call__(self, data, greens, sources, progress_handle=Null(),
warn("Empty data set. No misfit evaluations will be carried out")
return np.zeros((len(sources), 1))

# checks that the container legnths are consistent
# checks that the container lengths are consistent
if len(data) != len(greens):
raise Exception("Inconsistent container lengths\n\n "+
"len(data): %d\n len(greens): %d\n" %
Expand Down Expand Up @@ -219,7 +222,8 @@ def collect_attributes(self, data, greens, source):
check_padding(greens, self.time_shift_min, self.time_shift_max)

synthetics = greens.get_synthetics(
source, components=data.get_components(), mode='map', inplace=True)
source, components=data.get_components(), stats=data.get_stats(),
mode='map', inplace=True)

# attaches attributes to synthetics
_ = level0.misfit(
Expand Down Expand Up @@ -254,7 +258,8 @@ def collect_synthetics(self, data, greens, source):
check_padding(greens, self.time_shift_min, self.time_shift_max)

synthetics = greens.get_synthetics(
source, components=data.get_components(), mode='map', inplace=True)
source, components=data.get_components(), stats=data.get_stats(),
mode='map', inplace=True)

# attaches attributes to synthetics
_ = level0.misfit(
Expand Down
Loading

0 comments on commit 90ad897

Please sign in to comment.