You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/10/26 02:29:03 UTC

[2/4] systemml git commit: [SYSTEMML-1969] Support single-precision operations on GPU backend

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/api/DMLScript.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java b/src/main/java/org/apache/sysml/api/DMLScript.java
index ba447cf..4da874e 100644
--- a/src/main/java/org/apache/sysml/api/DMLScript.java
+++ b/src/main/java/org/apache/sysml/api/DMLScript.java
@@ -163,6 +163,7 @@ public class DMLScript
 	public static boolean           ENABLE_DEBUG_MODE   = DMLOptions.defaultOptions.debug;       // debug mode
 	public static ExplainType       EXPLAIN             = DMLOptions.defaultOptions.explainType; // explain type
 	public static String            DML_FILE_PATH_ANTLR_PARSER = DMLOptions.defaultOptions.filePath; // filename of dml/pydml script
+	public static String            FLOATING_POINT_PRECISION = "double"; 							// data type to use internally
 
 	/**
 	 * Global variable indicating the script type (DML or PYDML). Can be used

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
index a49ffda..51ab6a1 100644
--- a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
+++ b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
@@ -81,6 +81,10 @@ public class ScriptExecutorUtils {
 		DMLScript.SYNCHRONIZE_GPU = dmlconf.getBooleanValue(DMLConfig.SYNCHRONIZE_GPU);
 		DMLScript.EAGER_CUDA_FREE = dmlconf.getBooleanValue(DMLConfig.EAGER_CUDA_FREE);
 		DMLScript.STATISTICS_MAX_WRAP_LEN = dmlconf.getIntValue(DMLConfig.STATS_MAX_WRAP_LEN);
+		if(DMLScript.USE_ACCELERATOR) {
+			DMLScript.FLOATING_POINT_PRECISION = dmlconf.getTextValue(DMLConfig.FLOATING_POINT_PRECISION);
+			org.apache.sysml.runtime.matrix.data.LibMatrixCUDA.resetFloatingPointPrecision();
+		}
 
 		boolean exceptionThrown = false;
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/conf/DMLConfig.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java b/src/main/java/org/apache/sysml/conf/DMLConfig.java
index 0b73ab0..e8bde56 100644
--- a/src/main/java/org/apache/sysml/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java
@@ -92,6 +92,7 @@ public class DMLConfig
 	// Fraction of available memory to use. The available memory is computer when the GPUContext is created
 	// to handle the tradeoff on calling cudaMemGetInfo too often.
 	public static final String GPU_MEMORY_UTILIZATION_FACTOR = "sysml.gpu.memory.util.factor";
+	public static final String FLOATING_POINT_PRECISION = "sysml.floating.point.precision"; // String to specify the datatype to use internally: supported values are double, single
 
 	// supported prefixes for custom map/reduce configurations
 	public static final String PREFIX_MAPRED = "mapred";
@@ -139,6 +140,7 @@ public class DMLConfig
 		_defaultVals.put(AVAILABLE_GPUS,         "-1");
 		_defaultVals.put(SYNCHRONIZE_GPU,        "true" );
 		_defaultVals.put(EAGER_CUDA_FREE,        "false" );
+		_defaultVals.put(FLOATING_POINT_PRECISION,        	 "double" );
 	}
 	
 	public DMLConfig()
@@ -421,7 +423,7 @@ public class DMLConfig
 				COMPRESSED_LINALG, 
 				CODEGEN, CODEGEN_COMPILER, CODEGEN_OPTIMIZER, CODEGEN_PLANCACHE, CODEGEN_LITERALS,
 				EXTRA_GPU_STATS, EXTRA_DNN_STATS, EXTRA_FINEGRAINED_STATS, STATS_MAX_WRAP_LEN,
-				AVAILABLE_GPUS, SYNCHRONIZE_GPU, EAGER_CUDA_FREE
+				AVAILABLE_GPUS, SYNCHRONIZE_GPU, EAGER_CUDA_FREE, FLOATING_POINT_PRECISION
 		}; 
 		
 		StringBuilder sb = new StringBuilder();

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
index 5297e61..c7ffdb1 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
@@ -404,7 +404,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
                 LOG.error("Inconsistent internal state - A copy of this CacheableData was dirty on more than 1 GPU");
                 throw new CacheException("Internal Error : Inconsistent internal state, A copy of this CacheableData was dirty on more than 1 GPU");
             } else if (gObj != null){
-                copiedFromGPU = gObj.acquireHostRead();
+                copiedFromGPU = gObj.acquireHostRead(null);
                 if( _data == null )
                     getCache();
             }
@@ -793,7 +793,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
                 LOG.error("Inconsistent internal state - A copy of this CacheableData was dirty on more than 1 GPU");
                 throw new CacheException("Internal Error : Inconsistent internal state, A copy of this CacheableData was dirty on more than 1 GPU");
             } else if (gObj != null){
-                copiedFromGPU = gObj.acquireHostRead();
+                copiedFromGPU = gObj.acquireHostRead(null);
                 if( _data == null )
                     getCache();
             }

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
index 7176a9c..53f1a19 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
@@ -20,7 +20,6 @@
 package org.apache.sysml.runtime.instructions.gpu.context;
 
 import static jcuda.jcusparse.JCusparse.cusparseCreateMatDescr;
-import static jcuda.jcusparse.JCusparse.cusparseDcsr2dense;
 import static jcuda.jcusparse.JCusparse.cusparseSetMatIndexBase;
 import static jcuda.jcusparse.JCusparse.cusparseSetMatType;
 import static jcuda.jcusparse.JCusparse.cusparseSetPointerMode;
@@ -38,6 +37,7 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysml.utils.GPUStatistics;
 import org.apache.sysml.utils.Statistics;
 
@@ -112,8 +112,8 @@ public class CSRPointer {
 		allocateMatDescrPointer();
 	}
 
-	private static long getDoubleSizeOf(long numElems) {
-		return numElems * ((long) jcuda.Sizeof.DOUBLE);
+	private static long getDataTypeSizeOf(long numElems) {
+		return numElems * ((long) LibMatrixCUDA.sizeOfDataType);
 	}
 
 	//  private Pointer allocate(String instName, long size) throws DMLRuntimeException {
@@ -121,7 +121,7 @@ public class CSRPointer {
 	//  }
 
 	private static long getIntSizeOf(long numElems) {
-		return numElems * ((long) jcuda.Sizeof.INT);
+		return numElems * ((long) Sizeof.INT);
 	}
 
 	//  private void cudaFreeHelper(Pointer toFree) throws DMLRuntimeException {
@@ -163,7 +163,7 @@ public class CSRPointer {
 	 * @return size estimate
 	 */
 	public static long estimateSize(long nnz2, long rows) {
-		long sizeofValArray = getDoubleSizeOf(nnz2);
+		long sizeofValArray = getDataTypeSizeOf(nnz2);
 		long sizeofRowPtrArray = getIntSizeOf(rows + 1);
 		long sizeofColIndArray = getIntSizeOf(nnz2);
 		long sizeofDescr = getIntSizeOf(4);
@@ -181,6 +181,7 @@ public class CSRPointer {
 	/**
 	 * Static method to copy a CSR sparse matrix from Host to Device
 	 *
+	 * @param gCtx GPUContext
 	 * @param dest   [input] destination location (on GPU)
 	 * @param rows   number of rows
 	 * @param nnz    number of non-zeroes
@@ -189,7 +190,7 @@ public class CSRPointer {
 	 * @param values double array of non zero values
 	 * @throws DMLRuntimeException if error occurs
 	 */
-	public static void copyToDevice(CSRPointer dest, int rows, long nnz, int[] rowPtr, int[] colInd, double[] values) throws DMLRuntimeException {
+	public static void copyToDevice(GPUContext gCtx, CSRPointer dest, int rows, long nnz, int[] rowPtr, int[] colInd, double[] values) throws DMLRuntimeException {
 		CSRPointer r = dest;
 		long t0 = 0;
 		if (DMLScript.STATISTICS)
@@ -200,15 +201,15 @@ public class CSRPointer {
 		if(rowPtr.length < rows + 1) throw new DMLRuntimeException("The length of rowPtr needs to be greater than or equal to " + (rows + 1));
 		if(colInd.length < nnz) throw new DMLRuntimeException("The length of colInd needs to be greater than or equal to " + nnz);
 		if(values.length < nnz) throw new DMLRuntimeException("The length of values needs to be greater than or equal to " + nnz);
+		LibMatrixCUDA.cudaSupportFunctions.hostToDevice(gCtx, values, r.val, null);
 		cudaMemcpy(r.rowPtr, Pointer.to(rowPtr), getIntSizeOf(rows + 1), cudaMemcpyHostToDevice);
 		cudaMemcpy(r.colInd, Pointer.to(colInd), getIntSizeOf(nnz), cudaMemcpyHostToDevice);
-		cudaMemcpy(r.val, Pointer.to(values), getDoubleSizeOf(nnz), cudaMemcpyHostToDevice);
 		if (DMLScript.STATISTICS)
 			GPUStatistics.cudaToDevTime.add(System.nanoTime() - t0);
 		if (DMLScript.STATISTICS)
 			GPUStatistics.cudaToDevCount.add(3);
 	}
-
+	
 	/**
 	 * Static method to copy a CSR sparse matrix from Device to host
 	 *
@@ -217,20 +218,12 @@ public class CSRPointer {
 	 * @param nnz    [input] number of non-zeroes
 	 * @param rowPtr [output] pre-allocated integer array of row pointers of size (rows+1)
 	 * @param colInd [output] pre-allocated integer array of column indices of size nnz
-	 * @param values [output] pre-allocated double array of values of size nnz
+	 * @throws DMLRuntimeException if error
 	 */
-	public static void copyToHost(CSRPointer src, int rows, long nnz, int[] rowPtr, int[] colInd, double[] values) {
+	public static void copyPtrToHost(CSRPointer src, int rows, long nnz, int[] rowPtr, int[] colInd) throws DMLRuntimeException {
 		CSRPointer r = src;
-		long t0 = 0;
-		if (DMLScript.STATISTICS)
-			t0 = System.nanoTime();
 		cudaMemcpy(Pointer.to(rowPtr), r.rowPtr, getIntSizeOf(rows + 1), cudaMemcpyDeviceToHost);
 		cudaMemcpy(Pointer.to(colInd), r.colInd, getIntSizeOf(nnz), cudaMemcpyDeviceToHost);
-		cudaMemcpy(Pointer.to(values), r.val, getDoubleSizeOf(nnz), cudaMemcpyDeviceToHost);
-		if (DMLScript.STATISTICS)
-			GPUStatistics.cudaFromDevTime.add(System.nanoTime() - t0);
-		if (DMLScript.STATISTICS)
-			GPUStatistics.cudaFromDevCount.add(3);
 	}
 
 	/**
@@ -305,9 +298,9 @@ public class CSRPointer {
 			// with no memory allocated on the GPU.
 			return r;
 		}
-		gCtx.ensureFreeSpace(getDoubleSizeOf(nnz2) + getIntSizeOf(rows + 1) + getIntSizeOf(nnz2));
+		gCtx.ensureFreeSpace(getDataTypeSizeOf(nnz2) + getIntSizeOf(rows + 1) + getIntSizeOf(nnz2));
 		// increment the cudaCount by 1 for the allocation of all 3 arrays
-		r.val = gCtx.allocate(null, getDoubleSizeOf(nnz2));
+		r.val = gCtx.allocate(null, getDataTypeSizeOf(nnz2));
 		r.rowPtr = gCtx.allocate(null, getIntSizeOf(rows + 1));
 		r.colInd = gCtx.allocate(null, getIntSizeOf(nnz2));
 		return r;
@@ -410,7 +403,7 @@ public class CSRPointer {
 			throws DMLRuntimeException {
 		LOG.trace("GPU : step3AllocateValNInd" + ", GPUContext=" + gCtx);
 		// Increment cudaCount by one when all three arrays of CSR sparse array are allocated
-		C.val = gCtx.allocate(null, getDoubleSizeOf(C.nnz));
+		C.val = gCtx.allocate(null, getDataTypeSizeOf(C.nnz));
 		C.colInd = gCtx.allocate(null, getIntSizeOf(C.nnz));
 	}
 
@@ -441,13 +434,14 @@ public class CSRPointer {
 		that.gpuContext.ensureFreeSpace(totalSize);
 
 		that.nnz = me.nnz;
-		that.val = allocate(that.nnz * Sizeof.DOUBLE);
-		that.rowPtr = allocate(rows * Sizeof.DOUBLE);
-		that.colInd = allocate(that.nnz * Sizeof.DOUBLE);
+		that.val = allocate(that.nnz * LibMatrixCUDA.sizeOfDataType);
+		// TODO: Nakul ... can you please double-check whether the below was a bug or intentional ?
+		that.rowPtr = allocate(rows * Sizeof.INT);
+		that.colInd = allocate(that.nnz * Sizeof.INT);
 
-		cudaMemcpy(that.val, me.val, that.nnz * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
-		cudaMemcpy(that.rowPtr, me.rowPtr, rows * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
-		cudaMemcpy(that.colInd, me.colInd, that.nnz * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
+		cudaMemcpy(that.val, me.val, that.nnz * LibMatrixCUDA.sizeOfDataType, cudaMemcpyDeviceToDevice);
+		cudaMemcpy(that.rowPtr, me.rowPtr, rows * Sizeof.INT, cudaMemcpyDeviceToDevice);
+		cudaMemcpy(that.colInd, me.colInd, that.nnz * Sizeof.INT, cudaMemcpyDeviceToDevice);
 
 		return that;
 	}
@@ -506,12 +500,12 @@ public class CSRPointer {
 		long t0 = GPUStatistics.DISPLAY_STATISTICS && instName != null ? System.nanoTime() : 0;
 		LOG.trace("GPU : sparse -> column major dense (inside CSRPointer) on " + this + ", GPUContext="
 				+ getGPUContext());
-		long size = ((long) rows) * getDoubleSizeOf((long) cols);
+		long size = ((long) rows) * getDataTypeSizeOf((long) cols);
 		Pointer A = allocate(size);
 		// If this sparse block is empty, the allocated dense matrix, initialized to zeroes, will be returned.
 		if (val != null && rowPtr != null && colInd != null && nnz > 0) {
 			// Note: cusparseDcsr2dense method cannot handle empty blocks
-			cusparseDcsr2dense(cusparseHandle, rows, cols, descr, val, rowPtr, colInd, A, rows);
+			LibMatrixCUDA.cudaSupportFunctions.cusparsecsr2dense(cusparseHandle, rows, cols, descr, val, rowPtr, colInd, A, rows);
 			//cudaDeviceSynchronize;
 		} else {
 			LOG.debug("in CSRPointer, the values array, row pointers array or column indices array was null");

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
index 55cb95f..dd776bc 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
@@ -24,8 +24,6 @@ import static jcuda.jcudnn.JCudnn.cudnnCreate;
 import static jcuda.jcudnn.JCudnn.cudnnDestroy;
 import static jcuda.jcusolver.JCusolverDn.cusolverDnCreate;
 import static jcuda.jcusolver.JCusolverDn.cusolverDnDestroy;
-import static jcuda.jcusolver.JCusolverSp.cusolverSpCreate;
-import static jcuda.jcusolver.JCusolverSp.cusolverSpDestroy;
 import static jcuda.jcusparse.JCusparse.cusparseCreate;
 import static jcuda.jcusparse.JCusparse.cusparseDestroy;
 import static jcuda.runtime.JCuda.cudaDeviceScheduleBlockingSync;
@@ -63,7 +61,6 @@ import jcuda.Pointer;
 import jcuda.jcublas.cublasHandle;
 import jcuda.jcudnn.cudnnHandle;
 import jcuda.jcusolver.cusolverDnHandle;
-import jcuda.jcusolver.cusolverSpHandle;
 import jcuda.jcusparse.cusparseHandle;
 import jcuda.runtime.JCuda;
 import jcuda.runtime.cudaDeviceProp;
@@ -107,10 +104,6 @@ public class GPUContext {
 	 */
 	private cusolverDnHandle cusolverDnHandle;
 	/**
-	 * cusolverSpHandle for invoking solve() function on sparse matrices on the GPU
-	 */
-	private cusolverSpHandle cusolverSpHandle;
-	/**
 	 * to launch custom CUDA kernel, specific to the active GPU for this GPUContext
 	 */
 	private JCudaKernels kernels;
@@ -233,12 +226,7 @@ public class GPUContext {
 			cusolverDnHandle = new cusolverDnHandle();
 			cusolverDnCreate(cusolverDnHandle);
 		}
-
-		if (cusolverSpHandle == null) {
-			cusolverSpHandle = new cusolverSpHandle();
-			cusolverSpCreate(cusolverSpHandle);
-		}
-
+		
 		if (kernels == null) {
 			kernels = new JCudaKernels();
 		}
@@ -578,7 +566,7 @@ public class GPUContext {
 								+ "). Allocated GPU objects:" + allocatedGPUObjects.toString());
 			}
 			if (toBeRemoved.dirty) {
-				toBeRemoved.copyFromDeviceToHost();
+				toBeRemoved.copyFromDeviceToHost(instructionName);
 			}
 			toBeRemoved.clearData(true);
 		}
@@ -754,15 +742,6 @@ public class GPUContext {
 	}
 
 	/**
-	 * Returns cusolverSpHandle for invoking solve() function on sparse matrices on the GPU.
-	 *
-	 * @return cusolverSpHandle for current thread
-	 */
-	public cusolverSpHandle getCusolverSpHandle() {
-		return cusolverSpHandle;
-	}
-
-	/**
 	 * Returns utility class used to launch custom CUDA kernel, specific to the active GPU for this GPUContext.
 	 *
 	 * @return {@link JCudaKernels} for current thread
@@ -801,14 +780,10 @@ public class GPUContext {
 		if (cusolverDnHandle != null)
 			cusolverDnDestroy(cusolverDnHandle);
 
-		if (cusolverSpHandle != null)
-			cusolverSpDestroy(cusolverSpHandle);
-
 		cudnnHandle = null;
 		cublasHandle = null;
 		cusparseHandle = null;
 		cusolverDnHandle = null;
-		cusolverSpHandle = null;
 	}
 
 	/**
@@ -827,7 +802,7 @@ public class GPUContext {
 			if (o.isDirty()) {
 				LOG.warn("Attempted to free GPU Memory when a block[" + o
 						+ "] is still on GPU memory, copying it back to host.");
-				o.acquireHostRead();
+				o.acquireHostRead(null);
 			}
 			o.clearData(true);
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index 06327db..35dfd58 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -19,14 +19,10 @@
 package org.apache.sysml.runtime.instructions.gpu.context;
 
 import static jcuda.jcublas.cublasOperation.CUBLAS_OP_T;
-import static jcuda.jcusparse.JCusparse.cusparseDdense2csr;
-import static jcuda.jcusparse.JCusparse.cusparseDnnz;
 import static jcuda.runtime.JCuda.cudaMemcpy;
 import static jcuda.runtime.JCuda.cudaMemset;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
-import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
-
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.LongAdder;
 
@@ -47,9 +43,6 @@ import org.apache.sysml.runtime.matrix.data.SparseBlockMCSR;
 import org.apache.sysml.utils.GPUStatistics;
 
 import jcuda.Pointer;
-import jcuda.Sizeof;
-import jcuda.jcublas.JCublas2;
-import jcuda.jcusparse.JCusparse;
 import jcuda.jcusparse.cusparseDirection;
 import jcuda.jcusparse.cusparseHandle;
 import jcuda.jcusparse.cusparseMatDescr;
@@ -126,7 +119,7 @@ public class GPUObject {
 			if (me.jcudaDenseMatrixPtr != null) {
 				long rows = me.mat.getNumRows();
 				long cols = me.mat.getNumColumns();
-				long size = rows * cols * Sizeof.DOUBLE;
+				long size = rows * cols * LibMatrixCUDA.sizeOfDataType;
 				me.gpuContext.ensureFreeSpace((int) size);
 				that.jcudaDenseMatrixPtr = allocate(size);
 				cudaMemcpy(that.jcudaDenseMatrixPtr, me.jcudaDenseMatrixPtr, size, cudaMemcpyDeviceToDevice);
@@ -181,13 +174,13 @@ public class GPUObject {
 		if(LOG.isTraceEnabled()) {
 			LOG.trace("GPU : transpose of block of size [" + m + "," + n + "]" + ", GPUContext=" + gCtx);
 		}
-		Pointer alpha = Pointer.to(new double[] { 1.0 });
-		Pointer beta = Pointer.to(new double[] { 0.0 });
+		Pointer alpha = LibMatrixCUDA.one();
+		Pointer beta = LibMatrixCUDA.zero();
 		Pointer A = densePtr;
-		Pointer C = gCtx.allocate(((long) m) * getDoubleSizeOf(n));
+		Pointer C = gCtx.allocate(((long) m) * getDatatypeSizeOf(n));
 
 		// Transpose the matrix to get a dense matrix
-		JCublas2.cublasDgeam(gCtx.getCublasHandle(), CUBLAS_OP_T, CUBLAS_OP_T, m, n, alpha, A, lda, beta, new Pointer(),
+		LibMatrixCUDA.cudaSupportFunctions.cublasgeam(gCtx.getCublasHandle(), CUBLAS_OP_T, CUBLAS_OP_T, m, n, alpha, A, lda, beta, new Pointer(),
 				lda, C, ldc);
 		return C;
 	}
@@ -217,7 +210,7 @@ public class GPUObject {
 		nnzTotalDevHostPtr = gCtx.allocate(getIntSizeOf(1));
 
 		// Output is in dense vector format, convert it to CSR
-		cusparseDnnz(cusparseHandle, cusparseDirection.CUSPARSE_DIRECTION_ROW, rows, cols, matDescr, densePtr, rows,
+		LibMatrixCUDA.cudaSupportFunctions.cusparsennz(cusparseHandle, cusparseDirection.CUSPARSE_DIRECTION_ROW, rows, cols, matDescr, densePtr, rows,
 				nnzPerRowPtr, nnzTotalDevHostPtr);
 		//cudaDeviceSynchronize();
 		int[] nnzC = { -1 };
@@ -241,7 +234,7 @@ public class GPUObject {
 		}
 
 		CSRPointer C = CSRPointer.allocateEmpty(gCtx, nnzC[0], rows);
-		cusparseDdense2csr(cusparseHandle, rows, cols, matDescr, densePtr, rows, nnzPerRowPtr, C.val, C.rowPtr,
+		LibMatrixCUDA.cudaSupportFunctions.cusparsedense2csr(cusparseHandle, rows, cols, matDescr, densePtr, rows, nnzPerRowPtr, C.val, C.rowPtr,
 				C.colInd);
 		//cudaDeviceSynchronize();
 
@@ -252,31 +245,6 @@ public class GPUObject {
 	}
 
 	/**
-	 * Gets the double array from GPU memory onto host memory and returns string.
-	 *
-	 * @param A    Pointer to memory on device (GPU), assumed to point to a double array
-	 * @param rows rows in matrix A
-	 * @param cols columns in matrix A
-	 * @return the debug string
-	 * @throws DMLRuntimeException if DMLRuntimeException occurs
-	 */
-	public static String debugString(Pointer A, long rows, long cols) throws DMLRuntimeException {
-		StringBuffer sb = new StringBuffer();
-		int len = toIntExact(rows * cols);
-		double[] tmp = new double[len];
-		cudaMemcpy(Pointer.to(tmp), A, getDoubleSizeOf(len), cudaMemcpyDeviceToHost);
-		int k = 0;
-		for (int i = 0; i < rows; i++) {
-			for (int j = 0; j < cols; j++) {
-				sb.append(tmp[k]).append(' ');
-				k++;
-			}
-			sb.append('\n');
-		}
-		return sb.toString();
-	}
-
-	/**
 	 * Convenience method to directly examine the Sparse matrix on GPU
 	 *
 	 * @return CSR (compressed sparse row) pointer
@@ -287,7 +255,7 @@ public class GPUObject {
 
 	/**
 	 * Convenience method to directly set the sparse matrix on GPU
-	 * Needed for operations like {@link JCusparse#cusparseDcsrgemm(cusparseHandle, int, int, int, int, int, cusparseMatDescr, int, Pointer, Pointer, Pointer, cusparseMatDescr, int, Pointer, Pointer, Pointer, cusparseMatDescr, Pointer, Pointer, Pointer)}
+	 * Needed for operations like cusparseDcsrgemm(cusparseHandle, int, int, int, int, int, cusparseMatDescr, int, Pointer, Pointer, Pointer, cusparseMatDescr, int, Pointer, Pointer, Pointer, cusparseMatDescr, Pointer, Pointer, Pointer)
 	 *
 	 * @param sparseMatrixPtr CSR (compressed sparse row) pointer
 	 * @throws DMLRuntimeException ?
@@ -475,8 +443,8 @@ public class GPUObject {
 		return isSparse;
 	}
 	
-	private static long getDoubleSizeOf(long numElems) {
-		return numElems * ((long) jcuda.Sizeof.DOUBLE);
+	private static long getDatatypeSizeOf(long numElems) {
+		return numElems * LibMatrixCUDA.sizeOfDataType;
 	}
 
 	private static long getIntSizeOf(long numElems) {
@@ -524,7 +492,7 @@ public class GPUObject {
 		long rows = mat.getNumRows();
 		long cols = mat.getNumColumns();
 		int numElems = toIntExact(rows * cols);
-		long size = getDoubleSizeOf(numElems);
+		long size = getDatatypeSizeOf(numElems);
 		setDenseMatrixCudaPointer(allocate(size));
 		// The "fill" kernel is called which treats the matrix "jcudaDensePtr" like a vector and fills it with value "v"
 		// If the fill value is 0, no need to call the special kernel, the allocate memsets the allocated region to 0
@@ -609,10 +577,11 @@ public class GPUObject {
 	/**
 	 * if the data is allocated on the GPU and is dirty, it is copied back to the host memory
 	 *
+	 * @param instName name of the instruction
 	 * @return true if a copy to host happened, false otherwise
 	 * @throws CacheException ?
 	 */
-	public boolean acquireHostRead() throws CacheException {
+	public boolean acquireHostRead(String instName) throws CacheException {
 		boolean copied = false;
 		try {
 			if(LOG.isTraceEnabled()) {
@@ -623,7 +592,7 @@ public class GPUObject {
 					LOG.trace("GPU : data is dirty on device, copying to host, on " + this + ", GPUContext="
 						+ getGPUContext());
 				}
-				copyFromDeviceToHost();
+				copyFromDeviceToHost(instName);
 				copied = true;
 			}
 		} catch (DMLRuntimeException e) {
@@ -728,7 +697,7 @@ public class GPUObject {
 			throw new DMLRuntimeException("Internal error - invalid number of rows when allocating dense matrix");
 		if(cols <= 0)
 			throw new DMLRuntimeException("Internal error - invalid number of columns when allocating dense matrix;");
-		long size = getDoubleSizeOf(rows * cols);
+		long size = getDatatypeSizeOf(rows * cols);
 		Pointer tmp = allocate(size);
 		setDenseMatrixCudaPointer(tmp);
 	}
@@ -774,7 +743,7 @@ public class GPUObject {
 		if (LibMatrixCUDA.isInSparseFormat(getGPUContext(), mat)) {
 			GPUSize = CSRPointer.estimateSize(nnz, rlen);
 		} else {
-			GPUSize = getDoubleSizeOf(rlen * clen);
+			GPUSize = getDatatypeSizeOf(rlen * clen);
 		}
 		return GPUSize;
 	}
@@ -858,7 +827,7 @@ public class GPUObject {
 
 			if (copyToDevice) {
 				long t1 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
-				CSRPointer.copyToDevice(getJcudaSparseMatrixPtr(), tmp.getNumRows(), tmp.getNonZeros(), rowPtr, colInd,
+				CSRPointer.copyToDevice(getGPUContext(), getJcudaSparseMatrixPtr(), tmp.getNumRows(), tmp.getNonZeros(), rowPtr, colInd,
 						values);
 				if(GPUStatistics.DISPLAY_STATISTICS) 
 					GPUStatistics.maintainCPMiscTimes(opcode, GPUInstruction.MISC_TIMER_HOST_TO_DEVICE, System.nanoTime() - t1);
@@ -877,18 +846,14 @@ public class GPUObject {
 				// Minor optimization: No need to allocate empty error for CPU 
 				// data = new double[tmp.getNumRows() * tmp.getNumColumns()];
 				long t1 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
-				cudaMemset(getJcudaDenseMatrixPtr(), 0, getDoubleSizeOf(mat.getNumRows() * mat.getNumColumns()));
+				cudaMemset(getJcudaDenseMatrixPtr(), 0, getDatatypeSizeOf(mat.getNumRows() * mat.getNumColumns()));
 				if(GPUStatistics.DISPLAY_STATISTICS) 
 					GPUStatistics.maintainCPMiscTimes(opcode, GPUInstruction.MISC_TIMER_SET_ZERO, System.nanoTime() - t1);
 			}
 			else {
 				// Copy dense block
 				// H2D now only measures the time taken to do 
-				long t1 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
-				cudaMemcpy(getJcudaDenseMatrixPtr(), Pointer.to(data),
-						getDoubleSizeOf(mat.getNumRows() * mat.getNumColumns()), cudaMemcpyHostToDevice);
-				if(GPUStatistics.DISPLAY_STATISTICS) 
-					GPUStatistics.maintainCPMiscTimes(opcode, GPUInstruction.MISC_TIMER_HOST_TO_DEVICE, System.nanoTime() - t1);
+				LibMatrixCUDA.cudaSupportFunctions.hostToDevice(getGPUContext(), data, getJcudaDenseMatrixPtr(), opcode);
 			}
 		}
 
@@ -907,7 +872,7 @@ public class GPUObject {
 		return (int) l;
 	}
 
-	protected void copyFromDeviceToHost() throws DMLRuntimeException {
+	protected void copyFromDeviceToHost(String instName) throws DMLRuntimeException {
 		if(LOG.isTraceEnabled()) {
 			LOG.trace("GPU : copyFromDeviceToHost, on " + this + ", GPUContext=" + getGPUContext());
 		}
@@ -921,11 +886,7 @@ public class GPUObject {
 				start = System.nanoTime();
 			MatrixBlock tmp = new MatrixBlock(toIntExact(mat.getNumRows()), toIntExact(mat.getNumColumns()), false);
 			tmp.allocateDenseBlock();
-			double[] data = tmp.getDenseBlock();
-
-			cudaMemcpy(Pointer.to(data), getJcudaDenseMatrixPtr(), getDoubleSizeOf(data.length),
-					cudaMemcpyDeviceToHost);
-
+			LibMatrixCUDA.cudaSupportFunctions.deviceToHost(getGPUContext(), getJcudaDenseMatrixPtr(), tmp.getDenseBlock(), instName);
 			tmp.recomputeNonZeros();
 			mat.acquireModify(tmp);
 			mat.release();
@@ -951,10 +912,16 @@ public class GPUObject {
 				int rows = toIntExact(mat.getNumRows());
 				int cols = toIntExact(mat.getNumColumns());
 				int nnz = toIntExact(getJcudaSparseMatrixPtr().nnz);
+				double[] values = new double[nnz];
+				LibMatrixCUDA.cudaSupportFunctions.deviceToHost(getGPUContext(), getJcudaSparseMatrixPtr().val, values, instName);
 				int[] rowPtr = new int[rows + 1];
 				int[] colInd = new int[nnz];
-				double[] values = new double[nnz];
-				CSRPointer.copyToHost(getJcudaSparseMatrixPtr(), rows, nnz, rowPtr, colInd, values);
+				long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
+				CSRPointer.copyPtrToHost(getJcudaSparseMatrixPtr(), rows, nnz, rowPtr, colInd);
+				if (DMLScript.STATISTICS)
+					GPUStatistics.cudaFromDevTime.add(System.nanoTime() - t0);
+				if (DMLScript.STATISTICS)
+					GPUStatistics.cudaFromDevCount.add(3);
 
 				SparseBlockCSR sparseBlock = new SparseBlockCSR(rowPtr, colInd, values, nnz);
 				MatrixBlock tmp = new MatrixBlock(rows, cols, nnz, sparseBlock);

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaKernels.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaKernels.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaKernels.java
index e1894ae..d22110d 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaKernels.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaKernels.java
@@ -29,6 +29,7 @@ import java.util.HashMap;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.io.IOUtilFunctions;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
 
 import jcuda.Pointer;
 import jcuda.driver.CUfunction;
@@ -72,11 +73,17 @@ public class JCudaKernels {
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
 	public void launchKernel(String name, ExecutionConfig config, Object... arguments) throws DMLRuntimeException {
+		name = name + LibMatrixCUDA.customKernelSuffix;
 		CUfunction function = kernels.get(name);
+		
 		if (function == null) {
 			// caching functions into hashmap reduces the lookup overhead
 			function = new CUfunction();
-			checkResult(cuModuleGetFunction(function, module, name));
+			try {
+				checkResult(cuModuleGetFunction(function, module, name));
+			} catch(jcuda.CudaException e) {
+				throw new DMLRuntimeException("Error finding the custom kernel:" + name, e);
+			}
 		}
 
 		// Setup parameters

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/matrix/data/CudaSupportFunctions.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/CudaSupportFunctions.java b/src/main/java/org/apache/sysml/runtime/matrix/data/CudaSupportFunctions.java
new file mode 100644
index 0000000..2b6c039
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/CudaSupportFunctions.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,  * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.matrix.data;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+
+import jcuda.jcublas.cublasHandle;
+import jcuda.jcusolver.cusolverDnHandle;
+import jcuda.jcusparse.cusparseHandle;
+import jcuda.jcusparse.cusparseMatDescr;
+import jcuda.Pointer;
+
+/**
+ * DESIGN DOCUMENTATION FOR SUPPORTING LOWER PRECISION:
+ * 1. SystemML.cu has been templatized in following way to support different datatype:
+ * - Similar to CuBLAS and CuSPARSE, the global kernels have the datatype specification in their name (for example: f for float
+ * and d for datatpe). But unlike CuBLAS and CuSPARSE, these are suffixes so as to simplify the engine.  
+ * - The global kernels with datatype specification invoke a corresponding templatized kernel (without suffix) which contains the core logic.
+ * - The suffixes are added in JCudaKernels's launchKernel method before invocation.
+ * For example:
+ * <code>
+ * template &lt; typename T &gt;
+ * __device__ void matrix_atan(T *A, T *C, unsigned int size) {
+ *     int index = blockIdx.x * blockDim.x + threadIdx.x;
+ *     if (index &lt; size){
+ *         C[index] = atan(A[index]);
+ *     }
+ * }
+ * extern "C" __global__ void matrix_atand(double *A, double *C, unsigned int size) {
+ * 	matrix_atan(A, C, size);
+ * }
+ * extern "C" __global__ void matrix_atanf(float *A, float *C, unsigned int size) {
+ * 	matrix_atan(A, C, size);
+ * } 
+ * </code>
+ * 
+ * 2. The CUDA library calls (such as CuBLAS, CuSPARSE, etc) go through this interface.
+ * The naming and parameters of the methods in this class are consistent with that of CUDA library to simplify development.
+ * 
+ * 3. During SystemML initialization, the appropriate class implementing CudaKernels interface is set based on the configuration property sysml.dataType.
+ */
+public interface CudaSupportFunctions {
+	public static boolean PERFORM_CONVERSION_ON_DEVICE = true;
+	public int cusparsecsrgemm(cusparseHandle handle, int transA, int transB, int m, int n, int k, 
+			cusparseMatDescr descrA, int nnzA, Pointer csrValA, Pointer csrRowPtrA, Pointer csrColIndA, 
+			cusparseMatDescr descrB, int nnzB, Pointer csrValB, Pointer csrRowPtrB, Pointer csrColIndB, 
+			cusparseMatDescr descrC, Pointer csrValC, Pointer csrRowPtrC, Pointer csrColIndC);
+	public int	cublasgeam(cublasHandle handle, int transa, int transb, int m, int n, jcuda.Pointer alpha, jcuda.Pointer A, 
+			int lda, jcuda.Pointer beta, jcuda.Pointer B, int ldb, jcuda.Pointer C, int ldc);
+	public int	cusparsecsrmv(cusparseHandle handle, int transA, int m, int n, int nnz, jcuda.Pointer alpha, cusparseMatDescr descrA, jcuda.Pointer csrValA, jcuda.Pointer csrRowPtrA, jcuda.Pointer csrColIndA, 
+			jcuda.Pointer x, jcuda.Pointer beta, jcuda.Pointer y);
+	public int	cusparsecsrmm2(cusparseHandle handle, int transa, int transb, int m, int n, int k, int nnz, jcuda.Pointer alpha, cusparseMatDescr descrA, jcuda.Pointer csrValA, jcuda.Pointer csrRowPtrA, jcuda.Pointer csrColIndA, 
+			jcuda.Pointer B, int ldb, jcuda.Pointer beta, jcuda.Pointer C, int ldc);
+	public int cublasdot(cublasHandle handle, int n, jcuda.Pointer x, int incx, jcuda.Pointer y, int incy, jcuda.Pointer result);
+	public int cublasgemv(cublasHandle handle, int trans, int m, int n, jcuda.Pointer alpha, jcuda.Pointer A, int lda, jcuda.Pointer x, int incx, jcuda.Pointer beta, jcuda.Pointer y, int incy);
+	public int cublasgemm(cublasHandle handle, int transa, int transb, int m, int n, int k, jcuda.Pointer alpha, jcuda.Pointer A, int lda, jcuda.Pointer B, int ldb, jcuda.Pointer beta, jcuda.Pointer C, int ldc);
+	public int cusparsecsr2csc(cusparseHandle handle, int m, int n, int nnz, jcuda.Pointer csrVal, jcuda.Pointer csrRowPtr, jcuda.Pointer csrColInd, jcuda.Pointer cscVal, jcuda.Pointer cscRowInd, jcuda.Pointer cscColPtr, int copyValues, int idxBase);
+	public int cublassyrk(cublasHandle handle, int uplo, int trans, int n, int k, jcuda.Pointer alpha, jcuda.Pointer A, int lda, jcuda.Pointer beta, jcuda.Pointer C, int ldc);
+	public int cublasaxpy(cublasHandle handle, int n, jcuda.Pointer alpha, jcuda.Pointer x, int incx, jcuda.Pointer y, int incy);
+	public int cublastrsm(cublasHandle handle, int side, int uplo, int trans, int diag, int m, int n, jcuda.Pointer alpha, jcuda.Pointer A, int lda, jcuda.Pointer B, int ldb);
+	public int cusolverDngeqrf_bufferSize(cusolverDnHandle handle, int m, int n, Pointer A, int lda, int[] Lwork);
+	public int cusolverDngeqrf(cusolverDnHandle handle, int m, int n, Pointer A, int lda, Pointer TAU, Pointer Workspace, int Lwork, Pointer devInfo);
+	public int cusolverDnormqr(cusolverDnHandle handle, int side, int trans, int m, int n, int k, Pointer A, int lda, Pointer tau, Pointer C, int ldc, Pointer work, int lwork, Pointer devInfo);
+	public int cusparsecsrgeam(cusparseHandle handle, int m, int n, jcuda.Pointer alpha, cusparseMatDescr descrA, int nnzA, jcuda.Pointer csrValA, jcuda.Pointer csrRowPtrA, jcuda.Pointer csrColIndA, jcuda.Pointer beta, cusparseMatDescr descrB, int nnzB, jcuda.Pointer csrValB, jcuda.Pointer csrRowPtrB, jcuda.Pointer csrColIndB, cusparseMatDescr descrC, jcuda.Pointer csrValC, jcuda.Pointer csrRowPtrC, jcuda.Pointer csrColIndC);
+	public int cusparsecsr2dense(cusparseHandle handle, int m, int n, cusparseMatDescr descrA, jcuda.Pointer csrValA, jcuda.Pointer csrRowPtrA, jcuda.Pointer csrColIndA, jcuda.Pointer A, int lda) ;
+	public int cusparsedense2csr(cusparseHandle handle, int m, int n, cusparseMatDescr descrA, jcuda.Pointer A, int lda, jcuda.Pointer nnzPerRow, jcuda.Pointer csrValA, jcuda.Pointer csrRowPtrA, jcuda.Pointer csrColIndA);
+	public int cusparsennz(cusparseHandle handle, int dirA, int m, int n, cusparseMatDescr descrA, jcuda.Pointer A, int lda, jcuda.Pointer nnzPerRowCol, jcuda.Pointer nnzTotalDevHostPtr);
+	public void deviceToHost(GPUContext gCtx, Pointer src, double [] dest, String instName) throws DMLRuntimeException;
+	public void hostToDevice(GPUContext gCtx, double [] src,  Pointer dest, String instName) throws DMLRuntimeException;
+	
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/matrix/data/DoublePrecisionCudaSupportFunctions.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/DoublePrecisionCudaSupportFunctions.java b/src/main/java/org/apache/sysml/runtime/matrix/data/DoublePrecisionCudaSupportFunctions.java
new file mode 100644
index 0000000..78b4de0
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/DoublePrecisionCudaSupportFunctions.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.matrix.data;
+
+import static jcuda.runtime.JCuda.cudaMemcpy;
+import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
+import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.utils.GPUStatistics;
+
+import jcuda.Pointer;
+import jcuda.Sizeof;
+import jcuda.jcublas.JCublas2;
+import jcuda.jcublas.cublasHandle;
+import jcuda.jcusolver.JCusolverDn;
+import jcuda.jcusolver.cusolverDnHandle;
+import jcuda.jcusparse.JCusparse;
+import jcuda.jcusparse.cusparseHandle;
+import jcuda.jcusparse.cusparseMatDescr;
+
+public class DoublePrecisionCudaSupportFunctions implements CudaSupportFunctions {
+
+	@Override
+	public int cusparsecsrgemm(cusparseHandle handle, int transA, int transB, int m, int n, int k,
+			cusparseMatDescr descrA, int nnzA, Pointer csrValA, Pointer csrRowPtrA, Pointer csrColIndA,
+			cusparseMatDescr descrB, int nnzB, Pointer csrValB, Pointer csrRowPtrB, Pointer csrColIndB,
+			cusparseMatDescr descrC, Pointer csrValC, Pointer csrRowPtrC, Pointer csrColIndC) {
+		return JCusparse.cusparseDcsrgemm(handle, transA,  transB,  m,  n,  k,
+				 descrA,  nnzA,  csrValA,  csrRowPtrA,  csrColIndA,
+				 descrB,  nnzB,  csrValB,  csrRowPtrB,  csrColIndB,
+				 descrC,  csrValC,  csrRowPtrC,  csrColIndC);
+	}
+	
+	@Override
+	public int cublasgeam(cublasHandle handle, int transa, int transb, int m, int n, Pointer alpha, Pointer A, int lda,
+			Pointer beta, Pointer B, int ldb, Pointer C, int ldc) {
+		return JCublas2.cublasDgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
+	}
+	
+	@Override
+	public int cusparsecsrmv(cusparseHandle handle, int transA, int m, int n, int nnz, Pointer alpha,
+			cusparseMatDescr descrA, Pointer csrValA, Pointer csrRowPtrA, Pointer csrColIndA, Pointer x, Pointer beta,
+			Pointer y) {
+		return JCusparse.cusparseDcsrmv(handle, transA, m, n, nnz, alpha, 
+				descrA, csrValA, csrRowPtrA, csrColIndA, x, beta, y);
+	}
+	
+	@Override
+	public int	cusparsecsrmm2(cusparseHandle handle, int transa, int transb, int m, int n, int k, int nnz, jcuda.Pointer alpha, cusparseMatDescr descrA, 
+			jcuda.Pointer csrValA, jcuda.Pointer csrRowPtrA, jcuda.Pointer csrColIndA, 
+			jcuda.Pointer B, int ldb, jcuda.Pointer beta, jcuda.Pointer C, int ldc) {
+		return JCusparse.cusparseDcsrmm2(handle, transa, transb, m, n, k, nnz, alpha, descrA, csrValA, 
+				csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
+	}
+	
+	@Override
+	public int cublasdot(cublasHandle handle, int n, Pointer x, int incx, Pointer y, int incy, Pointer result) {
+		return JCublas2.cublasDdot(handle, n, x, incx, y, incy, result);
+	}
+	
+	@Override
+	public int cublasgemv(cublasHandle handle, int trans, int m, int n, Pointer alpha, Pointer A, int lda, Pointer x,
+			int incx, Pointer beta, Pointer y, int incy) {
+		return JCublas2.cublasDgemv(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
+	}
+	
+	@Override
+	public int cublasgemm(cublasHandle handle, int transa, int transb, int m, int n, int k, Pointer alpha, Pointer A,
+			int lda, Pointer B, int ldb, Pointer beta, Pointer C, int ldc) {
+		return JCublas2.cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
+	}
+	
+	@Override
+	public int cusparsecsr2csc(cusparseHandle handle, int m, int n, int nnz, Pointer csrVal, Pointer csrRowPtr,
+			Pointer csrColInd, Pointer cscVal, Pointer cscRowInd, Pointer cscColPtr, int copyValues, int idxBase) {
+		return JCusparse.cusparseDcsr2csc(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, cscVal, cscRowInd, cscColPtr, copyValues, idxBase);
+	}
+	
+	@Override
+	public int cublassyrk(cublasHandle handle, int uplo, int trans, int n, int k, Pointer alpha, Pointer A, int lda,
+			Pointer beta, Pointer C, int ldc) {
+		return JCublas2.cublasDsyrk(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc);
+	}
+	
+	@Override
+	public int cublasaxpy(cublasHandle handle, int n, Pointer alpha, Pointer x, int incx, Pointer y, int incy) {
+		return JCublas2.cublasDaxpy(handle, n, alpha, x, incx, y, incy);
+	}
+	
+	@Override
+	public int cublastrsm(cublasHandle handle, int side, int uplo, int trans, int diag, int m, int n, Pointer alpha,
+			Pointer A, int lda, Pointer B, int ldb) {
+		return JCublas2.cublasDtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);
+	}
+
+	@Override
+	public int cusolverDngeqrf_bufferSize(cusolverDnHandle handle, int m, int n, Pointer A, int lda, int[] Lwork) {
+		return JCusolverDn.cusolverDnDgeqrf_bufferSize(handle, m, n, A, lda, Lwork);
+	}
+	
+	@Override
+	public int cusolverDngeqrf(cusolverDnHandle handle, int m, int n, Pointer A, int lda, Pointer TAU,
+			Pointer Workspace, int Lwork, Pointer devInfo) {
+		return JCusolverDn.cusolverDnDgeqrf(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo);
+	}
+
+	@Override
+	public int cusolverDnormqr(cusolverDnHandle handle, int side, int trans, int m, int n, int k, Pointer A, int lda,
+			Pointer tau, Pointer C, int ldc, Pointer work, int lwork, Pointer devInfo) {
+		return JCusolverDn.cusolverDnDormqr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, lwork, devInfo);
+	}
+	
+	@Override
+	public int cusparsecsrgeam(cusparseHandle handle, int m, int n, Pointer alpha, cusparseMatDescr descrA, int nnzA,
+			Pointer csrValA, Pointer csrRowPtrA, Pointer csrColIndA, Pointer beta, cusparseMatDescr descrB, int nnzB,
+			Pointer csrValB, Pointer csrRowPtrB, Pointer csrColIndB, cusparseMatDescr descrC, Pointer csrValC,
+			Pointer csrRowPtrC, Pointer csrColIndC) {
+		return JCusparse.cusparseDcsrgeam(handle, m, n, alpha, descrA, nnzA, 
+				csrValA, csrRowPtrA, csrColIndA, beta, descrB, nnzB, 
+				csrValB, csrRowPtrB, csrColIndB, descrC, csrValC, csrRowPtrC, csrColIndC);
+	}
+	
+	@Override
+	public int cusparsecsr2dense(cusparseHandle handle, int m, int n, cusparseMatDescr descrA, Pointer csrValA,
+			Pointer csrRowPtrA, Pointer csrColIndA, Pointer A, int lda) {
+		return JCusparse.cusparseDcsr2dense(handle, m, n, descrA, csrValA, csrRowPtrA, csrColIndA, A, lda);
+	}
+
+	@Override
+	public int cusparsedense2csr(cusparseHandle handle, int m, int n, cusparseMatDescr descrA, Pointer A, int lda,
+			Pointer nnzPerRow, Pointer csrValA, Pointer csrRowPtrA, Pointer csrColIndA) {
+		return JCusparse.cusparseDdense2csr(handle, m, n, descrA, A, lda, nnzPerRow, csrValA, csrRowPtrA, csrColIndA);
+	}
+
+	@Override
+	public int cusparsennz(cusparseHandle handle, int dirA, int m, int n, cusparseMatDescr descrA, Pointer A, int lda,
+			Pointer nnzPerRowCol, Pointer nnzTotalDevHostPtr) {
+		return JCusparse.cusparseDnnz(handle, dirA, m, n, descrA, A, lda, nnzPerRowCol, nnzTotalDevHostPtr);
+	}
+
+	@Override
+	public void deviceToHost(GPUContext gCtx, Pointer src, double[] dest, String instName) throws DMLRuntimeException {
+		long t1 = GPUStatistics.DISPLAY_STATISTICS  && instName != null? System.nanoTime() : 0;
+		cudaMemcpy(Pointer.to(dest), src, ((long)dest.length)*Sizeof.DOUBLE, cudaMemcpyDeviceToHost);
+		if(GPUStatistics.DISPLAY_STATISTICS && instName != null) 
+			GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_DEVICE_TO_HOST, System.nanoTime() - t1);
+	}
+
+	@Override
+	public void hostToDevice(GPUContext gCtx, double[] src, Pointer dest, String instName) throws DMLRuntimeException {
+		long t1 = GPUStatistics.DISPLAY_STATISTICS  && instName != null? System.nanoTime() : 0;
+		cudaMemcpy(dest, Pointer.to(src), ((long)src.length)*Sizeof.DOUBLE, cudaMemcpyHostToDevice);
+		if(GPUStatistics.DISPLAY_STATISTICS && instName != null) 
+			GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_HOST_TO_DEVICE, System.nanoTime() - t1);
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 7e25299..eb17e69 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -21,12 +21,13 @@ package org.apache.sysml.runtime.matrix.data;
 
 import static jcuda.jcublas.cublasOperation.CUBLAS_OP_N;
 import static jcuda.jcublas.cublasOperation.CUBLAS_OP_T;
-import static jcuda.jcusparse.JCusparse.cusparseDcsr2csc;
 import static jcuda.runtime.JCuda.cudaMemcpy;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
@@ -80,14 +81,11 @@ import org.apache.sysml.utils.Statistics;
 
 import jcuda.Pointer;
 import jcuda.Sizeof;
-import jcuda.jcublas.JCublas2;
 import jcuda.jcublas.cublasDiagType;
 import jcuda.jcublas.cublasFillMode;
 import jcuda.jcublas.cublasHandle;
 import jcuda.jcublas.cublasOperation;
 import jcuda.jcublas.cublasSideMode;
-import jcuda.jcusolver.JCusolverDn;
-import jcuda.jcusparse.JCusparse;
 import jcuda.jcusparse.cusparseAction;
 import jcuda.jcusparse.cusparseHandle;
 import jcuda.jcusparse.cusparseIndexBase;
@@ -100,6 +98,34 @@ import jcuda.jcusparse.cusparseIndexBase;
 public class LibMatrixCUDA {
 
 	private static final Log LOG = LogFactory.getLog(LibMatrixCUDA.class.getName());
+	
+	protected static int CUDNN_DATA_TYPE = jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
+	// The below variables are used in CSRPointer, GPUObjects, etc.
+	public static CudaSupportFunctions cudaSupportFunctions = new DoublePrecisionCudaSupportFunctions();
+	public static int sizeOfDataType = jcuda.Sizeof.DOUBLE;
+	public static String customKernelSuffix = "_d";
+	
+	/**
+	 * Sets the internal state based on the DMLScript.DATA_TYPE
+	 * @throws DMLRuntimeException if error
+	 */
+	public static void resetFloatingPointPrecision() throws DMLRuntimeException {
+		if(DMLScript.FLOATING_POINT_PRECISION.equalsIgnoreCase("double")) {
+			LibMatrixCUDA.CUDNN_DATA_TYPE = jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
+			LibMatrixCUDA.cudaSupportFunctions = new DoublePrecisionCudaSupportFunctions();
+			LibMatrixCUDA.sizeOfDataType = jcuda.Sizeof.DOUBLE;
+			LibMatrixCUDA.customKernelSuffix = "_d";
+		}
+		else if(DMLScript.FLOATING_POINT_PRECISION.equalsIgnoreCase("single")) {
+			LibMatrixCUDA.CUDNN_DATA_TYPE = jcuda.jcudnn.cudnnDataType.CUDNN_DATA_FLOAT;
+			LibMatrixCUDA.cudaSupportFunctions = new SinglePrecisionCudaSupportFunctions();
+			LibMatrixCUDA.sizeOfDataType = jcuda.Sizeof.FLOAT;
+			LibMatrixCUDA.customKernelSuffix = "_f";
+		}
+		else {
+			throw new DMLRuntimeException("Unsupported floating point precision: " + DMLScript.FLOATING_POINT_PRECISION);
+		}
+	}
 
 	// Assume Compute Capability 3.0
 	// MAX BLOCKS is 2^31 - 1 For compute capability > 3.0
@@ -110,7 +136,7 @@ public class LibMatrixCUDA {
 	
 	// From CuDNN 5.1 documentation:
 	// The total size of a tensor including the potential padding between dimensions is limited to 2 Giga-elements of type datatype.
-	protected static long maxNumDoublesOfCuDNNTensor = 2000000000;
+	protected static long maxNumElementsOfCuDNNTensor = 2000000000;
 
 	//********************************************************************/
 	//***************************** UTILS ********************************/
@@ -179,7 +205,18 @@ public class LibMatrixCUDA {
 	protected static JCudaKernels getCudaKernels(GPUContext gCtx) throws DMLRuntimeException {
 		return gCtx.getKernels();
 	}
-
+	
+	public static Pointer double2float(GPUContext gCtx, Pointer A, Pointer ret, int numElems) throws DMLRuntimeException {
+		getCudaKernels(gCtx).launchKernel("double2float", ExecutionConfig.getConfigForSimpleVectorOperations(numElems),
+				A, ret, numElems);
+		return ret;
+	}
+	
+	public static Pointer float2double(GPUContext gCtx, Pointer A, Pointer ret, int numElems) throws DMLRuntimeException {
+		getCudaKernels(gCtx).launchKernel("float2double", ExecutionConfig.getConfigForSimpleVectorOperations(numElems),
+				A, ret, numElems);
+		return ret;
+	}
 
 	//********************************************************************/
 	//************************ End of UTILS ******************************/
@@ -191,13 +228,15 @@ public class LibMatrixCUDA {
 
 	private static Pointer _one;
 	private static Pointer _zero;
+	private static int oldDataTypeSize;
 	/**
 	 * Convenience method to get a pointer to value '1.0' on device. Instead of allocating and deallocating it for every kernel invocation.
 	 * @return jcuda pointer
 	 */
-	protected static Pointer one() {
-		if(_one == null) {
-			_one = pointerTo(1.0);
+	public static Pointer one() {
+		if(_one == null || oldDataTypeSize != sizeOfDataType) {
+			_one = dataTypePointerTo(1.0);
+			oldDataTypeSize = sizeOfDataType;
 		}
 		return _one;
 	}
@@ -205,9 +244,10 @@ public class LibMatrixCUDA {
 	 * Convenience method to get a pointer to value '0.0f' on device. Instead of allocating and deallocating it for every kernel invocation.
 	 * @return jcuda pointer
 	 */
-	protected static Pointer zero() {
-		if(_zero == null) {
-			_zero = pointerTo(0.0f);
+	public static Pointer zero() {
+		if(_zero == null  || oldDataTypeSize != sizeOfDataType) {
+			_zero = dataTypePointerTo(0.0);
+			oldDataTypeSize = sizeOfDataType;
 		}
 		return _zero;
 	}
@@ -242,8 +282,16 @@ public class LibMatrixCUDA {
 		return input.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
 	}
 	
-	protected static Pointer pointerTo(double value) {
-		return Pointer.to(new double[] { value });
+	protected static Pointer dataTypePointerTo(double value) {
+		if(sizeOfDataType == Sizeof.DOUBLE) {
+			return Pointer.to(new double[] { value });
+		}
+		else if(sizeOfDataType == Sizeof.FLOAT) {
+			return Pointer.to(new float[] { (float) value });
+		}
+		else {
+			throw new RuntimeException("Unsupported datatype with size " + sizeOfDataType);
+		}
 	}
 	
 
@@ -434,7 +482,7 @@ public class LibMatrixCUDA {
 		long t0=0, t1=0;
 
 		if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
-		JCublas2.cublasDsyrk(getCublasHandle(gCtx), cublasFillMode.CUBLAS_FILL_MODE_LOWER,transa, m, k, one(), A, lda, zero(), C, ldc);
+		cudaSupportFunctions.cublassyrk(getCublasHandle(gCtx), cublasFillMode.CUBLAS_FILL_MODE_LOWER,transa, m, k, one(), A, lda, zero(), C, ldc);
 		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_SYRK_LIB, System.nanoTime() - t0);
 
 		if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
@@ -630,7 +678,7 @@ public class LibMatrixCUDA {
 		}
 		case OP_PLUS_SQ : {
 			// Calculate the squares in a temporary object tmp
-			Pointer tmp = gCtx.allocate(instName, size * Sizeof.DOUBLE);
+			Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
 
 			squareMatrix(gCtx, instName, in, tmp, rlen, clen);
 			// Then do the sum on the temporary object and free it
@@ -729,8 +777,8 @@ public class LibMatrixCUDA {
 		}
 		case OP_VARIANCE : {
 			// Temporary GPU array for
-			Pointer tmp = gCtx.allocate(instName, size * Sizeof.DOUBLE);
-			Pointer tmp2 = gCtx.allocate(instName, size * Sizeof.DOUBLE);
+			Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
+			Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
 
 			switch(reductionDirection) {
 
@@ -758,7 +806,7 @@ public class LibMatrixCUDA {
 
 				squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
 
-				Pointer tmpRow = gCtx.allocate(instName, rlen * Sizeof.DOUBLE);
+				Pointer tmpRow = gCtx.allocate(instName, rlen * sizeOfDataType);
 				reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, clen);
 
 				ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1);
@@ -776,7 +824,7 @@ public class LibMatrixCUDA {
 
 				squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
 
-				Pointer tmpCol = gCtx.allocate(instName, clen * Sizeof.DOUBLE);
+				Pointer tmpCol = gCtx.allocate(instName, clen * sizeOfDataType);
 				reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, clen);
 
 				ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1);
@@ -847,9 +895,9 @@ public class LibMatrixCUDA {
 		int[] tmp = getKernelParamsForReduceAll(gCtx, n);
 		int blocks = tmp[0], threads = tmp[1], sharedMem = tmp[2];
 
-		Pointer tempOut = gCtx.allocate(instName, n * Sizeof.DOUBLE);
+		Pointer tempOut = gCtx.allocate(instName, n * sizeOfDataType);
 
-		long t1=0,t2=0,t3=0;
+		long t1=0,t2=0;
 
 		if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
 		getCudaKernels(gCtx).launchKernel(kernelFunction, new ExecutionConfig(blocks, threads, sharedMem), in, tempOut, n);
@@ -867,11 +915,7 @@ public class LibMatrixCUDA {
 			s = (s + (threads*2-1)) / (threads*2);
 		}
 		double[] result = {-1f};
-
-		if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
-		cudaMemcpy(Pointer.to(result), tempOut, Sizeof.DOUBLE, cudaMemcpyDeviceToHost);
-		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_DEVICE_TO_HOST, System.nanoTime() - t3);
-
+		cudaSupportFunctions.deviceToHost(gCtx, tempOut, result, instName);
 		gCtx.cudaFreeHelper(instName, tempOut);
 		return result[0];
 	}
@@ -946,7 +990,7 @@ public class LibMatrixCUDA {
 		int blocks = (n + (threads * 2 - 1)) / (threads * 2);
 		blocks = Math.min(MAX_BLOCKS, blocks);
 
-		int sharedMemSize = threads * Sizeof.DOUBLE;
+		int sharedMemSize = threads * sizeOfDataType;
 		if (threads <= WARP_SIZE){
 			sharedMemSize *= 2;
 		}
@@ -965,7 +1009,7 @@ public class LibMatrixCUDA {
 		final int MAX_THREADS = getMaxThreads(gCtx);
 		int threads = (cols < MAX_THREADS *2) ? nextPow2((cols + 1)/ 2) : MAX_THREADS;
 		int blocks = rows;
-		int sharedMemSize = threads * Sizeof.DOUBLE;
+		int sharedMemSize = threads * sizeOfDataType;
 		if (threads <= WARP_SIZE){
 			sharedMemSize *=2;
 		}
@@ -979,7 +1023,7 @@ public class LibMatrixCUDA {
 		int threads = Math.min(cols, MAX_THREADS);
 		int blocks = Math.min(cols/MAX_THREADS, MAX_BLOCKS);
 		if (cols % MAX_THREADS != 0) blocks++;
-		int sharedMemSize = threads * Sizeof.DOUBLE;
+		int sharedMemSize = threads * sizeOfDataType;
 		if (threads <= WARP_SIZE){
 			sharedMemSize *=2;
 		}
@@ -1475,7 +1519,7 @@ public class LibMatrixCUDA {
 	private static void deviceCopy(String instName, Pointer src, Pointer dest, int rlen, int clen) throws DMLRuntimeException {
 		long t0=0;
 		if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
-		int size = rlen * clen * Sizeof.DOUBLE;
+		int size = rlen * clen * sizeOfDataType;
 		cudaMemcpy(dest, src, size, cudaMemcpyDeviceToDevice);
 		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_DEVICE_TO_DEVICE, System.nanoTime() - t0);
 	}
@@ -1538,8 +1582,8 @@ public class LibMatrixCUDA {
 			LOG.trace("GPU : dgeam" + ", GPUContext=" + gCtx);
 		}
 
-		Pointer alphaPtr = pointerTo(alpha);
-		Pointer betaPtr = pointerTo(beta);
+		Pointer alphaPtr = dataTypePointerTo(alpha);
+		Pointer betaPtr = dataTypePointerTo(beta);
 		int transa = isLeftTransposed ? CUBLAS_OP_T : CUBLAS_OP_N;
 		int transb = isRightTransposed ? CUBLAS_OP_T : CUBLAS_OP_N;
 
@@ -1584,7 +1628,7 @@ public class LibMatrixCUDA {
 				int nnz = (int)A.nnz;
 				CSRPointer C = CSRPointer.allocateEmpty(gCtx, nnz, n);
 				out.getGPUObject(gCtx).setSparseMatrixCudaPointer(C);
-				cusparseDcsr2csc(getCusparseHandle(gCtx), m, n, nnz, A.val, A.rowPtr, A.colInd, C.val, C.colInd, C.rowPtr, cusparseAction.CUSPARSE_ACTION_NUMERIC, cusparseIndexBase.CUSPARSE_INDEX_BASE_ZERO);
+				cudaSupportFunctions.cusparsecsr2csc(getCusparseHandle(gCtx), m, n, nnz, A.val, A.rowPtr, A.colInd, C.val, C.colInd, C.rowPtr, cusparseAction.CUSPARSE_ACTION_NUMERIC, cusparseIndexBase.CUSPARSE_INDEX_BASE_ZERO);
 			} else {
 				// General case (cusparse does not support accept the transpose operator for dgeam)
 				// TODO: to implement the transposed + dgeam for sparse matrices, they need to be converted to csc, which is effectively a tranpose
@@ -1604,7 +1648,7 @@ public class LibMatrixCUDA {
 				//long sizeOfC = CSRPointer.estimateSize(C.nnz, out.getNumRows());
 				if (GPUStatistics.DISPLAY_STATISTICS)
 					t0 = System.nanoTime();
-				JCusparse.cusparseDcsrgeam(getCusparseHandle(gCtx), m, n, alphaPtr, A.descr, toInt(A.nnz), A.val, A.rowPtr, A.colInd, betaPtr,
+				cudaSupportFunctions.cusparsecsrgeam(getCusparseHandle(gCtx), m, n, alphaPtr, A.descr, toInt(A.nnz), A.val, A.rowPtr, A.colInd, betaPtr,
 						B.descr, toInt(B.nnz), B.val, B.rowPtr, B.colInd, C.descr, C.val, C.rowPtr, C.colInd);
 				//cudaDeviceSynchronize;
 				if (GPUStatistics.DISPLAY_STATISTICS)
@@ -1635,7 +1679,7 @@ public class LibMatrixCUDA {
 			Pointer C = getDensePointer(gCtx, out, instName);
 
 			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
-			JCublas2.cublasDgeam(getCublasHandle(gCtx), transa, transb, m, n, alphaPtr, A, lda, betaPtr, B, ldb, C, ldc);
+			cudaSupportFunctions.cublasgeam(getCublasHandle(gCtx), transa, transb, m, n, alphaPtr, A, lda, betaPtr, B, ldb, C, ldc);
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_DENSE_DGEAM_LIB, System.nanoTime() - t0);
 		}
 	}
@@ -1673,7 +1717,7 @@ public class LibMatrixCUDA {
 	//******************* End of Re-org Functions ************************/
 	//********************************************************************/
 
-	static int toInt(long num) throws DMLRuntimeException {
+	public static int toInt(long num) throws DMLRuntimeException {
 		if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
 			throw new DMLRuntimeException("GPU : Exceeded supported size " + num);
 		}
@@ -1751,8 +1795,8 @@ public class LibMatrixCUDA {
 		long t0 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
 		long retClen = cu - cl + 1;
 		if (inClen == retClen) {
-			cudaMemcpy(outPointer, inPointer.withByteOffset(rl * inClen * Sizeof.DOUBLE), (ru - rl + 1) * inClen
-					* Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
+			cudaMemcpy(outPointer, inPointer.withByteOffset(rl * inClen * sizeOfDataType), (ru - rl + 1) * inClen
+					* sizeOfDataType, cudaMemcpyDeviceToDevice);
 		} else {
 			long retRlen = ru - rl + 1;
 			getCudaKernels(gCtx).launchKernel("slice_dense_dense", ExecutionConfig.getConfigForSimpleVectorOperations(toInt(retRlen*retClen)),
@@ -2255,17 +2299,17 @@ public class LibMatrixCUDA {
 
 			// Matrix-Matrix daxpy
 			long n = in1.getNumRows()*in2.getNumColumns(); // Since A is always a matrix
-			Pointer alphaPtr = pointerTo(constant);
+			Pointer alphaPtr = dataTypePointerTo(constant);
 			// C <- A + alpha*B
 			// becomes
 			// C <- A
 			// C <- alpha*B + C
 			if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
-			cudaMemcpy(C, A, n*((long)jcuda.Sizeof.DOUBLE), cudaMemcpyDeviceToDevice);
+			cudaMemcpy(C, A, n*((long)sizeOfDataType), cudaMemcpyDeviceToDevice);
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_DEVICE_TO_DEVICE, System.nanoTime() - t1);
 
 			if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
-			JCublas2.cublasDaxpy(getCublasHandle(gCtx), toInt(n), alphaPtr, B, 1, C, 1);
+			cudaSupportFunctions.cublasaxpy(getCublasHandle(gCtx), toInt(n), alphaPtr, B, 1, C, 1);
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_DAXPY_LIB, System.nanoTime() - t2);
 		}
 		else {
@@ -2353,15 +2397,15 @@ public class LibMatrixCUDA {
 		// step 3: query working space of geqrf and ormqr
 		if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
 		int[] lwork = {0};
-		JCusolverDn.cusolverDnDgeqrf_bufferSize(gCtx.getCusolverDnHandle(), m, n, A, m, lwork);
+		cudaSupportFunctions.cusolverDngeqrf_bufferSize(gCtx.getCusolverDnHandle(), m, n, A, m, lwork);
 		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_QR_BUFFER, System.nanoTime() - t0);
 
 		// step 4: compute QR factorization
-		Pointer work = gCtx.allocate(instName, lwork[0] * Sizeof.DOUBLE);
-		Pointer tau = gCtx.allocate(instName, m * Sizeof.DOUBLE);
+		Pointer work = gCtx.allocate(instName, lwork[0] * sizeOfDataType);
+		Pointer tau = gCtx.allocate(instName, m * sizeOfDataType);
 		Pointer devInfo = gCtx.allocate(Sizeof.INT);
 		if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
-		JCusolverDn.cusolverDnDgeqrf(gCtx.getCusolverDnHandle(), m, n, A, m, tau, work, lwork[0], devInfo);
+		cudaSupportFunctions.cusolverDngeqrf(gCtx.getCusolverDnHandle(), m, n, A, m, tau, work, lwork[0], devInfo);
 		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_QR, System.nanoTime() - t0);
 
 		int[] qrError = {-1};
@@ -2372,7 +2416,7 @@ public class LibMatrixCUDA {
 
 		// step 5: compute Q^T*B
 		if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
-		JCusolverDn.cusolverDnDormqr(gCtx.getCusolverDnHandle(), cublasSideMode.CUBLAS_SIDE_LEFT, cublasOperation.CUBLAS_OP_T, m, 1, n, A, m, tau, b, m, work, lwork[0], devInfo);
+		cudaSupportFunctions.cusolverDnormqr(gCtx.getCusolverDnHandle(), cublasSideMode.CUBLAS_SIDE_LEFT, cublasOperation.CUBLAS_OP_T, m, 1, n, A, m, tau, b, m, work, lwork[0], devInfo);
 		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ORMQR, System.nanoTime() - t0);
 		cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost);
 		if (qrError[0] != 0) {
@@ -2381,9 +2425,9 @@ public class LibMatrixCUDA {
 
 		// step 6: compute x = R \ Q^T*B
 		if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
-		JCublas2.cublasDtrsm(gCtx.getCublasHandle(),
+		cudaSupportFunctions.cublastrsm(gCtx.getCublasHandle(),
 			cublasSideMode.CUBLAS_SIDE_LEFT, cublasFillMode.CUBLAS_FILL_MODE_UPPER, cublasOperation.CUBLAS_OP_N, cublasDiagType.CUBLAS_DIAG_NON_UNIT,
-			n, 1, pointerTo(1.0), A, m, b, m);
+			n, 1, dataTypePointerTo(1.0), A, m, b, m);
 		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_TRSM, System.nanoTime() - t0);
 
 		if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
@@ -2393,7 +2437,7 @@ public class LibMatrixCUDA {
 		// TODO  : Find a way to assign bTobj directly to the output and set the correct flags so as to not crash
 		// There is an avoidable copy happening here
 		MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, in1.getNumColumns(), 1);
-		cudaMemcpy(out.getGPUObject(gCtx).getJcudaDenseMatrixPtr(), bTobj.getJcudaDenseMatrixPtr(), n * 1 * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
+		cudaMemcpy(out.getGPUObject(gCtx).getJcudaDenseMatrixPtr(), bTobj.getJcudaDenseMatrixPtr(), n * 1 * sizeOfDataType, cudaMemcpyDeviceToDevice);
 
 		gCtx.cudaFreeHelper(instName, work);
 		gCtx.cudaFreeHelper(instName, tau);

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index bb74aa2..7fd766c 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -30,13 +30,11 @@ import static jcuda.jcudnn.JCudnn.cudnnPoolingForward;
 import static jcuda.jcudnn.JCudnn.cudnnSetActivationDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
 import static jcuda.jcudnn.cudnnActivationMode.CUDNN_ACTIVATION_RELU;
-import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
 import static jcuda.jcudnn.cudnnNanPropagation.CUDNN_PROPAGATE_NAN;
 import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
 import static jcuda.runtime.JCuda.cudaMemset;
 import jcuda.CudaException;
 import jcuda.Pointer;
-import jcuda.Sizeof;
 import jcuda.jcudnn.cudnnActivationDescriptor;
 import jcuda.jcudnn.cudnnConvolutionFwdPreference;
 import jcuda.jcudnn.cudnnHandle;
@@ -131,7 +129,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		long CHW = C*H*W; long KPQ = K*P*Q; long CRS = C*R*S; 
 		long NCHW = N*CHW; long NKPQ = N*KPQ; long KCRS = K*CRS;
 		
-		if(NCHW < maxNumDoublesOfCuDNNTensor && NKPQ < maxNumDoublesOfCuDNNTensor && KCRS < maxNumDoublesOfCuDNNTensor) {
+		if(NCHW < maxNumElementsOfCuDNNTensor && NKPQ < maxNumElementsOfCuDNNTensor && KCRS < maxNumElementsOfCuDNNTensor) {
 			// Filter and output are accounted as dense in the memory estimation for conv2d
 			double overhead = isInSparseFormat(gCtx, filter) ? OptimizerUtils.estimateSizeExactSparsity(K, CRS, 1.0) : 0;
 			overhead += isInSparseFormat(gCtx, image) ? OptimizerUtils.estimateSizeExactSparsity(N, CHW, 1.0) : 0;
@@ -155,7 +153,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 					try(LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image)) {
 						for(int n = 0; n < N; n++) {
 							// Perform one-input all-channel conv2d
-							cudnnConv2d(gCtx, instName, imgFetcher.getNthRow(n), filterPointer, dstPointer.withByteOffset(n*KPQ*Sizeof.DOUBLE), algo);
+							cudnnConv2d(gCtx, instName, imgFetcher.getNthRow(n), filterPointer, dstPointer.withByteOffset(n*KPQ*sizeOfDataType), algo);
 						}
 					}
 				}
@@ -180,7 +178,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 	 */
 	private static void throwCuDNNDimensionError(long dim1, long dim2, long dim3, long dim4) throws DMLRuntimeException {
 		throw new DMLRuntimeException("The dimensions of input/output matrices is too large to execute a CuDNN kernel. "
-				+ "Max CuDNN matrix size:" + maxNumDoublesOfCuDNNTensor + ". "
+				+ "Max CuDNN matrix size:" + maxNumElementsOfCuDNNTensor + ". "
 				+ "Given input matrix dimensions: [" + dim1 + "," + dim2 + "]. Output dimension:  [" + dim3 + "," + dim4 + "].");
 	}
 
@@ -197,7 +195,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 	 */
 	private static void throwCuDNNDimensionError(long dim1, long dim2, long dim3, long dim4, long dim5, long dim6) throws DMLRuntimeException {
 		throw new DMLRuntimeException("The dimensions of input/output matrices is too large to execute a CuDNN kernel. "
-				+ "Max CuDNN matrix size:" + maxNumDoublesOfCuDNNTensor + ". "
+				+ "Max CuDNN matrix size:" + maxNumElementsOfCuDNNTensor + ". "
 				+ "Given input matrix dimensions: [" + dim1 + "," + dim2 + "], [" + dim3 + "," + dim4 + "]. Output dimension: [" + dim5 + "," + dim6 + "]");
 	}
 
@@ -270,7 +268,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		long NCHW = N*CHW; long NKPQ = N*KPQ; long KCRS = K*CRS;
 		
 		
-		if(NCHW < maxNumDoublesOfCuDNNTensor && NKPQ < maxNumDoublesOfCuDNNTensor && KCRS < maxNumDoublesOfCuDNNTensor) {
+		if(NCHW < maxNumElementsOfCuDNNTensor && NKPQ < maxNumElementsOfCuDNNTensor && KCRS < maxNumElementsOfCuDNNTensor) {
 			Pointer dwPointer = getDensePointerForCuDNN(gCtx, outputBlock, instName);
 			double overhead = isInSparseFormat(gCtx, image) ? OptimizerUtils.estimateSizeExactSparsity(N, CHW, 1.0) : 0;
 			overhead += isInSparseFormat(gCtx, dout) ? OptimizerUtils.estimateSizeExactSparsity(N, KPQ, 1.0) : 0;
@@ -292,10 +290,10 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 					try(LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
 						LibMatrixCuDNNInputRowFetcher doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout)) {
 						// Perform one-input conv2dBackwardFilter
-						Pointer tempdwPointer = gCtx.allocate(KCRS*Sizeof.DOUBLE);
+						Pointer tempdwPointer = gCtx.allocate(KCRS*sizeOfDataType);
 						for(int n = 0; n < N; n++) {
 							long t0 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
-							cudaMemset(tempdwPointer, 0, KCRS*Sizeof.DOUBLE);
+							cudaMemset(tempdwPointer, 0, KCRS*sizeOfDataType);
 							if(GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_SET_ZERO, System.nanoTime() - t0);
 							// Perform one-input conv2dBackwardFilter
 							cudnnConv2dBackwardFilter(gCtx, instName, imgFetcher.getNthRow(n), doutFetcher.getNthRow(n), tempdwPointer, algo);
@@ -376,7 +374,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		long CHW = C*H*W; long KPQ = K*P*Q; long CRS = C*R*S; 
 		long NCHW = N*CHW; long NKPQ = N*KPQ; long KCRS = K*CRS;
 
-		if(NCHW < maxNumDoublesOfCuDNNTensor && NKPQ < maxNumDoublesOfCuDNNTensor && KCRS < maxNumDoublesOfCuDNNTensor) {
+		if(NCHW < maxNumElementsOfCuDNNTensor && NKPQ < maxNumElementsOfCuDNNTensor && KCRS < maxNumElementsOfCuDNNTensor) {
 			// Filter and output are accounted as dense in the memory estimation for conv2dBackwardData
 			double overhead = isInSparseFormat(gCtx, filter) ? OptimizerUtils.estimateSizeExactSparsity(K, CRS, 1.0) : 0;
 			overhead += isInSparseFormat(gCtx, dout) ? OptimizerUtils.estimateSizeExactSparsity(N, KPQ, 1.0) : 0;
@@ -398,7 +396,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 				else {
 					try(LibMatrixCuDNNInputRowFetcher doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout)) {
 						for(int n = 0; n < N; n++) {
-							cudnnConv2dBackwardData(gCtx, instName, doutFetcher.getNthRow(n), filterPointer, dstPointer.withByteOffset(n*CHW*Sizeof.DOUBLE), algo);
+							cudnnConv2dBackwardData(gCtx, instName, doutFetcher.getNthRow(n), filterPointer, dstPointer.withByteOffset(n*CHW*sizeOfDataType), algo);
 						}
 					}
 				}
@@ -468,7 +466,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		long CHW = C*H*W; long CPQ = C*P*Q;  
 		long NCHW = N*CHW; long NCPQ = N*CPQ; 
 
-		if(NCHW < maxNumDoublesOfCuDNNTensor && NCPQ < maxNumDoublesOfCuDNNTensor) {
+		if(NCHW < maxNumElementsOfCuDNNTensor && NCPQ < maxNumElementsOfCuDNNTensor) {
 			// Filter and output are accounted as dense in the memory estimation for conv2dBackwardData
 			long overhead = isInSparseFormat(gCtx, image) ? OptimizerUtils.estimateSizeExactSparsity(N, CHW, 1.0) : 0;
 			Pointer y = getDensePointerForCuDNN(gCtx, outputBlock, instName);
@@ -479,7 +477,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 			else {
 				LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
 				for(int n = 0; n < N; n++) {
-					cudnnMaxpooling(gCtx, instName, imgFetcher.getNthRow(n), y.withByteOffset(n*CPQ*Sizeof.DOUBLE), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+					cudnnMaxpooling(gCtx, instName, imgFetcher.getNthRow(n), y.withByteOffset(n*CPQ*sizeOfDataType), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
 				}
 				imgFetcher.close();
 			}
@@ -545,7 +543,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		long CHW = C*H*W; long CPQ = C*P*Q;  
 		long NCHW = N*CHW; long NCPQ = N*CPQ; 
 
-		if(NCHW < maxNumDoublesOfCuDNNTensor && NCPQ < maxNumDoublesOfCuDNNTensor) {
+		if(NCHW < maxNumElementsOfCuDNNTensor && NCPQ < maxNumElementsOfCuDNNTensor) {
 			// Filter and output are accounted as dense in the memory estimation for conv2dBackwardData
 			long overhead = isInSparseFormat(gCtx, image) ? OptimizerUtils.estimateSizeExactSparsity(N, CHW, 1.0) : 0;
 			overhead += isInSparseFormat(gCtx, dout) ? OptimizerUtils.estimateSizeExactSparsity(N, CPQ, 1.0) : 0;
@@ -560,7 +558,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 				LibMatrixCuDNNInputRowFetcher doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout);
 				for(int n = 0; n < N; n++) {
 					cudnnMaxpoolingBackward(gCtx, instName, imgFetcher.getNthRow(n), doutFetcher.getNthRow(n), 
-							dx.withByteOffset(n*CHW*Sizeof.DOUBLE), 
+							dx.withByteOffset(n*CHW*sizeOfDataType), 
 							1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
 				}
 				// Deallocate temporary array to hold one element of input
@@ -591,7 +589,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 			
 			// Calling PoolForward first, y is one of the inputs for poolBackward
 			// TODO: Remove calling poolForward after necessary changes at language level for poolBackward
-			long numBytes = N*C*P*Q*Sizeof.DOUBLE;
+			long numBytes = N*C*P*Q*sizeOfDataType;
 			y = gCtx.allocate(numBytes);
 			
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
@@ -668,7 +666,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		MatrixObject output = ec.getMatrixObject(outputName);
 		getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, in.getNumRows(), in.getNumColumns()); // Allocated the dense output matrix
 		long t0=0;
-		if(N*CHW >= maxNumDoublesOfCuDNNTensor) {
+		if(N*CHW >= maxNumElementsOfCuDNNTensor) {
 			if(LOG.isTraceEnabled()) {
 				LOG.trace("GPU : relu custom kernel" + ", GPUContext=" + gCtx);
 			}
@@ -684,7 +682,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		else {
 			cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
 			cudnnCreateTensorDescriptor(tensorDescriptor);
-			cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, toInt(N), 1, 1, toInt(CHW));
+			cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_TYPE, toInt(N), 1, 1, toInt(CHW));
 			cudnnReLU(gCtx, instName, in, getDensePointerForCuDNN(gCtx, output, instName), tensorDescriptor);
 			cudnnDestroyTensorDescriptor(tensorDescriptor);
 		}
@@ -701,7 +699,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 	 */
 	protected static Pointer getDensePointerForCuDNN(GPUContext gCtx, MatrixObject image, String instName) throws DMLRuntimeException {
 		long numElems = image.getNumRows()*image.getNumColumns();
-		if(numElems > maxNumDoublesOfCuDNNTensor) {
+		if(numElems > maxNumElementsOfCuDNNTensor) {
 			throw new DMLRuntimeException("CuDNN restriction: the size of input tensor cannot have greater than 2 giga-elements, but has " + numElems + " (i.e. [" + image.getNumRows() + " X " + image.getNumColumns() + "]). Hint: try reducing the mini-batch size.");
 		}
 		return getDensePointer(gCtx, image, instName);
@@ -717,4 +715,4 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		if(status != cudnnStatus.CUDNN_STATUS_SUCCESS)
 			throw new DMLRuntimeException("Error status returned by CuDNN:" + jcuda.jcudnn.cudnnStatus.stringFor(status));
 	}
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
index f49433d..ee22541 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
@@ -40,7 +40,6 @@ import static jcuda.jcudnn.JCudnn.cudnnSetConvolution2dDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetFilter4dDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
 import static jcuda.jcudnn.cudnnConvolutionMode.CUDNN_CROSS_CORRELATION;
-import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
 import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
 
 /**
@@ -255,14 +254,14 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseab
 	private static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) throws DMLRuntimeException {
 		cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
 		cudnnCreateTensorDescriptor(tensorDescriptor);
-		cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
+		cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, LibMatrixCUDA.CUDNN_DATA_TYPE, N, C, H, W);
 		return tensorDescriptor;
 	}
 	
 	private static cudnnFilterDescriptor allocateFilterDescriptor(int K, int C, int R, int S) {
 		cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
 		cudnnCreateFilterDescriptor(filterDesc);
-		cudnnSetFilter4dDescriptor(filterDesc, CUDNN_DATA_DOUBLE, CUDNN_TENSOR_NCHW, K, C, R, S);
+		cudnnSetFilter4dDescriptor(filterDesc, LibMatrixCUDA.CUDNN_DATA_TYPE, CUDNN_TENSOR_NCHW, K, C, R, S);
 		return filterDesc;
 	}
 	

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
index 581607e..5121c87 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
@@ -20,8 +20,6 @@ package org.apache.sysml.runtime.matrix.data;
 
 import static jcuda.runtime.JCuda.cudaMemset;
 import jcuda.Pointer;
-import jcuda.Sizeof;
-
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
@@ -32,7 +30,7 @@ import org.apache.sysml.utils.GPUStatistics;
 /**
  * Performs a slice operation: out = in[(n+1):(n+1), 1:numColumns]
  */
-public class LibMatrixCuDNNInputRowFetcher implements java.lang.AutoCloseable {
+public class LibMatrixCuDNNInputRowFetcher extends LibMatrixCUDA implements java.lang.AutoCloseable {
 	GPUContext gCtx; String instName; int numColumns; boolean isInputInSparseFormat; 
 	Object inPointer; // can be either CSRPointer or Pointer 
 	Pointer outPointer;
@@ -50,7 +48,7 @@ public class LibMatrixCuDNNInputRowFetcher implements java.lang.AutoCloseable {
 		numColumns = LibMatrixCUDA.toInt(image.getNumColumns());
 		isInputInSparseFormat = LibMatrixCUDA.isInSparseFormat(gCtx, image);
 		inPointer = isInputInSparseFormat ? LibMatrixCUDA.getSparsePointer(gCtx, image, instName) : LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
-		outPointer = gCtx.allocate(numColumns*Sizeof.DOUBLE);
+		outPointer = gCtx.allocate(numColumns*sizeOfDataType);
 	}
 	/**
 	 * Copy the nth row and return the dense pointer
@@ -62,7 +60,7 @@ public class LibMatrixCuDNNInputRowFetcher implements java.lang.AutoCloseable {
 		if(isInputInSparseFormat) {
 			jcuda.runtime.JCuda.cudaDeviceSynchronize();
 			long t0 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
-			cudaMemset(outPointer, 0, numColumns*Sizeof.DOUBLE);
+			cudaMemset(outPointer, 0, numColumns*sizeOfDataType);
 			jcuda.runtime.JCuda.cudaDeviceSynchronize();
 			if(GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_SET_ZERO, System.nanoTime() - t0);
 			LibMatrixCUDA.sliceSparseDense(gCtx, instName, (CSRPointer)inPointer, outPointer, n, n, 0, LibMatrixCUDA.toInt(numColumns-1), numColumns);

http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
index f817bd5..d4b213f 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
@@ -24,7 +24,6 @@ import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetPooling2dDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
-import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
 import static jcuda.jcudnn.cudnnNanPropagation.CUDNN_PROPAGATE_NAN;
 import static jcuda.jcudnn.cudnnPoolingMode.CUDNN_POOLING_MAX;
 import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
@@ -141,7 +140,7 @@ public class LibMatrixCuDNNPoolingDescriptors implements java.lang.AutoCloseable
 	private static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) throws DMLRuntimeException {
 		cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
 		cudnnCreateTensorDescriptor(tensorDescriptor);
-		cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
+		cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, LibMatrixCUDA.CUDNN_DATA_TYPE, N, C, H, W);
 		return tensorDescriptor;
 	}