You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/03/27 10:09:14 UTC

[systemds] branch master updated (c20e540 -> efc78f1)

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git.


    from c20e540  [MINOR] Fix IPA function call graph (missing cleanup debug output)
     add 0545ad9  [SYSTEMDS-2914] maxpooling_backward sparse
     add ab3c743  [SYSTEMDS-2917] NN tests separated into Individual tests
     new efc78f1  [SYSTEMDS-2914] maxpooling_backward sparse Update

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .github/workflows/functionsTests.yml               |    3 +-
 .../runtime/instructions/cp/DnnCPInstruction.java  |    8 +-
 .../sysds/runtime/matrix/data/LibMatrixAgg.java    |   65 +-
 .../sysds/runtime/matrix/data/LibMatrixDNN.java    |   53 +-
 .../runtime/matrix/data/LibMatrixDNNPooling.java   |  427 ++++++--
 .../java/org/apache/sysds/utils/Statistics.java    |    1 -
 .../applications/{NNTest.java => nn/BaseTest.java} |   32 +-
 .../test/applications/nn/NNComponentTest.java      |  120 +++
 .../sysds/test/applications/nn/NNGradientTest.java |   31 +
 .../applications/nn/NNMaxPool2dComponentTest.java  |   62 ++
 .../functions/lineage/LineageMLContextTest.java    |   10 +-
 .../test/functions/mlcontext/MLContextTest.java    |  146 +--
 .../functions/mlcontext/MLContextTestBase.java     |   15 +-
 src/test/scripts/applications/nn/README.md         |   37 -
 .../applications/nn/component/batch_norm1d.dml     |   52 +
 .../applications/nn/component/batch_norm2d.dml     |  109 ++
 .../scripts/applications/nn/component/conv2d.dml   |   67 ++
 .../applications/nn/component/conv2d_depthwise.dml |   94 ++
 .../applications/nn/component/conv2d_transpose.dml |   83 ++
 .../nn/component/conv2d_transpose_depthwise.dml    |   84 ++
 .../nn/component/cross_entropy_loss.dml            |   49 +
 .../nn/component/cross_entropy_loss2d.dml          |   83 ++
 src/test/scripts/applications/nn/component/elu.dml |   57 +
 .../scripts/applications/nn/component/im2col.dml   |   65 ++
 .../applications/nn/component/max_pool2d.dml       |  276 +++++
 .../scripts/applications/nn/component/padding.dml  |   65 ++
 .../applications/nn/component/softmax2d.dml        |   83 ++
 .../scripts/applications/nn/component/tanh.dml     |  106 ++
 .../applications/nn/component/threshold.dml        |   49 +
 .../scripts/applications/nn/component/top_k.dml    |  168 +++
 .../nn/component/transpose_NCHW_to_CNHW.dml        |   59 +
 .../nn/{run_tests.dml => run_tests_gradients.dml}  |   60 --
 src/test/scripts/applications/nn/test.dml          | 1126 --------------------
 33 files changed, 2264 insertions(+), 1481 deletions(-)
 rename src/test/java/org/apache/sysds/test/applications/{NNTest.java => nn/BaseTest.java} (59%)
 create mode 100644 src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
 create mode 100644 src/test/java/org/apache/sysds/test/applications/nn/NNGradientTest.java
 create mode 100644 src/test/java/org/apache/sysds/test/applications/nn/NNMaxPool2dComponentTest.java
 delete mode 100644 src/test/scripts/applications/nn/README.md
 create mode 100644 src/test/scripts/applications/nn/component/batch_norm1d.dml
 create mode 100644 src/test/scripts/applications/nn/component/batch_norm2d.dml
 create mode 100644 src/test/scripts/applications/nn/component/conv2d.dml
 create mode 100644 src/test/scripts/applications/nn/component/conv2d_depthwise.dml
 create mode 100644 src/test/scripts/applications/nn/component/conv2d_transpose.dml
 create mode 100644 src/test/scripts/applications/nn/component/conv2d_transpose_depthwise.dml
 create mode 100644 src/test/scripts/applications/nn/component/cross_entropy_loss.dml
 create mode 100644 src/test/scripts/applications/nn/component/cross_entropy_loss2d.dml
 create mode 100644 src/test/scripts/applications/nn/component/elu.dml
 create mode 100644 src/test/scripts/applications/nn/component/im2col.dml
 create mode 100644 src/test/scripts/applications/nn/component/max_pool2d.dml
 create mode 100644 src/test/scripts/applications/nn/component/padding.dml
 create mode 100644 src/test/scripts/applications/nn/component/softmax2d.dml
 create mode 100644 src/test/scripts/applications/nn/component/tanh.dml
 create mode 100644 src/test/scripts/applications/nn/component/threshold.dml
 create mode 100644 src/test/scripts/applications/nn/component/top_k.dml
 create mode 100644 src/test/scripts/applications/nn/component/transpose_NCHW_to_CNHW.dml
 rename src/test/scripts/applications/nn/{run_tests.dml => run_tests_gradients.dml} (54%)
 delete mode 100644 src/test/scripts/applications/nn/test.dml

[systemds] 01/01: [SYSTEMDS-2914] maxpooling_backward sparse Update

Posted by ba...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit efc78f18c3ecf2aa9cfd916aefc07909a0db0e9c
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Fri Mar 26 18:33:31 2021 +0100

    [SYSTEMDS-2914] maxpooling_backward sparse Update
    
    This commit update the maxpooling sparse output, to use
    append again, since the outputs were almost sorted,
    this means that in practice small arrays are allocated and sorted.
    to then be appended to the sparse row outputs.
    The sorting is very limited to small arrays of 1-14 elements,
    but this value can grow depending on how many kernels can be applied
    on the input horizontally.
    
    Closes #1213
---
 .github/workflows/functionsTests.yml               |   3 +-
 .../runtime/instructions/cp/DnnCPInstruction.java  |  12 -
 .../sysds/runtime/matrix/data/LibMatrixDNN.java    |   3 +
 .../runtime/matrix/data/LibMatrixDNNPooling.java   | 315 ++++++++++++++++-----
 .../applications/nn/NNMaxPool2dComponentTest.java  |   2 +-
 .../applications/nn/component/max_pool2d.dml       |   4 +-
 6 files changed, 259 insertions(+), 80 deletions(-)

diff --git a/.github/workflows/functionsTests.yml b/.github/workflows/functionsTests.yml
index 5e7466c..70d2af1 100644
--- a/.github/workflows/functionsTests.yml
+++ b/.github/workflows/functionsTests.yml
@@ -45,7 +45,8 @@ jobs:
           "**.functions.codegenalg.partone.**",
           "**.functions.builtin.**",
           "**.functions.frame.**,**.functions.indexing.**,**.functions.io.**,**.functions.jmlc.**,**.functions.lineage.**",
-          "**.functions.dnn.**,**.functions.misc.**,**.functions.mlcontext.**,**.functions.paramserv.**",
+          "**.functions.dnn.**,**.functions.paramserv.**",
+          "**.functions.misc.**,**.functions.mlcontext.**",
           "**.functions.nary.**,**.functions.quaternary.**",
           "**.functions.parfor.**,**.functions.pipelines.**,**.functions.privacy.**,**.functions.unary.scalar.**,**.functions.updateinplace.**,**.functions.vect.**",
           "**.functions.reorg.**,**.functions.rewrite.**,**.functions.ternary.**,**.functions.transform.**",
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
index f29b85e..a486672 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
@@ -548,12 +548,6 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 			}
 			else {
 				outputBlock = new MatrixBlock(K, C*R*S, false).allocateBlock();
-				if(params.enableNative ){
-					if(matBlock.isInSparseFormat())
-						matBlock.sparseToDense();
-					if(dout.isInSparseFormat())
-						dout.sparseToDense();
-				}
 				if(params.enableNative && !matBlock.isInSparseFormat() && !dout.isInSparseFormat())
 					LibMatrixNative.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
 				else
@@ -568,12 +562,6 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 			}
 			else {
 				outputBlock = new MatrixBlock(N, C * H * W, false).allocateBlock();
-				if(params.enableNative ){
-					if(matBlock.isInSparseFormat())
-						matBlock.sparseToDense();
-					if(dout.isInSparseFormat())
-						dout.sparseToDense();
-				}
 				if(params.enableNative && !isFilterSparse(matBlock) && !dout.isInSparseFormat())
 					LibMatrixNative.conv2dBackwardData(matBlock, dout, outputBlock, params);
 				else
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
index d1bd2d3..598fef5 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
@@ -178,6 +178,9 @@ public class LibMatrixDNN {
 			fillIndexesArray(params); 
 		}
 		else {
+			if(!params.input2.isInSparseFormat())
+				params.input1.sparseToDense();
+
 			if( !(params.input1.isInSparseFormat() && !params.input2.isInSparseFormat()) )
 				fillIndexesArray(params); //not needed for sparse-dense	 
 		}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNPooling.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNPooling.java
index 2196d1e..84170ac 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNPooling.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNPooling.java
@@ -37,11 +37,6 @@ import org.apache.sysds.runtime.matrix.data.LibMatrixDNNHelper.CellIndex3;
 public class LibMatrixDNNPooling {
 	
 	protected static final Log LOG =  LogFactory.getLog(LibMatrixDNNPooling.class.getName());
-
-	// *********************************** low-level runtime operator selection ***********************************************
-	// *********************************** based on runtime properties (sparsity, native, etc) ********************************
-	// These methods help reduce branch miss predictions and instruction-cache misses.
-	// Also, they simplify the design of LibMatrixDNN and help in code-maintenance.
 	
 	/**
 	 * Factory method that returns list of callable tasks for performing pooling operation
@@ -78,6 +73,7 @@ public class LibMatrixDNNPooling {
 		int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
 		int taskSize = (int)(Math.ceil((double)params.N / k / 2));
 		if(poolType == PoolingType.MAX) {
+			
 			boolean sparse1 = params.input1.isInSparseFormat();
 			boolean sparse2 = params.input2.isInSparseFormat();
 			for(int i = 0; i*taskSize < params.N; i++) {
@@ -357,22 +353,31 @@ public class LibMatrixDNNPooling {
 		public Long call() throws Exception {
 			if(output.isInSparseFormat()){
 				SparseBlock out = output.getSparseBlock();
+				final int[] i = new int[Q];
+				final double[] v = new double[Q];  
 				for(int n = _rl; n < _ru; n++){
 					// each row correspond to a single batch element.
 					// here we allocate the sparse row.
 					out.allocate(n, P*Q*C);
-					SparseRow elm = out.get(n);
+					final SparseRow elm = out.get(n);
+					final int nCHW = n*CHW;
+
+					// tmp arrays for sorting.
 					for(int c = 0; c < C; c++){
 						// each channel processed.
-						final int inputOffset = n*CHW + c*HW;
+						final int inputOffset = nCHW + c*HW;
 						final int outputOffset = n*CPQ + c*PQ;
 						for(int p = 0; p < P; p++){
+							int pointer = 0;
 							for(int q = 0; q < Q; q++){
-								int maxIndex =  getMaxIndex(p, q, inputOffset, inputArray, _params, performReluBackward);
+								int maxIndex = getMaxIndex(p, q, inputOffset, inputArray, _params, performReluBackward);
 								if(maxIndex != -1){
-									add(elm, maxIndex - n*CHW, doutArray[outputOffset +  p * Q + q] );
+									i[pointer] = maxIndex - nCHW;
+									v[pointer] = doutArray[outputOffset +  p * Q + q];
+									pointer++;
 								}
 							}
+							add(elm,i,v,pointer);
 						}
 					}
 				}
@@ -409,7 +414,7 @@ public class LibMatrixDNNPooling {
 		MatrixBlock output; 
 		boolean performReluBackward;
 		double [] inputArray;  MatrixBlock dout;
-		int CHW; int P; int Q; int HW; int C;
+		final int CHW; final int P; final int Q; final int HW; final int C;
 		public PoolingBackwardDenseSparse(int rl, int ru, DnnParameters params, boolean performReluBackward) {
 			_rl = rl; _ru = ru;
 			_params = params;
@@ -429,31 +434,50 @@ public class LibMatrixDNNPooling {
 		@Override
 		public Long call() throws Exception {
 
-			CellIndex3 ix = new CellIndex3();
 			SparseBlock sblock = dout.sparseBlock;
 			if(output.isInSparseFormat()){
 				SparseBlock out = output.getSparseBlock();
+				final int[] i = new int[Q];
+				final double[] v = new double[Q];  
 				for(int n = _rl; n < _ru; n++){
 					// each row correspond to a single batch element.
 					// here we allocate the sparse row.
 					if( sblock.isEmpty(n) ) continue;
+					
 					out.allocate(n, P*Q*C);
-					SparseRow elm = out.get(n);
-					int apos = sblock.pos(n);
-					int alen = sblock.size(n);
-					int[] aix = sblock.indexes(n);
-					double[] avals = sblock.values(n);
+					final SparseRow elm = out.get(n);
+					
+					final int apos = sblock.pos(n);
+					final int alen = sblock.size(n);
+					final int[] aix = sblock.indexes(n);
+					final double[] avals = sblock.values(n);
+
+					int oldP = 0;
+					int pointer = 0;
+					final int nCHW = n*CHW;
+
 					for(int j = apos; j < apos+alen; j++) {
-						ix = LibMatrixDNNHelper.computeTensorIndexes(aix[j], P, Q, ix);
-						final int inputOffset = n*CHW + ix.ix1*HW;
-						int maxIndex = getMaxIndex(ix.ix2, ix.ix3,
-							inputOffset, inputArray, _params, performReluBackward);
-						if(maxIndex != -1)
-							add(elm, maxIndex - n*CHW, avals[j]);
+						final int tmp = aix[j] / Q;
+						final int inputOffset = nCHW + (tmp / P) * HW;
+						final int p = tmp % P;
+						final int q = aix[j] % Q;
+						if(p != oldP){
+							add(elm, i, v, pointer);
+							oldP = p;
+							pointer = 0;
+						}
+						int maxIndex = getMaxIndex(p, q, inputOffset, inputArray, _params, performReluBackward);
+						if(maxIndex != -1){
+							i[pointer] = maxIndex - nCHW;
+							v[pointer] = avals[j];
+							pointer++;
+						}
 					}
+					add(elm, i, v, pointer);
 				}
 			}
 			else {
+				CellIndex3 ix = new CellIndex3();
 				double[] out = output.getDenseBlockValues();
 				for(int n = _rl; n < _ru; n++)  {
 					if( sblock.isEmpty(n) ) continue;
@@ -475,7 +499,7 @@ public class LibMatrixDNNPooling {
 			return P*Q*C*(long)(_ru - _rl);
 		}
 	}
-	
+
 	/**
 	 * Performs the avgpooling backward operation for sparse error (dout)
 	 */
@@ -533,6 +557,10 @@ public class LibMatrixDNNPooling {
 	
 	/**
 	 * Performs the maxpooling backward operation for sparse input and dense error (dout)
+	 * 
+	 * Currently this is NOT IN USE since the sparse left part is forced dense.
+	 * This is because this method is inefficient compared to our dense version.
+	 * 
 	 */
 	private static class PoolingBackwardSparseDense implements Callable<Long> 
 	{
@@ -572,27 +600,26 @@ public class LibMatrixDNNPooling {
 			//allocate auxiliary data structures
 			double[] maxVal = new double[PQ];
 			int[] maxIx = new int[PQ];
-			
 			for(int n = _rl; n < _ru; n++)  {
 				for (int c = 0; c < C; c++) {
 					//step 1: perform maxpooling w/ index maintenance in a 
 					//single, sequential pass over the sparse input matrix
-					maxpoolingForward(maxVal, maxIx, n, c,
+					boolean empty = maxpoolingForward(maxVal, maxIx, n, c,
 						padh, padw, strideh, stridew, C, P, Q, R, S, HW, W);
-					
-					//step 2: perform maxpooling backward
-					if(output.isInSparseFormat())
-						maxpoolingBackwardSparse(maxIx, c*HW, n, c, C, Q, PQ, CPQ);
-					else
-						maxpoolingBackwardDense(maxIx, n*CHW + c*HW, n, c, C, Q, PQ, CPQ);
-					
+					if(!empty){
+						//step 2: perform maxpooling backward
+						if(output.isInSparseFormat())
+							maxpoolingBackwardSparse(maxIx, c*HW, n, c, C, Q, P, CPQ);
+						else
+							maxpoolingBackwardDense(maxIx, n*CHW + c*HW, n, c, C, Q, PQ, CPQ);
+					}
 				}
 			}
 			//thread-local nnz maintenance
 			return P*Q*C*(long)(_ru - _rl);
 		}
 		
-		protected void maxpoolingForward(double[] maxVal, int[] maxIx, int n, int c, int padh, int padw, int strideh, int stridew, int C, int P, int Q, int R, int S, int HW, int W) {
+		protected boolean maxpoolingForward(double[] maxVal, int[] maxIx, int n, int c, int padh, int padw, int strideh, int stridew, int C, int P, int Q, int R, int S, int HW, int W) {
 			SparseBlock sblock = _params.input1.getSparseBlock();
 			if( !sblock.isEmpty(n) ) {
 				Arrays.fill(maxVal, -Double.MAX_VALUE);
@@ -619,17 +646,10 @@ public class LibMatrixDNNPooling {
 				}
 				//handle skipped zero values at end of row
 				update0(lastix+1, (c+1)*HW, maxVal, maxIx, padh, padw, strideh, stridew, P, Q, R, S, HW, W);
+				return false;
 			}
 			else {
-				//handle empty row
-				Arrays.fill(maxVal, 0);
-				for(int p = 0, ix=0; p < P; p++) {
-					int h = Math.max(-padh+p*strideh, 0);
-					for(int q = 0; q < Q; q++, ix++) {
-						int w = Math.max(-padw+q*stridew, 0);
-						maxIx[ix] = h * W + w;
-					}
-				}
+				return true;
 			}
 		}
 		
@@ -641,14 +661,19 @@ public class LibMatrixDNNPooling {
 				out[ outOffset + maxIx[pq] ] += dout[ doutOffset + pq ];
 		}
 
-		protected void maxpoolingBackwardSparse(int[] maxIx, int offset, int n, int c, int C, int Q, int PQ, int CPQ) {
+		protected void maxpoolingBackwardSparse(int[] maxIx, int offset, int n, int c, int C, int Q, int P, int CPQ) {
 			double[] dout = doutput.getDenseBlockValues();
 			SparseBlock out = output.getSparseBlock();
-			out.allocate(n, PQ);
+			out.allocate(n, P * Q);
 			SparseRow row = out.get(n);
-			final int doutOffset = n*CPQ + c*PQ;
-			for( int pq = 0; pq < PQ; pq++ )
-				row.add(maxIx[pq] + offset ,dout[ doutOffset + pq ]);
+			final int doutOffset = n*CPQ + c*P * Q;
+			int pq = 0;
+			for( int p = 0; p < P; p++ ){
+				for(int q = 0; q < Q; q++){
+					row.add(maxIx[pq] + offset ,dout[ doutOffset + pq ]);
+					pq++;
+				}
+			}
 		}
 		
 		private static void update0(int lix, int uix, double[] maxVal, int[] maxIx, int padh, int padw, int strideh, int stridew, int P, int Q, int R, int S, int HW, int W) {
@@ -680,6 +705,10 @@ public class LibMatrixDNNPooling {
 	
 	/**
 	 * Performs the maxpooling backward operation for sparse input and sparse error (dout)
+	 * 
+	 * Currently this is NOT IN USE since the sparse left part is forced dense.
+	 * This is because this method is inefficient compared to our dense version.
+	 * 
 	 */
 	private static class PoolingBackwardSparseSparse extends PoolingBackwardSparseDense
 	{
@@ -713,10 +742,11 @@ public class LibMatrixDNNPooling {
 		}
 
 		@Override
-		protected void maxpoolingBackwardSparse(int[] maxIx, int offset, int n, int c, int C, int Q, int PQ, int CPQ) {
+		protected void maxpoolingBackwardSparse(int[] maxIx, int offset, int n, int c, int C, int Q, int P, int CPQ) {
 			SparseBlock sblock = doutput.getSparseBlock();
 			if( sblock.isEmpty(n) )
 				return;
+			final int PQ = P*Q;
 			SparseBlock out = output.getSparseBlock();
 			out.allocate(n, PQ);
 			SparseRow row = out.get(n);
@@ -769,44 +799,199 @@ public class LibMatrixDNNPooling {
 		int end_index_w = params.end_indexes_w[q];
 		
 		int maxIndex = -1; 
-		double maxVal = -Double.MAX_VALUE;
+		double maxVal = performReluBackward ? 0 : Double.NEGATIVE_INFINITY;
 		
 		// Note: We do not treat pad as zero and hence we don't do:  
 		// maxVal = 0 
 		// if start_index_h < 0 || start_index_w < 0 || end_index_h >= params.H || end_index_w >= params.W
 		
 		// Find maxIndex
-		double currDoutVal = -1;
 		for (int h = start_index_h; h < end_index_h; h++) {
 			for (int w = start_index_w; w < end_index_w; w++) {
 				final int idx = inputOffset +  h*params.W + w;
-				currDoutVal = inputArray[idx];
-				currDoutVal = performReluBackward && currDoutVal < 0 ? 0 : currDoutVal;
+				final double currDoutVal = inputArray[idx];
 				if(maxVal < currDoutVal) {
 					maxIndex = idx;
 					maxVal = currDoutVal;
 				}
 			}
 		}
-		return maxIndex;
+		return maxVal == 0 && performReluBackward ? -1 : maxIndex;
+	}
+
+	/**
+	 * Add all elements in the arrays to the sparse row. It is guaranteed that all i is larger than all indexes already contained in row.
+	 * 
+	 * @param row the row to append to
+	 * @param i the indexes to append
+	 * @param v the values to append
+	 */
+	private static void add(SparseRow row, int[] i, double[] v, int size){
+		// sort based on the i array.
+		sort(i,v, size);
+		for(int x = 0; x < size; x++){
+			row.append(i[x], v[x]);
+		}
 	}
 
 
+
 	/**
-	 * Add to sparse row assuming that most of the time we would append to the end of the sparse row.
+	 * Use sorting networks for small arrays.
+	 * Note small arrays here is less than 32.
 	 * 
-	 * @param row row to add to.
-	 * @param index the index in the row to add to
-	 * @param v the value to add.
+	 * The basic idea is to use Network sorting, that is the theoretical
+	 * fewest compare and swap operations possible for a specific size array.
+	 * 
+	 * @param i indexes to sort by
+	 * @param v the values to sort along side
 	 */
-	private static void add(SparseRow row, int index, double v){
-		final int size = row.size();
+	private static void sort(int[] i , double[] v, int size){
+		if(size > 32)
+			LOG.warn("Not a optimal size for small array sort " + size);
+		switch (size) {
+			case 1: break;
+			case 2: comp(i,v,0,1); break;
+			case 3: sort3(i,v); break;
+			case 4: sort4(i,v); break;
+			case 5: sort5(i,v); break;
+			case 6: sort6(i,v); break;
+			case 7: sort7(i,v); break;
+			default:
+				// Most cases are handled by the sorting of smaller arrays, 
+				// but just in case we have a insertion sort here. 
+				// Since the array is already semi sorted, it is okay. But not ideal once 
+				// we see larger arrays.
+				// Larger arrays only occur if the input data allow many kernels in the horizontal
+				// dimension.
+				insertSort(i,v, size);
+				break;
+		}
+	}
+
+	private static void sort3(int[] i, double[] v){
+		// 3 moves
+		comp(i,v,0,2);
+		comp(i,v,0,1);
+		comp(i,v,1,2);
+	}
+
+	private static void sort4(int[] i, double[] v){
+		// 5 moves
+		// block 1
+		comp(i,v,0,2);
+		comp(i,v,1,3);
+		// block 2
+		comp(i,v,0,1);
+		comp(i,v,2,3);
+		// block 3
+		comp(i,v,1,2);
+	}
+
+	private static void sort5(int[] i, double[] v){
+		// 9 moves
+		// block 1
+		comp(i,v,0,1);
+		comp(i,v,2,3);
+		// block 2
+		comp(i,v,1,3);
+		comp(i,v,2,4);
+		// block 3
+		comp(i,v,1,4);
+		comp(i,v,0,2);
+		// block 4
+		comp(i,v,1,2);
+		comp(i,v,3,4);
+		// block 5
+		comp(i,v,2,3);
+	}
+
+	private static void sort6(int[] i, double[] v){
+		// 12 moves
+		// block 1
+		comp(i,v,0,1);
+		comp(i,v,2,3);
+		comp(i,v,4,5);
+		// block 2
+		comp(i,v,1,3);
+		// block 3
+		comp(i,v,0,4);
+		// block 4
+		comp(i,v,1,3);
+		// block 5
+		comp(i,v,1,5);
+		// block 6
+		comp(i,v,2,4);
+		// block 7
+		comp(i,v,1,2);
+		comp(i,v,3,5);
+		// block 8
+		comp(i,v,3,4);
+		// block 9
+		comp(i,v,2,3);
+	}
+
+	private static void sort7(int[] i, double[] v){
+		// 16 moves.
+		// block 1
+		comp(i,v,0,1);
+		comp(i,v,2,3);
+		comp(i,v,4,5);
+		// block 2
+		comp(i,v,0,6);
+		// block 3
+		comp(i,v,2,4);
+		// block 4
+		comp(i,v,0,2);
+		// block 5
+		comp(i,v,1,3);
+		comp(i,v,5,6);
+		// block 6
+		comp(i,v,1,4);
+		// block 7
+		comp(i,v,2,5);
+		// block 8
+		comp(i,v,1,2);
+		comp(i,v,4,5);
+		// block 9
+		comp(i,v,2,4);
+		// block 10
+		comp(i,v,3,6);
+		// block 11
+		comp(i,v,3,5);
+		// block 12
+		comp(i,v,3,4);
+	}
+
+	private static void insertSort(int[] i, double[] v, int size){
+		int p, k, j;
+		double t;
+		for(p  = 1; p < size; p++){
+			k = i[p];
+			t = v[p];
+			j = p -1;
+			while(j >= 0 && i[j] > k){
+				i[j+1] = i[j];
+				v[j+1] = v[j];
+				j = j-1;
+			}
+			i[j+1] = k;
+			v[j+1] = t;
+		}
+	}
+
+	private static void comp(int[] i , double[] v, int f, int t){
+		if(i[f] > i[t])
+			swap(i,v,f,t);
+	}
 
-		if(size <= 1)
-			row.add(index, v);
-		else if( row.indexes()[size-1] < index)
-			row.append(index, v);
-		else
-			row.add(index, v);
+	private static void swap(int[] i , double[] v, int f, int t){
+		int tmpI = i[f];
+		double tmpV = v[f];
+		i[f] = i[t];
+		v[f] = v[t];
+		i[t] = tmpI;
+		v[t] = tmpV; 
 	}
 }
+
diff --git a/src/test/java/org/apache/sysds/test/applications/nn/NNMaxPool2dComponentTest.java b/src/test/java/org/apache/sysds/test/applications/nn/NNMaxPool2dComponentTest.java
index dfdacb8..0be02b6 100644
--- a/src/test/java/org/apache/sysds/test/applications/nn/NNMaxPool2dComponentTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/nn/NNMaxPool2dComponentTest.java
@@ -48,7 +48,7 @@ public class NNMaxPool2dComponentTest extends BaseTest {
 	@Parameterized.Parameter(1)
 	public int w;
 
-	final static String[] argNames =  new String[] {"$h", "$w"};
+	final static String[] argNames = new String[] {"$h", "$w"};
 
 	@Test
 	public void max_pool2d_padh_padw() {
diff --git a/src/test/scripts/applications/nn/component/max_pool2d.dml b/src/test/scripts/applications/nn/component/max_pool2d.dml
index c13bb91..0ec075a 100644
--- a/src/test/scripts/applications/nn/component/max_pool2d.dml
+++ b/src/test/scripts/applications/nn/component/max_pool2d.dml
@@ -76,7 +76,9 @@ max_pool2d_pad = function(Integer h, Integer w) {
                                           Hf, Wf, stride, stride, padh, padw)
   dX_builtin = max_pool2d_builtin::backward(dout, Hout_builtin, Wout_builtin, X, C, Hin, Win,
                                             Hf, Wf, stride, stride, padh, padw)
-
+  print(toString(dX))
+  print(toString(dX_simple))
+  print(toString(dX_builtin))
   # Equivalency check
   dX = matrix(dX, rows=1, cols=N*C*Hin*Win)
   dX_simple = matrix(dX_simple, rows=1, cols=N*C*Hin*Win)