You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by na...@apache.org on 2017/05/04 23:27:43 UTC

incubator-systemml git commit: [HOTFIX] Bug fix for solve, removed warnings and added instrumentation

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 76f3ca5d3 -> 2c5c3b14e


[HOTFIX] Bug fix for solve, removed warnings and added instrumentation


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2c5c3b14
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2c5c3b14
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2c5c3b14

Branch: refs/heads/master
Commit: 2c5c3b14e1906cda70ae1581b19a5e908b3ab329
Parents: 76f3ca5
Author: Nakul Jindal <na...@gmail.com>
Authored: Thu May 4 16:26:47 2017 -0700
Committer: Nakul Jindal <na...@gmail.com>
Committed: Thu May 4 16:26:47 2017 -0700

----------------------------------------------------------------------
 .../instructions/GPUInstructionParser.java      |  4 +-
 .../gpu/BuiltinBinaryGPUInstruction.java        |  2 +
 .../instructions/gpu/GPUInstruction.java        | 28 ++++---
 .../gpu/MatrixMatrixBuiltinGPUInstruction.java  |  1 +
 .../instructions/gpu/context/GPUContext.java    |  2 +
 .../instructions/gpu/context/GPUObject.java     |  3 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      | 77 +++++++++++++++-----
 7 files changed, 86 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index ef0412c..4a45521 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -35,9 +35,9 @@ import org.apache.sysml.runtime.instructions.gpu.AggregateUnaryGPUInstruction;
 
 public class GPUInstructionParser  extends InstructionParser 
 {
-	public static final HashMap<String, GPUINSTRUCTION_TYPE> String2GPUInstructionType;
+	static final HashMap<String, GPUINSTRUCTION_TYPE> String2GPUInstructionType;
 	static {
-		String2GPUInstructionType = new HashMap<String, GPUINSTRUCTION_TYPE>();
+		String2GPUInstructionType = new HashMap<>();
 
 		// Neural Network Operators
 		String2GPUInstructionType.put( "relu_backward",          GPUINSTRUCTION_TYPE.Convolution);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java
index 372f883..24e9e79 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java
@@ -30,7 +30,9 @@ import org.apache.sysml.runtime.matrix.operators.Operator;
 
 public abstract class BuiltinBinaryGPUInstruction extends GPUInstruction {
 
+  @SuppressWarnings("unused")
   private int _arity;
+
   CPOperand output;
   CPOperand input1, input2;
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
index 9eef072..f4c523b 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
@@ -35,16 +35,20 @@ public abstract class GPUInstruction extends Instruction
 	public enum GPUINSTRUCTION_TYPE { AggregateUnary, AggregateBinary, Convolution, MMTSJ, Reorg, ArithmeticBinary, BuiltinUnary, BuiltinBinary, Builtin };
 
 	// Memory/conversions
-	public final static String MISC_TIMER_HOST_TO_DEVICE = 				"H2D";	// time spent in bringing data to gpu (from host)
-	public final static String MISC_TIMER_DEVICE_TO_HOST =				"D2H"; 	// time spent in bringing data from gpu (to host)
-	public final static String MISC_TIMER_DEVICE_TO_DEVICE = 			"D2D"; 	// time spent in copying data from one region on the device to another
-	public final static String MISC_TIMER_SPARSE_TO_DENSE = 			"s2d";	// time spent in converting data from sparse to dense
-	public final static String MISC_TIMER_DENSE_TO_SPARSE = 			"d2s";	// time spent in converting data from dense to sparse
-	public final static String MISC_TIMER_CUDA_FREE = 						"f";		// time spent in calling cudaFree
-	public final static String MISC_TIMER_ALLOCATE = 							"a";		// time spent to allocate memory on gpu
-	public final static String MISC_TIMER_ALLOCATE_DENSE_OUTPUT = "ao";		// time spent to allocate dense output (recorded differently than MISC_TIMER_ALLOCATE)
-	public final static String MISC_TIMER_SET_ZERO = 							"az";		// time spent to allocate
-	public final static String MISC_TIMER_REUSE = 								"r";		// time spent in reusing already allocated memory on GPU (mainly for the count)
+	public final static String MISC_TIMER_HOST_TO_DEVICE =          "H2D";	// time spent in bringing data to gpu (from host)
+	public final static String MISC_TIMER_DEVICE_TO_HOST =          "D2H"; 	// time spent in bringing data from gpu (to host)
+	public final static String MISC_TIMER_DEVICE_TO_DEVICE =        "D2D"; 	// time spent in copying data from one region on the device to another
+	public final static String MISC_TIMER_SPARSE_TO_DENSE =         "s2d";	// time spent in converting data from sparse to dense
+	public final static String MISC_TIMER_DENSE_TO_SPARSE =         "d2s";	// time spent in converting data from dense to sparse
+	public final static String MISC_TIMER_ROW_TO_COLUMN_MAJOR =     "r2c";	// time spent in converting data from row major to column major
+	public final static String MISC_TIMER_COLUMN_TO_ROW_MAJOR =     "c2r";	// time spent in converting data from column major to row major
+	public final static String MISC_TIMER_OBJECT_CLONE =            "clone";// time spent in cloning (deep copying) a GPUObject instance
+
+	public final static String MISC_TIMER_CUDA_FREE =               "f";		// time spent in calling cudaFree
+	public final static String MISC_TIMER_ALLOCATE =                "a";		// time spent to allocate memory on gpu
+	public final static String MISC_TIMER_ALLOCATE_DENSE_OUTPUT =   "ao";		// time spent to allocate dense output (recorded differently than MISC_TIMER_ALLOCATE)
+	public final static String MISC_TIMER_SET_ZERO =                "az";		// time spent to allocate
+	public final static String MISC_TIMER_REUSE =                   "r";		// time spent in reusing already allocated memory on GPU (mainly for the count)
 
 	// Matmult instructions
 	public final static String MISC_TIMER_SPARSE_ALLOCATE_LIB = 						"Msao";		// time spend in allocating for sparse matrix output
@@ -58,6 +62,10 @@ public abstract class GPUInstruction extends Instruction
 
 	// Other BLAS instructions
 	public final static String MISC_TIMER_DAXPY_LIB = "daxpy";	// time spent in daxpy
+	public final static String MISC_TIMER_QR_BUFFER = "qr_buffer"; 	// time spent in calculating buffer needed to perform QR
+	public final static String MISC_TIMER_QR = "qr"; 	// time spent in doing QR
+	public final static String MISC_TIMER_ORMQR = "ormqr"; // time spent in ormqr
+	public final static String MISC_TIMER_TRSM = "trsm"; // time spent in cublas Dtrsm
 
 	// Transpose
 	public final static String MISC_TIMER_SPARSE_DGEAM_LIB = 	"sdgeaml"; 	// time spent in sparse transpose (and other ops of type a*op(A) + b*op(B))

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java
index f492b6e..8936735 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java
@@ -45,6 +45,7 @@ public class MatrixMatrixBuiltinGPUInstruction extends BuiltinBinaryGPUInstructi
     MatrixObject mat2 = getMatrixInputForGPUInstruction(ec, input2.getName());
 
     if(opcode.equals("solve")) {
+      ec.setMetaData(output.getName(), mat1.getNumColumns(), 1);
       LibMatrixCUDA.solve(ec, ec.getGPUContext(), getExtendedOpcode(), mat1, mat2, output.getName());
 
     } else {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
index d71f725..673601f 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
@@ -307,6 +307,8 @@ public class GPUContext {
         freeList = new LinkedList<Pointer>();
         freeCUDASpaceMap.put(size, freeList);
       }
+      if (freeList.contains(toFree))
+        throw new RuntimeException("GPU : Internal state corrupted, double free");
       freeList.add(toFree);
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index 1d2285d..d735e38 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -26,7 +26,6 @@ import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
 import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
 import static jcuda.jcusparse.JCusparse.cusparseDdense2csr;
 import static jcuda.jcusparse.JCusparse.cusparseDnnz;
-import static jcuda.runtime.JCuda.cudaMalloc;
 import static jcuda.runtime.JCuda.cudaMemcpy;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
@@ -343,7 +342,7 @@ public class GPUObject {
 
 	/**
 	 * Convenience method. Converts Column Major Dense Matrix to Row Major Dense Matrix
-	 * @throws DMLRuntimeException
+	 * @throws DMLRuntimeException if error
 	 */
 	public void denseColumnMajorToRowMajor() throws DMLRuntimeException {
 		LOG.trace("GPU : dense Ptr row-major -> col-major on " + this + ", GPUContext=" + getGPUContext());

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 23304b5..a99571a 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -329,6 +329,7 @@ public class LibMatrixCUDA {
 	 * @return a sparse matrix pointer
 	 * @throws DMLRuntimeException if error occurs
 	 */
+	@SuppressWarnings("unused")
 	private static CSRPointer getSparsePointer(GPUContext gCtx, MatrixObject input, String instName) throws DMLRuntimeException {
 		if(!isInSparseFormat(gCtx, input)) {
 			input.getGPUObject(gCtx).denseToSparse();
@@ -2754,6 +2755,25 @@ public class LibMatrixCUDA {
 		Pointer betaPtr = pointerTo(beta);
 		int transa = isLeftTransposed ? CUBLAS_OP_T : CUBLAS_OP_N;
 		int transb = isRightTransposed ? CUBLAS_OP_T : CUBLAS_OP_N;
+
+		int lda = (int) in1.getNumColumns();
+		int ldb = (int) in2.getNumColumns();
+		int m = (int) in1.getNumColumns();
+		int n = (int) in2.getNumRows();
+		if (isLeftTransposed && isRightTransposed) {
+			m = (int) in1.getNumRows();
+			n = (int) in2.getNumColumns();
+		}
+		else if (isLeftTransposed) {
+			m = (int) in1.getNumRows();
+		} else if (isRightTransposed) {
+			n = (int) in2.getNumColumns();
+		}
+		int ldc = m;
+
+
+
+		/**
 		int m = (int) in1.getNumRows();
 		int n = (int) in1.getNumColumns();
 		if(!isLeftTransposed && isRightTransposed) {
@@ -2763,6 +2783,7 @@ public class LibMatrixCUDA {
 		int lda = isLeftTransposed ? n : m;
 		int ldb = isRightTransposed ? n : m;
 		int ldc = m;
+		**/
 
 		MatrixObject out = ec.getMatrixObject(outputName);
 		boolean isSparse1 = isInSparseFormat(gCtx, in1);
@@ -2963,8 +2984,10 @@ public class LibMatrixCUDA {
             throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
 
         // x = solve(A, b)
+		LOG.trace("GPU : solve" + ", GPUContext=" + gCtx);
+
+		long t0 = -1;
 
-        // Both Sparse
         if (!isInSparseFormat(gCtx, in1) && !isInSparseFormat(gCtx, in2)) {    // Both dense
             GPUObject Aobj = in1.getGPUObject(gCtx);
             GPUObject bobj = in2.getGPUObject(gCtx);
@@ -2980,55 +3003,75 @@ public class LibMatrixCUDA {
             // convert dense matrices to row major
             // Operation in cuSolver and cuBlas are for column major dense matrices
             // and are destructive to the original input
+			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
             GPUObject ATobj = (GPUObject) Aobj.clone();
-            ATobj.denseRowMajorToColumnMajor();
+			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_OBJECT_CLONE, System.nanoTime() - t0);
+			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
+			ATobj.denseRowMajorToColumnMajor();
+            if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ROW_TO_COLUMN_MAJOR, System.nanoTime() - t0);
             Pointer A = ATobj.getJcudaDenseMatrixPtr();
 
+			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
             GPUObject bTobj = (GPUObject) bobj.clone();
-            bTobj.denseRowMajorToColumnMajor();
-            Pointer b = bTobj.getJcudaDenseMatrixPtr();
+			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_OBJECT_CLONE, System.nanoTime() - t0);
+			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
+			bTobj.denseRowMajorToColumnMajor();
+			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ROW_TO_COLUMN_MAJOR, System.nanoTime() - t0);
+
+			Pointer b = bTobj.getJcudaDenseMatrixPtr();
 
             // The following set of operations is done following the example in the cusolver documentation
             // http://docs.nvidia.com/cuda/cusolver/#ormqr-example1
 
             // step 3: query working space of geqrf and ormqr
-            int[] lwork = {0};
+			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
+			int[] lwork = {0};
             JCusolverDn.cusolverDnDgeqrf_bufferSize(gCtx.getCusolverDnHandle(), m, n, A, m, lwork);
+			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_QR_BUFFER, System.nanoTime() - t0);
+
 
-            // step 4: compute QR factorization
-            Pointer work = gCtx.allocate(lwork[0] * Sizeof.DOUBLE);
-            Pointer tau = gCtx.allocate(Math.max(m, m) * Sizeof.DOUBLE);
+			// step 4: compute QR factorization
+            Pointer work = gCtx.allocate(instName, lwork[0] * Sizeof.DOUBLE);
+            Pointer tau = gCtx.allocate(instName, Math.max(m, m) * Sizeof.DOUBLE);
             Pointer devInfo = gCtx.allocate(Sizeof.INT);
-            JCusolverDn.cusolverDnDgeqrf(gCtx.getCusolverDnHandle(), m, n, A, m, tau, work, lwork[0], devInfo);
+			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
+			JCusolverDn.cusolverDnDgeqrf(gCtx.getCusolverDnHandle(), m, n, A, m, tau, work, lwork[0], devInfo);
+			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_QR, System.nanoTime() - t0);
 
-            int[] qrError = {-1};
+
+			int[] qrError = {-1};
             cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost);
             if (qrError[0] != 0) {
                 throw new DMLRuntimeException("GPU : Error in call to geqrf (QR factorization) as part of solve, argument " + qrError[0] + " was wrong");
             }
 
             // step 5: compute Q^T*B
-            JCusolverDn.cusolverDnDormqr(gCtx.getCusolverDnHandle(), cublasSideMode.CUBLAS_SIDE_LEFT, cublasOperation.CUBLAS_OP_T, m, 1, n, A, m, tau, b, m, work, lwork[0], devInfo);
-            cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost);
+			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
+			JCusolverDn.cusolverDnDormqr(gCtx.getCusolverDnHandle(), cublasSideMode.CUBLAS_SIDE_LEFT, cublasOperation.CUBLAS_OP_T, m, 1, n, A, m, tau, b, m, work, lwork[0], devInfo);
+			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ORMQR, System.nanoTime() - t0);
+			cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost);
             if (qrError[0] != 0) {
                 throw new DMLRuntimeException("GPU : Error in call to ormqr (to compuete Q^T*B after QR factorization) as part of solve, argument " + qrError[0] + " was wrong");
             }
 
             // step 6: compute x = R \ Q^T*B
-            JCublas2.cublasDtrsm(gCtx.getCublasHandle(),
+			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
+			JCublas2.cublasDtrsm(gCtx.getCublasHandle(),
                 cublasSideMode.CUBLAS_SIDE_LEFT, cublasFillMode.CUBLAS_FILL_MODE_UPPER, cublasOperation.CUBLAS_OP_N, cublasDiagType.CUBLAS_DIAG_NON_UNIT,
                 n, 1, pointerTo(1.0), A, m, b, m);
+			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_TRSM, System.nanoTime() - t0);
 
-            bTobj.denseColumnMajorToRowMajor();
+			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
+			bTobj.denseColumnMajorToRowMajor();
+			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_COLUMN_TO_ROW_MAJOR, System.nanoTime() - t0);
 
             // TODO  : Find a way to assign bTobj directly to the output and set the correct flags so as to not crash
             // There is an avoidable copy happening here
             MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName);
             cudaMemcpy(out.getGPUObject(gCtx).getJcudaDenseMatrixPtr(), bTobj.getJcudaDenseMatrixPtr(), n * 1 * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
 
-            gCtx.cudaFreeHelper(work);
-            gCtx.cudaFreeHelper(tau);
-            gCtx.cudaFreeHelper(tau);
+            gCtx.cudaFreeHelper(instName, work);
+            gCtx.cudaFreeHelper(instName, tau);
             ATobj.clearData();
             bTobj.clearData();