You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/02/24 23:23:24 UTC

incubator-systemml git commit: [SYSTEMML-1343] Spark Convolution and Pooling forward instruction (Map-side only)

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 2f7fa8d73 -> 9c31c2ca7


[SYSTEMML-1343] Spark Convolution and Pooling forward instruction (Map-side only)

This commit support the instructions necessary for distributed forward
pass (for example: prediction). Since this is a first cut, we are only
supporting map-side only operators which have following constraints:
- Weights needs to be smaller than 2G (i.e. fit into Spark's broadcast
budget).
- First convolution pays the penalty of reblock if used while prediction
over entire dataset. If used in a loop, then first convolution of every
iteration will pay the penalty.

In subsequent commit, we will introduce a more general convolution/pooling
operators that will support distributed training as well.

Closes #402.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/9c31c2ca
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/9c31c2ca
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/9c31c2ca

Branch: refs/heads/master
Commit: 9c31c2ca7624c61357894ec463223c507165d1d8
Parents: 2f7fa8d
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Fri Feb 24 15:17:17 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Fri Feb 24 15:20:57 2017 -0800

----------------------------------------------------------------------
 .../org/apache/sysml/hops/ConvolutionOp.java    | 116 ++++--
 .../instructions/SPInstructionParser.java       |  10 +
 .../spark/ConvolutionSPInstruction.java         | 405 +++++++++++++++++++
 .../instructions/spark/SPInstruction.java       |   1 +
 .../matrix/data/ConvolutionParameters.java      |   5 +-
 .../sysml/runtime/matrix/data/LibMatrixDNN.java |   7 +-
 .../functions/tensor/Conv2DTest.java            |  89 +++-
 .../integration/functions/tensor/PoolTest.java  |  31 ++
 8 files changed, 618 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c31c2ca/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index 9f67968..7751999 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -21,7 +21,6 @@ package org.apache.sysml.hops;
 
 import java.util.ArrayList;
 
-import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.Hop.MultiThreadedHop;
 import org.apache.sysml.lops.ConvolutionTransform;
@@ -29,6 +28,7 @@ import org.apache.sysml.lops.ConvolutionTransform.OperationTypes;
 import org.apache.sysml.lops.Lop;
 import org.apache.sysml.lops.LopsException;
 import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.lops.ReBlock;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
@@ -70,6 +70,10 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		return "" + HopsConv2Lops.get(op);
 	}
 
+	private boolean isEligibleForSpark() {
+		return (op == ConvOp.DIRECT_CONV2D || op == ConvOp.MAX_POOLING) ? true : false;
+	}
+	
 	@Override
 	public Lop constructLops()
 		throws HopsException, LopsException 
@@ -90,19 +94,11 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 			case DIRECT_CONV2D_BACKWARD_FILTER:
 			case BIAS_ADD:
 			{	
-				//TODO: Fix me. Currently forcing the instruction to GPU if gpu flag is set
-				if(DMLScript.USE_ACCELERATOR) {
-					et = ExecType.GPU;
+				if(et == ExecType.CP || et == ExecType.GPU || et == ExecType.SPARK) {
 					setLops(constructConvolutionLops(et, inputs));
 					break;
 				}
-				else if(et == ExecType.CP) {
-					setLops(constructConvolutionLops(et, inputs));
-					break;
-				}			
 				else {
-					// TODO: Add support for SPARK/MR backends once we are happy with the performance of
-					// single node Lenet script. 
 					throw new HopsException("Unimplemented ConvolutionOp for execution type: " + et.name());
 				}
 				// break;
@@ -120,34 +116,74 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 	public void setOp(ConvOp op) {
 		this.op = op;
 	}
-
-	public Lop constructConvolutionLops(ExecType et, ArrayList<Hop> inputs) throws HopsException, LopsException {
-		int expectedNumInputs = 13;
-		if(op == ConvOp.MAX_POOLING_BACKWARD 
-				|| op == ConvOp.DIRECT_CONV2D 
-				|| op == ConvOp.DIRECT_CONV2D_BACKWARD_FILTER
-				|| op == ConvOp.DIRECT_CONV2D_BACKWARD_DATA) {
-			expectedNumInputs = 14;
+	
+	private int getNumExpectedInputs() {
+		switch(op) {
+			case MAX_POOLING_BACKWARD: 
+			case DIRECT_CONV2D:
+			case DIRECT_CONV2D_BACKWARD_FILTER:
+			case DIRECT_CONV2D_BACKWARD_DATA:
+				return 14;
+			case BIAS_ADD:
+				return 2;
+			default:
+				return 13;
 		}
-		else if(op == ConvOp.BIAS_ADD) {
-			expectedNumInputs = 2;
+	}
+	
+	private boolean isInputReLU(Hop input) {
+		return input instanceof UnaryOp && ((UnaryOp) input).getOp() == OpOp1.SELP;
+	}
+	
+	private boolean isInputConv2d(Hop input) {
+		return input instanceof ConvolutionOp && ((ConvolutionOp) input).getOp() == ConvOp.DIRECT_CONV2D;
+	}
+	
+	@SuppressWarnings("unused")
+	private Lop addReblockIfNecessary(ExecType et, OperationTypes lopOp, Lop in) throws LopsException {
+		if(et == ExecType.SPARK) {
+			switch(lopOp) {
+				case MAX_POOLING:
+				case RELU_MAX_POOLING:
+				case DIRECT_CONV2D:
+				case DIRECT_CONV2D_BIAS_ADD:
+					if(in.getOutputParameters().getColsInBlock() < in.getOutputParameters().getNumCols() || 
+						in.getOutputParameters().getRowsInBlock() != 1) {
+						// Need to add a reblock
+						return new ReBlock(in, 1L, in.getOutputParameters().getNumCols(), DataType.MATRIX, ValueType.DOUBLE, true, et);
+					}
+					else 
+						return in;
+				default:
+					throw new LopsException("Spark operator is not implemented for " + lopOp.name());
+			}
 		}
-		
-		if(inputs.size() != expectedNumInputs) {
-			throw new HopsException("Incorrect number of inputs for " + op.name());
+		return in;
+	}
+	
+	@SuppressWarnings("unused")
+	private void setReblockedOutputDimension(ExecType et, Lop lop) throws HopsException {
+		if(et == ExecType.SPARK) {
+			lop.getOutputParameters().setDimensions(getDim1(), getDim2(), 1L, getDim2(), getNnz(), getUpdateType());
 		}
+		else {
+			setOutputDimensions(lop);
+		}
+	}
+	
+	public Lop constructConvolutionLops(ExecType et, ArrayList<Hop> inputs) throws HopsException, LopsException {
+		if(inputs.size() != getNumExpectedInputs()) 
+			throw new HopsException("Incorrect number of inputs for " + op.name());
 		
 		Lop in = null; Lop in2 = null;
-		OperationTypes lopOp = HopsConv2Lops.get(op);
-		int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
 		ArrayList<Hop> inputs1 = inputs;
-		if(op == ConvOp.MAX_POOLING && et == ExecType.CP && inputs.get(0) instanceof UnaryOp
-				&& ((UnaryOp) inputs.get(0)).getOp() == OpOp1.SELP) {
+		int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
+		OperationTypes lopOp = HopsConv2Lops.get(op);
+		if(op == ConvOp.MAX_POOLING && (et == ExecType.CP || et == ExecType.SPARK) && isInputReLU(inputs.get(0))) {
 			in = inputs.get(0).getInput().get(0).constructLops();
 			lopOp = OperationTypes.RELU_MAX_POOLING;
 		}
-		else if(op == ConvOp.BIAS_ADD && et == ExecType.CP && inputs.get(0) instanceof ConvolutionOp
-				&& ((ConvolutionOp) inputs.get(0)).getOp() == ConvOp.DIRECT_CONV2D) {
+		else if(op == ConvOp.BIAS_ADD && (et == ExecType.CP || et == ExecType.SPARK) && isInputConv2d(inputs.get(0))) {
 			lopOp = OperationTypes.DIRECT_CONV2D_BIAS_ADD;
 			
 			// the first lop is image 
@@ -161,8 +197,13 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		else {
 			in = inputs.get(0).constructLops();
 		}
-		ConvolutionTransform transform1 = new ConvolutionTransform( in, lopOp, getDataType(), getValueType(), et, k);
+		
+//		// TODO: Inserting reblock requires knowing columns apriori
+//		ConvolutionTransform transform1 = new ConvolutionTransform(addReblockIfNecessary(et, lopOp, in), lopOp, getDataType(), getValueType(), et, k);
+//		setReblockedOutputDimension(et, transform1);
+		ConvolutionTransform transform1 = new ConvolutionTransform(in, lopOp, getDataType(), getValueType(), et, k);
 		setOutputDimensions(transform1);
+		
 		setLineNumbers(transform1);
 		in.addOutput(transform1);
 		
@@ -290,11 +331,6 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		
 		checkAndSetForcedPlatform();
 		
-		//TODO: Remove this once memEstimate is fixed for these instructions 
-		if((op == ConvOp.MAX_POOLING || op == ConvOp.MAX_POOLING_BACKWARD) && DMLScript.USE_ACCELERATOR) {
-			return ExecType.GPU;
-		}
-	
 		ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
 		
 		if( _etypeForced != null ) 			
@@ -303,12 +339,12 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		}
 		else 
 		{	
-			// TODO: After adding Spark backend, uncomment this
 			if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
-				_etype = findExecTypeByMemEstimate();
+				_etype = findGPUExecTypeByMemEstimate(findExecTypeByMemEstimate());
+				// TODO: Fix this after adding remaining spark instructions
+				_etype = !isEligibleForSpark() && _etype == REMOTE ?  ExecType.CP : _etype;
 			}
-			else 
-			{
+			else {
 				_etype = REMOTE;
 			}
 			
@@ -320,8 +356,6 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) && _etype==REMOTE )
 			setRequiresRecompile();
 		
-		_etype = ExecType.CP;
-	
 		return _etype;
 	}
 	

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c31c2ca/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index fa05de3..6658a88 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -50,6 +50,7 @@ import org.apache.sysml.runtime.instructions.spark.CastSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.CentralMomentSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.CheckpointSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.CompressionSPInstruction;
+import org.apache.sysml.runtime.instructions.spark.ConvolutionSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.CovarianceSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.CpmmSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.CumulativeAggregateSPInstruction;
@@ -132,6 +133,12 @@ public class SPInstructionParser extends InstructionParser
 		//ternary aggregate operators
 		String2SPInstructionType.put( "tak+*"      , SPINSTRUCTION_TYPE.AggregateTernary);
 		String2SPInstructionType.put( "tack+*"     , SPINSTRUCTION_TYPE.AggregateTernary);
+
+		// Neural network operators
+		String2SPInstructionType.put( "conv2d",                 SPINSTRUCTION_TYPE.Convolution);
+		String2SPInstructionType.put( "conv2d_bias_add", SPINSTRUCTION_TYPE.Convolution);
+		String2SPInstructionType.put( "maxpooling",             SPINSTRUCTION_TYPE.Convolution);
+		String2SPInstructionType.put( "relu_maxpooling",          SPINSTRUCTION_TYPE.Convolution);
 		
 		String2SPInstructionType.put( "rangeReIndex"   	, SPINSTRUCTION_TYPE.MatrixIndexing);
 		String2SPInstructionType.put( "leftIndex"   	, SPINSTRUCTION_TYPE.MatrixIndexing);
@@ -331,6 +338,9 @@ public class SPInstructionParser extends InstructionParser
 			case AggregateTernary:
 				return AggregateTernarySPInstruction.parseInstruction(str);
 				
+			case Convolution:
+				 return ConvolutionSPInstruction.parseInstruction(str);
+
 			case MatrixIndexing:
 				return IndexingSPInstruction.parseInstruction(str);
 				

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c31c2ca/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java
new file mode 100644
index 0000000..b485233
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.instructions.spark;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.spark.broadcast.Broadcast;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.functionobjects.SwapIndex;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator;
+import org.apache.sysml.runtime.instructions.spark.functions.ExtractBlockForBinaryReblock;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
+import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysml.runtime.util.ConvolutionUtils;
+
+import scala.Tuple2;
+
+public class ConvolutionSPInstruction extends UnarySPInstruction {
+	private CPOperand _in2;
+	private CPOperand _in3; 
+	private ArrayList<CPOperand> _input_shape;
+	private ArrayList<CPOperand> _filter_shape;
+	private ArrayList<CPOperand> _stride = new ArrayList<CPOperand>();
+	private ArrayList<CPOperand> _padding = new ArrayList<CPOperand>();
+
+	public ConvolutionSPInstruction(CPOperand in, CPOperand out, String opcode,
+			String istr, ArrayList<CPOperand> stride,
+			ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
+			ArrayList<CPOperand> filter_shape) {
+		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out,
+				opcode, istr);
+		_sptype = SPINSTRUCTION_TYPE.Convolution;
+		_stride = stride;
+		_padding = padding;
+		_input_shape = input_shape;
+		_filter_shape = filter_shape;
+	}
+
+	public ConvolutionSPInstruction(CPOperand in, CPOperand in2, CPOperand out,
+			String opcode, String istr, ArrayList<CPOperand> stride,
+			ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
+			ArrayList<CPOperand> filter_shape) {
+		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out,
+				opcode, istr);
+		_in2 = in2;
+		_sptype = SPINSTRUCTION_TYPE.Convolution;
+		_stride = stride;
+		_padding = padding;
+		_input_shape = input_shape;
+		_filter_shape = filter_shape;
+	}
+
+	public ConvolutionSPInstruction(CPOperand in, CPOperand in2, CPOperand in3,
+			CPOperand out, String opcode, String istr,
+			ArrayList<CPOperand> stride, ArrayList<CPOperand> padding,
+			ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape) {
+		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out,
+				opcode, istr);
+		_in2 = in2;
+		_in3 = in3;
+		_sptype = SPINSTRUCTION_TYPE.Convolution;
+		_stride = stride;
+		_padding = padding;
+		_input_shape = input_shape;
+		_filter_shape = filter_shape;
+	}
+
+	public ConvolutionSPInstruction(CPOperand in, CPOperand in2, CPOperand out,
+			String opcode, String istr) {
+		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out,
+				opcode, istr);
+		_in2 = in2;
+		_sptype = SPINSTRUCTION_TYPE.Convolution;
+	}
+
+	public static ConvolutionSPInstruction parseInstruction( String str ) throws DMLRuntimeException {
+		CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
+		CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
+
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode = parts[0];
+		if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling")) {
+			InstructionUtils.checkNumFields(parts, 14);
+			// stride1, stride2, padding1, padding2
+			// input_shape1, input_shape2, input_shape3, input_shape4,
+			// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
+			in.split(parts[1]);
+			out.split(parts[14]);
+
+			ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
+			stride.add(new CPOperand(parts[2]));
+			stride.add(new CPOperand(parts[3]));
+			padding.add(new CPOperand(parts[4]));
+			padding.add(new CPOperand(parts[5]));
+			input_shape.add(new CPOperand(parts[6]));
+			input_shape.add(new CPOperand(parts[7]));
+			input_shape.add(new CPOperand(parts[8]));
+			input_shape.add(new CPOperand(parts[9]));
+			filter_shape.add(new CPOperand(parts[10]));
+			filter_shape.add(new CPOperand(parts[11]));
+			filter_shape.add(new CPOperand(parts[12]));
+			filter_shape.add(new CPOperand(parts[13]));
+
+			return new ConvolutionSPInstruction(in, out, opcode, str, stride,
+					padding, input_shape, filter_shape);
+		} 
+		else if (opcode.equalsIgnoreCase("maxpooling_backward")
+				|| opcode.equalsIgnoreCase("conv2d")
+				|| opcode.equalsIgnoreCase("conv2d_backward_filter")
+				|| opcode.equalsIgnoreCase("conv2d_backward_data")) {
+			InstructionUtils.checkNumFields(parts, 15);
+			// dout, stride1, stride2, padding1, padding2
+			// input_shape1, input_shape2, input_shape3, input_shape4,
+			// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
+			in.split(parts[1]);
+			CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
+			in2.split(parts[2]);
+			out.split(parts[15]);
+
+			ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
+			stride.add(new CPOperand(parts[3]));
+			stride.add(new CPOperand(parts[4]));
+			padding.add(new CPOperand(parts[5]));
+			padding.add(new CPOperand(parts[6]));
+			input_shape.add(new CPOperand(parts[7]));
+			input_shape.add(new CPOperand(parts[8]));
+			input_shape.add(new CPOperand(parts[9]));
+			input_shape.add(new CPOperand(parts[10]));
+			filter_shape.add(new CPOperand(parts[11]));
+			filter_shape.add(new CPOperand(parts[12]));
+			filter_shape.add(new CPOperand(parts[13]));
+			filter_shape.add(new CPOperand(parts[14]));
+
+			return new ConvolutionSPInstruction(in, in2, out, opcode, str, stride,
+					padding, input_shape, filter_shape);
+		}
+		else if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
+			InstructionUtils.checkNumFields(parts, 16);
+			// dout, stride1, stride2, padding1, padding2
+			// input_shape1, input_shape2, input_shape3, input_shape4,
+			// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
+			in.split(parts[1]);
+			CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
+			in2.split(parts[2]);
+			CPOperand in3 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
+			in3.split(parts[3]);
+			out.split(parts[16]);
+
+			ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
+			stride.add(new CPOperand(parts[4]));
+			stride.add(new CPOperand(parts[5]));
+			padding.add(new CPOperand(parts[6]));
+			padding.add(new CPOperand(parts[7]));
+			input_shape.add(new CPOperand(parts[8]));
+			input_shape.add(new CPOperand(parts[9]));
+			input_shape.add(new CPOperand(parts[10]));
+			input_shape.add(new CPOperand(parts[11]));
+			filter_shape.add(new CPOperand(parts[12]));
+			filter_shape.add(new CPOperand(parts[13]));
+			filter_shape.add(new CPOperand(parts[14]));
+			filter_shape.add(new CPOperand(parts[15]));
+
+			return new ConvolutionSPInstruction(in, in2, in3, out, opcode, str, stride,
+					padding, input_shape, filter_shape);
+		}
+		else if (opcode.equalsIgnoreCase("bias_add")) {
+			InstructionUtils.checkNumFields(parts, 3);
+			in.split(parts[1]);
+			CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
+			in2.split(parts[2]);
+			out.split(parts[3]);
+			return new ConvolutionSPInstruction(in, in2, out, opcode, str);
+		}
+		else {
+			throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionCPInstruction: " + str);
+		}
+	}
+	
+	private JavaPairRDD<MatrixIndexes,MatrixBlock> reblockAsRectangularMatrices(SparkExecutionContext sec, String name, int numRowsPerBlock) throws DMLRuntimeException {
+		JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( name );
+		MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(name);
+		if(mcRdd.getColsPerBlock() < mcRdd.getCols() || mcRdd.getRowsPerBlock() != 1) {
+			MatrixCharacteristics mcOut = new MatrixCharacteristics(mcRdd);
+			mcOut.setColsPerBlock((int)mcRdd.getCols());
+			mcOut.setRowsPerBlock(numRowsPerBlock); 
+			in1 = RDDAggregateUtils.mergeByKey(in1.flatMapToPair(new ExtractBlockForBinaryReblock(mcRdd, mcOut)));
+			// TODO: Inject checkpoint to avoid doing this repeated for validation set
+//			sec.setRDDHandleForVariable(name, in1);
+//			sec.setMetaData(name, new MatrixDimensionsMetaData(mcOut));
+		}
+		return in1;
+	}
+	
+	private Broadcast<MatrixBlock> getBroadcast(SparkExecutionContext sec, String name) throws DMLRuntimeException {
+		MatrixBlock mb = sec.getMatrixInput( name );
+		sec.releaseMatrixInput(name);
+		return sec.getSparkContext().broadcast(mb);
+	}
+
+	@Override
+	public void processInstruction(ExecutionContext ec)
+			throws DMLRuntimeException {
+		SparkExecutionContext sec = (SparkExecutionContext)ec;
+		if(instOpcode.equalsIgnoreCase("conv2d") || instOpcode.equalsIgnoreCase("conv2d_bias_add")
+			|| instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) {
+			String rddVar = input1.getName();
+			int numRowsPerBlock = 1;
+			JavaPairRDD<MatrixIndexes,MatrixBlock> inputRDD = reblockAsRectangularMatrices(sec, rddVar, numRowsPerBlock);
+			MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(rddVar);
+			
+			// ------------------------------------
+			// TODO: Handle large filters > 2G
+			Broadcast<MatrixBlock> filterBroadcast = null;
+			Broadcast<MatrixBlock> biasBroadcast = null;
+			if(instOpcode.equalsIgnoreCase("conv2d")) {
+				filterBroadcast = getBroadcast(sec, _in2.getName());
+			}
+			else if(instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
+				filterBroadcast = getBroadcast(sec, _in3.getName());
+				biasBroadcast = getBroadcast(sec, _in2.getName());
+			}
+			// ------------------------------------
+			
+			int pad_h = getScalarInput(ec, _padding, 0);
+			int pad_w = getScalarInput(ec, _padding, 1);
+			int stride_h = getScalarInput(ec, _stride, 0);
+			int stride_w = getScalarInput(ec, _stride, 1);
+
+			// int N = getScalarInput(ec, _input_shape, 0);
+			int C = getScalarInput(ec, _input_shape, 1);
+			int H = getScalarInput(ec, _input_shape, 2);
+			int W = getScalarInput(ec, _input_shape, 3);
+
+			int K = getScalarInput(ec, _filter_shape, 0);
+			int R = getScalarInput(ec, _filter_shape, 2);
+			int S = getScalarInput(ec, _filter_shape, 3);
+			int P = (int) ConvolutionUtils.getP(H, R, stride_h, pad_h);
+			int Q = (int) ConvolutionUtils.getQ(W, S, stride_w, pad_w);
+			
+			ConvolutionParameters params = new ConvolutionParameters(numRowsPerBlock, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, 1);
+			JavaPairRDD<MatrixIndexes,MatrixBlock> out = inputRDD.mapPartitionsToPair(new RDDConv2dMapMMFunction(filterBroadcast, params, instOpcode, biasBroadcast, mcRdd.getRows()), true);
+			
+			//put output RDD handle into symbol table
+			sec.setRDDHandleForVariable(output.getName(), out);
+			sec.addLineageRDD(output.getName(), rddVar);
+			
+			long nnz = -1; // TODO: Handle nnz
+			long numCols = ((long)K)*((long)P)*((long)Q);
+			if(instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) {
+				numCols = ((long)C)*((long)P)*((long)Q);
+			}
+			if(numCols > Integer.MAX_VALUE) {
+				throw new DMLRuntimeException("The current operator doesnot support large outputs.");
+			}
+			sec.setMetaData(output.getName(), 
+					new MatrixFormatMetaData(new MatrixCharacteristics(mcRdd.getRows(), numCols, numRowsPerBlock, (int)numCols, nnz), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
+		}
+		else {
+			throw new DMLRuntimeException("Not implemented: " + instOpcode);
+		}
+	}
+
+	private int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL,
+			int index) throws DMLRuntimeException {
+		return (int) ec.getScalarInput(aL.get(index).getName(),
+				aL.get(index).getValueType(), aL.get(index).isLiteral())
+				.getLongValue();
+	}
+	
+	private static class RDDConv2dMapMMFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
+	// PairFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, MatrixBlock> {
+		private static final long serialVersionUID = -2106155380020232155L;
+		Broadcast<MatrixBlock> filterBroadcast = null;
+		Broadcast<MatrixBlock> biasBroadcast = null;
+		ConvolutionParameters params = null;
+		String instOpcode = null;
+		long numRows = 0;
+		public RDDConv2dMapMMFunction(Broadcast<MatrixBlock> filterBroadcast, 
+				ConvolutionParameters params, String instOpcode, Broadcast<MatrixBlock> biasBroadcast, long numRows) {
+			this.filterBroadcast = filterBroadcast;
+			this.params = params;
+			this.instOpcode = instOpcode;
+			this.biasBroadcast = biasBroadcast;
+			this.numRows = numRows;
+		}
+		
+		private MatrixBlock processRectangularBlock(MatrixBlock matBlock) throws Exception {
+			MatrixBlock outputBlock = null;
+			if(instOpcode.equalsIgnoreCase("conv2d")) {
+				MatrixBlock filter = filterBroadcast.getValue();
+				if(filter.isEmptyBlock() || matBlock.isEmptyBlock()) {
+					outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, true);
+				}
+				else {
+					outputBlock = getDenseOutputBlock(params.N, params.K*params.P*params.Q);
+					LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
+				}
+			}
+			else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
+				MatrixBlock filter = filterBroadcast.getValue();
+				MatrixBlock bias = biasBroadcast.getValue();
+				if((filter.isEmptyBlock() || matBlock.isEmptyBlock()) && bias.isEmptyBlock()) {
+					outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, true);
+				}
+				else {
+					outputBlock = getDenseOutputBlock(params.N, params.K*params.P*params.Q);
+					if(!bias.isEmptyBlock())
+						params.bias = bias;
+					LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
+				}
+			}
+			else if(instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) {
+				if(matBlock.isEmptyBlock()) {
+					outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, true);
+				}
+				else {
+					outputBlock = getDenseOutputBlock(params.N, params.C*params.P*params.Q);
+					if(instOpcode.equalsIgnoreCase("maxpooling"))
+						Arrays.fill(outputBlock.getDenseBlock(), -Double.MAX_VALUE);
+					LibMatrixDNN.maxpooling(matBlock, outputBlock, params);
+				}
+			}
+			else {
+				throw new RuntimeException("Not implemented");
+			}
+			return outputBlock;
+		}
+		
+		private MatrixBlock getDenseOutputBlock(int numRows, int numCols) throws DMLRuntimeException {
+			MatrixBlock outputBlock = new MatrixBlock(numRows, numCols, false);
+			outputBlock.allocateDenseBlock();
+			return outputBlock;
+		}
+
+		@Override
+		public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(
+				Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0)
+				throws Exception {
+			return new MapsideConvolutionPartitionIterator(arg0);
+		}
+		
+		// Avoid materialization of partitions
+		private class MapsideConvolutionPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
+			public MapsideConvolutionPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) {
+				super(in);
+			}
+
+			@Override
+			protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception {
+				if(arg._1.getRowIndex() > numRows || arg._1.getColumnIndex() != 1) {
+					throw new RuntimeException("Expected the inputs to be reblocked as rectangular RDD");
+				}
+				MatrixBlock out = processRectangularBlock(arg._2);
+				if(out.getNumRows() != 1) {
+					throw new RuntimeException("Expected the output to have 1 row");
+				}
+				return new Tuple2<MatrixIndexes, MatrixBlock>(arg._1, out);
+			}
+		}
+		
+	}
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c31c2ca/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
index 8f866af..b28e408 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
@@ -38,6 +38,7 @@ public abstract class SPInstruction extends Instruction
 		ParameterizedBuiltin, MAppend, RAppend, GAppend, GAlignedAppend, Rand, 
 		MatrixReshape, Ternary, Quaternary, CumsumAggregate, CumsumOffset, BinUaggChain, UaggOuterChain, 
 		Write, INVALID, 
+		Convolution
 	};
 	
 	protected SPINSTRUCTION_TYPE _sptype;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c31c2ca/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java b/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
index 9cd187c..213e564 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysml.runtime.matrix.data;
 
+import java.io.Serializable;
+
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
 
@@ -26,7 +28,8 @@ import org.apache.sysml.runtime.util.ConvolutionUtils;
  * This class is container that stores parameters required for executing following operations:
  * conv2d, conv2d_backward_data, conv2d_backward_filter, maxpooling, maxpooling_backward 
  */
-public class ConvolutionParameters {
+public class ConvolutionParameters implements Serializable {
+	private static final long serialVersionUID = -212362627205772829L;
 	public int N; public int C; public int H; public int W;
 	public int K; public int R; public int S; public int stride_h; public int stride_w; public int pad_h; public int pad_w;
 	public int P; public int Q; public int numThreads;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c31c2ca/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 82b0a61..0c0410c 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -325,7 +325,7 @@ public class LibMatrixDNN {
 		
 		if(input.getNumRows() != params.N || input.getNumColumns() != params.C*params.H*params.W || 
 				filter.getNumRows() != params.K || filter.getNumColumns() != params.C*params.R*params.S) {
-			throw new DMLRuntimeException("Incorrect input to conv2d");
+			throw new DMLRuntimeException("Incorrect input to conv2d: " + input.getNumRows());
 		}
 		
 		if(DMLScript.STATISTICS && DISPLAY_STATISTICS) {
@@ -389,8 +389,9 @@ public class LibMatrixDNN {
 		}
 		// -----------------------------------------------------------------------------
 		
+		// Recomputing nnz is not required for each individual im2col as it is invoked by outer public methods (i.e. conv2d.
 		//post-processing: maintain nnz
-		params.output.recomputeNonZeros(); 
+		// params.output.recomputeNonZeros(); 
 	}
 	
 	/**
@@ -807,7 +808,7 @@ public class LibMatrixDNN {
 		params.output = outputBlock;
 		
 		if(input.getNumColumns() != params.C*params.H*params.W || input.getNumRows() != params.N) {
-			throw new DMLRuntimeException("Incorrect input dimensions in maxpooling:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.K*params.P*params.Q);
+			throw new DMLRuntimeException("Incorrect input dimensions in maxpooling:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.C*params.H*params.W);
 		}
 		
 		fillIndexesArray(params);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c31c2ca/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DTest.java b/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DTest.java
index 9a1a823..81fe154 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DTest.java
@@ -51,6 +51,7 @@ public class Conv2DTest extends AutomatedTestBase
 		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
 	}
 	
+	
 	@Test
 	public void testConv2DDense2() 
 	{
@@ -127,6 +128,92 @@ public class Conv2DTest extends AutomatedTestBase
 		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
 	}
 	
+	// --------------------------------------------
+	
+
+	@Test
+	public void testConv2DDense1SP() 
+	{
+		int numImg = 5; int imgSize = 3; int numChannels = 3; int numFilters = 6; int filterSize = 2; int stride = 1; int pad = 0;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+	}
+	
+	@Test
+	public void testConv2DDense2SP() 
+	{
+		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 0;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+	}
+	
+	@Test
+	public void testConv2DDense3SP() 
+	{
+		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 1;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+	}
+	
+	@Test
+	public void testConv2DDense4SP() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 1; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+	}
+	
+	@Test
+	public void testConv2DDense5SP() 
+	{
+		int numImg = 3; int imgSize = 8; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 1; int pad = 2;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+	}
+	
+	@Test
+	public void testConv2DDense6SP() 
+	{
+		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 1; int pad = 0;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+	}
+	
+	@Test
+	public void testConv2DDense7SP() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 1; int numFilters = 3; int filterSize = 2; int stride = 1; int pad = 0;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+	}
+	
+	@Test
+	public void testConv2DSparse1SP() 
+	{
+		int numImg = 5; int imgSize = 3; int numChannels = 3; int numFilters = 6; int filterSize = 2; int stride = 1; int pad = 0;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+	}
+	
+	@Test
+	public void testConv2DSparse2SP() 
+	{
+		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 0;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+	}
+	
+	@Test
+	public void testConv2DSparse3SP() 
+	{
+		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 1;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+	}
+	
+	public void testConv2DSparse4SP() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 1; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+	}
+	
+	@Test
+	public void testConv2DSparse5SP() 
+	{
+		int numImg = 3; int imgSize = 8; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 1; int pad = 2;
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+	}
+	
 	/**
 	 * 
 	 * @param et
@@ -162,7 +249,7 @@ public class Conv2DTest extends AutomatedTestBase
 				fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
 				
 				
-				programArgs = new String[]{"-explain", "-args",  "" + imgSize, "" + numImg, 
+				programArgs = new String[]{"-explain", "recompile_runtime", "-args",  "" + imgSize, "" + numImg, 
 					"" + numChannels, "" + numFilters, 
 					"" + filterSize, "" + stride, "" + pad, 
 					output("B")};

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c31c2ca/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java b/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
index 3a10714..c064ca6 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
@@ -72,6 +72,37 @@ public class PoolTest extends AutomatedTestBase
 		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
 	}
 	
+	// ----------------------------------------
+	
+	@Test
+	public void testMaxPool2DDense1SP() 
+	{
+		int numImg = 1; int imgSize = 50; int numChannels = 1;  int stride = 2; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
+		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+	}
+	
+	@Test
+	public void testMaxPool2DDense2SP() 
+	{
+		int numImg = 2; int imgSize = 6; int numChannels = 1;  int stride = 1; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
+		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+	}
+	
+	
+	@Test
+	public void testMaxPool2DDense3SP() 
+	{
+		int numImg = 3; int imgSize = 7; int numChannels = 2;  int stride = 2; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
+		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+	}
+	
+	@Test
+	public void testMaxPool2DDense4SP() 
+	{
+		int numImg = 2; int imgSize = 4; int numChannels = 2;  int stride = 1; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
+		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+	}
+	
 	/**
 	 * 
 	 * @param et