You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/09/13 19:08:40 UTC

systemml git commit: [SYSTEMML-540] Improve the performance of CP Convolution operators

Repository: systemml
Updated Branches:
  refs/heads/master 3acd94186 -> e624d149f


[SYSTEMML-540] Improve the performance of CP Convolution operators

- Support sparse bias_multiply.
- Allow JIT to optimize loops in maxpooling operations.
- Perform examSparsity to optimize the future sparse-enabled operations.
- Added script generation logic in Caffe2DML.

Closes #661.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e624d149
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e624d149
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e624d149

Branch: refs/heads/master
Commit: e624d149f7826fb0cd98bb3f32ada423acfaac66
Parents: 3acd941
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Wed Sep 13 11:08:20 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Wed Sep 13 12:08:20 2017 -0700

----------------------------------------------------------------------
 .../nn/test/compare_backends/test_conv2d.dml    |   3 +-
 scripts/nn/test/compare_backends/test_conv2d.sh |  22 +-
 .../compare_backends/test_conv2d_bwd_data.sh    |  22 +-
 .../compare_backends/test_conv2d_bwd_filter.sh  |  22 +-
 .../nn/test/compare_backends/test_maxpool.sh    |  22 +-
 .../org/apache/sysml/hops/ConvolutionOp.java    |  13 +-
 .../cp/ConvolutionCPInstruction.java            |   4 +-
 .../sysml/runtime/matrix/data/LibMatrixDNN.java |  68 +++--
 .../matrix/data/LibMatrixDNNConv2dHelper.java   |   6 +-
 .../runtime/matrix/data/LibMatrixDNNHelper.java |  27 +-
 .../data/LibMatrixDNNPoolingBackwardHelper.java |   5 +-
 .../matrix/data/LibMatrixDNNPoolingHelper.java  |  93 ++++---
 .../data/LibMatrixDNNRotate180Helper.java       |   3 +
 .../org/apache/sysml/api/dl/Caffe2DML.scala     | 258 ++++++++++---------
 .../org/apache/sysml/api/dl/DMLGenerator.scala  |  11 +-
 .../scala/org/apache/sysml/api/dl/Utils.scala   |   5 +
 src/test/scripts/functions/tensor/PoolTest.R    |   2 +-
 src/test/scripts/functions/tensor/PoolTest.dml  |   2 +-
 18 files changed, 304 insertions(+), 284 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/scripts/nn/test/compare_backends/test_conv2d.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_conv2d.dml b/scripts/nn/test/compare_backends/test_conv2d.dml
index b56a0ae..ea3bea2 100644
--- a/scripts/nn/test/compare_backends/test_conv2d.dml
+++ b/scripts/nn/test/compare_backends/test_conv2d.dml
@@ -19,7 +19,8 @@
 #
 #-------------------------------------------------------------
 
+fmt = ifdef($fmt, 'csv')
 X = read("input.mtx")
 w = read("filter.mtx")
 out = conv2d(X, w, input_shape=[$N,$C,$H,$W], filter_shape=[$F, $C, $Hf, $Wf], stride=[$stride,$stride], padding=[$pad,$pad])
-write(out, $out, format="csv")
+write(out, $out, format=fmt)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/scripts/nn/test/compare_backends/test_conv2d.sh
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_conv2d.sh b/scripts/nn/test/compare_backends/test_conv2d.sh
index 4c578a6..205339c 100644
--- a/scripts/nn/test/compare_backends/test_conv2d.sh
+++ b/scripts/nn/test/compare_backends/test_conv2d.sh
@@ -20,27 +20,7 @@
 #
 #-------------------------------------------------------------
 
-jars='.'
-os_suffix='linux-x86_64'
-version='0.8.0'
-
-# Downloads the jcuda jars
-for lib in jcuda jcublas jcufft jcusparse jcusolver jcurand jnvgraph jcudnn
-do
-        file=$lib'-'$version'.jar'
-        if [ ! -f $file ]; then
-                url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-
-        file=$lib'-natives-'$version'-'$os_suffix'.jar'
-        if [ ! -f $file ]; then
-                url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'-natives/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-done
+jars='systemml-*-extra.jar'
 
 # N = Number of images, C = number of channels, H = height, W = width
 # F = number of filters, Hf = filter height, Wf = filter width

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/scripts/nn/test/compare_backends/test_conv2d_bwd_data.sh
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_conv2d_bwd_data.sh b/scripts/nn/test/compare_backends/test_conv2d_bwd_data.sh
index da7b4f3..716560f 100644
--- a/scripts/nn/test/compare_backends/test_conv2d_bwd_data.sh
+++ b/scripts/nn/test/compare_backends/test_conv2d_bwd_data.sh
@@ -20,27 +20,7 @@
 #
 #-------------------------------------------------------------
 
-jars='.'
-os_suffix='linux-x86_64'
-version='0.8.0'
-
-# Downloads the jcuda jars
-for lib in jcuda jcublas jcufft jcusparse jcusolver jcurand jnvgraph jcudnn
-do
-        file=$lib'-'$version'.jar'
-        if [ ! -f $file ]; then
-                url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-
-        file=$lib'-natives-'$version'-'$os_suffix'.jar'
-        if [ ! -f $file ]; then
-                url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'-natives/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-done
+jars='systemml-*-extra.jar'
 
 # N = Number of images, C = number of channels, H = height, W = width
 # F = number of filters, Hf = filter height, Wf = filter width

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/scripts/nn/test/compare_backends/test_conv2d_bwd_filter.sh
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_conv2d_bwd_filter.sh b/scripts/nn/test/compare_backends/test_conv2d_bwd_filter.sh
index f19fc73..99d3011 100644
--- a/scripts/nn/test/compare_backends/test_conv2d_bwd_filter.sh
+++ b/scripts/nn/test/compare_backends/test_conv2d_bwd_filter.sh
@@ -20,27 +20,7 @@
 #
 #-------------------------------------------------------------
 
-jars='.'
-os_suffix='linux-x86_64'
-version='0.8.0'
-
-# Downloads the jcuda jars
-for lib in jcuda jcublas jcufft jcusparse jcusolver jcurand jnvgraph jcudnn
-do
-        file=$lib'-'$version'.jar'
-        if [ ! -f $file ]; then
-                url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-
-        file=$lib'-natives-'$version'-'$os_suffix'.jar'
-        if [ ! -f $file ]; then
-                url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'-natives/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-done
+jars='systemml-*-extra.jar'
 
 # N = Number of images, C = number of channels, H = height, W = width
 # F = number of filters, Hf = filter height, Wf = filter width

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/scripts/nn/test/compare_backends/test_maxpool.sh
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_maxpool.sh b/scripts/nn/test/compare_backends/test_maxpool.sh
index e8575e3..9d7da4a 100644
--- a/scripts/nn/test/compare_backends/test_maxpool.sh
+++ b/scripts/nn/test/compare_backends/test_maxpool.sh
@@ -20,27 +20,7 @@
 #
 #-------------------------------------------------------------
 
-jars='.'
-os_suffix='linux-x86_64'
-version='0.8.0'
-
-# Downloads the jcuda jars
-for lib in jcuda jcublas jcufft jcusparse jcusolver jcurand jnvgraph jcudnn
-do
-        file=$lib'-'$version'.jar'
-        if [ ! -f $file ]; then
-                url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-
-        file=$lib'-natives-'$version'-'$os_suffix'.jar'
-        if [ ! -f $file ]; then
-                url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'-natives/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-done
+jars='systemml-*-extra.jar'
 
 # N = Number of images, C = number of channels, H = height, W = width
 N=5

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index 59ac29e..0ad9182 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -219,8 +219,17 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 	@Override
 	protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
 	{		
-		double sparsity = 1.0;
-		return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
+		if(getOp() == ConvOp.BIAS_MULTIPLY) {
+			// in non-gpu mode, the worst case size of bias multiply operation is same as that of input.
+			if(DMLScript.USE_ACCELERATOR) 
+				return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
+			else
+				return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, getInput().get(0).getSparsity());
+		}
+		else {
+			double sparsity = 1.0;
+			return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
+		}
 	}
 	
 	// ---------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
index e91029e..df72f24 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
@@ -288,8 +288,8 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction {
 		}
 		else {
 			// As we always fill the output first with bias
-			outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), false);
-			outputBlock.allocateDenseBlock();
+			outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), input.isInSparseFormat());
+			outputBlock.allocateDenseOrSparseBlock();
 			LibMatrixDNN.biasMultiply(input, bias, outputBlock, _numThreads);
 		}
 		

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 30b8b64..40192de 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -167,7 +167,8 @@ public class LibMatrixDNN {
 		execute(LibMatrixDNNHelper.getConv2dWorkers(params), params);
 		
 		//post-processing: maintain nnz
-		outputBlock.recomputeNonZeros();
+		outputBlock.recomputeNonZeros(); 
+		outputBlock.examSparsity();
 	}
 	
 	/**
@@ -188,7 +189,8 @@ public class LibMatrixDNN {
 		execute(LibMatrixDNNHelper.getConv2dBackwardDataWorkers(params), params);
 		
 		//post-processing: maintain nnz
-		outputBlock.recomputeNonZeros();
+		outputBlock.recomputeNonZeros(); 
+		outputBlock.examSparsity();
 	}
 	
 	/**
@@ -209,7 +211,8 @@ public class LibMatrixDNN {
 		execute(LibMatrixDNNHelper.getConv2dBackwardFilterWorkers(params), params);
 		
 		//post-processing: maintain nnz
-		outputBlock.recomputeNonZeros();
+		outputBlock.recomputeNonZeros(); 
+		outputBlock.examSparsity();
 	}
 	
 	
@@ -338,7 +341,8 @@ public class LibMatrixDNN {
 		execute(LibMatrixDNNHelper.getMaxPoolingBackwardWorkers(params, performReluBackward), params);
 		
 		//post-processing: maintain nnz 
-		outputBlock.recomputeNonZeros();
+		outputBlock.recomputeNonZeros(); 
+		outputBlock.examSparsity();
 	}
 	
 	/**
@@ -391,7 +395,8 @@ public class LibMatrixDNN {
 		execute(LibMatrixDNNHelper.getReluBackwardWorkers(params), params);
 		
 		// post-processing: maintain nnz
-		outputBlock.recomputeNonZeros();
+		outputBlock.recomputeNonZeros(); 
+		outputBlock.examSparsity();
 	}
 	
 	/**
@@ -429,15 +434,17 @@ public class LibMatrixDNN {
 			double [] biasArr = bias.getDenseBlock();
 			for(int n = 0; n < N; n++) {
 				for(int k = 0; k < K; k++) {
+					double biasVal = biasArr[k];
 					for(int pq = 0; pq < PQ; pq++, index++) {
-						outputArray[index] += biasArr[k];
+						outputArray[index] += biasVal;
 					}
 				}
 			}
 		}
 		
 		//post-processing: maintain nnz
-		outputBlock.recomputeNonZeros();
+		outputBlock.recomputeNonZeros(); 
+		outputBlock.examSparsity();
 	}
 	
 	
@@ -469,22 +476,52 @@ public class LibMatrixDNN {
 		
 		if(!input.isEmptyBlock() && !bias.isEmptyBlock()) {
 			// Handles both dense and sparse inputs and copies it to dense output
-			outputBlock.copy(input); 
-			double [] outputArray = outputBlock.getDenseBlock();
-			int index = 0;
+			outputBlock.copy(input);
 			if(bias.isInSparseFormat())
 				bias.sparseToDense(); // Since bias is extremely small array
 			double [] biasArr = bias.getDenseBlock();
-			for(int n = 0; n < N; n++) {
+			if(!input.isInSparseFormat()) {
+				double [] outputArray = outputBlock.getDenseBlock();
+				int index = 0;
+				for(int n = 0; n < N; n++) {
+					for(int k = 0; k < K; k++) {
+						double biasVal = biasArr[k];
+						for(int pq = 0; pq < PQ; pq++, index++) {
+							outputArray[index] *= biasVal;
+						}
+					}
+				}
+			}
+			else {
+				// First delete those elements which will become zero 
 				for(int k = 0; k < K; k++) {
-					for(int pq = 0; pq < PQ; pq++, index++) {
-						outputArray[index] *= biasArr[k];
+					if(biasArr[k] == 0) {
+						for(int n = 0; n < N; n++) {
+							outputBlock.sparseBlock.deleteIndexRange(n, k*PQ, (k+1)*PQ);
+						}
+					}
+				}
+				// Then perform bias_multiply for non-zero bias entries
+				for(int n = 0; n < N; n++) {
+					if( !outputBlock.sparseBlock.isEmpty(n) ) {
+						int apos = outputBlock.sparseBlock.pos(n);
+						int alen = outputBlock.sparseBlock.size(n);
+						int[] aix = outputBlock.sparseBlock.indexes(n);
+						double[] avals = outputBlock.sparseBlock.values(n);
+						
+						for(int j=apos; j<apos+alen; j++) {
+							// Since aix[j] => KPQ
+							int k = aix[j] % PQ;
+							if(biasArr[k] != 0)
+								avals[j] *= biasArr[k];
+						}
 					}
 				}
 			}
 			
 			//post-processing: maintain nnz
-			params.output.recomputeNonZeros();
+			params.output.recomputeNonZeros(); 
+			params.output.examSparsity();
 		}
 		else {
 			params.output.setNonZeros(0);
@@ -504,7 +541,8 @@ public class LibMatrixDNN {
 		execute(LibMatrixDNNHelper.getMaxPoolingWorkers(params), params);
 		
 		// post-processing: maintain nnz
-		outputBlock.recomputeNonZeros();
+		outputBlock.recomputeNonZeros(); 
+		outputBlock.examSparsity();
 	}
 	
 	/**

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
index 4c3a3c3..66b2ed1 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
@@ -96,9 +96,10 @@ public class LibMatrixDNNConv2dHelper {
 							int alen = src.sparseBlock.size(k);
 							int[] aix = src.sparseBlock.indexes(k);
 							double[] avals = src.sparseBlock.values(k);
+							int desPosK = destPos + k*PQ;
 							for(int j = apos; j < apos+alen; j++) {
 								int pqIndex = aix[j];
-								dest[destPos + k*PQ + pqIndex ] += avals[j];
+								dest[desPosK + pqIndex ] += avals[j];
 							}
 						}
 					}
@@ -174,9 +175,10 @@ public class LibMatrixDNNConv2dHelper {
 							int alen = src.sparseBlock.size(k);
 							int[] aix = src.sparseBlock.indexes(k);
 							double[] avals = src.sparseBlock.values(k);
+							int desPosK = destPos + k*PQ;
 							for(int j = apos; j < apos+alen; j++) {
 								int pqIndex = aix[j];
-								dest[destPos + k*PQ + pqIndex ] = avals[j];
+								dest[desPosK + pqIndex ] = avals[j];
 							}
 						}
 					}

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
index ab96a8e..0550a98 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
@@ -265,9 +265,12 @@ public class LibMatrixDNNHelper {
 			double [] outputArr = mb.getDenseBlock();
 			if(filter != null) {
 				for(int k = 0; k < _params.K; k++) {
+					int outOffset = k*RS;
+					int filterOffset = k*CRS + c*RS;
 					for(int rs = 0; rs < RS; rs++) {
-						outputArr[k*RS + rs] = filter[k*CRS + c*RS + rs];
-						nnz += outputArr[k*RS + rs] != 0 ? 1 : 0;
+						int outIndex = outOffset + rs;
+						outputArr[outIndex] = filter[filterOffset + rs];
+						nnz += outputArr[outIndex] != 0 ? 1 : 0;
 					}
 				}
 			}
@@ -473,12 +476,16 @@ public class LibMatrixDNNHelper {
 		}
 		else {
 			if(!input.isEmptyBlock()) {
+				int outOffset = outputN*params.C*params.H*params.W;
+				int HW = params.H*params.W;
 				int [] tensorIndexes = new int[3];
 				for(int i = 0; i < input.getNumRows(); i++) {
 					if( !input.sparseBlock.isEmpty(i) ) {
 						computeTensorIndexes(i, tensorIndexes, params.P, params.Q);
 						int p = tensorIndexes[1];
 						int q = tensorIndexes[2];
+						int tmpP = p*params.stride_h - params.pad_h;
+						int tmpQ = q*params.stride_w - params.pad_w;
 						if(tensorIndexes[0] != 0) 
 							throw new DMLRuntimeException("Incorrect tensor indexes: " + tensorIndexes[0] + " != 0 <" + p + " " + q + " " + tensorIndexes[0] + params.P + " " + params.Q + ">");
 						
@@ -491,10 +498,10 @@ public class LibMatrixDNNHelper {
 							int c = tensorIndexes[0];
 							int r = tensorIndexes[1];
 							int s = tensorIndexes[2];
-							int h = p*params.stride_h + r - params.pad_h;
-							int w = q*params.stride_w + s - params.pad_w;
+							int h = tmpP + r;
+							int w = tmpQ + s;
 							if(h >= 0 && h < params.H && w >= 0 && w < params.W) {
-								int outIndex = outputN*params.C*params.H*params.W + c*params.H*params.W + h*params.W + w;
+								int outIndex = outOffset + c*HW + h*params.W + w;
 								outputArray[outIndex] += avals[j];
 							}
 						}
@@ -508,6 +515,10 @@ public class LibMatrixDNNHelper {
 	// Or converts input: NPQ X CRS matrix and writes to N X CHW 
 	private static void doCol2IMDenseInput(int inputN, int outputN, double [] inputArray, double [] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
 		final int outputNOffset = outputN*params.C*params.H*params.W;
+		final int HW = params.H*params.W;
+		final int inputNPQ = inputN*params.P*params.Q;
+		final int CRS = params.C*params.R*params.S;
+		final int RS = params.R*params.S;
 		for (int p = 0; p < params.P; p++) {
 			// h = p*params.stride_h + r - params.pad_h
 			//   = r + hOffset
@@ -522,10 +533,10 @@ public class LibMatrixDNNHelper {
 				final int wOffset = q*params.stride_w - params.pad_w;
 				final int sStart = Math.max(0, - wOffset);
 				final int sEnd = Math.min(params.S, params.W - wOffset);
-				final int tempOffset = (inputN*params.P*params.Q + p*params.Q + q)*params.C*params.R*params.S;
+				final int tempOffset = (inputNPQ + p*params.Q + q)*CRS;
 				for (int c = 0; c < params.C; c++) {
-					final int outOffset = outputNOffset + c*params.H*params.W;
-					final int inputOffset = tempOffset + c*params.R*params.S;
+					final int outOffset = outputNOffset + c*HW;
+					final int inputOffset = tempOffset + c*RS;
 					for (int r = rStart; r < rEnd; r++) {
 						for (int s = sStart; s < sEnd; s++) {
 							int inputIndex = inputOffset + r*params.S + s;

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
index b400105..5b04e59 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
@@ -146,11 +146,12 @@ public class LibMatrixDNNPoolingBackwardHelper {
 		public Long call() throws Exception {
 			for(int n = _rl; n < _ru; n++)  {
 				for (int c = 0; c < C; c++) {
+					final int doutOffset = n*CPQ + c*PQ;
+					final int inputOffset = n*CHW + c*HW;
 					for (int p = 0; p < P; p++) {
 						for (int q = 0; q < Q; q++) {
-							double inVal = doutArray[n*CPQ + c*PQ +  p * Q + q];
+							double inVal = doutArray[doutOffset +  p * Q + q];
 							if(inVal != 0) {
-								final int inputOffset = n*CHW + c*HW;
 								int maxIndex = LibMatrixDNNHelper.getMaxIndexSparse(p, q, inputOffset, n, c, _params.input1, _params, performReluBackward);
 								if(maxIndex != -1)
 									outputArray[maxIndex] += inVal;

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
index c6aaee2..19c3f71 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
@@ -55,11 +55,14 @@ public class LibMatrixDNNPoolingHelper {
 					final int inOffset1 = inOffset + c*HW;
 					for (int p = 0; p < P; p++) {
 						for (int q = 0; q < Q; q++, out_index++) {
+							double tmp = outputArray[out_index];
 							for (int h = _params.start_indexes_h[p]; h < _params.end_indexes_h[p]; h++) {
-								for (int w = _params.start_indexes_w[q]; w < _params.end_indexes_w[q]; w++) {
-									outputArray[out_index] = Math.max(outputArray[out_index], inputArray[inOffset1 +  h*W + w]);
+								int inputIndex = inOffset1 +  h*W + _params.start_indexes_w[q];
+								for (int w = _params.start_indexes_w[q]; w < _params.end_indexes_w[q]; w++, inputIndex++) {
+									tmp = Math.max(tmp, inputArray[inputIndex]);
 								}
 							}
+							outputArray[out_index] = tmp;
 						}
 					}
 				}
@@ -75,67 +78,63 @@ public class LibMatrixDNNPoolingHelper {
 	{
 		public int _rl; public int _ru; 
 		private final ConvolutionParameters _params;
-		int HW;
+		final int HW;
 		double [] outputArray;
-		int C; int P; int Q; int W;
+		final int C; final int P; final int Q; final int W; final int H; final int CPQ; final int PQ;
 		public SparseMaxPooling(int rl, int ru, ConvolutionParameters params) {
 			_rl = rl; _ru = ru;
 			_params = params;
 			outputArray = params.output.getDenseBlock();
-			C = params.C; P = params.P; Q = params.Q; W = params.W;
+			C = params.C; P = params.P; Q = params.Q; H = params.H; 
+			W = params.W;
 			HW = _params.H*_params.W;
-		}
-		
-		boolean isNthRowEmpty = false;
-		int apos; int alen; int[] aix; double[] avals;
-		private void getNthSparseRow(int n) {
-			if( !_params.input1.sparseBlock.isEmpty(n) ) {
-				apos = _params.input1.sparseBlock.pos(n);
-				alen = _params.input1.sparseBlock.size(n);
-				aix = _params.input1.sparseBlock.indexes(n);
-				avals = _params.input1.sparseBlock.values(n);
-				isNthRowEmpty = false;
-			}
-			else
-				isNthRowEmpty = true;
-		}
-		int fromIndex = -1; // as per C
-		int toIndex = -1; // as per C
-		private int setSearchIndex(int from, int searchVal) {
-			for(int j = from; j < apos+alen; j++) {
-				if(aix[j] > searchVal)
-					return Math.max(from, j-1);
-			}
-			return apos+alen;
-		}
-		private double getValue(int col) {
-			if( !isNthRowEmpty ) {
-				int index = Arrays.binarySearch(aix, fromIndex, toIndex, col);
-				return index > 0 ? avals[index] : 0;
-			}
-			return 0;
+			CPQ = C*P*Q;
+			PQ = P*Q;
 		}
 		
 		@Override
 		public Long call() throws Exception {
-			final int CPQ = C*P*Q;
 			for(int n = _rl; n < _ru; n++)  {
-				getNthSparseRow(n);
-				int out_index = n*CPQ;
-				for (int c = 0; c < C; c++) {
-					// This allows for binary search in getValue to be more efficient
-					fromIndex = setSearchIndex(apos, c*HW);
-					toIndex = Math.min(apos+alen, setSearchIndex(fromIndex, (c+1)*HW));
-					for (int p = 0; p < P; p++) {
-						for (int q = 0; q < Q; q++, out_index++) {
-							for (int h = _params.start_indexes_h[p]; h < _params.end_indexes_h[p]; h++) {
-								for (int w = _params.start_indexes_w[q]; w < _params.end_indexes_w[q]; w++) {
-									outputArray[out_index] = Math.max(outputArray[out_index], getValue(c*HW +  h*W + w));
+				if( !_params.input1.sparseBlock.isEmpty(n) ) {
+					final int apos = _params.input1.sparseBlock.pos(n);
+					final int alen = _params.input1.sparseBlock.size(n);
+					final int [] aix = _params.input1.sparseBlock.indexes(n);
+					final double [] avals = _params.input1.sparseBlock.values(n);
+					int chw = 0; int index = apos;
+					for (int c = 0; c < C; c++) {
+						final int outOffset = n*CPQ + c*PQ;
+						for(int h = 0; h < H; h++) {
+							for(int w = 0; w < W; w++, chw++) {
+								// Take into account zero values as well
+								double nchwVal = 0;
+								if(aix[index] == chw) {
+									nchwVal = avals[index++];
+									// Ensure that we satisfy the condition index < apos+alen
+									if(index >= apos+alen) index--;
+								}
+								// Perform maxpooling without binary search :)
+								// Tradeoff as compared to dense maxpooling: 
+								// In dense maxpooling, iteration space CPQHW where H and W iterations are restricted by _params.start_indexes_h[p] 
+								// and are eligible for JIT optimizations.
+								// In sparse maxpooling, iteration space CHWPQ without HW restrictions.
+								for (int p = 0; p < P; p++) {
+									if(h >= _params.start_indexes_h[p] && h < _params.end_indexes_h[p]) {
+										final int outOffsetWithp = outOffset + p*Q;
+										for (int q = 0; q < Q; q++) {
+											if(w >= _params.start_indexes_w[q] && w < _params.end_indexes_w[q]) {
+												outputArray[outOffsetWithp + q] = Math.max(outputArray[outOffsetWithp + q], nchwVal);
+											}
+										}
+									}
 								}
 							}
 						}
 					}
 				}
+				else {
+					// Empty input image
+					Arrays.fill(outputArray, n*CPQ, (n+1)*CPQ, 0);
+				}
 			}
 			return 0L;
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
index c003756..6bc7caf 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
@@ -66,6 +66,9 @@ public class LibMatrixDNNRotate180Helper {
 	
 	/**
 	 * Performing rotate180 when input is sparse (general case)
+	 * 
+	 * Why are we allocating the output of rotate180 in dense format ? 
+	 * Because the number of rows of output (i.e. NPQ) is much larger than number of columns (i.e. K) 
 	 */
 	static class SparseRotate180Worker implements Rotate180Worker {
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index 000fe32..a62fae2 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -131,6 +131,29 @@ object Caffe2DML  {
     val envFlagNesterovUDF = System.getenv("USE_NESTEROV_UDF")
     envFlagNesterovUDF != null && envFlagNesterovUDF.toBoolean
   }
+  
+  def main(args: Array[String]): Unit = {
+	// Arguments: [train_script | predict_script] $OUTPUT_DML_FILE $SOLVER_FILE $INPUT_CHANNELS $INPUT_HEIGHT $INPUT_WIDTH $NUM_ITER
+	if(args.length < 6) throwUsageError
+	val outputDMLFile = args(1)
+	val solverFile = args(2)
+	val inputChannels = args(3)
+	val inputHeight = args(4)
+	val inputWidth = args(5)
+	val caffeObj = new Caffe2DML(new SparkContext(), solverFile, inputChannels, inputHeight, inputWidth)
+	if(args(0).equals("train_script")) {
+		Utils.writeToFile(caffeObj.getTrainingScript(true)._1.getScriptString, outputDMLFile)
+	}
+	else if(args(0).equals("predict_script")) {
+		Utils.writeToFile(new Caffe2DMLModel(caffeObj).getPredictionScript(true)._1.getScriptString, outputDMLFile)
+	}
+	else {
+		throwUsageError
+	}
+  }
+  def throwUsageError():Unit = {
+	throw new RuntimeException("Incorrect usage: train_script OUTPUT_DML_FILE SOLVER_FILE INPUT_CHANNELS INPUT_HEIGHT INPUT_WIDTH"); 
+  }
 }
 
 class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter, 
@@ -147,7 +170,10 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
   def this(sc: SparkContext, solver1:Caffe.SolverParameter, numChannels:String, height:String, width:String) {
     this(sc, solver1, Utils.parseSolver(solver1), new CaffeNetwork(solver1.getNet, caffe.Caffe.Phase.TRAIN, numChannels, height, width), 
         new LearningRatePolicy(solver1), numChannels, height, width)
-  } 
+  }
+  def this(sc: SparkContext, solverPath:String, numChannels:String, height:String, width:String) {
+    this(sc, Utils.readCaffeSolver(solverPath), numChannels, height, width)
+  }
   val uid:String = "caffe_classifier_" + (new Random).nextLong
   override def copy(extra: org.apache.spark.ml.param.ParamMap): Estimator[Caffe2DMLModel] = {
     val that = new Caffe2DML(sc, solverParam, solver, net, lrPolicy, numChannels, height, width)
@@ -232,105 +258,30 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
 	  val shouldValidate = solverParam.getTestInterval > 0 && solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
 	  trainTestSplit(if(shouldValidate) solverParam.getTestIter(0) else 0)
 	  
-	  // Set iteration-related variables such as max_epochs, num_iters_per_epoch, lr, etc.
-	  setIterationVariables
+	  // Set iteration-related variables such as num_iters_per_epoch, lr, etc.
+	  ceilDivide(tabDMLScript, "num_iters_per_epoch", Caffe2DML.numImages, Caffe2DML.batchSize)
+	  assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
+	  assign(tabDMLScript, "max_iter", ifdef("$max_iter", solverParam.getMaxIter.toString))
+	  assign(tabDMLScript, "e", "0")
+	  
 	  val lossLayers = getLossLayers(net)
 	  // ----------------------------------------------------------------------------
 	  // Main logic
-	  forBlock("e", "1", "max_epochs") {
-	    getTrainAlgo.toLowerCase match {
-	      case "minibatch" => 
-	        forBlock("i", "1", "num_iters_per_epoch") {
-	          getTrainingBatch(tabDMLScript)
-	          tabDMLScript.append("iter = iter + 1\n")
-	          // -------------------------------------------------------
-	          // Perform forward, backward and update on minibatch
-	          forward; backward; update
-	          // -------------------------------------------------------
-	          displayLoss(lossLayers(0), shouldValidate)
-            performSnapshot
-	        }
-	      case "batch" => {
-          tabDMLScript.append("iter = iter + 1\n")
-          // -------------------------------------------------------
-          // Perform forward, backward and update on entire dataset
-          forward; backward; update
-          // -------------------------------------------------------
-          displayLoss(lossLayers(0), shouldValidate)
-          performSnapshot
-	      }
-	      case "allreduce_parallel_batches" => {
-	        // This setting uses the batch size provided by the user
-          if(!inputs.containsKey("$parallel_batches")) {
-            throw new RuntimeException("The parameter parallel_batches is required for allreduce_parallel_batches")
-          }
-          // The user specifies the number of parallel_batches
-          // This ensures that the user of generated script remembers to provide the commandline parameter $parallel_batches
-          assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
-          assign(tabDMLScript, "group_batch_size", "parallel_batches*" + Caffe2DML.batchSize)
-          assign(tabDMLScript, "groups", "as.integer(ceil(" + Caffe2DML.numImages + "/group_batch_size))")
-          // Grab groups of mini-batches
-          forBlock("g", "1", "groups") {
-            tabDMLScript.append("iter = iter + 1\n")
-            // Get next group of mini-batches
-            assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) %% " + Caffe2DML.numImages + " + 1")
-            assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages + ", group_beg + group_batch_size - 1)")
-            assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[group_beg:group_end,]")
-            assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[group_beg:group_end,]")
-            initializeGradients("parallel_batches")
-            assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
-            parForBlock("j", "1", "parallel_batches") {
-              // Get a mini-batch in this group
-              assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize + ") %% nrow(X_group_batch) + 1")
-              assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " + Caffe2DML.batchSize + " - 1)")
-              assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
-              assign(tabDMLScript, "yb", "y_group_batch[beg:end,]")
-              forward; backward
-              flattenGradients
-            }
-            aggregateAggGradients    
-	          update
-	          // -------------------------------------------------------
-	          assign(tabDMLScript, "Xb", "X_group_batch")
-            assign(tabDMLScript, "yb", "y_group_batch")
-            displayLoss(lossLayers(0), shouldValidate)
-            performSnapshot
-          }
-	      }
-	      case "allreduce" => {
-	        // This is distributed synchronous gradient descent
-	        forBlock("i", "1", "num_iters_per_epoch") {
-	          tabDMLScript.append("iter = iter + 1\n")
-	          // -------------------------------------------------------
-            // Perform forward, backward and update on minibatch in parallel
-	          assign(tabDMLScript, "beg", "((i-1) * " + Caffe2DML.batchSize + ") %% " + Caffe2DML.numImages + " + 1")
-	          assign(tabDMLScript, "end", " min(beg +  " + Caffe2DML.batchSize + " - 1, " + Caffe2DML.numImages + ")")
-	          assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[beg:end,]")
-            assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[beg:end,]")
-            assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
-	          tabDMLScript.append("local_batch_size = nrow(y_group_batch)\n")
-	          val localBatchSize = "local_batch_size"
-	          initializeGradients(localBatchSize)
-	          parForBlock("j", "1", localBatchSize) {
-	            assign(tabDMLScript, "Xb", "X_group_batch[j,]")
-	            assign(tabDMLScript, "yb", "y_group_batch[j,]")
-	            forward; backward
-              flattenGradients
-	          }
-	          aggregateAggGradients    
-	          update
-	          // -------------------------------------------------------
-	          assign(tabDMLScript, "Xb", "X_group_batch")
-            assign(tabDMLScript, "yb", "y_group_batch")
-            displayLoss(lossLayers(0), shouldValidate)
-            performSnapshot
-	        }
-	      }
-	      case _ => throw new DMLRuntimeException("Unsupported train algo:" + getTrainAlgo)
-	    }
-	    // After every epoch, update the learning rate
-	    tabDMLScript.append("# Learning rate\n")
-	    lrPolicy.updateLearningRate(tabDMLScript)
+	  forBlock("iter", "1", "max_iter") {
+		performTrainingIter(lossLayers, shouldValidate)
+		if(getTrainAlgo.toLowerCase.equals("batch")) {
+			assign(tabDMLScript, "e", "iter")
+			tabDMLScript.append("# Learning rate\n")
+			lrPolicy.updateLearningRate(tabDMLScript)
+		}
+		else {
+			ifBlock("iter %% num_iters_per_epoch == 0") {
+				// After every epoch, update the learning rate
+				assign(tabDMLScript, "e", "e + 1")
+				tabDMLScript.append("# Learning rate\n")
+				lrPolicy.updateLearningRate(tabDMLScript)
+			}
+		}
 	  }
 	  // ----------------------------------------------------------------------------
 	  
@@ -350,6 +301,90 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
 	}
 	// ================================================================================================
   
+  private def performTrainingIter(lossLayers:List[IsLossLayer], shouldValidate:Boolean):Unit = {
+	getTrainAlgo.toLowerCase match {
+      case "minibatch" => 
+          getTrainingBatch(tabDMLScript)
+          // -------------------------------------------------------
+          // Perform forward, backward and update on minibatch
+          forward; backward; update
+          // -------------------------------------------------------
+          displayLoss(lossLayers(0), shouldValidate)
+          performSnapshot
+      case "batch" => {
+	      // -------------------------------------------------------
+	      // Perform forward, backward and update on entire dataset
+	      forward; backward; update
+	      // -------------------------------------------------------
+	      displayLoss(lossLayers(0), shouldValidate)
+	      performSnapshot
+      }
+      case "allreduce_parallel_batches" => {
+    	  // This setting uses the batch size provided by the user
+	      if(!inputs.containsKey("$parallel_batches")) {
+	        throw new RuntimeException("The parameter parallel_batches is required for allreduce_parallel_batches")
+	      }
+	      // The user specifies the number of parallel_batches
+	      // This ensures that the user of generated script remembers to provide the commandline parameter $parallel_batches
+	      assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
+	      assign(tabDMLScript, "group_batch_size", "parallel_batches*" + Caffe2DML.batchSize)
+	      assign(tabDMLScript, "groups", "as.integer(ceil(" + Caffe2DML.numImages + "/group_batch_size))")
+	      // Grab groups of mini-batches
+	      forBlock("g", "1", "groups") {
+	        // Get next group of mini-batches
+	        assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) %% " + Caffe2DML.numImages + " + 1")
+	        assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages + ", group_beg + group_batch_size - 1)")
+	        assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[group_beg:group_end,]")
+	        assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[group_beg:group_end,]")
+	        initializeGradients("parallel_batches")
+	        assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
+	        parForBlock("j", "1", "parallel_batches") {
+	          // Get a mini-batch in this group
+	          assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize + ") %% nrow(X_group_batch) + 1")
+	          assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " + Caffe2DML.batchSize + " - 1)")
+	          assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
+	          assign(tabDMLScript, "yb", "y_group_batch[beg:end,]")
+	          forward; backward
+	          flattenGradients
+	        }
+	        aggregateAggGradients    
+	        update
+	        // -------------------------------------------------------
+	        assign(tabDMLScript, "Xb", "X_group_batch")
+	        assign(tabDMLScript, "yb", "y_group_batch")
+	        displayLoss(lossLayers(0), shouldValidate)
+	        performSnapshot
+	      }
+      }
+      case "allreduce" => {
+    	  // This is distributed synchronous gradient descent
+    	  // -------------------------------------------------------
+    	  // Perform forward, backward and update on minibatch in parallel
+    	  assign(tabDMLScript, "beg", "((iter-1) * " + Caffe2DML.batchSize + ") %% " + Caffe2DML.numImages + " + 1")
+    	  assign(tabDMLScript, "end", " min(beg +  " + Caffe2DML.batchSize + " - 1, " + Caffe2DML.numImages + ")")
+    	  assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[beg:end,]")
+    	  assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[beg:end,]")
+    	  assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
+          tabDMLScript.append("local_batch_size = nrow(y_group_batch)\n")
+          val localBatchSize = "local_batch_size"
+          initializeGradients(localBatchSize)
+          parForBlock("j", "1", localBatchSize) {
+            assign(tabDMLScript, "Xb", "X_group_batch[j,]")
+            assign(tabDMLScript, "yb", "y_group_batch[j,]")
+            forward; backward
+          flattenGradients
+          }
+          aggregateAggGradients    
+          update
+          // -------------------------------------------------------
+          assign(tabDMLScript, "Xb", "X_group_batch")
+          assign(tabDMLScript, "yb", "y_group_batch")
+          displayLoss(lossLayers(0), shouldValidate)
+          performSnapshot
+      }
+      case _ => throw new DMLRuntimeException("Unsupported train algo:" + getTrainAlgo)
+    }
+  }
   // -------------------------------------------------------------------------------------------
   // Helper functions to generate DML
   // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
@@ -499,10 +534,12 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
   
   private def performSnapshot():Unit = {
     if(solverParam.getSnapshot > 0) {
-      ifBlock("iter %% snapshot == 0") {
+      ifBlock("iter %% " + solverParam.getSnapshot + " == 0") {
         tabDMLScript.append("snapshot_dir= \"" + solverParam.getSnapshotPrefix + "\" + \"/iter_\" + iter + \"/\"\n")
-        net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(write(l.weight, "snapshot_dir + \"" + l.param.getName + "_weight.mtx\"", "binary")))
-  		  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(write(l.bias, "snapshot_dir + \"" + l.param.getName + "_bias.mtx\"", "binary")))
+        net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(
+        	"write(" + l.weight + ", snapshot_dir + \"" + l.param.getName + "_weight.mtx\", format=\"binary\")\n"))
+  		net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(
+  			"write(" + l.bias + ", snapshot_dir + \"" + l.param.getName + "_bias.mtx\", format=\"binary\")\n"))
       }
   	}
   }
@@ -547,19 +584,6 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
           matrix(colSums(l.dBias + "_agg"), nrow(l.bias), ncol(l.bias)))
     })
   }
-  // Set iteration-related variables such as max_epochs, num_iters_per_epoch, lr, etc.
-  def setIterationVariables():Unit = {
-    getTrainAlgo.toLowerCase match {
-	    case "batch" => 
-	      assign(tabDMLScript, "max_epochs", solverParam.getMaxIter.toString)
-	    case _ => {
-	      ceilDivide(tabDMLScript, "num_iters_per_epoch", Caffe2DML.numImages, Caffe2DML.batchSize)
-	      ceilDivide(tabDMLScript, "max_epochs", solverParam.getMaxIter.toString, "num_iters_per_epoch")
-	    }
-	  }
-	  assign(tabDMLScript, "iter", "0")
-	  assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
-  }
   // -------------------------------------------------------------------------------------------
 }
 
@@ -617,7 +641,7 @@ class Caffe2DMLModel(val numClasses:String, val sc: SparkContext, val solver:Caf
 	  estimator.getTestAlgo.toLowerCase match {
       case "minibatch" => {
         ceilDivide(tabDMLScript(), "num_iters", Caffe2DML.numImages, Caffe2DML.batchSize)
-        forBlock("i", "1", "num_iters") {
+        forBlock("iter", "1", "num_iters") {
           getTestBatch(tabDMLScript)
           net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, true))
           assign(tabDMLScript, "Prob[beg:end,]", lossLayers(0).out)
@@ -656,10 +680,10 @@ class Caffe2DMLModel(val numClasses:String, val sc: SparkContext, val solver:Caf
       case "allreduce" => {
         // This setting doesnot use the batch size for scoring and allows the parfor optimizer to select the best plan
         // by minimizing the memory requirement (i.e. batch size = 1)
-        parForBlock("i", "1", Caffe2DML.numImages) {
-          assign(tabDMLScript, "Xb", "X_full[i,]")
+        parForBlock("iter", "1", Caffe2DML.numImages) {
+          assign(tabDMLScript, "Xb", "X_full[iter,]")
           net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, true))
-          assign(tabDMLScript, "Prob[i,]", lossLayers(0).out)
+          assign(tabDMLScript, "Prob[iter,]", lossLayers(0).out)
         }
       }
       case _ => throw new DMLRuntimeException("Unsupported test algo:" + estimator.getTestAlgo)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 456b032..6b06c26 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -182,10 +182,10 @@ trait NextBatchGenerator extends TabbedDMLGenerator {
 	  dmlScript.append("\n")
 	}
 	def getTestBatch(tabDMLScript:StringBuilder):Unit = {
-    assignBatch(tabDMLScript, "Xb", Caffe2DML.X, null, null, "", Caffe2DML.numImages, "i")
+    assignBatch(tabDMLScript, "Xb", Caffe2DML.X, null, null, "", Caffe2DML.numImages, "iter")
   } 
   def getTrainingBatch(tabDMLScript:StringBuilder):Unit = {
-    assignBatch(tabDMLScript, "Xb", Caffe2DML.X, "yb", Caffe2DML.y, "", Caffe2DML.numImages, "i")
+    assignBatch(tabDMLScript, "Xb", Caffe2DML.X, "yb", Caffe2DML.y, "", Caffe2DML.numImages, "iter")
   }
 	def getTrainingBatch(tabDMLScript:StringBuilder, X:String, y:String, numImages:String):Unit = {
 	  assignBatch(tabDMLScript, "Xb", X, "yb", y, "", numImages, "i")
@@ -298,6 +298,13 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with Visua
 	  numTabs -= 1
 	  tabDMLScript.append("}\n")
 	}
+	def whileBlock(cond:String)(op: => Unit) {
+	  tabDMLScript.append("while(" + cond + ") {\n")
+	  numTabs += 1
+	  op
+	  numTabs -= 1
+	  tabDMLScript.append("}\n")
+	}
 	def forBlock(iterVarName:String, startVal:String, endVal:String)(op: => Unit) {
 	  tabDMLScript.append("for(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
 	  numTabs += 1

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/scala/org/apache/sysml/api/dl/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Utils.scala b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
index 0c00d3c..2684261 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Utils.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
@@ -277,6 +277,11 @@ object Utils {
 	
 	// --------------------------------------------------------------
 	// File IO utility functions
+	def writeToFile(content:String, filePath:String): Unit = {
+		val pw = new java.io.PrintWriter(new File(filePath))
+		pw.write(content)
+		pw.close
+	}
 	def getInputStreamReader(filePath:String ):InputStreamReader = {
 		//read solver script from file
 		if(filePath == null)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/test/scripts/functions/tensor/PoolTest.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/PoolTest.R b/src/test/scripts/functions/tensor/PoolTest.R
index a34e0b0..aef0384 100644
--- a/src/test/scripts/functions/tensor/PoolTest.R
+++ b/src/test/scripts/functions/tensor/PoolTest.R
@@ -32,7 +32,7 @@ pad=as.integer(args[7])
 # Assumption: NCHW image format
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, numChannels*imgSize*imgSize, byrow=TRUE)
 if(as.logical(args[9])) {
-	zero_mask = (x - mean(x)) > 0 
+	zero_mask = (x - 1.5*mean(x)) > 0 
 	x = x * zero_mask
 } else {
 	x = x - mean(x)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/test/scripts/functions/tensor/PoolTest.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/PoolTest.dml b/src/test/scripts/functions/tensor/PoolTest.dml
index cc8132f..5246a2d 100644
--- a/src/test/scripts/functions/tensor/PoolTest.dml
+++ b/src/test/scripts/functions/tensor/PoolTest.dml
@@ -30,7 +30,7 @@ poolMode=$8
 # Assumption: NCHW image format
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), rows=numImg, cols=numChannels*imgSize*imgSize)
 if($10) {
-	zero_mask = (x - mean(x)) > 0 
+	zero_mask = (x - 1.5*mean(x)) > 0 
 	x = x * zero_mask
 }
 else {