You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/10/30 17:37:06 UTC

systemml git commit: [SYSTEMML-540] Added a rewrite to support a common tensor operation (sum over channels)

Repository: systemml
Updated Branches:
  refs/heads/master 2896f3316 -> d916ba5bd


[SYSTEMML-540] Added a rewrite to support a common tensor operation (sum over channels)

- Added a rewrite to convert out = rowSums(matrix(colSums(A), rows=C, cols=HW)) to out = channel_sums(A) when nrow(A) > 1 and exectype is CP or GPU.
- This avoids unnecessary intermediates and GPU-CP-GPU transfer (for
  reshape). This saves about ~150 seconds on sentence CNN for 200 epochs.
- When we move to a higher CuDNN version, we can replace the custom channel_sums kernel with possibly more optimized CuDNN reduce tensor kernel.
- Added the corresponding CPU and GPU tests.
- Updated T_MAX(val) to MAX(). Interestingly enough, nvcc was smart enough
  to remove the parameter automatically, hence the ptx remained the same
  after the change.

Closes #693.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d916ba5b
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d916ba5b
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d916ba5b

Branch: refs/heads/master
Commit: d916ba5bd8ceec591a04f4d16c6d24f3985e3e4f
Parents: 2896f33
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Mon Oct 30 10:32:53 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Mon Oct 30 10:32:53 2017 -0700

----------------------------------------------------------------------
 src/main/cpp/kernels/Makefile                   |   2 +-
 src/main/cpp/kernels/SystemML.cu                | 308 ++++++++++---------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |  95 ++++--
 .../apache/sysml/lops/ConvolutionTransform.java |  42 ++-
 .../instructions/CPInstructionParser.java       |   1 +
 .../instructions/GPUInstructionParser.java      |   1 +
 .../cp/ConvolutionCPInstruction.java            |  86 ++++++
 .../gpu/ConvolutionGPUInstruction.java          |  47 +++
 .../spark/QuantilePickSPInstruction.java        |   2 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      |  31 ++
 .../runtime/matrix/data/LibMatrixCuDNN.java     |   4 +-
 .../sysml/test/gpu/AggregateUnaryOpTests.java   |  31 ++
 .../apache/sysml/test/gpu/UnaryOpTestsBase.java |   8 +-
 .../functions/tensor/ChannelSumTest.java        | 146 +++++++++
 .../scripts/functions/tensor/ChannelSumTest.R   |  39 +++
 .../scripts/functions/tensor/ChannelSumTest.dml |  35 +++
 16 files changed, 690 insertions(+), 188 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/cpp/kernels/Makefile
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/Makefile b/src/main/cpp/kernels/Makefile
index 5feae69..ec10317 100644
--- a/src/main/cpp/kernels/Makefile
+++ b/src/main/cpp/kernels/Makefile
@@ -16,7 +16,7 @@
 # under the License.
 
 NVCC=nvcc
-CUDAFLAGS= -ptx -c -arch=sm_30 
+CUDAFLAGS= -ptx -c -arch=sm_30 --std c++11
 
 # Use these flags for precise math
 #CUDAFLAGS= -ptx -c -arch=sm_30 -ftz=false -prec-div=true -prec-sqrt=true

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index d176f8f..ade2dd1 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -20,7 +20,7 @@
 /**********************************
 When updating a kernel or adding a new one,
 please compile the ptx file and commit it:
-nvcc -ptx -arch=sm_30 SystemML.cu
+nvcc -ptx -arch=sm_30 --std c++11 SystemML.cu
 ***********************************/
 
 #include <cfloat>
@@ -29,7 +29,8 @@ nvcc -ptx -arch=sm_30 SystemML.cu
 extern "C" __global__ void double2float_f(double *A, float *ret, int N) {
   int tid = blockIdx.x * blockDim.x + threadIdx.x;
   if (tid < N) {
-  	// TODO: Use __double2float_rd or __double2float_rn  or __double2float_ru or __double2float_rz after 
+    // TODO: Use __double2float_rd or __double2float_rn  or __double2float_ru or
+    // __double2float_rz after
     ret[tid] = (float)A[tid];
   }
 }
@@ -84,15 +85,14 @@ __device__ void slice_sparse_dense_row(T *inVal, int *inRowPtr, int *colInd,
      *
      * int size = inRowPtr[rowIndex+1] - inRowPtr[rowIndex];
      * double numThreads = (double)min(size, MAX_NUM_THREADS_CHILD_KERNEL);
-     * slice_sparse_dense_row_helper<<< ceil(numThreads/
-*MAX_NUM_THREADS_CHILD_KERNEL), MAX_NUM_THREADS_CHILD_KERNEL>>>(inVal, inRowPtr,
-*colInd, ret,
-*			rl, ru, cl, cu, retClen, inRowPtr[rowIndex],
-*inRowPtr[rowIndex+1], index);
-*
-* Two-step compilation and linking process in JCudaKernels's constructor:
-* cuLinkAddFile(linkState, CUjitInputType.CU_JIT_INPUT_LIBRARY,
-*"/usr/local/cuda/lib64/libcudadevrt.a", jitOptions);
+     * slice_sparse_dense_row_helper
+     * <<< ceil(numThreads/MAX_NUM_THREADS_CHILD_KERNEL), MAX_NUM_THREADS_CHILD_KERNEL>>>
+     * (inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, retClen, inRowPtr[rowIndex],
+     *	inRowPtr[rowIndex+1], index);
+     *
+     * Two-step compilation and linking process in JCudaKernels's constructor:
+     * cuLinkAddFile(linkState, CUjitInputType.CU_JIT_INPUT_LIBRARY,
+     * "/usr/local/cuda/lib64/libcudadevrt.a", jitOptions);
      */
     // Iterate over elements of the row 'rowIndex'.
     for (int i = inRowPtr[rowIndex]; i < inRowPtr[rowIndex + 1]; i++) {
@@ -104,17 +104,18 @@ __device__ void slice_sparse_dense_row(T *inVal, int *inRowPtr, int *colInd,
   }
 }
 
-extern "C" __global__ void slice_sparse_dense_row_d(double *inVal, int *inRowPtr,
-                                                   int *colInd, double *ret,
-                                                   int rl, int ru, int cl,
-                                                   int cu, int retClen) {
+extern "C" __global__ void slice_sparse_dense_row_d(double *inVal,
+                                                    int *inRowPtr, int *colInd,
+                                                    double *ret, int rl, int ru,
+                                                    int cl, int cu,
+                                                    int retClen) {
   slice_sparse_dense_row(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, retClen);
 }
 
 extern "C" __global__ void slice_sparse_dense_row_f(float *inVal, int *inRowPtr,
-                                                   int *colInd, float *ret,
-                                                   int rl, int ru, int cl,
-                                                   int cu, int retClen) {
+                                                    int *colInd, float *ret,
+                                                    int rl, int ru, int cl,
+                                                    int cu, int retClen) {
   slice_sparse_dense_row(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, retClen);
 }
 
@@ -153,17 +154,18 @@ __device__ void slice_sparse_dense_nnz(T *inVal, int *inRowPtr, int *colInd,
   }
 }
 
-extern "C" __global__ void slice_sparse_dense_nnz_d(double *inVal, int *inRowPtr,
-                                                   int *colInd, double *ret,
-                                                   int rl, int ru, int cl,
-                                                   int cu, int retClen) {
+extern "C" __global__ void slice_sparse_dense_nnz_d(double *inVal,
+                                                    int *inRowPtr, int *colInd,
+                                                    double *ret, int rl, int ru,
+                                                    int cl, int cu,
+                                                    int retClen) {
   slice_sparse_dense_nnz(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, retClen);
 }
 
 extern "C" __global__ void slice_sparse_dense_nnz_f(float *inVal, int *inRowPtr,
-                                                   int *colInd, float *ret,
-                                                   int rl, int ru, int cl,
-                                                   int cu, int retClen) {
+                                                    int *colInd, float *ret,
+                                                    int rl, int ru, int cl,
+                                                    int cu, int retClen) {
   slice_sparse_dense_nnz(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, retClen);
 }
 
@@ -194,16 +196,16 @@ __device__ void slice_dense_dense(T *in, T *ret, int rl, int ru, int cl, int cu,
 }
 
 extern "C" __global__ void slice_dense_dense_d(double *in, double *ret, int rl,
-                                              int ru, int cl, int cu,
-                                              int inClen, int retRlen,
-                                              int retClen) {
+                                               int ru, int cl, int cu,
+                                               int inClen, int retRlen,
+                                               int retClen) {
   slice_dense_dense(in, ret, rl, ru, cl, cu, inClen, retRlen, retClen);
 }
 
 extern "C" __global__ void slice_dense_dense_f(float *in, float *ret, int rl,
-                                              int ru, int cl, int cu,
-                                              int inClen, int retRlen,
-                                              int retClen) {
+                                               int ru, int cl, int cu,
+                                               int inClen, int retRlen,
+                                               int retClen) {
   slice_dense_dense(in, ret, rl, ru, cl, cu, inClen, retRlen, retClen);
 }
 
@@ -236,15 +238,15 @@ extern "C" __global__ void copy_u2l_dense_f(float *ret, int dim, int N) {
 
 // Use this method in templates to fetch the maximum value for a given datatype
 template <typename T>
-__forceinline__ __device__ T T_MAX(T x) {
-  return (T)DBL_MAX;
+__forceinline__ __device__ T MAX() {
+  return T();
 }
 template <>
-__forceinline__ __device__ float T_MAX(float x) {
+__forceinline__ __device__ float MAX <float>() {
   return FLT_MAX;
 }
 template <>
-__forceinline__ __device__ double T_MAX(double x) {
+__forceinline__ __device__ double MAX <double>() {
   return DBL_MAX;
 }
 
@@ -311,7 +313,7 @@ __forceinline__ __device__ T binaryOp(T x, T y, int op) {
       }
     }
     default:
-      return T_MAX(x);
+      return MAX<T>();
   }
 }
 
@@ -342,7 +344,8 @@ extern "C" __global__ void relu_f(float *A, float *ret, int rlen, int clen) {
 }
 
 /**
- * This method computes the backpropagation errors for previous layer of relu operation
+ * This method computes the backpropagation errors for previous layer of relu
+ * operation
  *
  * @param X input activation array allocated on the GPU
  * @param dout errors from previous layer
@@ -361,12 +364,12 @@ __device__ void relu_backward(T *X, T *dout, T *ret, int rlen, int clen) {
 }
 
 extern "C" __global__ void relu_backward_d(double *X, double *dout, double *ret,
-                                          int rlen, int clen) {
+                                           int rlen, int clen) {
   relu_backward(X, dout, ret, rlen, clen);
 }
 
 extern "C" __global__ void relu_backward_f(float *X, float *dout, float *ret,
-                                          int rlen, int clen) {
+                                           int rlen, int clen) {
   relu_backward(X, dout, ret, rlen, clen);
 }
 
@@ -389,12 +392,12 @@ __device__ void inplace_add(T *input, T *ret, int rlen, int clen) {
 }
 
 extern "C" __global__ void inplace_add_d(double *input, double *ret, int rlen,
-                                        int clen) {
+                                         int clen) {
   inplace_add(input, ret, rlen, clen);
 }
 
 extern "C" __global__ void inplace_add_f(float *input, float *ret, int rlen,
-                                        int clen) {
+                                         int clen) {
   inplace_add(input, ret, rlen, clen);
 }
 
@@ -416,12 +419,12 @@ __device__ void bias_add(T *input, T *bias, T *ret, int rlen, int clen,
 }
 
 extern "C" __global__ void bias_add_d(double *input, double *bias, double *ret,
-                                     int rlen, int clen, int PQ) {
+                                      int rlen, int clen, int PQ) {
   bias_add(input, bias, ret, rlen, clen, PQ);
 }
 
 extern "C" __global__ void bias_add_f(float *input, float *bias, float *ret,
-                                     int rlen, int clen, int PQ) {
+                                      int rlen, int clen, int PQ) {
   bias_add(input, bias, ret, rlen, clen, PQ);
 }
 
@@ -443,16 +446,16 @@ __device__ void daxpy_matrix_vector(T *A, T *B, double alpha, T *ret, int rlenA,
 }
 
 extern "C" __global__ void daxpy_matrix_vector_d(double *A, double *B,
-                                                double alpha, double *ret,
-                                                int rlenA, int clenA, int rlenB,
-                                                int clenB) {
+                                                 double alpha, double *ret,
+                                                 int rlenA, int clenA,
+                                                 int rlenB, int clenB) {
   daxpy_matrix_vector(A, B, alpha, ret, rlenA, clenA, rlenB, clenB);
 }
 
 extern "C" __global__ void daxpy_matrix_vector_f(float *A, float *B,
-                                                double alpha, float *ret,
-                                                int rlenA, int clenA, int rlenB,
-                                                int clenB) {
+                                                 double alpha, float *ret,
+                                                 int rlenA, int clenA,
+                                                 int rlenB, int clenB) {
   daxpy_matrix_vector(A, B, alpha, ret, rlenA, clenA, rlenB, clenB);
 }
 
@@ -471,13 +474,14 @@ __device__ void bias_multiply(T *input, T *bias, T *ret, int rlen, int clen,
 }
 
 extern "C" __global__ void bias_multiply_d(double *input, double *bias,
-                                          double *ret, int rlen, int clen,
-                                          int PQ) {
+                                           double *ret, int rlen, int clen,
+                                           int PQ) {
   bias_multiply(input, bias, ret, rlen, clen, PQ);
 }
 
-extern "C" __global__ void bias_multiply_f(float *input, float *bias, float *ret,
-                                          int rlen, int clen, int PQ) {
+extern "C" __global__ void bias_multiply_f(float *input, float *bias,
+                                           float *ret, int rlen, int clen,
+                                           int PQ) {
   bias_multiply(input, bias, ret, rlen, clen, PQ);
 }
 
@@ -563,14 +567,14 @@ __device__ void matrix_scalar_op(T *A, T scalar, T *C, int size, int op,
 }
 
 extern "C" __global__ void matrix_scalar_op_d(double *A, double scalar,
-                                             double *C, int size, int op,
-                                             int isLeftScalar) {
+                                              double *C, int size, int op,
+                                              int isLeftScalar) {
   matrix_scalar_op(A, scalar, C, size, op, isLeftScalar);
 }
 
 extern "C" __global__ void matrix_scalar_op_f(float *A, double scalar, float *C,
-                                             int size, int op,
-                                             int isLeftScalar) {
+                                              int size, int op,
+                                              int isLeftScalar) {
   matrix_scalar_op(A, (float)scalar, C, size, op, isLeftScalar);
 }
 
@@ -635,12 +639,12 @@ __device__ void cbind(T *A, T *B, T *C, int rowsA, int colsA, int rowsB,
 }
 
 extern "C" __global__ void cbind_d(double *A, double *B, double *C, int rowsA,
-                                  int colsA, int rowsB, int colsB) {
+                                   int colsA, int rowsB, int colsB) {
   cbind(A, B, C, rowsA, colsA, rowsB, colsB);
 }
 
 extern "C" __global__ void cbind_f(float *A, float *B, float *C, int rowsA,
-                                  int colsA, int rowsB, int colsB) {
+                                   int colsA, int rowsB, int colsB) {
   cbind(A, B, C, rowsA, colsA, rowsB, colsB);
 }
 
@@ -684,12 +688,12 @@ __device__ void rbind(T *A, T *B, T *C, int rowsA, int colsA, int rowsB,
 }
 
 extern "C" __global__ void rbind_d(double *A, double *B, double *C, int rowsA,
-                                  int colsA, int rowsB, int colsB) {
+                                   int colsA, int rowsB, int colsB) {
   rbind(A, B, C, rowsA, colsA, rowsB, colsB);
 }
 
 extern "C" __global__ void rbind_f(float *A, float *B, float *C, int rowsA,
-                                  int colsA, int rowsB, int colsB) {
+                                   int colsA, int rowsB, int colsB) {
   rbind(A, B, C, rowsA, colsA, rowsB, colsB);
 }
 
@@ -828,15 +832,15 @@ template <typename ReductionOp, typename AssignmentOp, typename T>
 __device__ void reduce_row(
     T *g_idata,  ///< input data stored in device memory (of size rows*cols)
     T *g_odata,  ///< output/temporary array store in device memory (of size
-                 ///rows*cols)
+    /// rows*cols)
     unsigned int rows,  ///< rows in input and temporary/output arrays
     unsigned int cols,  ///< columns in input and temporary/output arrays
     ReductionOp
         reduction_op,  ///< Reduction operation to perform (functor object)
     AssignmentOp assignment_op,  ///< Operation to perform before assigning this
-                                 ///to its final location in global memory for
-                                 ///each row
-    T initialValue) {            ///< initial value for the reduction variable
+    /// to its final location in global memory for
+    /// each row
+    T initialValue) {  ///< initial value for the reduction variable
   // extern __shared__ T sdata[];
   extern __shared__ __align__(sizeof(T)) unsigned char my_sdata[];
   T *sdata = reinterpret_cast<T *>(my_sdata);
@@ -935,15 +939,15 @@ template <typename ReductionOp, typename AssignmentOp, typename T>
 __device__ void reduce_col(
     T *g_idata,  ///< input data stored in device memory (of size rows*cols)
     T *g_odata,  ///< output/temporary array store in device memory (of size
-                 ///rows*cols)
+    /// rows*cols)
     unsigned int rows,  ///< rows in input and temporary/output arrays
     unsigned int cols,  ///< columns in input and temporary/output arrays
     ReductionOp
         reduction_op,  ///< Reduction operation to perform (functor object)
     AssignmentOp assignment_op,  ///< Operation to perform before assigning this
-                                 ///to its final location in global memory for
-                                 ///each column
-    T initialValue)              ///< initial value for the reduction variable
+    /// to its final location in global memory for
+    /// each column
+    T initialValue)  ///< initial value for the reduction variable
 {
   unsigned int global_tid = blockIdx.x * blockDim.x + threadIdx.x;
   if (global_tid >= cols) {
@@ -990,12 +994,12 @@ __device__ void reduce_sum(T *g_idata, T *g_odata, unsigned int n) {
 }
 
 extern "C" __global__ void reduce_sum_d(double *g_idata, double *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_sum(g_idata, g_odata, n);
 }
 
 extern "C" __global__ void reduce_sum_f(float *g_idata, float *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_sum(g_idata, g_odata, n);
 }
 
@@ -1016,14 +1020,14 @@ __device__ void reduce_row_sum(T *g_idata, T *g_odata, unsigned int rows,
 }
 
 extern "C" __global__ void reduce_row_sum_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_sum(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_row_sum_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_sum(g_idata, g_odata, rows, cols);
 }
 
@@ -1044,14 +1048,14 @@ __device__ void reduce_col_sum(T *g_idata, T *g_odata, unsigned int rows,
 }
 
 extern "C" __global__ void reduce_col_sum_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_sum(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_col_sum_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_sum(g_idata, g_odata, rows, cols);
 }
 
@@ -1063,12 +1067,13 @@ struct MaxOp {
   __device__ __forceinline__ T operator()(T a, T b) const { return fmax(a, b); }
 };
 
-template<>
+template <>
 struct MaxOp<float> {
-  __device__ __forceinline__ float operator()(float a, float b) const { return fmaxf(a, b); }
+  __device__ __forceinline__ float operator()(float a, float b) const {
+    return fmaxf(a, b);
+  }
 };
 
-
 /**
  * Do a max over all elements of an array/matrix
  * @param g_idata   input data stored in device memory (of size n)
@@ -1078,16 +1083,16 @@ struct MaxOp<float> {
 template <typename T>
 __device__ void reduce_max(T *g_idata, T *g_odata, unsigned int n) {
   MaxOp<T> op;
-  reduce<MaxOp<T>, T>(g_idata, g_odata, n, op, -T_MAX(g_idata[0]));
+  reduce<MaxOp<T>, T>(g_idata, g_odata, n, op, -MAX<T>());
 }
 
 extern "C" __global__ void reduce_max_d(double *g_idata, double *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_max(g_idata, g_odata, n);
 }
 
 extern "C" __global__ void reduce_max_f(float *g_idata, float *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_max(g_idata, g_odata, n);
 }
 
@@ -1104,18 +1109,18 @@ __device__ void reduce_row_max(T *g_idata, T *g_odata, unsigned int rows,
   MaxOp<T> op;
   IdentityOp<T> aop;
   reduce_row<MaxOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop,
-                                         -T_MAX(g_idata[0]));
+                                         -MAX<T>());
 }
 
 extern "C" __global__ void reduce_row_max_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_max(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_row_max_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_max(g_idata, g_odata, rows, cols);
 }
 
@@ -1132,18 +1137,18 @@ __device__ void reduce_col_max(T *g_idata, T *g_odata, unsigned int rows,
   MaxOp<T> op;
   IdentityOp<T> aop;
   reduce_col<MaxOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop,
-                                         (T)-T_MAX(g_idata[0]));
+                                         -MAX<T>());
 }
 
 extern "C" __global__ void reduce_col_max_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_max(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_col_max_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_max(g_idata, g_odata, rows, cols);
 }
 
@@ -1164,16 +1169,16 @@ struct MinOp {
 template <typename T>
 __device__ void reduce_min(T *g_idata, T *g_odata, unsigned int n) {
   MinOp<T> op;
-  reduce<MinOp<T>, T>(g_idata, g_odata, n, op, T_MAX(g_idata[0]));
+  reduce<MinOp<T>, T>(g_idata, g_odata, n, op, MAX<T>());
 }
 
 extern "C" __global__ void reduce_min_d(double *g_idata, double *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_min(g_idata, g_odata, n);
 }
 
 extern "C" __global__ void reduce_min_f(float *g_idata, float *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_min(g_idata, g_odata, n);
 }
 
@@ -1190,18 +1195,18 @@ __device__ void reduce_row_min(T *g_idata, T *g_odata, unsigned int rows,
   MinOp<T> op;
   IdentityOp<T> aop;
   reduce_row<MinOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop,
-                                         T_MAX(g_idata[0]));
+                                         MAX<T>());
 }
 
 extern "C" __global__ void reduce_row_min_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_min(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_row_min_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_min(g_idata, g_odata, rows, cols);
 }
 
@@ -1218,18 +1223,18 @@ __device__ void reduce_col_min(T *g_idata, T *g_odata, unsigned int rows,
   MinOp<T> op;
   IdentityOp<T> aop;
   reduce_col<MinOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop,
-                                         T_MAX(g_idata[0]));
+                                         MAX<T>());
 }
 
 extern "C" __global__ void reduce_col_min_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_min(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_col_min_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_min(g_idata, g_odata, rows, cols);
 }
 
@@ -1254,12 +1259,12 @@ __device__ void reduce_prod(T *g_idata, T *g_odata, unsigned int n) {
 }
 
 extern "C" __global__ void reduce_prod_d(double *g_idata, double *g_odata,
-                                        unsigned int n) {
+                                         unsigned int n) {
   reduce_prod(g_idata, g_odata, n);
 }
 
 extern "C" __global__ void reduce_prod_f(float *g_idata, float *g_odata,
-                                        unsigned int n) {
+                                         unsigned int n) {
   reduce_prod(g_idata, g_odata, n);
 }
 
@@ -1293,14 +1298,14 @@ __device__ void reduce_row_mean(T *g_idata, T *g_odata, unsigned int rows,
 }
 
 extern "C" __global__ void reduce_row_mean_d(double *g_idata, double *g_odata,
-                                            unsigned int rows,
-                                            unsigned int cols) {
+                                             unsigned int rows,
+                                             unsigned int cols) {
   reduce_row_mean(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_row_mean_f(float *g_idata, float *g_odata,
-                                            unsigned int rows,
-                                            unsigned int cols) {
+                                             unsigned int rows,
+                                             unsigned int cols) {
   reduce_row_mean(g_idata, g_odata, rows, cols);
 }
 
@@ -1321,14 +1326,14 @@ __device__ void reduce_col_mean(T *g_idata, T *g_odata, unsigned int rows,
 }
 
 extern "C" __global__ void reduce_col_mean_d(double *g_idata, double *g_odata,
-                                            unsigned int rows,
-                                            unsigned int cols) {
+                                             unsigned int rows,
+                                             unsigned int cols) {
   reduce_col_mean(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_col_mean_f(float *g_idata, float *g_odata,
-                                            unsigned int rows,
-                                            unsigned int cols) {
+                                             unsigned int rows,
+                                             unsigned int cols) {
   reduce_col_mean(g_idata, g_odata, rows, cols);
 }
 
@@ -1347,7 +1352,7 @@ __device__ void matrix_exp(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_exp_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_exp(A, C, size);
 }
 
@@ -1370,11 +1375,12 @@ __device__ void matrix_sqrt(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_sqrt_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_sqrt(A, C, size);
 }
 
-extern "C" __global__ void matrix_sqrt_f(float *A, float *C, unsigned int size) {
+extern "C" __global__ void matrix_sqrt_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_sqrt(A, C, size);
 }
 
@@ -1393,12 +1399,12 @@ __device__ void matrix_round(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_round_d(double *A, double *C,
-                                         unsigned int size) {
+                                          unsigned int size) {
   matrix_round(A, C, size);
 }
 
 extern "C" __global__ void matrix_round_f(float *A, float *C,
-                                         unsigned int size) {
+                                          unsigned int size) {
   matrix_round(A, C, size);
 }
 
@@ -1417,7 +1423,7 @@ __device__ void matrix_abs(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_abs_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_abs(A, C, size);
 }
 
@@ -1440,7 +1446,7 @@ __device__ void matrix_log(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_log_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_log(A, C, size);
 }
 
@@ -1463,12 +1469,12 @@ __device__ void matrix_floor(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_floor_d(double *A, double *C,
-                                         unsigned int size) {
+                                          unsigned int size) {
   matrix_floor(A, C, size);
 }
 
 extern "C" __global__ void matrix_floor_f(float *A, float *C,
-                                         unsigned int size) {
+                                          unsigned int size) {
   matrix_floor(A, C, size);
 }
 
@@ -1487,11 +1493,12 @@ __device__ void matrix_ceil(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_ceil_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_ceil(A, C, size);
 }
 
-extern "C" __global__ void matrix_ceil_f(float *A, float *C, unsigned int size) {
+extern "C" __global__ void matrix_ceil_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_ceil(A, C, size);
 }
 
@@ -1510,7 +1517,7 @@ __device__ void matrix_sin(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_sin_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_sin(A, C, size);
 }
 
@@ -1533,11 +1540,12 @@ __device__ void matrix_sinh(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_sinh_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_sinh(A, C, size);
 }
 
-extern "C" __global__ void matrix_sinh_f(float *A, float *C, unsigned int size) {
+extern "C" __global__ void matrix_sinh_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_sinh(A, C, size);
 }
 
@@ -1556,7 +1564,7 @@ __device__ void matrix_cos(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_cos_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_cos(A, C, size);
 }
 
@@ -1579,11 +1587,12 @@ __device__ void matrix_cosh(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_cosh_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_cosh(A, C, size);
 }
 
-extern "C" __global__ void matrix_cosh_f(float *A, float *C, unsigned int size) {
+extern "C" __global__ void matrix_cosh_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_cosh(A, C, size);
 }
 
@@ -1602,7 +1611,7 @@ __device__ void matrix_tan(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_tan_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_tan(A, C, size);
 }
 
@@ -1625,11 +1634,12 @@ __device__ void matrix_tanh(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_tanh_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_tanh(A, C, size);
 }
 
-extern "C" __global__ void matrix_tanh_f(float *A, float *C, unsigned int size) {
+extern "C" __global__ void matrix_tanh_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_tanh(A, C, size);
 }
 
@@ -1648,11 +1658,12 @@ __device__ void matrix_asin(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_asin_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_asin(A, C, size);
 }
 
-extern "C" __global__ void matrix_asin_f(float *A, float *C, unsigned int size) {
+extern "C" __global__ void matrix_asin_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_asin(A, C, size);
 }
 
@@ -1671,11 +1682,12 @@ __device__ void matrix_acos(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_acos_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_acos(A, C, size);
 }
 
-extern "C" __global__ void matrix_acos_f(float *A, float *C, unsigned int size) {
+extern "C" __global__ void matrix_acos_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_acos(A, C, size);
 }
 
@@ -1694,11 +1706,12 @@ __device__ void matrix_atan(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_atan_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_atan(A, C, size);
 }
 
-extern "C" __global__ void matrix_atan_f(float *A, float *C, unsigned int size) {
+extern "C" __global__ void matrix_atan_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_atan(A, C, size);
 }
 
@@ -1722,10 +1735,11 @@ __device__ void matrix_sign(T *A, T *C, unsigned int size) {
 }
 
 extern "C" __global__ void matrix_sign_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_sign(A, C, size);
 }
 
-extern "C" __global__ void matrix_sign_f(float *A, float *C, unsigned int size) {
+extern "C" __global__ void matrix_sign_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_sign(A, C, size);
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 04a32bd..9b9406a 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -26,6 +26,7 @@ import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.lops.Aggregate;
 import org.apache.sysml.lops.Aggregate.OperationTypes;
 import org.apache.sysml.lops.Binary;
+import org.apache.sysml.lops.ConvolutionTransform;
 import org.apache.sysml.lops.Group;
 import org.apache.sysml.lops.Lop;
 import org.apache.sysml.lops.LopsException;
@@ -131,6 +132,20 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 		return false;
 	}
 	
+	/**
+	 * Checks if channels sum rewrite is applicable
+	 * 
+	 * @return returns true for pattern rowSums(matrix(colSums(X), rows=.., cols=..)) else false
+	 */
+	private boolean isChannelSumRewriteApplicable() {
+		if( OptimizerUtils.ALLOW_OPERATOR_FUSION && _op == AggOp.SUM && _direction == Direction.Row
+			&& getInput().get(0) instanceof ReorgOp && ((ReorgOp)getInput().get(0)).getOp() == ReOrgOp.RESHAPE) {
+			Hop input1 = getInput().get(0).getInput().get(0);
+			return input1 instanceof AggUnaryOp && ((AggUnaryOp)input1)._op == AggOp.SUM && ((AggUnaryOp)input1)._direction == Direction.Col;
+		}
+		return false;
+	}
+	
 	@Override
 	public Lop constructLops()
 		throws HopsException, LopsException 
@@ -147,41 +162,57 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 			if ( et == ExecType.CP || et == ExecType.GPU ) 
 			{
 				Lop agg1 = null;
-				if( isTernaryAggregateRewriteApplicable() ) {
-					agg1 = constructLopsTernaryAggregateRewrite(et);
+				long numChannels = isChannelSumRewriteApplicable() ? Hop.computeSizeInformation(getInput().get(0).getInput().get(1)) : -1;
+				if(numChannels > 0 && numChannels < 1000000) {
+					// Apply channel sums only if rewrite is applicable and if the dimension of C is known at compile time
+					// and if numChannels is less than 8 MB.
+					ReorgOp in = ((ReorgOp)getInput().get(0));
+					agg1 = new ConvolutionTransform(
+							in.getInput().get(0).getInput().get(0).constructLops(), 
+							in.getInput().get(1).constructLops(),
+							in.getInput().get(2).constructLops(),
+							ConvolutionTransform.OperationTypes.CHANNEL_SUMS, getDataType(), getValueType(), et, -1);
+					agg1.getOutputParameters().setDimensions(numChannels, 1, getRowsInBlock(), getColsInBlock(), -1);
+					setLineNumbers(agg1);
+					setLops(agg1);
 				}
-				else if( isUnaryAggregateOuterCPRewriteApplicable() )
-				{
-					OperationTypes op = HopsAgg2Lops.get(_op);
-					DirectionTypes dir = HopsDirection2Lops.get(_direction);
-
-					BinaryOp binput = (BinaryOp)getInput().get(0);
-					agg1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), 
-							binput.getInput().get(1).constructLops(), op, dir, 
-							HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.CP);
-					PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir);
-				
+				else { 
+					if( isTernaryAggregateRewriteApplicable() ) {
+						agg1 = constructLopsTernaryAggregateRewrite(et);
+					}
+					else if( isUnaryAggregateOuterCPRewriteApplicable() )
+					{
+						OperationTypes op = HopsAgg2Lops.get(_op);
+						DirectionTypes dir = HopsDirection2Lops.get(_direction);
+	
+						BinaryOp binput = (BinaryOp)getInput().get(0);
+						agg1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), 
+								binput.getInput().get(1).constructLops(), op, dir, 
+								HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.CP);
+						PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir);
+					
+						if (getDataType() == DataType.SCALAR) {
+							UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
+									                    getDataType(), getValueType());
+							unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
+							setLineNumbers(unary1);
+							setLops(unary1);
+						}
+					
+					}				
+					else { //general case		
+						int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
+						agg1 = new PartialAggregate(input.constructLops(), 
+								HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), getDataType(),getValueType(), et, k);
+					}
+					
+					setOutputDimensions(agg1);
+					setLineNumbers(agg1);
+					setLops(agg1);
+					
 					if (getDataType() == DataType.SCALAR) {
-						UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
-								                    getDataType(), getValueType());
-						unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
-						setLineNumbers(unary1);
-						setLops(unary1);
+						agg1.getOutputParameters().setDimensions(1, 1, getRowsInBlock(), getColsInBlock(), getNnz());
 					}
-				
-				}				
-				else { //general case		
-					int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
-					agg1 = new PartialAggregate(input.constructLops(), 
-							HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), getDataType(),getValueType(), et, k);
-				}
-				
-				setOutputDimensions(agg1);
-				setLineNumbers(agg1);
-				setLops(agg1);
-				
-				if (getDataType() == DataType.SCALAR) {
-					agg1.getOutputParameters().setDimensions(1, 1, getRowsInBlock(), getColsInBlock(), getNnz());
 				}
 			}
 			else if( et == ExecType.MR )

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
index dfc187c..94a67f0 100644
--- a/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
+++ b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
@@ -32,7 +32,7 @@ public class ConvolutionTransform extends Lop
 	public enum OperationTypes {
 		MAX_POOLING, MAX_POOLING_BACKWARD, RELU_MAX_POOLING, RELU_BACKWARD, RELU_MAX_POOLING_BACKWARD,
 		DIRECT_CONV2D, DIRECT_CONV2D_BACKWARD_FILTER, DIRECT_CONV2D_BACKWARD_DATA,
-		BIAS_ADD, DIRECT_CONV2D_BIAS_ADD, BIAS_MULTIPLY
+		BIAS_ADD, DIRECT_CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS
 	}
 	
 	private OperationTypes operation = null;
@@ -67,6 +67,18 @@ public class ConvolutionTransform extends Lop
 		input2.addOutput(this);
 		setLevel();
 	}
+	
+	public ConvolutionTransform(Lop input1, Lop input2, Lop input3, ConvolutionTransform.OperationTypes op, DataType dt, ValueType vt, ExecType et, int k) 
+	{
+		super(Lop.Type.Transform, dt, vt);		
+		init(input1, op, dt, vt, et);
+		numThreads = k;
+		this.addInput(input2);
+		input2.addOutput(this);
+		this.addInput(input3);
+		input3.addOutput(this);
+		setLevel();
+	}
 
 	private void init (Lop input, ConvolutionTransform.OperationTypes op, DataType dt, ValueType vt, ExecType et) 
 	{
@@ -142,6 +154,9 @@ public class ConvolutionTransform extends Lop
 		case DIRECT_CONV2D_BACKWARD_DATA:
 			return "conv2d_backward_data";
 			
+		case CHANNEL_SUMS:
+			return "channel_sums";
+			
 		default:
 			throw new UnsupportedOperationException(this.printErrorLocation() + "Instruction is not defined for Transform operation " + operation);
 				
@@ -180,6 +195,31 @@ public class ConvolutionTransform extends Lop
 	}
 	
 	@Override
+	public String getInstructions(String input, String C, String HW, String output) throws LopsException {
+		if(operation == OperationTypes.CHANNEL_SUMS) {
+			StringBuilder sb = new StringBuilder();
+			sb.append( getExecType() );
+			
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getOpcode() );
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(0).prepInputOperand(input));
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(1).prepInputOperand(C));
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(2).prepInputOperand(HW));
+			//output
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( this.prepOutputOperand(output));
+			
+			return sb.toString();
+		}
+		else {
+			throw new LopsException("The operation is not supported with three operands:" + operation.name());
+		}
+	}
+	
+	@Override
 	public String getInstructions(String[] inputs, String output) throws LopsException {
 		StringBuilder sb = new StringBuilder();
 		appendOpcode(sb);

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index 4e66042..d0bc429 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -233,6 +233,7 @@ public class CPInstructionParser extends InstructionParser
 		String2CPInstructionType.put( "conv2d_backward_data"      , CPINSTRUCTION_TYPE.Convolution);
 		String2CPInstructionType.put( "bias_add"      , CPINSTRUCTION_TYPE.Convolution);
 		String2CPInstructionType.put( "bias_multiply"      , CPINSTRUCTION_TYPE.Convolution);
+		String2CPInstructionType.put( "channel_sums"      , CPINSTRUCTION_TYPE.Convolution);
 		
 		// Quaternary instruction opcodes
 		String2CPInstructionType.put( "wsloss"  , CPINSTRUCTION_TYPE.Quaternary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 503576f..ae19969 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -53,6 +53,7 @@ public class GPUInstructionParser  extends InstructionParser
 		String2GPUInstructionType.put( "maxpooling_backward",    GPUINSTRUCTION_TYPE.Convolution);
 		String2GPUInstructionType.put( "bias_add",               GPUINSTRUCTION_TYPE.Convolution);
 		String2GPUInstructionType.put( "bias_multiply",          GPUINSTRUCTION_TYPE.Convolution);
+		String2GPUInstructionType.put( "channel_sums",          GPUINSTRUCTION_TYPE.Convolution);
 
 		// Matrix Multiply Operators
 		String2GPUInstructionType.put( "ba+*",  GPUINSTRUCTION_TYPE.AggregateBinary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
index c6b4698..36422d9 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
@@ -27,12 +27,14 @@ import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.functionobjects.KahanPlus;
 import org.apache.sysml.runtime.functionobjects.SwapIndex;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
 import org.apache.sysml.runtime.matrix.data.LibMatrixNative;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.SparseBlock;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
 import org.apache.sysml.utils.NativeHelper;
@@ -59,6 +61,19 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction {
 		_numThreads = numThreads;
 		_intermediateMemoryBudget = intermediateMemoryBudget;
 	}
+	
+	public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, int numThreads, double intermediateMemoryBudget) throws DMLRuntimeException {
+		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out,
+				opcode, istr);
+		if( !opcode.equals("channel_sums") ) {
+			throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode);
+		}
+		_in2 = in2;
+		_in3 = in3;
+		_cptype = CPINSTRUCTION_TYPE.Convolution;
+		_numThreads = numThreads;
+		_intermediateMemoryBudget = intermediateMemoryBudget;
+	}
 
 	private ConvolutionCPInstruction(CPOperand in, CPOperand out, String opcode, String istr,
 			ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
@@ -212,6 +227,14 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction {
 			int k = Integer.parseInt(parts[4]);
 			return new ConvolutionCPInstruction(in, in2, out, opcode, str, k, Double.parseDouble(parts[5]));
 		}
+		else if (opcode.equalsIgnoreCase("channel_sums")) {
+			InstructionUtils.checkNumFields(parts, 4);
+			CPOperand in = new CPOperand(parts[1]);
+			CPOperand in2 = new CPOperand(parts[2]);
+			CPOperand in3 = new CPOperand(parts[3]);
+			CPOperand out = new CPOperand(parts[4]);
+			return new ConvolutionCPInstruction(in, in2, in3, out, opcode, str, -1, 0);
+		}
 		else {
 			throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionCPInstruction: " + str);
 		}
@@ -297,6 +320,65 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction {
 		ec.setMatrixOutput(getOutputVariableName(), outputBlock, getExtendedOpcode());
 	}
 	
+	public void processChannelSumsInstruction(ExecutionContext ec) throws DMLRuntimeException {
+		MatrixBlock input = ec.getMatrixInput(input1.getName(), getExtendedOpcode());
+		int C = (int) ec.getScalarInput(_in2.getName(), _in2.getValueType(), _in2.isLiteral()).getLongValue();
+		int HW = (int) ec.getScalarInput(_in3.getName(), _in3.getValueType(), _in3.isLiteral()).getLongValue();
+		if(C*HW != input.getNumColumns()) {
+			throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns());
+		}
+		MatrixBlock outputBlock = null;
+		if(input.isEmpty()) {
+			outputBlock = new MatrixBlock(C, 1, true);
+		}
+		else {
+			outputBlock = new MatrixBlock(C, 1, false).allocateBlock();
+			double [] output = outputBlock.getDenseBlock();
+			if(input.isInSparseFormat()) {
+				SparseBlock sblock = input.getSparseBlock();
+				for(int n = 0; n < input.getNumRows(); n++) {
+					if( sblock.isEmpty(n) )
+						continue;
+					int apos = sblock.pos(n);
+					int alen = sblock.size(n);
+					int[] aix = sblock.indexes(n);
+					double[] avals = sblock.values(n);
+					
+					// Iterate over the sparse block
+					for(int j=apos; j<apos+alen; j++) {
+						// Note: the input is of shape [N, CHW]
+						int chw = aix[j];
+						
+						// Get individual zero-based c,h,w indexes from zero-based 'chw'
+						int c = chw / HW;
+						output[c] += avals[j];
+					}
+				}
+			}
+			else {
+				double [] inArr = input.getDenseBlock();
+				if(inArr != null) {
+					KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
+					for(int c = 0; c < C; c++) {
+						KahanObject sum = new KahanObject(0.0, 0.0);
+						for(int n = 0; n < input.getNumRows(); n++) {
+							int index =  n*C*HW + c*HW;
+							for(int hw = 0; hw < HW; hw++, index++) {
+								kplus.execute2(sum, inArr[index]);
+							}
+						}
+						output[c] = sum._sum;
+					}
+				}
+			}
+			outputBlock.recomputeNonZeros(getExtendedOpcode());
+		}
+		
+		// release inputs/outputs
+		ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+		ec.setMatrixOutput(getOutputVariableName(), outputBlock, getExtendedOpcode());
+	}
+	
 	// Assumption: enableNative && NativeHelper.isNativeLibraryLoaded() is true
 	// This increases the number of native calls. For example:the cases where filter is sparse but input is dense
 	private static boolean isFilterSparse(MatrixBlock filter) throws DMLRuntimeException {
@@ -324,6 +406,10 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction {
 			processReluBackwardInstruction(ec);
 			return;
 		}
+		else if (instOpcode.equalsIgnoreCase("channel_sums")) {
+			processChannelSumsInstruction(ec);
+			return;
+		}
 		
 		// acquire inputs
 		MatrixBlock outputBlock = null;

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
index 8565b5a..fdb208e 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
@@ -20,12 +20,17 @@ package org.apache.sysml.runtime.instructions.gpu;
 
 import java.util.ArrayList;
 
+import jcuda.Pointer;
+
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.functionobjects.SwapIndex;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.ConvolutionCPInstruction;
+import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
@@ -57,6 +62,19 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 		_intermediateMemoryBudget = intermediateMemoryBudget;
 	}
 	
+	public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) throws DMLRuntimeException {
+		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
+		if( !opcode.equals("channel_sums") ) {
+			throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode);
+		}
+		_input1 = in1;
+		_input2 = in2;
+		_input3 = in3;
+		_gputype = GPUINSTRUCTION_TYPE.Convolution;
+		_output = out;
+		_intermediateMemoryBudget = intermediateMemoryBudget;
+	}
+	
 	public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode,
 			String istr, ArrayList<CPOperand> stride,
 			ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
@@ -210,6 +228,14 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 			CPOperand out = new CPOperand(parts[3]);
 			return new ConvolutionGPUInstruction(in1, in2, out, opcode, str, Double.parseDouble(parts[4]));
 		}
+		else if (opcode.equalsIgnoreCase("channel_sums")) {
+			InstructionUtils.checkNumFields(parts, 4);
+			CPOperand in = new CPOperand(parts[1]);
+			CPOperand in2 = new CPOperand(parts[2]);
+			CPOperand in3 = new CPOperand(parts[3]);
+			CPOperand out = new CPOperand(parts[4]);
+			return new ConvolutionGPUInstruction(in, in2, in3, out, opcode, str, 0);
+		}
 		else {
 			throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionGPUInstruction: " + str);	
 		}
@@ -246,6 +272,23 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
 	}
 	
+	public void processChannelSumsInstruction(ExecutionContext ec) throws DMLRuntimeException {
+		GPUStatistics.incrementNoOfExecutedGPUInst();
+		MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
+		int C = (int) ec.getScalarInput(_input2.getName(), _input2.getValueType(), _input2.isLiteral()).getLongValue();
+		int HW = (int) ec.getScalarInput(_input3.getName(), _input3.getValueType(), _input3.isLiteral()).getLongValue();
+		if(C*HW != input.getNumColumns()) {
+			throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns());
+		}
+		MatrixObject outputBlock = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), C, 1);
+		
+		LibMatrixCUDA.channelSums(ec.getGPUContext(0), getExtendedOpcode(), input, outputBlock, C, HW);
+		
+		// release inputs/outputs
+		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+	}
+	
 	@Override
 	public void processInstruction(ExecutionContext ec) 
 			throws DMLRuntimeException 
@@ -258,6 +301,10 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 			processReLUBackwardInstruction(ec);
 			return;
 		}
+		else if (instOpcode.equalsIgnoreCase("channel_sums")) {
+			processChannelSumsInstruction(ec);
+			return;
+		}
 		
 		GPUStatistics.incrementNoOfExecutedGPUInst();
 					

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
index e7f515a..caaa9e8 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
@@ -316,7 +316,7 @@ public class QuantilePickSPInstruction extends BinarySPInstruction {
 				sum += v2.next()._2().sumWeightForQuantile();
 			
 			//return tuple for partition aggregate
-			return Arrays.asList(new Tuple2<>(v1,sum)).iterator();
+			return Arrays.asList(new Tuple2<Integer, Double>(v1,sum)).iterator();
 		}
 	}
 	

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 2cccde0..c0091c8 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -322,6 +322,37 @@ public class LibMatrixCUDA {
 		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_RELU_BACKWARD_KERNEL, System.nanoTime() - t1);
 
 	}
+	
+	/**
+	 * Perform channel_sums operations: out = rowSums(matrix(colSums(A), rows=C, cols=HW))
+	 * 
+	 * @param gCtx a valid {@link GPUContext}
+	 * @param instName the invoking instruction's name for record {@link Statistics}.
+	 * @param input input image
+	 * @param outputBlock output
+	 * @param C number of channels
+	 * @param HW height*width
+	 * @throws DMLRuntimeException if DMLRuntimeException occurs
+	 */
+	public static void channelSums(GPUContext gCtx, String instName, MatrixObject input, MatrixObject outputBlock, long C, long HW) throws DMLRuntimeException {
+		if(LOG.isTraceEnabled()) {
+			LOG.trace("GPU : channelSums" + ", GPUContext=" + gCtx);
+		}
+		int N = toInt(input.getNumRows());
+		int cols = toInt(input.getNumColumns());
+		if(cols != C*HW) {
+			throw new DMLRuntimeException("Incorrect parameters, number of columns " + cols + " != " + C + "*" + HW);
+		}
+		Pointer imagePointer = getDensePointer(gCtx, input, instName);
+		Pointer outputPointer = getDensePointer(gCtx, outputBlock, instName);
+		
+		// We can replace this with CuDNN tensor reduce
+		Pointer tmp = gCtx.allocate(instName, cols*sizeOfDataType);
+		reduceCol(gCtx, instName, "reduce_col_sum", imagePointer, tmp, N, cols);
+		reduceRow(gCtx, instName, "reduce_row_sum", tmp, outputPointer, toInt(C), toInt(HW));
+		gCtx.cudaFreeHelper(tmp);
+
+	}
 
 	/**
 	 * Performs the operation corresponding to the DML script:

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index e0a6a57..5935285 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -64,7 +64,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 	protected static cudnnHandle getCudnnHandle(GPUContext gCtx) throws DMLRuntimeException {
 		return gCtx.getCudnnHandle();
 	}
-
+	
 	/**
 	 * Does a 2D convolution followed by a bias_add
 	 *
@@ -722,4 +722,4 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		if(status != cudnnStatus.CUDNN_STATUS_SUCCESS)
 			throw new DMLRuntimeException("Error status returned by CuDNN:" + jcuda.jcudnn.cudnnStatus.stringFor(status));
 	}
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java b/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java
index 0b229f0..59e9cb1 100644
--- a/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java
@@ -45,6 +45,37 @@ public class AggregateUnaryOpTests extends UnaryOpTestsBase {
 	public void colSums() {
 		testSimpleUnaryOpMatrixOutput("colSums", "gpu_uack+");
 	}
+	
+	@Test
+	public void channelSums() {
+		int[] rows = rowSizes;
+		int[] C = new int[] { 2, 5, 10, 50 };
+		int[] HW = new int[] { 10, 12, 21, 51 };
+		double[] sparsities = this.sparsities;
+		int seed = this.seed;	
+
+		for (int k = 0; k < sparsities.length; k++) {
+			double sparsity = sparsities[k];
+			if(sparsity == 0)
+				continue; // sparsity == 0 has been independently tested but it fails with non-informative mlcontext error
+			for (int i = 0; i < rows.length; i++) {
+				int row = rows[i];
+				if(row == 1)
+					continue; // Currently channel_sums rewrite is enabled only for row > 1
+				for (int c : C) {
+					if(c == 1)
+						continue; // C == 1 will result in scalar value, but this case has been independently tested
+					for (int hw : HW) {
+						// Skip the case of a scalar unary op
+						// System.out.println("Started channelSum test for " + row + " " + c + " " + hw + " " +  sparsity);
+						String scriptStr = "out = rowSums(matrix(colSums(in1), rows=" + c + ", cols=" + hw + "));";
+						testUnaryOpMatrixOutput(scriptStr, "gpu_channel_sums", "in1", "out", seed, row, c*hw, sparsity);
+						// System.out.println("Ended channelSum test for " + row + " " + c + " " + hw + " " +  sparsity);
+					}
+				}
+			}
+		}
+	}
 
 	@Test
 	public void rowSums() {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java b/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java
index 0051dd4..0f6b59c 100644
--- a/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java
+++ b/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java
@@ -31,10 +31,10 @@ import org.apache.sysml.api.mlcontext.Matrix;
 public abstract class UnaryOpTestsBase extends GPUTests {
 
 	// Set of rows and column sizes & sparsities to test unary ops
-	private final int[] rowSizes = new int[] { 2049, 1024, 140, 64, 1 };
-	private final int[] columnSizes = new int[] { 2049, 1024, 140, 64, 1 };
-	private final double[] sparsities = new double[] { 0.9, 0.3, 0.03, 0.0 };
-	private final int seed = 42;
+	protected final int[] rowSizes = new int[] { 2049, 1024, 140, 64, 1 };
+	protected final int[] columnSizes = new int[] { 2049, 1024, 150, 64, 1 };
+	protected final double[] sparsities = new double[] { 0.9, 0.3, 0.03, 0.0 };
+	protected final int seed = 42;
 
 	/**
 	 * Tests unary ops with a variety of matrix shapes and sparsities.

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java b/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java
new file mode 100644
index 0000000..61ca370
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.test.integration.functions.tensor;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+public class ChannelSumTest extends AutomatedTestBase
+{
+	
+	private final static String TEST_NAME = "ChannelSumTest";
+	private final static String TEST_DIR = "functions/tensor/";
+	private final static String TEST_CLASS_DIR = TEST_DIR + PoolTest.class.getSimpleName() + "/";
+	private final static double epsilon=0.0000000001;
+	
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, 
+				new String[] {"B"}));
+	}
+	
+	@Test
+	public void testChannelSumDense1() 
+	{
+		int numImg = 10; int imgSize = 9; int numChannels = 5; 
+		runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, false);
+	}
+	
+	@Test
+	public void testChannelSumDense2() 
+	{
+		int numImg = 2; int imgSize = 5; int numChannels = 3; 
+		runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, false);
+	}
+	
+	@Test
+	public void testChannelSumDense3() 
+	{
+		int numImg = 9; int imgSize = 4; int numChannels = 11; 
+		runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, false);
+	}
+	
+	@Test
+	public void testChannelSumDense4() 
+	{
+		int numImg = 7; int imgSize = 8; int numChannels = 12; 
+		runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, false);
+	}
+	
+	@Test
+	public void testChannelSumSparse1() 
+	{
+		int numImg = 4; int imgSize = 10; int numChannels = 5; 
+		runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, true);
+	}
+	
+	@Test
+	public void testChannelSumSparse2() 
+	{
+		int numImg = 2; int imgSize = 10; int numChannels = 8; 
+		runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, true);
+	}
+	
+	@Test
+	public void testChannelSumSparse3() 
+	{
+		int numImg = 4; int imgSize = 10; int numChannels = 11; 
+		runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, true);
+	}
+	
+	@Test
+	public void testChannelSumSparse4() 
+	{
+		int numImg = 9; int imgSize = 6; int numChannels = 8; 
+		runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, true);
+	}
+	
+	public void runChannelSumTest( ExecType et, int imgSize, int numImg, int numChannels, boolean sparse) 
+	{
+		RUNTIME_PLATFORM platformOld = rtplatform;
+		switch( et ){
+			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+			default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+		}
+		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		if( rtplatform == RUNTIME_PLATFORM.SPARK )
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+		
+		try
+		{
+			String sparseVal = String.valueOf(sparse).toUpperCase();
+			
+			TestConfiguration config = getTestConfiguration(TEST_NAME);
+			loadTestConfiguration(config);
+	
+			String RI_HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
+			programArgs = new String[]{"-explain", "hops", "-args", String.valueOf(imgSize), 
+				String.valueOf(numImg), String.valueOf(numChannels),
+				output("B"), sparseVal};
+			
+			fullRScriptName = RI_HOME + TEST_NAME + ".R";
+			rCmd = "Rscript" + " " + fullRScriptName + " " + imgSize + " " + numImg + 
+				" " + numChannels + " " + expectedDir() + " " + sparseVal; 
+			
+			// run scripts
+			runTest(true, false, null, -1);
+			runRScript(true);
+			
+			//compare results
+			HashMap<CellIndex, Double> bHM = readRMatrixFromFS("B");
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("B");
+			TestUtils.compareMatrices(dmlfile, bHM, epsilon, "B-DML", "NumPy");
+		}
+		finally {
+			rtplatform = platformOld;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+		}
+	}
+	
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/test/scripts/functions/tensor/ChannelSumTest.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/ChannelSumTest.R b/src/test/scripts/functions/tensor/ChannelSumTest.R
new file mode 100644
index 0000000..c605074
--- /dev/null
+++ b/src/test/scripts/functions/tensor/ChannelSumTest.R
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args <- commandArgs(TRUE)
+library("Matrix")
+library("matrixStats") 
+imgSize=as.integer(args[1])
+numImg=as.integer(args[2])
+numChannels=as.integer(args[3])
+
+# Assumption: NCHW image format
+x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, numChannels*imgSize*imgSize, byrow=TRUE)
+if(as.logical(args[5])) {
+	zero_mask = (x - 1.5*mean(x)) > 0 
+	x = x * zero_mask
+} else {
+	x = x - mean(x)
+}
+
+output = rowSums(matrix(colSums(x), numChannels, imgSize*imgSize, byrow=TRUE));
+
+writeMM(as(output,"CsparseMatrix"), paste(args[4], "B", sep=""))
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/test/scripts/functions/tensor/ChannelSumTest.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/ChannelSumTest.dml b/src/test/scripts/functions/tensor/ChannelSumTest.dml
new file mode 100644
index 0000000..7810a12
--- /dev/null
+++ b/src/test/scripts/functions/tensor/ChannelSumTest.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# 
+#-------------------------------------------------------------
+imgSize=$1
+numImg=$2
+numChannels=$3
+
+# Assumption: NCHW image format
+x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), rows=numImg, cols=numChannels*imgSize*imgSize)
+if($5) {
+	zero_mask = (x - 1.5*mean(x)) > 0 
+	x = x * zero_mask
+}
+else {
+	x = x - mean(x)
+}
+output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))  # shape (C, 1)
+write(output, $4, format="text")
\ No newline at end of file