diff --git a/google/datalab/contrib/mlworkbench/commands/_ml.py b/google/datalab/contrib/mlworkbench/commands/_ml.py index ed6814196..6351d40a6 100644 --- a/google/datalab/contrib/mlworkbench/commands/_ml.py +++ b/google/datalab/contrib/mlworkbench/commands/_ml.py @@ -386,10 +386,12 @@ def ml(line, cell=None): 'explain', formatter_class=argparse.RawTextHelpFormatter, help='Explain a prediction with LIME tool.') - explain_parser.add_argument('--type', required=True, choices=['text', 'image'], + explain_parser.add_argument('--type', default='all', choices=['text', 'image', 'tabular', 'all'], help='the type of column to explain.') explain_parser.add_argument('--algorithm', choices=['lime', 'ig'], default='lime', - help='the type of column to explain.') + help='"lime" is the open sourced project for prediction explainer.' + + '"ig" means integrated gradients and currently only applies ' + + 'to image.') explain_parser.add_argument('--model', required=True, help='path of the model directory used for prediction.') explain_parser.add_argument('--labels', required=True, @@ -398,11 +400,13 @@ def ml(line, cell=None): help='the name of the column to explain. Optional if text type ' + 'and there is only one text column, or image type and ' + 'there is only one image column.') - explain_parser.add_argument('--hide_color', type=int, default=0, - help='the color to use for perturbed area. If -1, average of ' + - 'each channel is used for each channel. For image only.') explain_parser.add_cell_argument('data', required=True, help='Prediction Data. Can be a csv line, or a dict.') + explain_parser.add_cell_argument('training_data', + help='A csv or bigquery dataset defined by %%ml dataset. ' + + 'Used by tabular explainer only to determine the ' + + 'distribution of numeric and categorical values. ' + + 'Suggest using original training dataset.') # options specific for lime explain_parser.add_argument('--num_features', type=int, @@ -411,8 +415,14 @@ def ml(line, cell=None): explain_parser.add_argument('--num_samples', type=int, help='size of the neighborhood to learn the linear model. ' + 'For lime only.') + explain_parser.add_argument('--hide_color', type=int, default=0, + help='the color to use for perturbed area. If -1, average of ' + + 'each channel is used for each channel. For image only.') explain_parser.add_argument('--include_negative', action='store_true', default=False, help='whether to show only positive areas. For lime image only.') + explain_parser.add_argument('--overview', action='store_true', default=False, + help='whether to show overview instead of details view.' + + 'For lime text and tabular only.') explain_parser.add_argument('--batch_size', type=int, default=100, help='size of batches passed to prediction. For lime only.') @@ -857,43 +867,147 @@ def _batch_predict(args, cell): print('done.') -def _explain(args, cell): - explainer = _prediction_explainer.PredictionExplainer(args['model']) - labels = args['labels'].split(',') - if args['type'] == 'text': - if args['algorithm'] == 'ig': - raise ValueError('Algorithm "ig" does not support text type.') +# Helper classes for explainer. Each for is for a combination +# of algorithm (LIME, IG) and type (text, image, tabular) +# =========================================================== +class _TextLimeExplainerInstance(object): + def __init__(self, explainer, labels, args): num_features = args['num_features'] if args['num_features'] else 10 num_samples = args['num_samples'] if args['num_samples'] else 5000 - exp = explainer.explain_text(labels, args['data'], column_name=args['column_name'], - num_features=num_features, num_samples=num_samples) - exp.show_in_notebook() - - elif args['type'] == 'image': - if args['algorithm'] == 'lime': - num_features = args['num_features'] if args['num_features'] else 3 - num_samples = args['num_samples'] if args['num_samples'] else 300 - hide_color = None if args['hide_color'] == -1 else args['hide_color'] - exp = explainer.explain_image(labels, args['data'], column_name=args['column_name'], - num_samples=num_samples, batch_size=args['batch_size'], - hide_color=hide_color) - for i in range(len(labels)): - image, mask = exp.get_image_and_mask(i, positive_only=not args['include_negative'], - num_features=num_features, hide_rest=False) - fig = plt.figure() - fig.suptitle(labels[i], fontsize=16) - plt.imshow(mark_boundaries(image, mask)) - elif args['algorithm'] == 'ig': - explainer = _prediction_explainer.PredictionExplainer(args['model']) - ret = explainer.probe_image(labels, args['data'], column_name=args['column_name'], - num_scaled_images=args['num_gradients'], - top_percent=args['percent_show']) - raw_image, analysis_images = ret - IPython.display.display(raw_image) - for label, image in zip(labels, analysis_images): - print(label) - IPython.display.display(image) + self._exp = explainer.explain_text( + labels, args['data'], column_name=args['column_name'], + num_features=num_features, num_samples=num_samples) + self._col_name = args['column_name'] if args['column_name'] else explainer._text_columns[0] + self._show_overview = args['overview'] + + def visualize(self, label_index): + if self._show_overview: + IPython.display.display( + IPython.display.HTML('
Text Column "%s"
' % self._col_name)) + self._exp.show_in_notebook(labels=[label_index]) + else: + fig = self._exp.as_pyplot_figure(label=label_index) + # Clear original title set by lime. + plt.title('') + fig.suptitle('Text Column "%s"' % self._col_name, fontsize=16) + plt.close(fig) + IPython.display.display(fig) + + +class _ImageLimeExplainerInstance(object): + + def __init__(self, explainer, labels, args): + num_samples = args['num_samples'] if args['num_samples'] else 300 + hide_color = None if args['hide_color'] == -1 else args['hide_color'] + self._exp = explainer.explain_image( + labels, args['data'], column_name=args['column_name'], + num_samples=num_samples, batch_size=args['batch_size'], hide_color=hide_color) + self._labels = labels + self._positive_only = not args['include_negative'] + self._num_features = args['num_features'] if args['num_features'] else 3 + self._col_name = args['column_name'] if args['column_name'] else explainer._image_columns[0] + + def visualize(self, label_index): + image, mask = self._exp.get_image_and_mask( + label_index, + positive_only=self._positive_only, + num_features=self._num_features, hide_rest=False) + fig = plt.figure() + fig.suptitle('Image Column "%s"' % self._col_name, fontsize=16) + plt.grid(False) + plt.imshow(mark_boundaries(image, mask)) + plt.close(fig) + IPython.display.display(fig) + + +class _ImageIgExplainerInstance(object): + + def __init__(self, explainer, labels, args): + self._raw_image, self._analysis_images = explainer.probe_image( + labels, args['data'], column_name=args['column_name'], + num_scaled_images=args['num_gradients'], top_percent=args['percent_show']) + self._labels = labels + self._col_name = args['column_name'] if args['column_name'] else explainer._image_columns[0] + + def visualize(self, label_index): + # Show both resized raw image and analyzed image. + fig = plt.figure() + fig.suptitle('Image Column "%s"' % self._col_name, fontsize=16) + plt.grid(False) + plt.imshow(self._analysis_images[label_index]) + plt.close(fig) + IPython.display.display(fig) + + +class _TabularLimeExplainerInstance(object): + + def __init__(self, explainer, labels, args): + if not args['training_data']: + raise ValueError('tabular explanation requires training_data to determine ' + + 'values distribution.') + + training_data = get_dataset_from_arg(args['training_data']) + if (not isinstance(training_data.train, datalab_ml.CsvDataSet) and + not isinstance(training_data.train, datalab_ml.BigQueryDataSet)): + raise ValueError('Require csv or bigquery dataset.') + + sample_size = min(training_data.train.size, 10000) + training_df = training_data.train.sample(sample_size) + num_features = args['num_features'] if args['num_features'] else 5 + self._exp = explainer.explain_tabular(training_df, labels, args['data'], + num_features=num_features) + self._show_overview = args['overview'] + + def visualize(self, label_index): + if self._show_overview: + IPython.display.display( + IPython.display.HTML('
All Categorical and Numeric Columns
')) + self._exp.show_in_notebook(labels=[label_index]) + else: + fig = self._exp.as_pyplot_figure(label=label_index) + # Clear original title set by lime. + plt.title('') + fig.suptitle(' All Categorical and Numeric Columns', fontsize=16) + plt.close(fig) + IPython.display.display(fig) + +# End of Explainer Helper Classes +# =================================================== + + +def _explain(args, cell): + + explainer = _prediction_explainer.PredictionExplainer(args['model']) + labels = args['labels'].split(',') + instances = [] + if args['type'] == 'all': + if explainer._numeric_columns or explainer._categorical_columns: + instances.append(_TabularLimeExplainerInstance(explainer, labels, args)) + for col_name in explainer._text_columns: + args['column_name'] = col_name + instances.append(_TextLimeExplainerInstance(explainer, labels, args)) + for col_name in explainer._image_columns: + args['column_name'] = col_name + if args['algorithm'] == 'lime': + instances.append(_ImageLimeExplainerInstance(explainer, labels, args)) + elif args['algorithm'] == 'ig': + instances.append(_ImageIgExplainerInstance(explainer, labels, args)) + + elif args['type'] == 'text': + instances.append(_TextLimeExplainerInstance(explainer, labels, args)) + elif args['type'] == 'image' and args['algorithm'] == 'lime': + instances.append(_ImageLimeExplainerInstance(explainer, labels, args)) + elif args['type'] == 'image' and args['algorithm'] == 'ig': + instances.append(_ImageIgExplainerInstance(explainer, labels, args)) + elif args['type'] == 'tabular': + instances.append(_TabularLimeExplainerInstance(explainer, labels, args)) + + for i, label in enumerate(labels): + IPython.display.display( + IPython.display.HTML('
Explaining features for label "%s"
' % label)) + for instance in instances: + instance.visualize(i) def _tensorboard_start(args, cell):