You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2017/01/23 04:51:09 UTC

[1/2] incubator-singa git commit: SINGA-296 - Add sign and to_host function for pysinga tensor module

Repository: incubator-singa
Updated Branches:
  refs/heads/master f647d685f -> 2d5f696bd


SINGA-296 - Add sign and to_host function for pysinga tensor module

add sign func for pysinga tensor; add tensor.to_host() which copies the tensor to a host tensor


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/3d407061
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/3d407061
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/3d407061

Branch: refs/heads/master
Commit: 3d4070618022278902bea96d93eb914503b50ead
Parents: f647d68
Author: wangwei <wa...@comp.nus.edu.sg>
Authored: Sun Jan 22 12:02:35 2017 +0800
Committer: wangwei <wa...@comp.nus.edu.sg>
Committed: Sun Jan 22 12:07:38 2017 +0800

----------------------------------------------------------------------
 python/singa/loss.py                  |   8 +-
 python/singa/metric.py                |   1 +
 python/singa/net.py                   |  36 ++++++++-
 python/singa/snapshot.py              |  12 ++-
 python/singa/tensor.py                |  50 +++++++-----
 src/core/tensor/tensor_math_cpp.h     |   2 +-
 src/core/tensor/tensor_math_opencl.cl |  86 ++++++++++-----------
 src/core/tensor/tensor_math_opencl.h  | 120 ++++++++++++++---------------
 test/singa/test_tensor_math.cc        |   2 +-
 9 files changed, 187 insertions(+), 130 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3d407061/python/singa/loss.py
----------------------------------------------------------------------
diff --git a/python/singa/loss.py b/python/singa/loss.py
index f3330dc..60835fc 100644
--- a/python/singa/loss.py
+++ b/python/singa/loss.py
@@ -25,14 +25,13 @@ Example usage::
 
     from singa import tensor
     from singa import loss
-    from singa.proto import model_pb2
 
     x = tensor.Tensor((3, 5))
     x.uniform(0, 1)  # randomly genearte the prediction activation
     y = tensor.from_numpy(np.array([0, 1, 3], dtype=np.int))  # set the truth
 
     f = loss.SoftmaxCrossEntropy()
-    l = f.forward(model_pb2.kTrain, x, y)  # l is tensor with 3 loss values
+    l = f.forward(True, x, y)  # l is tensor with 3 loss values
     g = f.backward()  # g is a tensor containing all gradients of x w.r.t l
 '''
 
@@ -42,7 +41,6 @@ from proto import model_pb2
 import tensor
 
 
-
 class Loss(object):
     '''Base loss class.
 
@@ -58,7 +56,7 @@ class Loss(object):
         '''Compute the loss values.
 
         Args:
-            flag (int): kTrain or kEval. If it is kTrain, then the backward
+            flag: kTrain/kEval or bool. If it is kTrain/True, then the backward
                 function must be called before calling forward again.
             x (Tensor): the prediction Tensor
             y (Tensor): the ground truch Tensor, x.shape[0] must = y.shape[0]
@@ -125,7 +123,7 @@ class SquaredError(Loss):
     It is implemented using Python Tensor operations.
     '''
     def __init__(self):
-        super(SquareLoss, self).__init__()
+        super(SquaredError, self).__init__()
         self.err = None
 
     def forward(self, flag, x, y):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3d407061/python/singa/metric.py
----------------------------------------------------------------------
diff --git a/python/singa/metric.py b/python/singa/metric.py
index 3a5750d..da8213b 100644
--- a/python/singa/metric.py
+++ b/python/singa/metric.py
@@ -35,6 +35,7 @@ Example usage::
 
 '''
 
+
 from . import singa_wrap as singa
 import tensor
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3d407061/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index 36c70f8..9d09740 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -17,8 +17,42 @@
 """
 Nerual net class for constructing the nets using layers and providing access
 functions for net info, e.g., parameters.
-"""
 
+Example usages::
+
+    from singa import net as ffnet
+    from singa import metric
+    from singa import loss
+    from singa import layer
+    from singa import device
+
+    # create net and add layers
+    net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy())
+    net.add(layer.Conv2D('conv1', 32, 5, 1, input_sample_shape=(3,32,32,)))
+    net.add(layer.Activation('relu1'))
+    net.add(layer.MaxPooling2D('pool1', 3, 2))
+    net.add(layer.Flatten('flat'))
+    net.add(layer.Dense('dense', 10))
+
+    # init parameters
+    for p in net.param_values():
+        if len(p.shape) == 0:
+            p.set_value(0)
+        else:
+            p.gaussian(0, 0.01)
+
+    # move net onto gpu
+    dev = device.create_cuda_gpu()
+    net.to_device(dev)
+
+    # training (skipped)
+
+    # do prediction after training
+    x = tensor.Tensor((2, 3, 32, 32), dev)
+    x.uniform(-1, 1)
+    y = net.predict(x)
+    print tensor.to_numpy(y)
+"""
 
 from .proto.model_pb2 import kTrain, kEval
 import tensor

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3d407061/python/singa/snapshot.py
----------------------------------------------------------------------
diff --git a/python/singa/snapshot.py b/python/singa/snapshot.py
index c259850..bd8918e 100644
--- a/python/singa/snapshot.py
+++ b/python/singa/snapshot.py
@@ -18,6 +18,16 @@
 '''
 This script includes io::snapshot class and its methods.
 
+Example usages::
+
+    from singa import snapshot
+
+    sn1 = snapshot.Snapshot('param', False)
+    params = sn1.read()  # read all params as a dictionary
+
+    sn2 = snapshot.Snapshot('param_new', False)
+    for k, v in params.iteritems():
+        sn2.write(k, v)
 '''
 
 from . import singa_wrap as singa
@@ -36,7 +46,7 @@ class Snapshot(object):
             buffer_size (int): Buffer size (in MB), default is 10
         '''
         self.snapshot = singa.Snapshot(f, mode, buffer_size)
-    
+
     def write(self, param_name, param_val):
         '''Call Write method to write a parameter
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3d407061/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index 57ce563..d1851d1 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -21,11 +21,11 @@ Example usage::
     from singa import tensor
     from singa import device
 
-# create a tensor with shape (2,3), default CppCPU device and float32
+    # create a tensor with shape (2,3), default CppCPU device and float32
     x = tensor.Tensor((2, 3))
     x.set_value(0.4)
 
-# create a tensor from a numpy array
+    # create a tensor from a numpy array
     npy = np.zeros((3, 3), dtype=np.float32)
     y = tensor.from_numpy(npy)
 
@@ -40,13 +40,13 @@ Example usage::
 
     r = tensor.relu(x)
 
-    r.to_host()  # move the data back to host cpu
-    s = tensor.to_numpy(r)  # tensor -> numpy array, r must be on cpu
+    s = tensor.to_numpy(r)  # tensor -> numpy array
 
 There are two sets of tensor functions,
 
 Tensor member functions
     which would change the internal state of the Tensor instance.
+
 Tensor module functions
     which accept Tensor instances as arguments and return Tensor instances.
 
@@ -558,28 +558,31 @@ def from_numpy(np_array):
     return ret
 
 
-def to_numpy(t):
-    '''Convert the tensor into a numpy array.
+def to_host(t):
+    '''Copy the data to a host tensor.
+    '''
+    ret = t.clone()
+    ret.to_host()
+    return ret
 
-    Since numpy array is allocated on CPU devices, the input Tensor instance
-    must be on the default CppCPU device.
+
+def to_numpy(t):
+    '''Copy the tensor into a numpy array.
 
     Args:
-        t (Tensor), a Tensor on the default CppCPU device.
+        t (Tensor), a Tensor
 
     Returns:
         a numpy array
     '''
-    assert (t.device.id() == -1) or (t.device is None), \
-        'Please move the tensor onto the default host device'
-
-    if t.dtype == core_pb2.kFloat32:
-        np_array = t.singa_tensor.GetFloatValue(int(t.size()))
-    elif t.dtype == core_pb2.kInt:
-        np_array = t.singa_tensor.GetIntValue(int(t.size()))
+    th = to_host(t)
+    if th.dtype == core_pb2.kFloat32:
+        np_array = ret.singa_tensor.GetFloatValue(int(th.size()))
+    elif th.dtype == core_pb2.kInt:
+        np_array = ret.singa_tensor.GetIntValue(int(th.size()))
     else:
-        print 'Not implemented yet for ', t.dtype
-    return np_array.reshape(t.shape)
+        print 'Not implemented yet for ', th.dtype
+    return np_array.reshape(th.shape)
 
 
 def abs(t):
@@ -638,6 +641,17 @@ def sigmoid(t):
     return _call_singa_func(singa.Sigmoid, t.singa_tensor)
 
 
+def sign(t):
+    '''
+    Args:
+        t (Tensor): input Tensor
+
+    Returns:
+        a new Tensor whose element y = sign(x)
+    '''
+    return _call_singa_func(singa.Sign, t.singa_tensor)
+
+
 def sqrt(t):
     '''
     Args:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3d407061/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index 5167fba..4f510ed 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -278,7 +278,7 @@ void Sign<float, lang::Cpp>(const size_t num, const Block *in, Block *out,
   float *outPtr = static_cast<float *>(out->mutable_data());
   const float *inPtr = static_cast<const float *>(in->data());
   for (size_t i = 0; i < num; i++) {
-    outPtr[i] = inPtr[i] > 0 ? 1.0f : 0.0f;
+    outPtr[i] = (inPtr[i] > 0) - (inPtr[i] < 0);
   }
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3d407061/src/core/tensor/tensor_math_opencl.cl
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_opencl.cl b/src/core/tensor/tensor_math_opencl.cl
index 7b89970..d5bc62f 100644
--- a/src/core/tensor/tensor_math_opencl.cl
+++ b/src/core/tensor/tensor_math_opencl.cl
@@ -23,7 +23,7 @@
 // Sum is basically reduction.
 // This reduction code is serial reduction modified from AMD's example.
 // http://developer.amd.com/resources/documentation-articles/articles-whitepapers/opencl-optimization-case-study-simple-reductions/
-__kernel 
+__kernel
 void clkernel_fabs(const int num, __global const float* in, __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -38,7 +38,7 @@ void clkernel_add_scalar(const int num, float x, __global const float* in, __glo
 }
 
 __kernel
-void clkernel_add(const int num, __global const float* in1, __global const float* in2, 
+void clkernel_add(const int num, __global const float* in1, __global const float* in2,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -46,7 +46,7 @@ void clkernel_add(const int num, __global const float* in1, __global const float
 }
 
 __kernel
-void clkernel_clamp(const int num, float low, float high, __global const float* in, 
+void clkernel_clamp(const int num, float low, float high, __global const float* in,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -62,7 +62,7 @@ void clkernel_divide_scalar_matx(const int num, __global const float* in1, const
 }
 
 __kernel
-void clkernel_divide_scalar_xmat(const int num, const float x, __global const float* in1, 
+void clkernel_divide_scalar_xmat(const int num, const float x, __global const float* in1,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -70,7 +70,7 @@ void clkernel_divide_scalar_xmat(const int num, const float x, __global const fl
 }
 
 __kernel
-void clkernel_divide(const int num, __global const float* in1, __global const float* in2, 
+void clkernel_divide(const int num, __global const float* in1, __global const float* in2,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -78,7 +78,7 @@ void clkernel_divide(const int num, __global const float* in1, __global const fl
 }
 
 __kernel
-void clkernel_eltmult_scalar(const int num, const float x, __global const float* in, 
+void clkernel_eltmult_scalar(const int num, const float x, __global const float* in,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -86,7 +86,7 @@ void clkernel_eltmult_scalar(const int num, const float x, __global const float*
 }
 
 __kernel
-void clkernel_eltmult(const int num, __global const float* in1, __global const float* in2, 
+void clkernel_eltmult(const int num, __global const float* in1, __global const float* in2,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -101,7 +101,7 @@ void clkernel_exp(const int num, __global const float* in, __global float* out)
 }
 
 __kernel
-void clkernel_le(const int num, __global const float* in, const float x, 
+void clkernel_le(const int num, __global const float* in, const float x,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -116,7 +116,7 @@ void clkernel_log(const int num, __global const float* in, __global float* out)
 }
 
 __kernel
-void clkernel_lt(const int num, __global const float* in, const float x, 
+void clkernel_lt(const int num, __global const float* in, const float x,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -124,7 +124,7 @@ void clkernel_lt(const int num, __global const float* in, const float x,
 }
 
 __kernel
-void clkernel_ge(const int num, __global const float* in, const float x, 
+void clkernel_ge(const int num, __global const float* in, const float x,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -132,7 +132,7 @@ void clkernel_ge(const int num, __global const float* in, const float x,
 }
 
 __kernel
-void clkernel_gt(const int num, __global const float* in, const float x, 
+void clkernel_gt(const int num, __global const float* in, const float x,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -140,7 +140,7 @@ void clkernel_gt(const int num, __global const float* in, const float x,
 }
 
 __kernel
-void clkernel_pow_scalar(const int num, const float x, __global const float* in, 
+void clkernel_pow_scalar(const int num, const float x, __global const float* in,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -148,7 +148,7 @@ void clkernel_pow_scalar(const int num, const float x, __global const float* in,
 }
 
 __kernel
-void clkernel_pow(const int num, __global const float* in1, __global const float* in2, 
+void clkernel_pow(const int num, __global const float* in1, __global const float* in2,
 		  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -180,7 +180,7 @@ __kernel
 void clkernel_sign(const int num, __global const float* in, __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
-  out[i] = sign(in[i]);
+  out[i] = (in[i] > 0) - (in[i] < 0);
 }
 
 __kernel
@@ -193,7 +193,7 @@ void clkernel_sqrt(const int num, __global const float* in, __global float* out)
 // kernel for square is called pow(2).
 
 __kernel
-void clkernel_subtract_scalar(const int num, __global const float* in, const float x, 
+void clkernel_subtract_scalar(const int num, __global const float* in, const float x,
 							  __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -201,7 +201,7 @@ void clkernel_subtract_scalar(const int num, __global const float* in, const flo
 }
 
 __kernel
-void clkernel_subtract(const int num, __global const float* in1, __global const float* in2, 
+void clkernel_subtract(const int num, __global const float* in1, __global const float* in2,
 					   __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -210,8 +210,8 @@ void clkernel_subtract(const int num, __global const float* in1, __global const
 
 // reduce3 kernel from
 // https://github.com/sschaetz/nvidia-opencl-examples/blob/master/OpenCL/src/oclReduction/oclReduction_kernel.cl
-__kernel 
-void clkernel_sum(const int num, __global const float* in, __global float* out, 
+__kernel
+void clkernel_sum(const int num, __global const float* in, __global float* out,
 				  __local float* sdata) {
   const int i = get_group_id(0)*(get_local_size(0)*2) + get_local_id(0);
   const int tid = get_local_id(0);
@@ -253,7 +253,7 @@ void clkernel_tanh(const int num, __global const float* in, __global float* out)
 // *********************************************************
 
 __kernel
-void clkernel_amax(const int num, __global const float* in, __global int* ret, 
+void clkernel_amax(const int num, __global const float* in, __global int* ret,
 				   __local uint* sdata, __local size_t* temp) {
   const int gid = get_global_id(0);
   const int tid = get_local_id(0);
@@ -272,7 +272,7 @@ void clkernel_amax(const int num, __global const float* in, __global int* ret,
 
 /* TODO: Fix line 284:20.
 __kernel
-void clkernel_amin(const int num, __global const float* in, __global int* ret, 
+void clkernel_amin(const int num, __global const float* in, __global int* ret,
 				   __local float* sdata, __local size_t* temp) {
   const int gid = get_global_id(0);
   const int tid = get_local_id(0);
@@ -294,7 +294,7 @@ void clkernel_amin(const int num, __global const float* in, __global int* ret,
 
 
 __kernel
-void clkernel_asum(const int num, __global const float* in, __global float* out, 
+void clkernel_asum(const int num, __global const float* in, __global float* out,
 				   __local float* sdata) {
   const int tid = get_local_id(0);
   const int i = get_global_id(0);
@@ -319,7 +319,7 @@ void clkernel_asum(const int num, __global const float* in, __global float* out,
 }
 
 __kernel
-void clkernel_axpy(const int num, float alpha, __global const float* in, 
+void clkernel_axpy(const int num, float alpha, __global const float* in,
 				   __global float* out) {
   const int i = get_global_id(0);
   if (i >= num) return;
@@ -362,13 +362,13 @@ void clkernel_scale(const int num, float x, __global float* out) {
 }
 
 __kernel
-void clkernel_dot(const int num, __global const float* in1, __global const float* in2, 
+void clkernel_dot(const int num, __global const float* in1, __global const float* in2,
 	  			  __global float* out, __local float* scratch) {
   const int i = get_global_id(0);
   if (i >= num) return;
   int offset = i << 2;
   scratch[i] = in1[offset] * in2[offset];
-  
+
 }
 
 // First kernel from http://www.bealto.com/gpu-gemv_intro.html
@@ -376,7 +376,7 @@ void clkernel_dot(const int num, __global const float* in1, __global const float
 // fma(a, b, c) == (a * b) + c with infinite precision
 __kernel
 void clkernel_gemv(const int m, const int n, const float alpha,
-				   __global const float* A, __global const float* v, 
+				   __global const float* A, __global const float* v,
 				   const float beta, __global float* out) {
   const int i = get_global_id(0);
   float sum  = 0.0f;
@@ -387,13 +387,13 @@ void clkernel_gemv(const int m, const int n, const float alpha,
 }
 
 // http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-dgmm
-// X[j] = x[j*inc(x)] 						if inc(x) \u2265 0 
+// X[j] = x[j*inc(x)] 						if inc(x) \u2265 0
 //		= x[(\u03c7 \u2212 1)*|inc(x)| \u2212 j*|inc(x)|] 	if inc(x) < 0
 
 // C = diag( X )*A
 __kernel
 void clkernel_dgmm_left(const int nrow, const int ncol,
-						__global const float* M, __global const float* v, 
+						__global const float* M, __global const float* v,
 						__global float* out) {
   const uint gidx = get_global_id(0);
 
@@ -406,7 +406,7 @@ void clkernel_dgmm_left(const int nrow, const int ncol,
 // C = A*diag( X )
 __kernel
 void clkernel_dgmm_right(const int nrow, const int ncol,
-						 __global const float* M, __global const float* v, 
+						 __global const float* M, __global const float* v,
 						 __global float* out) {
   const uint gidx = get_global_id(0);
 
@@ -420,7 +420,7 @@ void clkernel_dgmm_right(const int nrow, const int ncol,
 //  C = \u03b1*A*B + \u03b2*C
 __kernel
 void clkernel_gemm(const uint nrowA, const uint ncolB, const uint ncolA, const float alpha,
-		 		   __global const float* A, __global const float* B, const float beta, 
+		 		   __global const float* A, __global const float* B, const float beta,
 		  		   __global float* C, __local float* Asub, __local float* Bsub) {
 
   const uint lidx = get_local_id(0);
@@ -428,10 +428,10 @@ void clkernel_gemm(const uint nrowA, const uint ncolB, const uint ncolA, const f
   const uint TS = get_local_size(0); // Tile size
   const uint gidx = TS * get_group_id(0) + lidx; // Row ID of C (0..M)
   const uint gidy = TS * get_group_id(1) + lidy; // Row ID of C (0..N)
-  
+
   // Initialise the accumulation register
   float acc = 0.0f;
-  
+
   // Loop over all tiles
   const int numtiles = ncolA / TS;
   for (int t = 0; t < numtiles; t++) {
@@ -439,23 +439,23 @@ void clkernel_gemm(const uint nrowA, const uint ncolB, const uint ncolA, const f
     const int tiledCol = TS * t + lidy;
     Asub[lidy * TS + lidx] = A[tiledCol * nrowA + gidx];
     Bsub[lidy * TS + lidx] = B[gidy * ncolA + tiledRow];
-    
+
     barrier(CLK_LOCAL_MEM_FENCE);
-    
+
     for(int k = 0; k < TS; k++) {
       acc += Asub[k * TS + lidx] * Bsub[lidy * TS + k] * alpha;
     }
-    
+
     barrier(CLK_LOCAL_MEM_FENCE);
   }
-  
+
   C[gidy * nrowA + gidx] = fma(beta, C[gidy * nrowA + gidx], acc);
 }
 
 
 __kernel
-void clkernel_crossentropy(const uint batchsize, const uint dim, 
-						   __global const float* p, __global const int* t, 
+void clkernel_crossentropy(const uint batchsize, const uint dim,
+						   __global const float* p, __global const int* t,
 						   __global float* loss) {
   const uint gidx = get_global_id(0);
   if (gidx >= batchsize) return;
@@ -485,12 +485,12 @@ void clkernel_rowmax(const uint nrow, const uint ncol,
                      __global const float* in, __global float* out) {
   const uint row_id = get_global_id(0);
   if (row_id >= nrow) return;
-  
+
   float row_max_val = -FLT_MAX;
   for (uint i = 0; i < ncol; i++) {
     row_max_val = fmax(row_max_val, in[row_id * ncol + i]);
   }
-  
+
   out[row_id] = row_max_val;
 }
 
@@ -521,7 +521,7 @@ __kernel
 void clkernel_outerproduct(int m, const int n, __global const float* in1, __global const float* in2, __global float* out) {
   const int col = get_global_id(0);
   const int row = get_global_id(1);
-  
+
   // TODO: This
 }
 
@@ -541,7 +541,7 @@ __kernel
 void clkernel_sumrow(int nrow, int ncol, __global const float* in, __global float* out) {
   const int idx = get_global_id(0);
   if (idx >= nrow) return;
-  
+
   float sum = 0.0f;
   for (int j = 0; j < ncol; j++) {
 	sum += in[j + ncol * idx];
@@ -553,8 +553,8 @@ void clkernel_sumrow(int nrow, int ncol, __global const float* in, __global floa
 // Adapted from http://code.haskell.org/HsOpenCL/tests/bench/transpose.cl
 #define BLOCK_DIM 16
 __kernel
-void clkernel_transpose(uint nrow, uint ncol, 
-						__global const float* in, __global float* out, 
+void clkernel_transpose(uint nrow, uint ncol,
+						__global const float* in, __global float* out,
 						__local float* sdata) {
   uint gidx = get_global_id(0);
   uint gidy = get_global_id(1);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3d407061/src/core/tensor/tensor_math_opencl.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_opencl.h b/src/core/tensor/tensor_math_opencl.h
index a209de4..bc876b3 100644
--- a/src/core/tensor/tensor_math_opencl.h
+++ b/src/core/tensor/tensor_math_opencl.h
@@ -50,10 +50,10 @@ template<>
 void Abs<float, lang::Opencl>(const size_t num, const Block* in, Block* out, Context* ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_fabs");
-  
+
   viennacl::vector<float> v_in((const cl_mem)in->data(), num);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   v_out = v_in;
   enqueue(kernel((cl_int)num, v_in, v_out));
 }
@@ -62,11 +62,11 @@ void Abs<float, lang::Opencl>(const size_t num, const Block* in, Block* out, Con
 template<>
 void Add<float, lang::Opencl>(const size_t num, const Block* in, const float x, Block* out, Context* ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
-  
+
   viennacl::vector<float> x_in = viennacl::scalar_vector<float>(num, x, ocl_ctx);
   viennacl::vector<float> v_in((const cl_mem)in->data(), num);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   v_out = v_in + x_in;
 }
 
@@ -86,7 +86,7 @@ void Clamp<float, lang::Opencl>(const size_t num, const float low, const float h
                                 const Block* in, Block* out, Context* ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_clamp");
-  
+
   viennacl::vector<float> v_in((const cl_mem)in->data(), num);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
 
@@ -97,7 +97,7 @@ void Clamp<float, lang::Opencl>(const size_t num, const float low, const float h
 template<>
 void Div<float, lang::Opencl>(const size_t num, const Block* in, const float x, Block* out, Context* ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
-  
+
   viennacl::vector<float> x_in = viennacl::scalar_vector<float>(num, x, ocl_ctx);
   viennacl::vector<float> v_in((const cl_mem)in->data(), num);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
@@ -109,7 +109,7 @@ void Div<float, lang::Opencl>(const size_t num, const Block* in, const float x,
 template<>
 void Div<float, lang::Opencl>(const size_t num, const float x, const Block* in, Block* out, Context* ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
-  
+
   viennacl::vector<float> x_in = viennacl::scalar_vector<float>(num, x, ocl_ctx);
   viennacl::vector<float> v_in((const cl_mem)in->data(), num);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
@@ -131,7 +131,7 @@ void Div<float, lang::Opencl>(const size_t num, const Block* in1, const Block* i
 template<>
 void EltwiseMult<float, lang::Opencl>(const size_t num, const Block* in, const float x, Block* out, Context* ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
-  
+
   viennacl::vector<float> x_in = viennacl::scalar_vector<float>(num, x, ocl_ctx);
   viennacl::vector<float> v_in((const cl_mem)in->data(), num);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
@@ -154,7 +154,7 @@ template<>
 void Exp<float, lang::Opencl>(const size_t num, const Block* in, Block* out, Context* ctx) {
   viennacl::vector<float> v_in((const cl_mem)in->data(), num);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   v_out = viennacl::linalg::element_exp(v_in);
 }
 
@@ -163,10 +163,10 @@ template<>
 void LE<float, lang::Opencl>(const size_t num, const Block *in, const float x, Block *out, Context *ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_le");
-  
+
   viennacl::vector<float> in_buf((const cl_mem)in->data(), num);
   viennacl::vector<float> out_buf(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   enqueue(kernel((cl_int)num, in_buf, x, out_buf));
 }
 
@@ -175,7 +175,7 @@ template<>
 void Log<float, lang::Opencl>(const size_t num, const Block* in, Block* out, Context* ctx) {
   viennacl::vector<float> v_in((const cl_mem)in->data(), num);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   v_out = viennacl::linalg::element_log(v_in);
 }
 
@@ -184,10 +184,10 @@ template<>
 void LT<float, lang::Opencl>(const size_t num, const Block *in, const float x, Block *out, Context *ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_lt");
-  
+
   viennacl::vector<float> in_buf((const cl_mem)in->data(), num);
   viennacl::vector<float> out_buf(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   enqueue(kernel((cl_int)num, in_buf, x, out_buf));
 }
 
@@ -196,10 +196,10 @@ template<>
 void GE<float, lang::Opencl>(const size_t num, const Block *in, const float x, Block *out, Context *ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_ge");
-  
+
   viennacl::vector<float> in_buf((const cl_mem)in->data(), num);
   viennacl::vector<float> out_buf(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   enqueue(kernel((cl_int)num, in_buf, x, out_buf));
 }
 
@@ -208,10 +208,10 @@ template<>
 void GT<float, lang::Opencl>(const size_t num, const Block *in, const float x, Block *out, Context *ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_gt");
-  
+
   viennacl::vector<float> in_buf((const cl_mem)in->data(), num);
   viennacl::vector<float> out_buf(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   enqueue(kernel((cl_int)num, in_buf, x, out_buf));
 }
 
@@ -219,7 +219,7 @@ void GT<float, lang::Opencl>(const size_t num, const Block *in, const float x, B
 template<>
 void Pow<float, lang::Opencl>(const size_t num, const Block* in, float x, Block* out, Context* ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
-  
+
   viennacl::vector<float> x_in = viennacl::scalar_vector<float>(num, x, ocl_ctx);
   viennacl::vector<float> v_in((const cl_mem)in->data(), num);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
@@ -242,10 +242,10 @@ template<>
 void ReLU<float, lang::Opencl>(const size_t num, const Block* in, Block* out, Context* ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_relu");
-  
+
   viennacl::vector<float> in_buf((const cl_mem)in->data(), num);
   viennacl::vector<float> out_buf(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   enqueue(kernel((cl_int)num, in_buf, out_buf));
 }
 
@@ -255,7 +255,7 @@ void Set<float, lang::Opencl>(const size_t num, const float x, Block* out, Conte
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
 
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   v_out = viennacl::scalar_vector<float>(num, x, ocl_ctx);
 }
 
@@ -263,13 +263,13 @@ void Set<float, lang::Opencl>(const size_t num, const float x, Block* out, Conte
 template<>
 void Sigmoid<float, lang::Opencl>(const size_t num, const Block* in, Block* out, Context* ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
-  
+
   const viennacl::vector<float> zero = viennacl::zero_vector<float>(num, ocl_ctx);
   const viennacl::vector<float> one = viennacl::scalar_vector<float>(num, 1.0f, ocl_ctx);
-  
+
   viennacl::vector<float> v_in((const cl_mem)in->data(), num);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   v_out = viennacl::linalg::element_div(one, viennacl::linalg::element_exp(zero - v_in) + one);
 }
 
@@ -277,11 +277,11 @@ void Sigmoid<float, lang::Opencl>(const size_t num, const Block* in, Block* out,
 template<>
 void Sign<float, lang::Opencl>(const size_t num, const Block* in, Block* out, Context* ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
-  auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_abs");
-  
+  auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_sign");
+
   viennacl::vector<float> in_buf((const cl_mem)in->data(), num);
   viennacl::vector<float> out_buf(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   enqueue(kernel(num, in_buf, out_buf));
 }
 
@@ -344,11 +344,11 @@ template<>
 void Bernoulli<float, lang::Opencl>(const size_t num, const float p, Block* out, Context *ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("distribution.cl", "PRNG_threefry4x32_bernoulli");
-  
+
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   viennacl::ocl::packed_cl_uint seed = {0, 32, 42, 888};
-  
+
   enqueue(kernel(v_out, seed, 0.0f, 1.0f, p, rounds, cl_uint(num / 4)));
 }
 
@@ -357,11 +357,11 @@ template<>
 void Gaussian<float, lang::Opencl>(const size_t num, const float mean, const float std, Block* out, Context *ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("distribution.cl", "PRNG_threefry4x32_gaussian");
-  
+
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   viennacl::ocl::packed_cl_uint seed = {0, 32, 42, 888};
-  
+
   enqueue(kernel(v_out, seed, mean, std, rounds, cl_uint(num/4)));
 }
 
@@ -370,11 +370,11 @@ template<>
 void Uniform<float, lang::Opencl>(const size_t num, const float low, const float high, Block* out, Context *ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("distribution.cl", "PRNG_threefry4x32_uniform");
-  
+
   viennacl::ocl::packed_cl_uint seed = {0, 32, 42, 888};
-  
+
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   enqueue(kernel(v_out, seed, low, high, rounds, cl_uint(num/4)));
 }
 
@@ -441,7 +441,7 @@ void Amin<float, lang::Opencl>(const size_t num, const Block* in, size_t* out, C
   delete temp;
 }
 
-	
+
 template<>
 void Asum<float, lang::Opencl>(const size_t num, const Block* in, float* out, Context* ctx) {
   cl_int status = CL_SUCCESS;
@@ -450,7 +450,7 @@ void Asum<float, lang::Opencl>(const size_t num, const Block* in, float* out, Co
   auto kernel = ctx->kernels->at(kname);
 
   cl::Buffer inbuf = *(static_cast<cl::Buffer*>(in->mutable_data()));
-  
+
   size_t size = sizeof(float) * num;
   cl::Buffer outval(ctx->ocl_ctx, CL_MEM_WRITE_ONLY, size, nullptr, &status);
   OCL_CHECK(status, "Failed to create buffer!");
@@ -475,7 +475,7 @@ template<>
 void Axpy<float, lang::Opencl>(const size_t num, const float alpha, const Block* in, Block* out, Context* ctx) {
   viennacl::vector<float> inbuf((const cl_mem)in->data(), num);
   viennacl::vector<float> outbuf(static_cast<cl_mem>(out->mutable_data()), num);
-  
+
   outbuf += alpha * inbuf;
 }
 
@@ -483,7 +483,7 @@ void Axpy<float, lang::Opencl>(const size_t num, const float alpha, const Block*
 template<>
 void Nrm2<float, lang::Opencl>(const size_t num, const Block* in, float* out, Context* ctx) {
   viennacl::vector<float> inbuf((const cl_mem)in->data(), num);
-  
+
   out[0] = viennacl::linalg::norm_2(inbuf);
 }
 
@@ -491,7 +491,7 @@ void Nrm2<float, lang::Opencl>(const size_t num, const Block* in, float* out, Co
 template<>
 void Scale<float, lang::Opencl>(const size_t num, const float x, Block* out, Context* ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
-  
+
   viennacl::vector<float> x_in = viennacl::scalar_vector<float>(num, x, ocl_ctx);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), num);
 
@@ -503,7 +503,7 @@ template<>
 void Dot<float, lang::Opencl>(const size_t num, const Block *in1, const Block *in2, float *out, Context *ctx) {
   viennacl::vector<float> in1_buf((const cl_mem)in1->data(), num);
   viennacl::vector<float> in2_buf((const cl_mem)in2->data(), num);
-  
+
   out[0] = viennacl::linalg::inner_prod(in1_buf, in2_buf);
 }
 
@@ -513,9 +513,9 @@ void GEMV<float, lang::Opencl>(bool trans, const size_t m, const size_t n, const
 		  const Block *A, const Block *v, const float beta, Block* out, Context* ctx) {
   viennacl::vector<float> v_buf((const cl_mem)v->data(), n);
   viennacl::vector<float> o_buf(static_cast<cl_mem>(out->mutable_data()), m);
-  
+
   viennacl::matrix<float> A_buf;
-  
+
   if (trans) {
     A_buf = viennacl::matrix<float>((const cl_mem)A->data(), n, m);
     A_buf = viennacl::trans(A_buf);
@@ -537,9 +537,9 @@ void DGMM<float, lang::Opencl>(bool side_right,
   viennacl::matrix<float> M_buf((const cl_mem)M->data(), nrow, ncol);
   viennacl::vector<float> v_buf((const cl_mem)v->data(), nrow);
   viennacl::matrix<float> out_buf(static_cast<cl_mem>(out->mutable_data()), nrow, ncol);
-  
+
   auto diag = viennacl::diag(v_buf);
-  
+
   if (side_right) {
     out_buf = viennacl::linalg::prod(diag, M_buf);
   } else {
@@ -556,21 +556,21 @@ void GEMM<float, lang::Opencl>(const bool transA, const bool transB,
 
   viennacl::matrix<float> A_buf, B_buf;
   viennacl::matrix<float> C_buf(static_cast<cl_mem>(C->mutable_data()), nrowA, ncolB);
-  
+
   if (transA) {
     A_buf = viennacl::matrix<float>((const cl_mem)A->data(), ncolA, nrowA);
     A_buf = viennacl::trans(A_buf);
   } else {
     A_buf = viennacl::matrix<float>((const cl_mem)A->data(), nrowA, ncolA);
   }
-  
+
   if (transB) {
     B_buf = viennacl::matrix<float>((const cl_mem)B->data(), ncolB, ncolA);
     B_buf = viennacl::trans(B_buf);
   } else {
     B_buf = viennacl::matrix<float>((const cl_mem)B->data(), ncolA, ncolB);
   }
-  
+
   C_buf *= beta;
   C_buf += alpha * viennacl::linalg::prod(A_buf, B_buf);
 }
@@ -582,11 +582,11 @@ void ComputeCrossEntropy<float, lang::Opencl>(const size_t batchsize, const size
                          Context *ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_crossentropy");
-  
+
   viennacl::vector<float> p_buf((const cl_mem)p->data(), batchsize);
   viennacl::vector<float> t_buf((const cl_mem)t->data(), batchsize);
   viennacl::vector<float> loss_buf(static_cast<cl_mem>(loss->mutable_data()), batchsize);
-  
+
   enqueue(kernel((cl_uint)batchsize, (cl_uint)dim, p_buf, t_buf, loss_buf));
 }
 
@@ -597,11 +597,11 @@ void SoftmaxCrossEntropyBwd<float, lang::Opencl>(const size_t batchsize, const s
                             Context *ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_softmaxentropy");
-  
+
   viennacl::vector<float> p_buf((const cl_mem)p->data(), batchsize);
   viennacl::vector<float> t_buf((const cl_mem)t->data(), batchsize);
   viennacl::vector<float> grad_buf(static_cast<cl_mem>(grad->mutable_data()), batchsize);
-  
+
   enqueue(kernel((cl_uint)batchsize, (cl_uint)dim, p_buf, t_buf, grad_buf));
 }
 
@@ -611,12 +611,12 @@ void RowMax<float, lang::Opencl>(const size_t nrow, const size_t ncol,
                                  const Block *in, Block *out, Context *ctx) {
   auto ocl_ctx = get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_rowmax");
-  
+
 //  kernel.global_work_size(0, nrow);
-  
+
   viennacl::matrix<float> in_buf((const cl_mem)in->data(), nrow, ncol);
   viennacl::vector<float> outbuf(static_cast<cl_mem>(out->mutable_data()), nrow);
-  
+
   enqueue(kernel((cl_uint)nrow, (cl_uint)ncol, in_buf, outbuf));
 }
 
@@ -641,7 +641,7 @@ void Outer<float, lang::Opencl>(const size_t m, const size_t n, const Block* lhs
   viennacl::vector<float> lhs_in((const cl_mem)lhs->data(), m);
   viennacl::vector<float> rhs_in((const cl_mem)rhs->data(), n);
   viennacl::matrix<float> out_buf(static_cast<cl_mem>(out->mutable_data()), m, n);
-  
+
   out_buf = viennacl::linalg::outer_prod(lhs_in, rhs_in);
 }
 
@@ -650,7 +650,7 @@ template<>
 void SumColumns<float, lang::Opencl>(const size_t nrow, const size_t ncol, const Block* in, Block* out, Context* ctx) {
   viennacl::matrix<float> m_in((const cl_mem)in->data(), nrow, ncol);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), nrow);
-  
+
   v_out = viennacl::linalg::column_sum(m_in);
 }
 
@@ -659,7 +659,7 @@ template<>
 void SumRows<float, lang::Opencl>(const size_t nrow, const size_t ncol, const Block* in, Block* out, Context* ctx) {
   viennacl::matrix<float> m_in((const cl_mem)in->data(), nrow, ncol);
   viennacl::vector<float> v_out(static_cast<cl_mem>(out->mutable_data()), ncol);
-  
+
   v_out = viennacl::linalg::column_sum(m_in);
 }
 */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3d407061/test/singa/test_tensor_math.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_tensor_math.cc b/test/singa/test_tensor_math.cc
index c3a1039..116262c 100644
--- a/test/singa/test_tensor_math.cc
+++ b/test/singa/test_tensor_math.cc
@@ -105,7 +105,7 @@ TEST_F(TestTensorMath, MemberSign) {
 
   Tensor p = Sign(cc);
   const float *dptr1 = p.data<float>();
-  EXPECT_EQ(0.0f, dptr1[0]);
+  EXPECT_EQ(-1.0f, dptr1[0]);
   EXPECT_EQ(0.0f, dptr1[1]);
   EXPECT_EQ(1.0f, dptr1[2]);
 }



[2/2] incubator-singa git commit: re-organize the installation page and update the index page

Posted by zh...@apache.org.
re-organize the installation page and update the index page


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/2d5f696b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/2d5f696b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/2d5f696b

Branch: refs/heads/master
Commit: 2d5f696bd2b996f2cec571f72ef938570c68af7e
Parents: 3d40706
Author: wang wei <wa...@comp.nus.edu.sg>
Authored: Sun Jan 22 08:51:46 2017 +0000
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Sun Jan 22 08:56:00 2017 +0000

----------------------------------------------------------------------
 doc/build.sh                |   2 +-
 doc/en/docs/index.rst       |   2 +-
 doc/en/docs/installation.md | 211 +++++++++++++++++----------------------
 doc/en/docs/net.rst         |   3 +
 doc/en/index.rst            |  10 +-
 examples/index.rst          |  12 +--
 python/singa/net.py         |  18 ++++
 python/singa/tensor.py      |   1 +
 8 files changed, 127 insertions(+), 132 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2d5f696b/doc/build.sh
----------------------------------------------------------------------
diff --git a/doc/build.sh b/doc/build.sh
index c0873a9..44eb1c2 100755
--- a/doc/build.sh
+++ b/doc/build.sh
@@ -29,7 +29,7 @@ fi
 
 
 if [ "$1"x = "html"x ]; then
-  cp -rf ../examples en/docs/
+  cp -rf ../examples en/docs/model_zoo
   cp README.md en/develop/contribute-docs.md
   for (( i=0; i<${#LANG_ARR[@]}; i++)) do
     echo "building language ${LANG_ARR[i]} ..."

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2d5f696b/doc/en/docs/index.rst
----------------------------------------------------------------------
diff --git a/doc/en/docs/index.rst b/doc/en/docs/index.rst
index 691c3c0..ee78290 100644
--- a/doc/en/docs/index.rst
+++ b/doc/en/docs/index.rst
@@ -36,4 +36,4 @@ Documentation
    snapshot
    converter
    utils
-   examples/index
+   model_zoo/index

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2d5f696b/doc/en/docs/installation.md
----------------------------------------------------------------------
diff --git a/doc/en/docs/installation.md b/doc/en/docs/installation.md
index f9231f6..116f629 100755
--- a/doc/en/docs/installation.md
+++ b/doc/en/docs/installation.md
@@ -1,13 +1,14 @@
 # Installation
 
 
-## Install PySINGA
+## From wheel
 
+Users can download the pre-compiled wheel files to install PySINGA.
 PySINGA has been tested on Linux (Ubunu 14.04 and 16.04) and Mac OS (10.11 and 10.12).
 
-### Install dependent libraries
+### Pre-requisite
 
-Python 2.7 is required to run PySINGA.
+Python 2.7 and pip are required
 
     # For Ubuntu
     $ sudo apt-get install python2.7-dev python-pip
@@ -16,6 +17,9 @@ Python 2.7 is required to run PySINGA.
     $ brew tap homebrew/python
     $ brew install python
 
+Note for Mac OS, you need to configure the (python) paths correctly if multiple python versions are installed.
+Refer to FAQ for the errors and solutions.
+
 ### Virtual environment
 
 Users are recommended to use PySINGA in python virtual environment.
@@ -37,15 +41,15 @@ Note that in python virtual environment, you may need to reset the `PYTHONPATH`
 to avoid the conflicts of system path and virtual environment path.
 
 
-### From wheel
+### Instructions
 
-Currently, we have the following wheel files,
+Currently, the following wheel files are available,
 
 <table border="1">
   <tr>
     <th>OS</th>
     <th>Device</th>
-    <th>CUDA/CUDNN</th>
+    <th>CUDA/cuDNN</th>
     <th>Link</th>
   </tr>
   <tr>
@@ -57,13 +61,13 @@ Currently, we have the following wheel files,
   <tr>
     <td>Ubuntu14.04</td>
     <td>GPU</td>
-    <td>CUDA7.5+CUDNN4</td>
+    <td>CUDA7.5+cuDNN4</td>
     <td><a href="http://comp.nus.edu.sg/~dbsystem/singa/assets/file/wheel/linux/latest/ubuntu14.04-cuda7.5-cudnn4/">latest</a>, <a href="http://www.comp.nus.edu.sg/~dbsystem/singa/assets/file/wheel/linux">history</a></td>
   </tr>
   <tr>
     <td>Ubuntu14.04</td>
     <td>GPU</td>
-    <td>CUDA7.5+CUDNN5</td>
+    <td>CUDA7.5+cuDNN5</td>
     <td><a href="http://comp.nus.edu.sg/~dbsystem/singa/assets/file/wheel/linux/latest/ubuntu14.04-cuda7.5-cudnn5/">latest</a>, <a href="http://www.comp.nus.edu.sg/~dbsystem/singa/assets/file/wheel/linux">history</a></td>
   </tr>
   <tr>
@@ -75,7 +79,7 @@ Currently, we have the following wheel files,
   <tr>
     <td>Ubuntu16.04</td>
     <td>GPU</td>
-    <td>CUDA8.0+CUDNN5</td>
+    <td>CUDA8.0+cuDNN5</td>
     <td><a href="http://comp.nus.edu.sg/~dbsystem/singa/assets/file/wheel/linux/latest/ubuntu16.04-cuda8.0-cudnn5/">latest</a>, <a href="http://www.comp.nus.edu.sg/~dbsystem/singa/assets/file/wheel/linux">history</a></td>
   </tr>
   <tr>
@@ -96,85 +100,50 @@ Download the whl file and execute the following command to install PySINGA,
 
     $ pip install --upgrade <path to the whel file>
 
-To install the wheel file compiled with CUDA, you need to install CUDA and export the `LD_LIBRARY_PATH` to CUDNN before running the above instruction.
+To install the wheel file compiled with CUDA, you need to install CUDA and export the `LD_LIBRARY_PATH` to cuDNN before running the above instruction.
 
 If you have sudo right, you can run the above commands using `sudo pip install` without python virtual environment.
 The option `--upgrade` may cause errors sometimes, in which case you can ignore it.
 
-### From source
-
-To build the PySINGA from source, the following dependent libraries are required,
-
-* swig(>=3.0.10)
-* numpy(>=1.11.0)
-
-They can be installed by
-
-    $ Ubuntu 14.04 and 16.04
-    $ sudo apt-get install python-numpy
-    # Ubuntu 16.04
-    $ sudo apt-get install swig
-
-Note that swig has to be installed from source on Ubuntu 14.04.
-After installing numpy, export the header path of numpy.i as
-
-    $ export CPLUS_INCLUDE_PATH=`python -c "import numpy; print numpy.get_include()"`:$CPLUS_INCLUDE_PATH
-
-**compile SINGA from source** (see the next section) with `cmake -DUSE_PYTHON=ON ..`,
-and then run the following commands,
-
-    # under the build directory
-    $ cd python
-    $ pip install .
-
-Developers can build the wheel file via
-
-    # under the build directory
-    $ cd python
-    $ python setup.py bdist_wheel
-
-The generated wheel file is under "dist" directory.
-
-
-## Install SINGA from Debian Package
+## From Debian Package
 
-We have prepared the Debian packages (on architecture: amd64) for SINGA as listed below,
+The following Debian packages (on architecture: amd64) are available
 
 <table border="1">
   <tr>
     <th>OS</th>
     <th>Device</th>
-    <th>CUDA/CUDNN</th>
+    <th>CUDA/cuDNN</th>
     <th>Link</th>
   </tr>
   <tr>
     <td>Ubuntu14.04</td>
     <td>CPU</td>
     <td>-</td>
-    <td><a href="http://comp.nus.edu.sg/~dbsystem/singa/assets/file/debian/latest/ubuntu14.04-cpp/">latest</a>, <a href="http://www.comp.nus.edu.sg/~dbsystem/singa/assets/file/debian">history</a></td>
+    <td><a href="http://comp.nus.edu.sg/~dbsystem/singa/assets/file/debian/latest/ubuntu14.04-cpp/python-singa.deb">latest</a>, <a href="http://www.comp.nus.edu.sg/~dbsystem/singa/assets/file/debian">history</a></td>
   </tr>
   <tr>
     <td>Ubuntu14.04</td>
     <td>GPU</td>
-    <td>CUDA7.5+CUDNN4</td>
+    <td>CUDA7.5+cuDNN4</td>
     <td>coming soon</td>
   </tr>
   <tr>
     <td>Ubuntu14.04</td>
     <td>GPU</td>
-    <td>CUDA7.5+CUDNN5</td>
+    <td>CUDA7.5+cuDNN5</td>
     <td>coming soon</td>
   </tr>
   <tr>
     <td>Ubuntu16.04</td>
     <td>CPU</td>
     <td>-</td>
-    <td><a href="http://comp.nus.edu.sg/~dbsystem/singa/assets/file/debian/latest/ubuntu16.04-cpp/">latest</a>, <a href="http://www.comp.nus.edu.sg/~dbsystem/singa/assets/file/debian">history</a></td>
+    <td><a href="http://comp.nus.edu.sg/~dbsystem/singa/assets/file/debian/latest/ubuntu16.04-cpp/python-singa.deb">latest</a>, <a href="http://www.comp.nus.edu.sg/~dbsystem/singa/assets/file/debian">history</a></td>
   </tr>
   <tr>
     <td>Ubuntu16.04</td>
     <td>GPU</td>
-    <td>CUDA8.0+CUDNN5</td>
+    <td>CUDA8.0+cuDNN5</td>
     <td>coming soon</td>
   </tr>
 </table>
@@ -183,69 +152,54 @@ Download the deb file and install it via
 
     apt-get install <path to the deb file, e.g., ./python-singa.deb>
 
-To create the Debian packages, please refer to the README.md file under `SINGA_ROOT/tool/debian`.
-
-
-## Use SINGA Docker Images
-
-A list of Docker images with SINGA installed are available on [Dockerhub](https://hub.docker.com/r/nusdbsystem/singa/).
-To use the image, run
-
-    # for images built without CUDA
-    $ docker run -it nusdbsystem/singa:<tag> /bin/bash
-    # for images built with CUDA support
-    $ nvidia-docker run -it nusdbsystem/singa:<tag> /bin/bash
-
-All available tags and descriptions are on [Dockerhub](https://hub.docker.com/r/nusdbsystem/singa/) and [Github](https://github.com/apache/incubator-singa/blob/master/tool/docker/README.md)
+Note that the path must include `./` if the file in inside the current folder.
 
-## Compile SINGA from source on Linux and Mac OS
+## From source
 
 The source files could be downloaded either as a [tar.gz file](https://dist.apache.org/repos/dist/dev/incubator/singa/), or as a git repo
 
     $ git clone https://github.com/apache/incubator-singa.git
     $ cd incubator-singa/
 
-cmake (>=2.8) is used for compile SINGA, which can be installed by
-
-    # For Ubuntu 14.04 and 16.04
-    $ sudo apt-get install cmake
-
-GCC (>=4.8.1) is required to compile SINGA on Linux.
-For Mac OS users, you can use either GCC or Clang.
-
-### Compile SINGA together with dependent libraries
-
-SINGA code uses CBLAS and Protobuf (>=2.5, <3).
-If they are not installed in your OS, you can compile SINGA together with them
-
-    $ In SINGA ROOT folder
-    $ mkdir build
-    $ cd build
-    $ cmake -DUSE_MODULES=ON ..
-    $ make
-
-cmake would download OpenBlas and Protobuf (2.6.1) and compile them together
-with SINGA.
-
-### Install dependent libraries and then compile SINGA
-
-Users can also install the dependent libraries and then link SINGA with them.
+### Pre-requisite
 
 The following libraries are required
+* cmake (>=2.8)
+* gcc (>=4.8.1) or Clang
 * google protobuf (>=2.5,<3)
 * blas (tested with openblas >=0.2.10)
-
+* swig(>=3.0.10) for compiling PySINGA
+* numpy(>=1.11.0) for compiling PySINGA
 
 The following libraries are optional
 * opencv (tested with 2.4.8)
 * lmdb (tested with 0.9)
 * glog
 
+### Instructions
+
+1. create a `build` folder inside incubator-singa and go into that folder
+2. run `cmake [options] ..`
+  by default all options are OFF except `USE_PYTHON`
+
+    * `USE_MODUELS=ON`, used if protobuf and blas are not installed a prior
+    * `USE_CUDA=ON`, used if CUDA and cuDNN is available
+    * `USE_PYTHON=ON`, used for compiling PySINGA
+    * `USE_OPENCL=ON`, used for compiling with OpenCL support
+3. compile the code, e.g., `make`
+4. goto python folder
+5. run `pip install .`
+6. [optional] run `python setup.py bdist_wheel` to generate the wheel file
+
+Step 4 and 5 are to install PySINGA.
+Details on the installation of dependent libraries and the instructions for each OS are given in the following sections.
+
+### Linux and Mac OS
 
 Most of the dependent libraries could be installed from source or via package mangers like
 apt-get, yum, and homebrew. Please refer to FAQ for problems caused by the path setting of the dependent libraries.
 
-The following instructions are tested on Ubuntu 14.04 for installing dependent libraries.
+The following instructions are tested on Ubuntu 14.04  and 16.04for installing dependent libraries.
 
     # required libraries
     $ sudo apt-get install libprotobuf-dev libopenblas-dev protobuf-compiler
@@ -254,10 +208,7 @@ The following instructions are tested on Ubuntu 14.04 for installing dependent l
     $ sudo apt-get install python2.7-dev python-pip python-numpy
     $ sudo apt-get install libopencv-dev libgoogle-glog-dev liblmdb-dev
 
-Note that PySINGA requires swig >=3.0, which could be installed via
-apt-get on Ubuntu 16.04; but it has to be installed from source for other Ubuntu versions including 14.04.
-
-The following instructions are tested on Mac OS X Yosemite (10.10.5) for installing dependent libraries.
+The following instructions are tested on Mac OS X Yosemite (10.11 and 10.12) for installing dependent libraries.
 
     # required libraries
     $ brew tap homebrew/science
@@ -281,20 +232,25 @@ To let the runtime know the openblas path,
     $ export LD_LIBRARY_PATH=/usr/local/opt/openblas/library:$LD_LIBRARY_PATH
 
 
-With the dependent libraries installed, SINGA can be compiled via
+#### Compile with USE_MODULES=ON
+
+If protobuf and openblas are not installed, you can compile SINGA together with them
 
+    $ In SINGA ROOT folder
     $ mkdir build
     $ cd build
-    $ cmake ..
+    $ cmake -DUSE_MODULES=ON ..
     $ make
-    $ make install
+
+cmake would download OpenBlas and Protobuf (2.6.1) and compile them together
+with SINGA.
 
 After compiling SINGA, you can run the unit tests by
 
     $ ./bin/test_singa
 
 You can see all the testing cases with testing results. If SINGA passes all
-tests, then you have successfully installed SINGA. Please proceed to try the examples!
+tests, then you have successfully installed SINGA.
 
 You can use `ccmake ..` to configure the compilation options.
 If some dependent libraries are not in the system default paths, you need to export
@@ -303,13 +259,41 @@ the following environment variables
     export CMAKE_INCLUDE_PATH=<path to the header file folder>
     export CMAKE_LIBRARY_PATH=<path to the lib file folder>
 
-### Compile SINGA with CUDA and CUDNN
+#### Compile with USE_PYTHON=ON
+swig and numpy can be install by
+
+    $ Ubuntu 14.04 and 16.04
+    $ sudo apt-get install python-numpy
+    # Ubuntu 16.04
+    $ sudo apt-get install swig
+
+Note that swig has to be installed from source on Ubuntu 14.04.
+After installing numpy, export the header path of numpy.i as
+
+    $ export CPLUS_INCLUDE_PATH=`python -c "import numpy; print numpy.get_include()"`:$CPLUS_INCLUDE_PATH
+
+Similar to compile CPP code, PySINGA is compiled by
+
+    $ cmake -DUSE_PYTHON=ON ..
+    $ make
+    $ cd python
+    $ pip install .
+
+Developers can build the wheel file via
+
+    # under the build directory
+    $ cd python
+
+The generated wheel file is under "dist" directory.
+
+
+#### Compile SINGA with USE_CUDA=ON
 
 Users are encouraged to install the CUDA and
-[CUDNN](https://developer.nvidia.com/cudnn) for running SINGA on GPUs to
+[cuDNN](https://developer.nvidia.com/cudnn) for running SINGA on GPUs to
 get better performance.
 
-SINGA has been tested over CUDA (7, 7.5, 8), and CUDNN (4 and 5).  If CUDNN is
+SINGA has been tested over CUDA (7, 7.5, 8), and cuDNN (4 and 5).  If cuDNN is
 decompressed into non-system folder, e.g. /home/bob/local/cudnn/, the following
 commands should be executed for cmake and the runtime to find it
 
@@ -317,16 +301,12 @@ commands should be executed for cmake and the runtime to find it
     $ export CMAKE_LIBRARY_PATH=/home/bob/local/cudnn/lib64:$CMAKE_LIBRARY_PATH
     $ export LD_LIBRARY_PATH=/home/bob/local/cudnn/lib64:$LD_LIBRARY_PATH
 
-The cmake options for CUDA and CUDNN should be switched on
+The cmake options for CUDA and cuDNN should be switched on
 
     # Dependent libs are install already
     $ cmake -DUSE_CUDA=ON ..
 
-    # Compile dependent libs together with SINGA
-    $ cmake -DUSE_CUDA=ON -DUSE_MODULES=ON ..
-
-
-### Compile SINGA with OpenCL support (Linux)
+#### Compile SINGA with USE_OPENCL=ON
 
 SINGA uses opencl-headers and viennacl (version 1.7.1 or newer) for OpenCL support, which
 can be installed using via
@@ -355,7 +335,7 @@ To build SINGA with OpenCL support, you need to pass the flag during cmake:
 
     cmake -DUSE_OPENCL=ON ..
 
-## Build SINGA on Windows
+### Compile SINGA on Windows
 
 For the dependent library installation, please refer to [Dependencies](dependencies.md).
 After all the dependencies are successfully installed, just run the following commands to
@@ -391,11 +371,6 @@ unit tests file named "test_singa" in the project binary folder.
 If you get errors when running test_singa.exe due to libglog.dll/libopenblas.dll missing,
 just copy the dll files into the same folder as test_singa.exe
 
-## Build the Debian packages
-
-    $ cd debian
-    $ ./build.sh
-
 ## FAQ
 
 * Q: Error from 'import singa' using PySINGA installed from wheel.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2d5f696b/doc/en/docs/net.rst
----------------------------------------------------------------------
diff --git a/doc/en/docs/net.rst b/doc/en/docs/net.rst
index cc20c21..7aff364 100644
--- a/doc/en/docs/net.rst
+++ b/doc/en/docs/net.rst
@@ -21,3 +21,6 @@ FeedForward Net
 
 .. automodule:: singa.net
    :members:
+   :member-order: bysource
+   :show-inheritance:
+   :undoc-members:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2d5f696b/doc/en/index.rst
----------------------------------------------------------------------
diff --git a/doc/en/index.rst b/doc/en/index.rst
index e0f8aaa..6d17557 100755
--- a/doc/en/index.rst
+++ b/doc/en/index.rst
@@ -58,18 +58,16 @@ Recent News
 
 Getting Started
 ---------------
-* The `Software stack <docs/software_stack.html>`_ page gives an overview of SINGA.
+* Try SINGA on `AWS <https://aws.amazon.com/marketplace/pp/B01NAUAWZW>`_ or via `Docker <https://hub.docker.com/r/nusdbsystem/singa/>`_.
 
-* The `Installation <docs/installation.html>`_ guide describes details on downloading and installing SINGA.
+* Install SINGA via `python wheel files <./docs/installation.html#from-wheel>`_, `Debian packages <./docs/installation.html#from-debian-package>`_ or from `source <./docs/installation.html#from-source>`_.
 
-* Please follow the `Examples <docs/examples/index.html>`_ guide to run simple applications on SINGA.
-
-* More exmaples in `Jupyter <http://jupyter.org/>`_ (IPython) can be open in `notebook viewer <http://nbviewer.jupyter.org/github/apache/incubator-singa/blob/master/doc/en/docs/notebook/index.ipynb>`_ .
+* Refer to the `Jupyter notebooks <http://nbviewer.jupyter.org/github/apache/incubator-singa/blob/master/doc/en/docs/notebook/index.ipynb>`_ for some basic examples and the `model zoo page <./docs/model_zoo/index.html>`_ for more examples.
 
 Documentation
 -------------
 
-* Documentations are listed `here <docs.html>`_.
+* Documentation and APIs are listed `here <docs.html>`_.
 
 * Research publication list is available `here <http://www.comp.nus.edu.sg/~dbsystem/singa/research/publication/>`_.
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2d5f696b/examples/index.rst
----------------------------------------------------------------------
diff --git a/examples/index.rst b/examples/index.rst
index b501b36..fa1c77b 100644
--- a/examples/index.rst
+++ b/examples/index.rst
@@ -1,4 +1,4 @@
-.. 
+..
 .. Licensed to the Apache Software Foundation (ASF) under one
 .. or more contributor license agreements.  See the NOTICE file
 .. distributed with this work for additional information
@@ -6,18 +6,18 @@
 .. to you under the Apache License, Version 2.0 (the
 .. "License"); you may not use this file except in compliance
 .. with the License.  You may obtain a copy of the License at
-.. 
+..
 ..     http://www.apache.org/licenses/LICENSE-2.0
-.. 
+..
 .. Unless required by applicable law or agreed to in writing, software
 .. distributed under the License is distributed on an "AS IS" BASIS,
 .. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 .. See the License for the specific language governing permissions and
 .. limitations under the License.
-.. 
+..
 
-Examples
-========
+Model Zoo
+=========
 
 .. toctree::
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2d5f696b/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index 9d09740..d6e313e 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -67,6 +67,12 @@ verbose = False
 class FeedForwardNet(object):
 
     def __init__(self, loss=None, metric=None):
+        '''Representing a feed-forward neural net.
+
+        Args:
+            loss, a Loss instance. Necessary training
+            metric, a Metric instance. Necessary for evaluation
+        '''
         self.loss = loss
         self.metric = metric
         self.layers = []
@@ -76,6 +82,9 @@ class FeedForwardNet(object):
         self.out_sample_shape_of_layer = {}
 
     def to_device(self, dev):
+        '''Move the net onto the given device, including
+        all parameters and intermediate data.
+        '''
         for lyr in self.layers:
             lyr.to_device(dev)
 
@@ -90,6 +99,7 @@ class FeedForwardNet(object):
 
         Args:
             lyr (Layer): the layer to be added
+            src (Layer): the source layer of lyr
         """
         if src is not None:
             if isinstance(src, layer.Layer):
@@ -128,6 +138,7 @@ class FeedForwardNet(object):
         return lyr
 
     def param_values(self):
+        '''Return a list of tensors for all parameters'''
         values = []
         layers = self.layers
         if self.ordered_layers is not None:
@@ -137,6 +148,7 @@ class FeedForwardNet(object):
         return values
 
     def param_specs(self):
+        '''Return a list of ParamSpec for all parameters'''
         specs = []
         layers = self.layers
         if self.ordered_layers is not None:
@@ -146,6 +158,7 @@ class FeedForwardNet(object):
         return specs
 
     def param_names(self):
+        '''Return a list for the names of all params'''
         return [spec.name for spec in self.param_specs()]
 
     def train(self, x, y):
@@ -304,6 +317,11 @@ class FeedForwardNet(object):
             return ret
 
     def backward(self):
+        '''Run back-propagation after forward-propagation.
+
+        Returns:
+            a list of gradient tensor for all parameters
+        '''
         if self.dst_of_layer is None:
             self.dst_of_layer = {}
             for cur in self.layers:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2d5f696b/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index d1851d1..12d7c53 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -17,6 +17,7 @@
 # =============================================================================
 """
 Example usage::
+
     import numpy as np
     from singa import tensor
     from singa import device