You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/05/29 23:21:57 UTC

[1/2] incubator-systemml git commit: [SYSTEMML-540] Refactored LibMatrixDNN to reduce instruction cache misses

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 28c92b93f -> 19eed8f38


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
new file mode 100644
index 0000000..40a39f0
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
@@ -0,0 +1,541 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.matrix.data;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.concurrent.Callable;
+
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.util.ConvolutionUtils;
+import org.apache.sysml.utils.NativeHelper;
+
+
+public class LibMatrixDNNHelper {
+	
+	// *********************************** low-level runtime operator selection ***********************************************
+	// *********************************** based on runtime properties (sparsity, native, etc) ********************************
+	// These methods help reduce branch miss predictions and instruction-cache misses.
+	// Also, they simplify the design of LibMatrixDNN and help in code-maintenance.
+	
+	/**
+	 * Factory method that returns list of callable tasks for performing maxpooling operation
+	 * 
+	 * @param params convolution parameters
+	 * @return list of callable tasks for performing maxpooling operation
+	 * @throws DMLRuntimeException if error occurs
+	 */
+	public static ArrayList<Callable<Long>> getMaxPoolingWorkers(ConvolutionParameters params) throws DMLRuntimeException {
+		ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
+		int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
+		int taskSize = (int)(Math.ceil((double)params.N / k));
+		for(int i = 0; i*taskSize < params.N; i++) {
+			if(params.input1.isInSparseFormat())
+				ret.add(new LibMatrixDNNPoolingHelper.SparseMaxPooling(i*taskSize, Math.min((i+1)*taskSize, params.N), params));
+			else
+				ret.add(new LibMatrixDNNPoolingHelper.DenseMaxPooling(i*taskSize, Math.min((i+1)*taskSize, params.N), params));
+		}
+		return ret;
+	}
+	
+	/**
+	 * Factory method that returns list of callable tasks for performing maxpooling backward operation
+	 * 
+	 * @param params convolution parameters
+	 * @return list of callable tasks for performing maxpooling backward operation
+	 * @throws DMLRuntimeException if error occurs
+	 */
+	public static ArrayList<Callable<Long>> getMaxPoolingBackwardWorkers(ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException {
+		ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
+		int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
+		int taskSize = (int)(Math.ceil((double)params.N / k));
+		for(int i = 0; i*taskSize < params.N; i++) {
+			if(!params.input1.isInSparseFormat()) {
+				if(!params.input2.isInSparseFormat()) 
+					ret.add(new LibMatrixDNNPoolingBackwardHelper.PoolingBackwardDenseDense(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward));
+				else
+					ret.add(new LibMatrixDNNPoolingBackwardHelper.PoolingBackwardDenseSparse(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward));
+			}
+			else {
+				if(!params.input2.isInSparseFormat()) 
+					ret.add(new LibMatrixDNNPoolingBackwardHelper.PoolingBackwardSparseDense(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward));
+				else
+					ret.add(new LibMatrixDNNPoolingBackwardHelper.PoolingBackwardSparseSparse(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward));
+			}
+		}
+		return ret;
+	}
+	
+	/**
+	 * Factory method that returns list of callable tasks for performing relu backward operation
+	 * 
+	 * @param params convolution parameters
+	 * @return list of callable tasks for performing relu backward operation
+	 * @throws DMLRuntimeException if error occurs
+	 */
+	public static ArrayList<Callable<Long>> getReluBackwardWorkers(ConvolutionParameters params) throws DMLRuntimeException {
+		ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
+		int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
+		int taskSize = (int)(Math.ceil((double)params.N / k));
+		for(int i = 0; i*taskSize < params.N; i++) {
+			ret.add(new ReluBackward(i*taskSize, Math.min((i+1)*taskSize, params.N), params));
+		}
+		return ret;
+	}
+	
+	/**
+	 * Factory method that returns list of callable tasks for performing conv2d
+	 * 
+	 * @param params convolution parameters
+	 * @return list of callable tasks for performing conv2d
+	 * @throws DMLRuntimeException if error occurs
+	 */
+	public static ArrayList<Callable<Long>> getConv2dWorkers(ConvolutionParameters params) throws DMLRuntimeException {
+		ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
+		
+		// Try to create as many tasks as threads. 
+		// Creating more tasks will help in tail, but would have additional overhead of maintaining the intermediate
+		// data structures such as im2col blocks.
+		int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
+		int taskSize = (int)(Math.ceil((double)params.N / k));
+		
+		// TODO: Decide here based on params whether to use LoopedIm2ColConv2dAllChannels or LoopedIm2ColConv2dOneChannel
+		// For now, let's stick to the existing approach of converting [1, CHW] to [CRS, PQ] as it allows matrix multiplication large enough matrix.
+		boolean allChannels = true; ArrayList<MatrixBlock> filters = null;
+		if(!allChannels) {
+			filters = splitFilter(params);
+		}
+		
+		boolean isEmptyDenseInput = !params.input1.isInSparseFormat() && params.input1.denseBlock == null;
+		
+		for(int i = 0; i*taskSize < params.N; i++) {
+			if(LibMatrixDNN.isEligibleForConv2dSparse(params)) 
+				ret.add(new LibMatrixDNNConv2dHelper.SparseNativeConv2d(i*taskSize, Math.min((i+1)*taskSize, params.N), params));
+			else if(!isEmptyDenseInput && allChannels)
+				ret.add(new LibMatrixDNNConv2dHelper.LoopedIm2ColConv2dAllChannels(i*taskSize, Math.min((i+1)*taskSize, params.N), params));
+			else if(!isEmptyDenseInput && !allChannels)
+				ret.add(new LibMatrixDNNConv2dHelper.LoopedIm2ColConv2dOneChannel(i*taskSize, Math.min((i+1)*taskSize, params.N), params, filters));
+			else
+				throw new DMLRuntimeException("Unsupported operator");
+		}
+		return ret;
+	}
+	
+	/**
+	 * Factory method that returns list of callable tasks for performing conv2d backward filter
+	 * 
+	 * @param params convolution parameters
+	 * @return list of callable tasks for performing conv2d backward filter
+	 * @throws DMLRuntimeException if error occurs
+	 */
+	public static ArrayList<Callable<Long>> getConv2dBackwardFilterWorkers(ConvolutionParameters params) throws DMLRuntimeException {
+		ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
+		// Try to create as many tasks as threads. 
+		// Creating more tasks will help in tail, but would have additional overhead of maintaining the intermediate
+		// data structures such as im2col blocks.
+		int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
+		int taskSize = (int)(Math.ceil((double)params.N / k));
+		
+		boolean isEmptyDenseInput = (!params.input1.isInSparseFormat() && params.input1.denseBlock == null) || 
+																(!params.input2.isInSparseFormat() && params.input2.denseBlock == null);
+		
+		for(int i = 0; i*taskSize < params.N; i++) {
+			if(LibMatrixDNN.isEligibleForConv2dBackwardFilterSparseDense(params)) 
+				ret.add(new LibMatrixDNNConv2dBackwardFilterHelper.SparseNativeConv2dBackwardFilterDense(i*taskSize, Math.min((i+1)*taskSize, params.N), params));
+			else if(!isEmptyDenseInput)
+				ret.add(new LibMatrixDNNConv2dBackwardFilterHelper.Conv2dBackwardFilter(i*taskSize, Math.min((i+1)*taskSize, params.N), params));
+			else
+				throw new DMLRuntimeException("Unsupported operator");
+		}
+		return ret;
+	}
+	
+	/**
+	 * Factory method that returns list of callable tasks for performing conv2d backward data
+	 * 
+	 * @param params convolution parameters
+	 * @return list of callable tasks for performing conv2d backward data
+	 * @throws DMLRuntimeException if error occurs
+	 */
+	public static ArrayList<Callable<Long>> getConv2dBackwardDataWorkers(ConvolutionParameters params) throws DMLRuntimeException {
+		ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
+		
+		// Try to create as many tasks as threads. 
+		// Creating more tasks will help in tail, but would have additional overhead of maintaining the intermediate
+		// data structures such as im2col blocks.
+		int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
+		int taskSize = (int)(Math.ceil((double)params.N / k));
+		
+		boolean isEmptyDenseInput = (!params.input1.isInSparseFormat() && params.input1.denseBlock == null) || 
+																(!params.input2.isInSparseFormat() && params.input2.denseBlock == null);
+		
+		for(int i = 0; i*taskSize < params.N; i++) {
+			if(LibMatrixDNN.isEligibleForConv2dBackwardDataDense(params)) 
+				ret.add(new LibMatrixDNNConv2dBackwardDataHelper.SparseNativeConv2dBackwardDataDense(i*taskSize, Math.min((i+1)*taskSize, params.N), params));
+			else if(!isEmptyDenseInput)
+				ret.add(new LibMatrixDNNConv2dBackwardDataHelper.Conv2dBackwardData(i*taskSize, Math.min((i+1)*taskSize, params.N), params));
+			else
+				throw new DMLRuntimeException("Unsupported operator");
+		}
+			
+		return ret;
+	}
+	
+	// *********************************** relu backward operator ******************************************************
+	
+	/**
+	 * Performs the operation: (X gt 0) * dout
+	 */
+	public static class ReluBackward implements Callable<Long> 
+	{
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params; 
+		double [] outputArray; int numOutCols;
+		public ReluBackward(int rl, int ru, ConvolutionParameters params) {
+			_rl = rl; _ru = ru;
+			_params = params;
+			outputArray= params.output.getDenseBlock();
+			numOutCols = params.input1.getNumColumns();
+		}
+		
+		@Override
+		public Long call() throws Exception {
+			if(!_params.input1.isInSparseFormat() && !_params.input2.isInSparseFormat()) {
+				double [] inputArr = _params.input1.getDenseBlock();
+				double [] doutArr = _params.input2.getDenseBlock();
+				for(int i = _rl*numOutCols; i < _ru*numOutCols; i++) {
+					outputArray[i] = inputArr[i] > 0 ? doutArr[i] : 0;
+				}
+			}
+			else {
+				// Perform (X > 0)
+				ConvolutionUtils.scalarOperations(_params.input1, outputArray, _rl*numOutCols, numOutCols, _rl, _ru, 
+						InstructionUtils.parseScalarBinaryOperator(">", false, 0));
+				// Then perform (X > 0) * dout
+				ConvolutionUtils.binaryOperationInPlace(_params.input2, outputArray, _rl*numOutCols, numOutCols, _rl, _ru, 
+						LibMatrixDNN._binaryElementWiseMultiplication);
+			}
+			return 0L;
+		}
+	}
+	
+	// *********************************** utility methods ******************************************************
+	
+	/**
+	 * Computes tensor indexes from column index such that column index  is equal to ret[0]*HW + ret[1]*W + ret[2]
+	 * 
+	 * @param j column index
+	 * @param ret tensor indexes
+	 * @param H second last dimension
+	 * @param W last dimension
+	 */
+	static void computeTensorIndexes(int j, int [] ret, int H, int W) {
+		ret[0] = j / (H*W);
+		ret[1] = (j - ret[0]*(H*W))/W;
+		ret[2] = j % W;
+	}
+	
+	//Split a filter of size [K, CRS] into c filters of [K, RS]
+	private static ArrayList<MatrixBlock> splitFilter(ConvolutionParameters _params) {
+		ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>();
+		int RS = _params.R*_params.S; int CRS = _params.C*_params.R*_params.S;
+		double [] filter = _params.input2.getDenseBlock(); int S = _params.S;
+		for(int c = 0; c < _params.C; c++) {
+			MatrixBlock mb = new MatrixBlock(_params.K, RS, false);
+			mb.allocateDenseBlock(); long nnz = 0;
+			double [] outputArr = mb.getDenseBlock();
+			if(filter != null) {
+				for(int k = 0; k < _params.K; k++) {
+					for(int rs = 0; rs < RS; rs++) {
+						outputArr[k*RS + rs] = filter[k*CRS + c*RS + rs];
+						nnz += outputArr[k*RS + rs] != 0 ? 1 : 0;
+					}
+				}
+			}
+			else {
+				for(int k = 0; k < _params.K; k++) {
+					if( !_params.input2.sparseBlock.isEmpty(k) ) {
+						int [] tensorIndexes = new int[3];
+						// Find maxIndex
+						int apos = _params.input2.sparseBlock.pos(k);
+						int alen = _params.input2.sparseBlock.size(k);
+						int[] aix = _params.input2.sparseBlock.indexes(k);
+						double[] avals = _params.input2.sparseBlock.values(k);
+						for(int j=apos; j<apos+alen; j++) {
+							computeTensorIndexes(aix[j], tensorIndexes, _params.R, _params.S);
+							if(c != tensorIndexes[0])
+								continue;
+							int r = tensorIndexes[1];
+							int s = tensorIndexes[2];
+							outputArr[k*RS + r*S + s] = avals[j];
+							nnz += outputArr[k*RS + r*S + s] != 0 ? 1 : 0;
+						}
+					}
+				}
+			}
+			mb.setNonZeros(nnz);
+			ret.add(mb);
+		}
+		return ret;
+	}
+	
+	// Single-threaded matrix multiplication
+	static void singleThreadedMatMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, 
+			boolean recomputeNNZM1, boolean recomputeNNZM2, ConvolutionParameters params) throws DMLRuntimeException {
+		if(!params.enableNative || m1.isInSparseFormat() || m2.isInSparseFormat()) {
+			if(recomputeNNZM1)
+				m1.recomputeNonZeros();
+			if(recomputeNNZM2)
+				m2.recomputeNonZeros();
+			LibMatrixMult.matrixMult(m1, m2, ret, false);
+		}
+		else {
+			ret.sparse = false;
+			if(ret.getDenseBlock() == null)
+				ret.allocateDenseBlock();
+			NativeHelper.matrixMultDenseDense(m1.denseBlock, m2.denseBlock, 
+					ret.denseBlock, m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns(), 1);
+			ret.recomputeNonZeros();
+		}
+	}
+	
+	static void addBias(int _rl, int _ru, double [] outputArr, double [] biasArr, int K, int PQ) {
+		// double [] biasArr = _params.bias.getDenseBlock();
+		
+		int index = _rl*K*PQ;
+		for(int n = _rl; n < _ru; n++) {
+			for(int k = 0; k < K; k++) {
+				for(int pq = 0; pq < PQ; pq++, index++) {
+					outputArr[index] += biasArr[k];
+				}
+			}
+		}
+	}
+	
+	/**
+	 * Returns the index of cell with maximum value. This method is optimized for dense input
+	 * 
+	 * @param p output feature map height
+	 * @param q output feature map width
+	 * @param inputOffset offset to be used for input index
+	 * @param inputArray input array
+	 * @param params convolution parameters
+	 * @param performReluBackward perform ReLU backward
+	 * @return index of cell with maximum value
+	 */
+	static int getMaxIndex(int p, int q, int inputOffset, double [] inputArray, ConvolutionParameters params, boolean performReluBackward) {
+		int start_index_h = params.start_indexes_h[p];
+		int end_index_h = params.end_indexes_h[p];
+		int start_index_w = params.start_indexes_w[q];
+		int end_index_w = params.end_indexes_w[q];
+		
+		int maxIndex = -1; 
+		double maxVal = -Double.MAX_VALUE;
+		
+		// Note: We do not treat pad as zero and hence we don't do:  
+		// maxVal = 0 
+		// if start_index_h < 0 || start_index_w < 0 || end_index_h >= params.H || end_index_w >= params.W
+		
+		// Find maxIndex
+		double currDoutVal = -1;
+		for (int h = start_index_h; h < end_index_h; h++) {
+			for (int w = start_index_w; w < end_index_w; w++) {
+				currDoutVal = inputArray[inputOffset +  h*params.W + w];
+				currDoutVal = performReluBackward && currDoutVal < 0 ? 0 : currDoutVal;
+				if(maxVal < currDoutVal) {
+					maxIndex = inputOffset +  h*params.W + w;
+					maxVal = currDoutVal;
+				}
+			}
+		}
+		return maxIndex;
+	}
+	
+	/**
+	 * Returns the index of cell with maximum value. This method is optimized for sparse input
+	 * 
+	 * @param p output feature map height
+	 * @param q output feature map width
+	 * @param inputOffset offset to be used for input index
+	 * @param n number of images
+	 * @param c number of channels 
+	 * @param input input matrix
+	 * @param params convolution parameters
+	 * @param performReluBackward perform ReLU on input
+	 * @return index of the cell with maximum value
+	 * @throws DMLRuntimeException if error occurs
+	 */
+	static int getMaxIndexSparse(int p, int q, int inputOffset, int n, int c, MatrixBlock input, ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException {
+		if(!input.isInSparseFormat())
+			throw new DMLRuntimeException("Incorrect usage: Only sparse format supported");
+		
+		int [] tensorIndexes = new int[3];
+		
+		int start_index_h = params.start_indexes_h[p];
+		int end_index_h = params.end_indexes_h[p];
+		int start_index_w = params.start_indexes_w[q];
+		int end_index_w = params.end_indexes_w[q];
+		
+		int maxIndex = -1; 
+		double maxVal = -Double.MAX_VALUE;
+		
+		// Note: We do not treat pad as zero and hence we don't do:  
+		// maxVal = 0 
+		// if start_index_h < 0 || start_index_w < 0 || end_index_h >= params.H || end_index_w >= params.W
+
+		// input.isEmptyBlock() check is done by the caller
+		if( !input.sparseBlock.isEmpty(n) ) {
+			// Find maxIndex
+			int apos = input.sparseBlock.pos(n);
+			int alen = input.sparseBlock.size(n);
+			int[] aix = input.sparseBlock.indexes(n);
+			double[] avals = input.sparseBlock.values(n);
+			for(int j=apos; j<apos+alen; j++) {
+				computeTensorIndexes(aix[j], tensorIndexes, params.H, params.W);
+				if(c != tensorIndexes[0])
+					continue;
+				int h = tensorIndexes[1];
+				int w = tensorIndexes[2];
+				if(h >= start_index_h && h < end_index_h && w >= start_index_w && w < end_index_w) {
+					double val = performReluBackward && avals[j] < 0 ? 0 : avals[j]; 
+					if(maxVal < val) {
+						maxIndex = inputOffset +  h*params.W + w;
+						maxVal = val;
+					}
+				}
+			}
+		}
+		else {
+			maxIndex = inputOffset;
+		}
+		return maxIndex;
+	}
+	
+	// Returns the row of matrix in dense format
+	static void getRowInDenseFormat(MatrixBlock input, int n, double []  ret) throws DMLRuntimeException {
+		if(input.getNumColumns() != ret.length) {
+			throw new DMLRuntimeException("Invalid parameters");
+		}
+		// Use temporary array to avoid binary search
+		if(input.isInSparseFormat()) {
+			Arrays.fill(ret, 0);
+			if( !input.sparseBlock.isEmpty(n) ) {
+				int apos = input.sparseBlock.pos(n);
+				int alen = input.sparseBlock.size(n);
+				int[] aix = input.sparseBlock.indexes(n);
+				double[] avals = input.sparseBlock.values(n);
+				for(int j=apos; j<apos+alen; j++)
+					ret[ aix[j] ] = avals[j];
+			}
+		}
+		else {
+			System.arraycopy(input.getDenseBlock(), n*input.getNumColumns(), ret, 0, input.getNumColumns());
+		}
+	}
+	
+	// ------------------------------------------------------------------------------------------------------
+	// Since col2im always operates on intermediate generated as part of matmult, it is not clear which operator to select apriori.
+	// Therefore, it is provided as utility function rather than an operator (like im2col or rotate180)
+	
+	//Converts input: PQ X CRS matrix and writes to 1 X CHW
+	static void doCol2imOverSingleImage(int outputN, MatrixBlock input, ConvolutionParameters params) throws DMLRuntimeException {
+		if(input.rlen != params.P*params.Q || input.clen != params.C*params.R*params.S) {
+			throw new DMLRuntimeException("Incorrect input dimensions");
+		}
+		
+		double [] outputArray = null;
+		if (!params.output.isInSparseFormat())
+			outputArray = params.output.getDenseBlock();
+		else {
+			throw new DMLRuntimeException("Only dense output is implemented");
+		}
+		
+		if(!input.isInSparseFormat()) {
+			double [] inputArray = input.getDenseBlock();
+			doCol2IMDenseInput(0, outputN, inputArray, outputArray, params);
+		}
+		else {
+			if(!input.isEmptyBlock()) {
+				int [] tensorIndexes = new int[3];
+				for(int i = 0; i < input.getNumRows(); i++) {
+					if( !input.sparseBlock.isEmpty(i) ) {
+						computeTensorIndexes(i, tensorIndexes, params.P, params.Q);
+						int p = tensorIndexes[1];
+						int q = tensorIndexes[2];
+						if(tensorIndexes[0] != 0) 
+							throw new DMLRuntimeException("Incorrect tensor indexes: " + tensorIndexes[0] + " != 0 <" + p + " " + q + " " + tensorIndexes[0] + params.P + " " + params.Q + ">");
+						
+						int apos = input.sparseBlock.pos(i);
+						int alen = input.sparseBlock.size(i);
+						int[] aix = input.sparseBlock.indexes(i);
+						double[] avals = input.sparseBlock.values(i);
+						for(int j = apos; j < apos+alen; j++) {
+							computeTensorIndexes(aix[j], tensorIndexes, params.R, params.S);
+							int c = tensorIndexes[0];
+							int r = tensorIndexes[1];
+							int s = tensorIndexes[2];
+							int h = p*params.stride_h + r - params.pad_h;
+							int w = q*params.stride_w + s - params.pad_w;
+							if(h >= 0 && h < params.H && w >= 0 && w < params.W) {
+								int outIndex = outputN*params.C*params.H*params.W + c*params.H*params.W + h*params.W + w;
+								outputArray[outIndex] += avals[j];
+							}
+						}
+					}
+				}
+			}
+		}
+	}
+	
+	// Converts input: PQ X CRS matrix and writes to 1 X CHW if inputN == 0
+	// Or converts input: NPQ X CRS matrix and writes to N X CHW 
+	private static void doCol2IMDenseInput(int inputN, int outputN, double [] inputArray, double [] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
+		final int outputNOffset = outputN*params.C*params.H*params.W;
+		for (int p = 0; p < params.P; p++) {
+			// h = p*params.stride_h + r - params.pad_h
+			//   = r + hOffset
+			// Based on restrictions: h >= 0 and r >= 0 and h < params.H and r < params.R, we get
+			// max(0, - hOffset) <= r < min(params.R, params.H - hOffset)
+			final int hOffset = p*params.stride_h - params.pad_h;
+			final int rStart = Math.max(0, - hOffset);
+			final int rEnd = Math.min(params.R, params.H - hOffset);
+			for (int q = 0; q < params.Q; q++) {
+				// Using the same logic as above on following:
+				// w = q*params.stride_w + s - params.pad_w
+				final int wOffset = q*params.stride_w - params.pad_w;
+				final int sStart = Math.max(0, - wOffset);
+				final int sEnd = Math.min(params.S, params.W - wOffset);
+				final int tempOffset = (inputN*params.P*params.Q + p*params.Q + q)*params.C*params.R*params.S;
+				for (int c = 0; c < params.C; c++) {
+					final int outOffset = outputNOffset + c*params.H*params.W;
+					final int inputOffset = tempOffset + c*params.R*params.S;
+					for (int r = rStart; r < rEnd; r++) {
+						for (int s = sStart; s < sEnd; s++) {
+							int inputIndex = inputOffset + r*params.S + s;
+							int outIndex = outOffset + (hOffset + r)*params.W + wOffset + s;
+							outputArray[outIndex] += inputArray[inputIndex];
+						}
+					}
+				}
+			}
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
new file mode 100644
index 0000000..9ae39bf
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
@@ -0,0 +1,386 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.matrix.data;
+
+import java.util.Arrays;
+
+/**
+ * This class contains the different implementation of im2col operation
+ */
+public class LibMatrixDNNIm2ColHelper {
+	
+	static interface Im2colWorker {
+		public void execute(int n);
+		public void execute(int n, int c);
+		public static Im2colWorker getWorker(MatrixBlock input, MatrixBlock im2ColOutBlock, ConvolutionParameters params, boolean allChannels) {
+			if(im2ColOutBlock.isInSparseFormat() || im2ColOutBlock.getDenseBlock() == null)
+				throw new RuntimeException("im2col output is always in dense format");
+			if(allChannels) {
+				if(!input.isInSparseFormat()) {
+					if (params.stride_h == 1 && params.stride_w == 1 && params.pad_h == 0 && params.pad_w == 0) 
+						return new DenseIm2colWorkerStride1Pad0AllChannels(input.getDenseBlock(), im2ColOutBlock.getDenseBlock(), params);
+					else
+						return new DenseIm2colWorkerAllChannels(input.getDenseBlock(), im2ColOutBlock.getDenseBlock(), params);
+				}
+				else 
+					return new SparseIm2colWorkerAllChannels(input, im2ColOutBlock, params);
+			}
+			else {
+				if(!input.isInSparseFormat()) {
+					if (params.stride_h == 1 && params.stride_w == 1 && params.pad_h == 0 && params.pad_w == 0) 
+						return new DenseIm2colWorkerStride1Pad0(input.getDenseBlock(), im2ColOutBlock.getDenseBlock(), params);
+					else
+						return new DenseIm2colWorker(input.getDenseBlock(), im2ColOutBlock.getDenseBlock(), params);
+				}
+				else 
+					return new SparseIm2colWorker(input, im2ColOutBlock, params);
+			}
+		}
+	}
+	
+	/**
+	 * Special case operator for performing dense im2col when stride = [1, 1] and pad = [0, 0] by using System.arraycopy
+	 */
+	static class DenseIm2colWorkerStride1Pad0 implements Im2colWorker {
+		double [] inputArray; double [] outputArray; 
+		int CRS; int S; int R; int P; int Q; int CHW; int H; int W;
+		public DenseIm2colWorkerStride1Pad0(double [] inputArray, double [] outputArray, ConvolutionParameters params) {
+			this.inputArray = inputArray;
+			this.outputArray = outputArray;
+			this.CRS = params.C * params.R * params.S;
+			this.H = params.H; this.W = params.W; this.R = params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
+			this.CHW = params.C*params.H*params.W;
+		}
+		
+		@Override
+		public void execute(int n) {
+			throw new RuntimeException("Not supported");
+		}
+
+		@Override
+		public void execute(int n, int cInput) {
+			int nOffset = n * CHW;
+			int RS = R*S;
+			for (int rs = 0; rs < RS; ++rs) {
+				int wOffset = rs % S;
+				int hOffset = rs / S;
+				for (int h = 0; h < P; ++h) {
+					int hPadded = h + hOffset;
+					int outOffset = (rs * P + h) * Q;
+					int inputOffset = nOffset + (cInput * H + hPadded) * W;
+					System.arraycopy(inputArray, inputOffset + wOffset, outputArray, outOffset, Q);
+					int w = Q - 1;
+					int wPadded = w + wOffset;
+					if (hPadded < H && wPadded < W)
+						outputArray[outOffset + w] = inputArray[inputOffset + wPadded];
+					else
+						outputArray[outOffset + w] = 0;
+				}
+			}
+		}
+	}
+
+	
+		
+	/**
+	 * Special case operator for performing dense im2col when stride = [1, 1] and pad = [0, 0] by using System.arraycopy
+	 */
+	static class DenseIm2colWorkerStride1Pad0AllChannels implements Im2colWorker {
+		double [] inputArray; double [] outputArray; 
+		int CRS; int S; int R; int P; int Q; int CHW; int H; int W;
+		public DenseIm2colWorkerStride1Pad0AllChannels(double [] inputArray, double [] outputArray, ConvolutionParameters params) {
+			this.inputArray = inputArray;
+			this.outputArray = outputArray;
+			this.CRS = params.C * params.R * params.S;
+			this.H = params.H; this.W = params.W; this.R = params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
+			this.CHW = params.C*params.H*params.W;
+		}
+		
+		@Override
+		public void execute(int n, int c) {
+			throw new RuntimeException("Not supported");
+		}
+
+		@Override
+		public void execute(int n) {
+			int nOffset = n * CHW;
+			for (int c = 0; c < CRS; ++c) {
+				int wOffset = c % S;
+				int hOffset = (c / S) % R;
+				int cInput = c / R / S;
+				for (int h = 0; h < P; ++h) {
+					int hPadded = h + hOffset;
+					int outOffset = (c * P + h) * Q;
+					int inputOffset = nOffset + (cInput * H + hPadded) * W;
+					System.arraycopy(inputArray, inputOffset + wOffset, outputArray, outOffset, Q);
+					int w = Q - 1;
+					int wPadded = w + wOffset;
+					if (hPadded < H && wPadded < W)
+						outputArray[outOffset + w] = inputArray[inputOffset + wPadded];
+					else
+						outputArray[outOffset + w] = 0;
+				}
+			}
+		}
+	}
+	
+	/**
+	 * Performing dense im2col (general case)
+	 */
+	static class DenseIm2colWorker implements Im2colWorker {
+		double [] inputArray; double [] outputArray; 
+		int CRS; int S; int R; int P; int Q; int CHW; int H; int W; 
+		int stride_h; int stride_w; int pad_h; int pad_w;
+		public DenseIm2colWorker(double [] inputArray, double [] outputArray, ConvolutionParameters params) {
+			this.inputArray = inputArray;
+			this.outputArray = outputArray;
+			this.CRS = params.C * params.R * params.S;
+			this.H = params.H; this.W = params.W; this.R = params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
+			this.CHW = params.C*params.H*params.W;
+			this.stride_h = params.stride_h; this.stride_w = params.stride_w;
+			this.pad_h = params.pad_h; this.pad_w = params.pad_w;
+		}
+		
+		@Override
+		public void execute(int n) {
+			throw new RuntimeException("Not supported");
+		}
+
+		@Override
+		public void execute(int n, int cInput) {
+			int nOffset = n * CHW; int RS = R*S;
+			for (int rs = 0; rs < RS; ++rs) {
+				int wOffset = rs % S;
+				int hOffset = rs / S;
+				for (int h = 0; h < P; ++h) {
+					int outOffset = (rs * P + h) * Q;
+					int hPadded = h * stride_h - pad_h + hOffset;
+					int inputOffset = nOffset + (cInput * H + hPadded) * W;
+					if (hPadded < 0 || hPadded >= H) {
+						Arrays.fill(outputArray, outOffset, outOffset+Q, 0);
+					} else {
+						for (int w = 0; w < Q; ++w) {
+							int wPadded = w * stride_w - pad_w + wOffset;
+							if (wPadded >= 0 && wPadded < W)
+								outputArray[outOffset + w] = inputArray[inputOffset + wPadded];
+							else
+								outputArray[outOffset + w] = 0;
+						}
+					}
+				}
+			}
+		}
+	}
+	
+	/**
+	 * Performing dense im2col (general case)
+	 */
+	static class DenseIm2colWorkerAllChannels implements Im2colWorker {
+		double [] inputArray; double [] outputArray; 
+		int CRS; int S; int R; int P; int Q; int CHW; int H; int W; 
+		int stride_h; int stride_w; int pad_h; int pad_w;
+		public DenseIm2colWorkerAllChannels(double [] inputArray, double [] outputArray, ConvolutionParameters params) {
+			this.inputArray = inputArray;
+			this.outputArray = outputArray;
+			this.CRS = params.C * params.R * params.S;
+			this.H = params.H; this.W = params.W; this.R = params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
+			this.CHW = params.C*params.H*params.W;
+			this.stride_h = params.stride_h; this.stride_w = params.stride_w;
+			this.pad_h = params.pad_h; this.pad_w = params.pad_w;
+		}
+		
+		@Override
+		public void execute(int n, int c) {
+			throw new RuntimeException("Not supported");
+		}
+
+		@Override
+		public void execute(int n) {
+			int nOffset = n * CHW;
+			for (int c = 0; c < CRS; ++c) {
+				int wOffset = c % S;
+				int hOffset = (c / S) % R;
+				int cInput = c / R / S;
+				for (int h = 0; h < P; ++h) {
+					int outOffset = (c * P + h) * Q;
+					int hPadded = h * stride_h - pad_h + hOffset;
+					int inputOffset = nOffset + (cInput * H + hPadded) * W;
+					if (hPadded < 0 || hPadded >= H) {
+						Arrays.fill(outputArray, outOffset, outOffset+Q, 0);
+					} else {
+						for (int w = 0; w < Q; ++w) {
+							int wPadded = w * stride_w - pad_w + wOffset;
+							if (wPadded >= 0 && wPadded < W)
+								outputArray[outOffset + w] = inputArray[inputOffset + wPadded];
+							else
+								outputArray[outOffset + w] = 0;
+						}
+					}
+				}
+			}
+		}
+	}
+	
+	/**
+	 * Performing dense im2col (general case)
+	 */
+	static class SparseIm2colWorkerAllChannels implements Im2colWorker {
+		MatrixBlock input; double [] outputArray; 
+		int CRS; int S; int R; int P; int Q; int H; int W; 
+		int stride_h; int stride_w; int pad_h; int pad_w; double [] temp;
+		public SparseIm2colWorkerAllChannels(MatrixBlock input, MatrixBlock im2ColOutBlock, ConvolutionParameters params) {
+			this.input = input;
+			this.outputArray = im2ColOutBlock.getDenseBlock();
+			this.CRS = params.C * params.R * params.S;
+			this.H = params.H; this.W = params.W; this.R = params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
+			this.stride_h = params.stride_h; this.stride_w = params.stride_w;
+			this.pad_h = params.pad_h; this.pad_w = params.pad_w;
+			temp = new double[input.getNumColumns()];
+		}
+		
+		@Override
+		public void execute(int n, int c) {
+			throw new RuntimeException("Not supported");
+		}
+
+		@Override
+		public void execute(int n) {
+			// Using a temporary array improves performance by not requiring binary search for getValue
+			// Since the access pattern depends on ConvolutionParameters, this serves as a temporary fix.
+			fillTemp(input, n);
+			// final int nOffset = n * params.C*params.H*params.W;
+			for (int c = 0; c < CRS; ++c) {
+				int wOffset = c % S;
+				int hOffset = (c / S) % R;
+				int cInput = c / R / S;
+				for (int h = 0; h < P; ++h) {
+					int outOffset = (c * P + h) * Q;
+					int hPadded = h * stride_h - pad_h + hOffset;
+					int tempOffset = (cInput * H + hPadded) * W;
+					// int inputOffset = nOffset + tempOffset;
+					if (hPadded < 0 || hPadded >= H) {
+						Arrays.fill(outputArray, outOffset, outOffset+Q, 0);
+					} else {
+						for (int w = 0; w < Q; ++w) {
+							int wPadded = w * stride_w - pad_w + wOffset;
+							if (wPadded >= 0 && wPadded < W) 
+								outputArray[outOffset + w] = temp[tempOffset + wPadded];
+							else
+								outputArray[outOffset + w] = 0;
+						}
+					}
+				}
+			}
+		}
+		// Returns the row of matrix in dense format
+		private void fillTemp(MatrixBlock input, int n) {
+			if(input.getNumColumns() != temp.length) {
+				throw new RuntimeException("Invalid parameters");
+			}
+			// Use temporary array to avoid binary search
+			if(input.isInSparseFormat()) {
+				Arrays.fill(temp, 0);
+				if( !input.sparseBlock.isEmpty(n) ) {
+					int apos = input.sparseBlock.pos(n);
+					int alen = input.sparseBlock.size(n);
+					int[] aix = input.sparseBlock.indexes(n);
+					double[] avals = input.sparseBlock.values(n);
+					for(int j=apos; j<apos+alen; j++)
+						temp[ aix[j] ] = avals[j];
+				}
+			}
+			else {
+				System.arraycopy(input.getDenseBlock(), n*input.getNumColumns(), temp, 0, input.getNumColumns());
+			}
+		}
+	}
+	
+	/**
+	 * Performing dense im2col (general case)
+	 */
+	static class SparseIm2colWorker implements Im2colWorker {
+		MatrixBlock input; double [] outputArray; 
+		int CRS; int S; int R; int P; int Q; int H; int W; 
+		int stride_h; int stride_w; int pad_h; int pad_w; double [] temp;
+		public SparseIm2colWorker(MatrixBlock input, MatrixBlock im2ColOutBlock, ConvolutionParameters params) {
+			this.input = input;
+			this.outputArray = im2ColOutBlock.getDenseBlock();
+			this.CRS = params.C * params.R * params.S;
+			this.H = params.H; this.W = params.W; this.R = params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
+			this.stride_h = params.stride_h; this.stride_w = params.stride_w;
+			this.pad_h = params.pad_h; this.pad_w = params.pad_w;
+			temp = new double[input.getNumColumns()];
+		}
+		
+		@Override
+		public void execute(int n) {
+			throw new RuntimeException("Not supported");
+		}
+
+		@Override
+		public void execute(int n, int cInput) {
+			// Using a temporary array improves performance by not requiring binary search for getValue
+			// Since the access pattern depends on ConvolutionParameters, this serves as a temporary fix.
+			fillTemp(input, n); int RS = R*S;
+			for (int rs = 0; rs < RS; ++rs) {
+				int wOffset = rs % S;
+				int hOffset = rs / S;
+				for (int h = 0; h < P; ++h) {
+					int outOffset = (rs * P + h) * Q;
+					int hPadded = h * stride_h - pad_h + hOffset;
+					int tempOffset = (cInput * H + hPadded) * W;
+					// int inputOffset = nOffset + tempOffset;
+					if (hPadded < 0 || hPadded >= H) {
+						Arrays.fill(outputArray, outOffset, outOffset+Q, 0);
+					} else {
+						for (int w = 0; w < Q; ++w) {
+							int wPadded = w * stride_w - pad_w + wOffset;
+							if (wPadded >= 0 && wPadded < W) 
+								outputArray[outOffset + w] = temp[tempOffset + wPadded];
+							else
+								outputArray[outOffset + w] = 0;
+						}
+					}
+				}
+			}
+		}
+		// Returns the row of matrix in dense format
+		private void fillTemp(MatrixBlock input, int n) {
+			if(input.getNumColumns() != temp.length) {
+				throw new RuntimeException("Invalid parameters");
+			}
+			// Use temporary array to avoid binary search
+			if(input.isInSparseFormat()) {
+				Arrays.fill(temp, 0);
+				if( !input.sparseBlock.isEmpty(n) ) {
+					int apos = input.sparseBlock.pos(n);
+					int alen = input.sparseBlock.size(n);
+					int[] aix = input.sparseBlock.indexes(n);
+					double[] avals = input.sparseBlock.values(n);
+					for(int j=apos; j<apos+alen; j++)
+						temp[ aix[j] ] = avals[j];
+				}
+			}
+			else {
+				System.arraycopy(input.getDenseBlock(), n*input.getNumColumns(), temp, 0, input.getNumColumns());
+			}
+		}
+	}
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
new file mode 100644
index 0000000..b400105
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.matrix.data;
+
+import java.util.concurrent.Callable;
+
+/**
+ * This class contains the set of operators used for performing pooling backward
+ */
+public class LibMatrixDNNPoolingBackwardHelper {
+	/**
+	 * Performs the maxpooling backward operation for dense input and dense error (dout)
+	 */
+	public static class PoolingBackwardDenseDense implements Callable<Long> 
+	{
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params; 
+		double [] outputArray; boolean performReluBackward;
+		double [] inputArray; double [] doutArray;
+		int C; int CHW; int P; int Q; int HW; int CPQ; int PQ;
+		public PoolingBackwardDenseDense(int rl, int ru, ConvolutionParameters params, boolean performReluBackward) {
+			_rl = rl; _ru = ru;
+			_params = params;
+			this.performReluBackward = performReluBackward;
+			inputArray = params.input1.getDenseBlock();
+			doutArray = params.input2.getDenseBlock();
+			outputArray = params.output.getDenseBlock();
+			C = params.C; CHW = params.C*params.H*params.W; HW = params.H*params.W;
+			P = params.P; Q = params.Q; CPQ = params.C*params.P*params.Q;
+			PQ = params.P*params.Q;
+			if (inputArray == null || doutArray == null || outputArray == null )
+				throw new RuntimeException("Incorrect usage: empty inputs");
+		}
+		
+		@Override
+		public Long call() throws Exception {
+			for(int n = _rl; n < _ru; n++)  {
+				for (int c = 0; c < C; c++) {
+					final int inputOffset = n*CHW + c*HW;
+					final int outputOffset = n*CPQ + c*PQ;
+					for (int p = 0; p < P; p++) {
+						for (int q = 0; q < Q; q++) {
+							int maxIndex = LibMatrixDNNHelper.getMaxIndex(p, q, inputOffset, inputArray, _params, performReluBackward);
+							if(maxIndex != -1)
+								outputArray[maxIndex] += doutArray[outputOffset +  p * Q + q];
+						}
+					}
+				}
+			}
+			return 0L;
+		}
+	}
+	
+	/**
+	 * Performs the maxpooling backward operation for dense input and sparse error (dout)
+	 */
+	public static class PoolingBackwardDenseSparse implements Callable<Long> 
+	{
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params; 
+		double [] outputArray; boolean performReluBackward;
+		double [] inputArray;  MatrixBlock dout;
+		int C; int CHW; int P; int Q; int HW;
+		public PoolingBackwardDenseSparse(int rl, int ru, ConvolutionParameters params, boolean performReluBackward) {
+			_rl = rl; _ru = ru;
+			_params = params;
+			this.performReluBackward = performReluBackward;
+			inputArray = params.input1.getDenseBlock();
+			dout = params.input2;
+			outputArray = params.output.getDenseBlock();
+			C = params.C; CHW = params.C*params.H*params.W; HW = params.H*params.W;
+			P = params.P; Q = params.Q; 
+			if (inputArray == null || outputArray == null )
+				throw new RuntimeException("Incorrect usage: empty inputs");
+			if (!params.input2.isInSparseFormat())
+				throw new RuntimeException("Incorrect usage: Call optimized versions");
+		}
+		
+		@Override
+		public Long call() throws Exception {
+			for(int n = _rl; n < _ru; n++)  {
+				if( !dout.sparseBlock.isEmpty(n) ) {
+					int [] tensorIndexes = new int[3];
+					int apos = dout.sparseBlock.pos(n);
+					int alen = dout.sparseBlock.size(n);
+					int[] aix = dout.sparseBlock.indexes(n);
+					double[] avals = dout.sparseBlock.values(n);
+					for(int j = apos; j < apos+alen; j++) {
+						LibMatrixDNNHelper.computeTensorIndexes(aix[j], tensorIndexes, P, Q);
+						int c = tensorIndexes[0];
+						int p = tensorIndexes[1];
+						int q = tensorIndexes[2];
+						final int inputOffset = n*CHW + c*HW;
+						int maxIndex = LibMatrixDNNHelper.getMaxIndex(p, q, inputOffset, inputArray, _params, performReluBackward);
+						if(maxIndex != -1)
+							outputArray[maxIndex] += avals[j];
+					}
+				}
+			}
+			return 0L;
+		}
+	}
+	
+	/**
+	 * Performs the maxpooling backward operation for sparse input and dense error (dout)
+	 */
+	public static class PoolingBackwardSparseDense implements Callable<Long> 
+	{
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params; 
+		double [] outputArray; boolean performReluBackward;
+		double [] doutArray;
+		int C; int CHW; int P; int Q; int HW; int CPQ; int PQ;
+		public PoolingBackwardSparseDense(int rl, int ru, ConvolutionParameters params, boolean performReluBackward) {
+			_rl = rl; _ru = ru;
+			_params = params;
+			this.performReluBackward = performReluBackward;
+			doutArray = params.input2.getDenseBlock();
+			outputArray = params.output.getDenseBlock();
+			C = params.C; CHW = params.C*params.H*params.W; HW = params.H*params.W;
+			P = params.P; Q = params.Q; CPQ = params.C*params.P*params.Q;
+			PQ = params.P*params.Q;
+			if (doutArray == null || outputArray == null )
+				throw new RuntimeException("Incorrect usage: empty inputs");
+			if (!params.input1.isInSparseFormat())
+				throw new RuntimeException("Incorrect usage: Call optimized versions");
+		}
+		
+		@Override
+		public Long call() throws Exception {
+			for(int n = _rl; n < _ru; n++)  {
+				for (int c = 0; c < C; c++) {
+					for (int p = 0; p < P; p++) {
+						for (int q = 0; q < Q; q++) {
+							double inVal = doutArray[n*CPQ + c*PQ +  p * Q + q];
+							if(inVal != 0) {
+								final int inputOffset = n*CHW + c*HW;
+								int maxIndex = LibMatrixDNNHelper.getMaxIndexSparse(p, q, inputOffset, n, c, _params.input1, _params, performReluBackward);
+								if(maxIndex != -1)
+									outputArray[maxIndex] += inVal;
+							}
+						}
+					}
+				}
+			}
+			return 0L;
+		}
+	}
+	
+	/**
+	 * Performs the maxpooling backward operation for sparse input and sparse error (dout)
+	 */
+	public static class PoolingBackwardSparseSparse implements Callable<Long> 
+	{
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params; 
+		double [] outputArray; boolean performReluBackward;
+		int C; int CHW; int P; int Q; int HW; 
+		public PoolingBackwardSparseSparse(int rl, int ru, ConvolutionParameters params, boolean performReluBackward) {
+			_rl = rl; _ru = ru;
+			_params = params;
+			this.performReluBackward = performReluBackward;
+			outputArray = params.output.getDenseBlock();
+			C = params.C; CHW = params.C*params.H*params.W; HW = params.H*params.W;
+			P = params.P; Q = params.Q;
+			if (outputArray == null )
+				throw new RuntimeException("Incorrect usage: empty outputs");
+			if (!params.input1.isInSparseFormat() || !params.input2.isInSparseFormat())
+				throw new RuntimeException("Incorrect usage: Call optimized versions");
+		}
+		
+		@Override
+		public Long call() throws Exception {
+			for(int n = _rl; n < _ru; n++)  {
+				if( !_params.input2.sparseBlock.isEmpty(n) ) {
+					int [] tensorIndexes = new int[3];
+					int apos = _params.input2.sparseBlock.pos(n);
+					int alen = _params.input2.sparseBlock.size(n);
+					int[] aix = _params.input2.sparseBlock.indexes(n);
+					double[] avals = _params.input2.sparseBlock.values(n);
+					for(int j = apos; j < apos+alen; j++) {
+						LibMatrixDNNHelper.computeTensorIndexes(aix[j], tensorIndexes, P, Q);
+						int c = tensorIndexes[0];
+						int p = tensorIndexes[1];
+						int q = tensorIndexes[2];
+						final int inputOffset = n*CHW + c*HW;
+						int maxIndex = LibMatrixDNNHelper.getMaxIndexSparse(p, q, inputOffset, n, c, _params.input1, _params, performReluBackward);
+						if(maxIndex != -1)
+							outputArray[maxIndex] += avals[j];
+					}
+				}
+			}
+			return 0L;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
new file mode 100644
index 0000000..c6aaee2
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.matrix.data;
+
+import java.util.Arrays;
+import java.util.concurrent.Callable;
+
+/**
+ * This class contains the set of operators used for performing pooling
+ */
+public class LibMatrixDNNPoolingHelper {
+	
+	/**
+	 * Performs the dense maxpooling
+	 */
+	public static class DenseMaxPooling implements Callable<Long> 
+	{
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params;
+		double [] inputArray; double [] outputArray;
+		int C; int P; int Q; int W;
+		public DenseMaxPooling(int rl, int ru, ConvolutionParameters params) {
+			_rl = rl; _ru = ru;
+			_params = params;
+			inputArray = params.input1.getDenseBlock();
+			outputArray = params.output.getDenseBlock();
+			C = params.C; P = params.P; Q = params.Q; W = params.W;
+		}
+		
+		@Override
+		public Long call() throws Exception {
+			final int HW = _params.H*_params.W;
+			final int CHW = _params.C*_params.H*_params.W;
+			final int CPQ = C*P*Q;
+			for(int n = _rl; n < _ru; n++)  {
+				final int inOffset = n*CHW;
+				int out_index = n*CPQ;
+				for (int c = 0; c < C; c++) {
+					final int inOffset1 = inOffset + c*HW;
+					for (int p = 0; p < P; p++) {
+						for (int q = 0; q < Q; q++, out_index++) {
+							for (int h = _params.start_indexes_h[p]; h < _params.end_indexes_h[p]; h++) {
+								for (int w = _params.start_indexes_w[q]; w < _params.end_indexes_w[q]; w++) {
+									outputArray[out_index] = Math.max(outputArray[out_index], inputArray[inOffset1 +  h*W + w]);
+								}
+							}
+						}
+					}
+				}
+			}
+			return 0L;
+		}
+	}
+	
+	/**
+	 * Performs the sparse maxpooling
+	 */
+	public static class SparseMaxPooling implements Callable<Long> 
+	{
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params;
+		int HW;
+		double [] outputArray;
+		int C; int P; int Q; int W;
+		public SparseMaxPooling(int rl, int ru, ConvolutionParameters params) {
+			_rl = rl; _ru = ru;
+			_params = params;
+			outputArray = params.output.getDenseBlock();
+			C = params.C; P = params.P; Q = params.Q; W = params.W;
+			HW = _params.H*_params.W;
+		}
+		
+		boolean isNthRowEmpty = false;
+		int apos; int alen; int[] aix; double[] avals;
+		private void getNthSparseRow(int n) {
+			if( !_params.input1.sparseBlock.isEmpty(n) ) {
+				apos = _params.input1.sparseBlock.pos(n);
+				alen = _params.input1.sparseBlock.size(n);
+				aix = _params.input1.sparseBlock.indexes(n);
+				avals = _params.input1.sparseBlock.values(n);
+				isNthRowEmpty = false;
+			}
+			else
+				isNthRowEmpty = true;
+		}
+		int fromIndex = -1; // as per C
+		int toIndex = -1; // as per C
+		private int setSearchIndex(int from, int searchVal) {
+			for(int j = from; j < apos+alen; j++) {
+				if(aix[j] > searchVal)
+					return Math.max(from, j-1);
+			}
+			return apos+alen;
+		}
+		private double getValue(int col) {
+			if( !isNthRowEmpty ) {
+				int index = Arrays.binarySearch(aix, fromIndex, toIndex, col);
+				return index > 0 ? avals[index] : 0;
+			}
+			return 0;
+		}
+		
+		@Override
+		public Long call() throws Exception {
+			final int CPQ = C*P*Q;
+			for(int n = _rl; n < _ru; n++)  {
+				getNthSparseRow(n);
+				int out_index = n*CPQ;
+				for (int c = 0; c < C; c++) {
+					// This allows for binary search in getValue to be more efficient
+					fromIndex = setSearchIndex(apos, c*HW);
+					toIndex = Math.min(apos+alen, setSearchIndex(fromIndex, (c+1)*HW));
+					for (int p = 0; p < P; p++) {
+						for (int q = 0; q < Q; q++, out_index++) {
+							for (int h = _params.start_indexes_h[p]; h < _params.end_indexes_h[p]; h++) {
+								for (int w = _params.start_indexes_w[q]; w < _params.end_indexes_w[q]; w++) {
+									outputArray[out_index] = Math.max(outputArray[out_index], getValue(c*HW +  h*W + w));
+								}
+							}
+						}
+					}
+				}
+			}
+			return 0L;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
new file mode 100644
index 0000000..c003756
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.matrix.data;
+
+import java.util.Arrays;
+
+/**
+ * This class contains the different implementation of rotate180 operation
+ */
+public class LibMatrixDNNRotate180Helper {
+
+	static interface Rotate180Worker {
+		public void execute(int inputN, int outputN);
+		public static Rotate180Worker getWorker(MatrixBlock input, double [] outputArray, ConvolutionParameters params, boolean zeroOutSparseOutput) {
+			if(!input.isInSparseFormat()) 
+				return new DenseRotate180Worker(input, outputArray, params);
+			else
+				return new SparseRotate180Worker(input, outputArray, params, zeroOutSparseOutput);
+		}
+	}
+	
+	/**
+	 * Performing dense rotate180 (general case)
+	 */
+	static class DenseRotate180Worker implements Rotate180Worker {
+
+		double [] inputArray; double [] outputArray;  
+		ConvolutionParameters params;
+		public DenseRotate180Worker(MatrixBlock input, double [] outputArray,  ConvolutionParameters params) {
+			this.outputArray = outputArray;
+			this.params = params;
+			inputArray = input.getDenseBlock();
+			if(inputArray == null || outputArray == null)
+				throw new RuntimeException("Incorrect usage: empty inputs");
+		}
+		
+		@Override
+		public void execute(int inputN, int outputN) {
+			int outputOffset = outputN*params.K*params.P*params.Q;
+			for (int k = 0; k < params.K; k++) {
+				for (int p = 0; p < params.P; p++) {
+					for (int q = 0; q < params.Q; q++) {
+						outputArray[outputOffset + p*params.Q*params.K + q*params.K + k] = 
+								inputArray[inputN*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q + q];
+					}
+				}
+			}
+		}
+	}
+	
+	/**
+	 * Performing rotate180 when input is sparse (general case)
+	 */
+	static class SparseRotate180Worker implements Rotate180Worker {
+
+		double [] outputArray;  MatrixBlock input;
+		ConvolutionParameters params; boolean zeroOutSparseOutput;
+		public SparseRotate180Worker(MatrixBlock input, double [] outputArray,  ConvolutionParameters params, boolean zeroOutSparseOutput) {
+			this.outputArray = outputArray;
+			this.params = params;
+			this.zeroOutSparseOutput = zeroOutSparseOutput;
+			this.input = input;
+			if(outputArray == null)
+				throw new RuntimeException("Incorrect usage: empty inputs");
+		}
+		
+		@Override
+		public void execute(int inputN, int outputN) {
+			if(zeroOutSparseOutput)
+				Arrays.fill(outputArray, 0);
+			
+			int outputOffset = outputN*params.K*params.P*params.Q;
+			if(!input.isEmptyBlock()) {
+				if( !input.sparseBlock.isEmpty(inputN) ) {
+					int [] tensorIndexes = new int[3];
+					int apos = input.sparseBlock.pos(inputN);
+					int alen = input.sparseBlock.size(inputN);
+					int[] aix = input.sparseBlock.indexes(inputN);
+					double[] avals = input.sparseBlock.values(inputN);
+					for(int j = apos; j < apos+alen; j++) {
+						LibMatrixDNNHelper.computeTensorIndexes(aix[j], tensorIndexes, params.P, params.Q);
+						int k = tensorIndexes[0];
+						int p = tensorIndexes[1];
+						int q = tensorIndexes[2];
+						outputArray[outputOffset + p*params.Q*params.K + q*params.K + k] = avals[j];
+					}
+				}
+			}
+		}
+	}
+}


[2/2] incubator-systemml git commit: [SYSTEMML-540] Refactored LibMatrixDNN to reduce instruction cache misses

Posted by ni...@apache.org.
[SYSTEMML-540] Refactored LibMatrixDNN to reduce instruction cache misses

- Bugfix for empty filter for conv2d_bias_add
- Improved sparse maxpooling's performance
- Reduced branch mispredictions and instruction cache misses.

Closes #520.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/19eed8f3
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/19eed8f3
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/19eed8f3

Branch: refs/heads/master
Commit: 19eed8f3858d7daad1c549b548b7de4ff270def8
Parents: 28c92b9
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Mon May 29 15:21:22 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Mon May 29 16:21:22 2017 -0700

----------------------------------------------------------------------
 docs/python-reference.md                        |   8 +-
 scripts/nn/test/compare_backends/compare.dml    |   5 +-
 .../cp/ConvolutionCPInstruction.java            |  20 +-
 .../matrix/data/ConvolutionParameters.java      |   3 +-
 .../sysml/runtime/matrix/data/LibMatrixDNN.java | 995 ++-----------------
 .../LibMatrixDNNConv2dBackwardDataHelper.java   | 112 +++
 .../LibMatrixDNNConv2dBackwardFilterHelper.java | 138 +++
 .../matrix/data/LibMatrixDNNConv2dHelper.java   | 224 +++++
 .../runtime/matrix/data/LibMatrixDNNHelper.java | 541 ++++++++++
 .../matrix/data/LibMatrixDNNIm2ColHelper.java   | 386 +++++++
 .../data/LibMatrixDNNPoolingBackwardHelper.java | 212 ++++
 .../matrix/data/LibMatrixDNNPoolingHelper.java  | 143 +++
 .../data/LibMatrixDNNRotate180Helper.java       | 107 ++
 13 files changed, 1951 insertions(+), 943 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/docs/python-reference.md
----------------------------------------------------------------------
diff --git a/docs/python-reference.md b/docs/python-reference.md
index 2ebfc38..a847964 100644
--- a/docs/python-reference.md
+++ b/docs/python-reference.md
@@ -189,14 +189,10 @@ method as DataFrame or NumPy array.
 
 ### Support for NumPy's universal functions
 
-The matrix class also supports most of NumPy's universal functions (i.e. ufuncs).
-The current version of NumPy explicitly disables overriding ufunc, but this should be enabled in next release. 
-Until then to test above code, please use:
+The matrix class also supports most of NumPy's universal functions (i.e. ufuncs):
 
 ```bash
-git clone https://github.com/niketanpansare/numpy.git
-cd numpy
-python setup.py install
+pip install --ignore-installed 'numpy>=1.13.0rc2'
 ```
 
 This will enable NumPy's functions to invoke matrix class:

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/scripts/nn/test/compare_backends/compare.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/compare.dml b/scripts/nn/test/compare_backends/compare.dml
index f87c472..7205631 100644
--- a/scripts/nn/test/compare_backends/compare.dml
+++ b/scripts/nn/test/compare_backends/compare.dml
@@ -22,7 +22,10 @@
 X = read($1)
 Y = read($2)
 msg = ifdef($3, " ")
-eps = 1e-3
+eps = 1e-6
+# Normalize X and Y
+X = X / max(X)
+Y = Y / max(Y)
 num_mismatch = sum(abs(X - Y) > eps)
 if(num_mismatch > 0) {
 	print("---------------------------------------------------\nERROR: >>>>>>>>> The results don't match(num_mismatch:" + num_mismatch + "): " + msg + "\n---------------------------------------------------")

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
index 1331d64..840b39e 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
@@ -218,13 +218,11 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction
 				.getLongValue();
 	}
 	
-	@SuppressWarnings("unused")
 	public void processReluBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
 		// (X > 0) * dout
 		MatrixBlock input = ec.getMatrixInput(input1.getName());
 		MatrixBlock dout = ec.getMatrixInput(_in2.getName());
-		MatrixBlock outputBlock =  new MatrixBlock(input.getNumRows(), input.getNumColumns(), 
-			LibMatrixDNN.SUPPORTS_SPARSE_OUTPUTS && (input.isInSparseFormat() || dout.isInSparseFormat()));
+		MatrixBlock outputBlock =  new MatrixBlock(input.getNumRows(), input.getNumColumns(), (input.isInSparseFormat() || dout.isInSparseFormat()));
 		
 		if( !input.isEmpty() && !dout.isEmpty() ) {
 			outputBlock.allocateDenseOrSparseBlock();
@@ -383,12 +381,26 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction
 		else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
 			MatrixBlock filter = ec.getMatrixInput(_in3.getName());
 			MatrixBlock bias = ec.getMatrixInput(_in2.getName());
-			if((filter.isEmpty() || matBlock.isEmpty()) && bias.isEmpty()) {
+			if(bias.getNumRows() != params.K || bias.getNumColumns() != 1) {
+				throw new DMLRuntimeException("Incorrect shape of bias matrix: [" + bias.getNumRows() + " " + bias.getNumColumns() + "]. "
+						+ "Expected: [" + params.K + ", 1]");
+			}
+			boolean isOutputConvEmpty = filter.isEmpty() || matBlock.isEmpty();
+			if(isOutputConvEmpty && bias.isEmpty()) {
+				// bias_add(empty mb, empty mb) = empty mb
 				outputBlock = new MatrixBlock(N, K*P*Q, true);
 			}
+			else if(isOutputConvEmpty && !bias.isEmpty()) {
+				// Add bias to empty output block
+				// bias_add(empty mb, bias)
+				outputBlock = getDenseOutputBlock(N, K*P*Q);
+				for(int n = 0;  n < params.N; n++) 
+					ConvolutionUtils.fillBias(bias, outputBlock.getDenseBlock(), n, n+1, params.N, params.K, params.P*params.Q);
+			}
 			else {
 				outputBlock = getDenseOutputBlock(N, K*P*Q);
 				if(!bias.isEmpty()) {
+					// Handle situation where both input and filter are non empty, but bias is empty
 					params.bias = bias;
 				}
 				if(params.enableNative && !isFilterSparse(filter) && !matBlock.isInSparseFormat())

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java b/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
index 11e74ca..a24a736 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
@@ -58,7 +58,8 @@ public class ConvolutionParameters implements Serializable {
 	}
 	
 	public String toString() {
-		return "(" + N + " " + C + " " + H + " " + W + " " + K + " " + R + " " + S + ")";  
+		return "(NCHW=[" + N + " " + C + " " + H + " " + W + "], KCRS=[" + K + " " + R + " " + S + "], stride=[" + stride_h + "," + stride_w  + 
+				"], pad=[" + pad_h + "," + pad_w + "])";  
 	}
 	
 	public ConvolutionParameters(long N, long C, long H, long W,

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index ab82697..30b8b64 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -19,10 +19,8 @@
 package org.apache.sysml.runtime.matrix.data;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.concurrent.Callable;
-import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
@@ -36,10 +34,9 @@ import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
-import org.apache.sysml.utils.NativeHelper;
 import org.apache.sysml.utils.Statistics;
 
-/**
+/*
  * This class allows users to invoke deep learning related operations 
  * (such as conv2d, conv2d_backward_data, conv2d_backward_filter, maxpooling, maxpooling_backward, bias_add)
  * using multiple threads.
@@ -47,23 +44,32 @@ import org.apache.sysml.utils.Statistics;
  * The methods accept the input matrices as MatrixBlock and the parameters using ConvolutionParameters.
  * 
  * To run in single thread, please set ConvolutionParameters.numThreads to 1.
+ * 
+ * DESIGN:
+ * 
+ * 1. LibMatrixDNN contains the user-facing methods for deep learning related operations. 
+ * 2. The deep learning tasks are executed in parallel using java's ExecutorService. The key pattern
+ * followed by the above mentioned functions are as follows:
+ *   execute(LibMatrixDNNHelper.get__Workers(params), params);
+ * 3. LibMatrixDNN's execute() method ensures the creation and shutdown of the ExecutorService.
+ * 4. LibMatrixDNNHelper.get__Workers creates appropriate workers based on the runtime characteristics of
+ * the input data (for example: input activations, filter, dout, ...). For code maintenance, these workers
+ * are placed in the respective LibMatrixDNN__Helper files.
+ * 5. The above mentioned workers may also use additional workers such as im2col and rotate180.
+ * We have created similar get__Workers methods to return the appropriate worker based on the
+ * runtime characteristics.
+ * 6. As opposed to earlier implementation, this design reduces branch misprediction as well 
+ * as instruction cache misses. It also allows us to experiment with new operators for different
+ * data characteristics without affecting the performance of other operators. 
+ * 7. This class assumes that the caller (for CP ConvolutionCPInstruction) deals with the empty block cases.  
+ * 
  */
 public class LibMatrixDNN {
 	
 	protected static final Log LOG =  LogFactory.getLog(LibMatrixDNN.class.getName());
 	
 	//library configurations and external contracts
-	public static final boolean SUPPORTS_SPARSE_OUTPUTS = false; //operations able to handle sparse outputs 
-	private static final boolean ALLOW_MULTI_THREADED_OPS = true; //enable multi-threading in cp
-	private static final int NUM_TASK_FACTOR = 2; //number of tasks is vcores scaled by this factor
 	public static boolean DISPLAY_STATISTICS = false; //conv2d summaries in stats output
-
-	private enum TaskType {
-		MaxPooling_Forward, MaxPooling_Backward, MaxPooling_Relu_Backward,
-		// Alternate approaches that we tried but the performance was unsatisfactory be included: direct, non-looped im2col
-		LoopedIm2ColConv2d, LoopedIm2ColConv2dBwdFilter, LoopedIm2ColConv2dBwdData,
-		ReluBackward
-	}
 	
 	// ------------------------------------------------------------------------------------------------
 	private static AtomicLong conv2dSparseCount = new AtomicLong(0);
@@ -76,12 +82,12 @@ public class LibMatrixDNN {
 	private static AtomicLong im2colDenseCount = new AtomicLong(0);
 	private static AtomicLong maxPoolBwdSparseCount = new AtomicLong(0);
 	private static AtomicLong maxPoolBwdDenseCount = new AtomicLong(0);
-	private static AtomicLong loopedConvMatMultTime = new AtomicLong(0);
-	private static AtomicLong loopedConvIm2ColTime = new AtomicLong(0);
-	private static AtomicLong loopedConvBwdFilterMatMultTime = new AtomicLong(0);
-	private static AtomicLong loopedConvBwdFilterIm2ColTime = new AtomicLong(0);
-	private static AtomicLong loopedConvBwdDataMatMultTime = new AtomicLong(0);
-	private static AtomicLong loopedConvBwdDataCol2ImTime = new AtomicLong(0);
+	static AtomicLong loopedConvMatMultTime = new AtomicLong(0);
+	static AtomicLong loopedConvIm2ColTime = new AtomicLong(0);
+	static AtomicLong loopedConvBwdFilterMatMultTime = new AtomicLong(0);
+	static AtomicLong loopedConvBwdFilterIm2ColTime = new AtomicLong(0);
+	static AtomicLong loopedConvBwdDataMatMultTime = new AtomicLong(0);
+	static AtomicLong loopedConvBwdDataCol2ImTime = new AtomicLong(0);
 	
 	public static void appendStatistics(StringBuilder sb) {
 		if(DMLScript.STATISTICS && DISPLAY_STATISTICS) {
@@ -128,8 +134,8 @@ public class LibMatrixDNN {
 	}
 	
 	// Commonly used operators
-	private static BinaryOperator _binaryElementWiseAddition = null;
-	private static BinaryOperator _binaryElementWiseMultiplication = null;
+	static BinaryOperator _binaryElementWiseAddition = null;
+	static BinaryOperator _binaryElementWiseMultiplication = null;
 	static {
 		try {
 			_binaryElementWiseAddition = InstructionUtils.parseBinaryOperator("+");
@@ -158,7 +164,7 @@ public class LibMatrixDNN {
 		if(isEligibleForConv2dSparse(params))
 			Statistics.numNativeSparseConv2dCalls.increment();
 		
-		runConvTask(TaskType.LoopedIm2ColConv2d, params);
+		execute(LibMatrixDNNHelper.getConv2dWorkers(params), params);
 		
 		//post-processing: maintain nnz
 		outputBlock.recomputeNonZeros();
@@ -179,7 +185,7 @@ public class LibMatrixDNN {
 		if(isEligibleForConv2dBackwardDataDense(params))
 			Statistics.numNativeSparseConv2dBwdDataCalls.increment();
 		
-		runConvTask(TaskType.LoopedIm2ColConv2dBwdData, params);
+		execute(LibMatrixDNNHelper.getConv2dBackwardDataWorkers(params), params);
 		
 		//post-processing: maintain nnz
 		outputBlock.recomputeNonZeros();
@@ -200,7 +206,7 @@ public class LibMatrixDNN {
 		if(isEligibleForConv2dBackwardFilterSparseDense(params))
 			Statistics.numNativeSparseConv2dBwdFilterCalls.increment();
 		
-		runConvTask(TaskType.LoopedIm2ColConv2dBwdFilter, params);
+		execute(LibMatrixDNNHelper.getConv2dBackwardFilterWorkers(params), params);
 		
 		//post-processing: maintain nnz
 		outputBlock.recomputeNonZeros();
@@ -239,10 +245,6 @@ public class LibMatrixDNN {
 				conv2dBwdDataDenseCount.addAndGet(1);
 			}
 		}
-		
-		int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
-		if (!(ALLOW_MULTI_THREADED_OPS && params.isOutputThreadSafe() && constrainedNumThreads > 1))
-			params.numThreads = 1;
 	}
 	
 	static void checkInputsConv2dBackwardFilter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params)  throws DMLRuntimeException {
@@ -270,89 +272,6 @@ public class LibMatrixDNN {
 		}
 	}
 	
-	/**
-	 * Performs the operation for(e : elem) ret += t(e) in a cache-conscious manner
-	 * by sequentially aggregating for(e : elem) tmp += e and finally transposing
-	 * ret = t(tmp).
-	 * 
-	 * @param ret left and output matrix
-	 * @param elem array of right untransposed matrices (expected in dense format)
-	 * @param params convolution parameters
-	 * @throws DMLRuntimeException in case of unsupported inputs or output
-	 */
-	private static void elementWiseInPlaceTransposedAddition(MatrixBlock ret, MatrixBlock[] elem) 
-		throws DMLRuntimeException 
-	{
-		//sanity checks non-empty and dense inputs / dense output
-		if( elem == null || elem.length==0 )
-			throw new DMLRuntimeException("Empty input not supported.");
-		for( MatrixBlock e : elem )
-			if( e.isInSparseFormat() )
-				throw new DMLRuntimeException("Sparse input format not supported.");
-		if( ret.isInSparseFormat() )
-			throw new DMLRuntimeException("Sparse output format not supported.");
-				
-		//Step 1: aggregate partial blocks without transpose
-		MatrixBlock tmpAgg = elem[0]; 
-		double[] tmp = tmpAgg.denseBlock;
-		for( int k=1; k<elem.length; k++ ) {
-			double[] tmp2 = elem[k].denseBlock;
-			for( int i=0; i<tmp.length; i++ )
-				tmp[i] += tmp2[i];
-		}
-		
-		//Step 2: cache-conscious transpose to output
-		tmpAgg.setNonZeros(-1); //avoid early abort
-		LibMatrixReorg.transpose(tmpAgg, ret);
-	}
-	
-	private static void doLoopedIm2ColConv2dBwdData(int n, MatrixBlock dout_reshaped, ConvolutionParameters params) throws DMLRuntimeException {
-		MatrixBlock filter = params.input1;
-		MatrixBlock dout = params.input2;
-		doRotate180(n, 0, dout, dout_reshaped.denseBlock, params, true);
-		
-		MatrixBlock temp = new MatrixBlock(params.P*params.Q, params.C*params.R*params.S, false);
-		long t1 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0;
-		singleThreadedMatMult(dout_reshaped, filter, temp, true, false, params);
-		long t2 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0 ;
-		doCol2imOverSingleImage(n, temp, params);
-		long t3 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0 ;
-		if(DMLScript.STATISTICS && DISPLAY_STATISTICS) {
-			loopedConvBwdDataMatMultTime.addAndGet(t2-t1);
-			loopedConvBwdDataCol2ImTime.addAndGet(t3-t2);
-		}
-	}
-	
-	private static MatrixBlock doLoopedIm2ColConv2dBwdFilter(int n, 
-			MatrixBlock im2ColOutBlock, MatrixBlock dout_reshaped, MatrixBlock partialRetBlock, ConvolutionParameters params, double []  tempIm2ColArr) throws DMLRuntimeException {
-		long t1 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0;
-		doIm2col(n, im2ColOutBlock, params, tempIm2ColArr);
-		long t2 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0 ;
-		
-		doRotate180(n, 0, params.input2, dout_reshaped.denseBlock, params, true);
-		
-		MatrixBlock temp = new MatrixBlock(params.C*params.R*params.S, params.K, false);
-		long t3 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0 ;
-		singleThreadedMatMult(im2ColOutBlock, dout_reshaped, temp, true, true, params);
-		long t4 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0 ;
-		if(DMLScript.STATISTICS && DISPLAY_STATISTICS) {
-			loopedConvBwdFilterMatMultTime.addAndGet(t4-t3);
-			loopedConvBwdFilterIm2ColTime.addAndGet(t2-t1);
-		}
-		if(!temp.isEmptyBlock()) {
-			// partialRetBlock is size: [params.C*params.R*params.S, params.K]
-			ConvolutionUtils.binaryOperationInPlace(temp, partialRetBlock.getDenseBlock(), 0, params.K, 0, params.C*params.R*params.S, 
-					_binaryElementWiseAddition);
-		}
-		return partialRetBlock;
-	}
-	
-	private static void computeTensorIndexes(int j, int [] ret, int H, int W) throws DMLRuntimeException {
-		ret[0] = j / (H*W);
-		ret[1] = (j - ret[0]*(H*W))/W;
-		ret[2] = j % W;
-	}
-	
 	static void checkInputsConv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
 		params.input1 = input;
 		params.input2 = filter;
@@ -379,76 +298,6 @@ public class LibMatrixDNN {
 		}
 	}
 	
-	// Single-threaded matrix multiplication
-	private static void singleThreadedMatMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, 
-			boolean recomputeNNZM1, boolean recomputeNNZM2, ConvolutionParameters params) throws DMLRuntimeException {
-		if(!params.enableNative || m1.isInSparseFormat() || m2.isInSparseFormat()) {
-			if(recomputeNNZM1)
-				m1.recomputeNonZeros();
-			if(recomputeNNZM2)
-				m2.recomputeNonZeros();
-			LibMatrixMult.matrixMult(m1, m2, ret, false);
-		}
-		else {
-			ret.sparse = false;
-			if(ret.getDenseBlock() == null)
-				ret.allocateDenseBlock();
-			NativeHelper.matrixMultDenseDense(m1.denseBlock, m2.denseBlock, 
-					ret.denseBlock, m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns(), 1);
-			ret.recomputeNonZeros();
-		}
-	}
-	
-	private static void doLoopedIm2ColConv2d(int n, MatrixBlock im2ColOutBlock, ConvolutionParameters params, double []  temp) throws DMLRuntimeException {
-		long t1 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0;
-		doIm2col(n, im2ColOutBlock, params, temp);
-		long t2 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0;
-		
-		MatrixBlock matMultOutBlock = new MatrixBlock(params.K, params.P*params.Q, false);
-		singleThreadedMatMult(params.input2, im2ColOutBlock, matMultOutBlock, false, true, params);
-		
-		long t3 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0;
-		
-		if(DMLScript.STATISTICS && DISPLAY_STATISTICS) {
-			loopedConvIm2ColTime.addAndGet(t2 - t1);
-			loopedConvMatMultTime.addAndGet(t3 - t2);
-		}
-		
-		// -----------------------------------------------------------------------------
-		// Copying is required as LibMatrixMult.matrixMult (and/or Java) is not pointer aware.
-		// This is not required in Native implementation
-		int destPos = n*params.K*params.P*params.Q;
-		int length = params.K*params.P*params.Q;
-		if(!matMultOutBlock.isEmptyBlock()) {
-			if(matMultOutBlock.isInSparseFormat()) {
-				// Copy the sparse matrix matMultOutBlock of shape [K X PQ] to 
-				// params.output.denseBlock + destPos
-				final int outOffset = n*params.K*params.P*params.Q;
-				final int PQ = params.P*params.Q;
-				for(int k = 0; k < matMultOutBlock.getNumRows(); k++) {
-					if( !matMultOutBlock.sparseBlock.isEmpty(k) ) {
-						int apos = matMultOutBlock.sparseBlock.pos(k);
-						int alen = matMultOutBlock.sparseBlock.size(k);
-						int[] aix = matMultOutBlock.sparseBlock.indexes(k);
-						double[] avals = matMultOutBlock.sparseBlock.values(k);
-						for(int j = apos; j < apos+alen; j++) {
-							int pqIndex = aix[j];
-							params.output.denseBlock[outOffset + k*PQ + pqIndex ] = avals[j];
-						}
-					}
-				}
-			}
-			else
-				System.arraycopy(matMultOutBlock.denseBlock, 0, params.output.denseBlock, destPos, length);
-		}
-		// -----------------------------------------------------------------------------
-		
-		// Recomputing nnz is not required for each individual im2col as it is invoked by outer public methods (i.e. conv2d.
-		//post-processing: maintain nnz
-		// params.output.recomputeNonZeros(); 
-	}
-	
-	
 	/**
 	 * This method computes the backpropogation errors for previous layer of maxpooling operation
 	 * 
@@ -485,10 +334,8 @@ public class LibMatrixDNN {
 			throw new DMLRuntimeException("Sparse maxpooling_backward is not supported");
 
 		fillIndexesArray(params);
-		if(performReluBackward)
-			runConvTask(TaskType.MaxPooling_Relu_Backward, params);
-		else
-			runConvTask(TaskType.MaxPooling_Backward, params);
+		
+		execute(LibMatrixDNNHelper.getMaxPoolingBackwardWorkers(params, performReluBackward), params);
 		
 		//post-processing: maintain nnz 
 		outputBlock.recomputeNonZeros();
@@ -521,211 +368,6 @@ public class LibMatrixDNN {
 		}
 	}
 	
-	private static void doPoolingBackward(int n, ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException {
-		double [] inputArray = null;
-		if (!params.input1.isInSparseFormat())
-			inputArray = params.input1.getDenseBlock();
-		double [] doutArray = null;
-		if (!params.input2.isInSparseFormat())
-			doutArray = params.input2.getDenseBlock();
-		double [] outputArray = null;
-		if (!params.output.isInSparseFormat())
-			outputArray = params.output.getDenseBlock();
-		else
-			throw new DMLRuntimeException("Only dense output supported for pooling_backward");
-			
-		if(inputArray != null) {
-			if(doutArray != null)
-				doPoolingBackwardDenseDense(n, inputArray, doutArray, outputArray, params, performReluBackward);
-			else
-				doPoolingBackwardDenseSparse(n, inputArray, params.input2, outputArray, params, performReluBackward);
-		}
-		else {
-			if(doutArray != null)
-				doPoolingBackwardSparseDense(n, doutArray, outputArray, params, performReluBackward);
-			else
-				doPoolingBackwardSparseSparse(n, outputArray, params, performReluBackward);
-		}
-	}
-	
-	private static void doPoolingBackwardSparseDense(int n, double [] doutArray,  double [] outputArray, ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException {
-		if (!params.input1.isInSparseFormat())
-			throw new DMLRuntimeException("Incorrect usage: Call optimized versions");
-		
-		for (int c = 0; c < params.C; c++) {
-			for (int p = 0; p < params.P; p++) {
-				for (int q = 0; q < params.Q; q++) {
-					double inVal = doutArray[n*params.C*params.P*params.Q + c*params.P*params.Q +  p * params.Q + q];
-					if(inVal != 0) {
-						final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W;
-						int maxIndex = getMaxIndexSparse(p, q, inputOffset, n, c, params.input1, params, performReluBackward);
-						if(maxIndex != -1)
-							outputArray[maxIndex] += inVal;
-					}
-				}
-			}
-		}
-	}
-	
-	private static void doPoolingBackwardSparseSparse(int n, double [] outputArray, ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException {
-		if (!params.input1.isInSparseFormat())
-			throw new DMLRuntimeException("Incorrect usage: Call optimized versions");
-		
-		if( !params.input2.sparseBlock.isEmpty(n) ) {
-			int [] tensorIndexes = new int[3];
-			int apos = params.input2.sparseBlock.pos(n);
-			int alen = params.input2.sparseBlock.size(n);
-			int[] aix = params.input2.sparseBlock.indexes(n);
-			double[] avals = params.input2.sparseBlock.values(n);
-			for(int j = apos; j < apos+alen; j++) {
-				computeTensorIndexes(aix[j], tensorIndexes, params.P, params.Q);
-				int c = tensorIndexes[0];
-				int p = tensorIndexes[1];
-				int q = tensorIndexes[2];
-				final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W;
-				int maxIndex = getMaxIndexSparse(p, q, inputOffset, n, c, params.input1, params, performReluBackward);
-				if(maxIndex != -1)
-					outputArray[maxIndex] += avals[j];
-			}
-		}
-	}
-	
-	private static void doPoolingBackwardDenseSparse(int n, double [] inputArray, 
-			MatrixBlock dout, double [] outputArray, ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException {
-		if( !dout.sparseBlock.isEmpty(n) ) {
-			int [] tensorIndexes = new int[3];
-			int apos = dout.sparseBlock.pos(n);
-			int alen = dout.sparseBlock.size(n);
-			int[] aix = dout.sparseBlock.indexes(n);
-			double[] avals = dout.sparseBlock.values(n);
-			for(int j = apos; j < apos+alen; j++) {
-				computeTensorIndexes(aix[j], tensorIndexes, params.P, params.Q);
-				int c = tensorIndexes[0];
-				int p = tensorIndexes[1];
-				int q = tensorIndexes[2];
-				final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W;
-				int maxIndex = getMaxIndex(p, q, inputOffset, inputArray, params, performReluBackward);
-				if(maxIndex != -1)
-					outputArray[maxIndex] += avals[j];
-			}
-		}
-	}
-	
-	private static void doPoolingBackwardDenseDense(int n, double [] inputArray, double [] doutArray, 
-			double [] outputArray, ConvolutionParameters params, boolean performReluBackward) {
-		for (int c = 0; c < params.C; c++) {
-			final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W;
-			final int outputOffset = n*params.C*params.P*params.Q + c*params.P*params.Q;
-			
-			for (int p = 0; p < params.P; p++) {
-				for (int q = 0; q < params.Q; q++) {
-					int maxIndex = getMaxIndex(p, q, inputOffset, inputArray, params, performReluBackward);
-					if(maxIndex != -1)
-						outputArray[maxIndex] += doutArray[outputOffset +  p * params.Q + q];
-				}
-			}
-		}
-	}
-	
-	/**
-	 * Returns the index of cell with maximum value. This method is optimized for sparse input
-	 * 
-	 * @param p output feature map height
-	 * @param q output feature map width
-	 * @param inputOffset offset to be used for input index
-	 * @param n number of images
-	 * @param c number of channels 
-	 * @param input input matrix
-	 * @param params convolution parameters
-	 * @param performReluBackward perform ReLU on input
-	 * @return index of the cell with maximum value
-	 * @throws DMLRuntimeException if error occurs
-	 */
-	private static int getMaxIndexSparse(int p, int q, int inputOffset, int n, int c, MatrixBlock input, ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException {
-		if(!input.isInSparseFormat())
-			throw new DMLRuntimeException("Incorrect usage: Only sparse format supported");
-		
-		int [] tensorIndexes = new int[3];
-		
-		int start_index_h = params.start_indexes_h[p];
-		int end_index_h = params.end_indexes_h[p];
-		int start_index_w = params.start_indexes_w[q];
-		int end_index_w = params.end_indexes_w[q];
-		
-		int maxIndex = -1; 
-		double maxVal = -Double.MAX_VALUE;
-		
-		// Note: We do not treat pad as zero and hence we don't do:  
-		// maxVal = 0 
-		// if start_index_h < 0 || start_index_w < 0 || end_index_h >= params.H || end_index_w >= params.W
-
-		// input.isEmptyBlock() check is done by the caller
-		if( !input.sparseBlock.isEmpty(n) ) {
-			// Find maxIndex
-			int apos = input.sparseBlock.pos(n);
-			int alen = input.sparseBlock.size(n);
-			int[] aix = input.sparseBlock.indexes(n);
-			double[] avals = input.sparseBlock.values(n);
-			for(int j=apos; j<apos+alen; j++) {
-				computeTensorIndexes(aix[j], tensorIndexes, params.H, params.W);
-				if(c != tensorIndexes[0])
-					continue;
-				int h = tensorIndexes[1];
-				int w = tensorIndexes[2];
-				if(h >= start_index_h && h < end_index_h && w >= start_index_w && w < end_index_w) {
-					double val = performReluBackward && avals[j] < 0 ? 0 : avals[j]; 
-					if(maxVal < val) {
-						maxIndex = inputOffset +  h*params.W + w;
-						maxVal = val;
-					}
-				}
-			}
-		}
-		else {
-			maxIndex = inputOffset;
-		}
-		return maxIndex;
-	}
-	
-	/**
-	 * Returns the index of cell with maximum value. This method is optimized for dense input
-	 * 
-	 * @param p output feature map height
-	 * @param q output feature map width
-	 * @param inputOffset offset to be used for input index
-	 * @param inputArray input array
-	 * @param params convolution parameters
-	 * @param performReluBackward perform ReLU backward
-	 * @return index of cell with maximum value
-	 */
-	private static int getMaxIndex(int p, int q, int inputOffset, double [] inputArray, ConvolutionParameters params, boolean performReluBackward) {
-		int start_index_h = params.start_indexes_h[p];
-		int end_index_h = params.end_indexes_h[p];
-		int start_index_w = params.start_indexes_w[q];
-		int end_index_w = params.end_indexes_w[q];
-		
-		int maxIndex = -1; 
-		double maxVal = -Double.MAX_VALUE;
-		
-		// Note: We do not treat pad as zero and hence we don't do:  
-		// maxVal = 0 
-		// if start_index_h < 0 || start_index_w < 0 || end_index_h >= params.H || end_index_w >= params.W
-		
-		// Find maxIndex
-		double currDoutVal = -1;
-		for (int h = start_index_h; h < end_index_h; h++) {
-			for (int w = start_index_w; w < end_index_w; w++) {
-				currDoutVal = inputArray[inputOffset +  h*params.W + w];
-				currDoutVal = performReluBackward && currDoutVal < 0 ? 0 : currDoutVal;
-				if(maxVal < currDoutVal) {
-					maxIndex = inputOffset +  h*params.W + w;
-					maxVal = currDoutVal;
-				}
-			}
-		}
-		return maxIndex;
-	}
-	
 	/**
 	 * This method computes the backpropagation errors for previous layer of relu operation
 	 * 
@@ -746,37 +388,12 @@ public class LibMatrixDNN {
 				input.getNumRows() + " != " + dout.getNumRows() + " || " + input.getNumColumns() + " != " + dout.getNumColumns());
 		}
 		
-		runConvTask(TaskType.ReluBackward, params);
-		
-		//note: no post-processing as nnz maintained per task
-	}
-	
-	private static long doReluBackward(ConvolutionParameters params, int rl, int ru) throws DMLRuntimeException {
-		// (X > 0) * dout
-		double [] outputArray = params.output.getDenseBlock();
-		int numOutCols = params.input1.getNumColumns();
-		
-		if(!params.input1.isInSparseFormat() && !params.input2.isInSparseFormat()) {
-			double [] inputArr = params.input1.getDenseBlock();
-			double [] doutArr = params.input2.getDenseBlock();
-			for(int i = rl*numOutCols; i < ru*numOutCols; i++) {
-				outputArray[i] = inputArr[i] > 0 ? doutArr[i] : 0;
-			}
-		}
-		else {
-			// Perform (X > 0)
-			ConvolutionUtils.scalarOperations(params.input1, outputArray, rl*numOutCols, numOutCols, rl, ru, 
-					InstructionUtils.parseScalarBinaryOperator(">", false, 0));
-			// Then perform (X > 0) * dout
-			ConvolutionUtils.binaryOperationInPlace(params.input2, outputArray, rl*numOutCols, numOutCols, rl, ru, 
-					_binaryElementWiseMultiplication);
-		}
+		execute(LibMatrixDNNHelper.getReluBackwardWorkers(params), params);
 		
-		//post-processing: maintain nnz
-		return params.output.recomputeNonZeros(rl, ru-1, 0, numOutCols-1);
+		// post-processing: maintain nnz
+		outputBlock.recomputeNonZeros();
 	}
 	
-	
 	/**
 	 * Performs the operation corresponding to the DML script:
 	 * ones = matrix(1, rows=1, cols=Hout*Wout)		
@@ -883,539 +500,55 @@ public class LibMatrixDNN {
 		}
 		
 		fillIndexesArray(params);
-		runConvTask(TaskType.MaxPooling_Forward, params);
-		
-		//post-processing: maintain nnz
-		outputBlock.recomputeNonZeros();
-	}
-	
-	private static void doPooling(int n, ConvolutionParameters params) throws DMLRuntimeException {
-		double [] inputArray = null;
-		if (!params.input1.isInSparseFormat())
-			inputArray = params.input1.getDenseBlock();
-		double [] outputArray = null;
-		if (!params.output.isInSparseFormat())
-			outputArray = params.output.getDenseBlock();
-		else
-			throw new DMLRuntimeException("Expected the output to be allocated in dense format");
-		
-		final int inOffset = n*params.C*params.H*params.W;
-		int out_index = n*params.C*params.P*params.Q;
-		final int HW = params.H*params.W;
 		
-		if(inputArray != null) {
-			for (int c = 0; c < params.C; c++) {
-				final int inOffset1 = inOffset + c*HW;
-				for (int p = 0; p < params.P; p++) {
-					for (int q = 0; q < params.Q; q++, out_index++) {
-						for (int h = params.start_indexes_h[p]; h < params.end_indexes_h[p]; h++) {
-							for (int w = params.start_indexes_w[q]; w < params.end_indexes_w[q]; w++) {
-								outputArray[out_index] = Math.max(outputArray[out_index], inputArray[inOffset1 +  h*params.W + w]);
-							}
-						}
-					}
-				}
-			}
-		}
-		else {
-			// TODO: Optimize sparse maxpooling
-			// Low priority after adding fused relu_maxpooling operator as output of conv2d expected to be dense
-			for (int c = 0; c < params.C; c++) {
-				for (int p = 0; p < params.P; p++) {
-					for (int q = 0; q < params.Q; q++, out_index++) {
-						for (int h = params.start_indexes_h[p]; h < params.end_indexes_h[p]; h++) {
-							for (int w = params.start_indexes_w[q]; w < params.end_indexes_w[q]; w++) {
-								outputArray[out_index] = Math.max(outputArray[out_index], params.input1.quickGetValue(n, c*HW +  h*params.W + w));
-							}
-						}
-					}
-				}
-			}
-		}
-	}
-	
-	private static void doRotate180(int inputN, int outputN, MatrixBlock input, 
-			double [] outputArray,  ConvolutionParameters params, boolean zeroOutSparseOutput) throws DMLRuntimeException {
-		double [] inputArray = null;
-		if (!input.isInSparseFormat())
-			inputArray = input.getDenseBlock();
-		if(outputArray == null)
-			throw new DMLRuntimeException("Sparse output is not supported for rotate180");
+		execute(LibMatrixDNNHelper.getMaxPoolingWorkers(params), params);
 		
-		int outputOffset = outputN*params.K*params.P*params.Q;
-		if(inputArray != null) {
-			for (int k = 0; k < params.K; k++) {
-				for (int p = 0; p < params.P; p++) {
-					for (int q = 0; q < params.Q; q++) {
-						outputArray[outputOffset + p*params.Q*params.K + q*params.K + k] = inputArray[inputN*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q + q];
-					}
-				}
-			}
-		}
-		else {
-			if(zeroOutSparseOutput)
-				Arrays.fill(outputArray, 0);
-			
-			if(!input.isEmptyBlock()) {
-				if( !input.sparseBlock.isEmpty(inputN) ) {
-					int [] tensorIndexes = new int[3];
-					int apos = input.sparseBlock.pos(inputN);
-					int alen = input.sparseBlock.size(inputN);
-					int[] aix = input.sparseBlock.indexes(inputN);
-					double[] avals = input.sparseBlock.values(inputN);
-					for(int j = apos; j < apos+alen; j++) {
-						computeTensorIndexes(aix[j], tensorIndexes, params.P, params.Q);
-						int k = tensorIndexes[0];
-						int p = tensorIndexes[1];
-						int q = tensorIndexes[2];
-						outputArray[outputOffset + p*params.Q*params.K + q*params.K + k] = avals[j];
-					}
-				}
-			}
-		}
+		// post-processing: maintain nnz
+		outputBlock.recomputeNonZeros();
 	}
 	
-	// ----------------------------------------------------------------------------------------------------------------
-	private static void addMatrixBlocks(int poolSize, TaskType type, ConvolutionParameters params, 
-			ConcurrentLinkedQueue<MatrixBlock> im2ColOutBlocks, ConcurrentLinkedQueue<MatrixBlock> doutReshapedBlocks,
-			ConcurrentLinkedQueue<MatrixBlock> partialRetBlocks) {
-		boolean isEligibleForConv2dSparse = (type == TaskType.LoopedIm2ColConv2d) && isEligibleForConv2dSparse(params);
-		boolean isEligibleForConv2dBackwardFilterSparseDense = (type == TaskType.LoopedIm2ColConv2dBwdFilter) && isEligibleForConv2dBackwardFilterSparseDense(params) ;
-		for(int i = 0; i < poolSize; i++) {
-			if(type == TaskType.LoopedIm2ColConv2d || type == TaskType.LoopedIm2ColConv2dBwdFilter) {
-				if(!isEligibleForConv2dSparse && !isEligibleForConv2dBackwardFilterSparseDense) {
-					MatrixBlock im2ColOutBlock = new MatrixBlock(params.C*params.R*params.S, params.P*params.Q, false);
-					im2ColOutBlock.allocateDenseBlock();
-					im2ColOutBlocks.add(im2ColOutBlock);
+	/**
+	 * Executes the tasks in parallel using java's ExecutorService.
+	 *  
+	 * @param tasks deep learning related tasks
+	 * @param params convolution parameters
+	 * @throws DMLRuntimeException if the error occurs
+	 */
+	private static void execute(ArrayList<Callable<Long>> tasks, ConvolutionParameters params) throws DMLRuntimeException {
+		int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
+		try {
+			if(k == 1) {
+				// Single-threaded execution when called in parfor
+				// this avoid unnecessary creation of threadpool.
+				for(Callable<Long> task : tasks) {
+					task.call();
 				}
 			}
-			
-			if(type == TaskType.LoopedIm2ColConv2dBwdFilter) {
-				MatrixBlock partialRetBlock = new MatrixBlock(params.C*params.R*params.S, params.K, false);
-				partialRetBlock.allocateDenseBlock();
-				partialRetBlocks.add(partialRetBlock);
-			}
-			
-			if(type == TaskType.LoopedIm2ColConv2dBwdData || type == TaskType.LoopedIm2ColConv2dBwdFilter) {
-				MatrixBlock doutReshapedBlock = new MatrixBlock(params.P*params.Q, params.K, false);
-				doutReshapedBlock.allocateDenseBlock();
-				doutReshapedBlocks.add(doutReshapedBlock);
-			}
-		}
-	}
-	// Methods to execute convolution-related tasks using multiple threads.
-	private static void runConvTask(TaskType type, ConvolutionParameters params) throws DMLRuntimeException {
-		int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
-		ConcurrentLinkedQueue<MatrixBlock> im2ColOutBlocks = new ConcurrentLinkedQueue<MatrixBlock>();
-		ConcurrentLinkedQueue<MatrixBlock> doutReshapedBlocks = new ConcurrentLinkedQueue<MatrixBlock>();
-		ConcurrentLinkedQueue<MatrixBlock> partialRetBlocks = new ConcurrentLinkedQueue<MatrixBlock>();
-		
-		if (ALLOW_MULTI_THREADED_OPS && params.isOutputThreadSafe() && k > 1) {
-			int poolSize = Math.min(k, params.N);
-			addMatrixBlocks(poolSize, type, params, im2ColOutBlocks, doutReshapedBlocks, partialRetBlocks);
-			
-			ArrayList<ConvTask> tasks = new ArrayList<ConvTask>();
-			int blklen = (int)(Math.ceil((double)params.N/poolSize/NUM_TASK_FACTOR));
-			for( int i=0; i<poolSize*NUM_TASK_FACTOR && i*blklen<params.N; i++ )
-				tasks.add(new ConvTask(i*blklen, Math.min((i+1)*blklen, params.N), 
-						type, params, im2ColOutBlocks, doutReshapedBlocks, partialRetBlocks));
-			
-			try {
-				ExecutorService pool = Executors.newFixedThreadPool( poolSize );
+			else {
+				ExecutorService pool = Executors.newFixedThreadPool( Math.min(k, params.N) );
 				List<Future<Long>> taskret = pool.invokeAll(tasks);
 				pool.shutdown();
 				for( Future<Long> task : taskret )
-					params.output.nonZeros += task.get();
-				if(type == TaskType.LoopedIm2ColConv2dBwdFilter) {
-					elementWiseInPlaceTransposedAddition(params.output, partialRetBlocks.toArray(new MatrixBlock[0]));
-				}
-			} 
-			catch (Exception e) {
-				throw new DMLRuntimeException("Error while executing multi-threaded " + type.name(), e);
-			}
-		}
-		else {
-			addMatrixBlocks(1, type, params, im2ColOutBlocks, doutReshapedBlocks, partialRetBlocks);
-			try {
-				//execute single task and maintain nnz if supported
-				params.output.setNonZeros(new ConvTask(0, params.N, type, params, im2ColOutBlocks, 
-						doutReshapedBlocks, partialRetBlocks).call());
-				
-				if(type == TaskType.LoopedIm2ColConv2dBwdFilter) {
-					elementWiseInPlaceTransposedAddition(params.output, partialRetBlocks.toArray(new MatrixBlock[0]));
-				}
-			} catch (Exception e) {
-				throw new DMLRuntimeException("Error while executing single-threaded " + type.name(), e);
+					task.get();
 			}
+		} 
+		catch (Exception e) {
+			throw new DMLRuntimeException("Error while executing multi-threaded tasks", e);
 		}
 	}
-	// ----------------------------------------------------------------------------------------------------------------
 	
-	private static boolean isEligibleForConv2dBackwardFilterSparseDense(ConvolutionParameters params) {
-		// NativeHelper.conv2dBackwardFilterSparseDense only if filter is sparse. 
+	static boolean isEligibleForConv2dBackwardFilterSparseDense(ConvolutionParameters params) {
+		// NativeHelper.conv2dBackwardFilterSparseDense only if input is sparse. 
 		// dout converted to dense if sparse.
 		return params.enableNative && params.input1.isInSparseFormat();
 	}
-	private static boolean isEligibleForConv2dSparse(ConvolutionParameters params) {
+	static boolean isEligibleForConv2dSparse(ConvolutionParameters params) {
 		// NativeHelper.conv2dSparse only if filter is dense and input is sparse
 		return params.enableNative && params.input1.isInSparseFormat() && !params.input2.isInSparseFormat();
 	}
-	private static boolean isEligibleForConv2dBackwardDataDense(ConvolutionParameters params) {
+	static boolean isEligibleForConv2dBackwardDataDense(ConvolutionParameters params) {
 		// NativeHelper.conv2dBackwardDataDense only if filter is dense. 
 		// dout converted to dense if sparse.
 		return params.enableNative && !params.input1.isInSparseFormat();
 	}
-	
-	/**
-	 * The ConvTask allows the convolution operations (such s conv2d, conv2d_backward, maxpooling, etc)
-	 * to be executed in multi-thread manner.
-	 * 
-	 */
-	private static class ConvTask implements Callable<Long> 
-	{
-		public int _rl; 
-		public int _ru; 
-		private final ConvolutionParameters _params;
-		private final TaskType _type;
-		private final ConcurrentLinkedQueue<MatrixBlock> _im2ColOutBlocks;
-		private final ConcurrentLinkedQueue<MatrixBlock> _partialRetBlocks;
-		private final ConcurrentLinkedQueue<MatrixBlock> _doutReshapedBlocks;
-		
-		public ConvTask(int rl, int ru, TaskType type, ConvolutionParameters params, 
-				ConcurrentLinkedQueue<MatrixBlock> im2ColOutBlocks,
-				ConcurrentLinkedQueue<MatrixBlock> doutReshapedBlocks,
-				ConcurrentLinkedQueue<MatrixBlock> partialRetBlocks) {
-			_rl = rl;
-			_ru = ru;
-			_type = type;
-			_params = params;
-			_im2ColOutBlocks = im2ColOutBlocks;
-			_partialRetBlocks = partialRetBlocks;
-			_doutReshapedBlocks = doutReshapedBlocks;
-		}
-		
-		@Override
-		public Long call() throws DMLRuntimeException {
-			long lnnz = 0; //nnz per partition
-			
-			switch(_type) {
-				case MaxPooling_Forward:
-					for(int n = _rl; n < _ru; n++)
-						doPooling(n, _params);
-					break;
-				case MaxPooling_Backward:
-					for(int n = _rl; n < _ru; n++) 
-						doPoolingBackward(n, _params, false);
-					break;
-				case MaxPooling_Relu_Backward:
-					for(int n = _rl; n < _ru; n++) 
-						doPoolingBackward(n, _params, true);
-					break;
-				case ReluBackward:
-					lnnz = doReluBackward(_params, _rl, _ru);
-					break;
-				case LoopedIm2ColConv2d:
-				{	
-					if(isEligibleForConv2dSparse(_params)) {
-						// NativeHelper.conv2dSparse only if filter is dense and input is sparse
-						int KPQ = _params.K*_params.P*_params.Q;
-						double[] temp = new double[KPQ];
-						for(int n = _rl; n < _ru; n++)  {
-							if( !_params.input1.getSparseBlock().isEmpty(n) ) {
-								int apos = _params.input1.getSparseBlock().pos(n);
-								int alen = _params.input1.getSparseBlock().size(n);
-								int[] aix = _params.input1.getSparseBlock().indexes(n);
-								double[] avals = _params.input1.getSparseBlock().values(n);
-								NativeHelper.conv2dSparse(apos, alen, aix, avals, _params.input2.getDenseBlock(), temp, 
-										1, _params.C, _params.H, _params.W, _params.K, _params.R, _params.S, 
-										_params.stride_h, _params.stride_w, _params.pad_h, _params.pad_w, _params.P, _params.Q, 1);
-								System.arraycopy(temp, 0, _params.output.denseBlock, n*KPQ, KPQ);
-							}
-						}
-					}
-					else {
-						// In all other cases, perform im2col in Java + matmult (either native or java).
-						MatrixBlock im2ColOutBlock = _im2ColOutBlocks.remove();
-						double [] temp = (_params.input1.isInSparseFormat() || _params.input1.denseBlock == null) ? new double[_params.input1.getNumColumns()] : null;
-						for(int n = _rl; n < _ru; n++) 
-							doLoopedIm2ColConv2d(n, im2ColOutBlock, _params, temp);
-						_im2ColOutBlocks.add(im2ColOutBlock);
-					}
-					if(_params.bias != null) {
-						// bias is always converted to dense format
-						double [] biasArr = _params.bias.getDenseBlock();
-						int PQ = _params.P*_params.Q;
-						int index = _rl*_params.K*PQ;
-						for(int n = _rl; n < _ru; n++) {
-							for(int k = 0; k < _params.K; k++) {
-								for(int pq = 0; pq < PQ; pq++, index++) {
-									_params.output.denseBlock[index] += biasArr[k];
-								}
-							}
-						}
-					}
-					break;
-				}
-				case LoopedIm2ColConv2dBwdFilter:
-				{
-					MatrixBlock partialRetBlock = _partialRetBlocks.remove();
-					MatrixBlock doutReshapedBlock = _doutReshapedBlocks.remove();
-					if(isEligibleForConv2dBackwardFilterSparseDense(_params)) {
-						double [] dout_n = doutReshapedBlock.getDenseBlock();
-						for(int n = _rl; n < _ru; n++) {
-							if( !_params.input1.getSparseBlock().isEmpty(n) ) {
-								doRotate180(n, 0, _params.input2, dout_n, _params, true);
-								int apos = _params.input1.getSparseBlock().pos(n);
-								int alen = _params.input1.getSparseBlock().size(n);
-								int[] aix = _params.input1.getSparseBlock().indexes(n);
-								double[] avals = _params.input1.getSparseBlock().values(n);
-								NativeHelper.conv2dBackwardFilterSparseDense(apos, alen, aix, avals, 
-										dout_n, partialRetBlock.getDenseBlock(), 1, _params.C, _params.H, _params.W, _params.K, 
-										_params.R, _params.S, _params.stride_h, _params.stride_w, _params.pad_h, _params.pad_w, _params.P, _params.Q, 1);
-							}
-						}
-					}
-					else {
-						MatrixBlock im2ColOutBlock = _im2ColOutBlocks.remove();
-						double [] temp = _params.input1.isInSparseFormat() ? new double[_params.input1.getNumColumns()] : null;
-						for(int n = _rl; n < _ru; n++) 
-							partialRetBlock = doLoopedIm2ColConv2dBwdFilter(n, im2ColOutBlock, doutReshapedBlock, partialRetBlock, _params, temp);
-						_im2ColOutBlocks.add(im2ColOutBlock);
-					}
-					_doutReshapedBlocks.add(doutReshapedBlock);
-					_partialRetBlocks.add(partialRetBlock);
-					break;
-				}
-				case LoopedIm2ColConv2dBwdData:
-				{
-					MatrixBlock doutReshapedBlock = _doutReshapedBlocks.remove();
-					if(isEligibleForConv2dBackwardDataDense(_params)) {
-						int CHW = _params.C*_params.H*_params.W;
-						double [] ret = new double[CHW];
-						double [] filterArr = _params.input1.getDenseBlock();
-						for(int n = _rl; n < _ru; n++) {
-							double [] dout_n = getRowInDenseFormat(_params.input2, n, doutReshapedBlock.getDenseBlock());
-							if(n > _rl)
-								Arrays.fill(ret, 0);
-							NativeHelper.conv2dBackwardDataDense(filterArr, dout_n, ret, 1, 
-									_params.C, _params.H, _params.W, _params.K, 
-									_params.R, _params.S, _params.stride_h, _params.stride_w, _params.pad_h, _params.pad_w, _params.P, _params.Q, 1);
-							System.arraycopy(ret, 0, _params.output.getDenseBlock(), n*CHW, CHW);
-						}
-					}
-					else {
-						for(int n = _rl; n < _ru; n++) 
-							doLoopedIm2ColConv2dBwdData(n, doutReshapedBlock, _params);
-					}
-					_doutReshapedBlocks.add(doutReshapedBlock);
-					break;
-				}
-				default:
-					throw new DMLRuntimeException("Unsupported ConvTask:" + _type.name());
-			}
-			
-			return lnnz;
-		}
-	}
-		
-	// Converts input: PQ X CRS matrix and writes to 1 X CHW
-	private static void doCol2imOverSingleImage(int outputN, MatrixBlock input, ConvolutionParameters params) throws DMLRuntimeException {
-		if(input.rlen != params.P*params.Q || input.clen != params.C*params.R*params.S) {
-			throw new DMLRuntimeException("Incorrect input dimensions");
-		}
-		
-		double [] outputArray = null;
-		if (!params.output.isInSparseFormat())
-			outputArray = params.output.getDenseBlock();
-		else {
-			throw new DMLRuntimeException("Only dense output is implemented");
-		}
-		
-		if(!input.isInSparseFormat()) {
-			double [] inputArray = input.getDenseBlock();
-			doCol2IMDenseInput(0, outputN, inputArray, outputArray, params);
-		}
-		else {
-			if(!input.isEmptyBlock()) {
-				int [] tensorIndexes = new int[3];
-				for(int i = 0; i < input.getNumRows(); i++) {
-					if( !input.sparseBlock.isEmpty(i) ) {
-						computeTensorIndexes(i, tensorIndexes, params.P, params.Q);
-						int p = tensorIndexes[1];
-						int q = tensorIndexes[2];
-						if(tensorIndexes[0] != 0) 
-							throw new DMLRuntimeException("Incorrect tensor indexes: " + tensorIndexes[0] + " != 0 <" + p + " " + q + " " + tensorIndexes[0] + params.P + " " + params.Q + ">");
-						
-						int apos = input.sparseBlock.pos(i);
-						int alen = input.sparseBlock.size(i);
-						int[] aix = input.sparseBlock.indexes(i);
-						double[] avals = input.sparseBlock.values(i);
-						for(int j = apos; j < apos+alen; j++) {
-							computeTensorIndexes(aix[j], tensorIndexes, params.R, params.S);
-							int c = tensorIndexes[0];
-							int r = tensorIndexes[1];
-							int s = tensorIndexes[2];
-							int h = p*params.stride_h + r - params.pad_h;
-							int w = q*params.stride_w + s - params.pad_w;
-							if(h >= 0 && h < params.H && w >= 0 && w < params.W) {
-								int outIndex = outputN*params.C*params.H*params.W + c*params.H*params.W + h*params.W + w;
-								outputArray[outIndex] += avals[j];
-							}
-						}
-					}
-				}
-			}
-		}
-	}
-	
-	// Converts input: PQ X CRS matrix and writes to 1 X CHW if inputN == 0
-	// Or converts input: NPQ X CRS matrix and writes to N X CHW 
-	private static void doCol2IMDenseInput(int inputN, int outputN, double [] inputArray, double [] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
-		final int outputNOffset = outputN*params.C*params.H*params.W;
-		for (int p = 0; p < params.P; p++) {
-			// h = p*params.stride_h + r - params.pad_h
-			//   = r + hOffset
-			// Based on restrictions: h >= 0 and r >= 0 and h < params.H and r < params.R, we get
-			// max(0, - hOffset) <= r < min(params.R, params.H - hOffset)
-			final int hOffset = p*params.stride_h - params.pad_h;
-			final int rStart = Math.max(0, - hOffset);
-			final int rEnd = Math.min(params.R, params.H - hOffset);
-			for (int q = 0; q < params.Q; q++) {
-				// Using the same logic as above on following:
-				// w = q*params.stride_w + s - params.pad_w
-				final int wOffset = q*params.stride_w - params.pad_w;
-				final int sStart = Math.max(0, - wOffset);
-				final int sEnd = Math.min(params.S, params.W - wOffset);
-				final int tempOffset = (inputN*params.P*params.Q + p*params.Q + q)*params.C*params.R*params.S;
-				for (int c = 0; c < params.C; c++) {
-					final int outOffset = outputNOffset + c*params.H*params.W;
-					final int inputOffset = tempOffset + c*params.R*params.S;
-					for (int r = rStart; r < rEnd; r++) {
-						for (int s = sStart; s < sEnd; s++) {
-							int inputIndex = inputOffset + r*params.S + s;
-							int outIndex = outOffset + (hOffset + r)*params.W + wOffset + s;
-							outputArray[outIndex] += inputArray[inputIndex];
-						}
-					}
-				}
-			}
-		}
-	}
-	
-	private static void doIm2colDense(int n, double [] inputArray, double [] outputArray, ConvolutionParameters params) {
-		int CRS = params.C * params.R * params.S;
-		final int nOffset = n * params.C*params.H*params.W;
-		if (params.stride_h == 1 && params.stride_w == 1 && params.pad_h == 0 && params.pad_w == 0) {
-			for (int c = 0; c < CRS; ++c) {
-				int wOffset = c % params.S;
-				int hOffset = (c / params.S) % params.R;
-				int cInput = c / params.R / params.S;
-				for (int h = 0; h < params.P; ++h) {
-					int hPadded = h + hOffset;
-					int outOffset = (c * params.P + h) * params.Q;
-					int inputOffset = nOffset + (cInput * params.H + hPadded) * params.W;
-					System.arraycopy(inputArray, inputOffset + wOffset, outputArray, outOffset, params.Q);
-					int w = params.Q - 1;
-					int wPadded = w + wOffset;
-					if (hPadded < params.H && wPadded < params.W)
-						outputArray[outOffset + w] = inputArray[inputOffset + wPadded];
-					else
-						outputArray[outOffset + w] = 0;
-				}
-			}
-		} else {
-			for (int c = 0; c < CRS; ++c) {
-				int wOffset = c % params.S;
-				int hOffset = (c / params.S) % params.R;
-				int cInput = c / params.R / params.S;
-				for (int h = 0; h < params.P; ++h) {
-					int outOffset = (c * params.P + h) * params.Q;
-					int hPadded = h * params.stride_h - params.pad_h + hOffset;
-					int inputOffset = nOffset + (cInput * params.H + hPadded) * params.W;
-					if (hPadded < 0 || hPadded >= params.H) {
-						Arrays.fill(outputArray, outOffset, outOffset+params.Q, 0);
-					} else {
-						for (int w = 0; w < params.Q; ++w) {
-							int wPadded = w * params.stride_w - params.pad_w + wOffset;
-							if (wPadded >= 0 && wPadded < params.W)
-								outputArray[outOffset + w] = inputArray[inputOffset + wPadded];
-							else
-								outputArray[outOffset + w] = 0;
-						}
-					}
-				}
-			}
-		}
-	}
-	
-	// Returns the row of matrix in dense format
-	private static double [] getRowInDenseFormat(MatrixBlock input, int n, double []  temp) throws DMLRuntimeException {
-		if(input.getNumColumns() != temp.length) {
-			throw new DMLRuntimeException("Invalid parameters");
-		}
-		// Use temporary array to avoid binary search
-		if(input.isInSparseFormat()) {
-			Arrays.fill(temp, 0);
-			if( !input.sparseBlock.isEmpty(n) ) {
-				int apos = input.sparseBlock.pos(n);
-				int alen = input.sparseBlock.size(n);
-				int[] aix = input.sparseBlock.indexes(n);
-				double[] avals = input.sparseBlock.values(n);
-				for(int j=apos; j<apos+alen; j++)
-					temp[ aix[j] ] = avals[j];
-			}
-		}
-		else {
-			System.arraycopy(input.getDenseBlock(), n*input.getNumColumns(), temp, 0, input.getNumColumns());
-		}
-		return temp;
-	}
-	
-	// Keeping this as a separate sparse method to allow for further dense optimizations
-	private static void doIm2colSparse(int n, MatrixBlock input, double [] outputArray, ConvolutionParameters params, double []  temp) throws DMLRuntimeException {
-		int CRS = params.C * params.R * params.S;
-		
-		// Using a temporary array improves performance by not requiring binary search for getValue
-		// Since the access pattern depends on ConvolutionParameters, this serves as a temporary fix.
-		temp = getRowInDenseFormat(input, n, temp);
-		// final int nOffset = n * params.C*params.H*params.W;
-		for (int c = 0; c < CRS; ++c) {
-			int wOffset = c % params.S;
-			int hOffset = (c / params.S) % params.R;
-			int cInput = c / params.R / params.S;
-			for (int h = 0; h < params.P; ++h) {
-				int outOffset = (c * params.P + h) * params.Q;
-				int hPadded = h * params.stride_h - params.pad_h + hOffset;
-				int tempOffset = (cInput * params.H + hPadded) * params.W;
-				// int inputOffset = nOffset + tempOffset;
-				if (hPadded < 0 || hPadded >= params.H) {
-					Arrays.fill(outputArray, outOffset, outOffset+params.Q, 0);
-				} else {
-					for (int w = 0; w < params.Q; ++w) {
-						int wPadded = w * params.stride_w - params.pad_w + wOffset;
-						if (wPadded >= 0 && wPadded < params.W) 
-							outputArray[outOffset + w] = temp[tempOffset + wPadded];
-						else
-							outputArray[outOffset + w] = 0;
-					}
-				}
-			}
-		}
-	}
-	
-	private static void doIm2col(int n, MatrixBlock output, ConvolutionParameters params, double []  temp) throws DMLRuntimeException {
-		double [] inputArray = null;
-		if (!params.input1.isInSparseFormat())
-			inputArray = params.input1.getDenseBlock();
-		double [] outputArray = null;
-		if(!output.isInSparseFormat())
-			outputArray = output.getDenseBlock();
-		else 
-			throw new DMLRuntimeException("Sparse output is not supported for im2col");
-		
-		if(inputArray != null)
-			doIm2colDense(n, inputArray, outputArray, params);
-		else
-			doIm2colSparse(n, params.input1, outputArray, params, temp);
-	}
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardDataHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardDataHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardDataHelper.java
new file mode 100644
index 0000000..609af11
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardDataHelper.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.matrix.data;
+
+import java.util.Arrays;
+import java.util.concurrent.Callable;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.utils.NativeHelper;
+
+/**
+ * This class contains the set of operators used for performing conv2d backward data
+ */
+public class LibMatrixDNNConv2dBackwardDataHelper {
+
+	/**
+	 * This operator is used only if native is enabled and filter is sparse. 
+	 * dout is converted into dense if sparse.
+	 */
+	public static class SparseNativeConv2dBackwardDataDense implements Callable<Long> 
+	{
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params; 
+		public SparseNativeConv2dBackwardDataDense(int rl, int ru, ConvolutionParameters params) {
+			_rl = rl; _ru = ru;
+			_params = params;
+		}
+
+		@Override
+		public Long call() throws Exception {
+			int CHW = _params.C*_params.H*_params.W;
+			double [] ret = new double[CHW];
+			double [] filterArr = _params.input1.getDenseBlock();
+			double [] dout_n = new double[_params.P*_params.Q*_params.K];
+			for(int n = _rl; n < _ru; n++) {
+				LibMatrixDNNHelper.getRowInDenseFormat(_params.input2, n, dout_n);
+				if(n > _rl)
+					Arrays.fill(ret, 0);
+				NativeHelper.conv2dBackwardDataDense(filterArr, dout_n, ret, 1, 
+						_params.C, _params.H, _params.W, _params.K, 
+						_params.R, _params.S, _params.stride_h, _params.stride_w, _params.pad_h, _params.pad_w, _params.P, _params.Q, 1);
+				System.arraycopy(ret, 0, _params.output.getDenseBlock(), n*CHW, CHW);
+			}
+			return 0L;
+		}
+	}
+	
+	/**
+	 * General conv2d backward data operator
+	 */
+	public static class Conv2dBackwardData implements Callable<Long> {
+
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params; 
+		public Conv2dBackwardData(int rl, int ru, ConvolutionParameters params) {
+			_rl = rl; _ru = ru;
+			_params = params;
+		}
+		
+		@Override
+		public Long call() throws Exception {
+			int PQ = _params.P*_params.Q; int K = _params.K; int CRS = _params.C*_params.R*_params.S;
+			MatrixBlock filter = _params.input1;
+			MatrixBlock dout = _params.input2;
+			MatrixBlock dout_reshaped = new MatrixBlock(PQ, K, false);
+			dout_reshaped.allocateDenseBlock();
+			LibMatrixDNNRotate180Helper.Rotate180Worker rotate180Worker = 
+					LibMatrixDNNRotate180Helper.Rotate180Worker.getWorker( dout, dout_reshaped.getDenseBlock(), _params, true);
+			long time1 = 0; long time2 = 0;
+			for(int n = _rl; n < _ru; n++)  {
+				// rotate180(dout[n,]) => dout_reshaped
+				rotate180Worker.execute(n, 0);
+				
+				// dout_reshaped %*% filter => temp
+				MatrixBlock temp = new MatrixBlock(PQ, CRS, false);
+				long t1 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+				LibMatrixDNNHelper.singleThreadedMatMult(dout_reshaped, filter, temp, true, false, _params);
+				long t2 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+				// col2im(temp) => output[n,] 
+				LibMatrixDNNHelper.doCol2imOverSingleImage(n, temp, _params);
+				long t3 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+				
+				if(DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
+					time1 += t2 - t1;
+					time2 += t3 - t2;
+				}
+			}
+			if(DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
+				LibMatrixDNN.loopedConvBwdDataMatMultTime.addAndGet(time1);
+				LibMatrixDNN.loopedConvBwdDataCol2ImTime.addAndGet(time2);
+			}
+			return 0L;
+		}
+		
+	}
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
new file mode 100644
index 0000000..560f32c
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.matrix.data;
+
+import java.util.concurrent.Callable;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.runtime.util.ConvolutionUtils;
+import org.apache.sysml.utils.NativeHelper;
+
+public class LibMatrixDNNConv2dBackwardFilterHelper {
+
+	/**
+	 * This operator is used only if native is enabled and input is sparse. 
+	 * dout is converted into dense if sparse.
+	 */
+	public static class SparseNativeConv2dBackwardFilterDense implements Callable<Long> 
+	{
+
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params; 
+		public SparseNativeConv2dBackwardFilterDense(int rl, int ru, ConvolutionParameters params) {
+			_rl = rl; _ru = ru;
+			_params = params;
+		}
+		
+		@Override
+		public Long call() throws Exception {
+			int CRS = _params.C*_params.R*_params.S; 
+			double [] dout_n = new double[_params.P*_params.Q*_params.K];
+			LibMatrixDNNRotate180Helper.Rotate180Worker rotate180Worker = 
+					LibMatrixDNNRotate180Helper.Rotate180Worker.getWorker( _params.input2, dout_n, _params, true);
+			// partialRetBlock is size: [params.C*params.R*params.S, params.K]
+			double [] partialRetBlock = new double[CRS*_params.K];
+			for(int n = _rl; n < _ru; n++) {
+				if( !_params.input1.getSparseBlock().isEmpty(n) ) {
+					// rotate180(dout[n,]) => dout_n
+					rotate180Worker.execute(n, 0);
+					
+					int apos = _params.input1.getSparseBlock().pos(n);
+					int alen = _params.input1.getSparseBlock().size(n);
+					int[] aix = _params.input1.getSparseBlock().indexes(n);
+					double[] avals = _params.input1.getSparseBlock().values(n);
+					NativeHelper.conv2dBackwardFilterSparseDense(apos, alen, aix, avals, 
+							dout_n, partialRetBlock, 1, _params.C, _params.H, _params.W, _params.K, 
+							_params.R, _params.S, _params.stride_h, _params.stride_w, _params.pad_h, _params.pad_w, _params.P, _params.Q, 1);
+				}
+			}
+			inplaceTransposedAddition(partialRetBlock, _params);
+			return 0L;
+		}
+	}
+	
+	/**
+	 * General conv2d backward data operator
+	 */
+	public static class Conv2dBackwardFilter implements Callable<Long> {
+
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params; 
+		public Conv2dBackwardFilter(int rl, int ru, ConvolutionParameters params) {
+			_rl = rl; _ru = ru;
+			_params = params;
+		}
+		
+		@Override
+		public Long call() throws Exception {
+			int PQ = _params.P*_params.Q; int K = _params.K; int CRS = _params.C*_params.R*_params.S;
+			MatrixBlock dout = _params.input2;
+			MatrixBlock im2ColOutBlock = new MatrixBlock(CRS, PQ, false);
+			im2ColOutBlock.allocateDenseBlock();
+			MatrixBlock dout_reshaped = new MatrixBlock(PQ, K, false);
+			dout_reshaped.allocateDenseBlock();
+			LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker( _params.input1, im2ColOutBlock, _params, true);
+			LibMatrixDNNRotate180Helper.Rotate180Worker rotate180Worker = 
+					LibMatrixDNNRotate180Helper.Rotate180Worker.getWorker( dout, dout_reshaped.getDenseBlock(), _params, true);
+			double [] partialRetBlock = new double[CRS*_params.K];
+			long time1 = 0; long time2 = 0;
+			for(int n = _rl; n < _ru; n++) {
+				// rotate180(dout[n,]) => dout_reshaped
+				rotate180Worker.execute(n, 0);
+				
+				// im2col(input) => _im2ColOutBlock
+				long t1 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+				im2ColWorker.execute(n);
+				long t2 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+				
+				MatrixBlock temp = new MatrixBlock(CRS, K, false);
+				LibMatrixDNNHelper.singleThreadedMatMult(im2ColOutBlock, dout_reshaped, temp, true, true, _params);
+				long t3 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+				
+				if(!temp.isEmptyBlock()) {
+					// partialRetBlock is size: [params.C*params.R*params.S, params.K]
+					ConvolutionUtils.binaryOperationInPlace(temp, partialRetBlock, 0, K, 0, CRS, 
+							LibMatrixDNN._binaryElementWiseAddition);
+				}
+				
+				if(DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
+					time1 += t2 - t1;
+					time2 += t3 - t2;
+				}
+			}
+			inplaceTransposedAddition(partialRetBlock, _params);
+			if(DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
+				LibMatrixDNN.loopedConvBwdFilterIm2ColTime.addAndGet(time1);
+				LibMatrixDNN.loopedConvBwdFilterMatMultTime.addAndGet(time2);
+			}
+			return 0L;
+		}
+	}
+	private static synchronized void inplaceTransposedAddition(double [] partialRetBlock, ConvolutionParameters params) {
+		// Perform transposed addition: output of size [K, CRS] += partialRetBlock of size [CRS,K]
+		int iter = 0; int CRS = params.C*params.R*params.S; int K = params.K;
+		double [] outputArr = params.output.denseBlock;
+		for(int i = 0; i < CRS; i++) {
+			for(int j = 0; j < K; j++, iter++) {
+				int index = j*CRS+i;
+				outputArr[index] += partialRetBlock[iter];
+			}
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19eed8f3/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
new file mode 100644
index 0000000..b2c4d67
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.matrix.data;
+
+import java.util.ArrayList;
+import java.util.concurrent.Callable;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.utils.NativeHelper;
+
+/**
+ * This class contains the set of operators used for performing conv2d
+ */
+public class LibMatrixDNNConv2dHelper {
+
+	/**
+	 * Performs convolution via: partialCopy1(filter %*% im2col(input)) = output.
+	 * This operator has less memory pressure than LoopedIm2ColConv2dAllChannels.
+	 */
+	public static class LoopedIm2ColConv2dOneChannel implements Callable<Long> 
+	{
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params; ArrayList<MatrixBlock> _filters;
+		public LoopedIm2ColConv2dOneChannel(int rl, int ru, ConvolutionParameters params, ArrayList<MatrixBlock> filters) {
+			_rl = rl; _ru = ru;
+			_params = params; 
+			_filters = filters;
+		}
+		
+		@Override
+		public Long call() throws Exception {
+			int PQ = _params.P*_params.Q; int K = _params.K;
+			int RS = _params.R*_params.S;
+			MatrixBlock im2ColOutBlock = new MatrixBlock(RS, PQ, false);
+			im2ColOutBlock.allocateDenseBlock();
+			LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker( _params.input1, im2ColOutBlock, _params, false);
+			long time1 = 0; long time2 = 0;
+			for(int n = _rl; n < _ru; n++)  {
+				for(int c = 0; c < _params.C; c++)  {
+					// im2col(input) => _im2ColOutBlock
+					long t1 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+					im2ColWorker.execute(n, c);
+					long t2 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+					
+					// filter %*% _im2ColOutBlock => matMultOutBlock
+					MatrixBlock matMultOutBlock = new MatrixBlock(K, PQ, false);
+					LibMatrixDNNHelper.singleThreadedMatMult(_filters.get(c), im2ColOutBlock, matMultOutBlock, false, true, _params);
+					long t3 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+					
+					if(DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
+						time1 += t2 - t1;
+						time2 += t3 - t2;
+					}
+					
+					// Add the matrix matMultOutBlock of shape [K X PQ] to params.output.denseBlock + destPos
+					add(matMultOutBlock, _params.output.getDenseBlock(), n*K*PQ, K, PQ);
+				}
+			}
+			if(_params.bias != null) {
+				// bias is always converted to dense format
+				LibMatrixDNNHelper.addBias(_rl, _ru, _params.output.getDenseBlock(), _params.bias.getDenseBlock(), K, PQ);
+			}
+			if(DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
+				LibMatrixDNN.loopedConvIm2ColTime.addAndGet(time1);
+				LibMatrixDNN.loopedConvMatMultTime.addAndGet(time2);
+			}
+			return 0L;
+		}
+		
+		// Copy the matrix src of shape [K X PQ] to params.output.denseBlock + destPos
+		private void add(MatrixBlock src, double [] dest, int destPos, int K, int PQ) {
+			// Copying is required as LibMatrixMult.matrixMult (and/or Java) is not pointer aware.
+			// This is not required in Native implementation
+			if(!src.isEmptyBlock()) {
+				if(src.isInSparseFormat()) {
+					// Copy the sparse matrix matMultOutBlock of shape [K X PQ] to 
+					// params.output.denseBlock + destPos
+					for(int k = 0; k < src.getNumRows(); k++) {
+						if( !src.sparseBlock.isEmpty(k) ) {
+							int apos = src.sparseBlock.pos(k);
+							int alen = src.sparseBlock.size(k);
+							int[] aix = src.sparseBlock.indexes(k);
+							double[] avals = src.sparseBlock.values(k);
+							for(int j = apos; j < apos+alen; j++) {
+								int pqIndex = aix[j];
+								dest[destPos + k*PQ + pqIndex ] += avals[j];
+							}
+						}
+					}
+				}
+				else {
+					for(int i = 0; i < K * PQ; i++) {
+						dest[destPos+i] += src.denseBlock[i];
+					}
+				}
+			}
+		}
+	}	
+	
+	/**
+	 * Performs convolution via: partialCopy1(filter %*% im2col(input)) = output
+	 */
+	public static class LoopedIm2ColConv2dAllChannels implements Callable<Long> 
+	{
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params;
+		public LoopedIm2ColConv2dAllChannels(int rl, int ru, ConvolutionParameters params) {
+			_rl = rl; _ru = ru;
+			_params = params;
+		}
+
+		@Override
+		public Long call() throws Exception {
+			int PQ = _params.P*_params.Q; int K = _params.K; int CRS = _params.C*_params.R*_params.S;
+			MatrixBlock im2ColOutBlock = new MatrixBlock(CRS, PQ, false);
+			im2ColOutBlock.allocateDenseBlock();
+			LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker( _params.input1, im2ColOutBlock, _params, true);
+			long time1 = 0; long time2 = 0;
+			for(int n = _rl; n < _ru; n++)  {
+				// im2col(input) => _im2ColOutBlock
+				long t1 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+				im2ColWorker.execute(n);
+				long t2 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+				
+				// filter %*% _im2ColOutBlock => matMultOutBlock
+				MatrixBlock matMultOutBlock = new MatrixBlock(K, PQ, false);
+				LibMatrixDNNHelper.singleThreadedMatMult(_params.input2, im2ColOutBlock, matMultOutBlock, false, true, _params);
+				long t3 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+				
+				if(DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
+					time1 += t2 - t1;
+					time2 += t3 - t2;
+				}
+				
+				// Copy the matrix matMultOutBlock of shape [K X PQ] to params.output.denseBlock + destPos
+				partialCopy1(matMultOutBlock, _params.output.getDenseBlock(), n*K*PQ, K, PQ);
+			}
+			if(_params.bias != null) {
+				// bias is always converted to dense format
+				LibMatrixDNNHelper.addBias(_rl, _ru, _params.output.getDenseBlock(), _params.bias.getDenseBlock(), K, PQ);
+			}
+			if(DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
+				LibMatrixDNN.loopedConvIm2ColTime.addAndGet(time1);
+				LibMatrixDNN.loopedConvMatMultTime.addAndGet(time2);
+			}
+			return 0L;
+		}
+		
+		// Copy the matrix src of shape [K X PQ] to params.output.denseBlock + destPos
+		private void partialCopy1(MatrixBlock src, double [] dest, int destPos, int K, int PQ) {
+			// Copying is required as LibMatrixMult.matrixMult (and/or Java) is not pointer aware.
+			// This is not required in Native implementation
+			if(!src.isEmptyBlock()) {
+				if(src.isInSparseFormat()) {
+					// Copy the sparse matrix matMultOutBlock of shape [K X PQ] to 
+					// params.output.denseBlock + destPos
+					for(int k = 0; k < src.getNumRows(); k++) {
+						if( !src.sparseBlock.isEmpty(k) ) {
+							int apos = src.sparseBlock.pos(k);
+							int alen = src.sparseBlock.size(k);
+							int[] aix = src.sparseBlock.indexes(k);
+							double[] avals = src.sparseBlock.values(k);
+							for(int j = apos; j < apos+alen; j++) {
+								int pqIndex = aix[j];
+								dest[destPos + k*PQ + pqIndex ] = avals[j];
+							}
+						}
+					}
+				}
+				else 
+					System.arraycopy(src.denseBlock, 0, dest, destPos, K * PQ);
+			}
+		}
+	}
+	
+	
+	/**
+	 * This operator is used only if native is enabled, filter is dense and input is sparse
+	 */
+	public static class SparseNativeConv2d implements Callable<Long> 
+	{
+		public int _rl; public int _ru; 
+		private final ConvolutionParameters _params;
+		public SparseNativeConv2d(int rl, int ru, ConvolutionParameters params) {
+			_rl = rl; _ru = ru;
+			_params = params;
+		}
+
+		@Override
+		public Long call() throws Exception {
+			int KPQ = _params.K*_params.P*_params.Q;
+			double[] temp = new double[KPQ];
+			for(int n = _rl; n < _ru; n++)  {
+				if( !_params.input1.getSparseBlock().isEmpty(n) ) {
+					int apos = _params.input1.getSparseBlock().pos(n);
+					int alen = _params.input1.getSparseBlock().size(n);
+					int[] aix = _params.input1.getSparseBlock().indexes(n);
+					double[] avals = _params.input1.getSparseBlock().values(n);
+					NativeHelper.conv2dSparse(apos, alen, aix, avals, _params.input2.getDenseBlock(), temp, 
+							1, _params.C, _params.H, _params.W, _params.K, _params.R, _params.S, 
+							_params.stride_h, _params.stride_w, _params.pad_h, _params.pad_w, _params.P, _params.Q, 1);
+					System.arraycopy(temp, 0, _params.output.denseBlock, n*KPQ, KPQ);
+				}
+			}
+			return 0L;
+		}
+	}
+}