You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/10/11 22:52:47 UTC

systemml git commit: [SYSTEMML-445] Refactoring to avoid potential memory leaks

Repository: systemml
Updated Branches:
  refs/heads/master 8f786aa22 -> 96ae6c7eb


[SYSTEMML-445] Refactoring to avoid potential memory leaks

- Removed tensor descriptors from GPUObject
- Created closeable LibMatrixCuDNNPoolingDescriptors class to manage the data structures required by maxpooling
- Enabled JCuda exceptions to catch CUDA errors eagerly
- Added debugging messages in eviction logic. The printing of these messages is guarded to avoid additional overhead
- Removed unused batch normalization methods from LibMatrixCuDNN

Closes #679.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/96ae6c7e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/96ae6c7e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/96ae6c7e

Branch: refs/heads/master
Commit: 96ae6c7eb34e792f9fe7c3a8b37c9130fb0ea7ae
Parents: 8f786aa
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Wed Oct 11 15:47:59 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Wed Oct 11 15:49:14 2017 -0700

----------------------------------------------------------------------
 .../instructions/gpu/context/GPUContext.java    |  18 +-
 .../instructions/gpu/context/GPUObject.java     |  75 +---
 .../runtime/matrix/data/LibMatrixCuDNN.java     | 363 ++-----------------
 .../LibMatrixCuDNNConvolutionAlgorithm.java     |  56 ++-
 .../data/LibMatrixCuDNNPoolingDescriptors.java  | 164 +++++++++
 5 files changed, 251 insertions(+), 425 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/96ae6c7e/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
index 118602b..55cb95f 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
@@ -401,8 +401,12 @@ public class GPUContext {
 	 */
 	public void cudaFreeHelper(String instructionName, final Pointer toFree, boolean eager) {
 		Pointer dummy = new Pointer();
-		if (toFree == dummy) // trying to free a null pointer
+		if (toFree == dummy) { // trying to free a null pointer
+			if (LOG.isTraceEnabled()) {
+				LOG.trace("GPU : trying to free an empty pointer");
+			}
 			return;
+		}
 		long t0 = 0;
 		if (!cudaBlockSizeMap.containsKey(toFree))
 			throw new RuntimeException(
@@ -410,7 +414,7 @@ public class GPUContext {
 		long size = cudaBlockSizeMap.get(toFree);
 		if (eager) {
 			if (LOG.isTraceEnabled()) {
-				LOG.trace("GPU : eagerly freeing cuda memory [ " + toFree + " ] for instruction " + instructionName
+				LOG.trace("GPU : eagerly freeing cuda memory [ " + toFree + " ] of size " + size + " for instruction " + instructionName
 						+ " on " + this);
 			}
 			if (DMLScript.STATISTICS)
@@ -426,7 +430,7 @@ public class GPUContext {
 						System.nanoTime() - t0);
 		} else {
 			if (LOG.isTraceEnabled()) {
-				LOG.trace("GPU : lazily freeing cuda memory for instruction " + instructionName + " on " + this);
+				LOG.trace("GPU : lazily freeing cuda memory of size " + size + " for instruction " + instructionName + " on " + this);
 			}
 			Set<Pointer> freeList = freeCUDASpaceMap.get(size);
 			if (freeList == null) {
@@ -492,6 +496,10 @@ public class GPUContext {
 			LOG.trace("GPU : evict called from " + instructionName + " for size " + neededSize + " on " + this);
 		}
 		GPUStatistics.cudaEvictionCount.add(1);
+		if (LOG.isDebugEnabled()) {
+			printMemoryInfo("EVICTION_CUDA_FREE_SPACE");
+		}
+		
 		// Release the set of free blocks maintained in a GPUObject.freeCUDASpaceMap
 		// to free up space
 		LRUCacheMap<Long, Set<Pointer>> lruCacheMap = freeCUDASpaceMap;
@@ -560,6 +568,9 @@ public class GPUContext {
 		});
 
 		while (neededSize > getAvailableMemory() && allocatedGPUObjects.size() > 0) {
+			if (LOG.isDebugEnabled()) {
+				printMemoryInfo("EVICTION_UNLOCKED");
+			}
 			GPUObject toBeRemoved = allocatedGPUObjects.get(allocatedGPUObjects.size() - 1);
 			if (toBeRemoved.isLocked()) {
 				throw new DMLRuntimeException(
@@ -569,7 +580,6 @@ public class GPUContext {
 			if (toBeRemoved.dirty) {
 				toBeRemoved.copyFromDeviceToHost();
 			}
-
 			toBeRemoved.clearData(true);
 		}
 	}

http://git-wip-us.apache.org/repos/asf/systemml/blob/96ae6c7e/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index 31bf151..feb34bc 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -19,11 +19,6 @@
 package org.apache.sysml.runtime.instructions.gpu.context;
 
 import static jcuda.jcublas.cublasOperation.CUBLAS_OP_T;
-import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
-import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
-import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
 import static jcuda.jcusparse.JCusparse.cusparseDdense2csr;
 import static jcuda.jcusparse.JCusparse.cusparseDnnz;
 import static jcuda.runtime.JCuda.cudaMemcpy;
@@ -32,7 +27,6 @@ import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
 
-import java.util.Arrays;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.LongAdder;
 
@@ -55,7 +49,6 @@ import org.apache.sysml.utils.GPUStatistics;
 import jcuda.Pointer;
 import jcuda.Sizeof;
 import jcuda.jcublas.JCublas2;
-import jcuda.jcudnn.cudnnTensorDescriptor;
 import jcuda.jcusparse.JCusparse;
 import jcuda.jcusparse.cusparseDirection;
 import jcuda.jcusparse.cusparseHandle;
@@ -84,17 +77,6 @@ public class GPUObject {
 	private CSRPointer jcudaSparseMatrixPtr = null;
 
 	/**
-	 * An optional tensor descriptor (and shape) that can be set by a tensor instruction such as convolution,
-	 * maxpooling and exploited by a subsequent non-tensor instruction such as relu
-	 */
-	private cudnnTensorDescriptor tensorDescriptor = null;
-
-	/**
-	 * the shape of this tensor, if in fact this is a tensor
-	 */
-	private int[] tensorShape = null;
-
-	/**
 	 * whether the block attached to this {@link GPUContext} is dirty on the device and needs to be copied back to host
 	 */
 	protected boolean dirty = false;
@@ -132,13 +114,7 @@ public class GPUObject {
 	public Object clone() {
 		GPUObject me = this;
 		GPUObject that = new GPUObject(me.gpuContext, me.mat);
-		if (me.tensorShape != null) {
-			that.tensorShape = new int[me.tensorShape.length];
-			System.arraycopy(me.tensorShape, 0, that.tensorShape, 0, me.tensorShape.length);
-			that.allocateTensorDescriptor(me.tensorShape[0], me.tensorShape[1], me.tensorShape[2], me.tensorShape[3]);
-		}
 		that.dirty = me.dirty;
-		// TODO Nakul: Should the locks be cloned here ?
 		// The only place clone is getting called: LibMatrixCUDA's solve
 		that.readLocks.reset();
 		that.writeLock = false;
@@ -498,51 +474,7 @@ public class GPUObject {
 	public boolean isSparse() {
 		return isSparse;
 	}
-
-	/**
-	 * Returns a previously allocated tensor shape or null
-	 *
-	 * @return int array of four elements or null
-	 */
-	public int[] getTensorShape() {
-		return tensorShape;
-	}
-
-	/**
-	 * Returns a previously allocated tensor descriptor or null
-	 *
-	 * @return cudnn tensor descriptor
-	 */
-	public cudnnTensorDescriptor getTensorDescriptor() {
-		return tensorDescriptor;
-	}
-
-	/**
-	 * Returns a previously allocated or allocates and returns a tensor descriptor
-	 *
-	 * @param N number of images
-	 * @param C number of channels
-	 * @param H height
-	 * @param W width
-	 * @return cudnn tensor descriptor
-	 */
-	public cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) {
-		if(LOG.isTraceEnabled()) {
-			LOG.trace("GPU : allocateTensorDescriptor with [N=" + N + ",C=" + C + ",H=" + H + ",W=" + W + "] on " + this);
-		}
-		if (tensorDescriptor == null) {
-			tensorDescriptor = new cudnnTensorDescriptor();
-			cudnnCreateTensorDescriptor(tensorDescriptor);
-			cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
-			tensorShape = new int[4];
-			tensorShape[0] = N;
-			tensorShape[1] = C;
-			tensorShape[2] = H;
-			tensorShape[3] = W;
-		}
-		return tensorDescriptor;
-	}
-
+	
 	private static long getDoubleSizeOf(long numElems) {
 		return numElems * ((long) jcuda.Sizeof.DOUBLE);
 	}
@@ -829,10 +761,6 @@ public class GPUObject {
 		}
 		jcudaDenseMatrixPtr = null;
 		jcudaSparseMatrixPtr = null;
-		if (tensorDescriptor != null) {
-			cudnnDestroyTensorDescriptor(tensorDescriptor);
-			tensorDescriptor = null;
-		}
 		resetReadWriteLock();
 		getGPUContext().removeRecordedUsage(this);
 	}
@@ -1094,7 +1022,6 @@ public class GPUObject {
 	@Override
 	public String toString() {
 		final StringBuilder sb = new StringBuilder("GPUObject{");
-		sb.append(", tensorShape=").append(Arrays.toString(tensorShape));
 		sb.append(", dirty=").append(dirty);
 		sb.append(", readLocks=").append(readLocks.longValue());
 		sb.append(", writeLock=").append(writeLock);

http://git-wip-us.apache.org/repos/asf/systemml/blob/96ae6c7e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index 25dc604..bb74aa2 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -19,43 +19,27 @@
 package org.apache.sysml.runtime.matrix.data;
 
 import static jcuda.jcudnn.JCudnn.cudnnActivationForward;
-import static jcuda.jcudnn.JCudnn.cudnnBatchNormalizationBackward;
-import static jcuda.jcudnn.JCudnn.cudnnBatchNormalizationForwardInference;
-import static jcuda.jcudnn.JCudnn.cudnnBatchNormalizationForwardTraining;
 import static jcuda.jcudnn.JCudnn.cudnnConvolutionBackwardData;
 import static jcuda.jcudnn.JCudnn.cudnnConvolutionBackwardFilter;
 import static jcuda.jcudnn.JCudnn.cudnnConvolutionForward;
 import static jcuda.jcudnn.JCudnn.cudnnCreateActivationDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnCreateConvolutionDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnCreateFilterDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnCreatePoolingDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnPoolingBackward;
 import static jcuda.jcudnn.JCudnn.cudnnPoolingForward;
 import static jcuda.jcudnn.JCudnn.cudnnSetActivationDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetConvolution2dDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetFilter4dDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetPooling2dDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
 import static jcuda.jcudnn.cudnnActivationMode.CUDNN_ACTIVATION_RELU;
-import static jcuda.jcudnn.cudnnConvolutionMode.CUDNN_CROSS_CORRELATION;
 import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
 import static jcuda.jcudnn.cudnnNanPropagation.CUDNN_PROPAGATE_NAN;
-import static jcuda.jcudnn.cudnnPoolingMode.CUDNN_POOLING_MAX;
 import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
-import static jcuda.runtime.JCuda.cudaMemcpy;
 import static jcuda.runtime.JCuda.cudaMemset;
-import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
 import jcuda.CudaException;
 import jcuda.Pointer;
 import jcuda.Sizeof;
 import jcuda.jcudnn.cudnnActivationDescriptor;
-import jcuda.jcudnn.cudnnBatchNormMode;
-import jcuda.jcudnn.cudnnConvolutionDescriptor;
 import jcuda.jcudnn.cudnnConvolutionFwdPreference;
-import jcuda.jcudnn.cudnnFilterDescriptor;
 import jcuda.jcudnn.cudnnHandle;
-import jcuda.jcudnn.cudnnPoolingDescriptor;
 import jcuda.jcudnn.cudnnStatus;
 import jcuda.jcudnn.cudnnTensorDescriptor;
 
@@ -115,6 +99,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		//cudaDeviceSynchronize;
 		biasAdd(gCtx, instName, output, bias, output);
 	}
+	
 
 	/**
 	 * Performs a 2D convolution
@@ -145,7 +130,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 
 		long CHW = C*H*W; long KPQ = K*P*Q; long CRS = C*R*S; 
 		long NCHW = N*CHW; long NKPQ = N*KPQ; long KCRS = K*CRS;
-
+		
 		if(NCHW < maxNumDoublesOfCuDNNTensor && NKPQ < maxNumDoublesOfCuDNNTensor && KCRS < maxNumDoublesOfCuDNNTensor) {
 			// Filter and output are accounted as dense in the memory estimation for conv2d
 			double overhead = isInSparseFormat(gCtx, filter) ? OptimizerUtils.estimateSizeExactSparsity(K, CRS, 1.0) : 0;
@@ -489,14 +474,12 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 			Pointer y = getDensePointerForCuDNN(gCtx, outputBlock, instName);
 			if(overhead <= intermediateMemoryBudget) {
 				Pointer x = getDensePointerForCuDNN(gCtx, image, instName);
-				cudnnTensorDescriptor xDesc = allocateTensorDescriptor(gCtx, image, N, C, H, W);
-				cudnnMaxpooling(gCtx, instName, x, xDesc, y, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+				cudnnMaxpooling(gCtx, instName, x, y, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
 			}
 			else {
 				LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
-				cudnnTensorDescriptor xDesc = allocateTensorDescriptor(gCtx, image, N, C, H, W);
 				for(int n = 0; n < N; n++) {
-					cudnnMaxpooling(gCtx, instName, imgFetcher.getNthRow(n), xDesc, y.withByteOffset(n*CPQ*Sizeof.DOUBLE), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+					cudnnMaxpooling(gCtx, instName, imgFetcher.getNthRow(n), y.withByteOffset(n*CPQ*Sizeof.DOUBLE), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
 				}
 				imgFetcher.close();
 			}
@@ -506,7 +489,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		}
 	}
 
-	private static void cudnnMaxpooling(GPUContext gCtx, String instName, Pointer x, cudnnTensorDescriptor xDesc,
+	private static void cudnnMaxpooling(GPUContext gCtx, String instName, Pointer x,
 			Pointer y, int N, int C, int H, int W, int K, int R,
 			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
 			int Q) throws DMLRuntimeException {
@@ -514,33 +497,21 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 			LOG.trace("GPU : performMaxpooling" + ", GPUContext=" + gCtx);
 		}
 
-		cudnnPoolingDescriptor poolingDesc = null;
-
-		try {
+		try(LibMatrixCuDNNPoolingDescriptors desc = 
+				LibMatrixCuDNNPoolingDescriptors.cudnnMaxpoolingDescriptors(gCtx, instName, N, C, H, W, K, R, S, 
+						pad_h, pad_w, stride_h, stride_w, P, Q)) {
 			long t1=0,t2=0;
 			if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
-			// Allocate descriptors
-			cudnnTensorDescriptor yDesc = allocateTensorDescriptor(N, C, P, Q);
-			poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-
 			if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
-			int status = cudnnPoolingForward(getCudnnHandle(gCtx), poolingDesc, one(), xDesc, x, zero(), yDesc, y);
+			int status = cudnnPoolingForward(getCudnnHandle(gCtx), desc.poolingDesc, one(), desc.xDesc, x, zero(), desc.yDesc, y);
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - t2);
-
 			if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
 				throw new DMLRuntimeException("Could not executed cudnnPoolingForward: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
 			}
 		} catch (CudaException e) {
 			throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), e);
 		}
-		finally {
-			long t3=0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
-			if(poolingDesc != null)
-				jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor(poolingDesc);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
-		}
 	}
 
 	/**
@@ -611,28 +582,22 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 			LOG.trace("GPU : maxpoolingBackward" + ", GPUContext=" + gCtx);
 		}
 		Pointer y = null;
-		cudnnPoolingDescriptor poolingDesc = null;
 
-		try {
+		try(LibMatrixCuDNNPoolingDescriptors desc = 
+				LibMatrixCuDNNPoolingDescriptors.cudnnMaxpoolingBackwardDescriptors(gCtx, instName, N, C, H, W, K, R, S, 
+						pad_h, pad_w, stride_h, stride_w, P, Q)) {
 			long t1=0, t2=0, t3=0;
 			if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
-			// Allocate descriptors
-			cudnnTensorDescriptor xDesc = allocateTensorDescriptor(N, C, H, W);
-			cudnnTensorDescriptor yDesc = allocateTensorDescriptor(N, C, P, Q);
-			cudnnTensorDescriptor dxDesc = allocateTensorDescriptor(N, C, H, W);
-			cudnnTensorDescriptor dyDesc = allocateTensorDescriptor(N, C, P, Q);
-
-			poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
-
+			
 			// Calling PoolForward first, y is one of the inputs for poolBackward
 			// TODO: Remove calling poolForward after necessary changes at language level for poolBackward
 			long numBytes = N*C*P*Q*Sizeof.DOUBLE;
 			y = gCtx.allocate(numBytes);
 			
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-
+			
 			if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
-			int status = cudnnPoolingForward(getCudnnHandle(gCtx), poolingDesc, one(), xDesc, x, zero(), yDesc, y);
+			int status = cudnnPoolingForward(getCudnnHandle(gCtx), desc.poolingDesc, one(), desc.xDesc, x, zero(), desc.yDesc, y);
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - t2);
 
 			if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
@@ -640,7 +605,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 			}
 
 			if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
-			status = cudnnPoolingBackward(getCudnnHandle(gCtx), poolingDesc, one(), yDesc, y, dyDesc, dy, xDesc, x, zero(), dxDesc, dx);
+			status = cudnnPoolingBackward(getCudnnHandle(gCtx), desc.poolingDesc, one(), desc.yDesc, y, desc.dyDesc, dy, desc.xDesc, x, zero(), desc.dxDesc, dx);
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_BACKWARD_LIB, System.nanoTime() - t3);
 
 			if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
@@ -652,297 +617,12 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		finally {
 			long t4=0;
 			if (GPUStatistics.DISPLAY_STATISTICS) t4 = System.nanoTime();
-
 			if(y != null)
 				gCtx.cudaFreeHelper(instName, y);
-			if(poolingDesc != null)
-				jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor(poolingDesc);
-
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t4);
 		}
 	}
 
-	static cudnnConvolutionDescriptor allocateConvolutionDescriptor(int padding [], int strides []) {
-		cudnnConvolutionDescriptor convDesc = new cudnnConvolutionDescriptor();
-		cudnnCreateConvolutionDescriptor(convDesc);
-		cudnnSetConvolution2dDescriptor(convDesc, padding[0], padding[1], strides[0], strides[1], 1, 1, CUDNN_CROSS_CORRELATION);
-		return convDesc;
-	}
-
-	protected static cudnnFilterDescriptor allocateFilterDescriptor(int K, int C, int R, int S) {
-		cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
-		cudnnCreateFilterDescriptor(filterDesc);
-		cudnnSetFilter4dDescriptor(filterDesc, CUDNN_DATA_DOUBLE, CUDNN_TENSOR_NCHW, K, C, R, S);
-		return filterDesc;
-	}
-
-	/**
-	 * allocates pooling descriptor, used in poolingForward and poolingBackward
-	 * @param R			pooling window height
-	 * @param S			pooling window width
-	 * @param pad_h		vertical padding
-	 * @param pad_w		horizontal padding
-	 * @param stride_h	pooling vertical stride
-	 * @param stride_w	pooling horizontal stride
-	 * @return cudnn pooling descriptor
-	 */
-	private static cudnnPoolingDescriptor allocatePoolingDescriptor(int R, int S, int pad_h, int pad_w, int stride_h, int stride_w) {
-		cudnnPoolingDescriptor poolingDesc = new cudnnPoolingDescriptor();
-		cudnnCreatePoolingDescriptor(poolingDesc);
-		cudnnSetPooling2dDescriptor(poolingDesc, CUDNN_POOLING_MAX, CUDNN_PROPAGATE_NAN, R, S, pad_h, pad_w, stride_h, stride_w);
-		return poolingDesc;
-	}
-
-	/**
-	 * Convenience method to get tensor descriptor
-	 * @param N number of images
-	 * @param C number of channels
-	 * @param H height
-	 * @param W width
-	 * @return cudnn tensor descriptor
-	 * @throws DMLRuntimeException if the input descriptor and matrix dimensions don't match
-	 */
-	static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) throws DMLRuntimeException {
-		cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
-		cudnnCreateTensorDescriptor(tensorDescriptor);
-		cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
-		return tensorDescriptor;
-	}
-
-	/**
-	 * Convenience method to get tensor descriptor from underlying GPUObject
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param mat matrix object
-	 * @param N number of images
-	 * @param C number of channels
-	 * @param H height
-	 * @param W width
-	 * @return cudnn tensor descriptor
-	 * @throws DMLRuntimeException if the input descriptor and matrix dimensions don't match
-	 */
-	private static cudnnTensorDescriptor allocateTensorDescriptor(GPUContext gCtx, MatrixObject mat, int N, int C, int H, int W) throws DMLRuntimeException {
-		if(mat.getNumRows() != N || mat.getNumColumns() != C*H*W) {
-			throw new DMLRuntimeException("Mismatch descriptor-matrix dimensions:" + mat.getNumRows() + " != " + N
-					+ " || " + mat.getNumColumns() + " != " + (C*H*W));
-		}
-		return mat.getGPUObject(gCtx).allocateTensorDescriptor(N, C, H, W);
-	}
-
-	/**
-	 * Performs the forward BatchNormalization layer computation for inference
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName name of the instruction
-	 * @param image input image
-	 * @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1]
-	 * @param bias bias (as per CuDNN) and beta as per original paper: shape [1, C, 1, 1]
-	 * @param runningMean running mean accumulated during training phase: shape [1, C, 1, 1]
-	 * @param runningVar running variance accumulated during training phase: shape [1, C, 1, 1]
-	 * @param ret normalized input
-	 * @param epsilon epsilon value used in the batch normalization formula
-	 * @throws DMLRuntimeException if error occurs
-	 */
-	public static void batchNormalizationForwardInference(GPUContext gCtx, String instName, MatrixObject image,
-			MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar,
-			MatrixObject ret, double epsilon) throws DMLRuntimeException {
-		if(LOG.isTraceEnabled()) {
-			LOG.trace("GPU : batchNormalizationForwardInference" + ", GPUContext=" + gCtx);
-		}
-		int mode = cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL;
-
-		int N = toInt(image.getNumRows());
-		int C = toInt(scale.getNumColumns());
-		long CHW = image.getNumColumns();
-		validateBatchNormalizationDimensions(scale, bias, runningMean, runningVar, C);
-
-		// Allocate descriptors
-		cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW,
-				new MatrixObject[] {image},  new MatrixObject[] {ret});
-		cudnnTensorDescriptor scaleTensorDesc = allocateTensorDescriptor(gCtx, scale, 1, C, 1, 1);
-
-		// Get underlying dense pointer
-		Pointer imagePtr = getDensePointerForCuDNN(gCtx, image, instName);
-		Pointer retPtr = getDensePointerForCuDNN(gCtx, ret, instName);
-		Pointer biasPtr = getDensePointerForCuDNN(gCtx, bias, instName);
-		Pointer scalePtr = getDensePointerForCuDNN(gCtx, scale, instName);
-		Pointer runningMeanPtr = getDensePointerForCuDNN(gCtx, runningMean, instName);
-		Pointer runningVarPtr = getDensePointerForCuDNN(gCtx, runningVar, instName);
-
-		checkStatus(cudnnBatchNormalizationForwardInference(getCudnnHandle(gCtx), mode, one(), zero(),
-				nCHWDescriptor, imagePtr, nCHWDescriptor, retPtr,
-				scaleTensorDesc, scalePtr, biasPtr,
-				runningMeanPtr, runningVarPtr, epsilon));
-	}
-
-	/**
-	 * Performs the forward BatchNormalization layer computation for training
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName name of the instruction
-	 * @param image input image
-	 * @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1]
-	 * @param bias bias (as per CuDNN) and beta as per original paper: shape [1, C, 1, 1]
-	 * @param runningMean running mean accumulated during training phase: shape [1, C, 1, 1]
-	 * @param runningVar running variance accumulated during training phase: shape [1, C, 1, 1]
-	 * @param ret (output) normalized input
-	 * @param retRunningMean (output) running mean accumulated during training phase: shape [1, C, 1, 1]
-	 * @param retRunningVar (output) running variance accumulated during training phase: shape [1, C, 1, 1]
-	 * @param epsilon epsilon value used in the batch normalization formula
-	 * @param exponentialAverageFactor factor used in the moving average computation
-	 * @throws DMLRuntimeException if error occurs
-	 */
-	public static void batchNormalizationForwardTraining(GPUContext gCtx, String instName, MatrixObject image,
-			MatrixObject scale,  MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar,
-			MatrixObject ret, MatrixObject retRunningMean, MatrixObject retRunningVar, double epsilon, double exponentialAverageFactor) throws DMLRuntimeException {
-		if(LOG.isTraceEnabled()) {
-			LOG.trace("GPU : batchNormalizationForwardTraining" + ", GPUContext=" + gCtx);
-		}
-		int mode = cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL;
-
-		int N = toInt(image.getNumRows());
-		int C = toInt(scale.getNumColumns());
-		long CHW = image.getNumColumns();
-		validateBatchNormalizationDimensions(scale, bias, runningMean, runningVar, C);
-
-		// Allocate descriptors
-		cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW,
-				new MatrixObject[] {image},  new MatrixObject[] {ret});
-		cudnnTensorDescriptor scaleTensorDesc = allocateTensorDescriptor(gCtx, scale, 1, C, 1, 1);
-
-		// Get underlying dense pointer
-		Pointer imagePtr = getDensePointerForCuDNN(gCtx, image, instName);
-		Pointer retPtr = getDensePointerForCuDNN(gCtx, ret, instName);
-		Pointer biasPtr = getDensePointerForCuDNN(gCtx, bias, instName);
-		Pointer scalePtr = getDensePointerForCuDNN(gCtx, scale, instName);
-		Pointer runningMeanPtr = getDensePointerForCuDNN(gCtx, runningMean, instName);
-		Pointer runningVarPtr = getDensePointerForCuDNN(gCtx, runningVar, instName);
-
-		// To allow for copy-on-write
-		Pointer retRunningMeanPtr = getDensePointerForCuDNN(gCtx, retRunningMean, instName);
-		Pointer retRunningVarPtr = getDensePointerForCuDNN(gCtx, retRunningVar, instName);
-		cudaMemcpy(retRunningMeanPtr, runningMeanPtr, C * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
-		cudaMemcpy(retRunningVarPtr, runningVarPtr, C * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
-
-		// ignoring resultSaveMean and resultSaveVariance as it requires state management
-		checkStatus(cudnnBatchNormalizationForwardTraining(getCudnnHandle(gCtx), mode, one(), zero(),
-				nCHWDescriptor, imagePtr, nCHWDescriptor, retPtr,
-				scaleTensorDesc, scalePtr, biasPtr, exponentialAverageFactor,
-				retRunningMeanPtr, retRunningVarPtr, epsilon, new Pointer(), new Pointer()));
-	}
-
-	private static void validateBatchNormalizationDimensions(MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar, int C) throws DMLRuntimeException {
-		if(scale.getNumRows() != 1 || scale.getNumColumns() != C) {
-			throw new DMLRuntimeException("Incorrect dimensions for scale");
-		}
-		if(bias.getNumRows() != 1 || bias.getNumColumns() != C) {
-			throw new DMLRuntimeException("Incorrect dimensions for bias");
-		}
-		if(runningMean.getNumRows() != 1 || runningMean.getNumColumns() != C) {
-			throw new DMLRuntimeException("Incorrect dimensions for running mean");
-		}
-		if(runningVar.getNumRows() != 1 || runningVar.getNumColumns() != C) {
-			throw new DMLRuntimeException("Incorrect dimensions for running variance");
-		}
-	}
-
-	/**
-	 * Convenient utility for batch normalization that returns a NCHW descriptor
-	 * @param gCtx a valid {@link GPUContext}
-	 * @param N number of images
-	 * @param C number of channels
-	 * @param CHW channels*height*width
-	 * @param input input matrix objects
-	 * @param output output matrix objects
-	 * @return one of the NCHW descriptor
-	 * @throws DMLRuntimeException if error occurs
-	 */
-	private static cudnnTensorDescriptor allocateNCHWDescriptors(GPUContext gCtx, int N, int C, long CHW, MatrixObject [] input, MatrixObject [] output) throws DMLRuntimeException {
-		cudnnTensorDescriptor ret  = null; // Return any one
-		if(CHW > ((long)Integer.MAX_VALUE)*C) {
-			throw new DMLRuntimeException("image size (height*width) should be less than " + Integer.MAX_VALUE);
-		}
-		cudnnTensorDescriptor knownNCHWdescriptor = null;
-		int H = -1; int W = -1;
-		for(int i = 0; i < input.length; i++) {
-			knownNCHWdescriptor = input[i].getGPUObject(gCtx).getTensorDescriptor();
-			if(knownNCHWdescriptor != null) {
-				int [] shape = input[i].getGPUObject(gCtx).getTensorShape();
-				if(shape[0] != N || shape[1] != C) {
-					throw new DMLRuntimeException("Incorrect N and C:" + shape[0]  + " != " + N + " || " + shape[1]  + " != " +  C);
-				}
-				H = shape[2];
-				W = shape[3];
-				break;
-			}
-		}
-		if(knownNCHWdescriptor != null) {
-			// We precisely know N, C, H, W
-			for(int i = 0; i < input.length; i++) {
-				ret = allocateTensorDescriptor(gCtx, input[i], N, C, H, W);
-			}
-			for(int i = 0; i < output.length; i++) {
-				ret = allocateTensorDescriptor(gCtx, output[i], N, C, H, W);
-			}
-		}
-		else {
-			int HW = (int) (CHW / C);
-			H = HW; W = 1; // If not known
-			double potentialH = Math.sqrt(HW);
-			if(potentialH == ((int) potentialH)) {
-				H = (int) potentialH;
-				W = H;
-			}
-			// We are not sure about H and W, hence don't allocate them.
-			ret = new cudnnTensorDescriptor();
-			cudnnCreateTensorDescriptor(ret);
-			cudnnSetTensor4dDescriptor(ret, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
-		}
-		return ret;
-	}
-
-	/**
-	 * This method computes the backpropagation errors for image, scale and bias of batch normalization layer
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName name of the instruction
-	 * @param image input image
-	 * @param dout input errors of shape C, H, W
-	 * @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1]
-	 * @param ret (output) backpropagation errors for previous layer
-	 * @param retScale backpropagation error for scale
-	 * @param retBias backpropagation error for bias
-	 * @param epsilon epsilon value used in the batch normalization formula
-	 * @throws DMLRuntimeException if error occurs
-	 */
-	public static void batchNormalizationBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
-			MatrixObject scale, MatrixObject ret, MatrixObject retScale, MatrixObject retBias,
-			double epsilon) throws DMLRuntimeException {
-		if(LOG.isTraceEnabled()) {
-			LOG.trace("GPU : batchNormalizationBackward" + ", GPUContext=" + gCtx);
-		}
-		int mode = cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL;
-
-		int N = toInt(image.getNumRows());
-		int C = toInt(scale.getNumColumns());
-		long CHW = image.getNumColumns();
-
-		// Allocate descriptors
-		cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW,
-				new MatrixObject[] {image, dout},  new MatrixObject[] {ret});
-		cudnnTensorDescriptor scaleTensorDesc = allocateTensorDescriptor(gCtx, scale, 1, C, 1, 1);
-
-		// Get underlying dense pointer
-		Pointer imagePtr = getDensePointerForCuDNN(gCtx, image, instName);
-		Pointer doutPtr = getDensePointerForCuDNN(gCtx, dout, instName);
-		Pointer scalePtr = getDensePointerForCuDNN(gCtx, scale, instName);
-		Pointer retPtr = getDensePointerForCuDNN(gCtx, ret, instName);
-		Pointer retScalePtr = getDensePointerForCuDNN(gCtx, retScale, instName);
-		Pointer retBiasPtr = getDensePointerForCuDNN(gCtx, retBias, instName);
-
-		// ignoring resultSaveMean and resultSaveVariance as it requires state management
-		checkStatus(cudnnBatchNormalizationBackward(getCudnnHandle(gCtx), mode,  one(), zero(), one(), zero(),
-				nCHWDescriptor,  imagePtr, nCHWDescriptor, doutPtr, nCHWDescriptor, retPtr,
-				scaleTensorDesc, scalePtr, retScalePtr, retBiasPtr, epsilon, new Pointer(), new Pointer()));
-	}
-
-
 	private static void cudnnReLU(GPUContext gCtx, String instName, MatrixObject in, Pointer dstData, cudnnTensorDescriptor srcTensorDesc) throws DMLRuntimeException {
 		long t0=0;
 		try {
@@ -988,8 +668,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		MatrixObject output = ec.getMatrixObject(outputName);
 		getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, in.getNumRows(), in.getNumColumns()); // Allocated the dense output matrix
 		long t0=0;
-		cudnnTensorDescriptor srcTensorDesc = in.getGPUObject(gCtx).getTensorDescriptor();
-		if(N*CHW >= maxNumDoublesOfCuDNNTensor ||  srcTensorDesc == null) {
+		if(N*CHW >= maxNumDoublesOfCuDNNTensor) {
 			if(LOG.isTraceEnabled()) {
 				LOG.trace("GPU : relu custom kernel" + ", GPUContext=" + gCtx);
 			}
@@ -1003,7 +682,11 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_RELU_KERNEL, System.nanoTime() - t0);
 		}
 		else {
-			cudnnReLU(gCtx, instName, in, getDensePointerForCuDNN(gCtx, output, instName), srcTensorDesc);
+			cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
+			cudnnCreateTensorDescriptor(tensorDescriptor);
+			cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, toInt(N), 1, 1, toInt(CHW));
+			cudnnReLU(gCtx, instName, in, getDensePointerForCuDNN(gCtx, output, instName), tensorDescriptor);
+			cudnnDestroyTensorDescriptor(tensorDescriptor);
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/96ae6c7e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
index 2243b58..871194e 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
@@ -31,9 +31,18 @@ import jcuda.jcudnn.cudnnConvolutionDescriptor;
 import jcuda.jcudnn.cudnnConvolutionFwdPreference;
 import jcuda.jcudnn.cudnnFilterDescriptor;
 import jcuda.jcudnn.cudnnTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnCreateConvolutionDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnCreateFilterDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnDestroyConvolutionDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnDestroyFilterDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetConvolution2dDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetFilter4dDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
+import static jcuda.jcudnn.cudnnConvolutionMode.CUDNN_CROSS_CORRELATION;
+import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
+import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
 
 /**
  * This class is a wrapper that contain necessary data structures to invoke 
@@ -48,6 +57,9 @@ import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
  *  
  */
 public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseable {
+	// Limit the workspace available to cudnn convolution operation to 1 GB
+	private static long MAX_WORKSPACE_LIMIT_BYTES = (long) 1e+9;
+	
 	public int algo = -1;
 	public Pointer workSpace = new Pointer();
 	public long sizeInBytes = 0;
@@ -61,12 +73,12 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseab
 			int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q) throws DMLRuntimeException {
 		int padding[] = {pad_h, pad_w};
 		int strides[] = {stride_h, stride_w};
-		convDesc = LibMatrixCuDNN.allocateConvolutionDescriptor(padding, strides);
+		convDesc = allocateConvolutionDescriptor(padding, strides);
 		this.gCtx = gCtx;
 		this.instName = instName;
-		nchwTensorDesc = LibMatrixCuDNN.allocateTensorDescriptor(N, C, H, W);
-		nkpqTensorDesc = LibMatrixCuDNN.allocateTensorDescriptor(N, K, P, Q);
-		filterDesc = LibMatrixCuDNN.allocateFilterDescriptor(K, C, R, S);
+		nchwTensorDesc = allocateTensorDescriptor(N, C, H, W);
+		nkpqTensorDesc = allocateTensorDescriptor(N, K, P, Q);
+		filterDesc = allocateFilterDescriptor(K, C, R, S);
 	}
 	
 	/**
@@ -125,7 +137,7 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseab
 		}
 		else {
 			int[] algos = {-1};
-			long sizeInBytesArray[] = {workspaceLimit};
+			long sizeInBytesArray[] = {Math.min(workspaceLimit, MAX_WORKSPACE_LIMIT_BYTES)};
 			jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardAlgorithm(LibMatrixCuDNN.getCudnnHandle(gCtx), 
 					ret.nchwTensorDesc, ret.filterDesc, ret.convDesc, ret.nkpqTensorDesc,
 					cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, sizeInBytesArray[0], algos);
@@ -177,7 +189,7 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseab
 		}
 		else {
 			int[] algos = {-1};
-			long sizeInBytesArray[] = {workspaceLimit};
+			long sizeInBytesArray[] = {Math.min(workspaceLimit, MAX_WORKSPACE_LIMIT_BYTES)};
 			jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardFilterAlgorithm(
 					LibMatrixCuDNN.getCudnnHandle(gCtx), 
 					ret.nchwTensorDesc, ret.nkpqTensorDesc, ret.convDesc, ret.filterDesc, 
@@ -230,7 +242,7 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseab
 		}
 		else {
 			int[] algos = {-1};
-			long sizeInBytesArray[] = {workspaceLimit};
+			long sizeInBytesArray[] = {Math.min(workspaceLimit, MAX_WORKSPACE_LIMIT_BYTES)};
 			jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardDataAlgorithm(
 					LibMatrixCuDNN.getCudnnHandle(gCtx), 
 					ret.filterDesc, ret.nkpqTensorDesc, ret.convDesc, ret.nchwTensorDesc,
@@ -246,4 +258,34 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseab
 			GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
 		return ret;
 	}
+	
+	/**
+	 * Convenience method to get tensor descriptor
+	 * @param N number of images
+	 * @param C number of channels
+	 * @param H height
+	 * @param W width
+	 * @return cudnn tensor descriptor
+	 * @throws DMLRuntimeException if the input descriptor and matrix dimensions don't match
+	 */
+	private static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) throws DMLRuntimeException {
+		cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
+		cudnnCreateTensorDescriptor(tensorDescriptor);
+		cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
+		return tensorDescriptor;
+	}
+	
+	private static cudnnFilterDescriptor allocateFilterDescriptor(int K, int C, int R, int S) {
+		cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
+		cudnnCreateFilterDescriptor(filterDesc);
+		cudnnSetFilter4dDescriptor(filterDesc, CUDNN_DATA_DOUBLE, CUDNN_TENSOR_NCHW, K, C, R, S);
+		return filterDesc;
+	}
+	
+	private static cudnnConvolutionDescriptor allocateConvolutionDescriptor(int padding [], int strides []) {
+		cudnnConvolutionDescriptor convDesc = new cudnnConvolutionDescriptor();
+		cudnnCreateConvolutionDescriptor(convDesc);
+		cudnnSetConvolution2dDescriptor(convDesc, padding[0], padding[1], strides[0], strides[1], 1, 1, CUDNN_CROSS_CORRELATION);
+		return convDesc;
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/96ae6c7e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
new file mode 100644
index 0000000..f817bd5
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.matrix.data;
+
+import static jcuda.jcudnn.JCudnn.cudnnCreatePoolingDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetPooling2dDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
+import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
+import static jcuda.jcudnn.cudnnNanPropagation.CUDNN_PROPAGATE_NAN;
+import static jcuda.jcudnn.cudnnPoolingMode.CUDNN_POOLING_MAX;
+import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+
+import jcuda.jcudnn.cudnnPoolingDescriptor;
+import jcuda.jcudnn.cudnnTensorDescriptor;
+
+/**
+ * This class is a wrapper that contain necessary data structures to invoke 
+ * a cudnn convolution* functions (such as cudnnConvolutionForward, etc)
+ * 
+ * It implements autocloseable to simplify the LibMatrixCuDNN code and also avoids potential memory leaks.
+ */
+public class LibMatrixCuDNNPoolingDescriptors implements java.lang.AutoCloseable {
+
+	public cudnnTensorDescriptor xDesc; 
+	public cudnnTensorDescriptor yDesc; 
+	public cudnnTensorDescriptor dxDesc; 
+	public cudnnTensorDescriptor dyDesc; 
+	public cudnnPoolingDescriptor poolingDesc;
+	
+	@Override
+	public void close() {
+		if(xDesc != null) 
+			cudnnDestroyTensorDescriptor(xDesc);
+		if(yDesc != null) 
+			cudnnDestroyTensorDescriptor(yDesc);
+		if(dxDesc != null) 
+			cudnnDestroyTensorDescriptor(dxDesc);
+		if(dyDesc != null) 
+			cudnnDestroyTensorDescriptor(dyDesc);
+		if(poolingDesc != null)
+			jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor(poolingDesc);
+	}
+	
+	/**
+	 * Get descriptors for maxpooling backward operation
+	 * 
+	 * @param gCtx gpu context
+	 * @param instName instruction name
+	 * @param N				batch size
+	 * @param C				number of channels
+	 * @param H				height of image
+	 * @param W				width of image
+	 * @param K				number of filters
+	 * @param R				height of filter
+	 * @param S				width of filter
+	 * @param pad_h			vertical padding
+	 * @param pad_w			horizontal padding
+	 * @param stride_h		horizontal stride
+	 * @param stride_w		vertical stride
+	 * @param P				(H - R + 1 + 2*pad_h)/stride_h
+	 * @param Q				(W - S + 1 + 2*pad_w)/stride_w
+	 * @return decriptor wrapper
+	 * @throws DMLRuntimeException if error occurs
+	 */
+	public static LibMatrixCuDNNPoolingDescriptors cudnnMaxpoolingBackwardDescriptors(GPUContext gCtx, 
+			String instName, int N, int C, int H, int W, int K, int R,
+			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
+			int Q) throws DMLRuntimeException {
+		LibMatrixCuDNNPoolingDescriptors ret = new LibMatrixCuDNNPoolingDescriptors();
+		ret.xDesc = allocateTensorDescriptor(N, C, H, W);
+		ret.yDesc = allocateTensorDescriptor(N, C, P, Q);
+		ret.dxDesc = allocateTensorDescriptor(N, C, H, W);
+		ret.dyDesc = allocateTensorDescriptor(N, C, P, Q);
+		ret.poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
+		return ret;
+	}
+	
+	/**
+	 * Get descriptors for maxpooling operation
+	 * 
+	 * @param gCtx gpu context
+	 * @param instName instruction name
+	 * @param N				batch size
+	 * @param C				number of channels
+	 * @param H				height of image
+	 * @param W				width of image
+	 * @param K				number of filters
+	 * @param R				height of filter
+	 * @param S				width of filter
+	 * @param pad_h			vertical padding
+	 * @param pad_w			horizontal padding
+	 * @param stride_h		horizontal stride
+	 * @param stride_w		vertical stride
+	 * @param P				(H - R + 1 + 2*pad_h)/stride_h
+	 * @param Q				(W - S + 1 + 2*pad_w)/stride_w
+	 * @return decriptor wrapper
+	 * @throws DMLRuntimeException if error occurs
+	 */
+	public static LibMatrixCuDNNPoolingDescriptors cudnnMaxpoolingDescriptors(GPUContext gCtx, 
+			String instName, int N, int C, int H, int W, int K, int R,
+			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
+			int Q) throws DMLRuntimeException {
+		LibMatrixCuDNNPoolingDescriptors ret = new LibMatrixCuDNNPoolingDescriptors();
+		ret.xDesc = allocateTensorDescriptor(N, C, H, W);
+		ret.yDesc = allocateTensorDescriptor(N, C, P, Q);
+		ret.poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
+		return ret;
+	}
+
+	/**
+	 * Convenience method to get tensor descriptor
+	 * @param N number of images
+	 * @param C number of channels
+	 * @param H height
+	 * @param W width
+	 * @return cudnn tensor descriptor
+	 * @throws DMLRuntimeException if the input descriptor and matrix dimensions don't match
+	 */
+	private static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) throws DMLRuntimeException {
+		cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
+		cudnnCreateTensorDescriptor(tensorDescriptor);
+		cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
+		return tensorDescriptor;
+	}
+	
+	/**
+	 * allocates pooling descriptor, used in poolingForward and poolingBackward
+	 * @param R			pooling window height
+	 * @param S			pooling window width
+	 * @param pad_h		vertical padding
+	 * @param pad_w		horizontal padding
+	 * @param stride_h	pooling vertical stride
+	 * @param stride_w	pooling horizontal stride
+	 * @return cudnn pooling descriptor
+	 */
+	private static cudnnPoolingDescriptor allocatePoolingDescriptor(int R, int S, int pad_h, int pad_w, int stride_h, int stride_w) {
+		cudnnPoolingDescriptor poolingDesc = new cudnnPoolingDescriptor();
+		cudnnCreatePoolingDescriptor(poolingDesc);
+		cudnnSetPooling2dDescriptor(poolingDesc, CUDNN_POOLING_MAX, CUDNN_PROPAGATE_NAN, R, S, pad_h, pad_w, stride_h, stride_w);
+		return poolingDesc;
+	}
+}