You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2016/07/09 03:30:23 UTC

incubator-systemml git commit: [SYSTEMML-769] Improved performance of LibMatrixDNN's conv2d and conv2d_backward_filter

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 20e05458b -> 2ebf885a6


[SYSTEMML-769] Improved performance of LibMatrixDNN's conv2d and
conv2d_backward_filter

- Fixed bug while iterating through sparse conv2d_backward_filter
- Also added vectorized conv2d

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2ebf885a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2ebf885a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2ebf885a

Branch: refs/heads/master
Commit: 2ebf885a6919e1cb0598e2aab4d0ffb46b8e0ab5
Parents: 20e0545
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Fri Jul 8 20:26:16 2016 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Fri Jul 8 20:28:25 2016 -0700

----------------------------------------------------------------------
 .../sysml/runtime/matrix/data/LibMatrixDNN.java | 558 ++++++++++++++-----
 1 file changed, 410 insertions(+), 148 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2ebf885a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index d9faf7e..26e2b8b 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -20,6 +20,7 @@ package org.apache.sysml.runtime.matrix.data;
 
 import java.lang.ref.SoftReference;
 import java.util.ArrayList;
+import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ConcurrentHashMap;
@@ -29,12 +30,16 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicLong;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
 
 
 public class LibMatrixDNN {
+	
+	protected static final Log LOG =  LogFactory.getLog(LibMatrixDNN.class.getName());
 
 	public static final boolean ALLOW_MULTI_THREADED_OPS = true;
 	// Using hashmap to avoid any performance impacts of multimap
@@ -62,13 +67,14 @@ public class LibMatrixDNN {
 	enum TaskType {
 		ReshapeCol, Rotate180, Im2Col, Col2Im, MaxPooling_Forward, MaxPooling_Backward, LoopBasedConv2d
 	}
-	public static final int TASK_SIZE = 64; // to take care of extremely small tasks
 	
 	public static class TemporaryConvolutionData {
 		public int [] minIndexArrR;
 		public int [] minIndexArrS;
 		public int [] maxIndexArrR;
 		public int [] maxIndexArrS;
+		int minCommonIndexS;
+		int maxCommonIndexS;
 	}
 	
 	public static class ConvolutionParameters {
@@ -159,6 +165,9 @@ public class LibMatrixDNN {
 				dout.getNumRows() != params.N || dout.getNumColumns() != params.K*params.P*params.Q) {
 			throw new DMLRuntimeException("Incorrect input to conv2d_backward_filter");
 		}
+		if(params.stride_h <= 0 || params.stride_w <= 0) {
+			throw new DMLRuntimeException("Only positive strides supported");
+		}
 		
 		int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
 		if(!ALLOW_MULTI_THREADED_OPS || constrainedNumThreads <= 1) {
@@ -198,7 +207,7 @@ public class LibMatrixDNN {
 		}
 	}
 	
-	public static void doConv2d_Backward_Filter(int k, int c, int r, int s, ConvolutionParameters params) {
+	private static void doConv2d_Backward_Filter(int k, int c, int r, int s, ConvolutionParameters params) throws DMLRuntimeException {
 		double [] inputArray = null;
 		if (!params.input1.isInSparseFormat())
 			inputArray = params.input1.getDenseBlock();
@@ -207,62 +216,125 @@ public class LibMatrixDNN {
 			doutArray = params.input2.getDenseBlock();
 		double [] outputArray = params.output.getDenseBlock();
 		
-		long outputVal = 0;
-		if(doutArray != null) {
-			for (int n = 0; n < params.N; n++) {
-				for (int p = 0; p < params.P; p++) {
-					for (int q = 0; q < params.Q; q++) {
-						int h = p*params.stride_h + r - params.pad_h;
-						int w = q*params.stride_w + s - params.pad_w;
-						if(h >= 0 && h < params.H && w >= 0 && w < params.W) {
-							double doutVal = doutArray[n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q + q];
-							if(doutVal != 0) {
-								if(inputArray != null)
-									outputVal += doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w];
-								else 
-									outputVal += doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w);
-							}
-						}
-					}
-				}
-			}
+		double outputVal = 0;
+		if(inputArray == null && doutArray == null) {
+			outputVal = doConv2d_Backward_Filter_SparseSparse(k, c, r, s, params);
+		}
+		else if(inputArray != null && doutArray == null) {
+			outputVal = doConv2d_Backward_Filter_DenseSparse(k, c, r, s, params, inputArray);
+		}
+		else if(inputArray == null && doutArray != null) {
+			outputVal = doConv2d_Backward_Filter_SparseDense(k, c, r, s, params, doutArray);
 		}
 		else {
-			MatrixBlock dout = params.input2;
-			if( !dout.isEmptyBlock(false) ) {
-				int start=0;
-				int rlen = dout.getNumRows();
-				int clen = dout.getNumColumns();
-				for(int r1=0; r1<Math.min(dout.sparseBlock.numRows(), rlen); r1++, start+=clen)
-				{
-					if(dout.sparseBlock.isEmpty(r1)) 
-						continue;
-					int pos = dout.sparseBlock.pos(r1);
-					int len = dout.sparseBlock.size(r1);
-					int[] aix = dout.sparseBlock.indexes(r1);
-					double[] avals = dout.sparseBlock.values(r1);
-					
-					for(int i=pos; i<pos+len; i++) {
-						int index = start+aix[i];
-						double doutVal = avals[i];
-						int n = index / clen; 
-						int p = index / params.Q;
-						int q = index % params.Q;
-						int h = p*params.stride_h + r - params.pad_h;
-						int w = q*params.stride_w + s - params.pad_w;
-						if(h >= 0 && h < params.H && w >= 0 && w < params.W && doutVal != 0) {
-							if(inputArray != null)
-								outputVal += doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w];
-							else 
-								outputVal += doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w);
-						}
-					}
+			outputVal = doConv2d_Backward_Filter_DenseDense(k, c, r, s, params, inputArray, doutArray);
+		}
+		
+		outputArray[k*params.C*params.R*params.S + c*params.R*params.S + r*params.S + s] = outputVal;
+	}
+	
+	private static double doConv2d_Backward_Filter_SparseDense(int k, int c, int r, int s, ConvolutionParameters params, double [] doutArray) throws DMLRuntimeException {
+		double outputVal = 0;
+		// To ensure h >= 0 && h < params.H 
+		int pMin = (int) Math.max(0, Math.ceil(((double)(params.pad_h-r))/params.stride_h));
+		int qMin = (int) Math.max(0, Math.ceil(((double)(params.pad_w-s))/params.stride_w));
+		// To ensure w >= 0 && w < params.W 
+		int pMax = (int) Math.min(params.P, Math.ceil(((double)(params.H+params.pad_h-r))/params.stride_h));
+		int qMax = (int) Math.min(params.Q, Math.ceil(((double)(params.W+params.pad_w-s))/params.stride_w));
+		
+		// TODO: Optimize this case
+		for (int n = 0; n < params.N; n++) {
+			int doutOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
+			for (int p = pMin; p < pMax; p++) {
+				for (int q = qMin; q < qMax; q++) {
+					int h = p*params.stride_h + r - params.pad_h;
+					int w = q*params.stride_w + s - params.pad_w;
+					outputVal += doutArray[doutOffset + p*params.Q + q]*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w);
 				}
-			}	
+			}
 		}
 		
+		return outputVal;
+	}
+	
+	private static double doConv2d_Backward_Filter_DenseDense(int k, int c, int r, int s, ConvolutionParameters params, double [] inputArray, double [] doutArray) {
+		double outputVal = 0;
+		// To ensure h >= 0 && h < params.H 
+		int pMin = (int) Math.max(0, Math.ceil(((double)(params.pad_h-r))/params.stride_h));
+		int qMin = (int) Math.max(0, Math.ceil(((double)(params.pad_w-s))/params.stride_w));
+		// To ensure w >= 0 && w < params.W 
+		int pMax = (int) Math.min(params.P, Math.ceil(((double)(params.H+params.pad_h-r))/params.stride_h));
+		int qMax = (int) Math.min(params.Q, Math.ceil(((double)(params.W+params.pad_w-s))/params.stride_w));
 		
-		outputArray[k*params.C*params.R*params.S + c*params.R*params.S + r*params.S + s] = outputVal;
+		for (int n = 0; n < params.N; n++) {
+			int inputOffset =  n*params.C*params.H*params.W + c*params.H*params.W + s - params.pad_w;
+			int doutOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
+			for (int p = pMin; p < pMax; p++) {
+				int h = p*params.stride_h + r - params.pad_h;
+				for (int q = qMin; q < qMax; q++) {
+					int w = q*params.stride_w;
+					outputVal += doutArray[doutOffset + p*params.Q + q]*inputArray[inputOffset + h*params.W+w];
+				}
+			}
+		}
+				
+		return outputVal;
+	}
+	
+	private static void computeTensorIndexes(int i, int j, int [] ret, int N, int C, int H, int W) throws DMLRuntimeException {
+		ret[0] = i;
+		ret[1] = j / (H*W);
+		ret[2] = (j - ret[1]*(H*W))/W;
+		ret[3] = j % W;
+	}
+	
+	private static double doConv2d_Backward_Filter_DenseSparse(int k, int c, int r, int s, ConvolutionParameters params, double [] inputArray) throws DMLRuntimeException {
+		MatrixBlock dout = params.input2;
+		double outputVal = 0;
+		Iterator<IJV> iter = dout.sparseBlock.getIterator();
+		int [] tensorIndexes = new int[4];
+		while(iter.hasNext()) {
+			IJV ijv = iter.next();
+			computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.K, params.P, params.Q);
+			if(k == tensorIndexes[1]) {
+				int n = tensorIndexes[0];
+				int p = tensorIndexes[2];
+				int q = tensorIndexes[3];
+				
+				double doutVal = ijv.getV();
+				int h = p*params.stride_h + r - params.pad_h;
+				int w = q*params.stride_w + s - params.pad_w;
+				if(h >= 0 && h < params.H && w >= 0 && w < params.W) {
+					outputVal += doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w];
+				}
+			}
+		}
+		return outputVal;
+	}
+	
+	private static double doConv2d_Backward_Filter_SparseSparse(int k, int c, int r, int s, ConvolutionParameters params) throws DMLRuntimeException {
+		MatrixBlock dout = params.input2;
+		double outputVal = 0;
+		Iterator<IJV> iter = dout.sparseBlock.getIterator();
+		int [] tensorIndexes = new int[4];
+		
+		while(iter.hasNext()) {
+			IJV ijv = iter.next();
+			computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.K, params.P, params.Q);
+			if(k == tensorIndexes[1]) {
+				int n = tensorIndexes[0];
+				int p = tensorIndexes[2];
+				int q = tensorIndexes[3];
+				
+				double doutVal = ijv.getV();
+				int h = p*params.stride_h + r - params.pad_h;
+				int w = q*params.stride_w + s - params.pad_w;
+				if(h >= 0 && h < params.H && w >= 0 && w < params.W) {
+					outputVal += doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w);
+				}
+			}
+		}
+		return outputVal;
 	}
 	
 	private static class ConvBackwardFilterTask implements Callable<Object> {
@@ -294,25 +366,55 @@ public class LibMatrixDNN {
 			throw new DMLRuntimeException("Incorrect input to conv2d");
 		}
 		
-		params.tmpData = new TemporaryConvolutionData();		
-		params.tmpData.minIndexArrR = new int[params.R];
-		params.tmpData.maxIndexArrR = new int[params.R];
-		params.tmpData.minIndexArrS = new int[params.S];
-		params.tmpData.maxIndexArrS = new int[params.S];
-		for (int r = 0; r < params.R; r++) {
-			params.tmpData.minIndexArrR[r] = getMinPQ(params.pad_h, r, params.stride_h);
-			params.tmpData.maxIndexArrR[r] = getMaxPQ(params.pad_h, r, params.stride_h, params.P, params.H);
+		params.tmpData = new TemporaryConvolutionData();
+		if(input.isInSparseFormat()) {
+			params.tmpData.minIndexArrR = new int[params.H];
+			params.tmpData.minIndexArrS = new int[params.W];
+			for(int h = 0; h < params.H; h++) {
+				for (int r = 0; r < params.R; r++) {
+					// int h = p*params.stride_h + r - params.pad_h;
+					if((h + params.pad_h - r) % params.stride_h == 0) {
+						params.tmpData.minIndexArrR[h] = r;
+						break;
+					}
+				}
+			}
+			for(int w = 0; w < params.W; w++) {
+				for (int s = 0; s < params.S; s++) {
+					// int h = p*params.stride_h + r - params.pad_h;
+					if((w + params.pad_w - s) % params.stride_w == 0) {
+						params.tmpData.minIndexArrS[w] = s;
+						break;
+					}
+				}
+			}
 		}
-		for (int s = 0; s < params.S; s++) {
-			params.tmpData.minIndexArrS[s] = getMinPQ(params.pad_w, s, params.stride_w);
-			params.tmpData.maxIndexArrS[s] = getMaxPQ(params.pad_w, s, params.stride_w, params.Q, params.W);
+		else {
+			params.tmpData.minIndexArrR = new int[params.R];
+			params.tmpData.maxIndexArrR = new int[params.R];
+			params.tmpData.minIndexArrS = new int[params.S];
+			params.tmpData.maxIndexArrS = new int[params.S];
+			for (int r = 0; r < params.R; r++) {
+				params.tmpData.minIndexArrR[r] = getMinPQ(params.pad_h, r, params.stride_h);
+				params.tmpData.maxIndexArrR[r] = getMaxPQ(params.pad_h, r, params.stride_h, params.P, params.H);
+			}
+			for (int s = 0; s < params.S; s++) {
+				params.tmpData.minIndexArrS[s] = getMinPQ(params.pad_w, s, params.stride_w);
+				params.tmpData.maxIndexArrS[s] = getMaxPQ(params.pad_w, s, params.stride_w, params.Q, params.W);
+			}
+			params.tmpData.minCommonIndexS = params.tmpData.minIndexArrS[0];
+			params.tmpData.maxCommonIndexS = params.tmpData.maxIndexArrS[0];
+			for (int s = 1; s < params.S; s++) {
+				params.tmpData.minCommonIndexS = Math.max(params.tmpData.minCommonIndexS, params.tmpData.minIndexArrS[s]);
+				params.tmpData.maxCommonIndexS = Math.min(params.tmpData.maxCommonIndexS, params.tmpData.maxIndexArrS[s]);
+			}
 		}
 		
 		int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
 		if(!ALLOW_MULTI_THREADED_OPS || constrainedNumThreads <= 1) {
 			for (int n = 0; n < params.N; n++) {
 				for (int k = 0; k < params.K; k++) {
-					doLoopBasedConv2d(n, k, params);
+					doLoopBasedConv2d(n, n+1, k, params);
 				}
 			}
 		}
@@ -345,102 +447,255 @@ public class LibMatrixDNN {
 		}
 	}
 	
-	/**
-	 * This is essentially memory-less operation and can be used when the memory pressure is extremely high.
-	 * @param n
-	 * @param k
-	 * @param params
-	 */
-	private static void doLoopBasedConv2d(int n, int k, ConvolutionParameters params) {
-		double [] inputArray = null;
-		if (!params.input1.isInSparseFormat())
-			inputArray = params.input1.getDenseBlock();
-		double [] filterArray = null;
-		if (!params.input2.isInSparseFormat())
-			filterArray = params.input2.getDenseBlock();
+	private static void doLoopBasedConv2dDenseDense(int n1, int n2, int k, ConvolutionParameters params, 
+			double [] inputArray, double [] filterArray) {
 		double [] outputArray = params.output.getDenseBlock();
-		
-		int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
-		
 		int [] minIndexArrR = params.tmpData.minIndexArrR;
 		int [] maxIndexArrR = params.tmpData.maxIndexArrR;
 		int [] minIndexArrS = params.tmpData.minIndexArrS;
 		int [] maxIndexArrS = params.tmpData.maxIndexArrS;
 		
-		if(inputArray != null && filterArray != null) {
-			for (int c = 0; c < params.C; c++) {
-				for (int r = 0; r < params.R; r++) {
-					int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
-					for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
-						for (int s = 0; s < params.S; s++) {
-							double filterVal = filterArray[filterOffset + s];
-							if(filterVal != 0) {
-								int h = p*params.stride_h + r - params.pad_h;
-								for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
-									int w = q*params.stride_w + s - params.pad_w;
-									outputArray[outputOffset + p*params.Q + q] += denseConvMultiply(inputArray, filterVal, params, n, c, h, w);
+		int minCommonIndexS = params.tmpData.minCommonIndexS;
+		int maxCommonIndexS = params.tmpData.maxCommonIndexS;
+		
+		
+		int minS = 0;
+		if(params.S >= 4) {
+			minS = params.S - params.S % 4;
+			for (int n = n1; n < n2; n++) {
+				for (int c = 0; c < params.C; c++) {
+					for (int r = 0; r < params.R; r++) {
+						final int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
+						for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
+							final int h = p*params.stride_h + r - params.pad_h;
+							final int inputOffSet = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w;
+							final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q;
+							// ------------------------------------------------------------------------
+							// Efficient striding with vectorization
+							for (int q = minCommonIndexS; q < maxCommonIndexS; q++) {
+								final int wOffset = inputOffSet + q*params.stride_w;
+								final int outOffsetWithQ = outputOffset + q;
+								for (int s = 0; s < minS; s += 4) {
+									final int inOffsetWithS = wOffset + s;
+									final int filterOffsetWithS = filterOffset + s;
+									outputArray[outOffsetWithQ] += inputArray[inOffsetWithS]*filterArray[filterOffsetWithS]
+											+ inputArray[inOffsetWithS+1]*filterArray[filterOffsetWithS+1]
+											+ inputArray[inOffsetWithS+2]*filterArray[filterOffsetWithS+2]
+											+ inputArray[inOffsetWithS+3]*filterArray[filterOffsetWithS+3];
 								}
 							}
+							// ------------------------------------------------------------------------
 						}
 					}
 				}
 			}
 		}
-		else if(inputArray != null && filterArray == null) {
+		
+		for (int n = n1; n < n2; n++) {
 			for (int c = 0; c < params.C; c++) {
 				for (int r = 0; r < params.R; r++) {
+					final int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
 					for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
-						for (int s = 0; s < params.S; s++) {
-							double filterVal = params.input2.quickGetValue(k, c*params.R*params.S + r*params.S + s);
-							if(filterVal != 0) {
-								int h = p*params.stride_h + r - params.pad_h;
-								for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
-									int w = q*params.stride_w + s - params.pad_w;
-									outputArray[outputOffset + p*params.Q + q] += denseConvMultiply(inputArray, filterVal, params, n, c, h, w);
-								}
+						final int h = p*params.stride_h + r - params.pad_h;
+						final int inputOffSet = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w;
+						final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q;
+						// ------------------------------------------------------------------------
+						// Efficient striding
+						for (int q = minCommonIndexS; q < maxCommonIndexS; q++) {
+							final int wOffset = inputOffSet + q*params.stride_w;
+							for (int s = minS; s < params.S; s++) {
+								outputArray[outputOffset + q] += inputArray[wOffset + s]*filterArray[filterOffset + s];
 							}
 						}
+						// ------------------------------------------------------------------------
 					}
 				}
 			}
-		}
-		else if(inputArray == null && filterArray != null) {
+			
+			
 			for (int c = 0; c < params.C; c++) {
 				for (int r = 0; r < params.R; r++) {
-					int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
+					final int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
 					for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
+						final int h = p*params.stride_h + r - params.pad_h;
+						final int inputOffSet = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w;
+						final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q;
+						// ------------------------------------------------------------------------
+						// Inefficient striding
 						for (int s = 0; s < params.S; s++) {
-							double filterVal = filterArray[filterOffset + s];
-							if(filterVal != 0) {
-								int h = p*params.stride_h + r - params.pad_h;
-								for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
-									int w = q*params.stride_w + s - params.pad_w;
-									outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(inputArray, filterVal, params, n, c, h, w);
-								}
+							for (int q = minIndexArrS[s]; q < minCommonIndexS; q++) {
+								final int w = q*params.stride_w + s;
+								outputArray[outputOffset + q] += inputArray[inputOffSet + w]*filterArray[filterOffset + s];
+							}
+							for (int q = maxCommonIndexS; q < maxIndexArrS[s]; q++) {
+								final int w = q*params.stride_w + s;
+								outputArray[outputOffset + q] += inputArray[inputOffSet + w]*filterArray[filterOffset + s];
 							}
 						}
+						// ------------------------------------------------------------------------
 					}
 				}
 			}
 		}
-		else if(inputArray == null && filterArray == null) {
-			for (int c = 0; c < params.C; c++) {
-				for (int r = 0; r < params.R; r++) {
-					for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
-						for (int s = 0; s < params.S; s++) {
-							double filterVal = params.input2.quickGetValue(k, c*params.R*params.S + r*params.S + s);
-							if(filterVal != 0) {
-								int h = p*params.stride_h + r - params.pad_h;
-								for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
-									int w = q*params.stride_w + s - params.pad_w;
-									outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(inputArray, filterVal, params, n, c, h, w);
-								}
+	}
+	
+	private static void doLoopBasedConv2dDenseSparse(int n, int k, ConvolutionParameters params, double [] inputArray) throws DMLRuntimeException {
+		double [] outputArray = params.output.getDenseBlock();
+		int [] minIndexArrR = params.tmpData.minIndexArrR;
+		int [] maxIndexArrR = params.tmpData.maxIndexArrR;
+		int [] minIndexArrS = params.tmpData.minIndexArrS;
+		int [] maxIndexArrS = params.tmpData.maxIndexArrS;
+		final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
+		
+		Iterator<IJV> iter = params.input2.sparseBlock.getIterator();
+		int [] tensorIndexes = new int[4];
+		
+		while(iter.hasNext()) {
+			IJV ijv = iter.next();
+			computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.K, params.C, params.R, params.S);
+			if(k == tensorIndexes[0]) {
+				int c = tensorIndexes[1];
+				int r = tensorIndexes[2];
+				int s = tensorIndexes[3];
+				double filterVal = ijv.getV();
+				final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W + s - params.pad_w;
+				for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
+					final int hOffset = inputOffset + (p*params.stride_h + r - params.pad_h)*params.W;
+					final int pOffset = outputOffset + p*params.Q;
+					for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
+						final int w = q*params.stride_w;
+						outputArray[pOffset + q] += inputArray[hOffset + w]*filterVal;
+					}
+				}
+			}
+		}
+	}
+	
+	private static void doLoopBasedConv2dSparseDense(int n, int k, ConvolutionParameters params, double [] filterArray) throws DMLRuntimeException {
+		double [] outputArray = params.output.getDenseBlock();
+		int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
+		
+		Iterator<IJV> iter = params.input1.sparseBlock.getIterator();
+		int [] tensorIndexes = new int[4];
+		
+		int [] minIndexArrR = params.tmpData.minIndexArrR;
+		int [] minIndexArrS = params.tmpData.minIndexArrS;
+		while(iter.hasNext()) {
+			IJV ijv = iter.next();
+			computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.C, params.H, params.W);
+			if(n == tensorIndexes[0]) {
+				int c = tensorIndexes[1];
+				int h = tensorIndexes[2];
+				int w = tensorIndexes[3];
+				double imgVal = ijv.getV();
+				for (int r = minIndexArrR[h]; r < params.R; r += params.stride_h) {
+					int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
+					for (int s = minIndexArrS[w]; s < params.S; s += params.stride_w) {
+						int p = (int)Math.ceil(((double)(h + params.pad_h - r)) / params.stride_h);
+						int q = (int)Math.ceil(((double)(w + params.pad_w - s)) / params.stride_w);
+						if(p >= 0 && p < params.P && q >= 0 && q < params.Q) {
+							double filterVal = filterArray[filterOffset + s];
+							outputArray[outputOffset + p*params.Q + q] += imgVal*filterVal;
+						}
+					}
+				}	
+			}
+		}
+	}
+	
+	private static void doLoopBasedConv2dSparseSparse(int n, int k, ConvolutionParameters params) throws DMLRuntimeException {
+		double [] outputArray = params.output.getDenseBlock();
+		int [] minIndexArrR = params.tmpData.minIndexArrR;
+		int [] maxIndexArrR = params.tmpData.maxIndexArrR;
+		int [] minIndexArrS = params.tmpData.minIndexArrS;
+		int [] maxIndexArrS = params.tmpData.maxIndexArrS;
+		int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
+		
+		
+		int [] tensorIndexesImage = new int[4];
+		int [] tensorIndexesFilter = new int[4];
+
+		Iterator<IJV> iter = params.input1.sparseBlock.getIterator();
+		
+		while(iter.hasNext()) {
+			IJV ijv = iter.next();
+			computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexesImage, params.N, params.C, params.H, params.W);
+			if(n == tensorIndexesImage[0]) {
+				int c = tensorIndexesImage[1];
+				int h = tensorIndexesImage[2];
+				int w = tensorIndexesImage[3];
+				double imgVal = ijv.getV();
+		
+				Iterator<IJV> iter1 = params.input2.sparseBlock.getIterator();
+				while(iter1.hasNext()) {
+					IJV ijv1 = iter1.next();
+					computeTensorIndexes(ijv1.getI(), ijv1.getJ(), tensorIndexesFilter, params.K, params.C, params.R, params.S);
+					if(k == tensorIndexesFilter[0] && c == tensorIndexesFilter[1]) {
+						int r =  tensorIndexesFilter[2];
+						int s =  tensorIndexesFilter[3];
+						if((r-minIndexArrR[h])%params.stride_h == 0 && (s-minIndexArrS[w])%params.stride_w == 0) {
+							int p = (int)Math.ceil(((double)(h + params.pad_h - r)) / params.stride_h);
+							int q = (int)Math.ceil(((double)(w + params.pad_w - s)) / params.stride_w);
+							if(p >= 0 && p < params.P && q >= 0 && q < params.Q) {
+								double filterVal =  ijv1.getV();
+								outputArray[outputOffset + p*params.Q + q] += imgVal*filterVal;
 							}
 						}
 					}
 				}
 			}
 		}
+		
+		while(iter.hasNext()) {
+			IJV ijv = iter.next();
+			computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexesFilter, params.K, params.C, params.R, params.S);
+			if(k == tensorIndexesFilter[0]) {
+				int c = tensorIndexesFilter[1];
+				int r = tensorIndexesFilter[2];
+				int s = tensorIndexesFilter[3];
+				double filterVal = ijv.getV();
+				for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
+					int h = p*params.stride_h + r - params.pad_h;
+					for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
+						int w = q*params.stride_w + s - params.pad_w;
+						// TODO: Improve the performance of sparse sparse 
+						outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(filterVal, params, n, c, h, w);
+					}
+				}
+			}
+		}
+	}
+	
+	/**
+	 * This is essentially memory-less operation and can be used when the memory pressure is extremely high.
+	 * @param n
+	 * @param k
+	 * @param params
+	 * @throws DMLRuntimeException 
+	 */
+	private static void doLoopBasedConv2d(int n1, int n2, int k, ConvolutionParameters params) throws DMLRuntimeException {
+		double [] inputArray = null;
+		if (!params.input1.isInSparseFormat())
+			inputArray = params.input1.getDenseBlock();
+		double [] filterArray = null;
+		if (!params.input2.isInSparseFormat())
+			filterArray = params.input2.getDenseBlock();
+		
+		if(inputArray != null && filterArray != null) {
+			doLoopBasedConv2dDenseDense(n1, n2, k, params, inputArray, filterArray);
+		}
+		else if(inputArray != null && filterArray == null) {
+			for (int n = n1; n < n2; n++) 
+				doLoopBasedConv2dDenseSparse(n, k, params, inputArray);
+		}
+		else if(inputArray == null && filterArray != null) {
+			for (int n = n1; n < n2; n++)
+				doLoopBasedConv2dSparseDense(n, k, params, filterArray);
+		}
+		else if(inputArray == null && filterArray == null) {
+			for (int n = n1; n < n2; n++)
+				doLoopBasedConv2dSparseSparse(n, k, params);
+		}
 	}
 	
 	private static int getMinPQ(int pad, int filterSize, int stride) {
@@ -451,12 +706,7 @@ public class LibMatrixDNN {
 		return Math.min(outputSize, (int)Math.ceil(((double)(inputSize + pad - filterSize)) / stride));
 	}
 	
-	private static double denseConvMultiply(double [] inputArray, double filterVal, ConvolutionParameters params,
-			int n, int c, int h, int w) {
-		return inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w]*filterVal;
-	}
-	
-	private static double sparseConvMultiply(double [] inputArray, double filterVal, ConvolutionParameters params,
+	private static double sparseConvMultiply(double filterVal, ConvolutionParameters params,
 			int n, int c, int h, int w) {
 		return params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w)*filterVal;
 	}
@@ -635,27 +885,41 @@ public class LibMatrixDNN {
 		outputBlock.setNonZeros(input.getNonZeros()); // As number of non-zeros doesnot change for reshape_col
 	}
 	
-	private static void runParallelConvTask(int constrainedNumThreads, int Z, TaskType type, ConvolutionParameters params) throws DMLRuntimeException {
-		// Total number of compute units available: constrainedNumThreads
-		// Static task allocation. TODO: Do this in dynamic way
-		int taskSize = TASK_SIZE;
-		while(true) {
-			if(params.N * Math.ceil(Z/taskSize) > constrainedNumThreads || taskSize == 1) {
-				doRunParallelConvTask(constrainedNumThreads, Z, type, params, taskSize);
-				return;
+	private static int [] getTaskSize(int constrainedNumThreads, int maxNumTaskSize1, int maxNumTaskSize2) {
+		int taskSize1 = 1; int taskSize2 = 1;
+		// Why this heuristics ? To reduce the impact of the thread-creation overhead in case of small tasks
+		int approxNumTasksToCreate = 3*constrainedNumThreads;
+		while((maxNumTaskSize1*maxNumTaskSize2)/(taskSize1*taskSize2) > approxNumTasksToCreate) {
+			// Possibility of creating too many tasks, increase taskSize2
+			taskSize2 *= 2;
+			if(taskSize2 >= maxNumTaskSize2) {
+				taskSize2 = maxNumTaskSize2;
+				break;
 			}
-			taskSize = Math.max(taskSize/2, 1);
 		}
+		while((maxNumTaskSize1*maxNumTaskSize2)/(taskSize1*taskSize2) > approxNumTasksToCreate) {
+			// Possibility of creating too many tasks, increase taskSize1
+			taskSize1 *= 2;
+			if(taskSize1 >= maxNumTaskSize1) {
+				taskSize1 = maxNumTaskSize1;
+				break;
+			}
+		}
+		int [] ret = new int[2];
+		ret[0] = taskSize1;
+		ret[1] = taskSize2;
+		return ret;
 	}
 	
-	private static void doRunParallelConvTask(int constrainedNumThreads, int Z, TaskType type, ConvolutionParameters params, int taskSize) throws DMLRuntimeException {
-		ArrayList<ConvTask> tasks = new ArrayList<ConvTask>();		
-		
-		for (int n = 0; n < params.N; n++) {
-			for (int z = 0; z < Z; z += taskSize) {
-				tasks.add(new ConvTask(n, n+1, z, Math.min(Z, z+taskSize), type, params));
+	private static void runParallelConvTask(int constrainedNumThreads, int Z, TaskType type, ConvolutionParameters params) throws DMLRuntimeException {
+		ArrayList<ConvTask> tasks = new ArrayList<ConvTask>();
+		int [] taskSizes = getTaskSize(constrainedNumThreads, params.N, Z);
+		for (int n = 0; n < params.N; n += taskSizes[0]) {
+			for (int z = 0; z < Z; z += taskSizes[1]) {
+				tasks.add(new ConvTask(n, Math.min(params.N, n+taskSizes[0]), z, Math.min(Z, z+taskSizes[1]), type, params));
 			}
 		}
+		LOG.debug("Reduce number of tasks from " + (params.N*Z)  + "(" + params.N + "," + Z + ") to " + tasks.size());
 
 		ExecutorService pool = Executors.newFixedThreadPool( Math.min(constrainedNumThreads, tasks.size()) );
 		List<Future<Object>> taskret;
@@ -727,10 +991,8 @@ public class LibMatrixDNN {
 					}
 					break;
 				case LoopBasedConv2d:
-					for (int n = n1; n < n2; n++) {
-						for (int z = z1; z < z2; z++) {
-							LibMatrixDNN.doLoopBasedConv2d(n, z, params);
-						}
+					for (int z = z1; z < z2; z++) {
+						LibMatrixDNN.doLoopBasedConv2d(n1, n2, z, params);
 					}
 					break;
 				default: