You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/09/07 19:49:59 UTC

[2/5] systemml git commit: [SYSTEMML-540] Support sparse GPU conv2d as well as fix memory estimation of convolution operations

http://git-wip-us.apache.org/repos/asf/systemml/blob/772d9302/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 09ffe9f..a362364 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -21,37 +21,6 @@ package org.apache.sysml.runtime.matrix.data;
 
 import static jcuda.jcublas.cublasOperation.CUBLAS_OP_N;
 import static jcuda.jcublas.cublasOperation.CUBLAS_OP_T;
-import static jcuda.jcudnn.JCudnn.cudnnActivationForward;
-import static jcuda.jcudnn.JCudnn.cudnnBatchNormalizationBackward;
-import static jcuda.jcudnn.JCudnn.cudnnBatchNormalizationForwardInference;
-import static jcuda.jcudnn.JCudnn.cudnnBatchNormalizationForwardTraining;
-import static jcuda.jcudnn.JCudnn.cudnnConvolutionBackwardData;
-import static jcuda.jcudnn.JCudnn.cudnnConvolutionBackwardFilter;
-import static jcuda.jcudnn.JCudnn.cudnnConvolutionForward;
-import static jcuda.jcudnn.JCudnn.cudnnCreateActivationDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnCreateConvolutionDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnCreateFilterDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnCreatePoolingDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnDestroyConvolutionDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnDestroyFilterDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardDataWorkspaceSize;
-import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize;
-import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardWorkspaceSize;
-import static jcuda.jcudnn.JCudnn.cudnnPoolingBackward;
-import static jcuda.jcudnn.JCudnn.cudnnPoolingForward;
-import static jcuda.jcudnn.JCudnn.cudnnSetActivationDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetConvolution2dDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetFilter4dDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetPooling2dDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
-import static jcuda.jcudnn.cudnnActivationMode.CUDNN_ACTIVATION_RELU;
-import static jcuda.jcudnn.cudnnConvolutionMode.CUDNN_CROSS_CORRELATION;
-import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
-import static jcuda.jcudnn.cudnnNanPropagation.CUDNN_PROPAGATE_NAN;
-import static jcuda.jcudnn.cudnnPoolingMode.CUDNN_POOLING_MAX;
-import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
 import static jcuda.jcusparse.JCusparse.cusparseDcsr2csc;
 import static jcuda.jcusparse.JCusparse.cusparseDcsrgemm;
 import static jcuda.jcusparse.JCusparse.cusparseDcsrmv;
@@ -116,7 +85,6 @@ import org.apache.sysml.runtime.util.IndexRange;
 import org.apache.sysml.utils.GPUStatistics;
 import org.apache.sysml.utils.Statistics;
 
-import jcuda.CudaException;
 import jcuda.Pointer;
 import jcuda.Sizeof;
 import jcuda.jcublas.JCublas2;
@@ -125,15 +93,6 @@ import jcuda.jcublas.cublasFillMode;
 import jcuda.jcublas.cublasHandle;
 import jcuda.jcublas.cublasOperation;
 import jcuda.jcublas.cublasSideMode;
-import jcuda.jcudnn.cudnnActivationDescriptor;
-import jcuda.jcudnn.cudnnBatchNormMode;
-import jcuda.jcudnn.cudnnConvolutionDescriptor;
-import jcuda.jcudnn.cudnnConvolutionFwdPreference;
-import jcuda.jcudnn.cudnnFilterDescriptor;
-import jcuda.jcudnn.cudnnHandle;
-import jcuda.jcudnn.cudnnPoolingDescriptor;
-import jcuda.jcudnn.cudnnStatus;
-import jcuda.jcudnn.cudnnTensorDescriptor;
 import jcuda.jcusolver.JCusolverDn;
 import jcuda.jcusparse.JCusparse;
 import jcuda.jcusparse.cusparseAction;
@@ -155,6 +114,10 @@ public class LibMatrixCUDA {
 	private static int _MAX_THREADS = -1;
 	private static int _MAX_BLOCKS  = -1;
 	private static int _WARP_SIZE 	= -1;
+	
+	// From CuDNN 5.1 documentation:
+	// The total size of a tensor including the potential padding between dimensions is limited to 2 Giga-elements of type datatype.
+	protected static long maxNumDoublesOfCuDNNTensor = 2000000000;
 
 	//********************************************************************/
 	//***************************** UTILS ********************************/
@@ -220,11 +183,7 @@ public class LibMatrixCUDA {
 		return gCtx.getCublasHandle();
 	}
 
-	private static cudnnHandle getCudnnHandle(GPUContext gCtx) throws DMLRuntimeException {
-		return gCtx.getCudnnHandle();
-	}
-
-	private static JCudaKernels getCudaKernels(GPUContext gCtx) throws DMLRuntimeException {
+	protected static JCudaKernels getCudaKernels(GPUContext gCtx) throws DMLRuntimeException {
 		return gCtx.getKernels();
 	}
 
@@ -237,17 +196,13 @@ public class LibMatrixCUDA {
 	//***************** DEEP LEARNING Operators **************************/
 	//********************************************************************/
 
-
-
-	private static int CONVOLUTION_PREFERENCE = cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
-
 	private static Pointer _one;
 	private static Pointer _zero;
 	/**
 	 * Convenience method to get a pointer to value '1.0' on device. Instead of allocating and deallocating it for every kernel invocation.
 	 * @return jcuda pointer
 	 */
-	private static Pointer one() {
+	protected static Pointer one() {
 		if(_one == null) {
 			_one = pointerTo(1.0);
 		}
@@ -257,7 +212,7 @@ public class LibMatrixCUDA {
 	 * Convenience method to get a pointer to value '0.0f' on device. Instead of allocating and deallocating it for every kernel invocation.
 	 * @return jcuda pointer
 	 */
-	private static Pointer zero() {
+	protected static Pointer zero() {
 		if(_zero == null) {
 			_zero = pointerTo(0.0f);
 		}
@@ -265,56 +220,6 @@ public class LibMatrixCUDA {
 	}
 
 	/**
-	 * Convenience method to get tensor descriptor from underlying GPUObject
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param mat matrix object
-	 * @param N number of images
-	 * @param C number of channels
-	 * @param H height
-	 * @param W width
-	 * @return cudnn tensor descriptor
-	 * @throws DMLRuntimeException if the input descriptor and matrix dimensions don't match
-	 */
-	private static cudnnTensorDescriptor allocateTensorDescriptor(GPUContext gCtx, MatrixObject mat, int N, int C, int H, int W) throws DMLRuntimeException {
-		if(mat.getNumRows() != N || mat.getNumColumns() != C*H*W) {
-			throw new DMLRuntimeException("Mismatch descriptor-matrix dimensions:" + mat.getNumRows() + " != " + N
-					+ " || " + mat.getNumColumns() + " != " + (C*H*W));
-		}
-		return mat.getGPUObject(gCtx).allocateTensorDescriptor(N, C, H, W);
-	}
-
-	/**
-	 * Convenience method to get tensor descriptor
-	 * @param N number of images
-	 * @param C number of channels
-	 * @param H height
-	 * @param W width
-	 * @return cudnn tensor descriptor
-	 * @throws DMLRuntimeException if the input descriptor and matrix dimensions don't match
-	 */
-	private static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) throws DMLRuntimeException {
-		cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
-		cudnnCreateTensorDescriptor(tensorDescriptor);
-		cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
-		return tensorDescriptor;
-	}
-
-	/**
-	 * Convenience method to get jcudaDenseMatrixPtr. This method explicitly converts sparse to dense format, so use it judiciously.
-	 * @param gCtx a valid {@link GPUContext}
-	 * @param image input matrix object
-	 * @param isForCuDNN true if the dense pointer is to be used by a CuDNN kernel
-	 * @return jcuda pointer
-	 * @throws DMLRuntimeException if error occurs while sparse to dense conversion
-	 */
-	private static Pointer getDensePointer(GPUContext gCtx, MatrixObject image, boolean isForCuDNN, String instName) throws DMLRuntimeException {
-		if(isForCuDNN && image.getNumRows()*image.getNumColumns() > numDoublesIn2GB) {
-			throw new DMLRuntimeException("CuDNN restriction: the size of input tensor cannot be greater than 2GB. Hint: try reducing the mini-batch size.");
-		}
-		return getDensePointer(gCtx, image, instName);
-	}
-
-	/**
 	 * Convenience method to get jcudaDenseMatrixPtr. This method explicitly converts sparse to dense format, so use it judiciously.
 	 * @param gCtx a valid {@link GPUContext}
 	 * @param input input matrix object
@@ -322,7 +227,7 @@ public class LibMatrixCUDA {
 	 * @return jcuda pointer
 	 * @throws DMLRuntimeException if error occurs while sparse to dense conversion
 	 */
-	private static Pointer getDensePointer(GPUContext gCtx, MatrixObject input, String instName) throws DMLRuntimeException {
+	protected static Pointer getDensePointer(GPUContext gCtx, MatrixObject input, String instName) throws DMLRuntimeException {
 		if(isInSparseFormat(gCtx, input)) {
 			input.getGPUObject(gCtx).sparseToDense(instName);
 		}
@@ -337,222 +242,17 @@ public class LibMatrixCUDA {
 	 * @return a sparse matrix pointer
 	 * @throws DMLRuntimeException if error occurs
 	 */
-	private static CSRPointer getSparsePointer(GPUContext gCtx, MatrixObject input, String instName) throws DMLRuntimeException {
+	protected static CSRPointer getSparsePointer(GPUContext gCtx, MatrixObject input, String instName) throws DMLRuntimeException {
 		if(!isInSparseFormat(gCtx, input)) {
 			input.getGPUObject(gCtx).denseToSparse();
 		}
 		return input.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
 	}
-
-	/**
-	 * Convenience method for checking the status of CuDNN kernel.
-	 *
-	 * @param status status returned by CuDNN
-	 * @throws DMLRuntimeException if status is not CUDNN_STATUS_SUCCESS
-	 */
-	private static void checkStatus(int status) throws DMLRuntimeException {
-		if(status != cudnnStatus.CUDNN_STATUS_SUCCESS)
-			throw new DMLRuntimeException("Error status returned by CuDNN:" + jcuda.jcudnn.cudnnStatus.stringFor(status));
-	}
-
-	/**
-	 * Does a 2D convolution followed by a bias_add
-	 *
-	 * @param gCtx     a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param image    input image matrix object
-	 * @param bias     bias matrix object
-	 * @param filter   filter matrix object
-	 * @param output   output matrix object
-	 * @param N        number of input images
-	 * @param C        number of channels
-	 * @param H        height of each image
-	 * @param W        width of each image
-	 * @param K        number of output "channels"
-	 * @param R        height of filter
-	 * @param S        width of filter
-	 * @param pad_h    padding height
-	 * @param pad_w    padding width
-	 * @param stride_h stride height
-	 * @param stride_w string width
-	 * @param P        output height
-	 * @param Q        output width
-	 * @throws DMLRuntimeException if error
-	 */
-	public static void conv2dBiasAdd(GPUContext gCtx, String instName, MatrixObject image, MatrixObject bias, MatrixObject filter, MatrixObject output, int N, int C, int H, int W,
-			int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q)
-					throws DMLRuntimeException {
-		/*
-		int rows = (int) output.getNumRows();
-		int cols = (int) output.getNumColumns();
-		long size  = rows * cols * Sizeof.DOUBLE;
-
-		Pointer imagePointer = getDensePointer(image, instName);
-		Pointer biasPointer = getDensePointer(bias, instName);
-		Pointer outputPointer = getDensePointer(output, instName);
-		Pointer filterPointer = getDensePointer(filter, instName);
-
-		Pointer tmp = allocate(size);
-
-		conv2d(instName, imagePointer, filterPointer, tmp, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
-		cudaDeviceSynchronize();
-
-		long k1 = bias.getNumColumns();
-		if(k1 != bias.getNumColumns() || bias.getNumColumns() != 1 || cols % k1 != 0) {
-			throw new DMLRuntimeException("Incorrect inputs for bias_add: input[" + rows + " X " + cols + "] and bias[" + K + " X " + bias.getNumColumns() + "]");
-		}
-		// biasAdd(instName, output, bias, output);
-		biasAdd(instName, tmp, biasPointer, outputPointer, rows, cols, (int)k1);
-
-		cudaFreeHelper(tmp);
-		 */
-		LOG.trace("GPU : conv2dBiasAdd" + ", GPUContext=" + gCtx);
-		conv2d(gCtx, instName, image, filter, output, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
-		//cudaDeviceSynchronize;
-		biasAdd(gCtx, instName, output, bias, output);
-	}
-
-	public static void conv2d(GPUContext gCtx, String instName, MatrixObject image, MatrixObject filter, MatrixObject outputBlock, int N, int C, int H, int W,
-			int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q)
-					throws DMLRuntimeException {
-		Pointer imagePointer = getDensePointer(gCtx, image, true, instName);
-		Pointer filterPointer = getDensePointer(gCtx, filter, true, instName);
-		Pointer dstPointer = getDensePointer(gCtx, outputBlock, true, instName);
-
-		conv2d(gCtx, instName, imagePointer, filterPointer, dstPointer, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
-	}
-
-	/**
-	 * Performs 2D convolution
-	 * Takes up an insignificant amount of intermediate space when CONVOLUTION_PREFERENCE is set to CUDNN_CONVOLUTION_FWD_NO_WORKSPACE
-	 * Intermediate space is required by the filter descriptor and convolution descriptor which are metadata structures and don't scale with the size of the input
-	 *
-	 * @param gCtx     a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param image    the input matrix (or image) allocated on the GPU
-	 * @param filter   the filter allocated on the GPU
-	 * @param output   the output matrix allocated on the GPU
-	 * @param N        number of input images
-	 * @param C        number of channels
-	 * @param H        height of each image
-	 * @param W        width of each image
-	 * @param K        number of output "channels"
-	 * @param R        height of filter
-	 * @param S        width of filter
-	 * @param pad_h    padding height
-	 * @param pad_w    padding width
-	 * @param stride_h stride height
-	 * @param stride_w string width
-	 * @param P        output height
-	 * @param Q        output width
-	 * @throws DMLRuntimeException if error
-	 */
-	public static void conv2d(GPUContext gCtx, String instName, Pointer image, Pointer filter, Pointer output, int N,
-			int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q)
-					throws DMLRuntimeException {
-		LOG.trace("GPU : conv2d" + ", GPUContext=" + gCtx);
-		cudnnFilterDescriptor filterDesc = null;
-		cudnnConvolutionDescriptor convDesc = null;
-		Pointer workSpace = null;
-		long sizeInBytes = 0;
-		try {
-			long t1 = 0, t2 = 0;
-			// Allocate descriptors
-			if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
-			cudnnTensorDescriptor srcTensorDesc = allocateTensorDescriptor(N, C, H, W);
-			cudnnTensorDescriptor dstTensorDesc = allocateTensorDescriptor(N, K, P, Q);
-			filterDesc = allocateFilterDescriptor(K, C, R, S);
-
-			int padding[] = {pad_h, pad_w};
-			int strides[] = {stride_h, stride_w};
-			convDesc = allocateConvolutionDescriptor(padding, strides);
-
-			// Select the best algorithm depending on the data and supported CUDA
-
-			int algo = -1;
-			workSpace = new Pointer();
-
-			if (CONVOLUTION_PREFERENCE == cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_NO_WORKSPACE) {
-				algo = jcuda.jcudnn.cudnnConvolutionFwdAlgo.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
-			} else if (CONVOLUTION_PREFERENCE == cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_PREFER_FASTEST) {
-				int[] algos = {-1};
-				// TODO: Look into FFt, Winograd, etc
-				// Also ensure that GPU has enough memory to allocate memory
-				long sizeInBytesArray[] = {0};
-				jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardAlgorithm(getCudnnHandle(gCtx), srcTensorDesc, filterDesc, convDesc, dstTensorDesc,
-						CONVOLUTION_PREFERENCE, sizeInBytesArray[0], algos);
-				cudnnGetConvolutionForwardWorkspaceSize(getCudnnHandle(gCtx), srcTensorDesc, filterDesc, convDesc, dstTensorDesc, algos[0], sizeInBytesArray);
-				if (sizeInBytesArray[0] != 0)
-					workSpace = gCtx.allocate(sizeInBytesArray[0]);
-				sizeInBytes = sizeInBytesArray[0];
-			} else if (CONVOLUTION_PREFERENCE == cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT) {
-				throw new DMLRuntimeException("CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT is not implemented");
-			} else {
-				throw new DMLRuntimeException("Unsupported preference criteria for convolution");
-			}
-			if (GPUStatistics.DISPLAY_STATISTICS)
-				GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-			if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
-			int status = cudnnConvolutionForward(getCudnnHandle(gCtx), one(),
-					srcTensorDesc, image,
-					filterDesc, filter,
-					convDesc, algo, workSpace, sizeInBytes, zero(),
-					dstTensorDesc, output);
-			if (GPUStatistics.DISPLAY_STATISTICS)
-				GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CONVOLUTION_FORWARD_LIB, System.nanoTime() - t2);
-			if (status != cudnnStatus.CUDNN_STATUS_SUCCESS) {
-				throw new DMLRuntimeException("Could not executed cudnnConvolutionForward: " + cudnnStatus.stringFor(status));
-			}
-		} catch (CudaException e) {
-			throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), e);
-		} finally {
-			long t3 = 0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
-			if (filterDesc != null)
-				cudnnDestroyFilterDescriptor(filterDesc);
-			if (convDesc != null)
-				cudnnDestroyConvolutionDescriptor(convDesc);
-			if (workSpace != null && sizeInBytes != 0)
-				gCtx.cudaFreeHelper(instName, workSpace);
-			if (GPUStatistics.DISPLAY_STATISTICS)
-				GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
-		}
-	}
-
-	private static cudnnConvolutionDescriptor allocateConvolutionDescriptor(int padding [], int strides []) {
-		cudnnConvolutionDescriptor convDesc = new cudnnConvolutionDescriptor();
-		cudnnCreateConvolutionDescriptor(convDesc);
-		cudnnSetConvolution2dDescriptor(convDesc, padding[0], padding[1], strides[0], strides[1], 1, 1, CUDNN_CROSS_CORRELATION);
-		return convDesc;
-	}
-
-	private static Pointer pointerTo(double value) {
+	
+	protected static Pointer pointerTo(double value) {
 		return Pointer.to(new double[] { value });
 	}
-
-	private static cudnnFilterDescriptor allocateFilterDescriptor(int K, int C, int R, int S) {
-		cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
-		cudnnCreateFilterDescriptor(filterDesc);
-		cudnnSetFilter4dDescriptor(filterDesc, CUDNN_DATA_DOUBLE, CUDNN_TENSOR_NCHW, K, C, R, S);
-		return filterDesc;
-	}
-
-	/**
-	 * allocates pooling descriptor, used in poolingForward and poolingBackward
-	 * @param R			pooling window height
-	 * @param S			pooling window width
-	 * @param pad_h		vertical padding
-	 * @param pad_w		horizontal padding
-	 * @param stride_h	pooling vertical stride
-	 * @param stride_w	pooling horizontal stride
-	 * @return cudnn pooling descriptor
-	 */
-	private static cudnnPoolingDescriptor allocatePoolingDescriptor(int R, int S, int pad_h, int pad_w, int stride_h, int stride_w) {
-		cudnnPoolingDescriptor poolingDesc = new cudnnPoolingDescriptor();
-		cudnnCreatePoolingDescriptor(poolingDesc);
-		cudnnSetPooling2dDescriptor(poolingDesc, CUDNN_POOLING_MAX, CUDNN_PROPAGATE_NAN, R, S, pad_h, pad_w, stride_h, stride_w);
-		return poolingDesc;
-	}
+	
 
 	/**
 	 * This method computes the backpropagation errors for previous layer of relu operation
@@ -669,598 +369,7 @@ public class LibMatrixCUDA {
 				image, bias, output, rows, cols, PQ);
 		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_RELU_BACKWARD_KERNEL, System.nanoTime() - t1);
 	}
-
-	private static void validateBatchNormalizationDimensions(MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar, int C) throws DMLRuntimeException {
-		if(scale.getNumRows() != 1 || scale.getNumColumns() != C) {
-			throw new DMLRuntimeException("Incorrect dimensions for scale");
-		}
-		if(bias.getNumRows() != 1 || bias.getNumColumns() != C) {
-			throw new DMLRuntimeException("Incorrect dimensions for bias");
-		}
-		if(runningMean.getNumRows() != 1 || runningMean.getNumColumns() != C) {
-			throw new DMLRuntimeException("Incorrect dimensions for running mean");
-		}
-		if(runningVar.getNumRows() != 1 || runningVar.getNumColumns() != C) {
-			throw new DMLRuntimeException("Incorrect dimensions for running variance");
-		}
-	}
-
-	/**
-	 * Performs the forward BatchNormalization layer computation for inference
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName name of the instruction
-	 * @param image input image
-	 * @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1]
-	 * @param bias bias (as per CuDNN) and beta as per original paper: shape [1, C, 1, 1]
-	 * @param runningMean running mean accumulated during training phase: shape [1, C, 1, 1]
-	 * @param runningVar running variance accumulated during training phase: shape [1, C, 1, 1]
-	 * @param ret normalized input
-	 * @param epsilon epsilon value used in the batch normalization formula
-	 * @throws DMLRuntimeException if error occurs
-	 */
-	public static void batchNormalizationForwardInference(GPUContext gCtx, String instName, MatrixObject image,
-			MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar,
-			MatrixObject ret, double epsilon) throws DMLRuntimeException {
-		LOG.trace("GPU : batchNormalizationForwardInference" + ", GPUContext=" + gCtx);
-		int mode = cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL;
-
-		int N = toInt(image.getNumRows());
-		int C = toInt(scale.getNumColumns());
-		long CHW = image.getNumColumns();
-		validateBatchNormalizationDimensions(scale, bias, runningMean, runningVar, C);
-
-		// Allocate descriptors
-		cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW,
-				new MatrixObject[] {image},  new MatrixObject[] {ret});
-		cudnnTensorDescriptor scaleTensorDesc = allocateTensorDescriptor(gCtx, scale, 1, C, 1, 1);
-
-		// Get underlying dense pointer
-		Pointer imagePtr = getDensePointer(gCtx, image, true, instName);
-		Pointer retPtr = getDensePointer(gCtx, ret, true, instName);
-		Pointer biasPtr = getDensePointer(gCtx, bias, true, instName);
-		Pointer scalePtr = getDensePointer(gCtx, scale, true, instName);
-		Pointer runningMeanPtr = getDensePointer(gCtx, runningMean, true, instName);
-		Pointer runningVarPtr = getDensePointer(gCtx, runningVar, true, instName);
-
-		checkStatus(cudnnBatchNormalizationForwardInference(getCudnnHandle(gCtx), mode, one(), zero(),
-				nCHWDescriptor, imagePtr, nCHWDescriptor, retPtr,
-				scaleTensorDesc, scalePtr, biasPtr,
-				runningMeanPtr, runningVarPtr, epsilon));
-	}
-
-	/**
-	 * Performs the forward BatchNormalization layer computation for training
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName name of the instruction
-	 * @param image input image
-	 * @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1]
-	 * @param bias bias (as per CuDNN) and beta as per original paper: shape [1, C, 1, 1]
-	 * @param runningMean running mean accumulated during training phase: shape [1, C, 1, 1]
-	 * @param runningVar running variance accumulated during training phase: shape [1, C, 1, 1]
-	 * @param ret (output) normalized input
-	 * @param retRunningMean (output) running mean accumulated during training phase: shape [1, C, 1, 1]
-	 * @param retRunningVar (output) running variance accumulated during training phase: shape [1, C, 1, 1]
-	 * @param epsilon epsilon value used in the batch normalization formula
-	 * @param exponentialAverageFactor factor used in the moving average computation
-	 * @throws DMLRuntimeException if error occurs
-	 */
-	public static void batchNormalizationForwardTraining(GPUContext gCtx, String instName, MatrixObject image,
-			MatrixObject scale,  MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar,
-			MatrixObject ret, MatrixObject retRunningMean, MatrixObject retRunningVar, double epsilon, double exponentialAverageFactor) throws DMLRuntimeException {
-		LOG.trace("GPU : batchNormalizationForwardTraining" + ", GPUContext=" + gCtx);
-		int mode = cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL;
-
-		int N = toInt(image.getNumRows());
-		int C = toInt(scale.getNumColumns());
-		long CHW = image.getNumColumns();
-		validateBatchNormalizationDimensions(scale, bias, runningMean, runningVar, C);
-
-		// Allocate descriptors
-		cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW,
-				new MatrixObject[] {image},  new MatrixObject[] {ret});
-		cudnnTensorDescriptor scaleTensorDesc = allocateTensorDescriptor(gCtx, scale, 1, C, 1, 1);
-
-		// Get underlying dense pointer
-		Pointer imagePtr = getDensePointer(gCtx, image, true, instName);
-		Pointer retPtr = getDensePointer(gCtx, ret, true, instName);
-		Pointer biasPtr = getDensePointer(gCtx, bias, true, instName);
-		Pointer scalePtr = getDensePointer(gCtx, scale, true, instName);
-		Pointer runningMeanPtr = getDensePointer(gCtx, runningMean, true, instName);
-		Pointer runningVarPtr = getDensePointer(gCtx, runningVar, true, instName);
-
-		// To allow for copy-on-write
-		Pointer retRunningMeanPtr = getDensePointer(gCtx, retRunningMean, true, instName);
-		Pointer retRunningVarPtr = getDensePointer(gCtx, retRunningVar, true, instName);
-		cudaMemcpy(retRunningMeanPtr, runningMeanPtr, C * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
-		cudaMemcpy(retRunningVarPtr, runningVarPtr, C * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
-
-		// ignoring resultSaveMean and resultSaveVariance as it requires state management
-		checkStatus(cudnnBatchNormalizationForwardTraining(getCudnnHandle(gCtx), mode, one(), zero(),
-				nCHWDescriptor, imagePtr, nCHWDescriptor, retPtr,
-				scaleTensorDesc, scalePtr, biasPtr, exponentialAverageFactor,
-				retRunningMeanPtr, retRunningVarPtr, epsilon, new Pointer(), new Pointer()));
-	}
-
-	/**
-	 * Convenient utility for batch normalization that returns a NCHW descriptor
-	 * @param gCtx a valid {@link GPUContext}
-	 * @param N number of images
-	 * @param C number of channels
-	 * @param CHW channels*height*width
-	 * @param input input matrix objects
-	 * @param output output matrix objects
-	 * @return one of the NCHW descriptor
-	 * @throws DMLRuntimeException if error occurs
-	 */
-	private static cudnnTensorDescriptor allocateNCHWDescriptors(GPUContext gCtx, int N, int C, long CHW, MatrixObject [] input, MatrixObject [] output) throws DMLRuntimeException {
-		cudnnTensorDescriptor ret  = null; // Return any one
-		if(CHW > ((long)Integer.MAX_VALUE)*C) {
-			throw new DMLRuntimeException("image size (height*width) should be less than " + Integer.MAX_VALUE);
-		}
-		cudnnTensorDescriptor knownNCHWdescriptor = null;
-		int H = -1; int W = -1;
-		for(int i = 0; i < input.length; i++) {
-			knownNCHWdescriptor = input[i].getGPUObject(gCtx).getTensorDescriptor();
-			if(knownNCHWdescriptor != null) {
-				int [] shape = input[i].getGPUObject(gCtx).getTensorShape();
-				if(shape[0] != N || shape[1] != C) {
-					throw new DMLRuntimeException("Incorrect N and C:" + shape[0]  + " != " + N + " || " + shape[1]  + " != " +  C);
-				}
-				H = shape[2];
-				W = shape[3];
-				break;
-			}
-		}
-		if(knownNCHWdescriptor != null) {
-			// We precisely know N, C, H, W
-			for(int i = 0; i < input.length; i++) {
-				ret = allocateTensorDescriptor(gCtx, input[i], N, C, H, W);
-			}
-			for(int i = 0; i < output.length; i++) {
-				ret = allocateTensorDescriptor(gCtx, output[i], N, C, H, W);
-			}
-		}
-		else {
-			int HW = (int) (CHW / C);
-			H = HW; W = 1; // If not known
-			double potentialH = Math.sqrt(HW);
-			if(potentialH == ((int) potentialH)) {
-				H = (int) potentialH;
-				W = H;
-			}
-			// We are not sure about H and W, hence don't allocate them.
-			ret = new cudnnTensorDescriptor();
-			cudnnCreateTensorDescriptor(ret);
-			cudnnSetTensor4dDescriptor(ret, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
-		}
-		return ret;
-	}
-
-	/**
-	 * This method computes the backpropagation errors for image, scale and bias of batch normalization layer
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName name of the instruction
-	 * @param image input image
-	 * @param dout input errors of shape C, H, W
-	 * @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1]
-	 * @param ret (output) backpropagation errors for previous layer
-	 * @param retScale backpropagation error for scale
-	 * @param retBias backpropagation error for bias
-	 * @param epsilon epsilon value used in the batch normalization formula
-	 * @throws DMLRuntimeException if error occurs
-	 */
-	public static void batchNormalizationBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
-			MatrixObject scale, MatrixObject ret, MatrixObject retScale, MatrixObject retBias,
-			double epsilon) throws DMLRuntimeException {
-		LOG.trace("GPU : batchNormalizationBackward" + ", GPUContext=" + gCtx);
-		int mode = cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL;
-
-		int N = toInt(image.getNumRows());
-		int C = toInt(scale.getNumColumns());
-		long CHW = image.getNumColumns();
-
-		// Allocate descriptors
-		cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW,
-				new MatrixObject[] {image, dout},  new MatrixObject[] {ret});
-		cudnnTensorDescriptor scaleTensorDesc = allocateTensorDescriptor(gCtx, scale, 1, C, 1, 1);
-
-		// Get underlying dense pointer
-		Pointer imagePtr = getDensePointer(gCtx, image, true, instName);
-		Pointer doutPtr = getDensePointer(gCtx, dout, true, instName);
-		Pointer scalePtr = getDensePointer(gCtx, scale, true, instName);
-		Pointer retPtr = getDensePointer(gCtx, ret, true, instName);
-		Pointer retScalePtr = getDensePointer(gCtx, retScale, true, instName);
-		Pointer retBiasPtr = getDensePointer(gCtx, retBias, true, instName);
-
-		// ignoring resultSaveMean and resultSaveVariance as it requires state management
-		checkStatus(cudnnBatchNormalizationBackward(getCudnnHandle(gCtx), mode,  one(), zero(), one(), zero(),
-				nCHWDescriptor,  imagePtr, nCHWDescriptor, doutPtr, nCHWDescriptor, retPtr,
-				scaleTensorDesc, scalePtr, retScalePtr, retBiasPtr, epsilon, new Pointer(), new Pointer()));
-	}
-
-
-	/**
-	 * This method computes the backpropogation errors for filter of convolution operation
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param image input image
-	 * @param dout errors from next layer
-	 * @param outputBlock  output errors
-	 * @param N number of images
-	 * @param C number of channels
-	 * @param H height
-	 * @param W width
-	 * @param K number of filters
-	 * @param R filter height
-	 * @param S filter width
-	 * @param pad_h pad height
-	 * @param pad_w pad width
-	 * @param stride_h stride height
-	 * @param stride_w stride width
-	 * @param P output activation height
-	 * @param Q output activation width
-	 * @throws DMLRuntimeException if DMLRuntimeException occurs
-	 */
-	public static void conv2dBackwardFilter(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
-			MatrixObject outputBlock, int N, int C, int H, int W, int K, int R,
-			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
-			int Q) throws DMLRuntimeException {
-		LOG.trace("GPU : conv2dBackwardFilter" + ", GPUContext=" + gCtx);
-		cudnnFilterDescriptor dwDesc = null;
-		cudnnConvolutionDescriptor convDesc = null;
-
-		Pointer workSpace = null;
-		long sizeInBytes = 0;
-		try {
-
-			long t1 = 0, t2 = 0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
-			// Allocate descriptors
-			cudnnTensorDescriptor xTensorDesc = allocateTensorDescriptor(gCtx, image, N, C, H, W);
-			cudnnTensorDescriptor doutTensorDesc = allocateTensorDescriptor(gCtx, dout, N, K, P, Q);
-			dwDesc = allocateFilterDescriptor(K, C, R, S);
-
-			// Allocate data
-			Pointer imagePointer = getDensePointer(gCtx, image, true, instName);
-			Pointer doutPointer = getDensePointer(gCtx, dout, true, instName);
-			Pointer dwPointer = getDensePointer(gCtx, outputBlock, true, instName);
-			int padding[] = {pad_h, pad_w};
-			int strides[] = {stride_h, stride_w};
-			convDesc = allocateConvolutionDescriptor(padding, strides);
-			long sizeInBytesArray[] = {0};
-
-			// TODO: Select the best algorithm depending on the data and supported CUDA
-			int algo = jcuda.jcudnn.cudnnConvolutionBwdFilterAlgo.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
-
-			workSpace = new Pointer();
-			cudnnGetConvolutionBackwardFilterWorkspaceSize(getCudnnHandle(gCtx),
-					xTensorDesc, doutTensorDesc, convDesc, dwDesc, algo, sizeInBytesArray);
-			if (GPUStatistics.DISPLAY_STATISTICS)
-				GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-
-			if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
-			int status = cudnnConvolutionBackwardFilter(getCudnnHandle(gCtx), one(), xTensorDesc, imagePointer,
-					doutTensorDesc, doutPointer, convDesc, algo, workSpace, sizeInBytes, zero(), dwDesc, dwPointer);
-			if (GPUStatistics.DISPLAY_STATISTICS)
-				GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CONVOLUTION_BACKWARD_FILTER_LIB, System.nanoTime() - t2);
-
-			if (status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
-				throw new DMLRuntimeException("Could not executed cudnnConvolutionBackwardFilter: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
-			}
-		} catch (CudaException e) {
-			throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), e);
-		} finally {
-			long t3=0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
-
-			if(workSpace != null && sizeInBytes != 0)
-				gCtx.cudaFreeHelper(instName, workSpace);
-			if(dwDesc != null)
-				cudnnDestroyFilterDescriptor(dwDesc);
-
-			if(convDesc != null)
-				cudnnDestroyConvolutionDescriptor(convDesc);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
-		}
-	}
-
-	private static long numDoublesIn2GB = 268435456;
-
-	/**
-	 * This method computes the backpropogation errors for previous layer of convolution operation
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param filter filter used in conv2d
-	 * @param dout errors from next layer
-	 * @param output  output errors
-	 * @param N number of images
-	 * @param C number of channels
-	 * @param H height
-	 * @param W width
-	 * @param K number of filters
-	 * @param R filter height
-	 * @param S filter width
-	 * @param pad_h pad height
-	 * @param pad_w pad width
-	 * @param stride_h stride height
-	 * @param stride_w stride width
-	 * @param P output activation height
-	 * @param Q output activation width
-	 * @throws DMLRuntimeException if DMLRuntimeException occurs
-	 */
-	public static void conv2dBackwardData(GPUContext gCtx, String instName, MatrixObject filter, MatrixObject dout,
-			MatrixObject output, int N, int C, int H, int W, int K, int R,
-			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
-			int Q) throws DMLRuntimeException {
-		LOG.trace("GPU : conv2dBackwardData" + ", GPUContext=" + gCtx);
-		cudnnFilterDescriptor wDesc = null;
-		cudnnConvolutionDescriptor convDesc = null;
-
-		Pointer workSpace = null;
-		long sizeInBytes = 0;
-		try {
-			long t1=0, t2=0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
-			// Allocate descriptors
-			wDesc = allocateFilterDescriptor(K, C, R, S);
-			cudnnTensorDescriptor dyDesc = allocateTensorDescriptor(gCtx, dout, N, K, P, Q);
-			cudnnTensorDescriptor dxDesc = allocateTensorDescriptor(gCtx, output, N, C, H, W);
-
-			// Allocate data
-			Pointer w = getDensePointer(gCtx, filter, true, instName);
-			Pointer dy = getDensePointer(gCtx, dout, true, instName);
-			Pointer dx = getDensePointer(gCtx, output, true, instName);
-
-			int padding [] = { pad_h, pad_w };
-			int strides [] = { stride_h, stride_w };
-			convDesc = allocateConvolutionDescriptor(padding, strides);
-			long sizeInBytesArray[] = { 0 };
-
-			// TODO: Select the best algorithm depending on the data and supported CUDA
-			int algo = jcuda.jcudnn.cudnnConvolutionBwdDataAlgo.CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
-			workSpace = new Pointer();
-			cudnnGetConvolutionBackwardDataWorkspaceSize(getCudnnHandle(gCtx),
-					wDesc, dyDesc, convDesc, dxDesc, algo, sizeInBytesArray);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-
-			if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
-			int status = cudnnConvolutionBackwardData(getCudnnHandle(gCtx), one(), wDesc, w,
-					dyDesc, dy, convDesc, algo, workSpace, sizeInBytes, zero(), dxDesc, dx);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CONVOLUTION_BACKWARD_DATA_LIB, System.nanoTime() - t2);
-
-			if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
-				throw new DMLRuntimeException("Could not executed cudnnConvolutionBackwardData: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
-			}
-		} catch (CudaException e) {
-			throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), e);
-		}
-		finally {
-			long t3=0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
-
-			if(workSpace != null && sizeInBytes != 0)
-				gCtx.cudaFreeHelper(instName, workSpace);
-			if(wDesc != null)
-				cudnnDestroyFilterDescriptor(wDesc);
-			if(convDesc != null)
-				cudnnDestroyConvolutionDescriptor(convDesc);
-
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
-		}
-	}
-
-	/**
-	 * performs maxpooling on GPU by exploiting cudnnPoolingForward(...)
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param image image as matrix object
-	 * @param outputBlock output matrix
-	 * @param N				batch size
-	 * @param C				number of channels
-	 * @param H				height of image
-	 * @param W				width of image
-	 * @param K				number of filters
-	 * @param R				height of filter
-	 * @param S				width of filter
-	 * @param pad_h			vertical padding
-	 * @param pad_w			horizontal padding
-	 * @param stride_h		horizontal stride
-	 * @param stride_w		vertical stride
-	 * @param P				(H - R + 1 + 2*pad_h)/stride_h
-	 * @param Q				(W - S + 1 + 2*pad_w)/stride_w
-	 * @throws DMLRuntimeException if DMLRuntimeException occurs
-	 */
-	public static void maxpooling(GPUContext gCtx, String instName, MatrixObject image,
-			MatrixObject outputBlock, int N, int C, int H, int W, int K, int R,
-			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
-			int Q) throws DMLRuntimeException {
-		Pointer x = getDensePointer(gCtx, image, true, instName);
-		cudnnTensorDescriptor xDesc = allocateTensorDescriptor(gCtx, image, N, C, H, W);
-		performMaxpooling(gCtx, instName, x, xDesc, outputBlock, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
-	}
-
-	public static void performMaxpooling(GPUContext gCtx, String instName, Pointer x, cudnnTensorDescriptor xDesc,
-			MatrixObject outputBlock, int N, int C, int H, int W, int K, int R,
-			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
-			int Q) throws DMLRuntimeException {
-		LOG.trace("GPU : performMaxpooling" + ", GPUContext=" + gCtx);
-		Pointer y = getDensePointer(gCtx, outputBlock, true, instName);
-		cudnnPoolingDescriptor poolingDesc = null;
-
-		try {
-			long t1=0,t2=0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
-			// Allocate descriptors
-			cudnnTensorDescriptor yDesc = allocateTensorDescriptor(gCtx, outputBlock, N, C, P, Q);
-			poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-
-			if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
-			int status = cudnnPoolingForward(getCudnnHandle(gCtx), poolingDesc, one(), xDesc, x, zero(), yDesc, y);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - t2);
-
-			if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
-				throw new DMLRuntimeException("Could not executed cudnnPoolingForward: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
-			}
-		} catch (CudaException e) {
-			throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), e);
-		}
-		finally {
-			long t3=0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
-			if(poolingDesc != null)
-				cudnnDestroyPoolingDescriptor(poolingDesc);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
-		}
-	}
-
-	/**
-	 * Performs maxpoolingBackward on GPU by exploiting cudnnPoolingBackward(...)
-	 * This method computes the backpropogation errors for previous layer of maxpooling operation
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param image image as matrix object
-	 * @param dout			delta matrix, output of previous layer
-	 * @param outputBlock output matrix
-	 * @param N				batch size
-	 * @param C				number of channels
-	 * @param H				height of image
-	 * @param W				width of image
-	 * @param K				number of filters
-	 * @param R				height of filter
-	 * @param S				width of filter
-	 * @param pad_h			vertical padding
-	 * @param pad_w			horizontal padding
-	 * @param stride_h		horizontal stride
-	 * @param stride_w		vertical stride
-	 * @param P				(H - R + 1 + 2*pad_h)/stride_h
-	 * @param Q				(W - S + 1 + 2*pad_w)/stride_w
-	 * @throws DMLRuntimeException if DMLRuntimeException occurs
-	 */
-	public static void maxpoolingBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
-			MatrixObject outputBlock, int N, int C, int H, int W, int K, int R,
-			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
-			int Q) throws DMLRuntimeException {
-		LOG.trace("GPU : maxpoolingBackward" + ", GPUContext=" + gCtx);
-		Pointer y = null;
-		cudnnPoolingDescriptor poolingDesc = null;
-
-		try {
-			long t1=0, t2=0, t3=0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
-			// Allocate descriptors
-			cudnnTensorDescriptor xDesc = allocateTensorDescriptor(gCtx, image, N, C, H, W);
-			cudnnTensorDescriptor yDesc = allocateTensorDescriptor(gCtx, dout, N, C, P, Q);
-			cudnnTensorDescriptor dxDesc = allocateTensorDescriptor(gCtx, outputBlock, N, C, H, W);
-			cudnnTensorDescriptor dyDesc = allocateTensorDescriptor(gCtx, dout, N, C, P, Q);
-
-			poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
-
-			// Calling PoolForward first, y is one of the inputs for poolBackward
-			// TODO: Remove calling poolForward after necessary changes at language level for poolBackward
-			long numBytes = N*C*P*Q*Sizeof.DOUBLE;
-			y = gCtx.allocate(numBytes);
-
-			// Allocate data
-			Pointer x = getDensePointer(gCtx, image, true, instName);
-			Pointer dx = getDensePointer(gCtx, outputBlock, true, instName);
-			Pointer dy = getDensePointer(gCtx, dout, true, instName);
-
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-
-			if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
-			int status = cudnnPoolingForward(getCudnnHandle(gCtx), poolingDesc, one(), xDesc, x, zero(), yDesc, y);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - t2);
-
-			if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
-				throw new DMLRuntimeException("Could not executed cudnnPoolingForward before cudnnPoolingBackward: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
-			}
-
-			if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
-			status = cudnnPoolingBackward(getCudnnHandle(gCtx), poolingDesc, one(), yDesc, y, dyDesc, dy, xDesc, x, zero(), dxDesc, dx);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_BACKWARD_LIB, System.nanoTime() - t3);
-
-			if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
-				throw new DMLRuntimeException("Could not executed cudnnPoolingBackward: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
-			}
-		} catch (CudaException e) {
-			throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), e);
-		}
-		finally {
-			long t4=0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t4 = System.nanoTime();
-
-			if(y != null)
-				gCtx.cudaFreeHelper(instName, y);
-			if(poolingDesc != null)
-				cudnnDestroyPoolingDescriptor(poolingDesc);
-
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t4);
-		}
-	}
-
-	private static void performCuDNNReLU(GPUContext gCtx, String instName, MatrixObject in, Pointer dstData, cudnnTensorDescriptor srcTensorDesc) throws DMLRuntimeException {
-		long t0=0;
-		try {
-			LOG.trace("GPU : performCuDNNReLU" + ", GPUContext=" + gCtx);
-			cudnnTensorDescriptor dstTensorDesc = srcTensorDesc;
-
-			Pointer srcData = getDensePointer(gCtx, in, true, instName);
-			cudnnActivationDescriptor activationDescriptor = new cudnnActivationDescriptor();
-			cudnnCreateActivationDescriptor(activationDescriptor);
-			double dummy = -1;
-			cudnnSetActivationDescriptor(activationDescriptor, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, dummy);
-			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
-			cudnnActivationForward(getCudnnHandle(gCtx), activationDescriptor,
-					one(), srcTensorDesc, srcData,
-					zero(), dstTensorDesc, dstData);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ACTIVATION_FORWARD_LIB, System.nanoTime() - t0);
-		} catch (CudaException e) {
-			throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), e);
-		}
-		finally {
-			long t1=0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t1);
-		}
-	}
-
-
-	/**
-	 * Performs the relu operation on the GPU.
-	 * @param ec currently active {@link ExecutionContext}
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param in input matrix
-	 * @param outputName	name of the output matrix
-	 * @throws DMLRuntimeException	if an error occurs
-	 */
-	public static void relu(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in, String outputName) throws DMLRuntimeException {
-		if (ec.getGPUContext(0) != gCtx)
-			throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
-		long N = in.getNumRows();
-		long CHW = in.getNumColumns();
-		MatrixObject output = ec.getMatrixObject(outputName);
-		getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, in.getNumRows(), in.getNumColumns()); // Allocated the dense output matrix
-		long t0=0;
-		cudnnTensorDescriptor srcTensorDesc = in.getGPUObject(gCtx).getTensorDescriptor();
-		if(N*CHW >= numDoublesIn2GB ||  srcTensorDesc == null) {
-			LOG.trace("GPU : relu custom kernel" + ", GPUContext=" + gCtx);
-			// Invokes relu(double* A,  double* ret, int rlen, int clen)
-			if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
-			Pointer dstData = getDensePointer(gCtx, output, instName);
-			Pointer srcData = getDensePointer(gCtx, in, instName); // TODO: FIXME: Add sparse kernel support for relu
-			getCudaKernels(gCtx).launchKernel("relu",
-					ExecutionConfig.getConfigForSimpleMatrixOperations(toInt(N), toInt(CHW)),
-					srcData, dstData, toInt(N), toInt(CHW));
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_RELU_KERNEL, System.nanoTime() - t0);
-		}
-		else {
-			performCuDNNReLU(gCtx, instName, in, getDensePointer(gCtx, output, true, instName), srcTensorDesc);
-		}
-	}
-
-
+	
 
 	//********************************************************************/
 	//************* End of DEEP LEARNING Operators ***********************/
@@ -2814,28 +1923,6 @@ public class LibMatrixCUDA {
 		deviceCopy(instName, srcPtr, destPtr, (int)src.getNumRows(), (int)src.getNumColumns());
 	}
 
-	@SuppressWarnings("unused")
-	private static void compareAndSet(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in, String outputName, double compareVal,  double tolerance,
-			double ifEqualsVal, double ifLessThanVal, double ifGreaterThanVal) throws DMLRuntimeException {
-		if (ec.getGPUContext(0) != gCtx)
-			throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
-		Pointer A = getDensePointer(gCtx, in, instName); // TODO: FIXME: Implement sparse kernel
-		MatrixObject out = ec.getMatrixObject(outputName);
-		int rlen = toInt(out.getNumRows());
-		int clen = toInt(out.getNumColumns());
-		getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, rlen, clen);	// Allocated the dense output matrix
-		Pointer ret = getDensePointer(gCtx, out, instName);
-
-		// out.getMatrixCharacteristics().setNonZeros(rlen*clen);
-		// compareAndSet(double* A,  double* ret, int rlen, int clen, double compareVal, double ifEqualsVal, double ifNotEqualsVal)
-		long t0=0;
-		if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
-		getCudaKernels(gCtx).launchKernel("compare_and_set",
-				ExecutionConfig.getConfigForSimpleMatrixOperations(rlen, clen),
-				A, ret, rlen, clen, compareVal, tolerance, ifEqualsVal, ifLessThanVal, ifGreaterThanVal);
-		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_COMPARE_AND_SET_KERNEL, System.nanoTime() - t0);
-	}
-
 	/**
 	 * Fills an an array on the GPU with a given scalar value
 	 * @param ec					currently active instance of the {@link ExecutionContext}
@@ -3075,7 +2162,7 @@ public class LibMatrixCUDA {
 	//******************* End of Re-org Functions ************************/
 	//********************************************************************/
 
-	private static int toInt(long num) throws DMLRuntimeException {
+	protected static int toInt(long num) throws DMLRuntimeException {
 		if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
 			throw new DMLRuntimeException("GPU : Exceeded supported size " + num);
 		}
@@ -3115,21 +2202,13 @@ public class LibMatrixCUDA {
 					+ in1.getNumColumns() + "]");
 		}
 
-		int len1 = toInt(in1.getNumColumns());
-		int len2 = toInt(ec.getMatrixObject(outputName).getNumColumns());
+		
 		if(isInSparseFormat(gCtx, in1)) {
 			// Input in1 is in sparse format and output is in dense format
 			MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, ru - rl + 1, cu - cl + 1);
 			CSRPointer inPointer = getSparsePointer(gCtx, in1, instName);
 			Pointer outPointer = getDensePointer(gCtx, out, instName);
-			int size = ru - rl + 1;
-			long t0 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
-			// Performs a slice operation where the input matrix is sparse and the output matrix is dense.
-			// This function avoids unnecessary sparse to dense conversion of the input matrix.
-			// We can generalize this later to output sparse matrix.
-			getCudaKernels(gCtx).launchKernel("slice_sparse_dense", ExecutionConfig.getConfigForSimpleVectorOperations(size),
-					inPointer.val, inPointer.rowPtr, inPointer.colInd, outPointer, rl, ru, cl, cu);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_RIX_SPARSE_DENSE_OP, System.nanoTime() - t0);
+			sliceSparseDense(gCtx, instName, inPointer, outPointer, rl, ru, cl, cu);
 		}
 		else {
 			// Input in1 is in dense format (see inPointer)
@@ -3137,18 +2216,64 @@ public class LibMatrixCUDA {
 
 			Pointer inPointer = getDensePointer(gCtx, in1, instName);
 			Pointer outPointer = getDensePointer(gCtx, out, instName);
-			long t0 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
-			if (len1 == len2) {
-				cudaMemcpy(outPointer, inPointer.withByteOffset(rl * len1 * Sizeof.DOUBLE), (ru - rl + 1) * len1
-						* Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
-			} else {
-				for (int i = rl, ix1 = rl * len1 + cl, ix2 = 0; i <= ru; i++, ix1 += len1, ix2 += len2) {
-					cudaMemcpy(outPointer.withByteOffset(ix2 * Sizeof.DOUBLE),
-							inPointer.withByteOffset(ix1 * Sizeof.DOUBLE), len2 * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
-				}
+			int len1 = toInt(in1.getNumColumns());
+			int len2 = toInt(ec.getMatrixObject(outputName).getNumColumns());
+			sliceDenseDense(gCtx, instName, inPointer, outPointer, rl, ru, cl, cu, len1, len2);
+		}
+	}
+	
+	/**
+	 * Perform slice operation on dense input and output it in dense format
+	 * 
+	 * @param gCtx gpu context
+	 * @param instName instruction name
+	 * @param inPointer dense input pointer
+	 * @param outPointer dense output pointer (doesnot need to be zeroed out)
+	 * @param rl row lower
+	 * @param ru row upper
+	 * @param cl column lower
+	 * @param cu column upper
+	 * @param len1 input number of columns
+	 * @param len2 output number of columns
+	 * @throws DMLRuntimeException
+	 */
+	protected static void sliceDenseDense(GPUContext gCtx, String instName, Pointer inPointer, Pointer outPointer, 
+			int rl, int ru, int cl, int cu, int len1, int len2) throws DMLRuntimeException {
+		long t0 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+		if (len1 == len2) {
+			cudaMemcpy(outPointer, inPointer.withByteOffset(rl * len1 * Sizeof.DOUBLE), (ru - rl + 1) * len1
+					* Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
+		} else {
+			for (int i = rl, ix1 = rl * len1 + cl, ix2 = 0; i <= ru; i++, ix1 += len1, ix2 += len2) {
+				cudaMemcpy(outPointer.withByteOffset(ix2 * Sizeof.DOUBLE),
+						inPointer.withByteOffset(ix1 * Sizeof.DOUBLE), len2 * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
 			}
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_RIX_DENSE_OP, System.nanoTime() - t0);
 		}
+		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_RIX_DENSE_OP, System.nanoTime() - t0);
+	}
+	
+	/**
+	 * Perform slice operation on sparse input and output it in dense format
+	 * 
+	 * @param gCtx gpu context
+	 * @param instName instruction name
+	 * @param inPointer sparse CSR input pointer
+	 * @param outPointer dense output pointer (expected to be zeroed out)
+	 * @param rl row lower
+	 * @param ru row upper
+	 * @param cl column lower
+	 * @param cu column upper
+	 * @throws DMLRuntimeException
+	 */
+	protected static void sliceSparseDense(GPUContext gCtx, String instName, CSRPointer inPointer, Pointer outPointer, int rl, int ru, int cl, int cu) throws DMLRuntimeException {
+		int size = ru - rl + 1;
+		long t0 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+		// Performs a slice operation where the input matrix is sparse and the output matrix is dense.
+		// This function avoids unnecessary sparse to dense conversion of the input matrix.
+		// We can generalize this later to output sparse matrix.
+		getCudaKernels(gCtx).launchKernel("slice_sparse_dense", ExecutionConfig.getConfigForSimpleVectorOperations(size),
+				inPointer.val, inPointer.rowPtr, inPointer.colInd, outPointer, rl, ru, cl, cu);
+		if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_RIX_SPARSE_DENSE_OP, System.nanoTime() - t0);
 	}
 
 	public static void cbind(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, MatrixObject in2, String outputName) throws DMLRuntimeException {
@@ -3650,26 +2775,6 @@ public class LibMatrixCUDA {
 	//********************************************************************/
 
 	/**
-	 * Convenience method for debugging matrices on the GPU.
-	 * @param in		Pointer to a double array (matrix) on the GPU
-	 * @param rlen	row length
-	 * @param clen	column length
-	 */
-	@SuppressWarnings("unused")
-	private static void debugPrintMatrix(Pointer in, int rlen, int clen){
-		double[] data = new double[rlen * clen];
-		cudaMemcpy(Pointer.to(data), in, rlen*clen*Sizeof.DOUBLE, cudaMemcpyDeviceToHost);
-		int k=0;
-		for (int i=0; i<rlen; ++i){
-			for (int j=0; j<clen; ++j){
-				System.out.print(data[k]);
-				k++;
-			}
-			System.out.println();
-		}
-	}
-
-	/**
 	 * Helper method to get the output block (allocated on the GPU)
 	 * Also records performance information into {@link Statistics}
 	 * @param ec		active {@link ExecutionContext}
@@ -3680,7 +2785,7 @@ public class LibMatrixCUDA {
 	 * @return	the matrix object
 	 * @throws DMLRuntimeException	if an error occurs
 	 */
-	private static MatrixObject getDenseMatrixOutputForGPUInstruction(ExecutionContext ec, String instName, String name, long numRows, long numCols) throws DMLRuntimeException {
+	protected static MatrixObject getDenseMatrixOutputForGPUInstruction(ExecutionContext ec, String instName, String name, long numRows, long numCols) throws DMLRuntimeException {
 		long t0=0;
 		if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
 		Pair<MatrixObject, Boolean> mb = ec.getDenseMatrixOutputForGPUInstruction(name, numRows, numCols);