You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/08/14 09:11:43 UTC

[2/3] incubator-singa git commit: SINGA-238 RBM on MNIST

SINGA-238 RBM on MNIST

Enable the training on GPU.
Fixed a bug from KernelSum() by removing it and implemented the
Sum(Tensor)->float function using Dot() (from blas).


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/5b332a40
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/5b332a40
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/5b332a40

Branch: refs/heads/dev
Commit: 5b332a4086ff32b0c3a298169c0befef78f003ca
Parents: e1a524d
Author: Wei Wang <wa...@gmail.com>
Authored: Sun Aug 14 17:07:22 2016 +0800
Committer: Wei Wang <wa...@gmail.com>
Committed: Sun Aug 14 17:07:22 2016 +0800

----------------------------------------------------------------------
 examples/mnist/README.md           |  21 ++-
 examples/mnist/train.py            | 265 ++++++++++++++++----------------
 include/singa/model/loss.h         |   1 -
 src/core/tensor/math_kernel.cu     |   5 +
 src/core/tensor/math_kernel.h      |   2 +-
 src/core/tensor/tensor.cc          |  10 +-
 src/core/tensor/tensor_math_cuda.h |   5 +-
 src/python/singa/optimizer.py      |   1 +
 8 files changed, 169 insertions(+), 141 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5b332a40/examples/mnist/README.md
----------------------------------------------------------------------
diff --git a/examples/mnist/README.md b/examples/mnist/README.md
index bfd480f..9f59e7e 100644
--- a/examples/mnist/README.md
+++ b/examples/mnist/README.md
@@ -1,3 +1,18 @@
-This example is to train an RBM model using mnist data set. This RBM follows paper http://www.cs.toronto.edu/~hinton/science.pdf and the source code for this paper can be found http://www.cs.toronto.edu/~hinton/MatlabForSciencePaper.html
-1. Download dataset mnist.pkl.gz from https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz
-2. $ python train.py
+# Train a RBM model against MNIST dataset
+
+This example is to train an RBM model using the
+MNIST dataset. The RBM model and its hyper-parameters are set following
+[Hinton's paper](http://www.cs.toronto.edu/~hinton/science.pdf)
+
+## Running instructions
+
+1. Download the pre-processed [MNIST dataset](https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz)
+
+2. Start the training
+
+        python train.py
+
+By default the training code would run on CPU. To run it on a GPU card, please start
+the program with an additional argument
+
+        python train.py --use_gpu

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5b332a40/examples/mnist/train.py
----------------------------------------------------------------------
diff --git a/examples/mnist/train.py b/examples/mnist/train.py
index 52b023a..43b8e26 100644
--- a/examples/mnist/train.py
+++ b/examples/mnist/train.py
@@ -1,131 +1,134 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =============================================================================
-
-import cPickle
-import numpy as np
-import numpy.matlib
-import os
-import sys
-import gzip, numpy
-
-
-sys.path.append(os.path.join(os.path.dirname(__file__),
-                             '../../build/python'))
-sys.path.append(os.path.join(os.path.dirname(__file__),
-                             '../../build/lib'))
-sys.path.append(os.path.join(os.path.dirname(__file__),'../../build/src'))
-from singa import initializer
-from singa import utils
-from singa import optimizer
-from singa import device
-from singa import tensor
-from singa.proto import core_pb2
-
-
-
-def load_train_data(dir_path):
-    f = gzip.open(dir_path, 'rb')
-    train_set, valid_set, test_set = cPickle.load(f)
-    traindata = train_set[0].astype(np.float32)
-    validdata = valid_set[0].astype(np.float32)
-    return traindata, validdata
-
-
-
-def train(data_dir, num_epoch=10, batch_size=100):
-    print 'Start intialization............'
-    lr = 0.1   # Learning rate
-    weight_decay  = 0.0002
-    hdim = 1000
-    vdim = 784
-    opt = optimizer.SGD(momentum=0.8, weight_decay=weight_decay)
-    
-    shape = (vdim, hdim)
-    tweight = tensor.Tensor(shape)
-    initializer.gaussian(tweight, 0.0, 0.1)
-    tvbias = tensor.from_numpy(np.zeros(vdim, dtype = np.float32))
-    thbias = tensor.from_numpy(np.zeros(hdim, dtype = np.float32))
-    opt = optimizer.SGD(momentum=0.5, weight_decay=weight_decay)
-
-    print 'Loading data ..................'
-    train_x, valid_x = load_train_data(data_dir)
-
-    num_train_batch = train_x.shape[0]/batch_size
-    print "num_train_batch = \n", num_train_batch
-    for epoch in range(num_epoch):
-        trainerrorsum = 0.0
-        validerrorsum = 0.0
-        print 'Epoch %d' % epoch
-        for b in range(num_train_batch):
-            # positive phase
-            if b % 100 == 0:
-                print "batch: \n", b
-
-            tdata = tensor.from_numpy(train_x[ (b * batch_size): ((b + 1) * batch_size), : ])
-            tposhidprob = tensor.mult(tdata, tweight)
-            tposhidprob.add_row(thbias)
-            tposhidprob = tensor.sigmoid(tposhidprob)
-            tposhidrandom = tensor.Tensor(tposhidprob.shape)
-            initializer.uniform(tposhidrandom, 0.0, 1.0)
-            tposhidsample = tensor.gt(tposhidprob, tposhidrandom)
-            
-            # negative phase
-            tnegdata = tensor.mult(tposhidsample, tweight.transpose())
-            tnegdata.add_row(tvbias)
-            tnegdata = tensor.sigmoid(tnegdata)
-
-            tneghidprob = tensor.mult(tnegdata, tweight)
-            tneghidprob.add_row(thbias) 
-            tneghidprob = tensor.sigmoid(tneghidprob)
-            trainerror = tensor.sum(tensor.eltwise_mult((tdata - tnegdata),(tdata - tnegdata)))
-            trainerrorsum = trainerror + trainerrorsum
-           
-            tgweight = tensor.mult(tnegdata.transpose(), tneghidprob) - tensor.mult(tdata.transpose(), tposhidprob)
-            tgvbias = tensor.sum(tnegdata, 0) - tensor.sum(tdata, 0)
-            tghbias = tensor.sum(tneghidprob, 0) - tensor.sum(tposhidprob, 0)
-            
-            opt.apply_with_lr(epoch, lr / batch_size, tgweight, tweight, '')
-            opt.apply_with_lr(epoch, lr / batch_size, tgvbias, tvbias, '')
-            opt.apply_with_lr(epoch, lr / batch_size, tghbias, thbias, '')
-
-        info = 'train errorsum = %f' \
-            % (trainerrorsum)
-        print info
-
-        tvaliddata = tensor.from_numpy(valid_x[ :, : ])
-        tvalidposhidprob = tensor.mult(tvaliddata, tweight)
-        tvalidposhidprob.add_row(thbias)
-        tvalidposhidprob = tensor.sigmoid(tvalidposhidprob)
-        tvalidposhidrandom = tensor.Tensor(tvalidposhidprob.shape)
-        initializer.uniform(tvalidposhidrandom, 0.0, 1.0)
-        tvalidposhidsample = tensor.gt(tvalidposhidprob, tvalidposhidrandom)
-
-        tvalidnegdata = tensor.mult(tvalidposhidsample, tweight.transpose())
-        tvalidnegdata.add_row(tvbias)
-        tvalidnegdata = tensor.sigmoid(tvalidnegdata)
-
-        validerrorsum = tensor.sum(tensor.eltwise_mult((tvaliddata - tvalidnegdata),(tvaliddata - tvalidnegdata)))
-        validinfo = 'valid errorsum = %f' \
-            % (validerrorsum)
-        print validinfo
-
-
-if __name__ == '__main__':
-    data_dir = 'mnist.pkl.gz'
-    assert os.path.exists(data_dir), \
-        'Pls download the mnist dataset'
-    train(data_dir)
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+import numpy as np
+import os
+import gzip
+import argparse
+import cPickle
+from singa import initializer
+from singa import utils
+from singa import optimizer
+from singa import device
+from singa import tensor
+
+
+from singa.proto import core_pb2
+
+
+
+def load_train_data(file_path):
+    f = gzip.open(file_path, 'rb')
+    train_set, valid_set, test_set = cPickle.load(f)
+    traindata = train_set[0].astype(np.float32)
+    validdata = valid_set[0].astype(np.float32)
+    print traindata.shape, validdata.shape
+    return traindata, validdata
+
+
+
+def train(data_file, use_gpu, num_epoch=10, batch_size=100):
+    print 'Start intialization............'
+    lr = 0.1   # Learning rate
+    weight_decay  = 0.0002
+    hdim = 1000
+    vdim = 784
+    opt = optimizer.SGD(momentum=0.8, weight_decay=weight_decay)
+
+    tweight = tensor.Tensor((vdim, hdim))
+    tweight.gaussian(0.0, 0.1)
+    tvbias = tensor.from_numpy(np.zeros(vdim, dtype = np.float32))
+    thbias = tensor.from_numpy(np.zeros(hdim, dtype = np.float32))
+    opt = optimizer.SGD(momentum=0.5, weight_decay=weight_decay)
+
+    print 'Loading data ..................'
+    train_x, valid_x = load_train_data(data_file)
+
+    if use_gpu:
+        dev = device.create_cuda_gpu()
+    else:
+        dev = device.get_default_device()
+
+    for t in [tweight, tvbias, thbias]:
+        t.to_device(dev)
+
+    num_train_batch = train_x.shape[0] / batch_size
+    print "num_train_batch = %d " % (num_train_batch)
+    for epoch in range(num_epoch):
+        trainerrorsum = 0.0
+        validerrorsum = 0.0
+        print 'Epoch %d' % epoch
+        for b in range(num_train_batch):
+            # positive phase
+            tdata = tensor.from_numpy(
+                    train_x[(b * batch_size):((b + 1) * batch_size), : ])
+            tdata.to_device(dev)
+            tposhidprob = tensor.mult(tdata, tweight)
+            tposhidprob.add_row(thbias)
+            tposhidprob = tensor.sigmoid(tposhidprob)
+            tposhidrandom = tensor.Tensor(tposhidprob.shape, dev)
+            tposhidrandom.uniform(0.0, 1.0)
+            tposhidsample = tensor.gt(tposhidprob, tposhidrandom)
+
+            # negative phase
+            tnegdata = tensor.mult(tposhidsample, tweight.transpose())
+            tnegdata.add_row(tvbias)
+            tnegdata = tensor.sigmoid(tnegdata)
+
+            tneghidprob = tensor.mult(tnegdata, tweight)
+            tneghidprob.add_row(thbias)
+            tneghidprob = tensor.sigmoid(tneghidprob)
+            error = tensor.sum(tensor.square((tdata - tnegdata)))
+            trainerrorsum = error + trainerrorsum
+
+            tgweight = tensor.mult(tnegdata.transpose(), tneghidprob) -\
+                    tensor.mult(tdata.transpose(), tposhidprob)
+            tgvbias = tensor.sum(tnegdata, 0) - tensor.sum(tdata, 0)
+            tghbias = tensor.sum(tneghidprob, 0) - tensor.sum(tposhidprob, 0)
+
+            opt.apply_with_lr(epoch, lr / batch_size, tgweight, tweight, 'w')
+            opt.apply_with_lr(epoch, lr / batch_size, tgvbias, tvbias, 'vb')
+            opt.apply_with_lr(epoch, lr / batch_size, tghbias, thbias, 'hb')
+
+        print 'training errorsum = %f' % (trainerrorsum)
+
+        tvaliddata = tensor.from_numpy(valid_x)
+        tvaliddata.to_device(dev)
+        tvalidposhidprob = tensor.mult(tvaliddata, tweight)
+        tvalidposhidprob.add_row(thbias)
+        tvalidposhidprob = tensor.sigmoid(tvalidposhidprob)
+        tvalidposhidrandom = tensor.Tensor(tvalidposhidprob.shape, dev)
+        initializer.uniform(tvalidposhidrandom, 0.0, 1.0)
+        tvalidposhidsample = tensor.gt(tvalidposhidprob, tvalidposhidrandom)
+
+        tvalidnegdata = tensor.mult(tvalidposhidsample, tweight.transpose())
+        tvalidnegdata.add_row(tvbias)
+        tvalidnegdata = tensor.sigmoid(tvalidnegdata)
+
+        validerrorsum = tensor.sum(tensor.square((tvaliddata - tvalidnegdata)))
+        print 'valid errorsum = %f' % (validerrorsum)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Train RBM over MNIST')
+    parser.add_argument('file', type=str, help='the dataset path')
+    parser.add_argument('--use_gpu', action='store_true')
+    args = parser.parse_args()
+
+    assert os.path.exists(args.file), 'Pls download the MNIST dataset from' \
+            'https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz'
+    train(args.file, args.use_gpu)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5b332a40/include/singa/model/loss.h
----------------------------------------------------------------------
diff --git a/include/singa/model/loss.h b/include/singa/model/loss.h
index 951c477..4ee41cb 100644
--- a/include/singa/model/loss.h
+++ b/include/singa/model/loss.h
@@ -51,7 +51,6 @@ public:
   /// [Evaluate|Forward] Backward.
   float Evaluate(int flag, const Tensor &prediction, const Tensor &target) {
     Tensor loss = Forward(flag, prediction, target);
-    loss.ToHost();
     return Sum<float>(loss) / (1.0f * loss.Size());
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5b332a40/src/core/tensor/math_kernel.cu
----------------------------------------------------------------------
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index e0112f3..d3f3335 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -35,6 +35,8 @@
 namespace singa {
 // Cuda Kernel Functions
 namespace cuda {
+/*
+wangwei: Not used due to error in the code.
 __global__ void KernelSum(const size_t n, const float *in, float *out) {
   int THREADS = blockDim.x;
 
@@ -65,6 +67,7 @@ __global__ void KernelSum(const size_t n, const float *in, float *out) {
   __syncthreads();
   *out = aux[0];
 }
+*/
 
 __global__ void KernelAdd(const size_t n, const float *in1, const float *in2,
                           float *out) {
@@ -461,12 +464,14 @@ void div(const size_t n, const float *in1, const float *in2, float *out,
   KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
 }
 
+/*
 void sum(const size_t n, const float *in, float *out, cudaStream_t s) {
   int threads_per_block = n > CU1DBLOCK ? CU1DBLOCK : n;
   //  here, we only need one block
   int num_blocks = 1;
   KernelSum <<<num_blocks, threads_per_block>>> (n, in, out);
 }
+*/
 
 void ComputeCrossEntropy(size_t batchsize, const size_t dim, const float *p,
                          const int *t, float *loss, cudaStream_t stream) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5b332a40/src/core/tensor/math_kernel.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index 202777e..cb0cb6a 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -101,7 +101,7 @@ void mult(const size_t n, const float *in1, const float *in2, float *out,
 void div(const size_t n, const float *in1, const float *in2, float *out,
          cudaStream_t s);
 
-void sum(const size_t n, const float *in, float *out, cudaStream_t s);
+// void sum(const size_t n, const float *in, float *out, cudaStream_t s);
 
 void ComputeCrossEntropy(const size_t batchsize, const size_t dim,
                          const float *p, const int *t, float *loss,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5b332a40/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index b80e233..670b27e 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -626,10 +626,14 @@ Tensor Average(const Tensor &M, int axis) {
 template <>
 float Sum<float>(const Tensor &in) {
   float s = 0.0f;
+  Tensor one(in.shape(), in.device(), in.data_type());
+  one.SetValue(1.0f);
   TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
-    in.device()->Exec([in, &s](Context *ctx) {
-      Sum<DType, Lang>(in.Size(), in.block(), &s, ctx);
-    }, {in.block()}, {});
+    one.device()->Exec([in, one, &s](Context *ctx) {
+      DType ret = DType(0);
+      Dot<DType, Lang>(in.Size(), in.block(), one.block(), &ret, ctx);
+      s = ret;
+    }, {in.block(), one.block()}, {});
   });
   return s;
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5b332a40/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index 1cd61b3..4daa97a 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -263,8 +263,9 @@ void Sub<float, lang::Cuda>(const size_t num, const Block* in1,
 template <>
 void Sum<float, lang::Cuda>(const size_t num, const Block* in, float* out,
                             Context* ctx) {
-  const float* inPtr = static_cast<const float*>(in->data());
-  cuda::sum(num, inPtr, out, ctx->stream);
+  LOG(FATAL) << "Cuda Sum is not implemented!";
+  // const float* inPtr = static_cast<const float*>(in->data());
+  // cuda::sum(num, inPtr, out, ctx->stream);
 }
 
 /// Element-wise operation, out[i]=tanh([in[i])

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5b332a40/src/python/singa/optimizer.py
----------------------------------------------------------------------
diff --git a/src/python/singa/optimizer.py b/src/python/singa/optimizer.py
index 7cab746..aa6bdd1 100644
--- a/src/python/singa/optimizer.py
+++ b/src/python/singa/optimizer.py
@@ -187,6 +187,7 @@ class SGD(Optimizer):
         """
         super(SGD, self).__init__(lr, momentum, decay)
         conf = model_pb2.OptimizerConf()
+        conf.momentum = momentum
         self.opt = singa.CreateOptimizer('SGD')
         self.opt.Setup(conf.SerializeToString())