Skip to content

Commit

Permalink
streamlined masking subsurface plots by confidence, burned-in datapoi…
Browse files Browse the repository at this point in the history
…nt, and the doi. All optional or additive
  • Loading branch information
leonfoks committed Jan 10, 2024
1 parent 1c3641a commit 8518053
Showing 1 changed file with 65 additions and 33 deletions.
98 changes: 65 additions & 33 deletions geobipy/src/inversion/Inference2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def mesh(self):
mesh = hdfRead.read_item(self.hdf_file['/model/mesh/y/edges/posterior/mesh'], skip_posterior=True)

# Change positive depth to negative height
mesh.y.edges = StatArray.StatArray(-out.y.edges, name='Height', units=self.y.units)
mesh.y.edges = StatArray.StatArray(-mesh.y.edges, name='Height', units=self.y.units)
mesh.y.relativeTo = self.elevation
mesh.x.centres = self.longest_coordinate

Expand Down Expand Up @@ -470,7 +470,8 @@ def compute_median_parameter(self, log=None, track=True):
def compute_mode_parameter(self, log=None, track=True):

posterior = self.parameter_posterior()
mean = posterior.mode(axis=1)

mode = posterior.mode(axis=1)

# if self.mode == 'r+':
# key = 'mean_parameter'
Expand All @@ -481,7 +482,7 @@ def compute_mode_parameter(self, log=None, track=True):
# self.hdf_file[key].attrs['name'] = mean.values.name
# self.hdf_file[key].attrs['units'] = mean.values.units

return mean
return mode


def compute_doi(self, percent=67.0, smooth=None, track=True):
Expand Down Expand Up @@ -1170,10 +1171,16 @@ def plot_burned_in(self, **kwargs):

x = self.axis(kwargs.pop('x', 'x'))
cmap = plt.get_cmap(kwargs.pop('cmap', 'cividis'))
labels = kwargs.pop('labels', True)

plt.fill_between(x, self.burned_in, 0.0, step='mid', color=cmap(1.0), label="")
plt.fill_between(x, 1-self.burned_in, 0.0, step='mid', color=cmap(0.0), label="")
ax = kwargs.pop('ax', plt.gca())

ylim = (0.0, 1.0)
if kwargs.pop('underlay', False):
kwargs['alpha'] = kwargs.get('alpha', 0.5)
ylim = ax.get_ylim()

plt.fill_between(x, ylim[1], ylim[0], step='mid', color=cmap(1.0), label="", **kwargs)
plt.fill_between(x, (ylim[1]-ylim[0])*(1-self.burned_in)+ylim[0], ylim[0], step='mid', color=cmap(0.0), label="", **kwargs)


def plot_channel_saturation(self, **kwargs):
Expand All @@ -1182,7 +1189,6 @@ def plot_channel_saturation(self, **kwargs):
labels = kwargs.pop('labels', True)
kwargs['color'] = kwargs.pop('color', 'k')
kwargs['linewidth'] = kwargs.pop('linewidth', 0.5)
kwargs['ylim'] = [kwargs.pop('ymin', 0.0), kwargs.pop('ymax', 1.0)]

self.data.plot(values=self.data.channel_saturation, **kwargs)

Expand Down Expand Up @@ -1219,15 +1225,13 @@ def plot_doi(self, **kwargs):
self.data.plot(values=self.doi, axis=1, **kwargs)

def plot_elevation(self, **kwargs):

kwargs['x'] = kwargs.pop('x', 'x')
labels = kwargs.pop('labels', True)
kwargs['color'] = kwargs.pop('color','k')
kwargs['linewidth'] = kwargs.pop('linewidth',0.5)

self.data.plot(values=self.data.elevation, **kwargs)


def plot_k_layers(self, **kwargs):
""" Plot the number of layers in the best model for each data point """
post = self.model.nCells.posterior
Expand Down Expand Up @@ -1282,6 +1286,8 @@ def plot_confidence(self, **kwargs):
opacity = self.opacity()
opacity.mesh.x.centres = self.data.axis(kwargs.get('x', 'x'))

mask, kwargs = self.mask(opacity, **kwargs); kwargs['alpha'] = mask

ax, pm, cb = opacity.pcolor(ticks=[0.0, 0.5, 1.0], **kwargs)

if cb is not None:
Expand All @@ -1295,6 +1301,9 @@ def plot_entropy(self, **kwargs):

entropy = self.entropy
entropy.mesh.x.centres = self.data.axis(kwargs.get('x', 'x'))

mask, kwargs = self.mask(entropy, **kwargs); kwargs['alpha'] = mask

entropy.pcolor(**kwargs)

# def plotError2DJointProbabilityDistribution(self, index, system=0, **kwargs):
Expand All @@ -1317,6 +1326,9 @@ def plot_interfaces(self, cut=0.0, **kwargs):

interfaces = self.interface_probability()
interfaces.mesh.x.centres = self.data.axis(kwargs.get('x', 'x'))

mask, kwargs = self.mask(interfaces, **kwargs); kwargs['alpha'] = mask

interfaces.pcolor(**kwargs)


Expand Down Expand Up @@ -1475,6 +1487,12 @@ def parameterHistogram(self, nBins, depth = None, depth2 = None, log=None):

def plot_best_model(self, **kwargs):
self.model.x.centres = self.data.axis(kwargs.get('x', 'x'))

kwargs['mask_by_confidence'] = False
kwargs['mask_by_doi'] = False

mask, kwargs = self.mask(self.model, **kwargs); kwargs['alpha'] = mask

return self.model.pcolor(**kwargs);

# def plot_cross_section(self, values, **kwargs):
Expand Down Expand Up @@ -1543,17 +1561,35 @@ def plot_marginal_probabilities(self, **kwargs):
gs2 = gridspec.GridSpec(nrows=1, ncols=1, left=0.92, right=0.93, bottom=cbar_b, top=0.95, wspace=0.01)
gs3 = gridspec.GridSpec(nrows=1, ncols=1, left=0.92, right=0.93, bottom=0.06, top=cbar_a, wspace=0.01)

def mask(self, model, **kwargs):

from pprint import pprint
mask = None
if kwargs.pop('mask_by_confidence', False):
mask = self.opacity().values

if kwargs.pop('mask_by_burned_in', True):
if mask is not None:
mask *= self.burned_in_mask(model)
else:
mask = self.burned_in_mask(model)

if kwargs.pop('mask_by_doi', False):
if mask is not None:
mask *= self.doi_mask(model)
else:
mask = self.doi_mask(model)

return mask, kwargs


def plot_mean_model(self, **kwargs):

model = self.mean_parameters()

model.mesh.x.centres = self.data.axis(kwargs.get('x', 'x'))

if kwargs.pop('use_variance', False):
kwargs['alpha'] = self.opacity().values

if kwargs.pop('mask_below_doi', False):
kwargs['alpha'] = self.doi_mask(model, **kwargs)
mask, kwargs = self.mask(model, **kwargs); kwargs['alpha'] = mask

return model.pcolor(**kwargs)

Expand All @@ -1563,11 +1599,7 @@ def plot_median_model(self, **kwargs):

model.mesh.x.centres = self.data.axis(kwargs.get('x', 'x'))

if kwargs.pop('use_variance', False):
kwargs['alpha'] = self.opacity().values

if kwargs.pop('mask_below_doi', False):
kwargs['alpha'] = self.doi_mask(model, **kwargs)
mask, kwargs = self.mask(model, **kwargs); kwargs['alpha'] = mask

return model.pcolor(**kwargs)

Expand All @@ -1577,28 +1609,26 @@ def plot_mode_model(self, **kwargs):

model.mesh.x.centres = self.data.axis(kwargs.get('x', 'x'))

if kwargs.pop('use_variance', False):
kwargs['alpha'] = self.opacity().values

if kwargs.pop('mask_below_doi', False):
kwargs['alpha'] = self.doi_mask(model, **kwargs)
mask, kwargs = self.mask(model, **kwargs); kwargs['alpha'] = mask

return model.pcolor(**kwargs)

def doi_mask(self, model, **kwargs):
opacity = kwargs.get('alpha', None)
if opacity is None:
opacity = ones(model.mesh.shape)
else:
opacity = opacity.copy()
def doi_mask(self, model):

mask = ones(model.shape)
indices = model.mesh.y.cellIndex(self.doi + model.mesh.y.relativeTo)


for i in range(self.nPoints):
opacity[i, indices[i]:] = 0.0
mask[i, indices[i]:] = 0.0

return mask

def burned_in_mask(self, model):
mask = ones(model.shape)
mask[~self.burned_in, :] = 0.0

return mask

return opacity

# def plotModeModel(self, **kwargs):

Expand All @@ -1617,6 +1647,8 @@ def plot_percentile(self, percent, **kwargs):

percentile = posterior.percentile(percent, axis=1)

mask, kwargs = self.mask(percentile, **kwargs); kwargs['alpha'] = mask

return percentile.pcolor(**kwargs)

def marginal_probability(self, slic=None):
Expand Down

0 comments on commit 8518053

Please sign in to comment.