Skip to content
This repository has been archived by the owner on Sep 3, 2022. It is now read-only.

Commit

Permalink
Update %%ml explain command to support "tabular" type, and also suppo…
Browse files Browse the repository at this point in the history
…rt "all" type. In "all" type, it walks through all explainable columns --- categorical, text, numeric, and images, and display results for each. (#604)
  • Loading branch information
qimingj authored Nov 9, 2017
1 parent 2295d7b commit 7a472e9
Showing 1 changed file with 153 additions and 39 deletions.
192 changes: 153 additions & 39 deletions google/datalab/contrib/mlworkbench/commands/_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,12 @@ def ml(line, cell=None):
'explain',
formatter_class=argparse.RawTextHelpFormatter,
help='Explain a prediction with LIME tool.')
explain_parser.add_argument('--type', required=True, choices=['text', 'image'],
explain_parser.add_argument('--type', default='all', choices=['text', 'image', 'tabular', 'all'],
help='the type of column to explain.')
explain_parser.add_argument('--algorithm', choices=['lime', 'ig'], default='lime',
help='the type of column to explain.')
help='"lime" is the open sourced project for prediction explainer.' +
'"ig" means integrated gradients and currently only applies ' +
'to image.')
explain_parser.add_argument('--model', required=True,
help='path of the model directory used for prediction.')
explain_parser.add_argument('--labels', required=True,
Expand All @@ -398,11 +400,13 @@ def ml(line, cell=None):
help='the name of the column to explain. Optional if text type ' +
'and there is only one text column, or image type and ' +
'there is only one image column.')
explain_parser.add_argument('--hide_color', type=int, default=0,
help='the color to use for perturbed area. If -1, average of ' +
'each channel is used for each channel. For image only.')
explain_parser.add_cell_argument('data', required=True,
help='Prediction Data. Can be a csv line, or a dict.')
explain_parser.add_cell_argument('training_data',
help='A csv or bigquery dataset defined by %%ml dataset. ' +
'Used by tabular explainer only to determine the ' +
'distribution of numeric and categorical values. ' +
'Suggest using original training dataset.')

# options specific for lime
explain_parser.add_argument('--num_features', type=int,
Expand All @@ -411,8 +415,14 @@ def ml(line, cell=None):
explain_parser.add_argument('--num_samples', type=int,
help='size of the neighborhood to learn the linear model. ' +
'For lime only.')
explain_parser.add_argument('--hide_color', type=int, default=0,
help='the color to use for perturbed area. If -1, average of ' +
'each channel is used for each channel. For image only.')
explain_parser.add_argument('--include_negative', action='store_true', default=False,
help='whether to show only positive areas. For lime image only.')
explain_parser.add_argument('--overview', action='store_true', default=False,
help='whether to show overview instead of details view.' +
'For lime text and tabular only.')
explain_parser.add_argument('--batch_size', type=int, default=100,
help='size of batches passed to prediction. For lime only.')

Expand Down Expand Up @@ -857,43 +867,147 @@ def _batch_predict(args, cell):
print('done.')


def _explain(args, cell):
explainer = _prediction_explainer.PredictionExplainer(args['model'])
labels = args['labels'].split(',')
if args['type'] == 'text':
if args['algorithm'] == 'ig':
raise ValueError('Algorithm "ig" does not support text type.')
# Helper classes for explainer. Each for is for a combination
# of algorithm (LIME, IG) and type (text, image, tabular)
# ===========================================================
class _TextLimeExplainerInstance(object):

def __init__(self, explainer, labels, args):
num_features = args['num_features'] if args['num_features'] else 10
num_samples = args['num_samples'] if args['num_samples'] else 5000
exp = explainer.explain_text(labels, args['data'], column_name=args['column_name'],
num_features=num_features, num_samples=num_samples)
exp.show_in_notebook()

elif args['type'] == 'image':
if args['algorithm'] == 'lime':
num_features = args['num_features'] if args['num_features'] else 3
num_samples = args['num_samples'] if args['num_samples'] else 300
hide_color = None if args['hide_color'] == -1 else args['hide_color']
exp = explainer.explain_image(labels, args['data'], column_name=args['column_name'],
num_samples=num_samples, batch_size=args['batch_size'],
hide_color=hide_color)
for i in range(len(labels)):
image, mask = exp.get_image_and_mask(i, positive_only=not args['include_negative'],
num_features=num_features, hide_rest=False)
fig = plt.figure()
fig.suptitle(labels[i], fontsize=16)
plt.imshow(mark_boundaries(image, mask))
elif args['algorithm'] == 'ig':
explainer = _prediction_explainer.PredictionExplainer(args['model'])
ret = explainer.probe_image(labels, args['data'], column_name=args['column_name'],
num_scaled_images=args['num_gradients'],
top_percent=args['percent_show'])
raw_image, analysis_images = ret
IPython.display.display(raw_image)
for label, image in zip(labels, analysis_images):
print(label)
IPython.display.display(image)
self._exp = explainer.explain_text(
labels, args['data'], column_name=args['column_name'],
num_features=num_features, num_samples=num_samples)
self._col_name = args['column_name'] if args['column_name'] else explainer._text_columns[0]
self._show_overview = args['overview']

def visualize(self, label_index):
if self._show_overview:
IPython.display.display(
IPython.display.HTML('<br/> Text Column "<b>%s</b>"<br/>' % self._col_name))
self._exp.show_in_notebook(labels=[label_index])
else:
fig = self._exp.as_pyplot_figure(label=label_index)
# Clear original title set by lime.
plt.title('')
fig.suptitle('Text Column "%s"' % self._col_name, fontsize=16)
plt.close(fig)
IPython.display.display(fig)


class _ImageLimeExplainerInstance(object):

def __init__(self, explainer, labels, args):
num_samples = args['num_samples'] if args['num_samples'] else 300
hide_color = None if args['hide_color'] == -1 else args['hide_color']
self._exp = explainer.explain_image(
labels, args['data'], column_name=args['column_name'],
num_samples=num_samples, batch_size=args['batch_size'], hide_color=hide_color)
self._labels = labels
self._positive_only = not args['include_negative']
self._num_features = args['num_features'] if args['num_features'] else 3
self._col_name = args['column_name'] if args['column_name'] else explainer._image_columns[0]

def visualize(self, label_index):
image, mask = self._exp.get_image_and_mask(
label_index,
positive_only=self._positive_only,
num_features=self._num_features, hide_rest=False)
fig = plt.figure()
fig.suptitle('Image Column "%s"' % self._col_name, fontsize=16)
plt.grid(False)
plt.imshow(mark_boundaries(image, mask))
plt.close(fig)
IPython.display.display(fig)


class _ImageIgExplainerInstance(object):

def __init__(self, explainer, labels, args):
self._raw_image, self._analysis_images = explainer.probe_image(
labels, args['data'], column_name=args['column_name'],
num_scaled_images=args['num_gradients'], top_percent=args['percent_show'])
self._labels = labels
self._col_name = args['column_name'] if args['column_name'] else explainer._image_columns[0]

def visualize(self, label_index):
# Show both resized raw image and analyzed image.
fig = plt.figure()
fig.suptitle('Image Column "%s"' % self._col_name, fontsize=16)
plt.grid(False)
plt.imshow(self._analysis_images[label_index])
plt.close(fig)
IPython.display.display(fig)


class _TabularLimeExplainerInstance(object):

def __init__(self, explainer, labels, args):
if not args['training_data']:
raise ValueError('tabular explanation requires training_data to determine ' +
'values distribution.')

training_data = get_dataset_from_arg(args['training_data'])
if (not isinstance(training_data.train, datalab_ml.CsvDataSet) and
not isinstance(training_data.train, datalab_ml.BigQueryDataSet)):
raise ValueError('Require csv or bigquery dataset.')

sample_size = min(training_data.train.size, 10000)
training_df = training_data.train.sample(sample_size)
num_features = args['num_features'] if args['num_features'] else 5
self._exp = explainer.explain_tabular(training_df, labels, args['data'],
num_features=num_features)
self._show_overview = args['overview']

def visualize(self, label_index):
if self._show_overview:
IPython.display.display(
IPython.display.HTML('<br/>All Categorical and Numeric Columns<br/>'))
self._exp.show_in_notebook(labels=[label_index])
else:
fig = self._exp.as_pyplot_figure(label=label_index)
# Clear original title set by lime.
plt.title('')
fig.suptitle(' All Categorical and Numeric Columns', fontsize=16)
plt.close(fig)
IPython.display.display(fig)

# End of Explainer Helper Classes
# ===================================================


def _explain(args, cell):

explainer = _prediction_explainer.PredictionExplainer(args['model'])
labels = args['labels'].split(',')
instances = []
if args['type'] == 'all':
if explainer._numeric_columns or explainer._categorical_columns:
instances.append(_TabularLimeExplainerInstance(explainer, labels, args))
for col_name in explainer._text_columns:
args['column_name'] = col_name
instances.append(_TextLimeExplainerInstance(explainer, labels, args))
for col_name in explainer._image_columns:
args['column_name'] = col_name
if args['algorithm'] == 'lime':
instances.append(_ImageLimeExplainerInstance(explainer, labels, args))
elif args['algorithm'] == 'ig':
instances.append(_ImageIgExplainerInstance(explainer, labels, args))

elif args['type'] == 'text':
instances.append(_TextLimeExplainerInstance(explainer, labels, args))
elif args['type'] == 'image' and args['algorithm'] == 'lime':
instances.append(_ImageLimeExplainerInstance(explainer, labels, args))
elif args['type'] == 'image' and args['algorithm'] == 'ig':
instances.append(_ImageIgExplainerInstance(explainer, labels, args))
elif args['type'] == 'tabular':
instances.append(_TabularLimeExplainerInstance(explainer, labels, args))

for i, label in enumerate(labels):
IPython.display.display(
IPython.display.HTML('<br/>Explaining features for label <b>"%s"</b><br/>' % label))
for instance in instances:
instance.visualize(i)


def _tensorboard_start(args, cell):
Expand Down

0 comments on commit 7a472e9

Please sign in to comment.