You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/11/06 06:04:13 UTC
[incubator-mxnet] branch master updated: Update adversary attack
generation example (#12918)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 722ad7a Update adversary attack generation example (#12918)
722ad7a is described below
commit 722ad7a7de8390372a27cc52725bdcf29b242ea9
Author: Thomas Delteil <th...@gmail.com>
AuthorDate: Mon Nov 5 22:03:41 2018 -0800
Update adversary attack generation example (#12918)
* Fix adversary example generation
* Update README.md
* Fix test_utils.list_gpus()
* fix unused variable
---
example/adversary/README.md | 4 +-
example/adversary/adversary_generation.ipynb | 343 ++++++++++++++-------------
python/mxnet/test_utils.py | 2 +-
3 files changed, 185 insertions(+), 164 deletions(-)
diff --git a/example/adversary/README.md b/example/adversary/README.md
index 51d295d..5d5b44f 100644
--- a/example/adversary/README.md
+++ b/example/adversary/README.md
@@ -1,7 +1,7 @@
# Adversarial examples
This demonstrates the concept of "adversarial examples" from [1] showing how to fool a well-trained CNN.
-The surprising idea is that one can easily generate examples which the CNN will consistently
-make the wrong prediction for that a human can easily tell are correct.
+Adversarial examples are samples where the input has been manipulated to confuse a model (i.e. confident in an incorrect prediction) but where the correct answer still appears obvious to a human.
+This method for generating adversarial examples uses the gradient of the loss with respect to the input to craft the adversarial examples.
[1] Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. "Explaining and harnessing adversarial examples." [arXiv preprint arXiv:1412.6572 (2014)](https://arxiv.org/abs/1412.6572)
diff --git a/example/adversary/adversary_generation.ipynb b/example/adversary/adversary_generation.ipynb
index b8804bd..0b45366 100644
--- a/example/adversary/adversary_generation.ipynb
+++ b/example/adversary/adversary_generation.ipynb
@@ -6,8 +6,7 @@
"source": [
"# Fast Sign Adversary Generation Example\n",
"\n",
- "This notebook demos find adversary example by using symbolic API and integration with Numpy\n",
- "Reference: \n",
+ "This notebook demos finds adversary examples using MXNet Gluon and taking advantage of the gradient information\n",
"\n",
"[1] Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. \"Explaining and harnessing adversarial examples.\" arXiv preprint arXiv:1412.6572 (2014).\n",
"https://arxiv.org/abs/1412.6572"
@@ -15,7 +14,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {
"collapsed": false
},
@@ -28,290 +27,312 @@
"import matplotlib.pyplot as plt\n",
"import matplotlib.cm as cm\n",
"\n",
- "from mxnet.test_utils import get_mnist_iterator"
+ "from mxnet import gluon"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Build Network\n",
- "\n",
- "note: in this network, we will calculate softmax, gradient in numpy"
+ "Build simple CNN network for solving the MNIST dataset digit recognition task"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
- "dev = mx.cpu()\n",
- "batch_size = 100\n",
- "train_iter, val_iter = get_mnist_iterator(batch_size=batch_size, input_shape = (1,28,28))"
+ "ctx = mx.gpu() if len(mx.test_utils.list_gpus()) else mx.cpu()\n",
+ "batch_size = 128"
]
},
{
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
+ "cell_type": "markdown",
+ "metadata": {},
"source": [
- "# input\n",
- "data = mx.symbol.Variable('data')\n",
- "# first conv\n",
- "conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)\n",
- "tanh1 = mx.symbol.Activation(data=conv1, act_type=\"tanh\")\n",
- "pool1 = mx.symbol.Pooling(data=tanh1, pool_type=\"max\",\n",
- " kernel=(2,2), stride=(2,2))\n",
- "# second conv\n",
- "conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)\n",
- "tanh2 = mx.symbol.Activation(data=conv2, act_type=\"tanh\")\n",
- "pool2 = mx.symbol.Pooling(data=tanh2, pool_type=\"max\",\n",
- " kernel=(2,2), stride=(2,2))\n",
- "# first fullc\n",
- "flatten = mx.symbol.Flatten(data=pool2)\n",
- "fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)\n",
- "tanh3 = mx.symbol.Activation(data=fc1, act_type=\"tanh\")\n",
- "# second fullc\n",
- "fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=10)"
+ "## Data Loading"
]
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "collapsed": true
- },
+ "execution_count": 3,
+ "metadata": {},
"outputs": [],
"source": [
- "def Softmax(theta):\n",
- " max_val = np.max(theta, axis=1, keepdims=True)\n",
- " tmp = theta - max_val\n",
- " exp = np.exp(tmp)\n",
- " norm = np.sum(exp, axis=1, keepdims=True)\n",
- " return exp / norm"
+ "transform = lambda x,y: (x.transpose((2,0,1)).astype('float32')/255., y)\n",
+ "\n",
+ "train_dataset = gluon.data.vision.MNIST(train=True).transform(transform)\n",
+ "test_dataset = gluon.data.vision.MNIST(train=False).transform(transform)\n",
+ "\n",
+ "train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=5)\n",
+ "test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Create the network"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
- "def LogLossGrad(alpha, label):\n",
- " grad = np.copy(alpha)\n",
- " for i in range(alpha.shape[0]):\n",
- " grad[i, int(label[i])] -= 1.\n",
- " return grad"
+ "net = gluon.nn.HybridSequential()\n",
+ "with net.name_scope():\n",
+ " net.add(\n",
+ " gluon.nn.Conv2D(kernel_size=5, channels=20, activation='tanh'),\n",
+ " gluon.nn.MaxPool2D(pool_size=2, strides=2),\n",
+ " gluon.nn.Conv2D(kernel_size=5, channels=50, activation='tanh'),\n",
+ " gluon.nn.MaxPool2D(pool_size=2, strides=2),\n",
+ " gluon.nn.Flatten(),\n",
+ " gluon.nn.Dense(500, activation='tanh'),\n",
+ " gluon.nn.Dense(10)\n",
+ " )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Prepare useful data for the network"
+ "## Initialize training"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"metadata": {
- "collapsed": false
+ "collapsed": true
},
"outputs": [],
"source": [
- "data_shape = (batch_size, 1, 28, 28)\n",
- "arg_names = fc2.list_arguments() # 'data' \n",
- "arg_shapes, output_shapes, aux_shapes = fc2.infer_shape(data=data_shape)\n",
- "\n",
- "arg_arrays = [mx.nd.zeros(shape, ctx=dev) for shape in arg_shapes]\n",
- "grad_arrays = [mx.nd.zeros(shape, ctx=dev) for shape in arg_shapes]\n",
- "reqs = [\"write\" for name in arg_names]\n",
- "\n",
- "model = fc2.bind(ctx=dev, args=arg_arrays, args_grad = grad_arrays, grad_req=reqs)\n",
- "arg_map = dict(zip(arg_names, arg_arrays))\n",
- "grad_map = dict(zip(arg_names, grad_arrays))\n",
- "data_grad = grad_map[\"data\"]\n",
- "out_grad = mx.nd.zeros(model.outputs[0].shape, ctx=dev)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Init weight "
+ "net.initialize(mx.initializer.Uniform(), ctx=ctx)\n",
+ "net.hybridize()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
- "for name in arg_names:\n",
- " if \"weight\" in name:\n",
- " arr = arg_map[name]\n",
- " arr[:] = mx.rnd.uniform(-0.07, 0.07, arr.shape)"
+ "loss = gluon.loss.SoftmaxCELoss()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
- "def SGD(weight, grad, lr=0.1, grad_norm=batch_size):\n",
- " weight[:] -= lr * grad / batch_size\n",
- "\n",
- "def CalAcc(pred_prob, label):\n",
- " pred = np.argmax(pred_prob, axis=1)\n",
- " return np.sum(pred == label) * 1.0\n",
- "\n",
- "def CalLoss(pred_prob, label):\n",
- " loss = 0.\n",
- " for i in range(pred_prob.shape[0]):\n",
- " loss += -np.log(max(pred_prob[i, int(label[i])], 1e-10))\n",
- " return loss"
+ "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1, 'momentum':0.95})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Train a network"
+ "## Training loop"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 8,
"metadata": {
"collapsed": false
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Train Accuracy: 0.92\t Train Loss: 0.32142\n",
+ "Train Accuracy: 0.97\t Train Loss: 0.16773\n",
+ "Train Accuracy: 0.97\t Train Loss: 0.14660\n"
+ ]
+ }
+ ],
"source": [
- "num_round = 4\n",
- "train_acc = 0.\n",
- "nbatch = 0\n",
- "for i in range(num_round):\n",
+ "epoch = 3\n",
+ "for e in range(epoch):\n",
" train_loss = 0.\n",
- " train_acc = 0.\n",
- " nbatch = 0\n",
- " train_iter.reset()\n",
- " for batch in train_iter:\n",
- " arg_map[\"data\"][:] = batch.data[0]\n",
- " model.forward(is_train=True)\n",
- " theta = model.outputs[0].asnumpy()\n",
- " alpha = Softmax(theta)\n",
- " label = batch.label[0].asnumpy()\n",
- " train_acc += CalAcc(alpha, label) / batch_size\n",
- " train_loss += CalLoss(alpha, label) / batch_size\n",
- " losGrad_theta = LogLossGrad(alpha, label)\n",
- " out_grad[:] = losGrad_theta\n",
- " model.backward([out_grad])\n",
- " # data_grad[:] = grad_map[\"data\"]\n",
- " for name in arg_names:\n",
- " if name != \"data\":\n",
- " SGD(arg_map[name], grad_map[name])\n",
+ " acc = mx.metric.Accuracy()\n",
+ " for i, (data, label) in enumerate(train_data):\n",
+ " data = data.as_in_context(ctx)\n",
+ " label = label.as_in_context(ctx)\n",
+ " \n",
+ " with mx.autograd.record():\n",
+ " output = net(data)\n",
+ " l = loss(output, label)\n",
+ " \n",
+ " l.backward()\n",
+ " trainer.update(data.shape[0])\n",
" \n",
- " nbatch += 1\n",
- " #print(np.linalg.norm(data_grad.asnumpy(), 2))\n",
- " train_acc /= nbatch\n",
- " train_loss /= nbatch\n",
- " print(\"Train Accuracy: %.2f\\t Train Loss: %.5f\" % (train_acc, train_loss))"
+ " train_loss += l.mean().asscalar()\n",
+ " acc.update(label, output)\n",
+ " \n",
+ " print(\"Train Accuracy: %.2f\\t Train Loss: %.5f\" % (acc.get()[1], train_loss/(i+1)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Get pertubation by using fast sign method, check validation change.\n",
- "See that the validation set was almost entirely correct before the perturbations, but after the perturbations, it is much worse than random guessing."
+ "## Perturbation\n",
+ "\n",
+ "We first run a validation batch and measure the resulting accuracy.\n",
+ "We then perturbate this batch by modifying the input in the opposite direction of the gradient."
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation batch accuracy 0.96875\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Get a batch from the testing set\n",
+ "for data, label in test_data:\n",
+ " data = data.as_in_context(ctx)\n",
+ " label = label.as_in_context(ctx)\n",
+ " break\n",
+ "\n",
+ "# Attach gradient to it to get the gradient of the loss with respect to the input\n",
+ "data.attach_grad()\n",
+ "with mx.autograd.record():\n",
+ " output = net(data) \n",
+ " l = loss(output, label)\n",
+ "l.backward()\n",
+ "\n",
+ "acc = mx.metric.Accuracy()\n",
+ "acc.update(label, output)\n",
+ "\n",
+ "print(\"Validation batch accuracy {}\".format(acc.get()[1]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now we perturb the input"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
"metadata": {
"collapsed": false
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation batch accuracy after perturbation 0.40625\n"
+ ]
+ }
+ ],
"source": [
- "val_iter.reset()\n",
- "batch = val_iter.next()\n",
- "data = batch.data[0]\n",
- "label = batch.label[0]\n",
- "arg_map[\"data\"][:] = data\n",
- "model.forward(is_train=True)\n",
- "theta = model.outputs[0].asnumpy()\n",
- "alpha = Softmax(theta)\n",
- "print(\"Val Batch Accuracy: \", CalAcc(alpha, label.asnumpy()) / batch_size)\n",
- "#########\n",
- "grad = LogLossGrad(alpha, label.asnumpy())\n",
- "out_grad[:] = grad\n",
- "model.backward([out_grad])\n",
- "noise = np.sign(data_grad.asnumpy())\n",
- "arg_map[\"data\"][:] = data.asnumpy() + 0.15 * noise\n",
- "model.forward(is_train=True)\n",
- "raw_output = model.outputs[0].asnumpy()\n",
- "pred = Softmax(raw_output)\n",
- "print(\"Val Batch Accuracy after pertubation: \", CalAcc(pred, label.asnumpy()) / batch_size)"
+ "data_perturbated = data + 0.15 * mx.nd.sign(data.grad)\n",
+ "\n",
+ "output = net(data_perturbated) \n",
+ "\n",
+ "acc = mx.metric.Accuracy()\n",
+ "acc.update(label, output)\n",
+ "\n",
+ "print(\"Validation batch accuracy after perturbation {}\".format(acc.get()[1]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Visualize an example after pertubation.\n",
- "Note that the prediction is consistently incorrect."
+ "## Visualization"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's visualize an example after pertubation.\n",
+ "\n",
+ "We can see that the prediction is often incorrect."
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 16,
"metadata": {
"collapsed": false
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "true label: 1\n",
+ "predicted: 3\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADpJJREFUeJzt3V+IXeW5x/Hfc9JsNbbMmLbGkAQdgxwZAxoZY+EMJy1tgo2F2AuluSg5IE0vIrbQi4q9qJeh9A9eSHGqobG2ScVWDConsaFgS0p1FI/G8VRNSWmGJGOxpCnIjJk8vdgrZYx7r7Wz1989z/cDw+xZ715rPbMmv6y997vW+5q7C0A8/1F3AQDqQfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwT1sSp31mq1fNmyZaVs+/Tp06Vs97yhoaHa9p0lrbYmq/O41X3M0n73 [...]
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
- "import random as rnd\n",
- "idx = rnd.randint(0, 99)\n",
- "images = data.asnumpy() + 0.15 * noise\n",
- "plt.imshow(images[idx, :].reshape(28,28), cmap=cm.Greys_r)\n",
- "print(\"true: %d\" % label.asnumpy()[idx])\n",
- "print(\"pred: %d\" % np.argmax(pred, axis=1)[idx])"
+ "from random import randint\n",
+ "idx = randint(0, batch_size-1)\n",
+ "\n",
+ "plt.imshow(data_perturbated[idx, :].asnumpy().reshape(28,28), cmap=cm.Greys_r)\n",
+ "print(\"true label: %d\" % label.asnumpy()[idx])\n",
+ "print(\"predicted: %d\" % np.argmax(output.asnumpy(), axis=1)[idx])"
]
}
],
"metadata": {
"kernelspec": {
- "display_name": "Python 2",
+ "display_name": "Python 3",
"language": "python",
- "name": "python2"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
- "version": 2
+ "version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.13"
+ "pygments_lexer": "ipython3",
+ "version": "3.6.4"
}
},
"nbformat": 4,
- "nbformat_minor": 0
+ "nbformat_minor": 2
}
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 5487e35..38a2733 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1379,7 +1379,7 @@ def list_gpus():
for cmd in nvidia_smi:
try:
re = subprocess.check_output([cmd, "-L"], universal_newlines=True)
- except OSError:
+ except (subprocess.CalledProcessError, OSError):
pass
return range(len([i for i in re.split('\n') if 'GPU' in i]))