You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/08/17 18:02:32 UTC

[11/51] [abbrv] incubator-singa git commit: SINGA-238 RBM on mnist

SINGA-238 RBM on mnist

Implement RBM python version on mnist data set
1. The model is following: http://www.cs.toronto.edu/~hinton/science.pdf
2. This model is implemented using python tensors
3. Users should first download mnist.pkl.gz


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e1a524d1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e1a524d1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e1a524d1

Branch: refs/heads/master
Commit: e1a524d1f428fa4289bdaee48e3b82acac6c0260
Parents: a91bf2a
Author: zhaojing <zh...@comp.nus.edu.sg>
Authored: Tue Aug 9 23:36:49 2016 +0800
Committer: zhaojing <zh...@comp.nus.edu.sg>
Committed: Sun Aug 14 13:29:54 2016 +0800

----------------------------------------------------------------------
 examples/mnist/README.md           |   3 +
 examples/mnist/train.py            | 131 ++++++++++++++++++++++++++++++++
 include/singa/core/tensor.h        |  19 +++++
 src/core/tensor/math_kernel.cu     |  48 +++++++++++-
 src/core/tensor/math_kernel.h      |  12 +++
 src/core/tensor/tensor.cc          |   5 +-
 src/core/tensor/tensor_math.h      |  24 ++++++
 src/core/tensor/tensor_math_cpp.h  |  42 ++++++++++
 src/core/tensor/tensor_math_cuda.h |  35 ++++++++-
 src/python/singa/optimizer.py      |   6 +-
 src/python/singa/tensor.py         |  20 ++++-
 src/python/swig/core_tensor.i      |  10 +++
 12 files changed, 343 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/examples/mnist/README.md
----------------------------------------------------------------------
diff --git a/examples/mnist/README.md b/examples/mnist/README.md
new file mode 100644
index 0000000..bfd480f
--- /dev/null
+++ b/examples/mnist/README.md
@@ -0,0 +1,3 @@
+This example is to train an RBM model using mnist data set. This RBM follows paper http://www.cs.toronto.edu/~hinton/science.pdf and the source code for this paper can be found http://www.cs.toronto.edu/~hinton/MatlabForSciencePaper.html
+1. Download dataset mnist.pkl.gz from https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz
+2. $ python train.py

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/examples/mnist/train.py
----------------------------------------------------------------------
diff --git a/examples/mnist/train.py b/examples/mnist/train.py
new file mode 100644
index 0000000..52b023a
--- /dev/null
+++ b/examples/mnist/train.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+import cPickle
+import numpy as np
+import numpy.matlib
+import os
+import sys
+import gzip, numpy
+
+
+sys.path.append(os.path.join(os.path.dirname(__file__),
+                             '../../build/python'))
+sys.path.append(os.path.join(os.path.dirname(__file__),
+                             '../../build/lib'))
+sys.path.append(os.path.join(os.path.dirname(__file__),'../../build/src'))
+from singa import initializer
+from singa import utils
+from singa import optimizer
+from singa import device
+from singa import tensor
+from singa.proto import core_pb2
+
+
+
+def load_train_data(dir_path):
+    f = gzip.open(dir_path, 'rb')
+    train_set, valid_set, test_set = cPickle.load(f)
+    traindata = train_set[0].astype(np.float32)
+    validdata = valid_set[0].astype(np.float32)
+    return traindata, validdata
+
+
+
+def train(data_dir, num_epoch=10, batch_size=100):
+    print 'Start intialization............'
+    lr = 0.1   # Learning rate
+    weight_decay  = 0.0002
+    hdim = 1000
+    vdim = 784
+    opt = optimizer.SGD(momentum=0.8, weight_decay=weight_decay)
+    
+    shape = (vdim, hdim)
+    tweight = tensor.Tensor(shape)
+    initializer.gaussian(tweight, 0.0, 0.1)
+    tvbias = tensor.from_numpy(np.zeros(vdim, dtype = np.float32))
+    thbias = tensor.from_numpy(np.zeros(hdim, dtype = np.float32))
+    opt = optimizer.SGD(momentum=0.5, weight_decay=weight_decay)
+
+    print 'Loading data ..................'
+    train_x, valid_x = load_train_data(data_dir)
+
+    num_train_batch = train_x.shape[0]/batch_size
+    print "num_train_batch = \n", num_train_batch
+    for epoch in range(num_epoch):
+        trainerrorsum = 0.0
+        validerrorsum = 0.0
+        print 'Epoch %d' % epoch
+        for b in range(num_train_batch):
+            # positive phase
+            if b % 100 == 0:
+                print "batch: \n", b
+
+            tdata = tensor.from_numpy(train_x[ (b * batch_size): ((b + 1) * batch_size), : ])
+            tposhidprob = tensor.mult(tdata, tweight)
+            tposhidprob.add_row(thbias)
+            tposhidprob = tensor.sigmoid(tposhidprob)
+            tposhidrandom = tensor.Tensor(tposhidprob.shape)
+            initializer.uniform(tposhidrandom, 0.0, 1.0)
+            tposhidsample = tensor.gt(tposhidprob, tposhidrandom)
+            
+            # negative phase
+            tnegdata = tensor.mult(tposhidsample, tweight.transpose())
+            tnegdata.add_row(tvbias)
+            tnegdata = tensor.sigmoid(tnegdata)
+
+            tneghidprob = tensor.mult(tnegdata, tweight)
+            tneghidprob.add_row(thbias) 
+            tneghidprob = tensor.sigmoid(tneghidprob)
+            trainerror = tensor.sum(tensor.eltwise_mult((tdata - tnegdata),(tdata - tnegdata)))
+            trainerrorsum = trainerror + trainerrorsum
+           
+            tgweight = tensor.mult(tnegdata.transpose(), tneghidprob) - tensor.mult(tdata.transpose(), tposhidprob)
+            tgvbias = tensor.sum(tnegdata, 0) - tensor.sum(tdata, 0)
+            tghbias = tensor.sum(tneghidprob, 0) - tensor.sum(tposhidprob, 0)
+            
+            opt.apply_with_lr(epoch, lr / batch_size, tgweight, tweight, '')
+            opt.apply_with_lr(epoch, lr / batch_size, tgvbias, tvbias, '')
+            opt.apply_with_lr(epoch, lr / batch_size, tghbias, thbias, '')
+
+        info = 'train errorsum = %f' \
+            % (trainerrorsum)
+        print info
+
+        tvaliddata = tensor.from_numpy(valid_x[ :, : ])
+        tvalidposhidprob = tensor.mult(tvaliddata, tweight)
+        tvalidposhidprob.add_row(thbias)
+        tvalidposhidprob = tensor.sigmoid(tvalidposhidprob)
+        tvalidposhidrandom = tensor.Tensor(tvalidposhidprob.shape)
+        initializer.uniform(tvalidposhidrandom, 0.0, 1.0)
+        tvalidposhidsample = tensor.gt(tvalidposhidprob, tvalidposhidrandom)
+
+        tvalidnegdata = tensor.mult(tvalidposhidsample, tweight.transpose())
+        tvalidnegdata.add_row(tvbias)
+        tvalidnegdata = tensor.sigmoid(tvalidnegdata)
+
+        validerrorsum = tensor.sum(tensor.eltwise_mult((tvaliddata - tvalidnegdata),(tvaliddata - tvalidnegdata)))
+        validinfo = 'valid errorsum = %f' \
+            % (validerrorsum)
+        print validinfo
+
+
+if __name__ == '__main__':
+    data_dir = 'mnist.pkl.gz'
+    assert os.path.exists(data_dir), \
+        'Pls download the mnist dataset'
+    train(data_dir)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 3420a0c..2075b5d 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -283,23 +283,42 @@ Tensor operator<(const Tensor &in, const SType x);
 template <typename SType>
 void LT(const Tensor &in, const SType x, Tensor *out);
 
+/// Element-wise operation, out[i]= (in1[i] < in2[i]) ? 1.f : 0.f
+Tensor operator<(const Tensor &in1, const Tensor& in2);
+void LT(const Tensor &in1, const Tensor& in2, Tensor *out);
+
 /// Element-wise operation, out[i]= (in[i] <= x) ? 1.f : 0.f
 template <typename SType>
 Tensor operator<=(const Tensor &in, const SType x);
 template <typename SType>
 void LE(const Tensor &in, const SType x, Tensor *out);
+
+/// Element-wise operation, out[i]= (in1[i] <= in2[i]) ? 1.f : 0.f
+Tensor operator<=(const Tensor &in1, const Tensor& in2);
+void LE(const Tensor &in1, const Tensor& in2, Tensor *out);
+
 /// Element-wise operation, out[i]= (in[i] > x) ? 1.f : 0.f
 template <typename SType>
 Tensor operator>(const Tensor &in, const SType x);
 template <typename SType>
 void GT(const Tensor &in, const SType x, Tensor *out);
 
+/// Element-wise operation, out[i]= (in1[i] > in2[i]) ? 1.f : 0.f
+Tensor operator>(const Tensor &in1, const Tensor& in2);
+void GT(const Tensor &in1, const Tensor& in2, Tensor *out);
+
+
 /// Element-wise operation, out[i]= (in[i] >= x) ? 1.f : 0.f
 template <typename SType>
 Tensor operator>=(const Tensor &in, const SType x);
 template <typename SType>
 void GE(const Tensor &in, const SType x, Tensor *out);
 
+/// Element-wise operation, out[i]= (in1[i] >= in2[i]) ? 1.f : 0.f
+Tensor operator>=(const Tensor &in1, const Tensor& in2);
+void GE(const Tensor &in1, const Tensor& in2, Tensor *out);
+
+
 Tensor operator+(const Tensor &lhs, const Tensor &rhs);
 void Add(const Tensor &lhs, const Tensor &rhs, Tensor *out);
 Tensor operator-(const Tensor &lhs, const Tensor &rhs);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/src/core/tensor/math_kernel.cu
----------------------------------------------------------------------
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index 13005af..e0112f3 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -243,6 +243,14 @@ __global__ void KernelGE(const size_t num, const float *in, const float x,
     out[idx] = in[idx] >= x ? 1.0f : 0.0f;
   }
 }
+
+__global__ void KernelBGE(const size_t num, const float *in1, const float *in2,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in1[idx] >= in2[idx] ? 1.0f : 0.0f;
+  }
+}
 __global__ void KernelGT(const size_t num, const float *in, const float x,
                          float *out) {
   for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
@@ -250,6 +258,13 @@ __global__ void KernelGT(const size_t num, const float *in, const float x,
     out[idx] = in[idx] > x ? 1.0f : 0.0f;
   }
 }
+__global__ void KernelBGT(const size_t num, const float *in1, const float *in2,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in1[idx] > in2[idx] ? 1.0f : 0.0f;
+  }
+}
 __global__ void KernelLE(const size_t num, const float *in, const float x,
                          float *out) {
   for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
@@ -257,7 +272,13 @@ __global__ void KernelLE(const size_t num, const float *in, const float x,
     out[idx] = in[idx] <= x ? 1.0f : 0.0f;
   }
 }
-
+__global__ void KernelBLE(const size_t num, const float *in1, const float *in2,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in1[idx] <= in2[idx] ? 1.0f : 0.0f;
+  }
+}
 __global__ void KernelLT(const size_t num, const float *in, const float x,
                          float *out) {
   for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
@@ -265,7 +286,13 @@ __global__ void KernelLT(const size_t num, const float *in, const float x,
     out[idx] = in[idx] < x ? 1.0f : 0.0f;
   }
 }
-
+__global__ void KernelBLT(const size_t num, const float *in1, const float *in2,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in1[idx] < in2[idx] ? 1.0f : 0.0f;
+  }
+}
 __global__ void KernelRowMax(const size_t nrow, const size_t ncol, const float *inPtr,
     float *outPtr) {
   for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < nrow;
@@ -381,19 +408,34 @@ void gt(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s) {
   KernelGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
 }
+void gt(const size_t num, const float *in1, const float *in2, float *out,
+        cudaStream_t s) {
+  KernelBGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+}
 void ge(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s) {
   KernelGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
 }
+void ge(const size_t num, const float *in1, const float *in2, float *out,
+        cudaStream_t s) {
+  KernelBGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+}
 void lt(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s) {
   KernelLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
 }
+void lt(const size_t num, const float *in1, const float *in2, float *out,
+        cudaStream_t s) {
+  KernelBLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+}
 void le(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s) {
   KernelLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
 }
-
+void le(const size_t num, const float *in1, const float *in2, float *out,
+        cudaStream_t s) {
+  KernelBLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+}
 void pow(const size_t n, const float *in1, const float *in2, float *out,
          cudaStream_t s) {
   KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/src/core/tensor/math_kernel.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index 63b0d82..202777e 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -66,12 +66,24 @@ void threshold(const size_t n, const float x, const float *in, float *out,
 
 void gt(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s);
+void gt(const size_t num, const float *in1, const float *in2, float *out,
+        cudaStream_t s);
+
 void ge(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s);
+void ge(const size_t num, const float *in1, const float *in2, float *out,
+        cudaStream_t s);
+
+
 void lt(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s);
+void lt(const size_t num, const float *in1, const float *in2, float *out,
+        cudaStream_t s);
+
 void le(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s);
+void le(const size_t num, const float *in1, const float *in2, float *out,
+        cudaStream_t s);
 
 // 2 inputs
 void pow(const size_t n, const float *in1, const float *in2, float *out,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index dfb1eb2..b80e233 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -541,7 +541,10 @@ GenBinaryTensorFn(operator-, Sub);
 GenBinaryTensorFn(operator*, EltwiseMult);
 GenBinaryTensorFn(operator/, Div);
 GenBinaryTensorFn(Pow, Pow);
-
+GenBinaryTensorFn(operator<, LT);
+GenBinaryTensorFn(operator<=, LE);
+GenBinaryTensorFn(operator>, GT);
+GenBinaryTensorFn(operator>=, GE);
 #define EltwiseTensorScalarFn(fn, t, x, ret)                            \
   do {                                                                  \
     TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, {  \

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index 1914ca6..bf913c0 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -127,6 +127,12 @@ void LE(const size_t num, const Block *in, const DType x, Block *out,
         Context *ctx) {
   LOG(FATAL) << "LE Not Implemented";
 }
+/// out[i]=(in1[i]<=in2[i])?1.f:0.f
+template <typename DType, typename Lang>
+void LE(const size_t num, const Block *in1, const Block *in2, Block *out,
+        Context *ctx) {
+  LOG(FATAL) << "Tensor-Tensor LE Not Implemented";
+}
 /// Natual logarithm, the base is e, Neper number out[i]=log(in[i]).
 template <typename DType, typename Lang>
 void Log(const size_t num, const Block *in, Block *out, Context *ctx) {
@@ -138,18 +144,36 @@ void LT(const size_t num, const Block *in, const DType x, Block *out,
         Context *ctx) {
   LOG(FATAL) << "LT Not Implemented";
 }
+/// out[i]=(in1[i]<in2[i])?1.f:0.f
+template <typename DType, typename Lang>
+void LT(const size_t num, const Block *in1, const Block *in2, Block *out,
+        Context *ctx) {
+  LOG(FATAL) << "Tensor-Tensor LT Not Implemented";
+}
 /// out[i]=(in[i]>=x)?1.f:0.f
 template <typename DType, typename Lang>
 void GE(const size_t num, const Block *in, const DType x, Block *out,
         Context *ctx) {
   LOG(FATAL) << "GE Not Implemented";
 }
+/// out[i]=(in1[i]>=in2[i])?1.f:0.f
+template <typename DType, typename Lang>
+void GE(const size_t num, const Block *in1, const Block *in2, Block *out,
+        Context *ctx) {
+  LOG(FATAL) << "Tensor-Tensor GE Not Implemented";
+}
 /// out[i]=(in[i]>x)?1.f:0.f
 template <typename DType, typename Lang>
 void GT(const size_t num, const Block *in, const DType x, Block *out,
         Context *ctx) {
   LOG(FATAL) << "GT Not Implemented";
 }
+/// out[i]=(in[i]>in2[i])?1.f:0.f
+template <typename DType, typename Lang>
+void GT(const size_t num, const Block *in, const Block *in2, Block *out,
+        Context *ctx) {
+  LOG(FATAL) << "Tensor-Tensor GT Not Implemented";
+}
 /// out[i] = pow(in[i], x)
 template <typename DType, typename Lang>
 void Pow(const size_t num, const Block *in, const DType x, Block *out,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index a2802d5..8c8a40a 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -142,6 +142,16 @@ void GE<float, lang::Cpp>(const size_t num, const Block *in, const float x,
 }
 
 template <>
+void GE<float, lang::Cpp>(const size_t num, const Block *in1, const Block *in2,
+                          Block *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *inPtr1 = static_cast<const float *>(in1->data());
+  const float *inPtr2 = static_cast<const float *>(in2->data());
+  for (size_t i = 0; i < num; i++) {
+    outPtr[i] = (inPtr1[i] >= inPtr2[i]) ? 1.f : 0.f;
+  }
+}
+template <>
 void GT<float, lang::Cpp>(const size_t num, const Block *in, const float x,
                           Block *out, Context *ctx) {
   float *outPtr = static_cast<float *>(out->mutable_data());
@@ -151,6 +161,17 @@ void GT<float, lang::Cpp>(const size_t num, const Block *in, const float x,
   }
 }
 template <>
+void GT<float, lang::Cpp>(const size_t num, const Block *in1, const Block *in2,
+                          Block *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *inPtr1 = static_cast<const float *>(in1->data());
+  const float *inPtr2 = static_cast<const float *>(in2->data());
+  for (size_t i = 0; i < num; i++) {
+    outPtr[i] = (inPtr1[i] > inPtr2[i]) ? 1.f : 0.f;
+  }
+}
+
+template <>
 void LE<float, lang::Cpp>(const size_t num, const Block *in, const float x,
                           Block *out, Context *ctx) {
   float *outPtr = static_cast<float *>(out->mutable_data());
@@ -160,6 +181,16 @@ void LE<float, lang::Cpp>(const size_t num, const Block *in, const float x,
   }
 }
 template <>
+void LE<float, lang::Cpp>(const size_t num, const Block *in1, const Block *in2,
+                          Block *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *inPtr1 = static_cast<const float *>(in1->data());
+  const float *inPtr2 = static_cast<const float *>(in2->data());
+  for (size_t i = 0; i < num; i++) {
+    outPtr[i] = (inPtr1[i] <= inPtr2[i]) ? 1.f : 0.f;
+  }
+}
+template <>
 void Log<float, lang::Cpp>(const size_t num, const Block *in, Block *out,
                            Context *ctx) {
   float *outPtr = static_cast<float *>(out->mutable_data());
@@ -179,6 +210,17 @@ void LT<float, lang::Cpp>(const size_t num, const Block *in, const float x,
   }
 }
 template <>
+void LT<float, lang::Cpp>(const size_t num, const Block *in1, const Block *in2,
+                          Block *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *inPtr1 = static_cast<const float *>(in1->data());
+  const float *inPtr2 = static_cast<const float *>(in2->data());
+  for (size_t i = 0; i < num; i++) {
+    outPtr[i] = (inPtr1[i] < inPtr2[i]) ? 1.f : 0.f;
+  }
+}
+
+template <>
 void Pow<float, lang::Cpp>(const size_t num, const Block *in, const float x,
                            Block *out, Context *ctx) {
   float *outPtr = static_cast<float *>(out->mutable_data());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index 8b6e939..1cd61b3 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -117,6 +117,15 @@ void GE<float, lang::Cuda>(const size_t num, const Block* in, const float x,
   const float* inPtr = static_cast<const float*>(in->data());
   cuda::ge(num, inPtr, x, outPtr, ctx->stream);
 }
+template <>
+void GE<float, lang::Cuda>(const size_t num, const Block* in1, const Block* in2,
+                           Block* out, Context* ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr1 = static_cast<const float*>(in1->data());
+  const float* inPtr2 = static_cast<const float*>(in2->data());
+  cuda::ge(num, inPtr1, inPtr2, outPtr, ctx->stream);
+}
+
 
 template <>
 void GT<float, lang::Cuda>(const size_t num, const Block* in, const float x,
@@ -125,7 +134,14 @@ void GT<float, lang::Cuda>(const size_t num, const Block* in, const float x,
   const float* inPtr = static_cast<const float*>(in->data());
   cuda::gt(num, inPtr, x, outPtr, ctx->stream);
 }
-
+template <>
+void GT<float, lang::Cuda>(const size_t num, const Block* in1, const Block* in2,
+                           Block* out, Context* ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr1 = static_cast<const float*>(in1->data());
+  const float* inPtr2 = static_cast<const float*>(in2->data());
+  cuda::gt(num, inPtr1, inPtr2, outPtr, ctx->stream);
+}
 template <>
 void LE<float, lang::Cuda>(const size_t num, const Block* in, const float x,
                            Block* out, Context* ctx) {
@@ -133,6 +149,14 @@ void LE<float, lang::Cuda>(const size_t num, const Block* in, const float x,
   const float* inPtr = static_cast<const float*>(in->data());
   cuda::le(num, inPtr, x, outPtr, ctx->stream);
 }
+template <>
+void LE<float, lang::Cuda>(const size_t num, const Block* in1, const Block* in2,
+                           Block* out, Context* ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr1 = static_cast<const float*>(in1->data());
+  const float* inPtr2 = static_cast<const float*>(in2->data());
+  cuda::le(num, inPtr1, inPtr2, outPtr, ctx->stream);
+}
 
 /// Natual logarithm, the base is e, Neper number out[i]=ln(in[i]).
 template <>
@@ -149,7 +173,14 @@ void LT<float, lang::Cuda>(const size_t num, const Block* in, const float x,
   const float* inPtr = static_cast<const float*>(in->data());
   cuda::lt(num, inPtr, x, outPtr, ctx->stream);
 }
-
+template <>
+void LT<float, lang::Cuda>(const size_t num, const Block* in1, const Block* in2,
+                           Block* out, Context* ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr1 = static_cast<const float*>(in1->data());
+  const float* inPtr2 = static_cast<const float*>(in2->data());
+  cuda::lt(num, inPtr1, inPtr2, outPtr, ctx->stream);
+}
 /// Element-wise operation, out[i] = in[i]^x
 template <>
 void Pow<float, lang::Cuda>(const size_t num, const Block* in, const float x,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/src/python/singa/optimizer.py
----------------------------------------------------------------------
diff --git a/src/python/singa/optimizer.py b/src/python/singa/optimizer.py
index 503527f..7cab746 100644
--- a/src/python/singa/optimizer.py
+++ b/src/python/singa/optimizer.py
@@ -102,10 +102,12 @@ class Optimizer(object):
             name (str): parameter name
             specs (ParamSpec): protobuf obj
         """
+	assert type(specs) == model_pb2.ParamSpec, \
+		'specs should be model_pb2.ParamSpec instance'
         if specs.HasField('regularizer'):
-            self.regularizers[name] = CppRegularizer(specs.constraint)
+            self.regularizers[name] = CppRegularizer(specs.regularizer)
         if specs.HasField('constraint'):
-            self.constraints[name] = CppConstraint(specs.regularizer)
+            self.constraints[name] = CppConstraint(specs.constraint)
         if specs.lr_mult != 1:
             self.learning_rate_multiplier[name] = specs.lr_mult
         if specs.decay_mult != 1:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/src/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/src/python/singa/tensor.py b/src/python/singa/tensor.py
index 6e84a4f..ed651e9 100644
--- a/src/python/singa/tensor.py
+++ b/src/python/singa/tensor.py
@@ -238,16 +238,28 @@ class Tensor(object):
                                     self.singa_tensor, rhs)
 
     def __lt__(self, rhs):
-        return _call_singa_func(singa.LT_Tf, self.singa_tensor, rhs)
+        if isinstance(rhs, Tensor):
+            return _call_singa_func(singa.LT_TT, self.singa_tensor, rhs.singa_tensor)
+        else:
+            return _call_singa_func(singa.LT_Tf, self.singa_tensor, rhs)
 
     def __le__(self, rhs):
-        return _call_singa_func(singa.LE_Tf, self.singa_tensor, rhs)
+        if isinstance(rhs, Tensor):
+            return _call_singa_func(singa.LE_TT, self.singa_tensor, rhs.singa_tensor)
+        else:
+            return _call_singa_func(singa.LE_Tf, self.singa_tensor, rhs)
 
     def __gt__(self, rhs):
-        return _call_singa_func(singa.GT_Tf, self.singa_tensor, rhs)
+        if isinstance(rhs, Tensor):
+            return _call_singa_func(singa.GT_TT, self.singa_tensor, rhs.singa_tensor)
+        else:
+            return _call_singa_func(singa.GT_Tf, self.singa_tensor, rhs)
 
     def __ge__(self, rhs):
-        return _call_singa_func(singa.GE_Tf, self.singa_tensor, rhs)
+        if isinstance(rhs, Tensor):
+            return _call_singa_func(singa.GE_TT, self.singa_tensor, rhs.singa_tensor)
+        else:
+            return _call_singa_func(singa.GE_Tf, self.singa_tensor, rhs)
 
 
 ''' python functions for global functions in Tensor.h

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e1a524d1/src/python/swig/core_tensor.i
----------------------------------------------------------------------
diff --git a/src/python/swig/core_tensor.i b/src/python/swig/core_tensor.i
index c4ee610..60f8b45 100644
--- a/src/python/swig/core_tensor.i
+++ b/src/python/swig/core_tensor.i
@@ -207,6 +207,16 @@ namespace singa{
   %rename(LE_Tf) operator<=(const Tensor &t, const float x);
   %rename(GT_Tf) operator>(const Tensor &t, const float x);
   %rename(GE_Tf) operator>=(const Tensor &t, const float x);
+  %rename(LT_TT) operator<(const Tensor &lhs, const Tensor &rhs);
+  %rename(LE_TT) operator<=(const Tensor &lhs, const Tensor &rhs);
+  %rename(GT_TT) operator>(const Tensor &lhs, const Tensor &rhs);
+  %rename(GE_TT) operator>=(const Tensor &lhs, const Tensor &rhs);
+
+  Tensor operator<(const Tensor &lhs, const Tensor &rhs);
+  Tensor operator<=(const Tensor &lhs, const Tensor &rhs);
+  Tensor operator>(const Tensor &lhs, const Tensor &rhs);
+  Tensor operator>=(const Tensor &lhs, const Tensor &rhs);
+
 
   template <typename DType>
   Tensor operator<(const Tensor &t, const DType x);