Skip to content

Commit

Permalink
Merge pull request #261 from uafgeotools/waveform_normalization
Browse files Browse the repository at this point in the history
Implementation of median amplitude normalization.
  • Loading branch information
rmodrak authored Apr 6, 2024
2 parents 5ebdd79 + da8ea0e commit 109b4e5
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 28 deletions.
2 changes: 1 addition & 1 deletion data/examples/unpack.bash
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cd $(dirname ${BASH_SOURCE[0]})
wd=$PWD

for filename in \
20090407201255351.tgz 20210809074550.tgz 20SPECFEM3D_SGT.tgz SPECFEM3D_SAC.tgz;
20090407201255351.tgz 20210809074550.tgz SPECFEM3D_SGT.tgz SPECFEM3D_SAC.tgz;
do
cd $wd
cd $(dirname $filename)
Expand Down
6 changes: 4 additions & 2 deletions mtuq/graphics/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from matplotlib import pyplot
from obspy.geodetics import gps2dist_azimuth
# location to degree distance with obspy
from obspy.geodetics import locations2degrees


def station_label_writer(ax, station, origin, units='km'):
Expand Down Expand Up @@ -31,7 +33,8 @@ def station_label_writer(ax, station, origin, units='km'):
label = '%d km' % round(distance_in_m/1000.)

elif units=='deg':
label = '%d%s' % (round(m_to_deg(distance_in_m)), u'\N{DEGREE SIGN}')
label = '%d%s' % (round(locations2degrees(origin.latitude, origin.longitude,
station.latitude, station.longitude)), u'\N{DEGREE SIGN}')

pyplot.text(0.2,0.35, label, fontsize=11, transform=ax.transAxes)

Expand Down Expand Up @@ -89,4 +92,3 @@ def _getattr(trace, name, *args):
else:
raise TypeError("Wrong number of arguments")


2 changes: 1 addition & 1 deletion mtuq/graphics/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def parse_data_processing(self):
if not self.process_bw:
pass
if not self.process_sw:
raise Excpetion()
raise Exception()

if self.process_sw.freq_max > 1.:
units = 'Hz'
Expand Down
114 changes: 91 additions & 23 deletions mtuq/graphics/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ def plot_waveforms1(filename,

max_amplitude = _max(data, synthetics)

if normalize == 'median_amplitude':
# Using the updated _median_amplitude function to calculate the median of non-zero maximum amplitudes
max_amplitude_median = _median_amplitude(data, synthetics)
max_amplitudes = np.array([max_amplitude_median if len(data[i]) > 0 and len(synthetics[i]) > 0 else 0.0 for i in range(len(data))])
elif normalize == 'maximum_amplitude':
max_amplitudes = np.array([max_amplitude if len(data[i]) > 0 and len(synthetics[i]) > 0 else 0.0 for i in range(len(data))])
elif normalize == 'station_amplitude' or normalize == 'trace_amplitude':
pass
else:
raise ValueError("Invalid normalization method specified.")

#
# loop over stations
#
Expand Down Expand Up @@ -100,7 +111,7 @@ def plot_waveforms1(filename,
continue

_plot_ZRT(axes[ir], 1, dat, syn, component,
normalize, trace_label_writer, max_amplitude, total_misfit)
normalize, trace_label_writer, max_amplitudes[_i], total_misfit)

ir += 1

Expand Down Expand Up @@ -153,6 +164,23 @@ def plot_waveforms2(filename,
max_amplitude_sw = _max(data_sw, synthetics_sw)


if normalize == 'median_amplitude':
# For body wave data and synthetics
bw_median = _median_amplitude(data_bw, synthetics_bw)
max_amplitudes_bw = np.array([bw_median if len(data_bw[i]) > 0 and len(synthetics_bw[i]) > 0 else 0.0 for i in range(len(data_bw))])

# For surface wave data and synthetics
sw_median = _median_amplitude(data_sw, synthetics_sw)
max_amplitudes_sw = np.array([sw_median if len(data_sw[i]) > 0 and len(synthetics_sw[i]) > 0 else 0.0 for i in range(len(data_sw))])
elif normalize == 'maximum_amplitude':
max_amplitudes_bw = np.array([max_amplitude_bw if len(data_bw[i]) > 0 and len(synthetics_bw[i]) > 0 else 0.0 for i in range(len(data_bw))])
max_amplitudes_sw = np.array([max_amplitude_sw if len(data_sw[i]) > 0 and len(synthetics_sw[i]) > 0 else 0.0 for i in range(len(data_sw))])
elif normalize == 'station_amplitude' or normalize == 'trace_amplitude':
max_amplitudes_bw = np.array([_max(data_bw[i], synthetics_bw[i]) if len(data_bw[i]) > 0 and len(synthetics_bw[i]) > 0 else 0.0 for i in range(len(data_bw))])
max_amplitudes_sw = np.array([_max(data_sw[i], synthetics_sw[i]) if len(data_sw[i]) > 0 and len(synthetics_sw[i]) > 0 else 0.0 for i in range(len(data_sw))])
else:
raise ValueError("Invalid normalization method specified.")

#
# loop over stations
#
Expand Down Expand Up @@ -191,7 +219,7 @@ def plot_waveforms2(filename,
continue

_plot_ZR(axes[ir], 1, dat, syn, component,
normalize, trace_label_writer, max_amplitude_bw, total_misfit_bw)
normalize, trace_label_writer, max_amplitudes_bw[_i], total_misfit_bw)


#
Expand All @@ -216,7 +244,7 @@ def plot_waveforms2(filename,
continue

_plot_ZRT(axes[ir], 3, dat, syn, component,
normalize, trace_label_writer, max_amplitude_sw, total_misfit_sw)
normalize, trace_label_writer, max_amplitudes_sw[_i], total_misfit_sw)


ir += 1
Expand Down Expand Up @@ -373,7 +401,7 @@ def _initialize(nrows=None, ncolumns=None, column_width_ratios=None,

def _plot_ZRT(axes, ic, dat, syn, component,
normalize='maximum_amplitude', trace_label_writer=None,
max_amplitude=1., total_misfit=1.):
normalization_amplitude=1., total_misfit=1.):

# plot traces
if component=='Z':
Expand All @@ -387,17 +415,13 @@ def _plot_ZRT(axes, ic, dat, syn, component,

_plot(axis, dat, syn)

# normalize amplitude
# normalize amplitude -- logic for station_amplitude, median_amplitude, and maximum_amplitude is done at higher level
if normalize=='trace_amplitude':
max_trace = _max(dat, syn)
ylim = [-1.5*max_trace, +1.5*max_trace]
axis.set_ylim(*ylim)
elif normalize=='station_amplitude':
max_stream = _max(stream_dat, stream_syn)
ylim = [-1.5*max_stream, +1.5*max_stream]
axis.set_ylim(*ylim)
elif normalize=='maximum_amplitude':
ylim = [-0.75*max_amplitude, +0.75*max_amplitude]
elif normalize=='station_amplitude' or normalize=='median_amplitude' or normalize=='maximum_amplitude':
ylim = [-1.25*normalization_amplitude, +1.25*normalization_amplitude]
axis.set_ylim(*ylim)

if trace_label_writer is not None:
Expand All @@ -406,7 +430,7 @@ def _plot_ZRT(axes, ic, dat, syn, component,

def _plot_ZR(axes, ic, dat, syn, component,
normalize='maximum_amplitude', trace_label_writer=None,
max_amplitude=1., total_misfit=1.):
normalization_amplitude=1., total_misfit=1.):

# plot traces
if component=='Z':
Expand All @@ -418,20 +442,15 @@ def _plot_ZR(axes, ic, dat, syn, component,

_plot(axis, dat, syn)

# normalize amplitude
# normalize amplitude -- logic for station_amplitude, median_amplitude, and maximum_amplitude is done at higher level
if normalize=='trace_amplitude':
max_trace = _max(dat, syn)
ylim = [-1.5*max_trace, +1.5*max_trace]
axis.set_ylim(*ylim)
elif normalize=='station_amplitude':
max_stream = _max(stream_dat, stream_syn)
ylim = [-1.5*max_stream, +1.5*max_stream]
axis.set_ylim(*ylim)
elif normalize=='maximum_amplitude':
ylim = [-0.75*max_amplitude, +0.75*max_amplitude]
elif normalize=='station_amplitude' or normalize=='median_amplitude' or normalize=='maximum_amplitude':
ylim = [-1.25*normalization_amplitude, +1.25*normalization_amplitude]
axis.set_ylim(*ylim)


if trace_label_writer is not None:
trace_label_writer(axis, dat, syn, total_misfit)

Expand All @@ -450,9 +469,9 @@ def _plot(axis, dat, syn, label=None):
s = syn.data

axis.plot(t, d, 'k', linewidth=1.5,
clip_on=False, zorder=10)
clip_on=True, zorder=10)
axis.plot(t, s[start:stop], 'r', linewidth=1.25,
clip_on=False, zorder=10)
clip_on=True, zorder=10)


def _add_component_labels1(axes, body_wave_labels=True, surface_wave_labels=True):
Expand Down Expand Up @@ -534,6 +553,19 @@ def _isempty(dataset):


def _max(dat, syn):
"""
Computes the maximum value from a set of two input data objects (observed and synthetics).
Parameters:
dat (Trace, Stream, or Dataset): observed data.
syn (Trace, Stream, or Dataset): synthetics.
Returns:
float: The maximum value for normalization purposes.
Raises:
TypeError: If the input objects are not of the same type (Trace, Stream, or Dataset).
"""
if type(dat)==type(syn)==Trace:
return max(
abs(dat.max()),
Expand All @@ -552,6 +584,43 @@ def _max(dat, syn):
else:
raise TypeError

def _median_amplitude(data, synthetics):
"""
Computes the median of the maximum non-zero amplitudes for pairs of data and synthetic traces.
Args:
data: A list of of observed data (can be Trace, Stream, or Dataset objects).
synthetics: A list of synthetic traces corresponding to the observed data.
Returns:
The median of the non-zero maximum amplitudes computed across all pairs.
Raises:
ValueError: If the lengths of data and synthetics lists differ.
"""
# Validate input lengths
# If Trace directly input, make it a list
data = [data] if isinstance(data, Trace) else data
synthetics = [synthetics] if isinstance(synthetics, Trace) else synthetics

# Validate lengths
if len(data) != len(synthetics):
raise ValueError("Data and synthetics lists must have the same length.")

max_amplitudes = []

# Iterate over pairs and handle empty traces - This gets a list of maximum amplitudes for each pair of data and synthetics
for dat, syn in zip(data, synthetics):
if not dat or not syn:
max_amplitudes.append(0)
else:
max_amplitudes.append(_max(dat, syn))

# Convert to NumPy array for efficient filtering
max_amplitudes = np.array(max_amplitudes)

# Compute median of non-zero values or return 0 if none exist
return np.median(max_amplitudes[max_amplitudes > 0]) if np.any(max_amplitudes > 0) else 0.0


def _hide_axes(axes):
Expand Down Expand Up @@ -597,4 +666,3 @@ def _get_tag(tags, pattern):
else:
return None


2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def run_tests(self):
# (consider using a conda based installation instead)
install_requires=[
"numpy",
"scipy",
"scipy<1.13.0",
"pandas",
"xarray",
"netCDF4",
Expand Down

0 comments on commit 109b4e5

Please sign in to comment.