You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/22 21:44:59 UTC

[GitHub] szha closed pull request #9144: rbm tutorial

szha closed pull request #9144: rbm tutorial
URL: https://github.com/apache/incubator-mxnet/pull/9144
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/rbm/rbm-tutorial.ipynb b/example/rbm/rbm-tutorial.ipynb
new file mode 100644
index 0000000000..7266efee53
--- /dev/null
+++ b/example/rbm/rbm-tutorial.ipynb
@@ -0,0 +1,622 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# RBM: Restricted Boltzmann Machine with MNIST dataset\n",
+    "\n",
+    "- In this tutorial, we will utilize MxNet NDArray API to compose RBM(Restricted Boltzmann Machine) model, in initializing parameters before starting real neural network.\n",
+    "- This tutorial is for beginners who might want to know weight initialization and brief usage of basic MxNet API.\n",
+    "\n",
+    "## Prerequisites\n",
+    "\n",
+    "\n",
+    "You need to install \n",
+    "    - Python 2.7. https://www.python.org/downloads/\n",
+    "    - mxnet(gpu ver.) https://mxnet.incubator.apache.org/get_started/install.html\n",
+    "\n",
+    "\n",
+    "This is to use RBM model on gluon Project. RBM is Introduced by (Hinton), this model is to initialize weight efficiently. This is not to predict output, but to obtain the distribution of input data. If you want to know more, please refer to the article.\n",
+    "\n",
+    "- A Practical Guide to Training : Restricted Boltzmann Machines(Geoffrey Hinton)\n",
+    "\n",
+    "\n",
+    "## The Data\n",
+    "\n",
+    " MNIST data, well known for basic step for neural network, is what we deal with. You can easily download data via MxNet or from LeCun's website.\n",
+    "     - http://yann.lecun.com/exdb/mnist/\n",
+    "     \n",
+    "     \n",
+    "## Prepare the Data\n",
+    "\n",
+    "#### 1. For training to initialize weight, we need to load data. It is easy process.\n",
+    " \n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# -*- coding: utf-8 -*-\n",
+    "\n",
+    "import cPickle\n",
+    "import gzip\n",
+    "import os\n",
+    "\n",
+    "import numpy\n",
+    "\n",
+    "def load_data(dataset='mnist.pkl.gz'):\n",
+    "  dataset = os.path.join(os.path.split(\"__file__\")[0], 'data', dataset)\n",
+    "  f = gzip.open(dataset, 'rb')\n",
+    "  train_set, valid_set, test_set = cPickle.load(f)\n",
+    "  f.close()\n",
+    "\n",
+    "  def make_numpy_array(data_xy):\n",
+    "    data_x, data_y = data_xy\n",
+    "    return numpy.array(data_x), numpy.array(data_y)\n",
+    "\n",
+    "  train_set_x, train_set_y = make_numpy_array(train_set)\n",
+    "  valid_set_x, valid_set_y = make_numpy_array(valid_set)\n",
+    "  test_set_x, test_set_y = make_numpy_array(test_set)\n",
+    "\n",
+    "  rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y),\n",
+    "          (test_set_x, test_set_y)]\n",
+    "\n",
+    "  return rval\n",
+    "\n",
+    "datasets = load_data('mnist.pkl.gz')\n",
+    "train_set_x, train_set_y = datasets[0]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Create the Model and initialize Weight\n",
+    "RBM is a model based on Energy Model, which shows quite good performance before real neural network. Weights between visible layer(input) and hidden layer are initialized via prop up and down process."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "![Alt text](http://deeplearning.net/tutorial/_images/rbm.png)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/urllib3/contrib/pyopenssl.py:46: DeprecationWarning: OpenSSL.rand is deprecated - you should use os.urandom instead\n",
+      "  import OpenSSL.SSL\n"
+     ]
+    }
+   ],
+   "source": [
+    "import mxnet as mx\n",
+    "class RBM(object):\n",
+    "\n",
+    "    def __init__(self, n_hiddens=10,\n",
+    "                 epsilon=0.1,\n",
+    "                 W=None,\n",
+    "                 b=None,\n",
+    "                 c=None,\n",
+    "                 n_samples=784,\n",
+    "                 epochs=2):\n",
+    "        self.n_hiddens = n_hiddens\n",
+    "        self.W = W\n",
+    "        self.b = b\n",
+    "        self.c = c\n",
+    "        self.epsilon = epsilon\n",
+    "        self.n_samples = n_samples\n",
+    "        self.epochs = epochs\n",
+    "        self.h_samples = None\n",
+    "\n",
+    "    def sigmoid(self, x):\n",
+    "        \"\"\"\n",
+    "        Sigmoid logistic function \n",
+    "        \"\"\"\n",
+    "        return 1. / (1. + numpy.exp(-numpy.maximum(numpy.minimum(x, 30), -30)))\n",
+    "\n",
+    "    def propup(self, vis):\n",
+    "        \"\"\"\n",
+    "        propagate between layers\n",
+    "        \"\"\"\n",
+    "        return self.prob_h(vis)\n",
+    "\n",
+    "    def prob_h(self, vis):\n",
+    "        \"\"\"\n",
+    "        calculate the probability with given visible layer\n",
+    "        \"\"\"\n",
+    "        W_np = self.W.asnumpy()\n",
+    "        return self.sigmoid(numpy.dot(vis, W_np) + self.b)\n",
+    "\n",
+    "    def sample_h_given_v(self, vis):\n",
+    "        \"\"\"\n",
+    "        sampling hidden layer with given visible layer\n",
+    "        \"\"\"\n",
+    "        return numpy.random.binomial(1, self.propup(vis))\n",
+    "\n",
+    "    def propdown(self, hid):\n",
+    "        \"\"\"\n",
+    "        propagate between layers\n",
+    "        \"\"\"\n",
+    "        W_np = self.W.asnumpy()\n",
+    "        return self.sigmoid(numpy.dot(hid, W_np.T) + self.c)\n",
+    "\n",
+    "    def sample_v_given_h(self, hid):\n",
+    "        \"\"\"\n",
+    "        sample visible layer with given hidden layer\n",
+    "        \"\"\"\n",
+    "        return numpy.random.binomial(1, self.propdown(hid))\n",
+    "\n",
+    "    def free_energy(self, vis):\n",
+    "        \"\"\"\n",
+    "        calculate free engergy. Note that the result is negative.\n",
+    "        \"\"\"\n",
+    "        W_np = self.W.asnumpy()\n",
+    "        return - numpy.dot(vis, self.c) \\\n",
+    "               - numpy.log(1. + numpy.exp(numpy.dot(vis, W_np) + self.b)).sum(1)\n",
+    "\n",
+    "    def gibbs_chain(self, vis):\n",
+    "        \"\"\"\n",
+    "        gibbs chaning between layers\n",
+    "        \"\"\"\n",
+    "        h_ = self.sample_h(vis)\n",
+    "        v_ = self.sample_v(h_)\n",
+    "        return v_\n",
+    "\n",
+    "    def _update(self, v_pos, verbose=True):\n",
+    "        \"\"\"\n",
+    "        update halper\n",
+    "        \"\"\"\n",
+    "        h_pos = self.propup(v_pos)\n",
+    "        v_neg = self.sample_v_given_h(self.h_samples)\n",
+    "        h_neg = self.propup(v_neg)\n",
+    "        \n",
+    "        W_np = self.W.asnumpy()\n",
+    "        W_np += self.epsilon * (numpy.dot(v_pos.T, h_pos)\n",
+    "                                  - numpy.dot(v_neg.T, h_neg)) / self.n_samples\n",
+    "        self.W = mx.nd.array(W_np)\n",
+    "        self.b += self.epsilon * (h_pos.mean(0) - h_neg.mean(0))\n",
+    "        self.c += self.epsilon * (v_pos.mean(0) - v_neg.mean(0))\n",
+    "\n",
+    "        self.h_samples = numpy.random.binomial(1, h_neg)\n",
+    "\n",
+    "        return self.pseudo_likelihood(v_pos)\n",
+    "\n",
+    "    def pseudo_likelihood(self, v):\n",
+    "        \"\"\"\n",
+    "        calculate the cost\n",
+    "        \"\"\"\n",
+    "        fe = self.free_energy(v)\n",
+    "\n",
+    "        v_ = v.copy()\n",
+    "        i_ = numpy.random.randint(0, v.shape[1], v.shape[0])\n",
+    "        v_[range(v.shape[0]), i_] = v_[range(v.shape[0]), i_] == 0\n",
+    "        fe_ = self.free_energy(v_)\n",
+    "\n",
+    "        return v.shape[1] * numpy.log(self.sigmoid(fe_ - fe))\n",
+    "\n",
+    "    def update(self, X, verbose=False):\n",
+    "        \"\"\"\n",
+    "        update cost\n",
+    "        \"\"\"\n",
+    "        if self.W == None:\n",
+    "            self.W = numpy.asarray(numpy.random.normal(0, 0.01,\n",
+    "                (X.shape[1], self.n_hiddens)), dtype=X.dtype)\n",
+    "            self.W = mx.nd.array(self.W)\n",
+    "            self.b = numpy.zeros(self.n_hiddens, dtype=X.dtype)\n",
+    "            self.c = numpy.zeros(X.shape[1], dtype=X.dtype)\n",
+    "            self.h_samples = numpy.zeros((self.n_samples, self.n_hiddens),\n",
+    "                                         dtype=X.dtype)\n",
+    "\n",
+    "        inds = range(X.shape[0])\n",
+    "\n",
+    "        numpy.random.shuffle(inds)\n",
+    "\n",
+    "        n_batches = int(numpy.ceil(len(inds) / float(self.n_samples)))\n",
+    "\n",
+    "        for epoch in range(self.epochs):\n",
+    "            print(epoch)\n",
+    "            pl = 0.\n",
+    "            for minibatch in range(n_batches):\n",
+    "                pl += self._update(X[inds[minibatch::n_batches]]).sum()\n",
+    "            pl /= X.shape[0]\n",
+    "            print(self.W)\n",
+    "            print (\"Epoch %d, Pseudo-Likelihood = %.2f\" % (epoch, pl))\n",
+    "\n",
+    "\n",
+    "    def main():\n",
+    "        pass\n",
+    "\n",
+    "\n",
+    "    if __name__ == '__main__':\n",
+    "        main()\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "    - train and get the intialized Weight and dump it."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0\n",
+      "\n",
+      "[[-0.2669937  -0.26080966 -0.2550104  ..., -0.26267484 -0.27630749\n",
+      "  -0.26593226]\n",
+      " [-0.26792291 -0.26412126 -0.27305239 ..., -0.28993195 -0.29635659\n",
+      "  -0.27198008]\n",
+      " [-0.28744251 -0.26321027 -0.26610094 ..., -0.28906548 -0.27673733\n",
+      "  -0.27588615]\n",
+      " ..., \n",
+      " [-0.28515878 -0.28254196 -0.27648616 ..., -0.28323209 -0.25867426\n",
+      "  -0.29814324]\n",
+      " [-0.28058594 -0.28413859 -0.27928758 ..., -0.27831089 -0.27733001\n",
+      "  -0.28026232]\n",
+      " [-0.27965176 -0.26214284 -0.28712091 ..., -0.28626716 -0.27738428\n",
+      "  -0.28456607]]\n",
+      "<NDArray 784x10 @cpu(0)>\n",
+      "Epoch 0, Pseudo-Likelihood = -278.24\n",
+      "1\n",
+      "\n",
+      "[[-0.31722581 -0.31104183 -0.30520838 ..., -0.31287125 -0.32648882\n",
+      "  -0.31620175]\n",
+      " [-0.31315628 -0.30939487 -0.31827205 ..., -0.33511931 -0.34155318\n",
+      "  -0.31724405]\n",
+      " [-0.33147728 -0.30728689 -0.310119   ..., -0.3330701  -0.32075664\n",
+      "  -0.3199743 ]\n",
+      " ..., \n",
+      " [-0.3306635  -0.32808211 -0.32197085 ..., -0.32873008 -0.30414727\n",
+      "  -0.3436805 ]\n",
+      " [-0.32325622 -0.32684505 -0.32192981 ..., -0.3209655  -0.31998613\n",
+      "  -0.32297391]\n",
+      " [-0.32545409 -0.30801317 -0.33286527 ..., -0.33204952 -0.32312897\n",
+      "  -0.33038604]]\n",
+      "<NDArray 784x10 @cpu(0)>\n",
+      "Epoch 1, Pseudo-Likelihood = -230.39\n",
+      "2\n",
+      "\n",
+      "[[-0.34656063 -0.34037054 -0.3345139  ..., -0.34217981 -0.35578695\n",
+      "  -0.34554166]\n",
+      " [-0.34504208 -0.34128273 -0.35013109 ..., -0.36697814 -0.37339473\n",
+      "  -0.3491497 ]\n",
+      " [-0.36412287 -0.33997643 -0.3427701  ..., -0.36569414 -0.35340148\n",
+      "  -0.35265172]\n",
+      " ..., \n",
+      " [-0.36151129 -0.35893264 -0.35281274 ..., -0.35958707 -0.33498588\n",
+      "  -0.37453163]\n",
+      " [-0.35363358 -0.35722253 -0.35229218 ..., -0.35132316 -0.35034156\n",
+      "  -0.35336182]\n",
+      " [-0.35867885 -0.34126469 -0.36605495 ..., -0.36523733 -0.35634357\n",
+      "  -0.36361733]]\n",
+      "<NDArray 784x10 @cpu(0)>\n",
+      "Epoch 2, Pseudo-Likelihood = -219.00\n",
+      "3\n",
+      "\n",
+      "[[-0.37045184 -0.36426327 -0.35839516 ..., -0.36606663 -0.3796756\n",
+      "  -0.36942664]\n",
+      " [-0.36639854 -0.36263373 -0.37147391 ..., -0.38831803 -0.39473808\n",
+      "  -0.37049818]\n",
+      " [-0.39011613 -0.3659786  -0.36875641 ..., -0.39168712 -0.37939385\n",
+      "  -0.37865001]\n",
+      " ..., \n",
+      " [-0.38678032 -0.38420725 -0.37807328 ..., -0.3848615  -0.36026153\n",
+      "  -0.39981103]\n",
+      " [-0.37738124 -0.38097379 -0.37601754 ..., -0.37507504 -0.37407941\n",
+      "  -0.37713015]\n",
+      " [-0.37858096 -0.36116457 -0.3859455  ..., -0.38513628 -0.37624007\n",
+      "  -0.38351849]]\n",
+      "<NDArray 784x10 @cpu(0)>\n",
+      "Epoch 3, Pseudo-Likelihood = -217.01\n",
+      "4\n",
+      "\n",
+      "[[-0.39066988 -0.38447928 -0.37860864 ..., -0.38628352 -0.39989403\n",
+      "  -0.38964248]\n",
+      " [-0.38581786 -0.38205105 -0.39089113 ..., -0.40773496 -0.41415608\n",
+      "  -0.38991463]\n",
+      " [-0.41081282 -0.38667658 -0.38945273 ..., -0.41238484 -0.40009099\n",
+      "  -0.39934757]\n",
+      " ..., \n",
+      " [-0.4057149  -0.40314078 -0.39700639 ..., -0.4037953  -0.37919831\n",
+      "  -0.41874313]\n",
+      " [-0.39733237 -0.40092501 -0.39596796 ..., -0.39502415 -0.39403406\n",
+      "  -0.39708087]\n",
+      " [-0.39474222 -0.37732247 -0.40209258 ..., -0.40129253 -0.39239594\n",
+      "  -0.39967355]]\n",
+      "<NDArray 784x10 @cpu(0)>\n",
+      "Epoch 4, Pseudo-Likelihood = -219.13\n",
+      "5\n",
+      "\n",
+      "[[-0.40391761 -0.39773053 -0.39185682 ..., -0.39953253 -0.41314024\n",
+      "  -0.40289411]\n",
+      " [-0.40074226 -0.39697546 -0.40581551 ..., -0.42265934 -0.42908052\n",
+      "  -0.40483904]\n",
+      " [-0.42638156 -0.4022449  -0.40502089 ..., -0.42795363 -0.41565904\n",
+      "  -0.41491571]\n",
+      " ..., \n",
+      " [-0.42154524 -0.41897181 -0.41283536 ..., -0.41962427 -0.39502779\n",
+      "  -0.43457252]\n",
+      " [-0.4127686  -0.41636142 -0.41140431 ..., -0.41046035 -0.40947038\n",
+      "  -0.41251734]\n",
+      " [-0.41336769 -0.3959479  -0.4207179  ..., -0.41991758 -0.41102123\n",
+      "  -0.41829905]]\n",
+      "<NDArray 784x10 @cpu(0)>\n",
+      "Epoch 5, Pseudo-Likelihood = -218.51\n",
+      "6\n",
+      "\n",
+      "[[-0.41909733 -0.41291019 -0.40703654 ..., -0.41471228 -0.42831987\n",
+      "  -0.41807368]\n",
+      " [-0.41337326 -0.40960664 -0.41844711 ..., -0.4352906  -0.44171208\n",
+      "  -0.41747022]\n",
+      " [-0.43612379 -0.41198817 -0.41476318 ..., -0.43769524 -0.42539707\n",
+      "  -0.42465729]\n",
+      " ..., \n",
+      " [-0.43609312 -0.43352005 -0.42738396 ..., -0.43417269 -0.40957665\n",
+      "  -0.44912058]\n",
+      " [-0.42807785 -0.43167073 -0.42671362 ..., -0.42576966 -0.42477983\n",
+      "  -0.42782652]\n",
+      " [-0.42574343 -0.40832356 -0.43309388 ..., -0.43229359 -0.423397\n",
+      "  -0.43067467]]\n",
+      "<NDArray 784x10 @cpu(0)>\n",
+      "Epoch 6, Pseudo-Likelihood = -218.04\n",
+      "7\n",
+      "\n",
+      "[[-0.43057719 -0.42439005 -0.41851637 ..., -0.42619213 -0.4397997\n",
+      "  -0.42955354]\n",
+      " [-0.42523578 -0.42146915 -0.43030962 ..., -0.44715312 -0.4535746\n",
+      "  -0.42933273]\n",
+      " [-0.4487516  -0.42461598 -0.42739099 ..., -0.45032305 -0.43802488\n",
+      "  -0.4372851 ]\n",
+      " ..., \n",
+      " [-0.44770053 -0.44512746 -0.43899137 ..., -0.4457801  -0.42118406\n",
+      "  -0.46072799]\n",
+      " [-0.43840972 -0.44200259 -0.43704548 ..., -0.43610153 -0.4351117\n",
+      "  -0.43815839]\n",
+      " [-0.43824369 -0.42082381 -0.44559413 ..., -0.44479388 -0.43589726\n",
+      "  -0.44317496]]\n",
+      "<NDArray 784x10 @cpu(0)>\n",
+      "Epoch 7, Pseudo-Likelihood = -215.13\n",
+      "8\n",
+      "\n",
+      "[[-0.43988863 -0.43370149 -0.42782781 ..., -0.43550357 -0.44911113\n",
+      "  -0.43886498]\n",
+      " [-0.43339923 -0.4296326  -0.43847305 ..., -0.45531657 -0.46173802\n",
+      "  -0.43749619]\n",
+      " [-0.45831811 -0.43418249 -0.43695751 ..., -0.45988956 -0.44759139\n",
+      "  -0.44685161]\n",
+      " ..., \n",
+      " [-0.45905283 -0.45647976 -0.45034367 ..., -0.4571324  -0.43253636\n",
+      "  -0.47208029]\n",
+      " [-0.44823134 -0.45182422 -0.44686711 ..., -0.44592315 -0.44493333\n",
+      "  -0.44798002]\n",
+      " [-0.44730002 -0.42988014 -0.45465046 ..., -0.45385021 -0.44495359\n",
+      "  -0.45223129]]\n",
+      "<NDArray 784x10 @cpu(0)>\n",
+      "Epoch 8, Pseudo-Likelihood = -220.10\n",
+      "9\n",
+      "\n",
+      "[[-0.45073071 -0.44454357 -0.43866989 ..., -0.44634566 -0.45995322\n",
+      "  -0.44970706]\n",
+      " [-0.44245556 -0.43868893 -0.44752938 ..., -0.4643729  -0.47079435\n",
+      "  -0.44655252]\n",
+      " [-0.46762955 -0.44349393 -0.44626895 ..., -0.469201   -0.45690283\n",
+      "  -0.45616305]\n",
+      " ..., \n",
+      " [-0.46543053 -0.46285746 -0.45672137 ..., -0.4635101  -0.43891406\n",
+      "  -0.47845799]\n",
+      " [-0.45805299 -0.46164587 -0.45668873 ..., -0.4557448  -0.45475495\n",
+      "  -0.45780167]\n",
+      " [-0.45839715 -0.44097728 -0.46574759 ..., -0.46494734 -0.45605072\n",
+      "  -0.46332842]]\n",
+      "<NDArray 784x10 @cpu(0)>\n",
+      "Epoch 9, Pseudo-Likelihood = -218.46\n",
+      "\n",
+      "[[-0.45073071 -0.44454357 -0.43866989 ..., -0.44634566 -0.45995322\n",
+      "  -0.44970706]\n",
+      " [-0.44245556 -0.43868893 -0.44752938 ..., -0.4643729  -0.47079435\n",
+      "  -0.44655252]\n",
+      " [-0.46762955 -0.44349393 -0.44626895 ..., -0.469201   -0.45690283\n",
+      "  -0.45616305]\n",
+      " ..., \n",
+      " [-0.46543053 -0.46285746 -0.45672137 ..., -0.4635101  -0.43891406\n",
+      "  -0.47845799]\n",
+      " [-0.45805299 -0.46164587 -0.45668873 ..., -0.4557448  -0.45475495\n",
+      "  -0.45780167]\n",
+      " [-0.45839715 -0.44097728 -0.46574759 ..., -0.46494734 -0.45605072\n",
+      "  -0.46332842]]\n",
+      "<NDArray 784x10 @cpu(0)>\n"
+     ]
+    }
+   ],
+   "source": [
+    "X = numpy.array(train_set_x)\n",
+    "model = RBM(n_hiddens=10, n_samples=784,epochs=10)\n",
+    "model.update(X)\n",
+    "\n",
+    "print(model.W)\n",
+    "\n",
+    "cPickle.dump(model.W, open('weight.p', \"wb\"))\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Check if RBM-initialized weight works well.\n",
+    "- from Gluon tutorial"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Parameter mydense1_weight (shape=(784, 10), dtype=<type 'numpy.float32'>)\n",
+      "Epoch 0. Train_acc 0.903085714286, Test_acc 0.9093\n",
+      "Epoch 1. Train_acc 0.907714285714, Test_acc 0.9045125\n",
+      "Epoch 2. Train_acc 0.910352380952, Test_acc 0.908293333333\n",
+      "Epoch 3. Train_acc 0.9122, Test_acc 0.910790909091\n",
+      "Epoch 4. Train_acc 0.913762857143, Test_acc 0.912527586207\n"
+     ]
+    }
+   ],
+   "source": [
+    "from __future__ import print_function\n",
+    "import mxnet as mx\n",
+    "import numpy as np\n",
+    "import pickle as cPickle\n",
+    "from mxnet import nd, autograd, gluon\n",
+    "from mxnet.gluon import nn, Block\n",
+    "mx.random.seed(1)\n",
+    "\n",
+    "###########################\n",
+    "#  Speficy the context we'll be using\n",
+    "###########################\n",
+    "ctx = mx.gpu(0)\n",
+    "\n",
+    "###########################\n",
+    "#  Load up our dataset\n",
+    "###########################\n",
+    "batch_size = 64\n",
+    "def transform(data, label):\n",
+    "    return data.astype(np.float32)/255, label.astype(np.float32)\n",
+    "train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=transform),\n",
+    "                                      batch_size, shuffle=True)\n",
+    "test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=transform),\n",
+    "                                     batch_size, shuffle=False)\n",
+    "def relu(X):\n",
+    "    return nd.maximum(X, 0)\n",
+    "class MyDense(Block):\n",
+    "    ####################\n",
+    "    # We add arguments to our constructor (__init__)\n",
+    "    # to indicate the number of input units (``in_units``) \n",
+    "    # and output units (``units``)\n",
+    "    ####################\n",
+    "    def __init__(self, units, in_units=0, **kwargs):\n",
+    "        super(MyDense, self).__init__(**kwargs)\n",
+    "        with self.name_scope():\n",
+    "            self.units = units\n",
+    "            self._in_units = in_units\n",
+    "            #################\n",
+    "            # We add the required parameters to the ``Block``'s ParameterDict , \n",
+    "            # indicating the desired shape\n",
+    "            #################\n",
+    "            self.weight = self.params.get('weight', init=mx.init.Uniform(scale=1), shape=(in_units, units))            \n",
+    "            self.bias = self.params.get('bias', shape=(units,))\n",
+    "            print(self.weight)\n",
+    "    #################\n",
+    "    #  Now we just have to write the forward pass. \n",
+    "    #  We could rely upong the FullyConnected primitative in NDArray, \n",
+    "    #  but it's better to get our hands dirty and write it out\n",
+    "    #  so you'll know how to compose arbitrary functions\n",
+    "    #################\n",
+    "    def forward(self, x):\n",
+    "        with x.context:\n",
+    "            linear = nd.dot(x, self.weight.data()) + self.bias.data()\n",
+    "            activation = relu(linear)\n",
+    "            return activation\n",
+    "        \n",
+    "dense = MyDense(10, in_units= 28*28)\n",
+    "dense.collect_params().initialize(ctx=ctx)\n",
+    "dense.weight.set_data(cPickle.load(open('weight.p')))\n",
+    "net = gluon.nn.Sequential()\n",
+    "with net.name_scope():\n",
+    "    net.add(dense)\n",
+    "loss = gluon.loss.SoftmaxCrossEntropyLoss()\n",
+    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})\n",
+    "metric = mx.metric.Accuracy()\n",
+    "\n",
+    "def evaluate_accuracy(data_iterator, net):\n",
+    "    numerator = 0.\n",
+    "    denominator = 0.\n",
+    "    \n",
+    "    for i, (data, label) in enumerate(data_iterator):\n",
+    "        with autograd.record():\n",
+    "            data = data.as_in_context(ctx).reshape((-1,784))\n",
+    "            label = label.as_in_context(ctx)\n",
+    "            label_one_hot = nd.one_hot(label, 10)\n",
+    "            output = net(data)\n",
+    "        \n",
+    "        metric.update([label], [output])\n",
+    "    return metric.get()[1]\n",
+    "epochs = 5  # Low number for testing, set higher when you run!\n",
+    "moving_loss = 0.\n",
+    "\n",
+    "for e in range(epochs):\n",
+    "    for i, (data, label) in enumerate(train_data):\n",
+    "        data = data.as_in_context(ctx).reshape((-1,784))\n",
+    "        label = label.as_in_context(ctx)\n",
+    "        with autograd.record():\n",
+    "            output = net(data)\n",
+    "            cross_entropy = loss(output, label)\n",
+    "            cross_entropy.backward()\n",
+    "        trainer.step(data.shape[0])\n",
+    "            \n",
+    "    test_accuracy = evaluate_accuracy(test_data, net)\n",
+    "    train_accuracy = evaluate_accuracy(train_data, net)\n",
+    "    print(\"Epoch %s. Train_acc %s, Test_acc %s\" % (e, train_accuracy, test_accuracy))\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Summary\n",
+    "\n",
+    "So far, RBM initialization is quite better than initial learning. WIth Uniform Random Initialization, the accuracy in the first few epoch is very low. However, even if this initialization might be worse than Xavier initialization, it shows quite good initial accuracy"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python [conda env:mxnet_p27]",
+   "language": "python",
+   "name": "conda-env-mxnet_p27-py"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 2
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython2",
+   "version": "2.7.14"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services