Skip to content

Commit

Permalink
潜在空間の描画をpcolormeshからimshowに変更しオプションにinterpolationを追加
Browse files Browse the repository at this point in the history
  • Loading branch information
ae14watanabe committed Jun 2, 2020
1 parent 6a4c7b4 commit a7c5852
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions libs/models/unsupervised_kernel_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def inverse_transform(self, Znew):

def visualize(self, n_grid_points=30, cmap=None, label_data=None, label_feature=None,
title_latent_space=None, title_feature_bars=None, is_show_all_label_data=False,
fig=None, fig_size=None, ax_latent_space=None, ax_feature_bars=None):
interpolation=None, fig=None, fig_size=None, ax_latent_space=None, ax_feature_bars=None):
"""Visualize fit model interactively.
The dataset can be visualized in an exploratory way using the latent variables and the mapping estimated by UKR.
When an arbitrary coordinate on the latent space is specified, the corresponding feature is displayed as a bar.
Expand All @@ -199,6 +199,8 @@ def visualize(self, n_grid_points=30, cmap=None, label_data=None, label_feature=
:param is_show_all_label_data: bool, optional, default = False
When True the labels of the data is always shown.
When False the label is only shown when the cursor overlaps the corresponding latent variable.
:param interpolation: str, optional, default = None
Interpolation method by imshow.
:param fig: matplotlib.figure.Figure, default = True
The figure to visualize.
It is assigned only when you want to specify a figure to visualize.
Expand All @@ -218,7 +220,7 @@ def visualize(self, n_grid_points=30, cmap=None, label_data=None, label_feature=

self._initialize_to_visualize(n_grid_points, cmap, label_data, label_feature,
title_latent_space, title_feature_bars, is_show_all_label_data,
fig, fig_size, ax_latent_space, ax_feature_bars)
interpolation, fig, fig_size, ax_latent_space, ax_feature_bars)

self._draw_latent_space()
self._draw_feature_bars()
Expand Down Expand Up @@ -259,7 +261,7 @@ def __mouse_over_fig(self, event):

def _initialize_to_visualize(self, n_grid_points, cmap, label_data, label_feature,
title_latent_space, title_feature_bars, is_show_all_label_data,
fig, fig_size, ax_latent_space, ax_feature_bars):
interpolation, fig, fig_size, ax_latent_space, ax_feature_bars):
# invalid check
if self.n_components != 2:
raise ValueError('Now support only n_components = 2')
Expand Down Expand Up @@ -328,6 +330,7 @@ def _initialize_to_visualize(self, n_grid_points, cmap, label_data, label_featur
self.ax_feature_bars = ax_feature_bars

self.cmap = cmap
self.interpolation = interpolation
self.click_point_latent_space = None # index of the clicked representative point
self.clicked_mapping = self.X.mean(axis=0)
self.is_initial_view = True
Expand Down Expand Up @@ -388,10 +391,30 @@ def _draw_latent_space(self):
# To draw by pcolormesh and contour, reshape arrays like grid
grid_values_to_draw_3d = self.__unflatten_grid_array(self.grid_values_to_draw)
grid_points_3d = self.__unflatten_grid_array(self.grid_points)
self.ax_latent_space.pcolormesh(grid_points_3d[:, :, 0],
grid_points_3d[:, :, 1],
grid_values_to_draw_3d,
cmap=self.cmap)
# set coordinate of axis
any_index = 0
if grid_points_3d[any_index, 0, 0] < grid_points_3d[any_index, -1, 0]:
coordinate_ax_left = grid_points_3d[any_index, 0, 0]
coordinate_ax_right = grid_points_3d[any_index, -1, 0]
else:
coordinate_ax_left = grid_points_3d[any_index, -1, 0]
coordinate_ax_right = grid_points_3d[any_index, 0, 0]
grid_values_to_draw_3d = np.flip(grid_values_to_draw_3d, axis=1).copy()

if grid_points_3d[-1, any_index, 1] < grid_points_3d[0, any_index, 1]:
coordinate_ax_bottom = grid_points_3d[-1, any_index, 1]
coordinate_ax_top = grid_points_3d[0, any_index, 1]
else:
coordinate_ax_bottom = grid_points_3d[0, any_index, 1]
coordinate_ax_top = grid_points_3d[-1, any_index, 1]
grid_values_to_draw_3d = np.flip(grid_values_to_draw_3d, axis=0).copy()
self.ax_latent_space.imshow(grid_values_to_draw_3d,
extent=[coordinate_ax_left,
coordinate_ax_right,
coordinate_ax_bottom,
coordinate_ax_top],
interpolation=self.interpolation,
cmap=self.cmap)
ctr = self.ax_latent_space.contour(grid_points_3d[:, :, 0],
grid_points_3d[:, :, 1],
grid_values_to_draw_3d, 6, colors='k')
Expand Down

0 comments on commit a7c5852

Please sign in to comment.