You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2016/06/13 13:20:33 UTC

[40/50] [abbrv] incubator-singa git commit: SINGA-192 Implement optimization algorithms for v1

SINGA-192 Implement optimization algorithms for v1

Merge branch PR#164 into dev

Fix the bugs in test adagrad and rmsprop.
Note, expect near (with diff 1e-5) is used to avoid numeric bugs. Need to do test on more
machines.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/5784bff3
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/5784bff3
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/5784bff3

Branch: refs/heads/master
Commit: 5784bff3e5ebfb3a992624d10f03f30cd5e520a3
Parents: 6d69047 178db01
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Sun Jun 12 15:43:53 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Sun Jun 12 18:03:12 2016 +0800

----------------------------------------------------------------------
 include/singa/model/optimizer.h |  43 ++++++++++++++
 src/core/tensor/math_kernel.cu  |  14 ++---
 src/core/tensor/math_kernel.h   |   2 +-
 src/core/tensor/tensor.cc       |   3 +-
 src/model/optimizer/adagrad.cc  |  36 ++++++++++++
 src/model/optimizer/nesterov.cc |  43 ++++++++++++++
 src/model/optimizer/rmsprop.cc  |  41 ++++++++++++++
 src/proto/model.proto           |   3 +
 test/singa/test_adagrad.cc      |  96 +++++++++++++++++++++++++++++++
 test/singa/test_nesterov.cc     | 101 +++++++++++++++++++++++++++++++++
 test/singa/test_rmsprop.cc      | 106 +++++++++++++++++++++++++++++++++++
 11 files changed, 478 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/core/tensor/math_kernel.cu
----------------------------------------------------------------------
diff --cc src/core/tensor/math_kernel.cu
index b618f9b,aed6add..484868a
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@@ -236,192 -300,151 +236,192 @@@ __global__ void KernelThreshold(const s
    }
  }
  
- __global__ void KernelGE(const int num, const float *in, const float x,
 -__global__ void kernel_div(const float *src_data_a, const float *src_data_b,
 -                           float *des_data, int n) {
 -  int index = blockIdx.x * blockDim.x + threadIdx.x;
 -  int num_threads = blockDim.x * gridDim.x;
 -  for (; index < n; index += num_threads) {
 -    des_data[index] = src_data_a[index] / src_data_b[index];
++__global__ void KernelGE(const size_t num, const float *in, const float x,
 +                         float *out) {
 +  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
 +       idx += blockDim.x * gridDim.x) {
 +    out[idx] = in[idx] >= x ? 1.0f : 0.0f;
    }
  }
- __global__ void KernelGT(const int num, const float *in, const float x,
 -
 -__global__ static void kernel_set_value(float *data, float value, int n) {
 -  int index = blockIdx.x * blockDim.x + threadIdx.x;
 -  int num_threads = blockDim.x * gridDim.x;
 -  for (; index < n; index += num_threads) {
 -    data[index] = value;
++__global__ void KernelGT(const size_t num, const float *in, const float x,
 +                         float *out) {
 +  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
 +       idx += blockDim.x * gridDim.x) {
 +    out[idx] = in[idx] > x ? 1.0f : 0.0f;
    }
  }
- __global__ void KernelLE(const int num, const float *in, const float x,
 -
 -__global__ void kernel_threshold(const float *src_data, float *des_data,
 -                                 float alpha, int n) {
 -  int index = blockIdx.x * blockDim.x + threadIdx.x;
 -  int num_threads = blockDim.x * gridDim.x;
 -  for (; index < n; index += num_threads) {
 -    des_data[index] = src_data[index] < alpha ? 1.0f : 0.0f;
++__global__ void KernelLE(const size_t num, const float *in, const float x,
 +                         float *out) {
 +  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
 +       idx += blockDim.x * gridDim.x) {
 +    out[idx] = in[idx] <= x ? 1.0f : 0.0f;
    }
  }
 -void sum(int n, const float *in, float *out) {
 -  int threads_per_block = n > CU1DBLOCK ? CU1DBLOCK : n;
 -  //  here, we only need one block
 -  int num_blocks = 1;
  
- __global__ void KernelLT(const int num, const float *in, const float x,
 -  kernel_sum_vec << <num_blocks, threads_per_block>>> (in, out, n);
++__global__ void KernelLT(const size_t num, const float *in, const float x,
 +                         float *out) {
 +  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
 +       idx += blockDim.x * gridDim.x) {
 +    out[idx] = in[idx] < x ? 1.0f : 0.0f;
 +  }
  }
  
 -void sum_row(int rows, int cols, int stride, const float *in, float *out) {
 -  int threads_per_block = rows > CU1DBLOCK ? CU1DBLOCK : rows;
 -  int num_blocks = cols;
 +// ********************************
 +// Functions call kernels
 +// ********************************
  
 -  kernel_sum_row << <num_blocks, threads_per_block>>>
 -      (in, out, rows, cols, stride);
 +void set(const size_t n, const float v, float *out, cudaStream_t s) {
 +  KernelSet <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, v, out);
  }
  
 -void sum_col(int rows, int cols, int stride, const float *in, float *out) {
 -  int threads_per_block = cols > CU1DBLOCK ? CU1DBLOCK : cols;
 -  int num_blocks = rows;
 +void abs(const size_t n, const float *in, float *out, cudaStream_t s) {
 +  KernelAbs <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
 +}
  
 -  kernel_sum_col << <num_blocks, threads_per_block>>>
 -      (in, out, rows, cols, stride);
 +void sign(const size_t n, const float *in, float *out, cudaStream_t s) {
 +  KernelSign <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
  }
 -void add_row(int rows, int cols, int stride, const float *in_row,
 -             const float *in_mat, float *out) {
 -  dim3 threads_per_block(CU2DBLOCK_X, CU2DBLOCK_Y);
 -  dim3 num_blocks(
 -      cols / threads_per_block.x + (cols % threads_per_block.x == 0 ? 0 : 1),
 -      rows / threads_per_block.y + (rows % threads_per_block.y == 0 ? 0 : 1));
 -  kernel_add_vec_row << <num_blocks, threads_per_block>>>
 -      (in_row, in_mat, out, rows, cols, stride);
 +
 +void exp(const size_t n, const float *in, float *out, cudaStream_t s) {
 +  KernelExp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
  }
 -void add(int n, const float *a, const float *b, float *out) {
 -  kernel_add << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
 +
 +void log(const size_t n, const float *in, float *out, cudaStream_t s) {
 +  KernelLog <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
  }
 -void sub(int n, const float *a, const float *b, float *out) {
 -  kernel_sub << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
 +
 +void sqrt(const size_t n, const float *in, float *out, cudaStream_t s) {
 +  KernelSqrt <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
  }
 -void exp(int n, const float *in, float *out) {
 -  kernel_exp << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +
 +void square(const size_t n, const float *in, float *out, cudaStream_t s) {
 +  KernelSquare <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
  }
  
 -void log(int n, const float *in, float *out) {
 -  kernel_log << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void tanh(const size_t n, const float *in, float *out, cudaStream_t s) {
 +  KernelTanh <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
  }
  
 -void sigmoid(int n, const float *in, float *out) {
 -  kernel_sigmoid << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void relu(const size_t n, const float *in, float *out, cudaStream_t s) {
 +  KernelRelu <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
 +}
- void sigmoid(const int n, const float *in, float *out, cudaStream_t s) {
++void sigmoid(const size_t n, const float *in, float *out, cudaStream_t s) {
 +  KernelSigmoid <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
 +}
 +void softplus(const size_t n, const float *in, float *out, cudaStream_t s) {
 +  KernelSoftplus <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
 +}
 +void clamp(const size_t n, const float low, const float high, const float *in,
 +           float *out, cudaStream_t s) {
 +  KernelClamp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, low, high, in, out);
  }
  
 -void sigmoid_grad(int n, const float *in, float *out) {
 -  kernel_sigmoid_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void pow(const size_t n, const float *in, const float x, float *out,
 +         cudaStream_t s) {
 +  KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
  }
  
 -void relu(int n, const float *in, float *out) {
 -  kernel_relu << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void add(const size_t n, const float *in, const float x, float *out,
 +         cudaStream_t s) {
 +  KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
  }
  
 -void relu_grad(int n, const float *in, float *out) {
 -  kernel_relu_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void mult(const size_t n, const float *in, const float x, float *out,
 +          cudaStream_t s) {
 +  KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
  }
  
 -void tanh(int n, const float *in, float *out) {
 -  kernel_tanh << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void div(const size_t n, const float x, const float *in, float *out,
 +          cudaStream_t s) {
 +  KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, x, in, out);
  }
  
 -void tanh_grad(int n, const float *in, float *out) {
 -  kernel_tanh_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void threshold(const size_t n, const float x, const float *in, float *out,
 +               cudaStream_t s) {
 +  KernelThreshold <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, x, in, out);
  }
  
 -void softplus(int n, const float *in, float *out) {
 -  kernel_softplus << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void gt(const size_t num, const float *in, const float x, float *out,
 +        cudaStream_t s) {
 +  KernelGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
 +}
 +void ge(const size_t num, const float *in, const float x, float *out,
 +        cudaStream_t s) {
 +  KernelGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
 +}
 +void lt(const size_t num, const float *in, const float x, float *out,
 +        cudaStream_t s) {
 +  KernelLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
 +}
 +void le(const size_t num, const float *in, const float x, float *out,
 +        cudaStream_t s) {
 +  KernelLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
  }
  
 -void softplus_grad(int n, const float *in, float *out) {
 -  kernel_softplus_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void pow(const size_t n, const float *in1, const float *in2, float *out,
 +         cudaStream_t s) {
 +  KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
  }
  
 -void square(int n, const float *in, float *out) {
 -  kernel_square << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void add(const size_t n, const float *in1, const float *in2, float *out,
 +         cudaStream_t s) {
 +  KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
  }
  
 -void square_grad(int n, const float *in, float *out) {
 -  kernel_square_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void sub(const size_t n, const float *in1, const float *in2, float *out,
 +         cudaStream_t s) {
 +  KernelSub <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
  }
  
 -void sqrt(int n, const float *in, float *out) {
 -  kernel_sqrt << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 +void mult(const size_t n, const float *in1, const float *in2, float *out,
 +          cudaStream_t s) {
 +  KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
  }
  
 -void pow(int n, const float *a, const float *b, float *out) {
 -  kernel_pow << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
 +void div(const size_t n, const float *in1, const float *in2, float *out,
 +         cudaStream_t s) {
 +  KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
  }
  
 -void mult(int n, const float *a, const float *b, float *out) {
 -  kernel_mult << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
 +void sum(const size_t n, const float *in, float *out, cudaStream_t s) {
 +  int threads_per_block = n > CU1DBLOCK ? CU1DBLOCK : n;
 +  //  here, we only need one block
 +  int num_blocks = 1;
 +  KernelSum <<<num_blocks, threads_per_block>>> (n, in, out);
 +}
 +/*
 +void square_grad(int n, const float *in, float *out, cudaStream_t s) {
 +  kernel_square_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
  }
  
 -void mult(int n, const float *a, const float x, float *out) {
 -  kernel_mult << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, x, out, n);
 +void tanh_grad(int n, const float *in, float *out, cudaStream_t s) {
 +  kernel_tanh_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
  }
  
 -void div(int n, const float *a, const float *b, float *out) {
 -  kernel_div << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
 +
 +void relu_grad(int n, const float *in, float *out, cudaStream_t s) {
 +  kernel_relu_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
  }
  
 -void set_value(int n, float v, float *out) {
 -  kernel_set_value << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (out, v, n);
 +
 +void sigmoid_grad(int n, const float *in, float *out, cudaStream_t s) {
 +  kernel_sigmoid_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
  }
  
 -void threshold(int n, float alpha, const float *in, float *out) {
 -  kernel_threshold << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, alpha, n);
 +void softplus_grad(int n, const float *in, float *out, cudaStream_t s) {
 +  kernel_softplus_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
  }
  
 -// follow the consistency guide for math API
 -__global__ void KernelDiv(const size_t num, const float alpha, const float *in,
 -                          float *out) {
 -  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
 -       idx += blockDim.x * gridDim.x) {
 -    out[idx] = alpha / in[idx];
 +
 +__global__ void kernel_sum_col(const float *src_mat_data, float *dst_vec_data,
 +                               int rows, int cols, int stride) {
 +  int index = blockIdx.x * blockDim.x + threadIdx.x;
 +  int num_threads = blockDim.x * gridDim.x;
 +  for (; index < rows; index += num_threads) {
 +    dst_vec_data[index] = 0.0f;
 +    for (int k = 0; k < cols; k++) {
 +      dst_vec_data[index] += src_mat_data[index * stride + k];
 +    }
    }
  }
  
@@@ -485,62 -485,30 +485,62 @@@ __global__ void kernel_sigmoid_grad(con
    }
  }
  
 -void Set(const size_t num, const float x, float *out, cudaStream_t s) {
 -  KernelSet << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, x, out);
 +
 +__global__ void kernel_relu_grad(const float *src_data, float *des_data,
 +                                 int n) {
 +  int index = blockIdx.x * blockDim.x + threadIdx.x;
 +  int num_threads = blockDim.x * gridDim.x;
 +  for (; index < n; index += num_threads) {
 +    des_data[index] = src_data[index] > 0.0f ? 1.0f : 0.0f;
 +  }
  }
 -void Div(const size_t num, float alpha, const float *in, float *out,
 -         cudaStream_t s) {
 -  KernelDiv << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, alpha, in, out);
 +
 +__global__ void kernel_tanh_grad(const float *src_data, float *des_data,
 +                                 int n) {
 +  int index = blockIdx.x * blockDim.x + threadIdx.x;
 +  int num_threads = blockDim.x * gridDim.x;
 +  for (; index < n; index += num_threads) {
 +    des_data[index] = (1.0f - src_data[index] * src_data[index]);
 +  }
  }
  
 -void GT(const size_t num, const float *in, const float x, float *out,
 -        cudaStream_t s) {
 -  KernelGT << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
 +
 +__global__ void kernel_softplus_grad(const float *src_data, float *des_data,
 +                                     int n) {
 +  int index = blockIdx.x * blockDim.x + threadIdx.x;
 +  int num_threads = blockDim.x * gridDim.x;
 +  for (; index < n; index += num_threads) {
 +    des_data[index] = 1.0f / (1.0f + expf(-src_data[index]));
 +  }
  }
 -void GE(const size_t num, const float *in, const float x, float *out,
 -        cudaStream_t s) {
 -  KernelGE << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
 +__global__ void KernelSquareGrad(const float *src_data, float *des_data,
 +                                   int n) {
 +  int index = blockIdx.x * blockDim.x + threadIdx.x;
 +  int num_threads = blockDim.x * gridDim.x;
 +  for (; index < n; index += num_threads) {
 +    des_data[index] = 2 * src_data[index];
 +  }
  }
- __global__ void kernel_softmax_loss(const float *prob, const int *label,
 -void LT(const size_t num, const float *in, const float x, float *out,
 -        cudaStream_t s) {
 -  KernelLT << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
++__global__ void kernel_softmax_loss(const float *prob, const size_t *label,
 +                                    float *loss, int n, int dim) {
 +  int index = blockIdx.x * blockDim.x + threadIdx.x;
 +  int num_threads = blockDim.x * gridDim.x;
 +  for (; index < n; index += num_threads) {
 +    float prob_of_truth = prob[index * dim + label[index]];
 +    loss[index] -= std::log(max(prob_of_truth, FLT_MIN));
 +  }
  }
- __global__ void kernel_softmax_gradient(float *grad, const int *label, int n,
 -void LE(const size_t num, const float *in, const float x, float *out,
 -        cudaStream_t s) {
 -  KernelLE << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
++__global__ void kernel_softmax_gradient(float *grad, const size_t *label, int n,
 +                                        int dim, float scale) {
 +  int index = blockIdx.x * blockDim.x + threadIdx.x;
 +  int num_threads = blockDim.x * gridDim.x;
 +  for (; index < n; index += num_threads) {
 +    int pos = index * dim + label[index];
 +    grad[pos] = (grad[pos] - 1.0f) * scale;
 +  }
  }
 +*/
 +
  
  }  // namespace cuda
  }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/core/tensor/math_kernel.h
----------------------------------------------------------------------
diff --cc src/core/tensor/math_kernel.h
index d8a58a5,5c906a9..444f6ca
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@@ -31,66 -31,65 +31,66 @@@ namespace singa 
  
  // TODO(wangwei) make all function templates.
  namespace cuda {
 -void sum(int n, const float *in, float *out);
  
 -void sum_row(int rows, int cols, int stride, const float *in, float *out);
 -
 -void sum_col(int rows, int cols, int stride, const float *in, float *out);
 -
 -void add_row(int rows, int cols, int stride, const float *in_row,
 -             const float *in_mat, float *out);
 -
 -void add(int n, const float *a, const float *b, float *out);
 -
 -void sub(int n, const float *a, const float *b, float *out);
 -
 -void exp(int n, const float *in, float *out);
 -
 -void log(int n, const float *in, float *out);
 -
 -void sigmoid(int n, const float *in, float *out);
 -
 -void sigmoid_grad(int n, const float *in, float *out);
 -
 -void relu(int n, const float *in, float *out);
 -
 -void relu_grad(int n, const float *in, float *out);
 -
 -void tanh(int n, const float *in, float *out);
 -
 -void tanh_grad(int n, const float *in, float *out);
 +// 0 input
 +void set(const size_t n, const float v, float *out, cudaStream_t s);
 +
 +// 1 input
 +void abs(const size_t n, const float *in, float *out, cudaStream_t s);
 +void sign(const size_t n, const float *in, float *out, cudaStream_t s);
 +void exp(const size_t n, const float *in, float *out, cudaStream_t s);
 +void log(const size_t n, const float *in, float *out, cudaStream_t s);
 +void sqrt(const size_t n, const float *in, float *out, cudaStream_t s);
 +void square(const size_t n, const float *in, float *out, cudaStream_t s);
 +void tanh(const size_t n, const float *in, float *out, cudaStream_t s);
 +void relu(const size_t n, const float *in, float *out, cudaStream_t s);
- void sigmoid(const int n, const float *in, float *out, cudaStream_t s);
++void sigmoid(const size_t n, const float *in, float *out, cudaStream_t s);
 +void softplus(const size_t n, const float *in, float *out, cudaStream_t s);
 +void clamp(const size_t n, const float low, const float high, const float *in,
 +           float *out, cudaStream_t s);
 +
 +void pow(const size_t n, const float *in, const float x, float *out,
 +         cudaStream_t s);
  
 -void softplus(int n, const float *in, float *out);
 +void add(const size_t n, const float *in, const float x, float *out,
 +         cudaStream_t s);
  
 -void softplus_grad(int n, const float *in, float *out);
 +void mult(const size_t n, const float *in, const float x, float *out,
 +          cudaStream_t s);
  
 -void square(int n, const float *in, float *out);
 +void div(const size_t n, const float x, const float *in, float *out,
 +         cudaStream_t s);
  
 -void square_grad(int n, const float *in, float *out);
 +void threshold(const size_t n, const float x, const float *in, float *out,
 +               cudaStream_t s);
  
 -void sqrt(int n, const float *in, float *out);
 +void gt(const size_t num, const float *in, const float x, float *out,
 +        cudaStream_t s);
 +void ge(const size_t num, const float *in, const float x, float *out,
 +        cudaStream_t s);
 +void lt(const size_t num, const float *in, const float x, float *out,
 +        cudaStream_t s);
 +void le(const size_t num, const float *in, const float x, float *out,
 +        cudaStream_t s);
  
 -void pow(int n, const float *a, const float *b, float *out);
 +// 2 inputs
 +void pow(const size_t n, const float *in1, const float *in2, float *out,
 +         cudaStream_t s);
  
 -void mult(int n, const float *a, const float *b, float *out);
 +void add(const size_t n, const float *in1, const float *in2, float *out,
 +         cudaStream_t s);
  
 -void mult(int n, const float *a, const float x, float *out);
 +void sub(const size_t n, const float *in1, const float *in2, float *out,
 +         cudaStream_t s);
  
 -void div(int n, const float *a, const float *b, float *out);
 +void mult(const size_t n, const float *in1, const float *in2, float *out,
 +          cudaStream_t s);
  
 -void set_value(int n, float v, float *out);
 +void div(const size_t n, const float *in1, const float *in2, float *out,
 +         cudaStream_t s);
  
 -void threshold(int n, float alpha, const float *in, float *out);
 +void sum(const size_t n, const float *in, float *out, cudaStream_t s);
  
 -// follow the consistency guide for math API
 -void Div(const size_t num, const float x, const float *in, float *out,
 -         cudaStream_t s);
 -void Set(const size_t num, const float x, float *out, cudaStream_t s);
 -void GT(size_t num, const float *in, const float x, float *out, cudaStream_t s);
 -void GE(size_t num, const float *in, const float x, float *out, cudaStream_t s);
 -void LT(size_t num, const float *in, const float x, float *out, cudaStream_t s);
 -void LE(size_t num, const float *in, const float x, float *out, cudaStream_t s);
  }  // cuda
  
  }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --cc src/core/tensor/tensor.cc
index e62386a,5ae375c..e6917d8
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@@ -639,92 -701,4 +639,91 @@@ void SumRows(const Tensor &M, Tensor *v
      Mult(X, one, v);
    }
  }
 +// ====================Random operations=====================================
 +template <typename SType>
 +void Bernoulli(const SType p, Tensor *out) {
 +  TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, {
 +    auto prob = TypeCast<SType, DType>(p);
 +    out->device()->Exec([prob, out](Context *ctx) {
 +      Bernoulli<DType, Lang>(out->Size(), prob, out->blob(), ctx);
 +    }, {}, {out->blob()}, true);
 +  });
 +}
 +template void Bernoulli<float>(const float p, Tensor *out);
 +
 +template <typename SType>
 +void Uniform(const SType low, const SType high, Tensor *out) {
 +  TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, {
 +    auto l = TypeCast<SType, DType>(low);
 +    auto h = TypeCast<SType, DType>(high);
 +    out->device()->Exec([l, h, out](Context *ctx) {
 +      Uniform<DType, Lang>(out->Size(), l, h, out->blob(), ctx);
 +    }, {}, {out->blob()}, true);
 +  });
 +}
 +template void Uniform<float>(const float low, const float high, Tensor *out);
 +
 +template <typename SType>
 +void Gaussian(const SType mean, const SType std, Tensor *out) {
 +  TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, {
 +    auto m = TypeCast<SType, DType>(mean);
 +    auto s = TypeCast<SType, DType>(std);
 +    out->device()->Exec([m, s, out](Context *ctx) {
 +      Gaussian<DType, Lang>(out->Size(), m, s, out->blob(), ctx);
 +    }, {}, {out->blob()}, true);
 +  });
 +}
 +template void Gaussian<float>(const float mean, const float std, Tensor *out);
 +
 +// ================Blas operations============================================
 +template <typename SType>
 +void Axpy(const SType alpha, const Tensor &in, Tensor *out) {
 +  TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
 +    auto a = TypeCast<SType, DType>(alpha);
 +    out->device()->Exec([a, in, out](Context *ctx) {
 +      Axpy<DType, Lang>(in.Size(), a, in.blob(), out->blob(), ctx);
 +    }, {in.blob(), out->blob()}, {out->blob()});
 +  });
 +}
- template <>
- void Axpy(const float alpha, const Tensor &in, Tensor *out);
++template void Axpy(const float alpha, const Tensor &in, Tensor *out);
 +
 +Tensor Mult(const Tensor &A, const Tensor &B) {
 +  Shape s;
 +  s.push_back(A.shape(0));
 +  if (B.nDim() == 2) s.push_back(B.shape(1));
 +  Tensor out(s, A.device(), A.data_type());
 +  Mult(A, B, &out);
 +  return out;
 +}
 +
 +void Mult(const Tensor &A, const Tensor &B, Tensor *out) {
 +  Mult(1.0f, A, B, 0.0f, out);
 +}
 +
 +template <typename SType>
 +void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
 +          Tensor *C) {
 +  CHECK_EQ(A.shape().size(), 2u);
 +  if (B.nDim() == 1u) {
 +    TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
 +      auto a = TypeCast<SType, DType>(alpha);
 +      auto b = TypeCast<SType, DType>(beta);
 +      C->device()->Exec([a, A, b, B, C](Context *ctx) {
 +        GEMV<DType, Lang>(A.transpose(), A.shape(0), A.shape(1), a, A.blob(),
 +                          B.blob(), b, C->blob(), ctx);
 +      }, {A.blob(), B.blob()}, {C->blob()});
 +    });
 +  } else {
 +    CHECK(!C->transpose());
 +    TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
 +      auto a = TypeCast<SType, DType>(alpha);
 +      auto b = TypeCast<SType, DType>(beta);
 +      C->device()->Exec([a, A, b, B, C](Context *ctx) {
 +        GEMM<DType, Lang>(A.transpose(), B.transpose(), A.shape(0), B.shape(1),
 +                          A.shape(1), a, A.blob(), B.blob(), b, C->blob(), ctx);
 +      }, {A.blob(), B.blob()}, {C->blob()});
 +    });
 +  }
 +}
 +
  }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/model/optimizer/adagrad.cc
----------------------------------------------------------------------
diff --cc src/model/optimizer/adagrad.cc
index 0000000,8bdb07c..0b8ec88
mode 000000,100644..100644
--- a/src/model/optimizer/adagrad.cc
+++ b/src/model/optimizer/adagrad.cc
@@@ -1,0 -1,35 +1,36 @@@
+ /**
+  * Licensed to the Apache Software Foundation (ASF) under one
+  * or more contributor license agreements.  See the NOTICE file
+  * distributed with this work for additional information
+  * regarding copyright ownership.  The ASF licenses this file
+  * to you under the Apache License, Version 2.0 (the
+  * "License"); you may not use this file except in compliance
+  * with the License.  You may obtain a copy of the License at
+  *
+  *     http://www.apache.org/licenses/LICENSE-2.0
+  *
+  * Unless required by applicable law or agreed to in writing, software
+  * distributed under the License is distributed on an "AS IS" BASIS,
+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  * See the License for the specific language governing permissions and
+  * limitations under the License.
+  */
+ #ifndef SRC_MODEL_OPTIMIZER_ADAGRAD_H_
+ #define SRC_MODEL_OPTIMIZER_ADAGRAD_H_
+ #include "singa/model/optimizer.h"
+ #include <functional>
+ namespace singa {
+ 
+ void Adagrad::Setup(const OptimizerConf& conf) { delta_ = conf.delta(); }
+ 
+ void Adagrad::Apply(int step, float lr, const string& name, Tensor* grad,
+                     Tensor* value) {
+   if (history_gradient_.find(name) == history_gradient_.end())
+     history_gradient_[name].ResetLike(*value);
+   Tensor& history = history_gradient_[name];
 -  history += (*grad) * (*grad);
 -  (*value) -= (*grad) * lr / Sqrt(history + delta_);
++  history += Square(*grad);
++  (*grad) /= Sqrt(history + delta_);
++  Axpy(-lr, *grad, value);
+ }
+ }  // namespace singa
+ #endif  // SRC_MODEL_OPTIMIZER_ADAGRAD_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/model/optimizer/rmsprop.cc
----------------------------------------------------------------------
diff --cc src/model/optimizer/rmsprop.cc
index 0000000,cad333c..7b9934c
mode 000000,100644..100644
--- a/src/model/optimizer/rmsprop.cc
+++ b/src/model/optimizer/rmsprop.cc
@@@ -1,0 -1,38 +1,41 @@@
+ /**
+  * Licensed to the Apache Software Foundation (ASF) under one
+  * or more contributor license agreements.  See the NOTICE file
+  * distributed with this work for additional information
+  * regarding copyright ownership.  The ASF licenses this file
+  * to you under the Apache License, Version 2.0 (the
+  * "License"); you may not use this file except in compliance
+  * with the License.  You may obtain a copy of the License at
+  *
+  *     http://www.apache.org/licenses/LICENSE-2.0
+  *
+  * Unless required by applicable law or agreed to in writing, software
+  * distributed under the License is distributed on an "AS IS" BASIS,
+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  * See the License for the specific language governing permissions and
+  * limitations under the License.
+  */
+ #ifndef SRC_MODEL_OPTIMIZER_ADAGRAD_H_
+ #define SRC_MODEL_OPTIMIZER_ADAGRAD_H_
+ #include "singa/model/optimizer.h"
+ #include <functional>
+ namespace singa {
+ 
+ void RMSProp::Setup(const OptimizerConf& conf) {
+   delta_ = conf.delta();
 -  rho_ = conf.delta();
++  rho_ = conf.rho();
+ }
+ 
+ void RMSProp::Apply(int step, float lr, const string& name, Tensor* grad,
+                     Tensor* value) {
 -  if (history_gradient_.find(name) == history_gradient_.end())
++  if (history_gradient_.find(name) == history_gradient_.end()) {
+     history_gradient_[name].ResetLike(*value);
++  }
+   Tensor& history = history_gradient_[name];
 -  history = history * rho_ + (*grad) * (*grad) * (1 - rho_);
 -  (*value) -= (*grad) * lr / Sqrt(history + delta_);
++  history *= rho_;
++  Axpy(1 - rho_, Square(*grad), &history);
++  (*grad) /= Sqrt(history + delta_);
++  Axpy(-lr, *grad, value);
+ }
+ }  // namespace singa
+ #endif  // SRC_MODEL_OPTIMIZER_ADAGRAD_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/proto/model.proto
----------------------------------------------------------------------
diff --cc src/proto/model.proto
index d368296,c26aa35..ca6f0cd
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@@ -86,6 -86,9 +86,9 @@@ message OptimizerConf 
  
    // used by vanilla sgd and nesterov
    optional float momentum = 5 [default = 0.9];
+ 
+   // delta is used to avoid dividing zero
 -  optional float delta = 6 [default = 0.0000001];
++  optional float delta = 6 [default = 1e-8];
  }
  
  message ConstraintConf {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/test/singa/test_adagrad.cc
----------------------------------------------------------------------
diff --cc test/singa/test_adagrad.cc
index 0000000,1382467..80240b1
mode 000000,100644..100644
--- a/test/singa/test_adagrad.cc
+++ b/test/singa/test_adagrad.cc
@@@ -1,0 -1,92 +1,96 @@@
+ /************************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *************************************************************/
+ 
+ #include "gtest/gtest.h"
+ #include "singa/model/optimizer.h"
+ #include "singa_config.h"
+ #include <cmath>
+ 
+ TEST(Adagrad, ApplyCPU) {
+   singa::Adagrad adagrad;
+   float lr = 0.1f;
+   const float v[4] = {0.1, 0.2, 0.3, 0.4};
+   const float g[4] = {0.01, 0.02, 0.03, 0.04};
+ 
+   singa::Tensor value(singa::Shape{4}), grad(singa::Shape{4});
+   value.CopyDataFromHostPtr(v, 4);
+   grad.CopyDataFromHostPtr(g, 4);
+ 
++  singa::OptimizerConf conf;
++  adagrad.Setup(conf);
+   adagrad.Apply(0, lr, "xx", &grad, &value);
+ 
+   singa::Tensor v1 = value.Clone();
+   const float* newv1 = v1.data<const float*>();
+   float history[4];
+   for (int i = 0; i < 4; ++i) history[i] = g[i] * g[i];
+   for (int i = 0; i < 4; ++i)
 -    EXPECT_FLOAT_EQ(newv1[i],
 -                    v[i] - lr * g[i] / sqrt(history[i] + (float)1E-8));
++    EXPECT_NEAR(newv1[i], v[i] - lr * g[i] / sqrt(history[i] + conf.delta()),
++                1e-5);
+ 
+   grad.CopyDataFromHostPtr(g, 4);
+   adagrad.Apply(1, lr, "xx", &grad, &value);
+   singa::Tensor v2 = value.Clone();
+   const float* newv2 = v2.data<const float*>();
+   for (int i = 0; i < 4; ++i) history[i] += g[i] * g[i];
+ 
+   for (int i = 0; i < 4; ++i)
 -    EXPECT_FLOAT_EQ(newv2[i],
 -                    newv1[i] - lr * g[i] / sqrt(history[i] + (float)1E-8));
++    EXPECT_NEAR(newv2[i],
++                newv1[i] - lr * g[i] / sqrt(history[i] + conf.delta()), 1e-5);
+ }
+ 
+ #ifdef USE_CUDA
+ TEST(Adagrad, ApplyCUDA) {
+   singa::Adagrad adagrad;
+   float lr = 0.1f;
+   const float v[4] = {0.1, 0.2, 0.3, 0.4};
+   const float g[4] = {0.01, 0.02, 0.03, 0.04};
+ 
+   singa::CudaGPU dev;
+   singa::Tensor value(singa::Shape{4}, &dev), grad(singa::Shape{4}, &dev);
+   value.CopyDataFromHostPtr(v, 4);
+   grad.CopyDataFromHostPtr(g, 4);
+ 
++  singa::OptimizerConf conf;
++  adagrad.Setup(conf);
+   adagrad.Apply(0, lr, "xx", &grad, &value);
+ 
+   singa::Tensor v1 = value.Clone();
+   v1.ToHost();
+   const float* newv1 = v1.data<const float*>();
+   float history[4];
+   for (int i = 0; i < 4; ++i) history[i] = g[i] * g[i];
+   for (int i = 0; i < 4; ++i)
 -    EXPECT_FLOAT_EQ(newv1[i],
 -                    v[i] - lr * g[i] / sqrt(history[i] + (float)1E-8));
++    EXPECT_NEAR(newv1[i], v[i] - lr * g[i] / sqrt(history[i] + conf.delta()),
++                1e-5);
+ 
+   grad.CopyDataFromHostPtr(g, 4);
+   adagrad.Apply(1, lr, "xx", &grad, &value);
+   singa::Tensor v2 = value.Clone();
+   v2.ToHost();
+   const float* newv2 = v2.data<const float*>();
+   for (int i = 0; i < 4; ++i) history[i] += g[i] * g[i];
+ 
+   for (int i = 0; i < 4; ++i)
+     EXPECT_FLOAT_EQ(newv2[i],
 -                    newv1[i] - lr * g[i] / sqrt(history[i] + (float)1E-8));
++                    newv1[i] - lr * g[i] / sqrt(history[i] + conf.delta()));
+ }
+ #endif

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/test/singa/test_rmsprop.cc
----------------------------------------------------------------------
diff --cc test/singa/test_rmsprop.cc
index 0000000,62101f7..8104f50
mode 000000,100644..100644
--- a/test/singa/test_rmsprop.cc
+++ b/test/singa/test_rmsprop.cc
@@@ -1,0 -1,103 +1,106 @@@
+ /************************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *************************************************************/
+ 
+ #include "gtest/gtest.h"
+ #include "singa/model/optimizer.h"
+ #include "singa_config.h"
+ #include <cmath>
+ 
+ TEST(RMSProp, ApplyCPU) {
+   singa::RMSProp rmsprop;
+   float lr = 0.1f;
 -  float rho = 0.002f;
++  float rho = 0.9;
+   const float v[4] = {0.1, 0.2, 0.3, 0.4};
+   const float g[4] = {0.01, 0.02, 0.03, 0.04};
+ 
+   singa::OptimizerConf conf;
+   conf.set_rho(rho);
++  conf.set_delta(1E-8);
+ 
+   singa::Tensor value(singa::Shape{4}), grad(singa::Shape{4});
+   value.CopyDataFromHostPtr(v, 4);
+   grad.CopyDataFromHostPtr(g, 4);
+ 
+   rmsprop.Setup(conf);
+   rmsprop.Apply(0, lr, "xx", &grad, &value);
+ 
+   singa::Tensor v1 = value.Clone();
+   const float* newv1 = v1.data<const float*>();
+   float history[4];
+   for (int i = 0; i < 4; ++i) history[i] = g[i] * g[i] * (1 - rho);
+   for (int i = 0; i < 4; ++i)
 -    EXPECT_FLOAT_EQ(newv1[i],
 -                    v[i] - lr * g[i] / sqrt(history[i] + (float)1E-8));
++    EXPECT_NEAR(newv1[i], v[i] - g[i] * lr / sqrt(history[i] + (float)1E-8),
++                1e-5);
+ 
+   grad.CopyDataFromHostPtr(g, 4);
+   rmsprop.Apply(1, lr, "xx", &grad, &value);
+   singa::Tensor v2 = value.Clone();
+   const float* newv2 = v2.data<const float*>();
+   for (int i = 0; i < 4; ++i)
 -    history[i] += history[i] * rho + g[i] * g[i] * (1 - rho);
++    history[i] = history[i] * rho + g[i] * g[i] * (1 - rho);
+ 
+   for (int i = 0; i < 4; ++i)
 -    EXPECT_FLOAT_EQ(newv2[i],
 -                    newv1[i] - lr * g[i] / sqrt(history[i] + (float)1E-8));
++    EXPECT_NEAR(newv2[i], newv1[i] - lr * g[i] / sqrt(history[i] + (float)1E-8),
++                1e-5);
+ }
+ 
+ #ifdef USE_CUDA
+ TEST(RMSProp, ApplyCUDA) {
+   singa::RMSProp rmsprop;
+   float lr = 0.1f;
 -  float rho = 0.002f;
++  float rho = 0.02;
+   const float v[4] = {0.1, 0.2, 0.3, 0.4};
+   const float g[4] = {0.01, 0.02, 0.03, 0.04};
+ 
+   singa::OptimizerConf conf;
+   conf.set_rho(rho);
++  conf.set_delta(1e-8);
+ 
+   singa::CudaGPU dev;
+   singa::Tensor value(singa::Shape{4}, &dev), grad(singa::Shape{4}, &dev);
+   value.CopyDataFromHostPtr(v, 4);
+   grad.CopyDataFromHostPtr(g, 4);
+ 
++  rmsprop.Setup(conf);
+   rmsprop.Apply(0, lr, "xx", &grad, &value);
+ 
+   singa::Tensor v1 = value.Clone();
+   v1.ToHost();
+   const float* newv1 = v1.data<const float*>();
+   float history[4];
+   for (int i = 0; i < 4; ++i) history[i] = g[i] * g[i] * (1 - rho);
+   for (int i = 0; i < 4; ++i)
 -    EXPECT_FLOAT_EQ(newv1[i],
 -                    v[i] - lr * g[i] / sqrt(history[i] + (float)1E-8));
++    EXPECT_NEAR(newv1[i], v[i] - lr * g[i] / sqrt(history[i] + conf.delta()),
++                1e-5);
+ 
+   grad.CopyDataFromHostPtr(g, 4);
+   rmsprop.Apply(1, lr, "xx", &grad, &value);
+   singa::Tensor v2 = value.Clone();
+   v2.ToHost();
+   const float* newv2 = v2.data<const float*>();
+   for (int i = 0; i < 4; ++i)
 -    history[i] += history[i] * rho + g[i] * g[i] * (1 - rho);
++    history[i] = history[i] * rho + g[i] * g[i] * (1 - rho);
+ 
+   for (int i = 0; i < 4; ++i)
 -    EXPECT_FLOAT_EQ(newv2[i],
 -                    newv1[i] - lr * g[i] / sqrt(history[i] + (float)1E-8));
++    EXPECT_NEAR(newv2[i],
++                newv1[i] - lr * g[i] / sqrt(history[i] + conf.delta()), 1e-5);
+ }
+ #endif