diff --git a/google/datalab/contrib/mlworkbench/commands/_ml.py b/google/datalab/contrib/mlworkbench/commands/_ml.py
index ed6814196..6351d40a6 100644
--- a/google/datalab/contrib/mlworkbench/commands/_ml.py
+++ b/google/datalab/contrib/mlworkbench/commands/_ml.py
@@ -386,10 +386,12 @@ def ml(line, cell=None):
'explain',
formatter_class=argparse.RawTextHelpFormatter,
help='Explain a prediction with LIME tool.')
- explain_parser.add_argument('--type', required=True, choices=['text', 'image'],
+ explain_parser.add_argument('--type', default='all', choices=['text', 'image', 'tabular', 'all'],
help='the type of column to explain.')
explain_parser.add_argument('--algorithm', choices=['lime', 'ig'], default='lime',
- help='the type of column to explain.')
+ help='"lime" is the open sourced project for prediction explainer.' +
+ '"ig" means integrated gradients and currently only applies ' +
+ 'to image.')
explain_parser.add_argument('--model', required=True,
help='path of the model directory used for prediction.')
explain_parser.add_argument('--labels', required=True,
@@ -398,11 +400,13 @@ def ml(line, cell=None):
help='the name of the column to explain. Optional if text type ' +
'and there is only one text column, or image type and ' +
'there is only one image column.')
- explain_parser.add_argument('--hide_color', type=int, default=0,
- help='the color to use for perturbed area. If -1, average of ' +
- 'each channel is used for each channel. For image only.')
explain_parser.add_cell_argument('data', required=True,
help='Prediction Data. Can be a csv line, or a dict.')
+ explain_parser.add_cell_argument('training_data',
+ help='A csv or bigquery dataset defined by %%ml dataset. ' +
+ 'Used by tabular explainer only to determine the ' +
+ 'distribution of numeric and categorical values. ' +
+ 'Suggest using original training dataset.')
# options specific for lime
explain_parser.add_argument('--num_features', type=int,
@@ -411,8 +415,14 @@ def ml(line, cell=None):
explain_parser.add_argument('--num_samples', type=int,
help='size of the neighborhood to learn the linear model. ' +
'For lime only.')
+ explain_parser.add_argument('--hide_color', type=int, default=0,
+ help='the color to use for perturbed area. If -1, average of ' +
+ 'each channel is used for each channel. For image only.')
explain_parser.add_argument('--include_negative', action='store_true', default=False,
help='whether to show only positive areas. For lime image only.')
+ explain_parser.add_argument('--overview', action='store_true', default=False,
+ help='whether to show overview instead of details view.' +
+ 'For lime text and tabular only.')
explain_parser.add_argument('--batch_size', type=int, default=100,
help='size of batches passed to prediction. For lime only.')
@@ -857,43 +867,147 @@ def _batch_predict(args, cell):
print('done.')
-def _explain(args, cell):
- explainer = _prediction_explainer.PredictionExplainer(args['model'])
- labels = args['labels'].split(',')
- if args['type'] == 'text':
- if args['algorithm'] == 'ig':
- raise ValueError('Algorithm "ig" does not support text type.')
+# Helper classes for explainer. Each for is for a combination
+# of algorithm (LIME, IG) and type (text, image, tabular)
+# ===========================================================
+class _TextLimeExplainerInstance(object):
+ def __init__(self, explainer, labels, args):
num_features = args['num_features'] if args['num_features'] else 10
num_samples = args['num_samples'] if args['num_samples'] else 5000
- exp = explainer.explain_text(labels, args['data'], column_name=args['column_name'],
- num_features=num_features, num_samples=num_samples)
- exp.show_in_notebook()
-
- elif args['type'] == 'image':
- if args['algorithm'] == 'lime':
- num_features = args['num_features'] if args['num_features'] else 3
- num_samples = args['num_samples'] if args['num_samples'] else 300
- hide_color = None if args['hide_color'] == -1 else args['hide_color']
- exp = explainer.explain_image(labels, args['data'], column_name=args['column_name'],
- num_samples=num_samples, batch_size=args['batch_size'],
- hide_color=hide_color)
- for i in range(len(labels)):
- image, mask = exp.get_image_and_mask(i, positive_only=not args['include_negative'],
- num_features=num_features, hide_rest=False)
- fig = plt.figure()
- fig.suptitle(labels[i], fontsize=16)
- plt.imshow(mark_boundaries(image, mask))
- elif args['algorithm'] == 'ig':
- explainer = _prediction_explainer.PredictionExplainer(args['model'])
- ret = explainer.probe_image(labels, args['data'], column_name=args['column_name'],
- num_scaled_images=args['num_gradients'],
- top_percent=args['percent_show'])
- raw_image, analysis_images = ret
- IPython.display.display(raw_image)
- for label, image in zip(labels, analysis_images):
- print(label)
- IPython.display.display(image)
+ self._exp = explainer.explain_text(
+ labels, args['data'], column_name=args['column_name'],
+ num_features=num_features, num_samples=num_samples)
+ self._col_name = args['column_name'] if args['column_name'] else explainer._text_columns[0]
+ self._show_overview = args['overview']
+
+ def visualize(self, label_index):
+ if self._show_overview:
+ IPython.display.display(
+ IPython.display.HTML('
Text Column "%s"
' % self._col_name))
+ self._exp.show_in_notebook(labels=[label_index])
+ else:
+ fig = self._exp.as_pyplot_figure(label=label_index)
+ # Clear original title set by lime.
+ plt.title('')
+ fig.suptitle('Text Column "%s"' % self._col_name, fontsize=16)
+ plt.close(fig)
+ IPython.display.display(fig)
+
+
+class _ImageLimeExplainerInstance(object):
+
+ def __init__(self, explainer, labels, args):
+ num_samples = args['num_samples'] if args['num_samples'] else 300
+ hide_color = None if args['hide_color'] == -1 else args['hide_color']
+ self._exp = explainer.explain_image(
+ labels, args['data'], column_name=args['column_name'],
+ num_samples=num_samples, batch_size=args['batch_size'], hide_color=hide_color)
+ self._labels = labels
+ self._positive_only = not args['include_negative']
+ self._num_features = args['num_features'] if args['num_features'] else 3
+ self._col_name = args['column_name'] if args['column_name'] else explainer._image_columns[0]
+
+ def visualize(self, label_index):
+ image, mask = self._exp.get_image_and_mask(
+ label_index,
+ positive_only=self._positive_only,
+ num_features=self._num_features, hide_rest=False)
+ fig = plt.figure()
+ fig.suptitle('Image Column "%s"' % self._col_name, fontsize=16)
+ plt.grid(False)
+ plt.imshow(mark_boundaries(image, mask))
+ plt.close(fig)
+ IPython.display.display(fig)
+
+
+class _ImageIgExplainerInstance(object):
+
+ def __init__(self, explainer, labels, args):
+ self._raw_image, self._analysis_images = explainer.probe_image(
+ labels, args['data'], column_name=args['column_name'],
+ num_scaled_images=args['num_gradients'], top_percent=args['percent_show'])
+ self._labels = labels
+ self._col_name = args['column_name'] if args['column_name'] else explainer._image_columns[0]
+
+ def visualize(self, label_index):
+ # Show both resized raw image and analyzed image.
+ fig = plt.figure()
+ fig.suptitle('Image Column "%s"' % self._col_name, fontsize=16)
+ plt.grid(False)
+ plt.imshow(self._analysis_images[label_index])
+ plt.close(fig)
+ IPython.display.display(fig)
+
+
+class _TabularLimeExplainerInstance(object):
+
+ def __init__(self, explainer, labels, args):
+ if not args['training_data']:
+ raise ValueError('tabular explanation requires training_data to determine ' +
+ 'values distribution.')
+
+ training_data = get_dataset_from_arg(args['training_data'])
+ if (not isinstance(training_data.train, datalab_ml.CsvDataSet) and
+ not isinstance(training_data.train, datalab_ml.BigQueryDataSet)):
+ raise ValueError('Require csv or bigquery dataset.')
+
+ sample_size = min(training_data.train.size, 10000)
+ training_df = training_data.train.sample(sample_size)
+ num_features = args['num_features'] if args['num_features'] else 5
+ self._exp = explainer.explain_tabular(training_df, labels, args['data'],
+ num_features=num_features)
+ self._show_overview = args['overview']
+
+ def visualize(self, label_index):
+ if self._show_overview:
+ IPython.display.display(
+ IPython.display.HTML('
All Categorical and Numeric Columns
'))
+ self._exp.show_in_notebook(labels=[label_index])
+ else:
+ fig = self._exp.as_pyplot_figure(label=label_index)
+ # Clear original title set by lime.
+ plt.title('')
+ fig.suptitle(' All Categorical and Numeric Columns', fontsize=16)
+ plt.close(fig)
+ IPython.display.display(fig)
+
+# End of Explainer Helper Classes
+# ===================================================
+
+
+def _explain(args, cell):
+
+ explainer = _prediction_explainer.PredictionExplainer(args['model'])
+ labels = args['labels'].split(',')
+ instances = []
+ if args['type'] == 'all':
+ if explainer._numeric_columns or explainer._categorical_columns:
+ instances.append(_TabularLimeExplainerInstance(explainer, labels, args))
+ for col_name in explainer._text_columns:
+ args['column_name'] = col_name
+ instances.append(_TextLimeExplainerInstance(explainer, labels, args))
+ for col_name in explainer._image_columns:
+ args['column_name'] = col_name
+ if args['algorithm'] == 'lime':
+ instances.append(_ImageLimeExplainerInstance(explainer, labels, args))
+ elif args['algorithm'] == 'ig':
+ instances.append(_ImageIgExplainerInstance(explainer, labels, args))
+
+ elif args['type'] == 'text':
+ instances.append(_TextLimeExplainerInstance(explainer, labels, args))
+ elif args['type'] == 'image' and args['algorithm'] == 'lime':
+ instances.append(_ImageLimeExplainerInstance(explainer, labels, args))
+ elif args['type'] == 'image' and args['algorithm'] == 'ig':
+ instances.append(_ImageIgExplainerInstance(explainer, labels, args))
+ elif args['type'] == 'tabular':
+ instances.append(_TabularLimeExplainerInstance(explainer, labels, args))
+
+ for i, label in enumerate(labels):
+ IPython.display.display(
+ IPython.display.HTML('
Explaining features for label "%s"
' % label))
+ for instance in instances:
+ instance.visualize(i)
def _tensorboard_start(args, cell):