You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/10/28 20:48:30 UTC

systemml git commit: [SYSTEMML-540] Avoid redundant computation of cudnnPoolingForward in max_pool_backward

Repository: systemml
Updated Branches:
  refs/heads/master 118e3c0f6 -> 06d5bb073


[SYSTEMML-540] Avoid redundant computation of cudnnPoolingForward in max_pool_backward

- If the max_pool is invoked in the forward pass, then its output can be
  reused by the max_pool_backward rather than calling cudnnPoolingForward
  again. For sentence CNN with 2 epochs, this reduces the time for
  max_pool_backward from 6.361 to 2.966 seconds.

Closes #691.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/06d5bb07
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/06d5bb07
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/06d5bb07

Branch: refs/heads/master
Commit: 06d5bb073792345f7c4b7ecd0fb4454a335cc421
Parents: 118e3c0
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Sat Oct 28 13:44:37 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Sat Oct 28 13:45:52 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/ConvolutionOp.java    | 163 +++++++++++++------
 .../gpu/ConvolutionGPUInstruction.java          |  43 ++++-
 .../runtime/matrix/data/LibMatrixCuDNN.java     |  51 +++---
 .../sysml/test/gpu/NeuralNetworkOpTests.java    |  82 ++++++++++
 4 files changed, 260 insertions(+), 79 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/06d5bb07/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index 50a7ca3..16a8b63 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -47,14 +47,23 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 	private static final boolean THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH = true;
 	// -------------------------------------------------------------------------
 	
+	// Specifies the type of this hop
 	private Hop.ConvOp op;
-
 	private int _maxNumThreads = -1; //-1 for unlimited
 
 	private ConvolutionOp() {
 		//default constructor for clone
 	}
 
+	/**
+	 * Create a hop from the builtin expression
+	 * 
+	 * @param l name of the hop
+	 * @param dt datatype (only supports matrix datatype)
+	 * @param vt valuetype  (only supports matrix valuetype) 
+	 * @param o type of this hop
+	 * @param inp input hops
+	 */
 	public ConvolutionOp(String l, DataType dt, ValueType vt, ConvOp o, ArrayList<Hop> inp) 
 	{
 		super(l, dt, vt);
@@ -75,8 +84,7 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		HopsException.check(_input.size() >= 1, this, "should have at least one input but has %d inputs", _input.size());
 	}
 
-	public ConvOp getOp()
-	{
+	public ConvOp getOp() {
 		return op;
 	}
 	
@@ -163,77 +171,129 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		return input instanceof ConvolutionOp && ((ConvolutionOp) input).getOp() == ConvOp.DIRECT_CONV2D;
 	}
 	
+	/**
+	 * Compares the input parameters for max_pool/max_pool_backward operations
+	 * 
+	 * @return true if the following parameters match: stride=[stride, stride], padding=[pad, pad], input_shape=[numImg, numChannels, imgSize, imgSize], pool_size=[poolSize1, poolSize2]
+	 */
+	private static boolean isPoolingParametersEqualAndKnown(ConvolutionParameters param1, ConvolutionParameters param2) {
+		return isEqualAndKnown(param1.stride_h, param2.stride_h) && isEqualAndKnown(param1.stride_w, param2.stride_w) && 
+			isEqualAndKnown(param1.pad_h, param2.pad_h) && isEqualAndKnown(param1.pad_w, param2.pad_w) &&
+			isEqualAndKnown(param1.R, param2.R) && isEqualAndKnown(param1.S, param2.S) &&
+			isEqualAndKnown(param1.N, param2.N) && isEqualAndKnown(param1.C, param2.C) &&
+			isEqualAndKnown(param1.H, param2.H) && isEqualAndKnown(param1.W, param2.W);
+	}
+	
+	private static boolean isEqualAndKnown(int val1, int val2) {
+		return val1 >= 0 && val2 >= 0 && val1 == val2;
+	}
+	
+	/**
+	 * Returns the output lop of maxpool operation with same parameters as this hop.
+	 * If corresponding output lop is not found or if this is not a max_pool_backward operation, this function returns null
+	 * 
+	 * @return output lop of maxpool operation with same parameters as this hop
+	 * @throws HopsException if error 
+	 * @throws LopsException if error
+	 */
+	private Lop getMaxPoolOutputLop() throws HopsException, LopsException {
+		if(op != ConvOp.MAX_POOLING_BACKWARD)
+			return null;
+		
+		Hop inputImage = getInput().get(0);
+		for(Hop tmpParent : inputImage.getParent()) {
+			if(!(tmpParent instanceof ConvolutionOp))
+				continue;
+			ConvolutionOp parent = (ConvolutionOp) tmpParent;
+			if(parent.getOp() == ConvOp.MAX_POOLING && isPoolingParametersEqualAndKnown(parent._cachedParams, _cachedParams)) {
+				return parent.constructLops();
+			}
+		}
+		return null;
+	}
+	
 	public Lop constructConvolutionLops(ExecType et, ArrayList<Hop> inputs) throws HopsException, LopsException {
 		if(inputs.size() != getNumExpectedInputs()) 
 			throw new HopsException("Incorrect number of inputs for " + op.name());
 		
-		Lop in = null; Lop in2 = null;
-		ArrayList<Hop> inputs1 = inputs;
-		int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
+		// ---------------------------------------------------------------
+		// Deal with fused operators and contruct lhsInputLop/optionalRhsInputLop
+		Lop lhsInputLop = null; Lop optionalRhsInputLop = null;
+		ArrayList<Hop> inputsOfPotentiallyFusedOp = inputs;
 		OperationTypes lopOp = HopsConv2Lops.get(op);
-
+		
 		// RELU_MAX_POOLING and RELU_MAX_POOLING_BACKWARD is extremely useful for CP backend 
 		// by reducing unnecessary sparse-to-dense-to-sparse conversion.
 		// For other backends, this operators is not necessary as it reduces an additional relu operator.
 		if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == ConvOp.MAX_POOLING && isInputReLU(inputs.get(0))) {
-			in = inputs.get(0).getInput().get(0).constructLops();
+			lhsInputLop = inputs.get(0).getInput().get(0).constructLops();
 			lopOp = OperationTypes.RELU_MAX_POOLING;
 		}
 		else if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == ConvOp.MAX_POOLING_BACKWARD && isInputReLU(inputs.get(0))) {
-			in = inputs.get(0).getInput().get(0).constructLops();
+			lhsInputLop = inputs.get(0).getInput().get(0).constructLops();
 			lopOp = OperationTypes.RELU_MAX_POOLING_BACKWARD;
 		}
 		else if(OptimizerUtils.ALLOW_OPERATOR_FUSION && op == ConvOp.BIAS_ADD && isInputConv2d(inputs.get(0))) {
 			lopOp = OperationTypes.DIRECT_CONV2D_BIAS_ADD;
 			
 			// the first lop is image 
-			in = inputs.get(0).getInput().get(0).constructLops();
+			lhsInputLop = inputs.get(0).getInput().get(0).constructLops();
 			// the second lop is bias
-			in2 = inputs.get(1).constructLops();
+			optionalRhsInputLop = inputs.get(1).constructLops();
 			
 			// Use the inputs from conv2d rather than bias_add
-			inputs1 = inputs.get(0).getInput();
+			inputsOfPotentiallyFusedOp = inputs.get(0).getInput();
 		}
 		else {
-			in = inputs.get(0).constructLops();
+			lhsInputLop = inputs.get(0).constructLops();
 		}
+		// ---------------------------------------------------------------
 		
-//		// TODO: Inserting reblock requires knowing columns apriori
-//		ConvolutionTransform transform1 = new ConvolutionTransform(addReblockIfNecessary(et, lopOp, in), lopOp, getDataType(), getValueType(), et, k);
-//		setReblockedOutputDimension(et, transform1);
-		double cpIntermediateMemEstimate = computeIntermediateMemEstimate(-1, -1, -1 );
+		// ---------------------------------------------------------------
+		// Compute intermediate memory budget that can be passed to GPU operators 
+		// for better CuDNN operator selection at runtime
+		double intermediateMemEstimate = computeIntermediateMemEstimate(-1, -1, -1 );
 		if(et == ExecType.GPU && _dim1 > 0 && _dim2 > 0) {
 			// This enables us to compile more efficient matrix-matrix CuDNN operation instead of 
 			// row-by-row invocation of multiple vector-matrix CuDNN operations.
 			// This is possible as the operations on GPU are single-threaded
 			double optimisticIntermediateMemEstimate = GPUContextPool.initialGPUMemBudget() - getOutputMemEstimate() - inputs.get(0).getOutputMemEstimate();
-			if(in2 != null) {
+			if(optionalRhsInputLop != null) {
 				optimisticIntermediateMemEstimate -= inputs.get(1).getOutputMemEstimate();
 			}
-			cpIntermediateMemEstimate = Math.max(cpIntermediateMemEstimate, optimisticIntermediateMemEstimate);
+			intermediateMemEstimate = Math.max(intermediateMemEstimate, optimisticIntermediateMemEstimate);
 		}
-		ConvolutionTransform transform1 = new ConvolutionTransform(in, lopOp, getDataType(), getValueType(), et, k, cpIntermediateMemEstimate);
-		setOutputDimensions(transform1);
+		// ---------------------------------------------------------------
 		
-		setLineNumbers(transform1);
-		in.addOutput(transform1);
+		// Contruct the lop
+		ConvolutionTransform convolutionLop = new ConvolutionTransform(lhsInputLop, lopOp, 
+				getDataType(), getValueType(), et, OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), intermediateMemEstimate);
 		
-		if(in2 != null) {
-			transform1.addInput(in2);
-			in2.addOutput(transform1);
-		}
+		// Propagate the output dimensions and the line number of ConvolutionOp to ConvolutionTransform
+		setOutputDimensions(convolutionLop);
+		setLineNumbers(convolutionLop);
 		
-		// stride1, stride2, padding1, padding2  
-		// input_shape1, input_shape2, input_shape3, input_shape4, 
-		// filter_shape1, filter_shape2, filter_shape3, filter_shape4
-		for( int i=1; i < inputs1.size(); i++ )
-		{
-			Lop ltmp = inputs1.get(i).constructLops();
-			transform1.addInput(ltmp);
-			ltmp.addOutput(transform1);
+		// ---------------------------------------------------------------
+		// Add input/output for parent lops of convolutionLop
+		lhsInputLop.addOutput(convolutionLop);
+		if(optionalRhsInputLop != null) {
+			convolutionLop.addInput(optionalRhsInputLop);
+			optionalRhsInputLop.addOutput(convolutionLop);
+		}
+		for( int i=1; i < inputsOfPotentiallyFusedOp.size(); i++ ) {
+			Lop ltmp = inputsOfPotentiallyFusedOp.get(i).constructLops();
+			convolutionLop.addInput(ltmp);
+			ltmp.addOutput(convolutionLop);
 		}
-		transform1.setLevel(); //force order of added lops
-		return transform1;
+		// Only valid for MAX_POOLING_BACKWARD on GPU
+		Lop optionalMaxPoolOutput = (et == ExecType.GPU) ? getMaxPoolOutputLop() : null; 
+		if(optionalMaxPoolOutput != null) {
+			convolutionLop.addInput(optionalMaxPoolOutput);
+			optionalMaxPoolOutput.addOutput(convolutionLop);
+		}
+		convolutionLop.setLevel(); //force order of added lops
+		// ---------------------------------------------------------------
+		return convolutionLop;
 	}
 
 			
@@ -453,12 +513,10 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		
 		ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
 		
-		if( _etypeForced != null ) 			
-		{
+		if( _etypeForced != null ) {
 			_etype = _etypeForced;
 		}
-		else 
-		{	
+		else {	
 			if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
 				_etype = findExecTypeByMemEstimate();
 			}
@@ -479,8 +537,9 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		return _etype;
 	}
 	
-	// Caching parameters speed-ups dynamic recompilation time by avoiding unnecessary computeSizeInformation
+	// Parameters recomputed in refreshSizeInformation and passed across many calls of getDim
 	private ConvolutionParameters _cachedParams = new ConvolutionParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, _maxNumThreads);
+	
 	// stride1, stride2, padding1, padding2  
 	// input_shape1, input_shape2, input_shape3, input_shape4, 
 	// filter_shape1, filter_shape2, filter_shape3, filter_shape4
@@ -494,16 +553,16 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 			imageHeightHop = getInput().get(8);
 			filterHeightHop = getInput().get(12);
 			_cachedParams.setIfUnknown(
-					getInput().get(6),
-					getInput().get(7), 
-					imageHeightHop, 
-					getInput().get(9), 
-					getInput().get(10), 
-					filterHeightHop, 
-					getInput().get(13), 
-					getInput().get(2), 
-					getInput().get(3), 
-					getInput().get(4), 
+					getInput().get(6),  // N
+					getInput().get(7),  // C
+					imageHeightHop,     // H
+					getInput().get(9),  // W
+					getInput().get(10), // K
+					filterHeightHop,    // R
+					getInput().get(13), // S
+					getInput().get(2),  // stride_h
+					getInput().get(3),  // stride_w
+					getInput().get(4),  // pad+h
 					getInput().get(5), _maxNumThreads);
 		}
 		else {

http://git-wip-us.apache.org/repos/asf/systemml/blob/06d5bb07/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
index 354ea63..8565b5a 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
@@ -92,8 +92,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 		
 		if( ( opcode.equalsIgnoreCase("conv2d")
 			 || opcode.equalsIgnoreCase("conv2d_backward_filter")
-			 || opcode.equalsIgnoreCase("conv2d_backward_data")
-			 || opcode.equalsIgnoreCase("maxpooling_backward")) ) {
+			 || opcode.equalsIgnoreCase("conv2d_backward_data")) ) {
 			InstructionUtils.checkNumFields(parts, 16);
 			CPOperand in1 = new CPOperand(parts[1]);
 			CPOperand in2 = new CPOperand(parts[2]);
@@ -119,6 +118,39 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 			return new ConvolutionGPUInstruction(in1, in2, out, opcode, str, stride,
 					padding, input_shape, filter_shape, Double.parseDouble(parts[16]));
 		}
+		else if( opcode.equalsIgnoreCase("maxpooling_backward") ) {
+			boolean withMaxPoolOut = false;
+			if(parts.length == 18) {
+				withMaxPoolOut = true;
+			}
+			else
+				InstructionUtils.checkNumFields(parts, 16);
+			CPOperand in1 = new CPOperand(parts[1]);
+			CPOperand in2 = new CPOperand(parts[2]);
+			CPOperand in3 = withMaxPoolOut ? new CPOperand(parts[15]) : null;
+			CPOperand out = withMaxPoolOut ? new CPOperand(parts[16]) : new CPOperand(parts[15]);
+			double memBudget = withMaxPoolOut ? Double.parseDouble(parts[17]) : Double.parseDouble(parts[16]);
+		
+			ArrayList<CPOperand> stride = new ArrayList<>();
+			ArrayList<CPOperand> padding = new ArrayList<>();
+			ArrayList<CPOperand> input_shape = new ArrayList<>();
+			ArrayList<CPOperand> filter_shape = new ArrayList<>();
+			stride.add(new CPOperand(parts[3]));
+			stride.add(new CPOperand(parts[4]));
+			padding.add(new CPOperand(parts[5]));
+			padding.add(new CPOperand(parts[6]));
+			input_shape.add(new CPOperand(parts[7]));
+			input_shape.add(new CPOperand(parts[8]));
+			input_shape.add(new CPOperand(parts[9]));
+			input_shape.add(new CPOperand(parts[10]));
+			filter_shape.add(new CPOperand(parts[11]));
+			filter_shape.add(new CPOperand(parts[12]));
+			filter_shape.add(new CPOperand(parts[13]));
+			filter_shape.add(new CPOperand(parts[14]));
+
+			return new ConvolutionGPUInstruction(in1, in2, in3, out, opcode, str, stride,
+					padding, input_shape, filter_shape, memBudget);
+		}
 		else if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
 			InstructionUtils.checkNumFields(parts, 17);
 			CPOperand in1 = new CPOperand(parts[1]);
@@ -324,7 +356,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 		else if (instOpcode.equalsIgnoreCase("maxpooling_backward")) {
 			MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
 			MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
-			
+			MatrixObject maxPoolOutput = _input3 != null ? getMatrixInputForGPUInstruction(ec, _input3.getName()) : null;
 			if(dout.getNumRows() != N || dout.getNumColumns() != C*P*Q) 
 				throw new DMLRuntimeException("Incorrect dimensions for dout in maxpooling_backward");
 			if(image.getNumRows() != N || image.getNumColumns() != C*H*W) 
@@ -333,7 +365,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 			
 			MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W);
 			
-			LibMatrixCuDNN.maxpoolingBackward(ec.getGPUContext(0), getExtendedOpcode(), image, dout, out, N, C, H, W,
+			LibMatrixCuDNN.maxpoolingBackward(ec.getGPUContext(0), getExtendedOpcode(), image, dout, maxPoolOutput, out, N, C, H, W,
 					K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
 		}
 		else {
@@ -346,7 +378,8 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 		if ( !instOpcode.equalsIgnoreCase("maxpooling") )
 			ec.releaseMatrixInputForGPUInstruction(_input2.getName());
 
-		if (instOpcode.equalsIgnoreCase("conv2d_bias_add"))
+		if (instOpcode.equalsIgnoreCase("conv2d_bias_add") || 
+			(instOpcode.equalsIgnoreCase("maxpooling_backward") && _input3 != null))
 			ec.releaseMatrixInputForGPUInstruction(_input3.getName());
 
 		ec.releaseMatrixOutputForGPUInstruction(_output.getName());

http://git-wip-us.apache.org/repos/asf/systemml/blob/06d5bb07/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index 7fd766c..e0a6a57 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -519,6 +519,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 	 * @param instName the invoking instruction's name for record {@link Statistics}.
 	 * @param image image as matrix object
 	 * @param dout			delta matrix, output of previous layer
+	 * @param maxpoolOutput (optional and can be null) output of maxpool forward function
 	 * @param outputBlock output matrix
 	 * @param N				batch size
 	 * @param C				number of channels
@@ -537,12 +538,14 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
 	public static void maxpoolingBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
-			MatrixObject outputBlock, int N, int C, int H, int W, int K, int R,
+			MatrixObject maxpoolOutput, MatrixObject outputBlock, int N, int C, int H, int W, int K, int R,
 			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
 			int Q, double intermediateMemoryBudget) throws DMLRuntimeException {
 		long CHW = C*H*W; long CPQ = C*P*Q;  
 		long NCHW = N*CHW; long NCPQ = N*CPQ; 
 
+		final boolean isMaxPoolOutputProvided = maxpoolOutput != null;
+		
 		if(NCHW < maxNumElementsOfCuDNNTensor && NCPQ < maxNumElementsOfCuDNNTensor) {
 			// Filter and output are accounted as dense in the memory estimation for conv2dBackwardData
 			long overhead = isInSparseFormat(gCtx, image) ? OptimizerUtils.estimateSizeExactSparsity(N, CHW, 1.0) : 0;
@@ -551,19 +554,26 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 			if(overhead <= intermediateMemoryBudget) {
 				Pointer x = getDensePointerForCuDNN(gCtx, image, instName);
 				Pointer dy = getDensePointerForCuDNN(gCtx, dout, instName);
-				cudnnMaxpoolingBackward(gCtx, instName, x, dy, dx, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+				Pointer y = isMaxPoolOutputProvided ? getDensePointerForCuDNN(gCtx, maxpoolOutput, instName) : null;
+				cudnnMaxpoolingBackward(gCtx, instName, x, dy, y, dx, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
 			}
 			else {
 				LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
 				LibMatrixCuDNNInputRowFetcher doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout);
+				LibMatrixCuDNNInputRowFetcher maxPoolOutFetcher = isMaxPoolOutputProvided ? new LibMatrixCuDNNInputRowFetcher(gCtx, instName, maxpoolOutput) : null;
 				for(int n = 0; n < N; n++) {
-					cudnnMaxpoolingBackward(gCtx, instName, imgFetcher.getNthRow(n), doutFetcher.getNthRow(n), 
+					Pointer x = imgFetcher.getNthRow(n);
+					Pointer dy = doutFetcher.getNthRow(n);
+					Pointer y = isMaxPoolOutputProvided ? maxPoolOutFetcher.getNthRow(n) : null;
+					cudnnMaxpoolingBackward(gCtx, instName, x, dy, y, 
 							dx.withByteOffset(n*CHW*sizeOfDataType), 
 							1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
 				}
 				// Deallocate temporary array to hold one element of input
 				imgFetcher.close();
 				doutFetcher.close();
+				if(isMaxPoolOutputProvided)
+					maxPoolOutFetcher.close();
 			}
 		}
 		else {
@@ -572,36 +582,33 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 	}
 	
 	private static void cudnnMaxpoolingBackward(GPUContext gCtx, String instName, 
-			Pointer x, Pointer dy, Pointer dx, 
+			Pointer x, Pointer dy, Pointer y, Pointer dx, 
 			int N, int C, int H, int W, int K, int R,
 			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
 			int Q) throws DMLRuntimeException {
 		if(LOG.isTraceEnabled()) {
 			LOG.trace("GPU : maxpoolingBackward" + ", GPUContext=" + gCtx);
 		}
-		Pointer y = null;
+		
+		boolean isMaxPoolOutputProvided = (y != null);
 
 		try(LibMatrixCuDNNPoolingDescriptors desc = 
 				LibMatrixCuDNNPoolingDescriptors.cudnnMaxpoolingBackwardDescriptors(gCtx, instName, N, C, H, W, K, R, S, 
 						pad_h, pad_w, stride_h, stride_w, P, Q)) {
 			long t1=0, t2=0, t3=0;
-			if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
-			
-			// Calling PoolForward first, y is one of the inputs for poolBackward
-			// TODO: Remove calling poolForward after necessary changes at language level for poolBackward
-			long numBytes = N*C*P*Q*sizeOfDataType;
-			y = gCtx.allocate(numBytes);
-			
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-			
-			if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
-			int status = cudnnPoolingForward(getCudnnHandle(gCtx), desc.poolingDesc, one(), desc.xDesc, x, zero(), desc.yDesc, y);
-			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - t2);
-
-			if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
-				throw new DMLRuntimeException("Could not executed cudnnPoolingForward before cudnnPoolingBackward: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
+			int status;
+			if(!isMaxPoolOutputProvided) {
+				if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
+				long numBytes = N*C*P*Q*sizeOfDataType;
+				y = gCtx.allocate(numBytes);
+				if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
+				if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
+				status = cudnnPoolingForward(getCudnnHandle(gCtx), desc.poolingDesc, one(), desc.xDesc, x, zero(), desc.yDesc, y);
+				if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - t2);
+				if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
+					throw new DMLRuntimeException("Could not executed cudnnPoolingForward before cudnnPoolingBackward: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
+				}
 			}
-
 			if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
 			status = cudnnPoolingBackward(getCudnnHandle(gCtx), desc.poolingDesc, one(), desc.yDesc, y, desc.dyDesc, dy, desc.xDesc, x, zero(), desc.dxDesc, dx);
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_BACKWARD_LIB, System.nanoTime() - t3);
@@ -615,7 +622,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		finally {
 			long t4=0;
 			if (GPUStatistics.DISPLAY_STATISTICS) t4 = System.nanoTime();
-			if(y != null)
+			if(!isMaxPoolOutputProvided)
 				gCtx.cudaFreeHelper(instName, y);
 			if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t4);
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/06d5bb07/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java b/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
index aba0cae..c57e997 100644
--- a/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
@@ -579,5 +579,87 @@ public class NeuralNetworkOpTests extends GPUTests {
 			}
 		}
 	}
+	
+	
+	@Test
+	@Ignore
+	public void testMaxPoolBackwardWithMaxpoolOut() {
+		String scriptStr = "tmp = max_pool(image, padding=[padH, padW], stride=[strideH, strideW], input_shape=[N,C,H,W], pool_size=[R,S]); print(sum(tmp)); O = max_pool_backward(image, dout, padding=[padH, padW], stride=[strideH, strideW], input_shape=[N,C,H,W], pool_size=[R,S])";
+
+		for (long N : Nlst) {
+			for (long C : Clst) {
+				for (long H : Hlst) {
+					long W = H;
+					for (long R : Rlst) {
+						long S = R;
+						for (long strideH : strideLst) {
+							long strideW = strideH;
+							for (long padH : padLst) {
+								long padW = padH;
+								for (double sparsity : sparsitylst) {
+
+									// pool is smaller than image + padding
+									if (R > (H + padH) || S > (W + padW))
+										continue;
+
+									// Make sure ops fit in GPU memory and within constraints of cudnn
+									long imageSize = N * C * H * W * 8l;
+									if (imageSize > MAX_OP_SIZE)  // image size
+										continue;
+									long poolSize = R * S * 8l;
+									if (poolSize > MAX_OP_SIZE)  // filter size
+										continue;
+
+									int P = (int) ConvolutionUtils.getP(H, R, strideH, padH);
+									int Q = (int) ConvolutionUtils.getQ(W, S, strideW, padW);
+
+									long doutSize = N * C * P * Q * 8l;
+									if (doutSize > MAX_OP_SIZE) // dout/output size
+										continue;
+
+									double imageSizeInMB = imageSize / (1024.0 * 1024.0);
+									double poolSizeInMB = poolSize / (1024.0 * 1024.0);
+									double doutSizeInMB = doutSize / (1024.0 * 1024.0);
+									System.out
+									.format("max_pool_backward, image[%d,%d,%d,%d](%.1fMB), pool[%d,%d](%.1f), dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], padding[%d,%d]",
+											N, C, H, W, imageSizeInMB, R, S, poolSizeInMB, N, C,
+											P, Q, doutSizeInMB, strideH, strideW, padH, padW);
+
+									Matrix image = generateInputMatrix(spark, (int) N,
+											(int) (C * H * W), -127.0, 127, sparsity, seed, true);
+									Matrix dout = generateInputMatrix(spark, (int) N, (int) (C * P * Q),
+											-127.0, 127, sparsity, seed, true);
+									HashMap<String, Object> inputs = new HashMap<>();
+									inputs.put("N", N);
+									inputs.put("C", C);
+									inputs.put("H", H);
+									inputs.put("W", W);
+									inputs.put("R", R);
+									inputs.put("S", S);
+									inputs.put("strideH", strideH);
+									inputs.put("strideW", strideW);
+									inputs.put("padH", padH);
+									inputs.put("padW", padW);
+									inputs.put("image", image);
+									inputs.put("dout", dout);
+									List<Object> outCPU = runOnCPU(spark, scriptStr, inputs,
+											Arrays.asList("O"));
+									List<Object> outGPU = runOnGPU(spark, scriptStr, inputs,
+											Arrays.asList("O"));
+									assertHeavyHitterPresent("gpu_maxpooling_backward");
+									assertEqualObjects(outCPU.get(0), outGPU.get(0));
+									clearGPUMemory();
+								}
+							}
+						}
+					}
+
+
+
+
+				}
+			}
+		}
+	}
 
 }