You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/09/07 19:50:02 UTC

[5/5] systemml git commit: [SYSTEMML-540] Support sparse GPU conv2d as well as fix memory estimation of convolution operations

[SYSTEMML-540] Support sparse GPU conv2d as well as fix memory estimation of convolution operations

Design doc: Memory estimation of GPU operators

- Since not all operator are supported on GPU, isGPUEnabled indicates
whether an operation is enabled for GPU. This method doesnot take into
account any memory estimates.
- To simplify memory estimation logic, the methods computeOutputMemEstimate
and computeIntermediateMemEstimate should return maximum of memory
required for GPU and CP operators.
- Additionally, these methods are guarded so that when -gpu flag is not
provided, additional memory overhead due to GPU are ignored. For example:
sparse-to-dense conversion on GPU.
- (WIP) Every GPU operators should respect the memory returned by
computeIntermediateMemEstimate (and computeOutputMemEstimate - see below
point).
- (WIP) Every GPU operator should create output in the same format as the
corresponding CP operator. That is, computeOutputMemEstimate are
consistent across both CP and GPU in terms of worst-case.
-  The drawback of using maximum memory (mem = Math.max(mem_gpu, mem_gpu))
are:
a. GPU operator is not selected when mem_gpu < total memory available on GPU
< mem
b. CP operator is not selected (i.e. distributed operator compiled) when
mem_cpu < driver memory budget < mem

Closes #650.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/772d9302
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/772d9302
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/772d9302

Branch: refs/heads/master
Commit: 772d9302dc196b047134ea491542d55113f52a08
Parents: a0cf8e3
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu Sep 7 11:49:52 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu Sep 7 12:49:52 2017 -0700

----------------------------------------------------------------------
 src/main/cpp/kernels/SystemML.cu                |   19 +
 src/main/cpp/kernels/SystemML.ptx               | 2757 +++++++++---------
 .../org/apache/sysml/hops/ConvolutionOp.java    |  458 ++-
 src/main/java/org/apache/sysml/hops/Hop.java    |   63 +-
 .../apache/sysml/lops/ConvolutionTransform.java |   11 +-
 .../cp/ConvolutionCPInstruction.java            |   91 +-
 .../gpu/ConvolutionGPUInstruction.java          |   72 +-
 .../gpu/MatrixBuiltinGPUInstruction.java        |    3 +-
 .../instructions/gpu/context/GPUContext.java    |    6 +
 .../matrix/data/ConvolutionParameters.java      |   25 +
 .../runtime/matrix/data/LibMatrixCUDA.java      | 1041 +------
 .../runtime/matrix/data/LibMatrixCuDNN.java     | 1219 ++++++++
 12 files changed, 3229 insertions(+), 2536 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/772d9302/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index d64d8aa..bb6482d 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -156,6 +156,25 @@ __global__ void relu_backward(double* X,  double* dout, double* ret, int rlen, i
 	}
 }
 
+/**
+ * Performs inplace addition: ret += input
+ *
+ * @param input rhs input array allocated on the GPU
+ * @param ret the input and output array allocated on the GPU
+ * @param rlen the number of rows
+ * @param clen the number of columns
+ */
+extern "C"
+__global__ void inplace_add(double* input,  double* ret, int rlen, int clen) {
+	int tid = blockIdx.x * blockDim.x + threadIdx.x;
+	int ix = tid / clen;
+	int iy = tid % clen;
+	if(ix < rlen && iy < clen) {
+		int index = ix * clen + iy;
+		ret[index] += input[index];
+	}
+}
+
 // Performs the operation corresponding to the DML script:
 // ones = matrix(1, rows=1, cols=Hout*Wout)
 // output = input + matrix(bias %*% ones, rows=1, cols=F*Hout*Wout)