You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/11/06 06:04:13 UTC

[incubator-mxnet] branch master updated: Update adversary attack generation example (#12918)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 722ad7a  Update adversary attack generation example (#12918)
722ad7a is described below

commit 722ad7a7de8390372a27cc52725bdcf29b242ea9
Author: Thomas Delteil <th...@gmail.com>
AuthorDate: Mon Nov 5 22:03:41 2018 -0800

    Update adversary attack generation example (#12918)
    
    * Fix adversary example generation
    
    * Update README.md
    
    * Fix test_utils.list_gpus()
    
    * fix unused variable
---
 example/adversary/README.md                  |   4 +-
 example/adversary/adversary_generation.ipynb | 343 ++++++++++++++-------------
 python/mxnet/test_utils.py                   |   2 +-
 3 files changed, 185 insertions(+), 164 deletions(-)

diff --git a/example/adversary/README.md b/example/adversary/README.md
index 51d295d..5d5b44f 100644
--- a/example/adversary/README.md
+++ b/example/adversary/README.md
@@ -1,7 +1,7 @@
 # Adversarial examples
 
 This demonstrates the concept of "adversarial examples" from [1] showing how to fool a well-trained CNN.
-The surprising idea is that one can easily generate examples which the CNN will consistently 
-make the wrong prediction for that a human can easily tell are correct.
+Adversarial examples are samples where the input has been manipulated to confuse a model (i.e. confident in an incorrect prediction) but where the correct answer still appears obvious to a human.
+This method for generating adversarial examples uses the gradient of the loss with respect to the input to craft the adversarial examples.
 
 [1] Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. "Explaining and harnessing adversarial examples." [arXiv preprint arXiv:1412.6572 (2014)](https://arxiv.org/abs/1412.6572)
diff --git a/example/adversary/adversary_generation.ipynb b/example/adversary/adversary_generation.ipynb
index b8804bd..0b45366 100644
--- a/example/adversary/adversary_generation.ipynb
+++ b/example/adversary/adversary_generation.ipynb
@@ -6,8 +6,7 @@
    "source": [
     "# Fast Sign Adversary Generation Example\n",
     "\n",
-    "This notebook demos find adversary example by using symbolic API and integration with Numpy\n",
-    "Reference: \n",
+    "This notebook demos finds adversary examples using MXNet Gluon and taking advantage of the gradient information\n",
     "\n",
     "[1] Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. \"Explaining and harnessing adversarial examples.\" arXiv preprint arXiv:1412.6572 (2014).\n",
     "https://arxiv.org/abs/1412.6572"
@@ -15,7 +14,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "metadata": {
     "collapsed": false
    },
@@ -28,290 +27,312 @@
     "import matplotlib.pyplot as plt\n",
     "import matplotlib.cm as cm\n",
     "\n",
-    "from mxnet.test_utils import get_mnist_iterator"
+    "from mxnet import gluon"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Build Network\n",
-    "\n",
-    "note: in this network, we will calculate softmax, gradient in numpy"
+    "Build simple CNN network for solving the MNIST dataset digit recognition task"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 17,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "dev = mx.cpu()\n",
-    "batch_size = 100\n",
-    "train_iter, val_iter = get_mnist_iterator(batch_size=batch_size, input_shape = (1,28,28))"
+    "ctx = mx.gpu() if len(mx.test_utils.list_gpus()) else mx.cpu()\n",
+    "batch_size = 128"
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "collapsed": true
-   },
-   "outputs": [],
+   "cell_type": "markdown",
+   "metadata": {},
    "source": [
-    "# input\n",
-    "data = mx.symbol.Variable('data')\n",
-    "# first conv\n",
-    "conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)\n",
-    "tanh1 = mx.symbol.Activation(data=conv1, act_type=\"tanh\")\n",
-    "pool1 = mx.symbol.Pooling(data=tanh1, pool_type=\"max\",\n",
-    "                          kernel=(2,2), stride=(2,2))\n",
-    "# second conv\n",
-    "conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)\n",
-    "tanh2 = mx.symbol.Activation(data=conv2, act_type=\"tanh\")\n",
-    "pool2 = mx.symbol.Pooling(data=tanh2, pool_type=\"max\",\n",
-    "                          kernel=(2,2), stride=(2,2))\n",
-    "# first fullc\n",
-    "flatten = mx.symbol.Flatten(data=pool2)\n",
-    "fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)\n",
-    "tanh3 = mx.symbol.Activation(data=fc1, act_type=\"tanh\")\n",
-    "# second fullc\n",
-    "fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=10)"
+    "## Data Loading"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "collapsed": true
-   },
+   "execution_count": 3,
+   "metadata": {},
    "outputs": [],
    "source": [
-    "def Softmax(theta):\n",
-    "    max_val = np.max(theta, axis=1, keepdims=True)\n",
-    "    tmp = theta - max_val\n",
-    "    exp = np.exp(tmp)\n",
-    "    norm = np.sum(exp, axis=1, keepdims=True)\n",
-    "    return exp / norm"
+    "transform = lambda x,y: (x.transpose((2,0,1)).astype('float32')/255., y)\n",
+    "\n",
+    "train_dataset = gluon.data.vision.MNIST(train=True).transform(transform)\n",
+    "test_dataset = gluon.data.vision.MNIST(train=False).transform(transform)\n",
+    "\n",
+    "train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=5)\n",
+    "test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Create the network"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 4,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "def LogLossGrad(alpha, label):\n",
-    "    grad = np.copy(alpha)\n",
-    "    for i in range(alpha.shape[0]):\n",
-    "        grad[i, int(label[i])] -= 1.\n",
-    "    return grad"
+    "net = gluon.nn.HybridSequential()\n",
+    "with net.name_scope():\n",
+    "    net.add(\n",
+    "        gluon.nn.Conv2D(kernel_size=5, channels=20, activation='tanh'),\n",
+    "        gluon.nn.MaxPool2D(pool_size=2, strides=2),\n",
+    "        gluon.nn.Conv2D(kernel_size=5, channels=50, activation='tanh'),\n",
+    "        gluon.nn.MaxPool2D(pool_size=2, strides=2),\n",
+    "        gluon.nn.Flatten(),\n",
+    "        gluon.nn.Dense(500, activation='tanh'),\n",
+    "        gluon.nn.Dense(10)\n",
+    "    )"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Prepare useful data for the network"
+    "## Initialize training"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 5,
    "metadata": {
-    "collapsed": false
+    "collapsed": true
    },
    "outputs": [],
    "source": [
-    "data_shape = (batch_size, 1, 28, 28)\n",
-    "arg_names = fc2.list_arguments() # 'data' \n",
-    "arg_shapes, output_shapes, aux_shapes = fc2.infer_shape(data=data_shape)\n",
-    "\n",
-    "arg_arrays = [mx.nd.zeros(shape, ctx=dev) for shape in arg_shapes]\n",
-    "grad_arrays = [mx.nd.zeros(shape, ctx=dev) for shape in arg_shapes]\n",
-    "reqs = [\"write\" for name in arg_names]\n",
-    "\n",
-    "model = fc2.bind(ctx=dev, args=arg_arrays, args_grad = grad_arrays, grad_req=reqs)\n",
-    "arg_map = dict(zip(arg_names, arg_arrays))\n",
-    "grad_map = dict(zip(arg_names, grad_arrays))\n",
-    "data_grad = grad_map[\"data\"]\n",
-    "out_grad = mx.nd.zeros(model.outputs[0].shape, ctx=dev)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Init weight "
+    "net.initialize(mx.initializer.Uniform(), ctx=ctx)\n",
+    "net.hybridize()"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 6,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "for name in arg_names:\n",
-    "    if \"weight\" in name:\n",
-    "        arr = arg_map[name]\n",
-    "        arr[:] = mx.rnd.uniform(-0.07, 0.07, arr.shape)"
+    "loss = gluon.loss.SoftmaxCELoss()"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 7,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "def SGD(weight, grad, lr=0.1, grad_norm=batch_size):\n",
-    "    weight[:] -= lr * grad / batch_size\n",
-    "\n",
-    "def CalAcc(pred_prob, label):\n",
-    "    pred = np.argmax(pred_prob, axis=1)\n",
-    "    return np.sum(pred == label) * 1.0\n",
-    "\n",
-    "def CalLoss(pred_prob, label):\n",
-    "    loss = 0.\n",
-    "    for i in range(pred_prob.shape[0]):\n",
-    "        loss += -np.log(max(pred_prob[i, int(label[i])], 1e-10))\n",
-    "    return loss"
+    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1, 'momentum':0.95})"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Train a network"
+    "## Training loop"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 8,
    "metadata": {
     "collapsed": false
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Train Accuracy: 0.92\t Train Loss: 0.32142\n",
+      "Train Accuracy: 0.97\t Train Loss: 0.16773\n",
+      "Train Accuracy: 0.97\t Train Loss: 0.14660\n"
+     ]
+    }
+   ],
    "source": [
-    "num_round = 4\n",
-    "train_acc = 0.\n",
-    "nbatch = 0\n",
-    "for i in range(num_round):\n",
+    "epoch = 3\n",
+    "for e in range(epoch):\n",
     "    train_loss = 0.\n",
-    "    train_acc = 0.\n",
-    "    nbatch = 0\n",
-    "    train_iter.reset()\n",
-    "    for batch in train_iter:\n",
-    "        arg_map[\"data\"][:] = batch.data[0]\n",
-    "        model.forward(is_train=True)\n",
-    "        theta = model.outputs[0].asnumpy()\n",
-    "        alpha = Softmax(theta)\n",
-    "        label = batch.label[0].asnumpy()\n",
-    "        train_acc += CalAcc(alpha, label) / batch_size\n",
-    "        train_loss += CalLoss(alpha, label) / batch_size\n",
-    "        losGrad_theta = LogLossGrad(alpha, label)\n",
-    "        out_grad[:] = losGrad_theta\n",
-    "        model.backward([out_grad])\n",
-    "        # data_grad[:] = grad_map[\"data\"]\n",
-    "        for name in arg_names:\n",
-    "            if name != \"data\":\n",
-    "                SGD(arg_map[name], grad_map[name])\n",
+    "    acc = mx.metric.Accuracy()\n",
+    "    for i, (data, label) in enumerate(train_data):\n",
+    "        data = data.as_in_context(ctx)\n",
+    "        label = label.as_in_context(ctx)\n",
+    "        \n",
+    "        with mx.autograd.record():\n",
+    "            output = net(data)\n",
+    "            l = loss(output, label)\n",
+    "            \n",
+    "        l.backward()\n",
+    "        trainer.update(data.shape[0])\n",
     "        \n",
-    "        nbatch += 1\n",
-    "    #print(np.linalg.norm(data_grad.asnumpy(), 2))\n",
-    "    train_acc /= nbatch\n",
-    "    train_loss /= nbatch\n",
-    "    print(\"Train Accuracy: %.2f\\t Train Loss: %.5f\" % (train_acc, train_loss))"
+    "        train_loss += l.mean().asscalar()\n",
+    "        acc.update(label, output)\n",
+    "    \n",
+    "    print(\"Train Accuracy: %.2f\\t Train Loss: %.5f\" % (acc.get()[1], train_loss/(i+1)))"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Get pertubation by using fast sign method, check validation change.\n",
-    "See that the validation set was almost entirely correct before the perturbations, but after the perturbations, it is much worse than random guessing."
+    "## Perturbation\n",
+    "\n",
+    "We first run a validation batch and measure the resulting accuracy.\n",
+    "We then perturbate this batch by modifying the input in the opposite direction of the gradient."
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Validation batch accuracy 0.96875\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Get a batch from the testing set\n",
+    "for data, label in test_data:\n",
+    "    data = data.as_in_context(ctx)\n",
+    "    label = label.as_in_context(ctx)\n",
+    "    break\n",
+    "\n",
+    "# Attach gradient to it to get the gradient of the loss with respect to the input\n",
+    "data.attach_grad()\n",
+    "with mx.autograd.record():\n",
+    "    output = net(data)    \n",
+    "    l = loss(output, label)\n",
+    "l.backward()\n",
+    "\n",
+    "acc = mx.metric.Accuracy()\n",
+    "acc.update(label, output)\n",
+    "\n",
+    "print(\"Validation batch accuracy {}\".format(acc.get()[1]))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Now we perturb the input"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
    "metadata": {
     "collapsed": false
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Validation batch accuracy after perturbation 0.40625\n"
+     ]
+    }
+   ],
    "source": [
-    "val_iter.reset()\n",
-    "batch = val_iter.next()\n",
-    "data = batch.data[0]\n",
-    "label = batch.label[0]\n",
-    "arg_map[\"data\"][:] = data\n",
-    "model.forward(is_train=True)\n",
-    "theta = model.outputs[0].asnumpy()\n",
-    "alpha = Softmax(theta)\n",
-    "print(\"Val Batch Accuracy: \", CalAcc(alpha, label.asnumpy()) / batch_size)\n",
-    "#########\n",
-    "grad = LogLossGrad(alpha, label.asnumpy())\n",
-    "out_grad[:] = grad\n",
-    "model.backward([out_grad])\n",
-    "noise = np.sign(data_grad.asnumpy())\n",
-    "arg_map[\"data\"][:] = data.asnumpy() + 0.15 * noise\n",
-    "model.forward(is_train=True)\n",
-    "raw_output = model.outputs[0].asnumpy()\n",
-    "pred = Softmax(raw_output)\n",
-    "print(\"Val Batch Accuracy after pertubation: \", CalAcc(pred, label.asnumpy()) / batch_size)"
+    "data_perturbated = data + 0.15 * mx.nd.sign(data.grad)\n",
+    "\n",
+    "output = net(data_perturbated)    \n",
+    "\n",
+    "acc = mx.metric.Accuracy()\n",
+    "acc.update(label, output)\n",
+    "\n",
+    "print(\"Validation batch accuracy after perturbation {}\".format(acc.get()[1]))"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Visualize an example after pertubation.\n",
-    "Note that the prediction is consistently incorrect."
+    "## Visualization"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Let's visualize an example after pertubation.\n",
+    "\n",
+    "We can see that the prediction is often incorrect."
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 16,
    "metadata": {
     "collapsed": false
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "true label: 1\n",
+      "predicted: 3\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADpJJREFUeJzt3V+IXeW5x/Hfc9JsNbbMmLbGkAQdgxwZAxoZY+EMJy1tgo2F2AuluSg5IE0vIrbQi4q9qJeh9A9eSHGqobG2ScVWDConsaFgS0p1FI/G8VRNSWmGJGOxpCnIjJk8vdgrZYx7r7Wz1989z/cDw+xZ715rPbMmv6y997vW+5q7C0A8/1F3AQDqQfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwT1sSp31mq1fNmyZaVs+/Tp06Vs97yhoaHa9p0lrbYmq/O41X3M0n73 [...]
+      "text/plain": [
+       "<Figure size 432x288 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
    "source": [
-    "import random as rnd\n",
-    "idx = rnd.randint(0, 99)\n",
-    "images = data.asnumpy()  + 0.15 * noise\n",
-    "plt.imshow(images[idx, :].reshape(28,28), cmap=cm.Greys_r)\n",
-    "print(\"true: %d\" % label.asnumpy()[idx])\n",
-    "print(\"pred: %d\" % np.argmax(pred, axis=1)[idx])"
+    "from random import randint\n",
+    "idx = randint(0, batch_size-1)\n",
+    "\n",
+    "plt.imshow(data_perturbated[idx, :].asnumpy().reshape(28,28), cmap=cm.Greys_r)\n",
+    "print(\"true label: %d\" % label.asnumpy()[idx])\n",
+    "print(\"predicted: %d\" % np.argmax(output.asnumpy(), axis=1)[idx])"
    ]
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 2",
+   "display_name": "Python 3",
    "language": "python",
-   "name": "python2"
+   "name": "python3"
   },
   "language_info": {
    "codemirror_mode": {
     "name": "ipython",
-    "version": 2
+    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython2",
-   "version": "2.7.13"
+   "pygments_lexer": "ipython3",
+   "version": "3.6.4"
   }
  },
  "nbformat": 4,
- "nbformat_minor": 0
+ "nbformat_minor": 2
 }
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 5487e35..38a2733 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1379,7 +1379,7 @@ def list_gpus():
     for cmd in nvidia_smi:
         try:
             re = subprocess.check_output([cmd, "-L"], universal_newlines=True)
-        except OSError:
+        except (subprocess.CalledProcessError, OSError):
             pass
     return range(len([i for i in re.split('\n') if 'GPU' in i]))