Skip to content

Commit

Permalink
Add CLEVER metric
Browse files Browse the repository at this point in the history
  • Loading branch information
Irina Nicolae authored and Irina Nicolae committed Apr 25, 2018
2 parents f687a9d + 7dbc91a commit 28397f5
Show file tree
Hide file tree
Showing 4 changed files with 312 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ install:

script:
- mkdir ./data
- python -m unittest discover src/ -p '*_unittest.py'
- python -m unittest discover art/ -p '*_unittest.py'
205 changes: 188 additions & 17 deletions art/metrics.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,3 @@
# MIT License
#
# Copyright (C) IBM Corporation 2018
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Module implementing varying metrics for assessing model robustness. These fall mainly under two categories:
attack-dependent and attack-independent.
Expand All @@ -25,6 +8,10 @@
import numpy as np
import numpy.linalg as la
import tensorflow as tf
from scipy.stats import weibull_min
from scipy.optimize import fmin as scipy_optimizer
from scipy.special import gammainc
from functools import reduce

from art.attacks.fast_gradient import FastGradientMethod

Expand Down Expand Up @@ -196,3 +183,187 @@ def loss_sensitivity(x, classifier, sess):
res = la.norm(res.reshape(res.shape[0], -1), ord=2, axis=1)

return np.mean(res)


def clever_u(x, classifier, n_b, n_s, r, sess, c_init=1):
"""
Compute CLEVER score for an untargeted attack. Paper link: https://arxiv.org/abs/1801.10578
:param x: One input sample
:type x: `np.ndarray`
:param classifier: A trained model.
:type classifier: :class:`Classifier`
:param n_b: Batch size
:type n_b: `int`
:param n_s: Number of examples per batch
:type n_s: `int`
:param r: Maximum perturbation
:type r: `float`
:param sess: The session to run graphs in
:type sess: `tf.Session`
:param c_init: initialization of Weibull distribution
:type c_init: `float`
:return: A tuple of 3 CLEVER scores, corresponding to norms 1, 2 and np.inf
:rtype: `tuple`
"""
# Get a list of untargeted classes
y_pred = classifier.predict(np.array([x]))
pred_class = np.argmax(y_pred, axis=1)[0]
num_class = np.shape(y_pred)[1]
untarget_classes = [i for i in range(num_class) if i != pred_class]

# Compute CLEVER score for each untargeted class
score1_list, score2_list, score8_list = [], [], []
for j in untarget_classes:
s1, s2, s8 = clever_t(x, classifier, j, n_b, n_s, r, sess, c_init)
score1_list.append(s1)
score2_list.append(s2)
score8_list.append(s8)

return np.min(score1_list), np.min(score2_list), np.min(score8_list)


def clever_t(x, classifier, target_class, n_b, n_s, r, sess, c_init=1):
"""
Compute CLEVER score for a targeted attack. Paper link: https://arxiv.org/abs/1801.10578
:param x: One input sample
:type x: `np.ndarray`
:param classifier: A trained model
:type classifier: :class:`Classifier`
:param target_class: Targeted class
:type target_class: `int`
:param n_b: Batch size
:type n_b: `int`
:param n_s: Number of examples per batch
:type n_s: `int`
:param r: Maximum perturbation
:type r: `float`
:param sess: The session to run graphs in
:type sess: `tf.Session`
:param c_init: Initialization of Weibull distribution
:type c_init: `float`
:return: A tuple of 3 CLEVER scores, corresponding to norms 1, 2 and np.inf
:rtype: `tuple`
"""
# Check if the targeted class is different from the predicted class
y_pred = classifier.predict(np.array([x]))
pred_class = np.argmax(y_pred, axis=1)[0]
if target_class == pred_class:
raise ValueError("The targeted class is the predicted class!")

# Define placeholders for computing g gradients
shape = [None]
shape.extend(x.shape)
imgs = tf.placeholder(shape=shape, dtype=tf.float32)
pred_class_ph = tf.placeholder(dtype=tf.int32, shape=[])
target_class_ph = tf.placeholder(dtype=tf.int32, shape=[])

# Define tensors for g gradients
grad_norm_1, grad_norm_2, grad_norm_8, g_x = _build_g_gradient(imgs, classifier, pred_class_ph, target_class_ph)

# Some auxiliary vars
set1, set2, set8 = [], [], []
dim = reduce(lambda x_, y: x_ * y, x.shape, 1)
shape = [n_s]
shape.extend(x.shape)

# Compute predicted class
y_pred = classifier.predict(np.array([x]))
pred_class = np.argmax(y_pred, axis=1)[0]

# Loop over n_b batches
for i in range(n_b):
# Random generation of data points
sample_xs0 = np.reshape(_random_sphere(m=n_s, n=dim, r=r), shape)
sample_xs = sample_xs0 + np.repeat(np.array([x]), n_s, 0)
np.clip(sample_xs, 0, 1, out=sample_xs)

# Preprocess data if it is supported in the classifier
if hasattr(classifier, 'feature_squeeze'):
sample_xs = classifier.feature_squeeze(sample_xs)
sample_xs = classifier._preprocess(sample_xs)

# Compute gradients
max_gn1, max_gn2, max_gn8 = sess.run(
[grad_norm_1, grad_norm_2, grad_norm_8],
feed_dict={imgs: sample_xs, pred_class_ph: pred_class,
target_class_ph: target_class})
set1.append(max_gn1)
set2.append(max_gn2)
set8.append(max_gn8)

# Maximum likelihood estimation for max gradient norms
[_, loc1, _] = weibull_min.fit(-np.array(set1), c_init, optimizer=scipy_optimizer)
[_, loc2, _] = weibull_min.fit(-np.array(set2), c_init, optimizer=scipy_optimizer)
[_, loc8, _] = weibull_min.fit(-np.array(set8), c_init, optimizer=scipy_optimizer)

# Compute g_x0
x0 = np.array([x])
if hasattr(classifier, 'feature_squeeze'):
x0 = classifier.feature_squeeze(x0)
x0 = classifier._preprocess(x0)
g_x0 = sess.run(g_x, feed_dict={imgs: x0, pred_class_ph: pred_class,
target_class_ph: target_class})

# Compute scores
# Note q = p / (p-1)
s8 = np.min([-g_x0[0] / loc1, r])
s2 = np.min([-g_x0[0] / loc2, r])
s1 = np.min([-g_x0[0] / loc8, r])

return s1, s2, s8


def _build_g_gradient(x, classifier, pred_class, target_class):
"""
Build tensors of gradient `g`.
:param x: One input sample
:type x: `np.ndarray`
:param classifier: A trained model
:type classifier: :class:`Classifier`
:param pred_class: Predicted class
:type pred_class: `int`
:param target_class: Target class
:type target_class: `int`
:return: Max gradient norms
:rtype: `tuple`
"""
# Get predict values
y_pred = classifier.model(x)
pred_val = y_pred[:, pred_class]
target_val = y_pred[:, target_class]
g_x = pred_val - target_val

# Get the gradient op
grad_op = tf.gradients(g_x, x)[0]

# Compute the gradient norm
grad_op_rs = tf.reshape(grad_op, (tf.shape(grad_op)[0], -1))
grad_norm_1 = tf.reduce_max(tf.norm(grad_op_rs, ord=1, axis=1))
grad_norm_2 = tf.reduce_max(tf.norm(grad_op_rs, ord=2, axis=1))
grad_norm_8 = tf.reduce_max(tf.norm(grad_op_rs, ord=np.inf, axis=1))

return grad_norm_1, grad_norm_2, grad_norm_8, g_x


def _random_sphere(m, n, r):
"""
Generate randomly `m x n`-dimension points with radius `r` and centered around 0.
:param m: Number of random data points
:type m: `int`
:param n: Dimension
:type n: `int`
:param r: Radius
:type r: `float`
:return: The generated random sphere
:rtype: `np.ndarray`
"""
a = np.random.randn(m, n)
s2 = np.sum(a**2, axis=1)
base = gammainc(n/2, s2/2)**(1/n) * r / np.sqrt(s2)
a = a * (np.tile(base, (n, 1))).T

return a
136 changes: 118 additions & 18 deletions art/metrics_unittest.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,3 @@
# MIT License
#
# Copyright (C) IBM Corporation 2018
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import absolute_import, division, print_function, unicode_literals

import unittest
Expand All @@ -27,6 +9,7 @@
from art.classifiers.cnn import CNN
from art.metrics import empirical_robustness
from art.utils import load_mnist, load_cifar10
from art.metrics import clever_t, clever_u
from art.classifiers.classifier import Classifier

BATCH_SIZE = 10
Expand Down Expand Up @@ -100,5 +83,122 @@ def test_emp_robustness_mnist(self):
# self.assertLessEqual(emp_robust_jsma, 1.)


#########################################
# This part is the unit test for Clever.#
#########################################

class TestClassifier(Classifier):
def __init__(self, defences=None, preproc=None):
from keras.models import Sequential
from keras.layers import Lambda
model = Sequential(name="TestClassifier")
model.add(Lambda(lambda x: x + 0, input_shape=(2,)))

super(TestClassifier, self).__init__(model, defences, preproc)


class TestClever(unittest.TestCase):
"""
Unittest for Clever metrics.
"""
def test_clever_t_unit(self):
"""
Test the targeted version with simplified data.
:return:
"""
print("Unit test for the targeted version with simplified data.")
# Define session & params
session = tf.Session()
k.set_session(session)

# Get classifier
classifier = TestClassifier()

# Compute scores
res = clever_t(np.array([1, 0]), classifier, 1, 20, 10, 1, session)

# Test
self.assertAlmostEqual(res[0], 0.9999999999999998, delta=0.00001)
self.assertAlmostEqual(res[1], 0.7071067811865474, delta=0.00001)
self.assertAlmostEqual(res[2], 0.4999999999999999, delta=0.00001)

def test_clever_u_unit(self):
"""
Test the untargeted version with simplified data.
:return:
"""
print("Unit test for the untargeted version with simplified data.")
# Define session & params
session = tf.Session()
k.set_session(session)

# Get classifier
classifier = TestClassifier()

# Compute scores
res = clever_u(np.array([1, 0]), classifier, 20, 10, 1, session)

# Test
self.assertAlmostEqual(res[0], 0.9999999999999998, delta=0.00001)
self.assertAlmostEqual(res[1], 0.7071067811865474, delta=0.00001)
self.assertAlmostEqual(res[2], 0.4999999999999999, delta=0.00001)

def test_clever_t(self):
"""
Test the targeted version.
:return:
"""
print("Test if the targeted version works on a true classifier/data")
# Define session & params
session = tf.Session()
k.set_session(session)

comp_params = {"loss": 'categorical_crossentropy', "optimizer": 'adam',
"metrics": ['accuracy']}

# Get MNIST
(X_train, Y_train), (_, _), _, _ = load_mnist()
X_train, Y_train = X_train[:NB_TRAIN], Y_train[:NB_TRAIN]
im_shape = X_train[0].shape

# Get classifier
classifier = CNN(im_shape, act="relu")
classifier.compile(comp_params)
classifier.fit(X_train, Y_train, epochs=1,
batch_size=BATCH_SIZE, verbose=0)

res = clever_t(X_train[-1], classifier, 7, 20, 10, 5, session)
self.assertGreater(res[0], res[1])
self.assertGreater(res[1], res[2])

def test_clever_u(self):
"""
Test the untargeted version.
:return:
"""
print("Test if the untargeted version works on a true classifier/data")
# Define session & params
session = tf.Session()
k.set_session(session)

comp_params = {"loss": 'categorical_crossentropy', "optimizer": 'adam',
"metrics": ['accuracy']}

# Get MNIST
(X_train, Y_train), (_, _), _, _ = load_mnist()
X_train, Y_train = X_train[:NB_TRAIN], Y_train[:NB_TRAIN]
im_shape = X_train[0].shape

# Get classifier
classifier = CNN(im_shape, act="relu")
classifier.compile(comp_params)
classifier.fit(X_train, Y_train, epochs=1,
batch_size=BATCH_SIZE, verbose=0)

res = clever_u(X_train[-1], classifier, 2, 10, 5, session)
self.assertGreater(res[0], res[1])
self.assertGreater(res[1], res[2])


if __name__ == '__main__':
unittest.main()
5 changes: 5 additions & 0 deletions docs/modules/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ Empirical Robustness
Distance to nearest neighbors
-----------------------------
.. autofunction:: nearest_neighbour_dist

CLEVER
------
.. autofunction:: clever_u
.. autofunction:: clever_t

0 comments on commit 28397f5

Please sign in to comment.