You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/07/16 21:58:24 UTC

systemml git commit: [SYSTEMML-445] Added a rewrite for batch normalization train

Repository: systemml
Updated Branches:
  refs/heads/master 5aadb4b22 -> a6bca8851


[SYSTEMML-445] Added a rewrite for batch normalization train

- This PR fuses a batch normalization train pattern into a FunctionOp. The method batchNormTrain in RewriteGPUSpecificOps performs the fusing.
- This rewrite is only enabled if none of the outputs are persistent writes. It replaces the existing outputs of the matched pattern with transient reads.

Closes #800.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a6bca885
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a6bca885
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a6bca885

Branch: refs/heads/master
Commit: a6bca88512f3f542278709713706d256fad2cc17
Parents: 5aadb4b
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Mon Jul 16 14:50:36 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Mon Jul 16 14:52:01 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/FunctionOp.java  |  28 +-
 src/main/java/org/apache/sysml/hops/Hop.java    |   1 +
 .../hops/rewrite/RewriteGPUSpecificOps.java     | 573 ++++++++++++++++++-
 .../org/apache/sysml/lops/FunctionCallCP.java   |  12 +-
 src/main/java/org/apache/sysml/lops/Lop.java    |  13 +
 .../org/apache/sysml/parser/DMLTranslator.java  |   6 +-
 .../instructions/GPUInstructionParser.java      |   1 +
 .../instructions/gpu/DnnGPUInstruction.java     |  62 +-
 .../apache/sysml/test/gpu/BatchNormTest.java    |  35 +-
 .../org/apache/sysml/test/gpu/GPUTests.java     |  37 +-
 10 files changed, 701 insertions(+), 67 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/FunctionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index 64963d9..aedaf81 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -169,17 +169,20 @@ public class FunctionOp extends Hop
 				long outputValues = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 1, 1.0);
 				return outputVectors+outputValues; 
 			}
-			else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward")  ) {
+			else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) {
 				// TODO: To allow for initial version to always run on the GPU
 				return 0; 
 			}
-			else if ( getFunctionName().equalsIgnoreCase("batch_norm2d") ) {
+			else if ( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
 				return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) +
 						OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) +
 						OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0) +
 						OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), getOutputs().get(3).getDim2(), 1.0) + 
 						OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), getOutputs().get(4).getDim2(), 1.0);
 			}
+			else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_test") ) {
+			return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0);
+		}
 			else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) {
 				return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) +
 						OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) +
@@ -215,7 +218,8 @@ public class FunctionOp extends Hop
 				return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0) 
 						+ 3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 1.0); 
 			}
-			else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
+			else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
+					getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
 				return 0; 
 			}
 			else if ( getFunctionName().equalsIgnoreCase("lstm") ||  getFunctionName().equalsIgnoreCase("lstm_backward") ) {
@@ -240,7 +244,8 @@ public class FunctionOp extends Hop
 	@Override
 	public boolean isGPUEnabled() {
 		if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ||  
-			getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) 
+			getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
+			getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) 
 			return true;
 		else
 			return false;
@@ -283,20 +288,25 @@ public class FunctionOp extends Hop
 		checkAndSetForcedPlatform();
 		
 		if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) {
+			boolean isBuiltinFunction = isBuiltinFunction();
 			// check if there is sufficient memory to execute this function
-			if( getFunctionName().equalsIgnoreCase("transformencode") ) {
+			if(isBuiltinFunction && getFunctionName().equalsIgnoreCase("transformencode") ) {
 				_etype = ((_etypeForced==ExecType.SPARK 
 					|| (getMemEstimate() >= OptimizerUtils.getLocalMemBudget()
 						&& OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP);
 			}
-			else if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward")) {
+			else if(isBuiltinFunction && (getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward"))) {
 				if(!DMLScript.USE_ACCELERATOR)
 					throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU.");
 				_etype = ExecType.GPU;
 			}
-			else if( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
+			else if(isBuiltinFunction && (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))) {
 				_etype = DMLScript.USE_ACCELERATOR ? ExecType.GPU : ExecType.CP;
 			}
+			else if(isBuiltinFunction && getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
+				// Only GPU implementation is supported
+				_etype = ExecType.GPU;
+			}
 			else {
 				// Since the memory estimate is only conservative, do not throw
 				// exception if the estimated memory is larger than the budget
@@ -312,6 +322,10 @@ public class FunctionOp extends Hop
 		
 		return _etype;
 	}
+	
+	private boolean isBuiltinFunction() {
+		return getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE);
+	}
 
 	@Override
 	public void refreshSizeInformation()

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 5d357c6..d8f4424 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1537,6 +1537,7 @@ public abstract class Hop implements ParseInfo
 		HopsData2String.put(DataOpTypes.PERSISTENTWRITE, "PWrite");
 		HopsData2String.put(DataOpTypes.TRANSIENTWRITE, "TWrite");
 		HopsData2String.put(DataOpTypes.TRANSIENTREAD, "TRead");
+		HopsData2String.put(DataOpTypes.FUNCTIONOUTPUT, "FunOut");
 	}
 
 	public static OpOp2 getOpOp2ForOuterVectorOperation(String op) 

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index 1c00c6f..b946178 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -20,47 +20,64 @@
 package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.FunctionOp;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.FunctionOp.FunctionType;
 import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.DataOpTypes;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.Hop.OpOpDnn;
 import org.apache.sysml.hops.Hop.ReOrgOp;
+import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.DnnOp;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
 
 /*
  * This class contains GPU-specific rewrites for following patterns:
  * 
- * 1. batchNormTest:
+ * 1. batchNormTest: applied when mode="test" in batch normalization nn layer.
  * norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
  * hi = bias_add(bias_multiply(norm, gamma), beta)
  * 
  * 2. channelSum:
  * output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))
+ * 
+ * 3. batchNormTrain: applied when mode="train" in batch normalization nn layer.
+ * This rewrite is only enabled if none of the outputs are persistent writes as it assumes that 
+ * FunctionOp will introduce a transient writes. This rewrite replaces the existing outputs of the matched pattern with transient reads.
+ * 
  */
 public class RewriteGPUSpecificOps extends HopRewriteRule {
 
+	private static int _seq = 1;
+	
 	@Override
 	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
 		if( roots == null )
 			return roots;
 
 		//one pass rewrite-descend (rewrite created pattern)
-		for( Hop h : roots )
-			rule_GPUKernels( h, false );
+		for( int i = 0; i < roots.size(); i++ )
+			rule_GPUKernels(roots, roots.get(i), false );
 		Hop.resetVisitStatus(roots, true);
 
 		//one pass descend-rewrite (for rollup) 
-		for( Hop h : roots )
-			rule_GPUKernels( h, true );
+		for( int i = 0; i < roots.size(); i++ )
+			rule_GPUKernels(roots, roots.get(i), true );
 		Hop.resetVisitStatus(roots, true);
 		
 		return roots;
@@ -72,12 +89,12 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
 			return root;
 		
 		//one pass rewrite-descend (rewrite created pattern)
-		rule_GPUKernels( root, false );
+		rule_GPUKernels(null, root, false );
 		
 		root.resetVisitStatus();
 		
 		//one pass descend-rewrite (for rollup) 
-		rule_GPUKernels( root, true );
+		rule_GPUKernels(null, root, true );
 		
 		return root;
 	}
@@ -85,10 +102,11 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
 	/**
 	 * Fuse the kernel
 	 * 
+	 * @param roots root operators
 	 * @param hop high-level operator
 	 * @param descendFirst true if recursively process children first
 	 */
-	private void rule_GPUKernels(Hop hop, boolean descendFirst) 
+	private void rule_GPUKernels(ArrayList<Hop> roots, Hop hop, boolean descendFirst) 
 	{
 		if(hop.isVisited())
 			return;
@@ -99,13 +117,16 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
 			
 			//process childs recursively first (to allow roll-up)
 			if( descendFirst )
-				rule_GPUKernels(hi, descendFirst); //see below
+				rule_GPUKernels(roots, hi, descendFirst); //see below
 			
+			if(roots != null) {
+				hi = batchNormTrain(roots, hop, hi, i);
+			}
 			hi = batchNormTest(hop, hi, i); 
 			hi = channelSums(hop, hi, i); 
 	
 			if( !descendFirst )
-				rule_GPUKernels(hi, descendFirst);
+				rule_GPUKernels(roots, hi, descendFirst);
 		}
 
 		hop.setVisited();
@@ -149,6 +170,10 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
 				memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget();
 	}
 	
+	private static boolean hasFirstInput(Hop h) {
+		return !(h == null || h.getInput() == null || h.getInput().size() < 1);
+	}
+	
 	private static Hop getFirstInput(Hop h) {
 		if(h == null || h.getInput() == null || h.getInput().size() < 1) {
 			throw new RuntimeException("No input available for " + h);
@@ -156,13 +181,24 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
 		return h.getInput().get(0);
 	}
 	
+	private static boolean hasSecondInput(Hop h) {
+		return !(h == null || h.getInput() == null || h.getInput().size() < 2);
+	}
+	
 	private static Hop getSecondInput(Hop h) {
 		if(h == null || h.getInput() == null || h.getInput().size() < 2) {
-			throw new RuntimeException("No input available for " + h);
+			throw new RuntimeException("Expected atleast two inputs for " + h);
 		}
 		return h.getInput().get(1);
 	}
 	
+	private static Hop getThirdInput(Hop h) {
+		if(h == null || h.getInput() == null || h.getInput().size() < 3) {
+			throw new RuntimeException("Expected atleast three inputs for " + h);
+		}
+		return h.getInput().get(2);
+	}
+	
 	private static boolean isUnaryMinus(Hop h) {
 		return HopRewriteUtils.isBinary(h, OpOp2.MINUS)
 			&& HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0);
@@ -200,13 +236,488 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
 		return hi;
 	}
 	
+	private static boolean isRowMeans(Hop h) {
+		return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row; 
+	}
+	
+	private static boolean isRowVars(Hop h) {
+		return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row; 
+	}
+	
+	private static boolean isRowVars(Hop h, Hop childHop) {
+		return isRowVars(h) && getFirstInput(h) == childHop; 
+	}
+	
+	private static boolean isColMeans(Hop h) {
+		return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col; 
+	}
+	
+	private static boolean isColVars(Hop h) {
+		return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col; 
+	}
+	
+	private static boolean isReshape(Hop h) {
+		return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE;
+	}
+	
+	private static boolean isReshape(Hop h, long expectedRows, long expectedCols) {
+		return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE &&
+				Hop.computeSizeInformation(getSecondInput(h)) == expectedRows && 
+				Hop.computeSizeInformation(getThirdInput(h)) == expectedCols;
+	}
+	
+	private static boolean isBinaryAdd(Hop h) {
+		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS;
+	}
+	
+	private static boolean isBinaryMSAdd(Hop h, double expectedValue) {
+		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS 
+				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
+				&& OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue;
+	}
+	
+	private static boolean isBinaryMMAdd(Hop h) {
+		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS 
+				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
+	}
+	
+	private static boolean isBinaryMSMult(Hop h, double expectedValue) {
+		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT 
+				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
+				&& OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue;
+	}
+	
+	private static boolean isBinarySSMinus(Hop h) {
+		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS 
+				&& getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
+	}
+	
+	private static boolean isBinarySSDiv(Hop h) {
+		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV 
+				&& getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
+	}
+	
+	private static boolean isBinarySMDiv(Hop h, double expectedValue) {
+		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV 
+				&& getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.MATRIX 
+				&& OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>()) == expectedValue;
+	}
+	
+	private static boolean isAnyBinaryAdd(ArrayList<Hop> hops) {
+		if(hops != null) {
+			for(Hop h : hops) {
+				if(h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS)
+					return true;
+			}
+		}
+		return false;
+	}
+	
+	private static boolean isBinaryMSMult(Hop h) {
+		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT 
+				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR;
+	}
+	
+	private static boolean isBinarySMMult(Hop h) {
+		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT 
+				&& getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR;
+	}
+	
+	/**
+	 * Checks if the "mean" hop is a moving average of mean in batch normalization layer.
+	 *  
+	 * @param mean hop to check against
+	 * @param X input data
+	 * @return true if the "mean" hop is a moving average of mean in batch normalization layer.
+	 */
+	private static boolean isBatchNormTrainMean(Hop mean, Hop X) {
+		// subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
+		// mean = rowMeans(subgrp_means)
+		return isRowMeans(mean) && isReshape(getFirstInput(mean)) && isColMeans(getFirstInput(getFirstInput(mean)))
+				&& getFirstInput(getFirstInput(getFirstInput(mean))) == X;
+	}
+	
+	/**
+	 * Checks for nrow(X) pattern
+	 * 
+	 * @param expr hop to be matched
+	 * @param X input X
+	 * @return true if expr is nrow(X) else false
+	 */
+	private static boolean isNrowOfX(Hop expr, Hop X) {
+		return expr instanceof UnaryOp && ((UnaryOp)expr).getOp() == OpOp1.NROW && getFirstInput(expr) == X;
+	}
+	
+	/**
+	 * Checks for the colVars(X) * ((N-1)/N) pattern
+	 * 
+	 * @param expr hop to be matched
+	 * @param X input X
+	 * @param ignoreCorrectionTerm whether to ignore the correction term ((N-1)/N).
+	 * @return true if expr is colVars(X) * ((N-1)/N) else false
+	 */
+	private static boolean isCorrectedColVars(Hop expr, Hop X, boolean ignoreCorrectionTerm) {
+		// colVars(X) * ((N-1)/N)
+		if(isColVars(expr) && getFirstInput(expr) == X) {
+			// Support no correction as well in this rewrite
+			return true;
+		}
+		else if(X.rowsKnown()) {
+			return isBinaryMSMult(expr, ((double)X.getDim1()-1)/X.getDim1()) && 
+					isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X;
+		}
+		else if(isBinaryMSMult(expr) && 
+				isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X) {
+			if(ignoreCorrectionTerm) {
+				return true;
+			}
+			Hop tmp = getSecondInput(expr);
+			// ((N-1)/N)
+			boolean isNMinus1Pattern = isBinarySSDiv(tmp) && isBinarySSMinus(getFirstInput(tmp)) &&
+					getFirstInput(getFirstInput(tmp)) == getSecondInput(tmp) && 
+					OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(getFirstInput(tmp)), new HashMap<>()) == 1;
+			boolean ret = isNMinus1Pattern && isNrowOfX(getSecondInput(tmp), X);
+			if(LOG.isDebugEnabled()) {
+				LOG.debug("Is the corrected column variance pattern for batch_norm_train rewrite when number of rows of X unknown matched:" + ret);
+			}
+			return ret;
+		}
+		return false;
+	}
+	
+	/**
+	 * Checks if the "var" hop is a moving average of variance in batch normalization layer.
+	 *  
+	 * @param mean previously matched mean hop
+	 * @param var the hop to check against
+	 * @param X input data hop
+	 * @param subgrpMeans mean for subgroup mean
+	 * @param ignoreCorrectionTerm whether to incore the correct term  (see isCorrectedColVars method in this class)
+	 * @return true if the "var" hop is a moving average of variance in batch normalization layer.
+	 */
+	private static boolean isBatchNormTrainVar(Hop mean, Hop var, Hop  X, Hop subgrpMeans, boolean ignoreCorrectionTerm) {
+		long numChannels = Hop.computeSizeInformation(getSecondInput(getFirstInput(mean)));
+		long HW = Hop.computeSizeInformation(getThirdInput(getFirstInput(mean)));
+		// subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
+		// var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+		return numChannels > 0 && HW > 0 && isBinaryMMAdd(var) && isRowMeans(getFirstInput(var)) &&  
+				// matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
+				isReshape(getFirstInput(getFirstInput(var)), numChannels, HW) &&
+				isCorrectedColVars(getFirstInput(getFirstInput(getFirstInput(var))), X, ignoreCorrectionTerm) &&
+				// rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+				isBinaryMSMult(getSecondInput(var), ((((double)HW)-1)/HW)) && 
+				isRowVars(getFirstInput(getSecondInput(var)), subgrpMeans);
+	}
+	
+	/**
+	 * Checks and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean  
+	 * 
+	 * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean 
+	 * @param mu value of mu
+	 * @return an array [ema_mean_upd, ema_mean] if expression matched, else null
+	 */
+	private static Hop [] getUpdatedMovingAverageExpressions(Hop rhsTimesOp, double mu) {
+		if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || 
+				!isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
+			return null;
+		
+		// Check (1-mu)*mean
+		double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>());
+		Hop plusOp = rhsTimesOp.getParent().get(0); 
+		Hop lhsTimesOp = null;
+		if(plusOp.getInput().get(0) == rhsTimesOp) {
+			lhsTimesOp = plusOp.getInput().get(1); 
+		}
+		else {
+			lhsTimesOp = plusOp.getInput().get(0);
+		}
+		
+		if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 &&  
+			isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) {
+			return new Hop[] {
+				plusOp.getParent().get(0),
+				getSecondInput(lhsTimesOp), 
+				getSecondInput(rhsTimesOp)
+			};
+		}
+		return null;
+	}
+	
+	/**
+	 * Checks (if exactly one of rhsTimesOps) and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean  
+	 * 
+	 * @param rhsTimesOps array list of hop representing BinaryOp of expression (1-mu)*mean 
+	 * @param mu value of mu
+	 * @return an array [ema_mean_upd, ema_mean] if any of the expression matched, else null
+	 */
+	private static Hop [] getUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps, double mu) {
+		if(rhsTimesOps == null || rhsTimesOps.size() == 0)
+			return null;
+		
+		Hop [] ret = null;
+		for(Hop h : rhsTimesOps) {
+			boolean matched = isUpdatedMovingAverageExpression(h, mu);
+			if(matched && ret != null) {
+				return null; // Multiple matches, cannot decide which one to fuse
+			}
+			else if(matched) {
+				ret = getUpdatedMovingAverageExpressions(h, mu);
+			}
+		}
+		
+		return ret;
+	}
+	
+	/**
+	 * Checks and returns the mu in the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
+	 * 
+	 * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean
+	 * @return value of mu if the expression matched else null 
+	 */
+	private static Double getMuFromUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps) {
+		if(rhsTimesOps == null || rhsTimesOps.size() == 0)
+			return null;
+		
+		Double ret = null; 
+		for(Hop h : rhsTimesOps) {
+			boolean matched = isUpdatedMovingAverageExpression(h);
+			if(matched && ret != null) {
+				return null; // Multiple matches, cannot decide which one to fuse
+			}
+			else if(matched) {
+				ret = -(OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>())-1);
+			}
+		}
+		return ret;
+	}
+	
+	/**
+	 * Checks for the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
+	 * 
+	 * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean
+	 * @return true if expression matched
+	 */
+	private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp) {
+		if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || 
+				!isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
+			return false;
+		
+		// Check (1-mu)*mean
+		Hop plusOp = rhsTimesOp.getParent().get(0); 
+		Hop lhsTimesOp = null;
+		if(plusOp.getInput().get(0) == rhsTimesOp) {
+			lhsTimesOp = plusOp.getInput().get(1); 
+		}
+		else {
+			lhsTimesOp = plusOp.getInput().get(0);
+		}
+		
+		if(plusOp.getParent() != null && plusOp.getParent().size() == 1 && isBinarySMMult(lhsTimesOp)) {
+			return true;
+		}
+		return false;
+	}
+	
+	// ema_mean_upd = mu*ema_mean + (1-mu)*mean
+	// Returns true if expression matched, else false
+	private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp, double mu) {
+		if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || 
+				!isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
+			return false;
+		
+		// Check (1-mu)*mean
+		double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>());
+		Hop plusOp = rhsTimesOp.getParent().get(0); 
+		Hop lhsTimesOp = null;
+		if(plusOp.getInput().get(0) == rhsTimesOp) {
+			lhsTimesOp = plusOp.getInput().get(1); 
+		}
+		else {
+			lhsTimesOp = plusOp.getInput().get(0);
+		}
+		
+		if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 &&  
+			isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) {
+			return true;
+		}
+		return false;
+	}
+	
+	/**
+	 * Checks for the expression 1/sqrt(denom)
+	 * 
+	 * @param denom denominator of the expression to be matched
+	 * @return true if the expression 1/sqrt(denom) matched else false
+	 */
+	private static boolean isOneBySqrt(Hop denom) {
+		return denom.getParent() != null && denom.getParent().get(0) instanceof UnaryOp &&
+				((UnaryOp)denom.getParent().get(0)).getOp() == OpOp1.SQRT &&
+				denom.getParent().get(0).getParent() != null && denom.getParent().get(0).getParent().size() == 1 &&
+				isBinarySMDiv(denom.getParent().get(0).getParent().get(0), 1);
+	}
+	
+	/**
+	 * Checks for the batch norm (mode="train") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar
+	 * and returns a new FunctionOp if matched
+	 * 
+	 * @param roots root hops of the given statement block
+	 * @param parent parent of the input
+	 * @param hi input to be matched
+	 * @param pos position
+	 * @return a new FunctionOp or hi
+	 */
+	private static Hop batchNormTrain(ArrayList<Hop> roots, Hop parent, Hop hi, int pos) 
+	{		
+		// norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
+		// hi = bias_add(bias_multiply(norm, gamma), beta)
+		// 2x for input and output and 1x for overhead
+		// fitsOnGPU(hi, 3)
+		if( hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) ) {	
+			Hop norm = getFirstInput(getFirstInput(hi));
+			if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) 
+					&& hasSecondInput(getFirstInput(norm)) && isUnaryMinus(getSecondInput(getFirstInput(norm)))
+					&& isOneDivideBySqrt(getSecondInput(norm))) {
+				double eps = 0;
+				Hop var = getFirstInput(getSecondInput(getSecondInput(norm)));
+				if(isBinaryAdd(var) && (getFirstInput(var) instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) {
+					// eps + ema_var
+					if(getFirstInput(var) instanceof LiteralOp) {
+						eps = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>());
+						var = getSecondInput(var);
+					}
+					else {
+						eps = OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new HashMap<>());
+						var = getFirstInput(var);
+					}
+				}
+				// Generate batch norm test op
+				Hop X = getFirstInput(getFirstInput(norm));
+				Hop mean = getSecondInput(getSecondInput(getFirstInput(norm)));
+				
+				if(hasFirstInput(mean) && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), false) &&
+					mean.getParent() != null && mean.getParent().size() >= 2 && 
+					var.getParent() != null && var.getParent().size() == 2) {
+					Hop gamma = getSecondInput(getFirstInput(hi));
+					Hop beta = getSecondInput(hi);
+					
+					// Always get mu from variance as it will have exactly one match of fusion pattern
+					Double potentialMu = getMuFromUpdatedMovingAverageExpressions(var.getParent());
+					if(potentialMu == null)
+						return hi;
+					double mu = potentialMu;
+					
+					Hop [] means = getUpdatedMovingAverageExpressions(mean.getParent(), mu);
+					Hop [] vars = getUpdatedMovingAverageExpressions(var.getParent(), mu);
+					if(means == null || vars == null)
+						return hi;
+					
+					Hop varPlusEps = null;
+					boolean isFirstBinaryAddOp = isAnyBinaryAdd(var.getParent().get(0).getParent());
+                    boolean isSecondBinaryAddOp = isAnyBinaryAdd(var.getParent().get(1).getParent());
+                    if(isFirstBinaryAddOp && !isSecondBinaryAddOp) {
+                            varPlusEps = var.getParent().get(1);
+                    }
+                    else if(!isFirstBinaryAddOp && isSecondBinaryAddOp) {
+                            varPlusEps = var.getParent().get(0);
+                    }
+					if(varPlusEps != null && isBinaryMSAdd(varPlusEps, eps) && isOneBySqrt(varPlusEps)) {
+						
+						Hop cache_var = varPlusEps.getParent().get(0).getParent().get(0);
+						Hop ema_mean_upd = means[0];
+						Hop ema_var_upd = vars[0];
+						Hop ema_mean = means[1];
+						Hop ema_var = vars[1];
+						Hop cache_mean = means[2];
+						
+						
+						ArrayList<Hop> inHops = new ArrayList<Hop>();
+						inHops.add(X);
+						inHops.add(gamma);
+						inHops.add(beta);
+						inHops.add(ema_mean);
+						inHops.add(ema_var);
+						inHops.add(new LiteralOp(eps));
+						inHops.add(new LiteralOp(mu));
+						Hop [] oldHops = {hi, ema_mean_upd, ema_var_upd, cache_mean, cache_var};
+						
+						// Since FunctionOp adds transientwrite explicitly, persistent writes are not supported
+						if(!isAnyPersistentWrite(oldHops)) {
+							LOG.debug("Applied batchNormTrain rewrite.");
+							ArrayList<Hop> outputs = getMultiOutputHops(roots, oldHops);
+							FunctionOp ret = new FunctionOp(FunctionType.MULTIRETURN_BUILTIN, DMLProgram.INTERNAL_NAMESPACE, "batch_norm2d_train", 
+									inHops, outputs.stream().map(h -> h.getName()).toArray(String[]::new), outputs);
+							Collections.reverse(roots);
+							roots.add(ret);
+							Collections.reverse(roots);
+							return ret;
+						}
+					}
+					
+				}
+			}			
+		}
+		
+		return hi;
+	}
+	
+	// ------------------------------------------------------------
+	/**
+	 * Checks if any of the given output hop is a persistent write.
+	 * 
+	 * @param outputHops output hops to check
+	 * @return true if any of the hop is a persistent write else false.
+	 */
+	private static boolean isAnyPersistentWrite(Hop [] outputHops) {
+		for(Hop outHop : outputHops) {
+			if(HopRewriteUtils.isData(outHop, DataOpTypes.PERSISTENTWRITE))
+				return true;
+		}
+		return false;
+	}
+	
+	/**
+	 * Returns output hop for a multi-output FunctionOp to be created by rewrite.
+	 * 
+	 * @param roots root hops of statement block
+	 * @param oldHops old output hops of the pattern
+	 * @return new output hops that should be passed to FunctionOp
+	 */
+	private static ArrayList<Hop> getMultiOutputHops(ArrayList<Hop> roots, Hop [] oldHops) {
+		ArrayList<Hop> ret = new ArrayList<>();
+		for(int i = 0; i < oldHops.length; i++) {
+			// Create a transient read as FunctionOp will add a transient write.
+			if(HopRewriteUtils.isData(oldHops[i], DataOpTypes.PERSISTENTWRITE))
+				throw new RuntimeException("Persistent write is not supported as output for the given rewrite." + oldHops[i]);
+			// Generate a new name if the old output was not transient write.
+			String name = HopRewriteUtils.isData(oldHops[i], DataOpTypes.TRANSIENTWRITE) ? oldHops[i].getName() : "_genGPU" + (_seq++);
+			DataOp tRead = HopRewriteUtils.createTransientRead(name, oldHops[i]);
+			HopRewriteUtils.rewireAllParentChildReferences(oldHops[i], tRead);
+			ret.add(tRead);
+			// Remove old output from roots to avoid unnecessary computation.
+			if(roots.contains(oldHops[i])) {
+				roots.remove(oldHops[i]);
+			}
+		}
+		return ret;
+	}
+	// ------------------------------------------------------------
+	
+	/**
+	 * Checks for the batch norm (mode="test") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar
+	 * and returns a new DnnOp if matched
+	 * 
+	 * @param parent parent of the input
+	 * @param hi input to be matched
+	 * @param pos position
+	 * @return a new DnnOp or hi
+	 */
 	private static Hop batchNormTest(Hop parent, Hop hi, int pos) {
 		// norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
 		// hi = bias_add(bias_multiply(norm, gamma), beta)
 		// 2x for input and output and 1x for overhead
-		if( isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) {
+		if(hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) {
 			Hop norm = getFirstInput(getFirstInput(hi));
-			if(isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) 
+			if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) 
 					&& isUnaryMinus(getSecondInput(getFirstInput(norm)))
 					&& isOneDivideBySqrt(getSecondInput(norm))) {
 				double eps = 0;
@@ -226,20 +737,28 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
 				// Generate batch norm test op
 				Hop X = getFirstInput(getFirstInput(norm));
 				Hop mean = getSecondInput(getSecondInput(getFirstInput(norm)));
-				Hop gamma = getSecondInput(getFirstInput(hi));
-				Hop beta = getSecondInput(hi);
-				ArrayList<Hop> inHops = new ArrayList<Hop>();
-				inHops.add(X);
-				inHops.add(gamma);
-				inHops.add(beta);
-				inHops.add(mean);
-				inHops.add(var);
-				inHops.add(new LiteralOp(eps));
-				if(fitsOnGPU(inHops, true)) {
-					LOG.debug("Applied batchNormTest rewrite.");
-					Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
-							OpOpDnn.BATCH_NORM2D_TEST, inHops);
-					return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+				
+				// This guard disallows eager fusion of train batch normalization into test batch normalization
+				boolean potentialForBatchNormTrain = !X.rowsKnown() && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), true);
+				if(!potentialForBatchNormTrain) {
+					Hop gamma = getSecondInput(getFirstInput(hi));
+					Hop beta = getSecondInput(hi);
+					ArrayList<Hop> inHops = new ArrayList<Hop>();
+					inHops.add(X);
+					inHops.add(gamma);
+					inHops.add(beta);
+					inHops.add(mean);
+					inHops.add(var);
+					inHops.add(new LiteralOp(eps));
+					if(fitsOnGPU(inHops, true)) {
+						LOG.debug("Applied batchNormTest rewrite.");
+						Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
+								OpOpDnn.BATCH_NORM2D_TEST, inHops);
+						return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+					}
+				}
+				else {
+					LOG.debug("Skipping batchNormTest rewrite as there is potential for batch normalization train rewrite after recompilation.");
 				}
 			}
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
index ac58335..711219a 100644
--- a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
+++ b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
@@ -42,8 +42,16 @@ public class FunctionCallCP extends Lop
 		this(inputs, fnamespace, fname, outputs, et);
 		if(outputHops != null) {
 			_outputLops = new ArrayList<>();
-			for(Hop h : outputHops)
-				_outputLops.add( h.constructLops() );
+			setLevel();
+			for(Hop h : outputHops) {
+				Lop outputLop = h.constructLops();
+				_outputLops.add( outputLop );
+				addOutput(outputLop);
+				// Update the output level if necessary for correct instruction ordering
+				if(outputLop.getLevel() <= getLevel()) {
+					outputLop.updateLevel(getLevel()+1);
+				}
+			}
 		}
 	}
 	

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/lops/Lop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java
index 9e81496..885b8b9 100644
--- a/src/main/java/org/apache/sysml/lops/Lop.java
+++ b/src/main/java/org/apache/sysml/lops/Lop.java
@@ -345,6 +345,19 @@ public abstract class Lop
 		lps.setLevel(inputs);
 	}
 	
+	protected void updateLevel(int newLevel) {
+		if(newLevel < getLevel()) {
+			throw new RuntimeException("Decrement the levels not supported.");
+		}
+		else if(newLevel > getLevel()) {
+			lps.setLevel(newLevel);
+			for(Lop out : outputs) {
+				if(out.getLevel() < newLevel+1)
+					out.updateLevel(newLevel+1);
+			}
+		}
+	}
+	
 	/**
 	 * Method to get the location property of LOP
 	 * 

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index b9e5f9d..7cf7418 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2000,8 +2000,8 @@ public class DMLTranslator
 				String[] outputNames = new String[targetList.size()]; 
 				outputNames[0] = ((DataIdentifier)targetList.get(0)).getName();
 				outputNames[1] = ((DataIdentifier)targetList.get(1)).getName();
-				outputs.add(new DataOp(outputNames[0], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[0]));
-				outputs.add(new DataOp(outputNames[1], DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[1]));
+				outputs.add(new DataOp(outputNames[0], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, inputs.get(0).getFilename()));
+				outputs.add(new DataOp(outputNames[1], DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, inputs.get(0).getFilename()));
 				
 				currBuiltinOp = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), inputs, outputNames, outputs);
 				break;
@@ -2233,7 +2233,7 @@ public class DMLTranslator
 			String[] outputNames = new String[targetList.size()]; 
 			for ( int i=0; i < targetList.size(); i++ ) {
 				outputNames[i] = ((DataIdentifier)targetList.get(i)).getName();
-				Hop output = new DataOp(outputNames[i], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[i]);
+				Hop output = new DataOp(outputNames[i], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, inputs.get(0).getFilename());
 				outputs.add(output);
 			}
 			

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 1122a24..01b10a8 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -61,6 +61,7 @@ public class GPUInstructionParser  extends InstructionParser
 		String2GPUInstructionType.put( "batch_norm2d",           GPUINSTRUCTION_TYPE.Dnn);
 		String2GPUInstructionType.put( "batch_norm2d_backward",  GPUINSTRUCTION_TYPE.Dnn);
 		String2GPUInstructionType.put( "batch_norm2d_test",      GPUINSTRUCTION_TYPE.Dnn);
+		String2GPUInstructionType.put( "batch_norm2d_train",      GPUINSTRUCTION_TYPE.Dnn);
 		
 		// Matrix Multiply Operators
 		String2GPUInstructionType.put( "ba+*",  GPUINSTRUCTION_TYPE.AggregateBinary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index b01b8d8..a36d0fc 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -351,12 +351,28 @@ public class DnnGPUInstruction extends GPUInstruction {
 			CPOperand out = new CPOperand(parts[7]);
 			return new DnnGPUInstruction(in, in2, in3, in4, in5, in6, out, opcode, str, 0);
 		}
+		else if (opcode.equalsIgnoreCase("batch_norm2d_train")) {
+			InstructionUtils.checkNumFields(parts, 12);
+			CPOperand in1 = new CPOperand(parts[1]); // image
+			CPOperand in2 = new CPOperand(parts[2]); // gamma
+			CPOperand in3 = new CPOperand(parts[3]); // beta
+			CPOperand in4 = new CPOperand(parts[4]); // ema_mean
+			CPOperand in5 = new CPOperand(parts[5]); // ema_var
+			CPOperand in6 = new CPOperand(parts[6]); // eps
+			CPOperand in7 = new CPOperand(parts[7]); // mu
+			CPOperand out = new CPOperand(parts[8]);  // out
+			CPOperand out2 = new CPOperand(parts[9]); // ema_mean_upd
+			CPOperand out3 = new CPOperand(parts[10]); // ema_var_upd
+			CPOperand out4 = new CPOperand(parts[11]); // cache_mean
+			CPOperand out5 = new CPOperand(parts[12]); // cache_inv_var
+			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0);
+		}
 		else {
 			throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str);	
 		}
 	}
 
-	public void processBiasInstruction(String instOpcode, ExecutionContext ec) {
+	private void processBiasInstruction(String instOpcode, ExecutionContext ec) {
 		GPUStatistics.incrementNoOfExecutedGPUInst();
 		MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
 		MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName());
@@ -372,7 +388,7 @@ public class DnnGPUInstruction extends GPUInstruction {
 		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
 	}
 	
-	public void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException {
+	private void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException {
 		GPUStatistics.incrementNoOfExecutedGPUInst();
 		MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
 		MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
@@ -420,7 +436,41 @@ public class DnnGPUInstruction extends GPUInstruction {
 		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
 	}
 	
-	public void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException {
+	private void processBatchNorm2dTrainInstruction(ExecutionContext ec) throws DMLRuntimeException {
+		GPUStatistics.incrementNoOfExecutedGPUInst();
+		MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
+		MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
+		MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
+		MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName());
+		MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName());
+		
+		double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue();
+		double exponentialAverageFactor = 1-ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue();
+		
+		MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
+		MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
+		MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
+		MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
+		MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
+		
+		LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(), 
+			image, scale, bias, runningMean, runningVar, ret, 
+			retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
+		
+		// release inputs/outputs
+		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+		ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+		ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+		ec.releaseMatrixInputForGPUInstruction(_input4.getName());
+		ec.releaseMatrixInputForGPUInstruction(_input5.getName());
+		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+		ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
+		ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
+		ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
+		ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
+	}
+	
+	private void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException {
 		GPUStatistics.incrementNoOfExecutedGPUInst();
 		MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
 		MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
@@ -485,7 +535,7 @@ public class DnnGPUInstruction extends GPUInstruction {
 		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
 	}
 	
-	public void processChannelSumsInstruction(ExecutionContext ec) {
+	private void processChannelSumsInstruction(ExecutionContext ec) {
 		GPUStatistics.incrementNoOfExecutedGPUInst();
 		MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
 		int C = (int) ec.getScalarInput(_input2.getName(), _input2.getValueType(), _input2.isLiteral()).getLongValue();
@@ -667,6 +717,10 @@ public class DnnGPUInstruction extends GPUInstruction {
 			processBatchNorm2dTestInstruction(ec);
 			return;
 		}
+		else if (instOpcode.equalsIgnoreCase("batch_norm2d_train")) {
+			processBatchNorm2dTrainInstruction(ec);
+			return;
+		}
 		
 		GPUStatistics.incrementNoOfExecutedGPUInst();
 					

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
index 83adad4..b8bb9b6 100644
--- a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
@@ -46,10 +46,10 @@ public class BatchNormTest extends GPUTests {
 		testBatchNormForward("test");
 	}
 	
-//	@Test
-//	public void testBatchNormForwardTrain() {
-//		testBatchNormForward("train");
-//	}
+	@Test
+	public void testBatchNormForwardTrain() {
+		testBatchNormForward("train");
+	}
 	
 	private void testBatchNormForward(String mode) {
 		int imgSize = 32; 
@@ -58,18 +58,29 @@ public class BatchNormTest extends GPUTests {
 		String scriptStr = "source(\"nn/layers/batch_norm2d_old.dml\") as batch_norm2d_old;\n "
 				+ "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d_old::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, ema_var, 0.9, 1e-3)";
 		HashMap<String, Object> inputs = new HashMap<>();
-		inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 100, sparsity, seed));
-		inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 10, sparsity, seed));
-		inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 0, 10, sparsity, seed));
-		inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 1, 40, 60, sparsity, seed));
-		inputs.put("ema_var", generateInputMatrix(spark, numChannels, 1, 5, 15, sparsity, seed));
+		inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 10, sparsity, seed));
+		inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
+		inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
+		inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 1, 3, 7, sparsity, seed));
+		inputs.put("ema_var", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
 		List<String> outputs = Arrays.asList("output", "ema_mean_upd", "ema_var_upd", "cache_mean", "cache_var");
 		List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs);
 		List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
-		if(mode.equals("test"))
+		if(mode.equals("test")) {
 			assertHeavyHitterPresent("gpu_batch_norm2d_test");
-		for(int i = 0; i < outputs.size(); i++) {
-			assertEqualObjects(outCPU.get(i), outGPU.get(i));
+			for(int i = 0; i < outputs.size(); i++) {
+				assertEqualObjects(outCPU.get(i), outGPU.get(i));
+			}
+		}
+		else {
+			assertHeavyHitterPresent("gpu_batch_norm2d_train");
+			double [] threshold = new double[outputs.size()];
+			Arrays.fill(threshold, getTHRESHOLD());
+			// Handle loss of precision in CuDNN kernel 
+			threshold[2] = 1e-3;
+			for(int i = 0; i < outputs.size()-1; i++) {
+				assertEqualObjects(outCPU.get(i), outGPU.get(i), threshold[i]);
+			}
 		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
index e006fd2..cae2e33 100644
--- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
@@ -212,7 +212,7 @@ public abstract class GPUTests extends AutomatedTestBase {
 		return in1;
 	}
 	
-	private void printMatrixIfNotEqual(MatrixBlock expectedMB, MatrixBlock actualMB) {
+	private void printMatrixIfNotEqual(MatrixBlock expectedMB, MatrixBlock actualMB, double threshold) {
 		long rows = expectedMB.getNumRows();
 		long cols = expectedMB.getNumColumns();
 		boolean matrixNotEqual = false;
@@ -222,7 +222,7 @@ public abstract class GPUTests extends AutomatedTestBase {
 				double actualDouble = actualMB.quickGetValue(i, j);
 				if (expectedDouble != 0.0 && !Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) {
 					double relativeError = Math.abs((expectedDouble - actualDouble) / expectedDouble);
-					if(relativeError >= getTHRESHOLD()) {
+					if(relativeError >= threshold) {
 						matrixNotEqual = true;
 						break;
 					}
@@ -250,12 +250,13 @@ public abstract class GPUTests extends AutomatedTestBase {
 	 *
 	 * @param expected expected matrix
 	 * @param actual   actual matrix
+	 * @param threshold relative threshold
 	 */
-	private void assertEqualMatrices(Matrix expected, Matrix actual) {
+	private void assertEqualMatrices(Matrix expected, Matrix actual, double threshold) {
 		try {
 			// Faster way to compare two matrices
 			MLContext cpuMLC = new MLContext(spark);
-			String scriptStr = "num_mismatch = sum((abs(X - Y) / X) > " + getTHRESHOLD() + ");";
+			String scriptStr = "num_mismatch = sum((abs(X - Y) / X) > " + threshold + ");";
 			Script script = ScriptFactory.dmlFromString(scriptStr).in("X", expected).in("Y", actual).out("num_mismatch");
 			long num_mismatch = cpuMLC.execute(script).getLong("num_mismatch");
 			cpuMLC.close();
@@ -271,7 +272,7 @@ public abstract class GPUTests extends AutomatedTestBase {
 			Assert.assertEquals(rows, actualMB.getNumRows());
 			Assert.assertEquals(cols, actualMB.getNumColumns());
 
-			if(PRINT_MAT_ERROR) printMatrixIfNotEqual(expectedMB, actualMB);
+			if(PRINT_MAT_ERROR) printMatrixIfNotEqual(expectedMB, actualMB, threshold);
 			
 			for (int i = 0; i < rows; i++) {
 				for (int j = 0; j < cols; j++) {
@@ -285,12 +286,12 @@ public abstract class GPUTests extends AutomatedTestBase {
 								"Relative error(%f) is more than threshold (%f). Expected = %f, Actual = %f, differed at [%d, %d]",
 								relativeError, getTHRESHOLD(), expectedDouble, actualDouble, i, j);
 						if(FLOATING_POINT_PRECISION.equals("double"))
-							Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD());
+							Assert.assertTrue(format.toString(), relativeError < threshold);
 						else
-							Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD() || absoluteError < getTHRESHOLD());
+							Assert.assertTrue(format.toString(), relativeError < threshold || absoluteError < threshold);
 						format.close();
 					} else {
-						Assert.assertEquals(expectedDouble, actualDouble, getTHRESHOLD());
+						Assert.assertEquals(expectedDouble, actualDouble, threshold);
 					}
 				}
 			}
@@ -349,6 +350,7 @@ public abstract class GPUTests extends AutomatedTestBase {
 		// and other side effects.
 		synchronized(GPUTests.class) {
 			MLContext gpuMLC = new MLContext(spark);
+			// gpuMLC.setExplain(true); gpuMLC.setExplainLevel("recompile_runtime");
 			gpuMLC.setConfigProperty("sysml.floating.point.precision", FLOATING_POINT_PRECISION);
 			if(IGNORE_CLEAR_MEMORY_BUG)
 				gpuMLC.setConfigProperty("sysml.gpu.eager.cudaFree", "true");
@@ -366,7 +368,7 @@ public abstract class GPUTests extends AutomatedTestBase {
 			return outputs;
 		}
 	}
-
+	
 	/**
 	 * Assert that the two objects are equal. Supported types are Boolean, Integer, String, Double and Matrix
 	 *
@@ -374,6 +376,17 @@ public abstract class GPUTests extends AutomatedTestBase {
 	 * @param actual
 	 */
 	protected void assertEqualObjects(Object expected, Object actual) {
+		assertEqualObjects(expected, actual, getTHRESHOLD());
+	}
+
+	/**
+	 * Assert that the two objects are equal. Supported types are Boolean, Integer, String, Double and Matrix
+	 *
+	 * @param expected expected value
+	 * @param actual actual value 
+	 * @param threshold relative error threshold
+	 */
+	protected void assertEqualObjects(Object expected, Object actual, double threshold) {
 		Assert.assertEquals(expected.getClass(), actual.getClass());
 
 		if (expected instanceof Boolean) {
@@ -384,16 +397,16 @@ public abstract class GPUTests extends AutomatedTestBase {
 			if (expectedDouble != 0.0 && !Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) {
 				double relativeError = Math.abs((expectedDouble - actualDouble) / expectedDouble);
 				Assert.assertTrue("Comparing floating point numbers, relative error(" + relativeError
-						+ ") is more than threshold (" + getTHRESHOLD() + ")", relativeError < getTHRESHOLD());
+						+ ") is more than threshold (" + threshold + ")", relativeError < threshold);
 			} else {
-				Assert.assertEquals(expectedDouble, actualDouble, getTHRESHOLD());
+				Assert.assertEquals(expectedDouble, actualDouble, threshold);
 			}
 		} else if (expected instanceof String) {
 			Assert.assertEquals(expected.toString(), actual.toString());
 		} else if (expected instanceof Integer) {
 			Assert.assertEquals(((Integer) expected).intValue(), ((Integer) actual).intValue());
 		} else if (expected instanceof Matrix)
-			assertEqualMatrices((Matrix) expected, (Matrix) actual);
+			assertEqualMatrices((Matrix) expected, (Matrix) actual, threshold);
 		else {
 			Assert.fail("Invalid types for comparison");
 		}