You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/10/26 03:00:59 UTC

systemml git commit: [SYSTEMML-446] Bugfix for GPU sparse right indexing with empty output

Repository: systemml
Updated Branches:
  refs/heads/master abbffc55e -> d3917effd


[SYSTEMML-446] Bugfix for GPU sparse right indexing with empty output


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d3917eff
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d3917eff
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d3917eff

Branch: refs/heads/master
Commit: d3917effd988de0e0977a310c73c4f232214632e
Parents: abbffc5
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Wed Oct 25 19:57:28 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Wed Oct 25 19:57:28 2017 -0700

----------------------------------------------------------------------
 .../gpu/context/ExecutionConfig.java            | 29 ++------------------
 .../runtime/matrix/data/LibMatrixCUDA.java      |  8 ++++--
 2 files changed, 7 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d3917eff/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/ExecutionConfig.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/ExecutionConfig.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/ExecutionConfig.java
index 7f8eb9e..cae0660 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/ExecutionConfig.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/ExecutionConfig.java
@@ -69,6 +69,8 @@ public class ExecutionConfig {
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
 	public static ExecutionConfig getConfigForSimpleVectorOperations(int numCells) throws DMLRuntimeException {
+		if(numCells == 0)
+			throw new DMLRuntimeException("Attempting to invoke a kernel with 0 threads");
 		int deviceNumber = 0;
 		int blockDimX = getMaxBlockDim(deviceNumber);
 		int gridDimX = (int) Math.ceil((double) numCells / blockDimX);
@@ -76,32 +78,6 @@ public class ExecutionConfig {
 	}
 
 	/**
-	 * Use this for simple matrix operations and use following in the kernel
-	 * <code>
-	 * int ix = blockIdx.x * blockDim.x + threadIdx.x;
-	 * int iy = blockIdx.y * blockDim.y + threadIdx.y;
-	 * </code>
-	 * <p>
-	 * This tries to schedule as minimum grids as possible.
-	 *
-	 * @param rlen number of rows
-	 * @param clen number of columns
-	 * @return execution configuration
-	 * @throws DMLRuntimeException if DMLRuntimeException occurs
-	 */
-	public static ExecutionConfig getConfigForMatrixOperations(int rlen, int clen) throws DMLRuntimeException {
-		int deviceNumber = 0;
-		int maxBlockDim = getMaxBlockDim(deviceNumber);
-		int blockDimX = (int) Math.min(maxBlockDim, rlen);
-		int gridDimX = (int) Math.ceil((double) rlen / blockDimX);
-		int blockDimY = (int) Math.min(Math.floor(((double) maxBlockDim) / blockDimX), clen);
-		int gridDimY = (int) Math.ceil((double) clen / blockDimY);
-		if (gridDimY > 65535)
-			throw new DMLRuntimeException("Internal Error: gridDimY must be less than 65535 for all supported CUDA compute capabilites!");
-		return new ExecutionConfig(gridDimX, gridDimY, blockDimX, blockDimY);
-	}
-
-	/**
 	 * Use this for simple vector operations and use following in the kernel
 	 * <code>
 	 * int index = blockIdx.x * blockDim.x + threadIdx.x
@@ -116,7 +92,6 @@ public class ExecutionConfig {
 		return getConfigForSimpleVectorOperations(rlen * clen);
 	}
 
-
 	public ExecutionConfig(int gridDimX, int blockDimX) {
 		this.gridDimX = gridDimX;
 		this.blockDimX = blockDimX;

http://git-wip-us.apache.org/repos/asf/systemml/blob/d3917eff/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index eb17e69..2cccde0 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -1821,17 +1821,19 @@ public class LibMatrixCUDA {
 	 */
 	protected static void sliceSparseDense(GPUContext gCtx, String instName, CSRPointer inPointer, Pointer outPointer, 
 			int rl, int ru, int cl, int cu, int inClen) throws DMLRuntimeException {
+		int size = getNnz(inPointer, rl, ru);
+		// Return since nnz of the output is 0 as outPointer is expected to be zeroed out.
+		if(size == 0) return;
+		
 		int retRlen = ru - rl + 1;
 		long t0 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
 		int retClen = cu - cl + 1;
 		
-		int size = -1; String kernel = null; String timer = null;
-		
+		String kernel = null; String timer = null;
 		// Note: row-wise parallelization scheme iterates over input rows in single thread 
 		// whereas nnz parallelization scheme iterates over number of output rows in single thread.
 		if(inClen > 10 && retClen > 2*retRlen) {
 			// Perform nnz parallelization for wide and short matrices
-			size = getNnz(inPointer, rl, ru);
 			timer = GPUInstruction.MISC_TIMER_RIX_SPARSE_DENSE_OP_NNZ;
 			kernel = "slice_sparse_dense_nnz";
 		}