You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/04/25 21:47:05 UTC

[1/2] incubator-systemml git commit: [SYSTEMML-687] Optimized LibMatrixDNN for sparse inputs

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 32f075695 -> 2d2196d84


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/scripts/functions/tensor/Conv2DTest.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/Conv2DTest.dml b/src/test/scripts/functions/tensor/Conv2DTest.dml
index aec8499..792367f 100644
--- a/src/test/scripts/functions/tensor/Conv2DTest.dml
+++ b/src/test/scripts/functions/tensor/Conv2DTest.dml
@@ -31,6 +31,16 @@ x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), rows=numImg, cols=numChanne
 w=matrix(seq(1, numFilters*numChannels*filterSize*filterSize), rows=numFilters, cols=numChannels*filterSize*filterSize)
 b=matrix(seq(1, numFilters), rows=numFilters, cols=1) 
 
+if($9) {
+	zero_mask = (x - mean(x)) > 0 
+	x = x * zero_mask
+}
+if($10) {
+	zero_mask = (w - mean(w)) > 0 
+	w = w * zero_mask
+}
+x = x - mean(x)
+w = w - mean(w)
 output = conv2d(x, w, padding=[pad, pad], stride=[stride, stride], input_shape=[numImg, numChannels, imgSize, imgSize], filter_shape=[numFilters, numChannels, filterSize, filterSize], bias=b)
 output = bias_add(output, b) 
 write(output, $8, format="text")
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/scripts/functions/tensor/PoolBackwardTest.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/PoolBackwardTest.R b/src/test/scripts/functions/tensor/PoolBackwardTest.R
index 8cb8a7c..f3133a7 100644
--- a/src/test/scripts/functions/tensor/PoolBackwardTest.R
+++ b/src/test/scripts/functions/tensor/PoolBackwardTest.R
@@ -34,7 +34,16 @@ Q=as.integer(args[9])
 # Assumption: NCHW image format
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, numChannels*imgSize*imgSize, byrow=TRUE)
 dout=matrix(seq(1, numImg*numChannels*P*Q), numImg, numChannels*P*Q, byrow=TRUE)
-
+if(as.logical(args[11])) {
+	# zero_mask = (x - mean(x)) > 0 
+	# x = x * zero_mask
+}
+if(as.logical(args[12])) {
+	# zero_mask = (dout - mean(dout)) > 0 
+	# dout = dout * zero_mask
+}
+x = x - mean(x)
+dout = dout - mean(dout)
 max_pool_backward <- function(dout, Hout, Wout, X, C,
                     Hin, Win, Hf, Wf, strideh, stridew)
      {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/scripts/functions/tensor/PoolBackwardTest.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/PoolBackwardTest.dml b/src/test/scripts/functions/tensor/PoolBackwardTest.dml
index 0ee80df..22f778f 100644
--- a/src/test/scripts/functions/tensor/PoolBackwardTest.dml
+++ b/src/test/scripts/functions/tensor/PoolBackwardTest.dml
@@ -33,6 +33,16 @@ Q = $10
 # Assumption: NCHW image format
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), rows=numImg, cols=numChannels*imgSize*imgSize)
 dout=matrix(seq(1, numImg*numChannels*P*Q), rows=numImg, cols=numChannels*P*Q)
+if($12) {
+	# zero_mask = (x - mean(x)) > 0 
+	# x = x * zero_mask
+}
+if($13) {
+	# zero_mask = (dout - mean(dout)) > 0 
+	# dout = dout * zero_mask
+}
+x = x - mean(x)
+dout = dout - mean(dout)
 if(poolMode == "max") {
 	output = max_pool_backward(x, dout, stride=[stride, stride], padding=[pad, pad], input_shape=[numImg, numChannels, imgSize, imgSize], pool_size=[poolSize1, poolSize2])
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/scripts/functions/tensor/PoolTest.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/PoolTest.R b/src/test/scripts/functions/tensor/PoolTest.R
index 3731807..d9c8d0c 100644
--- a/src/test/scripts/functions/tensor/PoolTest.R
+++ b/src/test/scripts/functions/tensor/PoolTest.R
@@ -31,7 +31,11 @@ pad=as.integer(args[7])
 
 # Assumption: NCHW image format
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, numChannels*imgSize*imgSize, byrow=TRUE)
-
+if(as.logical(args[9])) {
+	zero_mask = (x - mean(x)) > 0 
+	x = x * zero_mask
+}
+x = x - mean(x)
 pad_image <- function(img, Hin, Win, padh, padw){
   C = nrow(img)
   img_padded = matrix(0, C, (Hin+2*padh)*(Win+2*padw))  # zeros

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/scripts/functions/tensor/PoolTest.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/PoolTest.dml b/src/test/scripts/functions/tensor/PoolTest.dml
index e163e89..b701e71 100644
--- a/src/test/scripts/functions/tensor/PoolTest.dml
+++ b/src/test/scripts/functions/tensor/PoolTest.dml
@@ -29,6 +29,11 @@ poolMode=$8
 
 # Assumption: NCHW image format
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), rows=numImg, cols=numChannels*imgSize*imgSize)
+if($10) {
+	zero_mask = (x - mean(x)) > 0 
+	x = x * zero_mask
+}
+x = x - mean(x)
 if(poolMode == "max") {
 	output = max_pool(x, stride=[stride, stride], padding=[pad, pad], input_shape=[numImg, numChannels, imgSize, imgSize], pool_size=[poolSize1, poolSize2])
 }


[2/2] incubator-systemml git commit: [SYSTEMML-687] Optimized LibMatrixDNN for sparse inputs

Posted by ni...@apache.org.
[SYSTEMML-687] Optimized LibMatrixDNN for sparse inputs


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2d2196d8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2d2196d8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2d2196d8

Branch: refs/heads/master
Commit: 2d2196d84750df8801f1218df2c7160ca8b438cb
Parents: 32f0756
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Tue Apr 25 13:46:53 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Tue Apr 25 14:46:53 2017 -0700

----------------------------------------------------------------------
 .../matrix/data/ConvolutionParameters.java      |   2 +-
 .../sysml/runtime/matrix/data/LibMatrixDNN.java | 518 +++++++------------
 .../sysml/runtime/util/ConvolutionUtils.java    | 144 ++++++
 .../tensor/Conv2DBackwardDataTest.java          |  87 +++-
 .../functions/tensor/Conv2DBackwardTest.java    | 139 ++++-
 .../functions/tensor/Conv2DTest.java            | 152 +++---
 .../functions/tensor/PoolBackwardTest.java      |  93 +++-
 .../integration/functions/tensor/PoolTest.java  |  67 ++-
 .../functions/tensor/Conv2DBackwardDataTest.R   |  11 +-
 .../functions/tensor/Conv2DBackwardDataTest.dml |  10 +
 .../functions/tensor/Conv2DBackwardTest.R       |  11 +-
 .../functions/tensor/Conv2DBackwardTest.dml     |  10 +
 src/test/scripts/functions/tensor/Conv2DTest.R  |  10 +
 .../scripts/functions/tensor/Conv2DTest.dml     |  10 +
 .../scripts/functions/tensor/PoolBackwardTest.R |  11 +-
 .../functions/tensor/PoolBackwardTest.dml       |  10 +
 src/test/scripts/functions/tensor/PoolTest.R    |   6 +-
 src/test/scripts/functions/tensor/PoolTest.dml  |   5 +
 18 files changed, 812 insertions(+), 484 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java b/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
index 213e564..3f0437f 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
@@ -35,7 +35,7 @@ public class ConvolutionParameters implements Serializable {
 	public int P; public int Q; public int numThreads;
 	
 	
-	MatrixBlock input1; MatrixBlock input2; MatrixBlock output;
+	public MatrixBlock input1; public MatrixBlock input2; public MatrixBlock output;
 	
 	public MatrixBlock bias;
 	public int [] start_indexes_h, end_indexes_h, start_indexes_w, end_indexes_w; 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 5ab41e0..8a1a43f 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -20,7 +20,6 @@ package org.apache.sysml.runtime.matrix.data;
 
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ConcurrentLinkedQueue;
@@ -34,6 +33,9 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.util.ConvolutionUtils;
 
 /**
  * This class allows users to invoke deep learning related operations 
@@ -124,6 +126,18 @@ public class LibMatrixDNN {
 		loopedConvBwdDataMatMultTime.set(0);
 		loopedConvBwdDataCol2ImTime.set(0);
 	}
+	
+	// Commonly used operators
+	private static BinaryOperator _binaryElementWiseAddition = null;
+	private static BinaryOperator _binaryElementWiseMultiplication = null;
+	static {
+		try {
+			_binaryElementWiseAddition = InstructionUtils.parseBinaryOperator("+");
+			_binaryElementWiseMultiplication = InstructionUtils.parseBinaryOperator("*");
+		} catch (DMLRuntimeException e) {
+			throw new RuntimeException("ERROR initializing LibMatrixDNN", e);
+		}
+	}
 	// ------------------------------------------------------------------------------------------------
 	
 	/**
@@ -199,37 +213,6 @@ public class LibMatrixDNN {
 	}
 	
 	/**
-	 * Performs the operation: ret += elem
-	 * @param ret left and output matrix
-	 * @param elem right matrix
-	 * @throws DMLRuntimeException if DMLRuntimeException occurs
-	 */
-	private static void elementWiseInPlaceAddition(MatrixBlock ret, MatrixBlock elem) throws DMLRuntimeException {
-		if(ret.getNumRows() != elem.getNumRows() || ret.getNumColumns() != elem.getNumColumns()) {
-			throw new DMLRuntimeException("Incorrect dimensions");
-		}
-		if(!ret.isInSparseFormat() && !elem.isInSparseFormat()) {
-			for(int i = 0; i < ret.getNumRows()*ret.getNumColumns(); i++) {
-				ret.denseBlock[i] += elem.denseBlock[i];
-			}
-		}
-		else if(!ret.isInSparseFormat() && elem.isInSparseFormat()) {
-			if(!elem.isEmptyBlock()) {
-				Iterator<IJV> iter = elem.sparseBlock.getIterator();
-				int numCol = ret.getNumColumns();
-				while(iter.hasNext()) {
-					IJV ijv = iter.next();
-					int index = ijv.getI()*numCol + ijv.getJ();
-					ret.denseBlock[index] += ijv.getV(); 
-				}
-			}
-		}
-		else {
-			throw new DMLRuntimeException("Sparse return format not supported");
-		}
-	}
-	
-	/**
 	 * Performs the operation for(e : elem) ret += t(e) in a cache-conscious manner
 	 * by sequentially aggregating for(e : elem) tmp += e and finally transposing
 	 * ret = t(tmp).
@@ -284,9 +267,9 @@ public class LibMatrixDNN {
 	}
 	
 	private static MatrixBlock doLoopedIm2ColConv2dBwdFilter(int n, 
-			MatrixBlock im2ColOutBlock, MatrixBlock dout_reshaped, MatrixBlock partialRetBlock, ConvolutionParameters params) throws DMLRuntimeException {
+			MatrixBlock im2ColOutBlock, MatrixBlock dout_reshaped, MatrixBlock partialRetBlock, ConvolutionParameters params, double []  tempIm2ColArr) throws DMLRuntimeException {
 		long t1 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0;
-		doIm2col(n, im2ColOutBlock, params);
+		doIm2col(n, im2ColOutBlock, params, tempIm2ColArr);
 		im2ColOutBlock.recomputeNonZeros();
 		long t2 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0 ;
 		
@@ -301,8 +284,11 @@ public class LibMatrixDNN {
 			loopedConvBwdFilterMatMultTime.addAndGet(t4-t3);
 			loopedConvBwdFilterIm2ColTime.addAndGet(t2-t1);
 		}
-		if(!temp.isEmptyBlock())
-			elementWiseInPlaceAddition(partialRetBlock, temp);
+		if(!temp.isEmptyBlock()) {
+			// partialRetBlock is size: [params.C*params.R*params.S, params.K]
+			ConvolutionUtils.binaryOperationInPlace(temp, partialRetBlock.getDenseBlock(), 0, params.K, 0, params.C*params.R*params.S, 
+					_binaryElementWiseAddition);
+		}
 		return partialRetBlock;
 	}
 	
@@ -331,22 +317,15 @@ public class LibMatrixDNN {
 			}
 		}
 		
-		if(!input.isInSparseFormat() && TEST_SPARSE_INPUT) {
-			input.denseToSparse();
-		}
-		if(!filter.isInSparseFormat() && TEST_SPARSE_FILTER) {
-			filter.denseToSparse();
-		}
-		
 		runConvTask(TaskType.LoopedIm2ColConv2d, params);
 		
 		//post-processing: maintain nnz
 		outputBlock.recomputeNonZeros();
 	}
 	
-	private static void doLoopedIm2ColConv2d(int n, MatrixBlock im2ColOutBlock, ConvolutionParameters params) throws DMLRuntimeException {
+	private static void doLoopedIm2ColConv2d(int n, MatrixBlock im2ColOutBlock, ConvolutionParameters params, double []  temp) throws DMLRuntimeException {
 		long t1 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0;
-		doIm2col(n, im2ColOutBlock, params);
+		doIm2col(n, im2ColOutBlock, params, temp);
 		im2ColOutBlock.recomputeNonZeros();
 		long t2 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0;
 		
@@ -366,15 +345,21 @@ public class LibMatrixDNN {
 		int length = params.K*params.P*params.Q;
 		if(!matMultOutBlock.isEmptyBlock()) {
 			if(matMultOutBlock.isInSparseFormat()) {
-				// NOTE: Potential bottlenc to copy sparse matmult back to dense output
-				Iterator<IJV> iter = matMultOutBlock.sparseBlock.getIterator();
+				// Copy the sparse matrix matMultOutBlock of shape [K X PQ] to 
+				// params.output.denseBlock + destPos
 				final int outOffset = n*params.K*params.P*params.Q;
-				while(iter.hasNext()) {
-					IJV ijv = iter.next();
-					int k = ijv.getI();
-					int p = ijv.getJ() / params.Q;
-					int q = ijv.getJ() % params.Q;
-					params.output.denseBlock[outOffset + k*params.P*params.Q + p*params.Q + q] = ijv.getV();
+				final int PQ = params.P*params.Q;
+				for(int k = 0; k < matMultOutBlock.getNumRows(); k++) {
+					if( !matMultOutBlock.sparseBlock.isEmpty(k) ) {
+						int apos = matMultOutBlock.sparseBlock.pos(k);
+						int alen = matMultOutBlock.sparseBlock.size(k);
+						int[] aix = matMultOutBlock.sparseBlock.indexes(k);
+						double[] avals = matMultOutBlock.sparseBlock.values(k);
+						for(int j = apos; j < apos+alen; j++) {
+							int pqIndex = aix[j];
+							params.output.denseBlock[outOffset + k*PQ + pqIndex ] = avals[j];
+						}
+					}
 				}
 			}
 			else
@@ -387,6 +372,7 @@ public class LibMatrixDNN {
 		// params.output.recomputeNonZeros(); 
 	}
 	
+	
 	/**
 	 * This method computes the backpropogation errors for previous layer of maxpooling operation
 	 * 
@@ -504,42 +490,43 @@ public class LibMatrixDNN {
 		if (!params.input1.isInSparseFormat())
 			throw new DMLRuntimeException("Incorrect usage: Call optimized versions");
 		
-		// params.input2.isEmptyBlock() check is done by the caller
-		Iterator<IJV> iter = params.input2.sparseBlock.getIterator(n, n+1);
-		int [] tensorIndexes = new int[3];
-		
-		while(iter.hasNext()) {
-			IJV ijv = iter.next();
-			computeTensorIndexes(ijv.getJ(), tensorIndexes, params.P, params.Q);
-			int c = tensorIndexes[0];
-			int p = tensorIndexes[1];
-			int q = tensorIndexes[2];
-			
-			final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W;
-			int maxIndex = getMaxIndexSparse(p, q, inputOffset, n, c, params.input1, params);
-			if(maxIndex != -1)
-				outputArray[maxIndex] += ijv.getV();
+		if( !params.input2.sparseBlock.isEmpty(n) ) {
+			int [] tensorIndexes = new int[3];
+			int apos = params.input2.sparseBlock.pos(n);
+			int alen = params.input2.sparseBlock.size(n);
+			int[] aix = params.input2.sparseBlock.indexes(n);
+			double[] avals = params.input2.sparseBlock.values(n);
+			for(int j = apos; j < apos+alen; j++) {
+				computeTensorIndexes(aix[j], tensorIndexes, params.P, params.Q);
+				int c = tensorIndexes[0];
+				int p = tensorIndexes[1];
+				int q = tensorIndexes[2];
+				final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W;
+				int maxIndex = getMaxIndexSparse(p, q, inputOffset, n, c, params.input1, params);
+				if(maxIndex != -1)
+					outputArray[maxIndex] += avals[j];
+			}
 		}
-		
 	}
 	
 	private static void doPoolingBackwardDenseSparse(int n, double [] inputArray, 
 			MatrixBlock dout, double [] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
-		// dout.isEmptyBlock() check is done by the caller
-		Iterator<IJV> iter = dout.sparseBlock.getIterator(n, n+1);
-		int [] tensorIndexes = new int[3];
-		
-		while(iter.hasNext()) {
-			IJV ijv = iter.next();
-			computeTensorIndexes(ijv.getJ(), tensorIndexes, params.P, params.Q);
-			int c = tensorIndexes[0];
-			int p = tensorIndexes[1];
-			int q = tensorIndexes[2];
-			
-			final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W;
-			int maxIndex = getMaxIndex(p, q, inputOffset, inputArray, params);
-			if(maxIndex != -1)
-				outputArray[maxIndex] += ijv.getV();
+		if( !dout.sparseBlock.isEmpty(n) ) {
+			int [] tensorIndexes = new int[3];
+			int apos = dout.sparseBlock.pos(n);
+			int alen = dout.sparseBlock.size(n);
+			int[] aix = dout.sparseBlock.indexes(n);
+			double[] avals = dout.sparseBlock.values(n);
+			for(int j = apos; j < apos+alen; j++) {
+				computeTensorIndexes(aix[j], tensorIndexes, params.P, params.Q);
+				int c = tensorIndexes[0];
+				int p = tensorIndexes[1];
+				int q = tensorIndexes[2];
+				final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W;
+				int maxIndex = getMaxIndex(p, q, inputOffset, inputArray, params);
+				if(maxIndex != -1)
+					outputArray[maxIndex] += avals[j];
+			}
 		}
 	}
 	
@@ -576,8 +563,6 @@ public class LibMatrixDNN {
 		if(!input.isInSparseFormat())
 			throw new DMLRuntimeException("Incorrect usage: Only sparse format supported");
 		
-		// input.isEmptyBlock() check is done by the caller
-		Iterator<IJV> iter = input.sparseBlock.getIterator(n, n+1);
 		int [] tensorIndexes = new int[3];
 		
 		int start_index_h = params.start_indexes_h[p];
@@ -592,22 +577,29 @@ public class LibMatrixDNN {
 		// maxVal = 0 
 		// if start_index_h < 0 || start_index_w < 0 || end_index_h >= params.H || end_index_w >= params.W
 
-		// Find maxIndex
-		double currDoutVal = -1;
-		while(iter.hasNext()) {
-			IJV ijv = iter.next();
-			computeTensorIndexes(ijv.getJ(), tensorIndexes, params.H, params.W);
-			if(c != tensorIndexes[0])
-				continue;
-			int h = tensorIndexes[1];
-			int w = tensorIndexes[2];
-			if(h >= start_index_h && h < end_index_h && w >= start_index_w && w < end_index_w) {
-				currDoutVal = ijv.getV();
-				if(maxVal < currDoutVal) {
-					maxIndex = inputOffset +  h*params.W + w;
-					maxVal = currDoutVal;
+		// input.isEmptyBlock() check is done by the caller
+		if( !input.sparseBlock.isEmpty(n) ) {
+			// Find maxIndex
+			int apos = input.sparseBlock.pos(n);
+			int alen = input.sparseBlock.size(n);
+			int[] aix = input.sparseBlock.indexes(n);
+			double[] avals = input.sparseBlock.values(n);
+			for(int j=apos; j<apos+alen; j++) {
+				computeTensorIndexes(aix[j], tensorIndexes, params.H, params.W);
+				if(c != tensorIndexes[0])
+					continue;
+				int h = tensorIndexes[1];
+				int w = tensorIndexes[2];
+				if(h >= start_index_h && h < end_index_h && w >= start_index_w && w < end_index_w) {
+					if(maxVal < avals[j]) {
+						maxIndex = inputOffset +  h*params.W + w;
+						maxVal = avals[j];
+					}
 				}
-			}	
+			}
+		}
+		else {
+			maxIndex = inputOffset;
 		}
 		return maxIndex;
 	}
@@ -688,37 +680,11 @@ public class LibMatrixDNN {
 		}
 		else {
 			// Perform (X > 0)
-			if(params.input1.isInSparseFormat()) {
-				Iterator<IJV> iter = params.input1.sparseBlock.getIterator(rl, ru);
-				while(iter.hasNext()) {
-					IJV ijv = iter.next();
-					int i = ijv.getI();
-					int j = ijv.getJ();
-					outputArray[i*numOutCols + j] = ijv.getV() > 0 ? 1 : 0;
-				}
-			}
-			else {
-				double [] inputArr = params.input1.getDenseBlock();
-				for(int i = rl*numOutCols; i < ru*numOutCols; i++) {
-					outputArray[i] = inputArr[i] > 0 ? 1 : 0;
-				}
-			}
+			ConvolutionUtils.scalarOperations(params.input1, outputArray, rl*numOutCols, numOutCols, rl, ru, 
+					InstructionUtils.parseScalarBinaryOperator(">", false, 0));
 			// Then perform (X > 0) * dout
-			if(params.input2.isInSparseFormat()) {
-				Iterator<IJV> iter = params.input2.sparseBlock.getIterator(rl, ru);
-				while(iter.hasNext()) {
-					IJV ijv = iter.next();
-					int i = ijv.getI();
-					int j = ijv.getJ();
-					outputArray[i*numOutCols + j] *= ijv.getV();
-				}
-			}
-			else {
-				double [] doutArr = params.input2.getDenseBlock();
-				for(int i = rl*numOutCols; i < ru*numOutCols; i++) {
-					outputArray[i] *= doutArr[i];
-				}
-			}
+			ConvolutionUtils.binaryOperationInPlace(params.input2, outputArray, rl*numOutCols, numOutCols, rl, ru, 
+					_binaryElementWiseMultiplication);
 		}
 		
 		//post-processing: maintain nnz
@@ -748,13 +714,6 @@ public class LibMatrixDNN {
 		params.input2 = bias;
 		params.output = outputBlock;
 		
-		if(!input.isInSparseFormat() && TEST_SPARSE_INPUT) {
-			input.denseToSparse();
-		}
-		if(!bias.isInSparseFormat() && TEST_SPARSE_FILTER) {
-			bias.denseToSparse();
-		}
-		
 		if(bias.getNumColumns() != 1 || input.getNumColumns() % K != 0) {
 			throw new DMLRuntimeException("Incorrect inputs for bias_add: input[" + N + " X " + input.getNumColumns()  + "] and bias[" + K + " X " + bias.getNumColumns() + "]");
 		}
@@ -762,7 +721,7 @@ public class LibMatrixDNN {
 		if(input.isEmptyBlock()) {
 			double [] outputArray = outputBlock.getDenseBlock();
 			for(int n = 0;  n < N; n++) 
-				fillBias(bias, outputArray, n, n+1, N, K, PQ);
+				ConvolutionUtils.fillBias(bias, outputArray, n, n+1, N, K, PQ);
 		}
 		else {
 			runConvTask(TaskType.BiasAdd, params);
@@ -795,13 +754,6 @@ public class LibMatrixDNN {
 		params.input2 = bias;
 		params.output = outputBlock;
 		
-		if(!input.isInSparseFormat() && TEST_SPARSE_INPUT) {
-			input.denseToSparse();
-		}
-		if(!bias.isInSparseFormat() && TEST_SPARSE_FILTER) {
-			bias.denseToSparse();
-		}
-		
 		if(bias.getNumColumns() != 1 || input.getNumColumns() % K != 0) {
 			throw new DMLRuntimeException("Incorrect inputs for bias_multiply: input[" + N + " X " + input.getNumColumns()  + "] and bias[" + K + " X " + bias.getNumColumns() + "]");
 		}
@@ -816,116 +768,6 @@ public class LibMatrixDNN {
 		}
 	}
 	
-	private static void doBiasMultiply(ConvolutionParameters params, int rl, int ru) throws DMLRuntimeException {
-		double [] outputArray = params.output.getDenseBlock();
-		int PQ = params.C;
-		int numOutCols = params.input1.getNumColumns();
-		
-		if(!params.input1.isInSparseFormat() && !params.input2.isInSparseFormat()) {
-			double [] inputArr = params.input1.getDenseBlock();
-			double [] biasArr = params.input2.getDenseBlock();
-			int K = params.K;
-			int index = rl*K*PQ;
-			for(int n = rl; n < ru; n++) {
-				for(int k = 0; k < K; k++) {
-					for(int pq = 0; pq < PQ; pq++, index++) {
-						outputArray[index] = inputArr[index] * biasArr[k];
-					}
-				}
-			}
-		}
-		else {
-			// Fill non-zero values
-			if(params.input1.isInSparseFormat()) {
-				Iterator<IJV> iter = params.input1.sparseBlock.getIterator(rl, ru);
-				while(iter.hasNext()) {
-					IJV ijv = iter.next();
-					int i = ijv.getI();
-					int j = ijv.getJ();
-					outputArray[i*numOutCols + j] = ijv.getV();
-				}
-			}
-			else {
-				System.arraycopy(params.input1.getDenseBlock(), 0, outputArray, 0, outputArray.length);
-			}
-			int K = params.K;
-			int index = rl*K*PQ;
-			for(int k = 0; k < K; k++) {
-				double val = params.input2.getValue(k, 1);
-				for(int n = rl; n < ru; n++) {
-					for(int pq = 0; pq < PQ; pq++, index++) {
-						outputArray[index] *= val;
-					}
-				}
-			}
-		}
-		
-	}
-	
-	private static void doBiasAdd(ConvolutionParameters params, int rl, int ru) throws DMLRuntimeException {
-		double [] outputArray = params.output.getDenseBlock();
-		int PQ = params.C;
-		int numOutCols = params.input1.getNumColumns();
-		
-		if(!params.input1.isInSparseFormat() && !params.input2.isInSparseFormat()) {
-			double [] inputArr = params.input1.getDenseBlock();
-			double [] biasArr = params.input2.getDenseBlock();
-			int K = params.K;
-			int index = rl*K*PQ;
-			for(int n = rl; n < ru; n++) {
-				for(int k = 0; k < K; k++) {
-					for(int pq = 0; pq < PQ; pq++, index++) {
-						outputArray[index] = inputArr[index] + biasArr[k];
-					}
-				}
-			}
-		}
-		else {
-			fillBias(params.input2, outputArray, rl, ru, params.N, params.K, PQ);
-			if(params.input1.isInSparseFormat()) {
-				Iterator<IJV> iter = params.input1.sparseBlock.getIterator(rl, ru);
-				while(iter.hasNext()) {
-					IJV ijv = iter.next();
-					int i = ijv.getI();
-					int j = ijv.getJ();
-					outputArray[i*numOutCols + j] += ijv.getV();
-				}
-			}
-			else {
-				double [] inputArr = params.input1.getDenseBlock();
-				for(int i = rl*numOutCols; i < ru*numOutCols; i++) {
-					outputArray[i] += inputArr[i];
-				}
-			}
-		}
-		
-	}
-	
-	private static void fillBias(MatrixBlock bias, double [] outputArray, int n1, int n2, int N, int K, int PQ) {
-		if(bias.isInSparseFormat()) {
-			Iterator<IJV> iter = bias.sparseBlock.getIterator();
-			while(iter.hasNext()) {
-				IJV ijv = iter.next();
-				int k = ijv.getI();
-				double val = ijv.getV();
-				for(int n = n1; n < n2; n++) {
-					int fromIndex = n*K*PQ + k*PQ;
-					Arrays.fill(outputArray, fromIndex, fromIndex + PQ, val);
-				}
-			}
-		}
-		else {
-			double [] biasArr = bias.getDenseBlock();
-			for(int n = n1; n < n2; n++) {
-				for(int k = 0; k < K; k++) {
-					int fromIndex = n*K*PQ + k*PQ;
-					double val = biasArr[k];
-					Arrays.fill(outputArray, fromIndex, fromIndex + PQ, val);
-				}
-			}
-		}
-	}
-
 	public static void maxpooling(MatrixBlock input, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
 		params.input1 = input;
 		params.output = outputBlock;
@@ -1009,15 +851,19 @@ public class LibMatrixDNN {
 				Arrays.fill(outputArray, 0);
 			
 			if(!input.isEmptyBlock()) {
-				Iterator<IJV> iter = input.sparseBlock.getIterator(inputN, inputN+1);
-				int [] tensorIndexes = new int[3];
-				while(iter.hasNext()) {
-					IJV ijv = iter.next();
-					computeTensorIndexes(ijv.getJ(), tensorIndexes, params.P, params.Q);
-					int k = tensorIndexes[0];
-					int p = tensorIndexes[1];
-					int q = tensorIndexes[2];
-					outputArray[outputOffset + p*params.Q*params.K + q*params.K + k] = ijv.getV();
+				if( !input.sparseBlock.isEmpty(inputN) ) {
+					int [] tensorIndexes = new int[3];
+					int apos = input.sparseBlock.pos(inputN);
+					int alen = input.sparseBlock.size(inputN);
+					int[] aix = input.sparseBlock.indexes(inputN);
+					double[] avals = input.sparseBlock.values(inputN);
+					for(int j = apos; j < apos+alen; j++) {
+						computeTensorIndexes(aix[j], tensorIndexes, params.P, params.Q);
+						int k = tensorIndexes[0];
+						int p = tensorIndexes[1];
+						int q = tensorIndexes[2];
+						outputArray[outputOffset + p*params.Q*params.K + q*params.K + k] = avals[j];
+					}
 				}
 			}
 		}
@@ -1137,22 +983,32 @@ public class LibMatrixDNN {
 						doPoolingBackward(n, _params);
 					break;
 				case BiasAdd:
-					doBiasAdd(_params, _rl, _ru);
+				{
+					double [] dest = _params.output.getDenseBlock();
+					ConvolutionUtils.binaryBiasOperations(_params.input1, _params.bias, dest, _params.K, _params.P*_params.Q, 
+							_rl, _ru, _binaryElementWiseAddition);
 					break;
+				}
 				case BiasMultiply:
-					doBiasMultiply(_params, _rl, _ru);
+				{
+					double [] dest = _params.output.getDenseBlock();
+					ConvolutionUtils.binaryBiasOperations(_params.input1, _params.bias, dest, _params.K, _params.P*_params.Q, 
+							_rl, _ru, _binaryElementWiseMultiplication);
 					break;
+				}
 				case ReluBackward:
 					lnnz = doReluBackward(_params, _rl, _ru);
 					break;
 				case LoopedIm2ColConv2d:
 				{	
 					MatrixBlock im2ColOutBlock = _im2ColOutBlocks.remove();
+					double [] temp = _params.input1.isInSparseFormat() ? new double[_params.input1.getNumColumns()] : null;
 					for(int n = _rl; n < _ru; n++) 
-						doLoopedIm2ColConv2d(n, im2ColOutBlock, _params);
+						doLoopedIm2ColConv2d(n, im2ColOutBlock, _params, temp);
 					_im2ColOutBlocks.add(im2ColOutBlock);
 					if(_params.bias != null)
-						addBias(_params, _rl, _ru);
+						ConvolutionUtils.binaryBiasOperationInPlace(_params.bias, _params.output.getDenseBlock(), _params.K, 
+								_params.P*_params.Q, _rl, _ru, _binaryElementWiseAddition);
 					break;
 				}
 				case LoopedIm2ColConv2dBwdFilter:
@@ -1160,8 +1016,9 @@ public class LibMatrixDNN {
 					MatrixBlock im2ColOutBlock = _im2ColOutBlocks.remove();
 					MatrixBlock partialRetBlock = _partialRetBlocks.remove();
 					MatrixBlock doutReshapedBlock = _doutReshapedBlocks.remove();
+					double [] temp = _params.input1.isInSparseFormat() ? new double[_params.input1.getNumColumns()] : null;
 					for(int n = _rl; n < _ru; n++) 
-						partialRetBlock = doLoopedIm2ColConv2dBwdFilter(n, im2ColOutBlock, doutReshapedBlock, partialRetBlock, _params);
+						partialRetBlock = doLoopedIm2ColConv2dBwdFilter(n, im2ColOutBlock, doutReshapedBlock, partialRetBlock, _params, temp);
 					_im2ColOutBlocks.add(im2ColOutBlock);
 					_partialRetBlocks.add(partialRetBlock);
 					_doutReshapedBlocks.add(doutReshapedBlock);
@@ -1182,37 +1039,6 @@ public class LibMatrixDNN {
 			return lnnz;
 		}
 	}
-	
-	private static void addBias(ConvolutionParameters params, int rl, int ru) {
-		int PQ = params.P*params.Q;
-		int K = params.K;
-		double [] outputArr = params.output.getDenseBlock();
-		if(!params.bias.isInSparseFormat()) {
-			double [] biasArr = params.bias.getDenseBlock();
-			int index = rl*K*PQ;
-			for(int n = rl; n < ru; n++) {
-				for(int k = 0; k < K; k++) {
-					for(int pq = 0; pq < PQ; pq++, index++) {
-						outputArr[index] += biasArr[k];
-					}
-				}
-			}
-		}
-		else {
-			Iterator<IJV> iter = params.bias.getSparseBlockIterator();
-			while(iter.hasNext()) {
-				IJV ijv = iter.next();
-				int k = ijv.getI();
-				double val = ijv.getV();
-				for(int n = rl; n < ru; n++) {
-					int index = n*K*PQ + k*PQ;
-					for(int pq = 0; pq < PQ; pq++, index++) {
-						outputArr[index] += val;
-					}
-				}
-			}
-		}
-	}
 		
 	// Converts input: PQ X CRS matrix and writes to 1 X CHW
 	private static void doCol2imOverSingleImage(int outputN, MatrixBlock input, ConvolutionParameters params) throws DMLRuntimeException {
@@ -1232,31 +1058,34 @@ public class LibMatrixDNN {
 			doCol2IMDenseInput(0, outputN, inputArray, outputArray, params);
 		}
 		else {
-			if(!input.isEmptyBlock())
-				doCol2IMSparseInput(0, outputN, input.getSparseBlockIterator(), outputArray, params);
-		}
-	}
-	
-	private static void doCol2IMSparseInput(int inputN, int outputN, Iterator<IJV> inputIter, double [] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
-		int [] tensorIndexes = new int[3];
-		
-		while(inputIter.hasNext()) {
-			IJV ijv = inputIter.next();
-			computeTensorIndexes(ijv.getJ(), tensorIndexes, params.R, params.S);
-			int c = tensorIndexes[0];
-			int r = tensorIndexes[1];
-			int s = tensorIndexes[2];
-			computeTensorIndexes(ijv.getI(), tensorIndexes, params.P, params.Q);
-			int p = tensorIndexes[1];
-			int q = tensorIndexes[2];
-			if(inputN != tensorIndexes[0]) {
-				throw new DMLRuntimeException("Incorrect tensor indexes: " + inputN + " != " + tensorIndexes[0] + " <" + p + " " + q + " " + ijv.getI() + params.P + " " + params.Q + ">");
-			}
-			int h = p*params.stride_h + r - params.pad_h;
-			int w = q*params.stride_w + s - params.pad_w;
-			if(h >= 0 && h < params.H && w >= 0 && w < params.W) {
-				int outIndex = outputN*params.C*params.H*params.W + c*params.H*params.W + h*params.W + w;
-				outputArray[outIndex] += ijv.getV();
+			if(!input.isEmptyBlock()) {
+				int [] tensorIndexes = new int[3];
+				for(int i = 0; i < input.getNumRows(); i++) {
+					if( !input.sparseBlock.isEmpty(i) ) {
+						computeTensorIndexes(i, tensorIndexes, params.P, params.Q);
+						int p = tensorIndexes[1];
+						int q = tensorIndexes[2];
+						if(tensorIndexes[0] != 0) 
+							throw new DMLRuntimeException("Incorrect tensor indexes: " + tensorIndexes[0] + " != 0 <" + p + " " + q + " " + tensorIndexes[0] + params.P + " " + params.Q + ">");
+						
+						int apos = input.sparseBlock.pos(i);
+						int alen = input.sparseBlock.size(i);
+						int[] aix = input.sparseBlock.indexes(i);
+						double[] avals = input.sparseBlock.values(i);
+						for(int j = apos; j < apos+alen; j++) {
+							computeTensorIndexes(aix[j], tensorIndexes, params.R, params.S);
+							int c = tensorIndexes[0];
+							int r = tensorIndexes[1];
+							int s = tensorIndexes[2];
+							int h = p*params.stride_h + r - params.pad_h;
+							int w = q*params.stride_w + s - params.pad_w;
+							if(h >= 0 && h < params.H && w >= 0 && w < params.W) {
+								int outIndex = outputN*params.C*params.H*params.W + c*params.H*params.W + h*params.W + w;
+								outputArray[outIndex] += avals[j];
+							}
+						}
+					}
+				}
 			}
 		}
 	}
@@ -1341,9 +1170,28 @@ public class LibMatrixDNN {
 		}
 	}
 	
+	// Returns the row of matrix in dense format
+	private static double [] getRowInDenseFormat(MatrixBlock input, int n, double []  temp) {
+		// Use temporary array to avoid binary search
+		Arrays.fill(temp, 0);
+		if( !input.sparseBlock.isEmpty(n) ) {
+			int apos = input.sparseBlock.pos(n);
+			int alen = input.sparseBlock.size(n);
+			int[] aix = input.sparseBlock.indexes(n);
+			double[] avals = input.sparseBlock.values(n);
+			for(int j=apos; j<apos+alen; j++)
+				temp[ aix[j] ] = avals[j];
+		}
+		return temp;
+	}
+	
 	// Keeping this as a separate sparse method to allow for further dense optimizations
-	private static void doIm2colSparse(int n, MatrixBlock input, double [] outputArray, ConvolutionParameters params) {
+	private static void doIm2colSparse(int n, MatrixBlock input, double [] outputArray, ConvolutionParameters params, double []  temp) throws DMLRuntimeException {
 		int CRS = params.C * params.R * params.S;
+		
+		// Using a temporary array improves performance by not requiring binary search for getValue
+		// Since the access pattern depends on ConvolutionParameters, this serves as a temporary fix.
+		temp = getRowInDenseFormat(input, n, temp);
 		// final int nOffset = n * params.C*params.H*params.W;
 		for (int c = 0; c < CRS; ++c) {
 			int wOffset = c % params.S;
@@ -1359,10 +1207,8 @@ public class LibMatrixDNN {
 				} else {
 					for (int w = 0; w < params.Q; ++w) {
 						int wPadded = w * params.stride_w - params.pad_w + wOffset;
-						if (wPadded >= 0 && wPadded < params.W) {
-							// NOTE: Potential performance bottleneck as we have to do binary search to getValue
-							outputArray[outOffset + w] = input.getValue(n, tempOffset + wPadded);
-						}
+						if (wPadded >= 0 && wPadded < params.W) 
+							outputArray[outOffset + w] = temp[tempOffset + wPadded];
 						else
 							outputArray[outOffset + w] = 0;
 					}
@@ -1371,7 +1217,7 @@ public class LibMatrixDNN {
 		}
 	}
 	
-	private static void doIm2col(int n, MatrixBlock output, ConvolutionParameters params) throws DMLRuntimeException {
+	private static void doIm2col(int n, MatrixBlock output, ConvolutionParameters params, double []  temp) throws DMLRuntimeException {
 		double [] inputArray = null;
 		if (!params.input1.isInSparseFormat())
 			inputArray = params.input1.getDenseBlock();
@@ -1384,12 +1230,6 @@ public class LibMatrixDNN {
 		if(inputArray != null)
 			doIm2colDense(n, inputArray, outputArray, params);
 		else
-			doIm2colSparse(n, params.input1, outputArray, params);
+			doIm2colSparse(n, params.input1, outputArray, params, temp);
 	}
-	
-	// ------------------------------------------------------------------------------------------------
-	// Used in integration tests. Please donot edit them
-	public static boolean TEST_SPARSE_INPUT = false;
-	public static boolean TEST_SPARSE_FILTER = false;
-	// ------------------------------------------------------------------------------------------------
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java b/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
index 814cf22..b988546 100644
--- a/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
@@ -19,6 +19,13 @@
 
 package org.apache.sysml.runtime.util;
 
+import java.util.Arrays;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.ScalarOperator;
+
 
 public class ConvolutionUtils {
 	
@@ -52,4 +59,141 @@ public class ConvolutionUtils {
 		return ret;
 	}
 	
+	// Performs dest[destPos ...] <- src[src_rl:src_ru, ]
+	//Assumes that dest is zeroed-out before calling
+	public static void copy(MatrixBlock src, double [] dest, int destPos, int destNumCols, int src_rl, int src_ru) {
+		if(src.isInSparseFormat()) {
+			if(!src.isEmptyBlock()) {
+				for(int i = src_rl, cix = destPos; i < src_ru; i++, cix += destNumCols) {
+					if( !src.getSparseBlock().isEmpty(i) ) {
+						int apos = src.getSparseBlock().pos(i);
+						int alen = src.getSparseBlock().size(i);
+						int[] aix = src.getSparseBlock().indexes(i);
+						double[] avals = src.getSparseBlock().values(i);
+						for(int j = apos; j < apos+alen; j++) {
+							dest[ cix+aix[j] ] = avals[j];
+						}
+					}
+				}
+			}
+		}
+		else {
+			System.arraycopy(src.getDenseBlock(), src_rl*src.getNumColumns(), dest, destPos, (src_ru-src_rl)*src.getNumColumns());
+		}
+	}
+	
+	// Performs dest[destPos...] op= thatValue[src_rl:src_ru,]
+	public static void binaryOperationInPlace(MatrixBlock src, double [] dest, 
+			int destPos, int destNumCols, int src_rl, int src_ru, BinaryOperator op) throws DMLRuntimeException {
+		if(src.isInSparseFormat()) {
+			for(int i = src_rl, cix = destPos; i < src_ru; i++, cix += destNumCols) {
+				if( !src.getSparseBlock().isEmpty(i) ) {
+					int apos = src.getSparseBlock().pos(i);
+					int alen = src.getSparseBlock().size(i);
+					int[] aix = src.getSparseBlock().indexes(i);
+					double[] avals = src.getSparseBlock().values(i);
+					for(int j = apos; j < apos+alen; j++) {
+						dest[ cix+aix[j] ] = op.fn.execute(dest[ cix+aix[j] ], avals[j]);
+					}
+				}
+			}
+		}
+		else {
+			double [] inputArr = src.getDenseBlock();
+			for(int i = destPos; i < src_ru*destNumCols; i++) {
+				dest[i] = op.fn.execute(dest[i], inputArr[i]);
+			}
+		}
+	}
+	
+	// Performs dest[destPos...] = src[src_rl:src_ru,] op scalar
+	public static void scalarOperations(MatrixBlock src, double [] dest, 
+			int destPos, int destNumCols, int src_rl, int src_ru, ScalarOperator scalarOp) throws DMLRuntimeException {
+		if(src.isInSparseFormat()) {
+			for(int i = src_rl, cix = destPos; i < src_ru; i++, cix += destNumCols) {
+				if( !src.getSparseBlock().isEmpty(i) ) {
+					int apos = src.getSparseBlock().pos(i);
+					int alen = src.getSparseBlock().size(i);
+					int[] aix = src.getSparseBlock().indexes(i);
+					double[] avals = src.getSparseBlock().values(i);
+					for(int j = apos; j < apos+alen; j++) {
+						dest[ cix+aix[j] ] = scalarOp.executeScalar(avals[j]);
+					}
+				}
+			}
+		}
+		else {
+			double [] inputArr = src.getDenseBlock();
+			for(int i = destPos; i < src_ru*destNumCols; i++) {
+				dest[i] = scalarOp.executeScalar(inputArr[i]);
+			}
+		}
+	}
+	
+	// dest (of size N x KPQ) = input (of size N x KPQ) op bias (of size K x 1)
+	public static void binaryBiasOperations(MatrixBlock input, MatrixBlock bias, double [] dest, 
+			int K, int PQ, int rl, int ru, BinaryOperator op) throws DMLRuntimeException {
+		copy(input, dest, rl*K*PQ, K*PQ, rl, ru);
+		binaryBiasOperationInPlace(bias, dest, K, PQ, rl, ru, op);
+	}
+	
+	// dest (of size N x KPQ) op= bias (of size K x 1)
+	public static void binaryBiasOperationInPlace(MatrixBlock bias, double [] dest, 
+			int K, int PQ, int rl, int ru, BinaryOperator op) throws DMLRuntimeException {
+		// bias.getNumColumns() == 1 checked outside
+		if(!bias.isInSparseFormat()) {
+			double [] biasArr = bias.getDenseBlock();
+			int index = rl*K*PQ;
+			for(int n = rl; n < ru; n++) {
+				for(int k = 0; k < K; k++) {
+					for(int pq = 0; pq < PQ; pq++, index++) {
+						dest[index] = op.fn.execute(dest[index], biasArr[k]);
+					}
+				}
+			}
+		}
+		else {
+			for(int k = 0; k < K; k++) {
+				if( !bias.getSparseBlock().isEmpty(k) ) {
+					int apos = bias.getSparseBlock().pos(k);
+					double[] avals = bias.getSparseBlock().values(k);
+					double val = avals[apos];
+					for(int n = rl; n < ru; n++) {
+						int index = n*K*PQ + k*PQ;
+						for(int pq = 0; pq < PQ; pq++, index++) {
+							dest[index] = op.fn.execute(dest[index], val);
+						}
+					}
+				}
+			}
+		}
+	}
+	
+	public static void fillBias(MatrixBlock bias, double [] outputArray, int src_rl, int src_ru, int N, int K, int PQ) throws DMLRuntimeException {
+		// bias.getNumColumns() == 1 checked outside
+		if(bias.isInSparseFormat()) {
+			for(int k = 0; k < K; k++) {
+				if( !bias.getSparseBlock().isEmpty(k) ) {
+					int apos = bias.getSparseBlock().pos(k);
+					double[] avals = bias.getSparseBlock().values(k);
+					double val = avals[apos];
+					for(int n = src_rl; n < src_ru; n++) {
+						int fromIndex = n*K*PQ + k*PQ;
+						Arrays.fill(outputArray, fromIndex, fromIndex + PQ, val);
+					}
+				}
+			}
+		}
+		else {
+			double [] biasArr = bias.getDenseBlock();
+			for(int n = src_rl; n < src_ru; n++) {
+				for(int k = 0; k < K; k++) {
+					int fromIndex = n*K*PQ + k*PQ;
+					double val = biasArr[k];
+					Arrays.fill(outputArray, fromIndex, fromIndex + PQ, val);
+				}
+			}
+		}
+	}
+	
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DBackwardDataTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DBackwardDataTest.java b/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DBackwardDataTest.java
index d3b6742..8f01f06 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DBackwardDataTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DBackwardDataTest.java
@@ -45,33 +45,84 @@ public class Conv2DBackwardDataTest extends AutomatedTestBase
 	}
 	
 	@Test
-	public void testConv2DDense1() 
+	public void testConv2DBwdDataDense1() 
 	{
 		int numImg = 2; int imgSize = 10; int numChannels = 3; int numFilters = 2; int filterSize = 2; int stride = 1; int pad = 0;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense2() 
 	{
 		int numImg = 5; int imgSize = 3; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 1; int pad = 1;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense3() 
 	{
 		int numImg = 5; int imgSize = 3; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 2; int pad = 1;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
-	public void testConv2DDense4() 
+	public void testConv2DBwdDataDense4() 
 	{
 		int numImg = 5; int imgSize = 10; int numChannels = 2; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
+	@Test
+	public void testConv2DBwdDataSparse1() 
+	{
+		int numImg = 2; int imgSize = 10; int numChannels = 3; int numFilters = 2; int filterSize = 2; int stride = 1; int pad = 0;
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
+	}
+	
+	@Test
+	public void testConv2DBwdDataSparse2() 
+	{
+		int numImg = 5; int imgSize = 3; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 1; int pad = 1;
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
+	}
+	
+	@Test
+	public void testConv2DBwdDataSparse3() 
+	{
+		int numImg = 5; int imgSize = 3; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 2; int pad = 1;
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, true);
+	}
+	
+	@Test
+	public void testConv2DBwdDataSparse4() 
+	{
+		int numImg = 5; int imgSize = 10; int numChannels = 2; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, true);
+	}
+	
+	@Test
+	public void testConv2DBwdDataSparse5() 
+	{
+		int numImg = 5; int imgSize = 3; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 1; int pad = 1;
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
+	}
+	
+	@Test
+	public void testConv2DBwdDataSparse6() 
+	{
+		int numImg = 5; int imgSize = 3; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 2; int pad = 1;
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, true);
+	}
+	
+	@Test
+	public void testConv2DBwdDataSparse7() 
+	{
+		int numImg = 5; int imgSize = 10; int numChannels = 2; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, true);
+	}
+	
+	
+	
 	
 	/**
 	 * 
@@ -79,7 +130,7 @@ public class Conv2DBackwardDataTest extends AutomatedTestBase
 	 * @param sparse
 	 */
 	public void runConv2DTest( ExecType et, int imgSize, int numImg, int numChannels, int numFilters, 
-			int filterSize, int stride, int pad) 
+			int filterSize, int stride, int pad, boolean sparse1, boolean sparse2) 
 	{
 		RUNTIME_PLATFORM oldRTP = rtplatform;
 			
@@ -87,13 +138,13 @@ public class Conv2DBackwardDataTest extends AutomatedTestBase
 		
 		try
 		{
-		    TestConfiguration config = getTestConfiguration(TEST_NAME);
-		    if(et == ExecType.SPARK) {
-		    	rtplatform = RUNTIME_PLATFORM.SPARK;
-		    }
-		    else {
-		    	rtplatform = (et==ExecType.MR)? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.SINGLE_NODE;
-		    }
+	    TestConfiguration config = getTestConfiguration(TEST_NAME);
+	    if(et == ExecType.SPARK) {
+	    	rtplatform = RUNTIME_PLATFORM.SPARK;
+	    }
+	    else {
+	    	rtplatform = (et==ExecType.MR)? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.SINGLE_NODE;
+	    }
 			if( rtplatform == RUNTIME_PLATFORM.SPARK )
 				DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 			
@@ -103,13 +154,15 @@ public class Conv2DBackwardDataTest extends AutomatedTestBase
 			String RI_HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
 			
+			String sparseVal1 = (""+sparse1).toUpperCase();
+			String sparseVal2 = (""+sparse2).toUpperCase();
 			
 			long P = ConvolutionUtils.getP(imgSize, filterSize, stride, pad);
 			programArgs = new String[]{"-explain", "-args",  "" + imgSize, "" + numImg, 
 					"" + numChannels, "" + numFilters, 
 					"" + filterSize, "" + stride, "" + pad,
 					"" + P, "" + P, 
-					output("B")};
+					output("B"), sparseVal1, sparseVal2};
 			        
 			boolean exceptionExpected = false;
 			int expectedNumberOfJobs = -1;
@@ -118,7 +171,8 @@ public class Conv2DBackwardDataTest extends AutomatedTestBase
 			fullRScriptName = RI_HOME + TEST_NAME + ".R";
 			rCmd = "Rscript" + " " + fullRScriptName + " " + imgSize + " " + numImg + 
 					" " + numChannels + " " + numFilters + 
-					" " + filterSize + " " + stride + " " + pad + " " + P + " " + P + " " + expectedDir();
+					" " + filterSize + " " + stride + " " + pad + " " + P + " " + P + " " + expectedDir() +
+					" " + sparseVal1 + " " + sparseVal2;
 			// Run comparison R script
 			runRScript(true);
 			HashMap<CellIndex, Double> bHM = readRMatrixFromFS("B");
@@ -132,6 +186,7 @@ public class Conv2DBackwardDataTest extends AutomatedTestBase
 			rtplatform = oldRTP;
 			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 		}
+		
 	}
 	
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DBackwardTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DBackwardTest.java b/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DBackwardTest.java
index 74d3d14..decca59 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DBackwardTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DBackwardTest.java
@@ -49,35 +49,140 @@ public class Conv2DBackwardTest extends AutomatedTestBase
 	public void testConv2DBackwardFilterDense1() 
 	{
 		int numImg = 3; int imgSize = 3; int numChannels = 3; int numFilters = 1; int filterSize = 2; int stride = 1; int pad = 0;
-		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad);
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DBackwardFilterDense2() 
 	{
 		int numImg = 3; int imgSize = 3; int numChannels = 3; int numFilters = 4; int filterSize = 2; int stride = 1; int pad = 0;
-		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad);
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DBackwardFilterDense3() 
 	{
 		int numImg = 3; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
-		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad);
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DBackwardFilterDense4() 
 	{
 		int numImg = 3; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 5; int stride = 1; int pad = 1;
-		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad);
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DBackwardFilterDense5() 
 	{
 		int numImg = 3; int imgSize = 10; int numChannels = 2; int numFilters = 3; int filterSize = 5; int stride = 3; int pad = 2;
-		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad);
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse1() 
+	{
+		int numImg = 3; int imgSize = 3; int numChannels = 3; int numFilters = 1; int filterSize = 2; int stride = 1; int pad = 0;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, true);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse2() 
+	{
+		int numImg = 3; int imgSize = 3; int numChannels = 3; int numFilters = 4; int filterSize = 2; int stride = 1; int pad = 0;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, true);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse3() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, true);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse4() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 5; int stride = 1; int pad = 1;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, true);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse5() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 2; int numFilters = 3; int filterSize = 5; int stride = 3; int pad = 2;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, true);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse6() 
+	{
+		int numImg = 3; int imgSize = 3; int numChannels = 3; int numFilters = 1; int filterSize = 2; int stride = 1; int pad = 0;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse7() 
+	{
+		int numImg = 3; int imgSize = 3; int numChannels = 3; int numFilters = 4; int filterSize = 2; int stride = 1; int pad = 0;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse8() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse9() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 5; int stride = 1; int pad = 1;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse10() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 2; int numFilters = 3; int filterSize = 5; int stride = 3; int pad = 2;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse11() 
+	{
+		int numImg = 3; int imgSize = 3; int numChannels = 3; int numFilters = 1; int filterSize = 2; int stride = 1; int pad = 0;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, true);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse12() 
+	{
+		int numImg = 3; int imgSize = 3; int numChannels = 3; int numFilters = 4; int filterSize = 2; int stride = 1; int pad = 0;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, true);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse13() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, true);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse14() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 5; int stride = 1; int pad = 1;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, true);
+	}
+	
+	@Test
+	public void testConv2DBackwardFilterSparse15() 
+	{
+		int numImg = 3; int imgSize = 10; int numChannels = 2; int numFilters = 3; int filterSize = 5; int stride = 3; int pad = 2;
+		runConv2DBackwardFilterTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
 	}
 	
 	/**
@@ -86,20 +191,23 @@ public class Conv2DBackwardTest extends AutomatedTestBase
 	 * @param sparse
 	 */
 	public void runConv2DBackwardFilterTest( ExecType et, int imgSize, int numImg, int numChannels, int numFilters, 
-			int filterSize, int stride, int pad) 
+			int filterSize, int stride, int pad, boolean sparse1, boolean sparse2) 
 	{
 		RUNTIME_PLATFORM oldRTP = rtplatform;
 			
 		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
 		try
 		{
-		    TestConfiguration config = getTestConfiguration(TEST_NAME);
-		    if(et == ExecType.SPARK) {
-		    	rtplatform = RUNTIME_PLATFORM.SPARK;
-		    }
-		    else {
-		    	rtplatform = (et==ExecType.MR)? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.SINGLE_NODE;
-		    }
+			String sparseVal1 = (""+sparse1).toUpperCase();
+			String sparseVal2 = (""+sparse2).toUpperCase();
+			
+	    TestConfiguration config = getTestConfiguration(TEST_NAME);
+	    if(et == ExecType.SPARK) {
+	    	rtplatform = RUNTIME_PLATFORM.SPARK;
+	    }
+	    else {
+	    	rtplatform = (et==ExecType.MR)? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.SINGLE_NODE;
+	    }
 			if( rtplatform == RUNTIME_PLATFORM.SPARK )
 				DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 			
@@ -116,7 +224,7 @@ public class Conv2DBackwardTest extends AutomatedTestBase
 				"" + numChannels, "" + numFilters, 
 				"" + filterSize, "" + stride, "" + pad,
 				"" + P, "" + P, 
-				output("B")};
+				output("B"), sparseVal1, sparseVal2};
 			        
 			boolean exceptionExpected = false;
 			int expectedNumberOfJobs = -1;
@@ -125,7 +233,8 @@ public class Conv2DBackwardTest extends AutomatedTestBase
 			fullRScriptName = RI_HOME + TEST_NAME + ".R";
 			rCmd = "Rscript" + " " + fullRScriptName + " " + imgSize + " " + numImg + 
 					" " + numChannels + " " + numFilters + 
-					" " + filterSize + " " + stride + " " + pad + " " + P + " " + P + " " + expectedDir();
+					" " + filterSize + " " + stride + " " + pad + " " + P + " " + P + " " + expectedDir() +
+					" " + sparseVal1 + " " + sparseVal2;
 			// Run comparison R script
 			runRScript(true);
 			HashMap<CellIndex, Double> bHM = readRMatrixFromFS("B");

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DTest.java b/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DTest.java
index 81fe154..e5528d2 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/tensor/Conv2DTest.java
@@ -23,7 +23,6 @@ import java.util.HashMap;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.lops.LopProperties.ExecType;
-import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
 import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
@@ -48,7 +47,7 @@ public class Conv2DTest extends AutomatedTestBase
 	public void testConv2DDense1() 
 	{
 		int numImg = 5; int imgSize = 3; int numChannels = 3; int numFilters = 6; int filterSize = 2; int stride = 1; int pad = 0;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	
@@ -56,76 +55,76 @@ public class Conv2DTest extends AutomatedTestBase
 	public void testConv2DDense2() 
 	{
 		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 0;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense3() 
 	{
 		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 1;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense4() 
 	{
 		int numImg = 3; int imgSize = 10; int numChannels = 1; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense5() 
 	{
 		int numImg = 3; int imgSize = 8; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 1; int pad = 2;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense6() 
 	{
 		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 1; int pad = 0;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense7() 
 	{
 		int numImg = 3; int imgSize = 10; int numChannels = 1; int numFilters = 3; int filterSize = 2; int stride = 1; int pad = 0;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DSparse1() 
 	{
 		int numImg = 5; int imgSize = 3; int numChannels = 3; int numFilters = 6; int filterSize = 2; int stride = 1; int pad = 0;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
 	}
 	
 	@Test
 	public void testConv2DSparse2() 
 	{
 		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 0;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, true);
 	}
 	
 	@Test
 	public void testConv2DSparse3() 
 	{
 		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 1;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
 	}
 	
 	public void testConv2DSparse4() 
 	{
 		int numImg = 3; int imgSize = 10; int numChannels = 1; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, true);
 	}
 	
 	@Test
 	public void testConv2DSparse5() 
 	{
 		int numImg = 3; int imgSize = 8; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 1; int pad = 2;
-		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+		runConv2DTest(ExecType.CP, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
 	}
 	
 	// --------------------------------------------
@@ -135,83 +134,83 @@ public class Conv2DTest extends AutomatedTestBase
 	public void testConv2DDense1SP() 
 	{
 		int numImg = 5; int imgSize = 3; int numChannels = 3; int numFilters = 6; int filterSize = 2; int stride = 1; int pad = 0;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense2SP() 
 	{
 		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 0;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense3SP() 
 	{
 		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 1;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense4SP() 
 	{
 		int numImg = 3; int imgSize = 10; int numChannels = 1; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense5SP() 
 	{
 		int numImg = 3; int imgSize = 8; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 1; int pad = 2;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense6SP() 
 	{
 		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 1; int pad = 0;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DDense7SP() 
 	{
 		int numImg = 3; int imgSize = 10; int numChannels = 1; int numFilters = 3; int filterSize = 2; int stride = 1; int pad = 0;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, false);
 	}
 	
 	@Test
 	public void testConv2DSparse1SP() 
 	{
 		int numImg = 5; int imgSize = 3; int numChannels = 3; int numFilters = 6; int filterSize = 2; int stride = 1; int pad = 0;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, false, true);
 	}
 	
 	@Test
 	public void testConv2DSparse2SP() 
 	{
 		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 0;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
 	}
 	
 	@Test
 	public void testConv2DSparse3SP() 
 	{
 		int numImg = 1; int imgSize = 10; int numChannels = 4; int numFilters = 3; int filterSize = 4; int stride = 2; int pad = 1;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, false);
 	}
 	
 	public void testConv2DSparse4SP() 
 	{
 		int numImg = 3; int imgSize = 10; int numChannels = 1; int numFilters = 3; int filterSize = 2; int stride = 2; int pad = 1;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, true);
 	}
 	
 	@Test
 	public void testConv2DSparse5SP() 
 	{
 		int numImg = 3; int imgSize = 8; int numChannels = 2; int numFilters = 3; int filterSize = 3; int stride = 1; int pad = 2;
-		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true);
+		runConv2DTest(ExecType.SPARK, imgSize, numImg, numChannels, numFilters, filterSize, stride, pad, true, true);
 	}
 	
 	/**
@@ -220,64 +219,61 @@ public class Conv2DTest extends AutomatedTestBase
 	 * @param sparse
 	 */
 	public void runConv2DTest( ExecType et, int imgSize, int numImg, int numChannels, int numFilters, 
-			int filterSize, int stride, int pad, boolean sparse) 
+			int filterSize, int stride, int pad, boolean sparse1, boolean sparse2) 
 	{
 		RUNTIME_PLATFORM oldRTP = rtplatform;
 			
 		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
 		
-		synchronized(LibMatrixDNN.class) {
-			try
-			{
-				LibMatrixDNN.TEST_SPARSE_INPUT = sparse;
-				LibMatrixDNN.TEST_SPARSE_FILTER = sparse;
-				
-			    TestConfiguration config = getTestConfiguration(TEST_NAME);
-			    if(et == ExecType.SPARK) {
-			    	rtplatform = RUNTIME_PLATFORM.SPARK;
-			    }
-			    else {
-			    	rtplatform = (et==ExecType.MR)? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.SINGLE_NODE;
-			    }
-				if( rtplatform == RUNTIME_PLATFORM.SPARK )
-					DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-				
-				loadTestConfiguration(config);
-		        
-				/* This is for running the junit test the new way, i.e., construct the arguments directly */
-				String RI_HOME = SCRIPT_DIR + TEST_DIR;
-				fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
-				
-				
-				programArgs = new String[]{"-explain", "recompile_runtime", "-args",  "" + imgSize, "" + numImg, 
-					"" + numChannels, "" + numFilters, 
-					"" + filterSize, "" + stride, "" + pad, 
-					output("B")};
-				
-				fullRScriptName = RI_HOME + TEST_NAME + ".R";
-				rCmd = "Rscript" + " " + fullRScriptName + " " + imgSize + " " + numImg + 
-						" " + numChannels + " " + numFilters + 
-						" " + filterSize + " " + stride + " " + pad + " " + expectedDir(); 
-				
-				boolean exceptionExpected = false;
-				int expectedNumberOfJobs = -1;
-				runTest(true, exceptionExpected, null, expectedNumberOfJobs);
-	
-				// Run comparison R script
-				runRScript(true);
-				HashMap<CellIndex, Double> bHM = readRMatrixFromFS("B");
-				
-				HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("B");
-				TestUtils.compareMatrices(dmlfile, bHM, epsilon, "B-DML", "B-R");
-				
-			}
-			finally
-			{
-				rtplatform = oldRTP;
-				DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
-				LibMatrixDNN.TEST_SPARSE_INPUT = false;
-				LibMatrixDNN.TEST_SPARSE_FILTER = false;
-			}
+		try
+		{
+			String sparseVal1 = (""+sparse1).toUpperCase();
+			String sparseVal2 = (""+sparse2).toUpperCase();
+			
+	    TestConfiguration config = getTestConfiguration(TEST_NAME);
+	    if(et == ExecType.SPARK) {
+	    	rtplatform = RUNTIME_PLATFORM.SPARK;
+	    }
+	    else {
+	    	rtplatform = (et==ExecType.MR)? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.SINGLE_NODE;
+	    }
+			if( rtplatform == RUNTIME_PLATFORM.SPARK )
+				DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+			
+			loadTestConfiguration(config);
+	        
+			/* This is for running the junit test the new way, i.e., construct the arguments directly */
+			String RI_HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
+			
+			
+			programArgs = new String[]{"-explain", "recompile_runtime", "-args",  "" + imgSize, "" + numImg, 
+				"" + numChannels, "" + numFilters, 
+				"" + filterSize, "" + stride, "" + pad, 
+				output("B"), sparseVal1, sparseVal2};
+			
+			fullRScriptName = RI_HOME + TEST_NAME + ".R";
+			rCmd = "Rscript" + " " + fullRScriptName + " " + imgSize + " " + numImg + 
+					" " + numChannels + " " + numFilters + 
+					" " + filterSize + " " + stride + " " + pad + " " + expectedDir() +
+					" " + sparseVal1 + " " + sparseVal2; 
+			
+			boolean exceptionExpected = false;
+			int expectedNumberOfJobs = -1;
+			runTest(true, exceptionExpected, null, expectedNumberOfJobs);
+
+			// Run comparison R script
+			runRScript(true);
+			HashMap<CellIndex, Double> bHM = readRMatrixFromFS("B");
+			
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("B");
+			TestUtils.compareMatrices(dmlfile, bHM, epsilon, "B-DML", "B-R");
+			
+		}
+		finally
+		{
+			rtplatform = oldRTP;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolBackwardTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolBackwardTest.java b/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolBackwardTest.java
index 35cfad9..54fda03 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolBackwardTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolBackwardTest.java
@@ -48,21 +48,84 @@ public class PoolBackwardTest extends AutomatedTestBase
 	public void testMaxPool2DBackwardDense1() 
 	{
 		int numImg = 1; int imgSize = 4; int numChannels = 1;  int stride = 2; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
-		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false, false);
 	}
 	
 	@Test
 	public void testMaxPool2DBackwardDense2() 
 	{
 		int numImg = 3; int imgSize = 6; int numChannels = 3;  int stride = 1; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
-		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false, false);
 	}
 	
 	@Test
 	public void testMaxPool2DBackwardDense3() 
 	{
 		int numImg = 2; int imgSize = 7; int numChannels = 2;  int stride = 2; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
-		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false, false);
+	}
+	
+	@Test
+	public void testMaxPool2DBackwardSparse1() 
+	{
+		int numImg = 1; int imgSize = 4; int numChannels = 1;  int stride = 2; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", true, false);
+	}
+	
+	@Test
+	public void testMaxPool2DBackwardSparse2() 
+	{
+		int numImg = 3; int imgSize = 6; int numChannels = 3;  int stride = 1; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", true, false);
+	}
+	
+	@Test
+	public void testMaxPool2DBackwardSparse3() 
+	{
+		int numImg = 2; int imgSize = 7; int numChannels = 2;  int stride = 2; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", true, false);
+	}
+	
+	@Test
+	public void testMaxPool2DBackwardSparse4() 
+	{
+		int numImg = 1; int imgSize = 4; int numChannels = 1;  int stride = 2; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", true, true);
+	}
+	
+	@Test
+	public void testMaxPool2DBackwardSparse5() 
+	{
+		int numImg = 3; int imgSize = 6; int numChannels = 3;  int stride = 1; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", true, true);
+	}
+	
+	@Test
+	public void testMaxPool2DBackwardSparse6() 
+	{
+		int numImg = 2; int imgSize = 7; int numChannels = 2;  int stride = 2; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", true, true);
+	}
+	
+	@Test
+	public void testMaxPool2DBackwardSparse7() 
+	{
+		int numImg = 1; int imgSize = 4; int numChannels = 1;  int stride = 2; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false, true);
+	}
+	
+	@Test
+	public void testMaxPool2DBackwardSparse8() 
+	{
+		int numImg = 3; int imgSize = 6; int numChannels = 3;  int stride = 1; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false, true);
+	}
+	
+	@Test
+	public void testMaxPool2DBackwardSparse9() 
+	{
+		int numImg = 2; int imgSize = 7; int numChannels = 2;  int stride = 2; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false, true);
 	}
 	
 	/**
@@ -71,7 +134,7 @@ public class PoolBackwardTest extends AutomatedTestBase
 	 * @param sparse
 	 */
 	public void runPoolTest( ExecType et, int imgSize, int numImg, int numChannels, int stride, 
-			int pad, int poolSize1, int poolSize2, String poolMode) 
+			int pad, int poolSize1, int poolSize2, String poolMode, boolean sparse1, boolean sparse2) 
 	{
 		RUNTIME_PLATFORM oldRTP = rtplatform;
 			
@@ -79,13 +142,15 @@ public class PoolBackwardTest extends AutomatedTestBase
 		
 		try
 		{
-		    TestConfiguration config = getTestConfiguration(TEST_NAME);
-		    if(et == ExecType.SPARK) {
-		    	rtplatform = RUNTIME_PLATFORM.SPARK;
-		    }
-		    else {
-		    	rtplatform = (et==ExecType.MR)? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.SINGLE_NODE;
-		    }
+			String sparseVal1 = (""+sparse1).toUpperCase();
+			String sparseVal2 = (""+sparse2).toUpperCase();
+			TestConfiguration config = getTestConfiguration(TEST_NAME);
+	    if(et == ExecType.SPARK) {
+	    	rtplatform = RUNTIME_PLATFORM.SPARK;
+	    }
+	    else {
+	    	rtplatform = (et==ExecType.MR)? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.SINGLE_NODE;
+	    }
 			if( rtplatform == RUNTIME_PLATFORM.SPARK )
 				DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 			
@@ -100,7 +165,7 @@ public class PoolBackwardTest extends AutomatedTestBase
 					"" + numChannels, "" + poolSize1, "" + poolSize2, 
 					"" + stride, "" + pad, poolMode, 
 					"" + P, "" + P, 
-					output("B")};
+					output("B"), sparseVal1, sparseVal2};
 			        
 			boolean exceptionExpected = false;
 			int expectedNumberOfJobs = -1;
@@ -109,7 +174,8 @@ public class PoolBackwardTest extends AutomatedTestBase
 			fullRScriptName = RI_HOME + TEST_NAME + ".R";
 			rCmd = "Rscript" + " " + fullRScriptName + " " + imgSize + " " + numImg + 
 					" " + numChannels + " " + poolSize1 + 
-					" " + poolSize2 + " " + stride + " " + pad + " " +  P + " " + P + " " + expectedDir(); 
+					" " + poolSize2 + " " + stride + " " + pad + " " +  P + " " + P + " " + expectedDir() +
+					" " + sparseVal1 + " " + sparseVal2; 
 			
 			// Run comparison R script
 			runRScript(true);
@@ -124,6 +190,7 @@ public class PoolBackwardTest extends AutomatedTestBase
 			rtplatform = oldRTP;
 			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 		}
+		
 	}
 	
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java b/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
index c064ca6..e1c84c5 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
@@ -47,14 +47,14 @@ public class PoolTest extends AutomatedTestBase
 	public void testMaxPool2DDense1() 
 	{
 		int numImg = 1; int imgSize = 6; int numChannels = 1;  int stride = 2; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
-		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false);
 	}
 	
 	@Test
 	public void testMaxPool2DDense2() 
 	{
 		int numImg = 2; int imgSize = 6; int numChannels = 1;  int stride = 1; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
-		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false);
 	}
 	
 	
@@ -62,14 +62,43 @@ public class PoolTest extends AutomatedTestBase
 	public void testMaxPool2DDense3() 
 	{
 		int numImg = 3; int imgSize = 7; int numChannels = 2;  int stride = 2; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
-		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false);
 	}
 	
 	@Test
 	public void testMaxPool2DDense4() 
 	{
 		int numImg = 2; int imgSize = 4; int numChannels = 2;  int stride = 1; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
-		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false);
+	}
+	
+	@Test
+	public void testMaxPool2DSparse1() 
+	{
+		int numImg = 1; int imgSize = 6; int numChannels = 1;  int stride = 2; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", true);
+	}
+	
+	@Test
+	public void testMaxPool2DSparse2() 
+	{
+		int numImg = 2; int imgSize = 6; int numChannels = 1;  int stride = 1; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", true);
+	}
+	
+	
+	@Test
+	public void testMaxPool2DSparse3() 
+	{
+		int numImg = 3; int imgSize = 7; int numChannels = 2;  int stride = 2; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", true);
+	}
+	
+	@Test
+	public void testMaxPool2DSparse4() 
+	{
+		int numImg = 2; int imgSize = 4; int numChannels = 2;  int stride = 1; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
+		runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", true);
 	}
 	
 	// ----------------------------------------
@@ -78,14 +107,14 @@ public class PoolTest extends AutomatedTestBase
 	public void testMaxPool2DDense1SP() 
 	{
 		int numImg = 1; int imgSize = 50; int numChannels = 1;  int stride = 2; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
-		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false);
 	}
 	
 	@Test
 	public void testMaxPool2DDense2SP() 
 	{
 		int numImg = 2; int imgSize = 6; int numChannels = 1;  int stride = 1; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
-		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false);
 	}
 	
 	
@@ -93,14 +122,14 @@ public class PoolTest extends AutomatedTestBase
 	public void testMaxPool2DDense3SP() 
 	{
 		int numImg = 3; int imgSize = 7; int numChannels = 2;  int stride = 2; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
-		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false);
 	}
 	
 	@Test
 	public void testMaxPool2DDense4SP() 
 	{
 		int numImg = 2; int imgSize = 4; int numChannels = 2;  int stride = 1; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
-		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max");
+		runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, stride, pad, poolSize1, poolSize2, "max", false);
 	}
 	
 	/**
@@ -109,7 +138,7 @@ public class PoolTest extends AutomatedTestBase
 	 * @param sparse
 	 */
 	public void runPoolTest( ExecType et, int imgSize, int numImg, int numChannels, int stride, 
-			int pad, int poolSize1, int poolSize2, String poolMode) 
+			int pad, int poolSize1, int poolSize2, String poolMode, boolean sparse) 
 	{
 		RUNTIME_PLATFORM oldRTP = rtplatform;
 			
@@ -117,13 +146,14 @@ public class PoolTest extends AutomatedTestBase
 		
 		try
 		{
-		    TestConfiguration config = getTestConfiguration(TEST_NAME);
-		    if(et == ExecType.SPARK) {
-		    	rtplatform = RUNTIME_PLATFORM.SPARK;
-		    }
-		    else {
-		    	rtplatform = (et==ExecType.MR)? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.SINGLE_NODE;
-		    }
+			String sparseVal = (""+sparse).toUpperCase();
+	    TestConfiguration config = getTestConfiguration(TEST_NAME);
+	    if(et == ExecType.SPARK) {
+	    	rtplatform = RUNTIME_PLATFORM.SPARK;
+	    }
+	    else {
+	    	rtplatform = (et==ExecType.MR)? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.SINGLE_NODE;
+	    }
 			if( rtplatform == RUNTIME_PLATFORM.SPARK )
 				DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 			
@@ -136,7 +166,7 @@ public class PoolTest extends AutomatedTestBase
 			programArgs = new String[]{"-explain", "-args",  "" + imgSize, "" + numImg, 
 					"" + numChannels, "" + poolSize1, "" + poolSize2, 
 					"" + stride, "" + pad, poolMode, 
-					output("B")};
+					output("B"), sparseVal};
 			        
 			boolean exceptionExpected = false;
 			int expectedNumberOfJobs = -1;
@@ -145,7 +175,7 @@ public class PoolTest extends AutomatedTestBase
 			fullRScriptName = RI_HOME + TEST_NAME + ".R";
 			rCmd = "Rscript" + " " + fullRScriptName + " " + imgSize + " " + numImg + 
 					" " + numChannels + " " + poolSize1 + 
-					" " + poolSize2 + " " + stride + " " + pad + " " + expectedDir(); 
+					" " + poolSize2 + " " + stride + " " + pad + " " + expectedDir() + " " + sparseVal; 
 			
 			// Run comparison R script
 			runRScript(true);
@@ -162,4 +192,5 @@ public class PoolTest extends AutomatedTestBase
 		}
 	}
 	
+	
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/scripts/functions/tensor/Conv2DBackwardDataTest.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/Conv2DBackwardDataTest.R b/src/test/scripts/functions/tensor/Conv2DBackwardDataTest.R
index e66d9e2..a251f7a 100644
--- a/src/test/scripts/functions/tensor/Conv2DBackwardDataTest.R
+++ b/src/test/scripts/functions/tensor/Conv2DBackwardDataTest.R
@@ -34,7 +34,16 @@ Q=as.integer(args[9])
 w=matrix(seq(1, numFilters*numChannels*filterSize*filterSize), numFilters, numChannels*filterSize*filterSize, byrow=TRUE)
 dout=matrix(seq(1, numImg*numFilters*P*Q), numImg, numFilters*P*Q, byrow=TRUE)
 
-
+if(as.logical(args[11])) {
+	zero_mask = (w - mean(w)) > 0 
+	w = w * zero_mask
+}
+if(as.logical(args[12])) {
+	zero_mask = (dout - mean(dout)) > 0 
+	dout = dout * zero_mask
+}
+w = w - mean(w)
+dout = dout - mean(dout)
 col2im <- function(img_cols, C, Hin, Win, Hf, Wf,
                   strideh, stridew, reduction) {
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/scripts/functions/tensor/Conv2DBackwardDataTest.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/Conv2DBackwardDataTest.dml b/src/test/scripts/functions/tensor/Conv2DBackwardDataTest.dml
index 78b2dee..c10ac37 100644
--- a/src/test/scripts/functions/tensor/Conv2DBackwardDataTest.dml
+++ b/src/test/scripts/functions/tensor/Conv2DBackwardDataTest.dml
@@ -32,5 +32,15 @@ Q = $9
 # Assumption: NCHW image format
 w=matrix(seq(1, numFilters*numChannels*filterSize*filterSize), rows=numFilters, cols=numChannels*filterSize*filterSize)
 dout=matrix(seq(1, numImg*numFilters*P*Q), rows=numImg, cols=numFilters*P*Q)
+if($11) {
+	zero_mask = (w - mean(w)) > 0 
+	w = w * zero_mask
+}
+if($12) {
+	zero_mask = (dout - mean(dout)) > 0 
+	dout = dout * zero_mask
+}
+w = w - mean(w)
+dout = dout - mean(dout)
 dx = conv2d_backward_data(w, dout, stride=[stride, stride], padding=[pad, pad], input_shape=[numImg, numChannels, imgSize, imgSize], filter_shape=[numFilters, numChannels, filterSize, filterSize])
 write(dx, $10, format="text")
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/scripts/functions/tensor/Conv2DBackwardTest.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/Conv2DBackwardTest.R b/src/test/scripts/functions/tensor/Conv2DBackwardTest.R
index 91e0065..a6bbdca 100644
--- a/src/test/scripts/functions/tensor/Conv2DBackwardTest.R
+++ b/src/test/scripts/functions/tensor/Conv2DBackwardTest.R
@@ -34,7 +34,16 @@ Q=as.integer(args[9])
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, numChannels*imgSize*imgSize, byrow=TRUE)
 dout=matrix(seq(1, numImg*numFilters*P*Q), numImg, numFilters*P*Q, byrow=TRUE)
 
-
+if(as.logical(args[11])) {
+	zero_mask = (x - mean(x)) > 0 
+	x = x * zero_mask
+}
+if(as.logical(args[12])) {
+	zero_mask = (dout - mean(dout)) > 0 
+	dout = dout * zero_mask
+}
+x = x - mean(x)
+dout = dout - mean(dout)
 pad_image <- function(img, Hin, Win, padh, padw){
   C = nrow(img)
   img_padded = matrix(0, C, (Hin+2*padh)*(Win+2*padw))  # zeros

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/scripts/functions/tensor/Conv2DBackwardTest.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/Conv2DBackwardTest.dml b/src/test/scripts/functions/tensor/Conv2DBackwardTest.dml
index 155c77b..c98e52b 100644
--- a/src/test/scripts/functions/tensor/Conv2DBackwardTest.dml
+++ b/src/test/scripts/functions/tensor/Conv2DBackwardTest.dml
@@ -32,5 +32,15 @@ Q = $9
 # Assumption: NCHW image format
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), rows=numImg, cols=numChannels*imgSize*imgSize)
 dout=matrix(seq(1, numImg*numFilters*P*Q), rows=numImg, cols=numFilters*P*Q)
+if($11) {
+	zero_mask = (x - mean(x)) > 0 
+	x = x * zero_mask
+}
+if($12) {
+	zero_mask = (dout - mean(dout)) > 0 
+	dout = dout * zero_mask
+}
+x = x - mean(x)
+dout = dout - mean(dout)
 dw = conv2d_backward_filter(x, dout, stride=[stride, stride], padding=[pad, pad], input_shape=[numImg, numChannels, imgSize, imgSize], filter_shape=[numFilters, numChannels, filterSize, filterSize])
 write(dw, $10, format="text")
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2d2196d8/src/test/scripts/functions/tensor/Conv2DTest.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/Conv2DTest.R b/src/test/scripts/functions/tensor/Conv2DTest.R
index 15e0e81..bec1ed7 100644
--- a/src/test/scripts/functions/tensor/Conv2DTest.R
+++ b/src/test/scripts/functions/tensor/Conv2DTest.R
@@ -32,6 +32,16 @@ pad=as.integer(args[7])
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, numChannels*imgSize*imgSize, byrow=TRUE)
 w=matrix(seq(1, numFilters*numChannels*filterSize*filterSize), numFilters, numChannels*filterSize*filterSize, byrow=TRUE)
 
+if(as.logical(args[9])) {
+	zero_mask = (x - mean(x)) > 0 
+	x = x * zero_mask
+}
+if(as.logical(args[10])) {
+	zero_mask = (w - mean(w)) > 0 
+	w = w * zero_mask
+}
+x = x - mean(x)
+w = w - mean(w)
 pad_image <- function(img, Hin, Win, padh, padw){
   C = nrow(img)
   img_padded = matrix(0, C, (Hin+2*padh)*(Win+2*padw), byrow=TRUE)  # zeros