You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/13 00:32:35 UTC

[GitHub] sandeep-krishnamurthy closed pull request #12964: Update multi-task learning example

sandeep-krishnamurthy closed pull request #12964: Update multi-task learning example
URL: https://github.com/apache/incubator-mxnet/pull/12964
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/multi-task/README.md b/example/multi-task/README.md
index 9034814c3b5..b7756fe378a 100644
--- a/example/multi-task/README.md
+++ b/example/multi-task/README.md
@@ -1,10 +1,13 @@
 # Mulit-task learning example
  
-This is a simple example to show how to use mxnet for multi-task learning. It uses MNIST as an example and mocks up the multi-label task.
+This is a simple example to show how to use mxnet for multi-task learning. It uses MNIST as an example, trying to predict jointly the digit and whether this digit is odd or even.
 
-## Usage
-First, you need to write a multi-task iterator on your own. The iterator needs to generate multiple labels according to your applications, and the label names should be specified in the `provide_label` function, which needs to be consist with the names of output layers. 
+For example:
 
-Then, if you want to show metrics of different tasks separately, you need to write your own metric class and specify the `num` parameter. In the `update` function of metric, calculate the metrics separately for different tasks.
+![](https://camo.githubusercontent.com/ed3cf256f47713335dc288f32f9b0b60bf1028b7/68747470733a2f2f7777772e636c61737365732e63732e756368696361676f2e6564752f617263686976652f323031332f737072696e672f31323330302d312f70612f7061312f64696769742e706e67)
 
-The example script uses gpu as device by default, if gpu is not available for your environment, you can change `device` to be `mx.cpu()`.
+Should be jointly classified as 4, and Even.
+
+In this example we don't expect the tasks to contribute to each other much, but for example multi-task learning has been successfully applied to the domain of image captioning. In [A Multi-task Learning Approach for Image Captioning](https://www.ijcai.org/proceedings/2018/0168.pdf) by Wei Zhao, Benyou Wang, Jianbo Ye, Min Yang, Zhou Zhao, Ruotian Luo, Yu Qiao, they train a network to jointly classify images and generate text captions
+
+Please refer to the notebook for a fully worked example.
diff --git a/example/multi-task/example_multi_task.py b/example/multi-task/example_multi_task.py
deleted file mode 100644
index 9e898494a14..00000000000
--- a/example/multi-task/example_multi_task.py
+++ /dev/null
@@ -1,159 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: skip-file
-import mxnet as mx
-from mxnet.test_utils import get_mnist_iterator
-import numpy as np
-import logging
-import time
-
-logging.basicConfig(level=logging.DEBUG)
-
-def build_network():
-    data = mx.symbol.Variable('data')
-    fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
-    act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
-    fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
-    act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
-    fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
-    sm1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax1')
-    sm2 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax2')
-
-    softmax = mx.symbol.Group([sm1, sm2])
-
-    return softmax
-
-class Multi_mnist_iterator(mx.io.DataIter):
-    '''multi label mnist iterator'''
-
-    def __init__(self, data_iter):
-        super(Multi_mnist_iterator, self).__init__()
-        self.data_iter = data_iter
-        self.batch_size = self.data_iter.batch_size
-
-    @property
-    def provide_data(self):
-        return self.data_iter.provide_data
-
-    @property
-    def provide_label(self):
-        provide_label = self.data_iter.provide_label[0]
-        # Different labels should be used here for actual application
-        return [('softmax1_label', provide_label[1]), \
-                ('softmax2_label', provide_label[1])]
-
-    def hard_reset(self):
-        self.data_iter.hard_reset()
-
-    def reset(self):
-        self.data_iter.reset()
-
-    def next(self):
-        batch = self.data_iter.next()
-        label = batch.label[0]
-
-        return mx.io.DataBatch(data=batch.data, label=[label, label], \
-                pad=batch.pad, index=batch.index)
-
-class Multi_Accuracy(mx.metric.EvalMetric):
-    """Calculate accuracies of multi label"""
-
-    def __init__(self, num=None):
-        self.num = num
-        super(Multi_Accuracy, self).__init__('multi-accuracy')
-
-    def reset(self):
-        """Resets the internal evaluation result to initial state."""
-        self.num_inst = 0 if self.num is None else [0] * self.num
-        self.sum_metric = 0.0 if self.num is None else [0.0] * self.num
-
-    def update(self, labels, preds):
-        mx.metric.check_label_shapes(labels, preds)
-
-        if self.num is not None:
-            assert len(labels) == self.num
-
-        for i in range(len(labels)):
-            pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
-            label = labels[i].asnumpy().astype('int32')
-
-            mx.metric.check_label_shapes(label, pred_label)
-
-            if self.num is None:
-                self.sum_metric += (pred_label.flat == label.flat).sum()
-                self.num_inst += len(pred_label.flat)
-            else:
-                self.sum_metric[i] += (pred_label.flat == label.flat).sum()
-                self.num_inst[i] += len(pred_label.flat)
-
-    def get(self):
-        """Gets the current evaluation result.
-
-        Returns
-        -------
-        names : list of str
-           Name of the metrics.
-        values : list of float
-           Value of the evaluations.
-        """
-        if self.num is None:
-            return super(Multi_Accuracy, self).get()
-        else:
-            return zip(*(('%s-task%d'%(self.name, i), float('nan') if self.num_inst[i] == 0
-                                                      else self.sum_metric[i] / self.num_inst[i])
-                       for i in range(self.num)))
-
-    def get_name_value(self):
-        """Returns zipped name and value pairs.
-
-        Returns
-        -------
-        list of tuples
-            A (name, value) tuple list.
-        """
-        if self.num is None:
-            return super(Multi_Accuracy, self).get_name_value()
-        name, value = self.get()
-        return list(zip(name, value))
-
-
-batch_size=100
-num_epochs=100
-device = mx.gpu(0)
-lr = 0.01
-
-network = build_network()
-train, val = get_mnist_iterator(batch_size=batch_size, input_shape = (784,))
-train = Multi_mnist_iterator(train)
-val = Multi_mnist_iterator(val)
-
-
-model = mx.mod.Module(
-    context            = device,
-    symbol             = network,
-    label_names        = ('softmax1_label', 'softmax2_label'))
-
-model.fit(
-    train_data         = train,
-    eval_data          = val,
-    eval_metric        = Multi_Accuracy(num=2),
-    num_epoch          = num_epochs,
-    optimizer_params   = (('learning_rate', lr), ('momentum', 0.9), ('wd', 0.00001)),
-    initializer        = mx.init.Xavier(factor_type="in", magnitude=2.34),
-    batch_end_callback = mx.callback.Speedometer(batch_size, 50))
-
diff --git a/example/multi-task/multi-task-learning.ipynb b/example/multi-task/multi-task-learning.ipynb
new file mode 100644
index 00000000000..6e03e2b61f8
--- /dev/null
+++ b/example/multi-task/multi-task-learning.ipynb
@@ -0,0 +1,454 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Multi-Task Learning Example"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "This is a simple example to show how to use mxnet for multi-task learning.\n",
+    "\n",
+    "The network is jointly going to learn whether a number is odd or even and to actually recognize the digit.\n",
+    "\n",
+    "\n",
+    "For example\n",
+    "\n",
+    "- 1 : 1 and odd\n",
+    "- 2 : 2 and even\n",
+    "- 3 : 3 and odd\n",
+    "\n",
+    "etc\n",
+    "\n",
+    "In this example we don't expect the tasks to contribute to each other much, but for example multi-task learning has been successfully applied to the domain of image captioning. In [A Multi-task Learning Approach for Image Captioning](https://www.ijcai.org/proceedings/2018/0168.pdf) by Wei Zhao, Benyou Wang, Jianbo Ye, Min Yang, Zhou Zhao, Ruotian Luo, Yu Qiao, they train a network to jointly classify images and generate text captions"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import logging\n",
+    "import random\n",
+    "import time\n",
+    "\n",
+    "import matplotlib.pyplot as plt\n",
+    "import mxnet as mx\n",
+    "from mxnet import gluon, nd, autograd\n",
+    "import numpy as np"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Parameters"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 99,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "batch_size = 128\n",
+    "epochs = 5\n",
+    "ctx = mx.gpu() if len(mx.test_utils.list_gpus()) > 0 else mx.cpu()\n",
+    "lr = 0.01"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Data\n",
+    "\n",
+    "We get the traditionnal MNIST dataset and add a new label to the existing one. For each digit we return a new label that stands for Odd or Even"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "![](https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_dataset = gluon.data.vision.MNIST(train=True)\n",
+    "test_dataset = gluon.data.vision.MNIST(train=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def transform(x,y):\n",
+    "    x = x.transpose((2,0,1)).astype('float32')/255.\n",
+    "    y1 = y\n",
+    "    y2 = y % 2 #odd or even\n",
+    "    return x, np.float32(y1), np.float32(y2)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We assign the transform to the original dataset"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_dataset_t = train_dataset.transform(transform)\n",
+    "test_dataset_t = test_dataset.transform(transform)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We load the datasets DataLoaders"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_data = gluon.data.DataLoader(train_dataset_t, shuffle=True, last_batch='rollover', batch_size=batch_size, num_workers=5)\n",
+    "test_data = gluon.data.DataLoader(test_dataset_t, shuffle=False, last_batch='rollover', batch_size=batch_size, num_workers=5)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Input shape: (28, 28, 1), Target Labels: (5.0, 1.0)\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(\"Input shape: {}, Target Labels: {}\".format(train_dataset[0][0].shape, train_dataset_t[0][1:]))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Multi-task Network\n",
+    "\n",
+    "The output of the featurization is passed to two different outputs layers"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 135,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class MultiTaskNetwork(gluon.HybridBlock):\n",
+    "    \n",
+    "    def __init__(self):\n",
+    "        super(MultiTaskNetwork, self).__init__()\n",
+    "        \n",
+    "        self.shared = gluon.nn.HybridSequential()\n",
+    "        with self.shared.name_scope():\n",
+    "            self.shared.add(\n",
+    "                gluon.nn.Dense(128, activation='relu'),\n",
+    "                gluon.nn.Dense(64, activation='relu'),\n",
+    "                gluon.nn.Dense(10, activation='relu')\n",
+    "            )\n",
+    "        self.output1 = gluon.nn.Dense(10) # Digist recognition\n",
+    "        self.output2 = gluon.nn.Dense(1) # odd or even\n",
+    "\n",
+    "        \n",
+    "    def hybrid_forward(self, F, x):\n",
+    "        y = self.shared(x)\n",
+    "        output1 = self.output1(y)\n",
+    "        output2 = self.output2(y)\n",
+    "        return output1, output2"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can use two different losses, one for each output"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 136,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "loss_digits = gluon.loss.SoftmaxCELoss()\n",
+    "loss_odd_even = gluon.loss.SigmoidBCELoss()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We create and initialize the network"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 137,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mx.random.seed(42)\n",
+    "random.seed(42)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 138,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "net = MultiTaskNetwork()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 139,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "net.initialize(mx.init.Xavier(), ctx=ctx)\n",
+    "net.hybridize() # hybridize for speed"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 140,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':lr})"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Evaluate Accuracy\n",
+    "We need to evaluate the accuracy of each task separately"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 141,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def evaluate_accuracy(net, data_iterator):\n",
+    "    acc_digits = mx.metric.Accuracy(name='digits')\n",
+    "    acc_odd_even = mx.metric.Accuracy(name='odd_even')\n",
+    "    \n",
+    "    for i, (data, label_digit, label_odd_even) in enumerate(data_iterator):\n",
+    "        data = data.as_in_context(ctx)\n",
+    "        label_digit = label_digit.as_in_context(ctx)\n",
+    "        label_odd_even = label_odd_even.as_in_context(ctx).reshape(-1,1)\n",
+    "\n",
+    "        output_digit, output_odd_even = net(data)\n",
+    "        \n",
+    "        acc_digits.update(label_digit, output_digit.softmax())\n",
+    "        acc_odd_even.update(label_odd_even, output_odd_even.sigmoid() > 0.5)\n",
+    "    return acc_digits.get(), acc_odd_even.get()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Training Loop"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We need to balance the contribution of each loss to the overall training and do so by tuning this alpha parameter within [0,1]."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 142,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "alpha = 0.5 # Combine losses factor"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 143,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch [0], Acc Digits   0.8945 Loss Digits   0.3409\n",
+      "Epoch [0], Acc Odd/Even 0.9561 Loss Odd/Even 0.1152\n",
+      "Epoch [0], Testing Accuracies (('digits', 0.9487179487179487), ('odd_even', 0.9770633012820513))\n",
+      "Epoch [1], Acc Digits   0.9576 Loss Digits   0.1475\n",
+      "Epoch [1], Acc Odd/Even 0.9804 Loss Odd/Even 0.0559\n",
+      "Epoch [1], Testing Accuracies (('digits', 0.9642427884615384), ('odd_even', 0.9826722756410257))\n",
+      "Epoch [2], Acc Digits   0.9681 Loss Digits   0.1124\n",
+      "Epoch [2], Acc Odd/Even 0.9852 Loss Odd/Even 0.0418\n",
+      "Epoch [2], Testing Accuracies (('digits', 0.9580328525641025), ('odd_even', 0.9846754807692307))\n",
+      "Epoch [3], Acc Digits   0.9734 Loss Digits   0.0961\n",
+      "Epoch [3], Acc Odd/Even 0.9884 Loss Odd/Even 0.0340\n",
+      "Epoch [3], Testing Accuracies (('digits', 0.9670472756410257), ('odd_even', 0.9839743589743589))\n",
+      "Epoch [4], Acc Digits   0.9762 Loss Digits   0.0848\n",
+      "Epoch [4], Acc Odd/Even 0.9894 Loss Odd/Even 0.0310\n",
+      "Epoch [4], Testing Accuracies (('digits', 0.9652887658227848), ('odd_even', 0.9858583860759493))\n"
+     ]
+    }
+   ],
+   "source": [
+    "for e in range(epochs):\n",
+    "    # Accuracies for each task\n",
+    "    acc_digits = mx.metric.Accuracy(name='digits')\n",
+    "    acc_odd_even = mx.metric.Accuracy(name='odd_even')\n",
+    "    # Accumulative losses\n",
+    "    l_digits_ = 0.\n",
+    "    l_odd_even_ = 0. \n",
+    "    \n",
+    "    for i, (data, label_digit, label_odd_even) in enumerate(train_data):\n",
+    "        data = data.as_in_context(ctx)\n",
+    "        label_digit = label_digit.as_in_context(ctx)\n",
+    "        label_odd_even = label_odd_even.as_in_context(ctx).reshape(-1,1)\n",
+    "        \n",
+    "        with autograd.record():\n",
+    "            output_digit, output_odd_even = net(data)\n",
+    "            l_digits = loss_digits(output_digit, label_digit)\n",
+    "            l_odd_even = loss_odd_even(output_odd_even, label_odd_even)\n",
+    "\n",
+    "            # Combine the loss of each task\n",
+    "            l_combined = (1-alpha)*l_digits + alpha*l_odd_even\n",
+    "            \n",
+    "        l_combined.backward()\n",
+    "        trainer.step(data.shape[0])\n",
+    "        \n",
+    "        l_digits_ += l_digits.mean()\n",
+    "        l_odd_even_ += l_odd_even.mean()\n",
+    "        acc_digits.update(label_digit, output_digit.softmax())\n",
+    "        acc_odd_even.update(label_odd_even, output_odd_even.sigmoid() > 0.5)\n",
+    "        \n",
+    "    print(\"Epoch [{}], Acc Digits   {:.4f} Loss Digits   {:.4f}\".format(\n",
+    "        e, acc_digits.get()[1], l_digits_.asscalar()/(i+1)))\n",
+    "    print(\"Epoch [{}], Acc Odd/Even {:.4f} Loss Odd/Even {:.4f}\".format(\n",
+    "        e, acc_odd_even.get()[1], l_odd_even_.asscalar()/(i+1)))\n",
+    "    print(\"Epoch [{}], Testing Accuracies {}\".format(e, evaluate_accuracy(net, test_data)))\n",
+    "        "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Testing"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 144,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_random_data():\n",
+    "    idx = random.randint(0, len(test_dataset))\n",
+    "\n",
+    "    img = test_dataset[idx][0]\n",
+    "    data, _, _ = test_dataset_t[idx]\n",
+    "    data = data.as_in_context(ctx).expand_dims(axis=0)\n",
+    "\n",
+    "    plt.imshow(img.squeeze().asnumpy(), cmap='gray')\n",
+    "    \n",
+    "    return data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 152,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Predicted digit: [9.], odd: [1.]\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADeVJREFUeJzt3X+MFPX9x/HXG6QGAQ3aiBdLpd9Ga6pBak5joqk01caaRuAfUhMbjE2viTUpEVFCNT31Dxu1rdWYJldLCk2/QhUb+KPWWuKP1jQNIKiotFJC00OEkjNBEiNyvPvHzdlTbz6zzs7uzPF+PpLL7e57Z+ad5V7M7H5m9mPuLgDxTKq7AQD1IPxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4I6oZsbMzNOJwQ6zN2tlee1tec3s6vM7O9mtsvMVrSzLgDdZWXP7TezyZL+IelKSYOSNku61t1fSyzDnh/osG7s+S+WtMvdd7v7EUlrJS1oY30Auqid8J8p6d9j7g9mj32ImfWZ2RYz29LGtgBUrOMf+Ln7gKQBicN+oEna2fPvlTR7zP3PZI8BmADaCf9mSWeb2efM7FOSvilpYzVtAei00of97n7UzG6S9JSkyZJWufurlXUGoKNKD/WV2hjv+YGO68pJPgAmLsIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCKj1FtySZ2R5J70galnTU3XuraApA57UV/sxX3P1gBesB0EUc9gNBtRt+l/RHM9tqZn1VNASgO9o97L/M3fea2emSnjazne7+/NgnZP8p8B8D0DDm7tWsyKxf0mF3vz/xnGo2BiCXu1srzyt92G9m08xsxuhtSV+TtKPs+gB0VzuH/bMk/c7MRtfz/+7+h0q6AtBxlR32t7QxDvuBjuv4YT+AiY3wA0ERfiAowg8ERfiBoAg/EFQVV/WhwaZPn56sL1++vK3lb7755mT97bffzq3deeedyWUffvjhZP3o0aPJOtLY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUFzSOwFMnTo1WV+xYkVurWgcftq0acl69n0NuTr591M0zr9s2bJk/ciRI1W2M2FwSS+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCIpx/i4oGqe//PLLk/Vbb701WZ8/f/4nballQ0NDbdWnTJmSWzvrrLNK9TTqySefTNafe+653NoDDzyQXHYinyPAOD+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCKpwnN/MVkn6hqQD7n5+9tipktZJmiNpj6TF7p7/Be3/W9dxOc5/0kknJesPPvhgsn7DDTdU2c6H7NixI1m/5557kvVt27Yl6zt37kzWZ8yYkVt76qmnkstecsklyXo7zjnnnGR9165dHdt2p1U5zv8rSVd95LEVkja5+9mSNmX3AUwgheF39+clffQ0rgWSVme3V0taWHFfADqs7Hv+We6+L7v9lqRZFfUDoEvanqvP3T31Xt7M+iT1tbsdANUqu+ffb2Y9kpT9PpD3RHcfcPded+8tuS0AHVA2/BslLcluL5G0oZp2AHRLYfjN7FFJf5X0BTMbNLNvS/qRpCvN7A1JV2T3AUwghe/53f3anNJXK+5lwrriiiuS9XbH8Q8ePJisr1u3Lrd2yy23JJd97733SvXUqp6entq2jTTO8AOCIvxAUIQfCIrwA0ERfiAowg8E1fbpvVGkprJevnx5R7f9yCOPJOsrV67s2LZPOCH9J7Jo0aJk/aGHHsqtnX766aV6atUzzzyTW9u7d29Htz0RsOcHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY52/RHXfckVu79NJL21p30Tj+3Xff3db6U84999xkfenSpcl6X19zv6Ht3nvvza29++67XeykmdjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPO3qJPXnq9ZsyZZLxqTTk03XTROv3jx4mT9tNNOS9aLpnjvpNR3BUjSs88+251GJij2/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QVOE4v5mtkvQNSQfc/fzssX5J35H0n+xpK939951qsgk2b96cW7v++uvbWveGDRuS9SNHjiTrU6dOza2dfPLJpXoa9f777yfr1113XbKemlNg7ty5pXoa9dhjjyXrTAGe1sqe/1eSrhrn8Z+6+7zs57gOPnA8Kgy/uz8vaagLvQDoonbe899kZi+b2Sozm1lZRwC6omz4fy7p85LmSdon6cd5TzSzPjPbYmZbSm4LQAeUCr+773f3YXc/JukXki5OPHfA3XvdvbdskwCqVyr8ZtYz5u4iSTuqaQdAt7Qy1PeopPmSPm1mg5J+KGm+mc2T5JL2SPpuB3sE0AHWzeuxzay+i7/bNGlS/kHS448/nlx24cKFVbdTmRdeeCFZv+uuu5L1ovMIisbiU4p6mz9/frI+PDxcetsTmbtbK8/jDD8gKMIPBEX4gaAIPxAU4QeCIvxAUHx1d4uOHTuWW7vxxhuTy+7fvz9ZL7osdufOncn6E088kVsr+nrrw4cPJ+snnnhisl40HGeWP+qUek0ladOmTcl61KG8qrDnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGguKQXSWeccUay/uabb5Ze9/bt25P1Cy+8sPS6I+OSXgBJhB8IivADQRF+ICjCDwRF+IGgCD8QFNfzI6m/v7+t5VNTfK9du7atdaM97PmBoAg/EBThB4Ii/EBQhB8IivADQRF+IKjC6/nNbLakNZJmSXJJA+7+MzM7VdI6SXMk7ZG02N3fLlgX1/M3zKJFi5L11JwAklT093Pffffl1m677bbksiinyuv5j0pa5u5flHSJpO+Z2RclrZC0yd3PlrQpuw9ggigMv7vvc/cXs9vvSHpd0pmSFkhanT1ttaSFnWoSQPU+0Xt+M5sj6UuS/iZplrvvy0pvaeRtAYAJouVz+81suqT1kpa6+6Gxc7C5u+e9nzezPkl97TYKoFot7fnNbIpGgv8bdx/9BGi/mfVk9R5JB8Zb1t0H3L3X3XuraBhANQrDbyO7+F9Ket3dfzKmtFHSkuz2Ekkbqm8PQKe0MtR3maQ/S3pF0uicyis18r7/t5I+K+lfGhnqGypYF0N9DfPSSy8l63Pnzk3Wh4aS/+S64IILcmuDg4PJZVFOq0N9he/53f0vkvJW9tVP0hSA5uAMPyAowg8ERfiBoAg/EBThB4Ii/EBQfHX3ca7ostnzzjsvWR8eHk7Wb7/99mSdsfzmYs8PBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0EVXs9f6ca4nr8j5syZk1vbtm1bctlTTjklWd+6dWuyftFFFyXr6L4qv7obwHGI8ANBEX4gKMIPBEX4gaAIPxAU4QeC4nr+48DSpUtza0Xj+EX6+/vbWh7NxZ4fCIrwA0ERfiAowg8ERfiBoAg/EBThB4IqvJ7fzGZLWiNpliSXNODuPzOzfknfkfSf7Kkr3f33Beviev4SrrnmmmR9/fr1ubXJkye3te1Jk9g/TDStXs/fykk+RyUtc/cXzWyGpK1m9nRW+6m731+2SQD1KQy/u++TtC+7/Y6ZvS7pzE43BqCzPtExnZnNkfQlSX/LHrrJzF42s1VmNjNnmT4z22JmW9rqFEClWg6/mU2XtF7SUnc/JOnnkj4vaZ5Gjgx+PN5y7j7g7r3u3ltBvwAq0lL4zWyKRoL/G3d/QpLcfb+7D7v7MUm/kHRx59oEULXC8JuZSfqlpNfd/SdjHu8Z87RFknZU3x6ATmnl0/5LJX1L0itmtj17bKWka81snkaG//ZI+m5HOoR2796drB86dCi3NnPmuB/FfOD++xmsiaqVT/v/Imm8ccPkmD6AZuMMDiAowg8ERfiBoAg/EBThB4Ii/EBQTNENHGeYohtAEuEHgiL8QFCEHwiK8ANBEX4gKMIPBNXtKboPSvrXmPufzh5roqb21tS+JHorq8rezmr1iV09yedjGzfb0tTv9mtqb03tS6K3surqjcN+ICjCDwRVd/gHat5+SlN7a2pfEr2VVUtvtb7nB1Cfuvf8AGpSS/jN7Coz+7uZ7TKzFXX0kMfM9pjZK2a2ve4pxrJp0A6Y2Y4xj51qZk+b2RvZ7/R3c3e3t34z25u9dtvN7OqaepttZs+Y2Wtm9qqZfT97vNbXLtFXLa9b1w/7zWyypH9IulLSoKTNkq5199e62kgOM9sjqdfdax8TNrMvSzosaY27n589dq+kIXf/UfYf50x3v60hvfVLOlz3zM3ZhDI9Y2eWlrRQ0vWq8bVL9LVYNbxudez5L5a0y913u/sRSWslLaihj8Zz9+clDX3k4QWSVme3V2vkj6frcnprBHff5+4vZrffkTQ6s3Str12ir1rUEf4zJf17zP1BNWvKb5f0RzPbamZ9dTczjlnZtOmS9JakWXU2M47CmZu76SMzSzfmtSsz43XV+MDv4y5z9wslfV3S97LD20bykfdsTRquaWnm5m4ZZ2bpD9T52pWd8bpqdYR/r6TZY+5/JnusEdx9b/b7gKTfqXmzD+8fnSQ1+32g5n4+0KSZm8ebWVoNeO2aNON1HeHfLOlsM/ucmX1K0jclbayhj48xs2nZBzEys2mSvqbmzT68UdKS7PYSSRtq7OVDmjJzc97M0qr5tWvcjNfu3vUfSVdr5BP/f0r6QR095PT1f5Jeyn5erbs3SY9q5DDwfY18NvJtSadJ2iTpDUl/knRqg3r7taRXJL2skaD11NTbZRo5pH9Z0vbs5+q6X7tEX7W8bpzhBwTFB35AUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4L6L4bahh5ke9v1AAAAAElFTkSuQmCC\n",
+      "text/plain": [
+       "<Figure size 432x288 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "data = get_random_data()\n",
+    "\n",
+    "digit, odd_even = net(data)\n",
+    "\n",
+    "digit = digit.argmax(axis=1)[0].asnumpy()\n",
+    "odd_even = (odd_even.sigmoid()[0] > 0.5).asnumpy()\n",
+    "\n",
+    "print(\"Predicted digit: {}, odd: {}\".format(digit, odd_even))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services