You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/08/30 22:43:05 UTC

[1/3] systemml git commit: [SYSTEMML-445] Removed batch_norm builtin functions

Repository: systemml
Updated Branches:
  refs/heads/master 81419ae6a -> 0f36780a8


http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index e736a1c..d620de9 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -19,8 +19,8 @@
 package org.apache.sysml.runtime.instructions.gpu;
 
 import java.util.ArrayList;
-
 import jcuda.Pointer;
+import jcuda.jcudnn.JCudnn;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
@@ -32,7 +32,6 @@ import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.DnnUtils;
@@ -57,12 +56,14 @@ public class DnnGPUInstruction extends GPUInstruction {
 	private ArrayList<CPOperand> _stride = new ArrayList<>();
 	private ArrayList<CPOperand> _padding = new ArrayList<>();
 	private double _intermediateMemoryBudget = 0;
+	private GPUContext gCtx;
+	private String instName;
 	
 	public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) {
 		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
-		if (!(opcode.equals("bias_add") || opcode.equals("bias_multiply") || opcode.equals("relu_backward"))) {
+		if (!(opcode.equals("bias_add") || opcode.equals("bias_multiply") || opcode.equals("relu_backward") || opcode.equals("inv_var") )) {
 			throw new DMLRuntimeException(
-					"Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found "
+					"Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward or inv_var, but found "
 							+ opcode);
 		}
 		_input1 = in1;
@@ -112,8 +113,8 @@ public class DnnGPUInstruction extends GPUInstruction {
 	public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, 
 			double intermediateMemoryBudget) throws DMLRuntimeException {
 		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
-		if( !opcode.equals("channel_sums") ) {
-			throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode);
+		if( !(opcode.equals("channel_sums") || opcode.equals("reshape_colmeans") || opcode.equals("update_ema") ) ) {
+			throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums or reshape_colmeans or update_ema, but found " + opcode);
 		}
 		_input1 = in1;
 		_input2 = in2;
@@ -126,7 +127,7 @@ public class DnnGPUInstruction extends GPUInstruction {
 	public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr, 
 			double intermediateMemoryBudget) throws DMLRuntimeException {
 		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
-		if( !opcode.equals("update_nesterov_x") ) {
+		if( !( opcode.equals("update_nesterov_x")) ) {
 			throw new DMLRuntimeException("Incorrect opcode: " + opcode);
 		}
 		_input1 = in1;
@@ -182,6 +183,22 @@ public class DnnGPUInstruction extends GPUInstruction {
 		_intermediateMemoryBudget = intermediateMemoryBudget;
 	}
 	
+	public DnnGPUInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, 
+			CPOperand out, String opcode, String istr, double intermediateMemoryBudget) {
+		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
+		if( !(opcode.equals("update_ema_var") || opcode.equals("batch_norm2d_bwd_dx")) ) {
+			throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be update_ema_var or batch_norm2d_bwd_dx, but found " + opcode);
+		}
+		_input1 = in;
+		_input2 = in2;
+		_input3 = in3;
+		_input4 = in4;
+		_input5 = in5;
+		_gputype = GPUINSTRUCTION_TYPE.Dnn;
+		_output = out;
+		_intermediateMemoryBudget = intermediateMemoryBudget;
+	}
+	
 	public static DnnGPUInstruction parseInstruction(String str) {
 		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
 		String opcode = parts[0];
@@ -297,14 +314,15 @@ public class DnnGPUInstruction extends GPUInstruction {
 			return new DnnGPUInstruction(in1, null, out, opcode, str, stride,
 					padding, input_shape, filter_shape, Double.parseDouble(parts[15]));
 		}
-		else if( opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply")  ) {
+		else if( opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") 
+				|| opcode.equalsIgnoreCase("inv_var") ) {
 			InstructionUtils.checkNumFields(parts, 4);
 			CPOperand in1 = new CPOperand(parts[1]);
 			CPOperand in2 = new CPOperand(parts[2]);
 			CPOperand out = new CPOperand(parts[3]);
 			return new DnnGPUInstruction(in1, in2, out, opcode, str, Double.parseDouble(parts[4]));
 		}
-		else if (opcode.equalsIgnoreCase("channel_sums")) {
+		else if (opcode.equalsIgnoreCase("channel_sums") || opcode.equals("reshape_colmeans") || opcode.equals("update_ema")) {
 			InstructionUtils.checkNumFields(parts, 4);
 			CPOperand in = new CPOperand(parts[1]);
 			CPOperand in2 = new CPOperand(parts[2]);
@@ -333,7 +351,7 @@ public class DnnGPUInstruction extends GPUInstruction {
 			CPOperand out2 = new CPOperand(parts[8]);
 			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0);
 		}
-		else if (opcode.equalsIgnoreCase("batch_norm2d") || opcode.equalsIgnoreCase("lstm_backward")) {
+		else if (opcode.equalsIgnoreCase("lstm_backward")) {
 			InstructionUtils.checkNumFields(parts, 13);
 			CPOperand in1 = new CPOperand(parts[1]); // image
 			CPOperand in2 = new CPOperand(parts[2]); // scale
@@ -350,19 +368,6 @@ public class DnnGPUInstruction extends GPUInstruction {
 			CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance
 			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
 		}
-		else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
-			InstructionUtils.checkNumFields(parts, 9);
-			CPOperand in1 = new CPOperand(parts[1]); // image
-			CPOperand in2 = new CPOperand(parts[2]); // dout
-			CPOperand in3 = new CPOperand(parts[3]); // scale
-			CPOperand in4 = new CPOperand(parts[4]); // epsilon
-			CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean
-			CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance
-			CPOperand out = new CPOperand(parts[7]);  // dX
-			CPOperand out2 = new CPOperand(parts[8]); // dScale
-			CPOperand out3 = new CPOperand(parts[9]); // dBias
-			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0);
-		}
 		else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
 			InstructionUtils.checkNumFields(parts, 7);
 			CPOperand in = new CPOperand(parts[1]);
@@ -374,21 +379,25 @@ public class DnnGPUInstruction extends GPUInstruction {
 			CPOperand out = new CPOperand(parts[7]);
 			return new DnnGPUInstruction(in, in2, in3, in4, in5, in6, out, opcode, str, 0);
 		}
-		else if (opcode.equalsIgnoreCase("batch_norm2d_train")) {
-			InstructionUtils.checkNumFields(parts, 12);
-			CPOperand in1 = new CPOperand(parts[1]); // image
-			CPOperand in2 = new CPOperand(parts[2]); // gamma
-			CPOperand in3 = new CPOperand(parts[3]); // beta
-			CPOperand in4 = new CPOperand(parts[4]); // ema_mean
-			CPOperand in5 = new CPOperand(parts[5]); // ema_var
-			CPOperand in6 = new CPOperand(parts[6]); // eps
-			CPOperand in7 = new CPOperand(parts[7]); // mu
-			CPOperand out = new CPOperand(parts[8]);  // out
-			CPOperand out2 = new CPOperand(parts[9]); // ema_mean_upd
-			CPOperand out3 = new CPOperand(parts[10]); // ema_var_upd
-			CPOperand out4 = new CPOperand(parts[11]); // cache_mean
-			CPOperand out5 = new CPOperand(parts[12]); // cache_inv_var
-			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0);
+		else if (opcode.equalsIgnoreCase("batch_norm2d_bwd_dx")) {
+			InstructionUtils.checkNumFields(parts, 6);
+			CPOperand in = new CPOperand(parts[1]);
+			CPOperand in2 = new CPOperand(parts[2]);
+			CPOperand in3 = new CPOperand(parts[3]);
+			CPOperand in4 = new CPOperand(parts[4]);
+			CPOperand in5 = new CPOperand(parts[5]);
+			CPOperand out = new CPOperand(parts[6]);
+			return new DnnGPUInstruction(in, in2, in3, in4, in5, out, opcode, str, 0);
+		}
+		else if (opcode.equalsIgnoreCase("update_ema_var")) {
+			InstructionUtils.checkNumFields(parts, 6);
+			CPOperand in = new CPOperand(parts[1]);
+			CPOperand in2 = new CPOperand(parts[2]);
+			CPOperand in3 = new CPOperand(parts[3]);
+			CPOperand in4 = new CPOperand(parts[4]);
+			CPOperand in5 = new CPOperand(parts[5]);
+			CPOperand out = new CPOperand(parts[6]);
+			return new DnnGPUInstruction(in, in2, in3, in4, in5, out, opcode, str, 0);
 		}
 		else {
 			throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str);	
@@ -396,211 +405,185 @@ public class DnnGPUInstruction extends GPUInstruction {
 	}
 
 	private void processBiasInstruction(String instOpcode, ExecutionContext ec) {
-		GPUStatistics.incrementNoOfExecutedGPUInst();
-		MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
-		MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName());
-		MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns());
-		
-		if(instOpcode.equalsIgnoreCase("bias_add"))
-			LibMatrixCUDA.biasAdd(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out);
-		else if(instOpcode.equalsIgnoreCase("bias_multiply"))
-			LibMatrixCUDA.biasMultiply(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out);
-		// release inputs/outputs
-		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-	}
-	
-	private void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException {
-		GPUStatistics.incrementNoOfExecutedGPUInst();
-		MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
-		MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
-		MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
-		MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName());
-		MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName());
-		
-		String phase = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getStringValue();
-		double epsilon = ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue();
-		
-		MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
-		
-		if(phase.equalsIgnoreCase("train")) {
-			double exponentialAverageFactor = 1-ec.getScalarInput(_input8.getName(), _input8.getValueType(), _input8.isLiteral()).getDoubleValue();
-			MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
-			MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
-			MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
-			MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
-			LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(), 
-				image, scale, bias, runningMean, runningVar, ret, 
-				retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
-			ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
-			ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
-			ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
-			ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
-		}
-		else if(phase.equalsIgnoreCase("test")) {
-			LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), getExtendedOpcode(), 
-					image, scale, bias, runningMean, runningVar, ret, epsilon);
-			ec.setMatrixOutput(_output2.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true), getExtendedOpcode());
-			ec.setMatrixOutput(_output3.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true), getExtendedOpcode());
-			ec.setMatrixOutput(_output4.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true), getExtendedOpcode());
-			ec.setMatrixOutput(_output5.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true), getExtendedOpcode());
-		}
-		else {
-			throw new DMLRuntimeException("Incorrect mode: Expected either train or test, but found " + phase);
+		try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+			fetcher.add("input", _input1).add("bias", _input2);
+			
+			MatrixObject input = fetcher.getInputMatrixObject("input");
+			MatrixObject bias = fetcher.getInputMatrixObject("bias");
+			MatrixObject out = fetcher.getOutputMatrixObject(input.getNumRows(), input.getNumColumns());
+			
+			if(instOpcode.equalsIgnoreCase("bias_add"))
+				LibMatrixCUDA.biasAdd(gCtx, instName, input, bias, out);
+			else if(instOpcode.equalsIgnoreCase("bias_multiply"))
+				LibMatrixCUDA.biasMultiply(gCtx, instName, input, bias, out);
 		}
-		
-		// release inputs/outputs
-		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input4.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
 	}
 	
-	private void processBatchNorm2dTrainInstruction(ExecutionContext ec) throws DMLRuntimeException {
-		GPUStatistics.incrementNoOfExecutedGPUInst();
-		MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
-		MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
-		MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
-		MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName());
-		MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName());
-		
-		double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue();
-		double exponentialAverageFactor = 1-ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue();
-		
-		MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
-		MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
-		MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
-		MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
-		MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
-		
-		LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(), 
-			image, scale, bias, runningMean, runningVar, ret, 
-			retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
-		
-		// release inputs/outputs
-		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input4.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
+	private void processInverseVarianceInstruction(String instOpcode, ExecutionContext ec) {
+		try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+			fetcher.add("X", _input1).addScalar("eps", _input2);
+			
+			int rows = LibMatrixCUDA.toInt(fetcher.getInputNumRows("X"));
+			int cols = LibMatrixCUDA.toInt(fetcher.getInputNumColumns("X"));
+			
+			// invVar(X, C, eps, size);
+			LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("invVar", 
+					ExecutionConfig.getConfigForSimpleVectorOperations(rows*cols),
+					fetcher.getInputPointer("X"), fetcher.getOutputPointer(rows, cols),
+					fetcher.getDouble("eps"), rows*cols);
+		}
 	}
 	
 	private void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException {
-		GPUStatistics.incrementNoOfExecutedGPUInst();
-		MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
-		MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
-		MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
-		MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName());
-		MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName());
-		double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue();
-		
-		MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
-		LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), getExtendedOpcode(), 
-				image, scale, bias, runningMean, runningVar, ret, epsilon);
-		
-		// release inputs/outputs
-		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input4.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+		try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+			fetcher.add("image", _input1).add("scale", _input2).add("bias", _input3)
+			.add("runningMean", _input4).add("runningVar", _input5).addScalar("epsilon", _input6);
+			
+			double epsilon = fetcher.getDouble("epsilon");
+			if(epsilon < JCudnn.CUDNN_BN_MIN_EPSILON) {
+				throw new DMLRuntimeException("The epsilon (" + epsilon + ") cannot be less than CUDNN_BN_MIN_EPSILON=(" + JCudnn.CUDNN_BN_MIN_EPSILON + ")");
+			}
+			
+			MatrixObject image = fetcher.getInputMatrixObject("image");
+			LibMatrixCuDNN.batchNormalizationForwardInference(gCtx, instName, 
+				image, fetcher.getInputMatrixObject("scale"), fetcher.getInputMatrixObject("bias"), 
+				fetcher.getInputMatrixObject("runningMean"), fetcher.getInputMatrixObject("runningVar"), 
+				fetcher.getOutputMatrixObject(image.getNumRows(), image.getNumColumns()), epsilon);
+		}
 	}
 	
-	public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
-		GPUStatistics.incrementNoOfExecutedGPUInst();
-		MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
-		MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
-		MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input3.getName());
-		double epsilon = ec.getScalarInput(_input4.getName(), _input4.getValueType(), _input4.isLiteral()).getDoubleValue();
-		MatrixObject resultSaveMean = getMatrixInputForGPUInstruction(ec, _input5.getName());
-		MatrixObject resultSaveInvVariance = getMatrixInputForGPUInstruction(ec, _input6.getName());
-		
-		MatrixObject dX = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
-		MatrixObject dScale = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), scale.getNumRows(), scale.getNumColumns());
-		MatrixObject dBias = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), scale.getNumRows(), scale.getNumColumns());
-		
-		LibMatrixCuDNN.batchNormalizationBackward(ec.getGPUContext(0), getExtendedOpcode(), image, 
-				dout, scale, dX, dScale, dBias,
-				epsilon, resultSaveMean, resultSaveInvVariance);
-		
-		// release inputs/outputs
-		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input6.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
+	private void processBatchNorm2dBackwardDxInstruction(ExecutionContext ec) throws DMLRuntimeException {
+		try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+			fetcher.add("X", _input1).add("dout", _input2).add("gamma", _input3)
+			.add("resultSaveMean", _input4).add("resultSaveInvVariance", _input5);
+			
+			// #define CUDNN_BN_MIN_EPSILON 1e-5 // Minimum epsilon allowed to be used in the Batch Normalization formula
+			double epsilon = 1e-4; 
+			MatrixObject image = fetcher.getInputMatrixObject("X");
+			LibMatrixCuDNN.batchNormalizationBackwardDX(gCtx, instName, image, 
+					fetcher.getInputMatrixObject("dout"), fetcher.getInputMatrixObject("gamma"), 
+					fetcher.getOutputMatrixObject(image.getNumRows(), image.getNumColumns()), epsilon, fetcher.getInputMatrixObject("resultSaveMean"), 
+					fetcher.getInputMatrixObject("resultSaveInvVariance")); 
+		}
 	}
-
+	
+	
 	// (X > 0) * dout
 	public void processReLUBackwardInstruction(ExecutionContext ec) {
-		GPUStatistics.incrementNoOfExecutedGPUInst();
-		MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
-		MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
-		
-		MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns());
-		
-		LibMatrixCUDA.reluBackward(ec.getGPUContext(0), getExtendedOpcode(), input, dout, out);
-		// release inputs/outputs
-		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+		try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+			fetcher.add("X", _input1).add("dout", _input2);
+			MatrixObject X = fetcher.getInputMatrixObject("X");
+			LibMatrixCUDA.reluBackward(gCtx, instName, X, 
+					fetcher.getInputMatrixObject("dout"), fetcher.getOutputMatrixObject(X.getNumRows(), X.getNumColumns()));
+		}
 	}
 	
 	private void processChannelSumsInstruction(ExecutionContext ec) {
-		GPUStatistics.incrementNoOfExecutedGPUInst();
-		MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
-		int C = (int) ec.getScalarInput(_input2.getName(), _input2.getValueType(), _input2.isLiteral()).getLongValue();
-		int HW = (int) ec.getScalarInput(_input3.getName(), _input3.getValueType(), _input3.isLiteral()).getLongValue();
-		if(C*HW != input.getNumColumns()) {
-			throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns());
+		try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+			fetcher.add("X", _input1).addScalar("C", _input2).addScalar("HW", _input3);
+			int C = fetcher.getInteger("C");
+			int HW = fetcher.getInteger("HW");
+			fetcher.validateDimensions("X", -1, C*HW);
+			LibMatrixCUDA.channelSums(gCtx, instName, 
+					fetcher.getInputMatrixObject("X"), 
+					fetcher.getOutputMatrixObject(C, 1), C, HW);
+		}
+	}
+	
+	private void processEMAInstruction(ExecutionContext ec) {
+		// "ema_mean", "mean", "mu"
+		try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+			fetcher.add("ema_mean", _input1).add("mean", _input2).addScalar("mu", _input3);
+			double mu = fetcher.getDouble("mu");
+			
+			int rows = LibMatrixCUDA.toInt(fetcher.getInputNumRows("ema_mean"));
+			int cols = LibMatrixCUDA.toInt(fetcher.getInputNumColumns("ema_mean"));
+			
+			fetcher.validateDimensions("mean", rows, cols);
+			
+			// aXplusbY(X, Y, C, a, b, size);
+			LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("aXplusbY", 
+					ExecutionConfig.getConfigForSimpleVectorOperations(rows*cols),
+					fetcher.getInputPointer("ema_mean"), fetcher.getInputPointer("mean"), 
+					fetcher.getOutputPointer(rows, cols),
+					mu, (1-mu), rows*cols);
+		}
+	}
+	
+	private void processReshapeColMeansInstruction(ExecutionContext ec) {
+		try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+			fetcher.add("X", _input1).addScalar("C", _input2).addScalar("HW", _input3);
+			int C = fetcher.getInteger("C");
+			int HW = fetcher.getInteger("HW");
+			fetcher.validateDimensions("X", -1, C*HW);
+			int rows = LibMatrixCUDA.toInt(fetcher.getInputNumRows("X"));
+			int cols = LibMatrixCUDA.toInt(fetcher.getInputNumColumns("X"));
+			// output = matrix(colMeans(X), rows=C, cols=Hin*Win)
+			LibMatrixCUDA.colMeans(gCtx, instName,  
+					fetcher.getInputPointer("X"), 
+					fetcher.getOutputPointer(C, HW), rows, cols);
 		}
-		MatrixObject outputBlock = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), C, 1);
-		
-		LibMatrixCUDA.channelSums(ec.getGPUContext(0), getExtendedOpcode(), input, outputBlock, C, HW);
-		
-		// release inputs/outputs
-		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
 	}
 	
+	private void processUpdateEMAVarInstruction(ExecutionContext ec) {
+		try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+			// "subgrp_means", "X", "C", "HW", "varConst1"
+			fetcher.add("subgrp_means", _input1).add("X", _input2).addScalar("C", _input3)
+				.addScalar("HW", _input4).addScalar("varConst1", _input5);
+			
+			// subgrp_vars = matrix(colVars(X) * varConst1, rows=C, cols=Hin*Win)
+			// var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+			// --->
+			// subgrp_vars = matrix(colVars(X), rows=C, cols=HW)
+			// var = rowMeans(subgrp_vars)*varConst1 + rowVars(subgrp_means)*((HW-1)/HW)  
+			int C = fetcher.getInteger("C");
+			int HW = fetcher.getInteger("HW");
+			double varConst1 = fetcher.getDouble("varConst1");
+			fetcher.validateDimensions("subgrp_means", C, HW);
+			fetcher.validateDimensions("X", -1, C*HW);
+			
+			Pointer subgrp_vars = gCtx.allocate(instName, C*HW*LibMatrixCUDA.sizeOfDataType);
+			// subgrp_vars <- colVars(X)
+			LibMatrixCUDA.colVars(gCtx, instName, fetcher.getInputPointer("X"), subgrp_vars, 
+					LibMatrixCUDA.toInt(fetcher.getInputNumRows("X")), C*HW);
+			
+			// tmp1 <- rowMeans(subgrp_vars)
+			Pointer tmp1 = gCtx.allocate(instName, C*LibMatrixCUDA.sizeOfDataType);
+			LibMatrixCUDA.rowMeans(gCtx, instName, subgrp_vars, tmp1, C, HW);
+			gCtx.cudaFreeHelper(instName, subgrp_vars, gCtx.EAGER_CUDA_FREE);
+			
+			// out <- rowVars(subgrp_means)
+			Pointer out = fetcher.getOutputPointer(C, 1);
+			LibMatrixCUDA.rowVars(gCtx, instName, fetcher.getInputPointer("subgrp_means"), out, C, HW);
+			
+			// var = tmp1*varConst1 + out*((HW-1)/HW)
+			LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("aXplusbC", 
+					ExecutionConfig.getConfigForSimpleVectorOperations(C),
+					tmp1, out,
+					varConst1, (((double)HW-1)/HW), C);
+			gCtx.cudaFreeHelper(instName, tmp1, gCtx.EAGER_CUDA_FREE);
+		}
+	}
+	
+	
+	
 	private void processNesterovUpdateInstruction(ExecutionContext ec) {
-		GPUStatistics.incrementNoOfExecutedGPUInst();
-		MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
-		MatrixObject v = getMatrixInputForGPUInstruction(ec, _input2.getName());
-		MatrixObject v_prev = getMatrixInputForGPUInstruction(ec, _input3.getName());
-		double mu = (int) ec.getScalarInput(_input4.getName(), _input4.getValueType(), _input4.isLiteral()).getDoubleValue();
-		int rows = LibMatrixCUDA.toInt(input.getNumRows());
-		int cols = LibMatrixCUDA.toInt(input.getNumColumns());
-		MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), rows, cols);
-		
-		GPUContext gCtx = ec.getGPUContext(0);
-		String instName = getExtendedOpcode();
-		LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x", 
-				ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)),
-				LibMatrixCUDA.getDensePointer(gCtx, input, instName), 
-				LibMatrixCUDA.getDensePointer(gCtx, v, instName),
-				LibMatrixCUDA.getDensePointer(gCtx, v_prev, instName),
-				mu, 
-				LibMatrixCUDA.getDensePointer(gCtx, out, instName),
-				rows*cols);
-		
-		// release inputs/outputs
-		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+		try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+			fetcher.add("input", _input1).add("v", _input2).add("v_prev", _input3)
+			.addScalar("mu", _input4);
+			MatrixObject input = fetcher.getInputMatrixObject("input");
+			int rows = LibMatrixCUDA.toInt(input.getNumRows());
+			int cols = LibMatrixCUDA.toInt(input.getNumColumns());
+			
+			LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x", 
+					ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)),
+					fetcher.getInputPointer("input"), 
+					fetcher.getInputPointer("v"),
+					fetcher.getInputPointer("v_prev"),
+					fetcher.getDouble("mu"), 
+					fetcher.getOutputPointer(rows, cols),
+					rows*cols);
+		}
 	}
 	
 	private static int toInt(long num) throws DMLRuntimeException {
@@ -610,32 +593,18 @@ public class DnnGPUInstruction extends GPUInstruction {
 		return (int)num;
 	}
 	
-//	private Pointer transpose(ExecutionContext ec, MatrixObject X) throws DMLRuntimeException {
-//		GPUContext gCtx = ec.getGPUContext(0);
-//		String instructionName = getExtendedOpcode();
-//		long numRowsX = X.getNumRows(); long numColsX = X.getNumColumns();
-//		Pointer tX = gCtx.allocate(instructionName, numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType);
-//		jcuda.runtime.JCuda.cudaMemcpy(tX, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType,  jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice);
-//		// LibMatrixCUDA.denseTranspose(ec, gCtx, instructionName, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), tX, numRowsX, numColsX);
-//		return tX;
-//	}
-	
 	private void processLstmBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
-		GPUStatistics.incrementNoOfExecutedGPUInst();
-		GPUContext gCtx = ec.getGPUContext(0);
-		String instructionName = getExtendedOpcode();
-		
 		MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName());
 		int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M)
-		Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName);
+		Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, out0, instName);
 		
 		MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName());
 		MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
 		long numRowsW = W.getNumRows();
 		int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures 
-		Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M);
-		Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M);
-		Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+		Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+		Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+		Pointer cudnnWPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
 		LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
 				ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
 				sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M);
@@ -644,20 +613,20 @@ public class DnnGPUInstruction extends GPUInstruction {
 		
 		
 		MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName());
-		Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName); 
+		Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName); 
 		int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
 		long numColsX = X.getNumColumns();
 		int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
-		Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
+		Pointer cudnnInput = gCtx.allocate(instName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
 		LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
 				ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
 				xPointer, cudnnInput, N, D, T*D, N*T*D);
 		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
 		
-		Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName);
+		Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
 		boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue();
 		
-		// LibMatrixCuDNN.lstm(ec, gCtx, instructionName, 
+		// LibMatrixCuDNN.lstm(ec, gCtx, instName, 
 				// cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
 				// String xName, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName,  // input
 				// String dxName, String dwName, String dbName, String dhxName, String dcxName,  	// output
@@ -668,12 +637,12 @@ public class DnnGPUInstruction extends GPUInstruction {
 		String dcxName = _output5.getName();
 		String doutName = _input7.getName();
 		String dcyName = _input8.getName();
-		LibMatrixCuDNN.lstmBackward(ec, gCtx, instructionName, 
+		LibMatrixCuDNN.lstmBackward(ec, gCtx, instName, 
 				cudnnInput, out0Pointer, c0Pointer, cudnnWPointer, doutName, dcyName,  // input
 				dxName, dwName, dbName, dhxName, dcxName, // output 
 				return_sequences, N, M, D, T);
-		gCtx.cudaFreeHelper(instructionName, cudnnWPointer, gCtx.EAGER_CUDA_FREE);
-		gCtx.cudaFreeHelper(instructionName, cudnnInput, gCtx.EAGER_CUDA_FREE);
+		gCtx.cudaFreeHelper(instName, cudnnWPointer, gCtx.EAGER_CUDA_FREE);
+		gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE);
 		
 		// release inputs/outputs
 		ec.releaseMatrixInputForGPUInstruction(_input4.getName());
@@ -686,21 +655,17 @@ public class DnnGPUInstruction extends GPUInstruction {
 		// weight W:(D+M+2, 4M) 
 		// previous output out0 (also represented by hx) and cell state c0 (also represented by cx): (N, M) ==> (1, M, N)
 		// out: (N, T*M) or (N, M) ==> (T, M, N)
-		GPUStatistics.incrementNoOfExecutedGPUInst();
-		GPUContext gCtx = ec.getGPUContext(0);
-		String instructionName = getExtendedOpcode();
-		
 		MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName());
 		int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M)
-		Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName);
+		Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, out0, instName);
 		
 		MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName());
 		MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
 		long numRowsW = W.getNumRows();
 		int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures 
-		Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M);
-		Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M);
-		Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+		Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+		Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+		Pointer cudnnWPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
 		LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
 				ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
 				sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M);
@@ -711,21 +676,21 @@ public class DnnGPUInstruction extends GPUInstruction {
 		
 		// Beause the matrices are released immediately, the output for transpose need not be taken into account
 		MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName());
-		Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName); 
+		Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName); 
 		int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
 		long numColsX = X.getNumColumns();
 		int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
-		Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
+		Pointer cudnnInput = gCtx.allocate(instName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
 		LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
 				ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
 				xPointer, cudnnInput, N, D, T*D, N*T*D);
 		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
 		
-		Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); 
+		Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instName); 
 		
-		LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
-		gCtx.cudaFreeHelper(instructionName, cudnnWPointer, gCtx.EAGER_CUDA_FREE);
-		gCtx.cudaFreeHelper(instructionName, cudnnInput, gCtx.EAGER_CUDA_FREE);
+		LibMatrixCuDNN.lstm(ec, gCtx, instName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
+		gCtx.cudaFreeHelper(instName, cudnnWPointer, gCtx.EAGER_CUDA_FREE);
+		gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE);
 		
 		// release inputs/outputs
 		ec.releaseMatrixInputForGPUInstruction(_input4.getName());
@@ -736,10 +701,17 @@ public class DnnGPUInstruction extends GPUInstruction {
 	
 	@Override
 	public void processInstruction(ExecutionContext ec) {
+		GPUStatistics.incrementNoOfExecutedGPUInst();
+		gCtx = ec.getGPUContext(0);
+		instName = getExtendedOpcode();
 		if (instOpcode.equalsIgnoreCase("bias_add") || instOpcode.equalsIgnoreCase("bias_multiply")) {
 			processBiasInstruction(instOpcode, ec);
 			return;
 		}
+		else if (instOpcode.equalsIgnoreCase("inv_var")) {
+			processInverseVarianceInstruction(instOpcode, ec);
+			return;
+		}
 		else if (instOpcode.equalsIgnoreCase("relu_backward")) {
 			processReLUBackwardInstruction(ec);
 			return;
@@ -748,10 +720,22 @@ public class DnnGPUInstruction extends GPUInstruction {
 			processChannelSumsInstruction(ec);
 			return;
 		}
+		else if (instOpcode.equalsIgnoreCase("update_ema")) {
+			processEMAInstruction(ec);
+			return;
+		}
+		else if (instOpcode.equalsIgnoreCase("reshape_colmeans")) {
+			processReshapeColMeansInstruction(ec);
+			return;
+		}
 		else if (instOpcode.equalsIgnoreCase("update_nesterov_x")) {
 			processNesterovUpdateInstruction(ec);
 			return;
 		}
+		else if (instOpcode.equalsIgnoreCase("update_ema_var")) {
+			processUpdateEMAVarInstruction(ec);
+			return;
+		}
 		else if (instOpcode.equalsIgnoreCase("lstm")) {
 			processLstmInstruction(ec);
 			return;
@@ -760,24 +744,14 @@ public class DnnGPUInstruction extends GPUInstruction {
 			processLstmBackwardInstruction(ec);
 			return;
 		}
-		else if (instOpcode.equalsIgnoreCase("batch_norm2d")) {
-			processBatchNorm2dInstruction(ec);
-			return;
-		}
-		else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
-			processBatchNorm2dBackwardInstruction(ec);
-			return;
-		}
 		else if (instOpcode.equalsIgnoreCase("batch_norm2d_test")) {
 			processBatchNorm2dTestInstruction(ec);
 			return;
 		}
-		else if (instOpcode.equalsIgnoreCase("batch_norm2d_train")) {
-			processBatchNorm2dTrainInstruction(ec);
+		else if (instOpcode.equalsIgnoreCase("batch_norm2d_bwd_dx")) {
+			processBatchNorm2dBackwardDxInstruction(ec);
 			return;
 		}
-		
-		GPUStatistics.incrementNoOfExecutedGPUInst();
 					
 		int pad_h = getScalarInput(ec, _padding, 0);
 		int pad_w = getScalarInput(ec, _padding, 1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
new file mode 100644
index 0000000..8fcaec3
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.instructions.gpu;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
+import org.apache.sysml.runtime.matrix.data.Pair;
+import org.apache.sysml.utils.GPUStatistics;
+
+import jcuda.Pointer;
+
+public class GPUDenseInputPointerFetcher implements java.lang.AutoCloseable {
+	ExecutionContext _ec; GPUContext _gCtx; String _instName;
+	HashMap<String, CPOperand> _inputMatrices = new HashMap<>();
+	HashMap<String, MatrixObject> _inputMatrixObjects = new HashMap<>();
+	HashMap<String, CPOperand> _inputScalars = new HashMap<>();
+	CPOperand _output;
+	public GPUDenseInputPointerFetcher(ExecutionContext ec, GPUContext gCtx, String instName, CPOperand output) {
+		_ec = ec;
+		_gCtx = gCtx;
+		_instName = instName;
+		_output = output;
+	}
+	public GPUDenseInputPointerFetcher add(String var, CPOperand in) {
+		_inputMatrices.put(var, in);
+		return this;
+	}
+	public GPUDenseInputPointerFetcher addScalar(String var, CPOperand in) {
+		_inputScalars.put(var, in);
+		return this;
+	}
+	public double getDouble(String var) {
+		CPOperand in = _inputScalars.get(var);
+		return _ec.getScalarInput(in.getName(), in.getValueType(), in.isLiteral()).getDoubleValue();
+	}
+	public long getLong(String var) {
+		CPOperand in = _inputScalars.get(var);
+		return _ec.getScalarInput(in.getName(), in.getValueType(), in.isLiteral()).getLongValue();
+	}
+	public int getInteger(String var) {
+		CPOperand in = _inputScalars.get(var);
+		return LibMatrixCUDA.toInt(_ec.getScalarInput(in.getName(), in.getValueType(), in.isLiteral()).getLongValue());
+	}
+	public Pointer getInputPointer(String var) {
+		return LibMatrixCUDA.getDensePointer(_gCtx, getInputMatrixObject(var), _instName);
+	}
+	public long getInputNumRows(String var) {
+		return getInputMatrixObject(var).getNumRows();
+	}
+	public long getInputNumColumns(String var) {
+		return getInputMatrixObject(var).getNumColumns();
+	}
+	public MatrixObject getOutputMatrixObject(long numRows, long numCols) {
+		boolean isFinegrainedStats = ConfigurationManager.isFinegrainedStatistics();
+		long t0 = isFinegrainedStats ? System.nanoTime() : 0;
+		Pair<MatrixObject, Boolean> mb = _ec.getDenseMatrixOutputForGPUInstruction(_output.getName(), numRows, numCols);
+		if (isFinegrainedStats && mb.getValue()) GPUStatistics.maintainCPMiscTimes(_instName,
+				GPUInstruction.MISC_TIMER_ALLOCATE_DENSE_OUTPUT, System.nanoTime() - t0);
+		return  mb.getKey();
+	}
+	public Pointer getOutputPointer(long numRows, long numCols) {
+		return LibMatrixCUDA.getDensePointer(_gCtx, getOutputMatrixObject(numRows, numCols), _instName);
+	}
+	public MatrixObject getInputMatrixObject(String var) {
+		CPOperand in = _inputMatrices.get(var);
+		if(!_inputMatrixObjects.containsKey(var)) {
+			_inputMatrixObjects.put(var, _ec.getMatrixInputForGPUInstruction(in.getName(), _instName));
+		}
+		return _inputMatrixObjects.get(var);
+	}
+	public void validateDimensions(String var, long numRows, long numCols) {
+		MatrixObject mo = getInputMatrixObject(var);
+		if(numRows > 0 && mo.getNumRows() != numRows) {
+			throw new DMLRuntimeException("Expected number of rows of subgrp_means to be " + numRows + ", but found " + mo.getNumRows());
+		}
+		else if(numCols > 0 && mo.getNumColumns() != numCols) {
+			throw new DMLRuntimeException("Expected number of columns of subgrp_means to be " + numCols + ", but found " + mo.getNumColumns());
+		}
+	}
+	@Override
+	public void close() {
+		for(CPOperand in : _inputMatrices.values()) {
+			_ec.releaseMatrixInputForGPUInstruction(in.getName());
+		}
+		_ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
index 2e43b99..e01c71a 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -242,7 +242,7 @@ public class GPUMemoryManager {
 	 * @return allocated pointer
 	 */
 	public Pointer malloc(String opcode, long size) {
-		if(size < 0) {
+		if(size <= 0) {
 			throw new DMLRuntimeException("Cannot allocate memory of size " + byteCountToDisplaySize(size));
 		}
 		if(DEBUG_MEMORY_LEAK) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index d02a875..f3f8434 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -208,7 +208,7 @@ public class LibMatrixCUDA {
 		return gCtx.getCusparseHandle();
 	}
 
-	protected static cublasHandle getCublasHandle(GPUContext gCtx) {
+	public static cublasHandle getCublasHandle(GPUContext gCtx) {
 		return gCtx.getCublasHandle();
 	}
 
@@ -302,7 +302,7 @@ public class LibMatrixCUDA {
 		}
 	}
 	
-	protected static Pointer dataTypePointerTo(double value) {
+	public static Pointer dataTypePointerTo(double value) {
 		if(value == 1) {
 			return one();
 		}
@@ -313,7 +313,6 @@ public class LibMatrixCUDA {
 			return _dataTypePointerTo(value);
 		}
 	}
-	
 
 	/**
 	 * This method computes the backpropagation errors for previous layer of relu operation
@@ -753,11 +752,11 @@ public class LibMatrixCUDA {
 				break;
 			}
 			case REDUCTION_COL: {
-				reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
+				rowMeans(gCtx, instName, in, out, rlen, clen);
 				break;
 			}
 			case REDUCTION_ROW: {
-				reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
+				colMeans(gCtx, instName, in, out, rlen, clen);
 				break;
 			}
 			default:
@@ -818,13 +817,14 @@ public class LibMatrixCUDA {
 			break;
 		}
 		case OP_VARIANCE : {
-			// Temporary GPU array for
-			Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
-			Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
+			
 
 			switch(reductionDirection) {
 
 			case REDUCTION_ALL: {
+				// Temporary GPU array for
+				Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
+				Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
 				double result = reduceAll(gCtx, instName, "reduce_sum", in, size);
 				double mean = result / size;
 
@@ -837,50 +837,21 @@ public class LibMatrixCUDA {
 				double result2 = reduceAll(gCtx, instName, "reduce_sum", tmp2, size);
 				double variance = result2 / (size - 1);
 				ec.setScalarOutput(output, new DoubleObject(variance));
-
+				gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
+				gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE);
 				break;
 			}
 			case REDUCTION_COL: {
-				reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
-				// Subtract the row-wise mean from every element in the matrix
-				BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
-				matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.COLUMN.code(), tmp, minusOp);
-
-				squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
-
-				Pointer tmpRow = gCtx.allocate(instName, rlen * sizeOfDataType);
-				reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, clen);
-
-				ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1);
-				matrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, divideOp);
-
-				gCtx.cudaFreeHelper(instName, tmpRow, gCtx.EAGER_CUDA_FREE);
-
+				rowVars(gCtx, instName, in, out, rlen, clen);
 				break;
 			}
 			case REDUCTION_ROW: {
-				reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
-				// Subtract the columns-wise mean from every element in the matrix
-				BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
-				matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.ROW.code(), tmp, minusOp);
-
-				squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
-
-				Pointer tmpCol = gCtx.allocate(instName, clen * sizeOfDataType);
-				reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, clen);
-
-				ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1);
-				matrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, divideOp);
-
-				gCtx.cudaFreeHelper(instName, tmpCol, gCtx.EAGER_CUDA_FREE);
-
+				colVars(gCtx, instName, in, out, rlen, clen);
 				break;
 			}
 			default:
 				throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for variance");
 			}
-			gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
-			gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE);
 			break;
 		}
 		case OP_MAXINDEX : {
@@ -904,6 +875,59 @@ public class LibMatrixCUDA {
 		default : throw new DMLRuntimeException("Internal Error - Invalid GPU Unary aggregate function!");
 		}
 	}
+	
+	public static void rowMeans(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) {
+		LibMatrixCUDA.reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
+	}
+	
+	public static void colMeans(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) {
+		reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
+	}
+	
+	public static void colVars(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) {
+		int size = rlen * clen;
+		Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
+		Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
+		reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
+		// Subtract the columns-wise mean from every element in the matrix
+		BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
+		matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.ROW.code(), tmp, minusOp);
+
+		squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
+
+		Pointer tmpCol = gCtx.allocate(instName, clen * sizeOfDataType);
+		reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, clen);
+
+		ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1);
+		matrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, divideOp);
+
+		gCtx.cudaFreeHelper(instName, tmpCol, gCtx.EAGER_CUDA_FREE);
+		gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
+		gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE);
+	}
+	
+	public static void rowVars(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) {
+		int size = rlen * clen;
+		Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
+		Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
+		
+		reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
+		// Subtract the row-wise mean from every element in the matrix
+		BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
+		matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.COLUMN.code(), tmp, minusOp);
+
+		squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
+
+		Pointer tmpRow = gCtx.allocate(instName, rlen * sizeOfDataType);
+		reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, clen);
+
+		ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1);
+		matrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, divideOp);
+
+		gCtx.cudaFreeHelper(instName, tmpRow, gCtx.EAGER_CUDA_FREE);
+		gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
+		gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE);
+	}
 
 	/**
 	 * Helper method to square a matrix in GPU memory
@@ -970,7 +994,7 @@ public class LibMatrixCUDA {
 	 * @param rows						number of rows in input matrix
 	 * @param cols						number of columns in input matrix
 	 */
-	private static void reduceRow(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) {
+	public static void reduceRow(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) {
 		if(LOG.isTraceEnabled()) {
 			LOG.trace("GPU : reduceRow for " + kernelFunction + ", GPUContext=" + gCtx);
 		}
@@ -997,7 +1021,7 @@ public class LibMatrixCUDA {
 	 * @param rows						number of rows in input matrix
 	 * @param cols						number of columns in input matrix
 	 */
-	private static void reduceCol(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) {
+	public static void reduceCol(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) {
 		if(LOG.isTraceEnabled()) {
 			LOG.trace("GPU : reduceCol for " + kernelFunction + ", GPUContext=" + gCtx);
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index d3b5984..e7955e1 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -1108,23 +1108,29 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 				runningMeanPtr, runningVarPtr, epsilon));
 	}
 	
+	private static void validateDimensions(MatrixObject mo, long expectedRows, long expectedCols) {
+		if(mo.getNumRows() != expectedRows || mo.getNumColumns() != expectedCols) {
+			throw new DMLRuntimeException("Incorrect dimensions for the input matrix object. Expected [" + expectedRows + ", " +  expectedCols+ "], but found "
+					+ "[" + mo.getNumRows() + ", " + mo.getNumColumns() + "].");
+		}
+	}
+	
 	/**
-	 * This method computes the backpropagation errors for image, scale and bias of batch normalization layer
+	 * This method computes the backpropagation errors for image of batch normalization layer
 	 * @param gCtx   a valid {@link GPUContext}
 	 * @param instName name of the instruction
 	 * @param image input image
 	 * @param dout input errors of shape C, H, W
 	 * @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1]
 	 * @param dX (output) backpropagation errors for previous layer
-	 * @param dScale backpropagation error for scale
-	 * @param dBias backpropagation error for bias
 	 * @param epsilon epsilon value used in the batch normalization formula
 	 * @param resultSaveMean (input) running mean accumulated during training phase: shape [1, C, 1, 1]
 	 * @param resultSaveInvVariance (input) running variance accumulated during training phase: shape [1, C, 1, 1]
 	 * @throws DMLRuntimeException if error occurs
 	 */
-	public static void batchNormalizationBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
-			MatrixObject scale, MatrixObject dX, MatrixObject dScale, MatrixObject dBias,
+	public static void batchNormalizationBackwardDX(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
+			MatrixObject scale, MatrixObject dX, 
+			// MatrixObject dScale, MatrixObject dBias,
 			double epsilon, MatrixObject resultSaveMean, MatrixObject resultSaveInvVariance) throws DMLRuntimeException {
 		if(LOG.isTraceEnabled()) {
 			LOG.trace("GPU : batchNormalizationBackward" + ", GPUContext=" + gCtx);
@@ -1133,7 +1139,13 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		int N = toInt(image.getNumRows());
 		int C = toInt(scale.getNumRows());
 		long CHW = image.getNumColumns();
-
+		
+		validateDimensions(scale, C, 1);
+		validateDimensions(dX, N, CHW);
+		validateDimensions(dout, N, CHW);
+		validateDimensions(resultSaveMean, C, 1);
+		validateDimensions(resultSaveInvVariance, C, 1);
+		
 		// Allocate descriptors
 		cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW,
 				new MatrixObject[] {image, dout},  new MatrixObject[] {dX});
@@ -1144,18 +1156,17 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		Pointer doutPtr = getDensePointerForCuDNN(gCtx, dout, instName);
 		Pointer scalePtr = getDensePointerForCuDNN(gCtx, scale, instName);
 		Pointer dXPtr = getDensePointerForCuDNN(gCtx, dX, instName);
-		Pointer dScalePtr = getDensePointerForCuDNN(gCtx, dScale, instName);
-		Pointer dBiasPtr = getDensePointerForCuDNN(gCtx, dBias, instName);
-		
+		Pointer dScalePtr = gCtx.allocate(instName, C*LibMatrixCUDA.sizeOfDataType); // getDensePointerForCuDNN(gCtx, dScale, instName);
+		Pointer dBiasPtr = gCtx.allocate(instName, C*LibMatrixCUDA.sizeOfDataType); //getDensePointerForCuDNN(gCtx, dBias, instName);		
 		Pointer resultSaveMeanPtr = getDensePointerForCuDNN(gCtx, resultSaveMean, instName);
 		Pointer resultSaveInvVariancePtr = getDensePointerForCuDNN(gCtx, resultSaveInvVariance, instName);
 
-
-		// ignoring resultSaveMean and resultSaveVariance as it requires state management
-		checkStatus(cudnnBatchNormalizationBackward(getCudnnHandle(gCtx), 
+		cudnnBatchNormalizationBackward(getCudnnHandle(gCtx), 
 				jcuda.jcudnn.cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL,  one(), zero(), one(), zero(),
 				nCHWDescriptor,  imagePtr, nCHWDescriptor, doutPtr, nCHWDescriptor, dXPtr,
-				scaleTensorDesc, scalePtr, dScalePtr, dBiasPtr, epsilon, resultSaveMeanPtr, resultSaveInvVariancePtr));
+				scaleTensorDesc, scalePtr, dScalePtr, dBiasPtr, epsilon, resultSaveMeanPtr, resultSaveInvVariancePtr);
+		gCtx.cudaFreeHelper(instName, dScalePtr, gCtx.EAGER_CUDA_FREE);
+		gCtx.cudaFreeHelper(instName, dBiasPtr, gCtx.EAGER_CUDA_FREE);
 	}
 	
 	private static void validateBatchNormalizationDimensions(MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar, int C) throws DMLRuntimeException {

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
index d96feac..d7e7b24 100644
--- a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
@@ -55,8 +55,8 @@ public class BatchNormTest extends GPUTests {
 		int imgSize = 32; 
 		int numChannels = 3;
 		double sparsity = 0.9;
-		String scriptStr = "source(\"nn/layers/batch_norm2d_old.dml\") as batch_norm2d_old;\n "
-				+ "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d_old::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, ema_var, 0.9, 1e-3)";
+		String scriptStr = "source(\"nn/layers/batch_norm2d.dml\") as batch_norm2d;\n "
+				+ "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, ema_var, 0.9, 1e-3)";
 		HashMap<String, Object> inputs = new HashMap<>();
 		inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 10, sparsity, seed));
 		inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
@@ -68,19 +68,40 @@ public class BatchNormTest extends GPUTests {
 		List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
 		if(mode.equals("test")) {
 			assertHeavyHitterPresent("gpu_batch_norm2d_test");
-			for(int i = 0; i < outputs.size(); i++) {
-				assertEqualObjects(outCPU.get(i), outGPU.get(i));
-			}
 		}
 		else {
-			//assertHeavyHitterPresent("gpu_batch_norm2d_train");
-			double [] threshold = new double[outputs.size()];
-			Arrays.fill(threshold, getTHRESHOLD());
-			// Handle loss of precision in CuDNN kernel 
-			threshold[2] = 1e-3;
-			for(int i = 0; i < outputs.size()-1; i++) {
-				assertEqualObjects(outCPU.get(i), outGPU.get(i), threshold[i]);
-			}
+			assertHeavyHitterPresent("gpu_batch_norm2d_test");
+			assertHeavyHitterPresent("gpu_reshape_colmeans");
+			assertHeavyHitterPresent("gpu_update_ema_var");
 		}
+		assertEqualObjects(outCPU.get(0), outGPU.get(0));
+		assertEqualObjects(outCPU.get(1), outGPU.get(1));
+		assertEqualObjects(outCPU.get(2), outGPU.get(2));
+		assertEqualObjects(outCPU.get(3), outGPU.get(3));
+		assertEqualObjects(outCPU.get(4), outGPU.get(4));
+	}
+	
+	@Test
+	public void testBatchNormBackward() {
+		int imgSize = 32; 
+		int numChannels = 3;
+		double sparsity = 0.9;
+		String scriptStr = "source(\"nn/layers/batch_norm2d.dml\") as batch_norm2d;\n "
+				+ "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"train\", ema_mean, ema_var, 0.9, 1e-3);\n"
+				+ "[dX, dgamma, dbeta] = batch_norm2d::backward(dout, cache_mean, cache_var, x, gamma, " + numChannels + ", " + imgSize + ", " + imgSize + ", 1e-3);";
+		HashMap<String, Object> inputs = new HashMap<>();
+		inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 10, sparsity, seed));
+		inputs.put("dout", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 1, 5, sparsity, seed));
+		inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
+		inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
+		inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 1, 3, 7, sparsity, seed));
+		inputs.put("ema_var", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
+		List<String> outputs = Arrays.asList("dX", "dgamma", "dbeta");
+		List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs);
+		List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
+		assertHeavyHitterPresent("gpu_batch_norm2d_bwd_dx");
+		assertEqualObjects(outCPU.get(0), outGPU.get(0), 1e-6);
+		assertEqualObjects(outCPU.get(1), outGPU.get(1));
+		assertEqualObjects(outCPU.get(2), outGPU.get(2));
 	}
 }


[3/3] systemml git commit: [SYSTEMML-445] Removed batch_norm builtin functions

Posted by ni...@apache.org.
[SYSTEMML-445] Removed batch_norm builtin functions

- Removed batch_norm builtin functions to exploit codegen in CP.
- Added rewrites for compiling efficient CuDNN operators.
- Added rewrites for SGD update operations.
- To simplify adding new GPU rewrites, added HopDagPatternMatcher that allows for pattern matching at the HOP-level. This can be extended for other rewrites as well.
- Added GPU tests to validate the rewrites.
- Updated the DML language documentation.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0f36780a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0f36780a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0f36780a

Branch: refs/heads/master
Commit: 0f36780a8244c6e728d37c32a79e00ed181211ad
Parents: 81419ae
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu Aug 30 15:40:44 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu Aug 30 15:40:44 2018 -0700

----------------------------------------------------------------------
 docs/dml-language-reference.md                  |    2 -
 scripts/nn/layers/batch_norm2d.dml              |   60 +-
 scripts/nn/layers/batch_norm2d_old.dml          |  200 ----
 src/main/cpp/kernels/SystemML.cu                |   56 +-
 src/main/cpp/kernels/SystemML.ptx               |  321 +++++-
 src/main/java/org/apache/sysml/hops/DnnOp.java  |   56 +-
 .../java/org/apache/sysml/hops/FunctionOp.java  |   30 +-
 src/main/java/org/apache/sysml/hops/Hop.java    |    8 +-
 .../hops/rewrite/HopDagPatternMatcher.java      |  378 +++++++
 .../sysml/hops/rewrite/HopPatternRewriter.java  |   72 ++
 .../HopRewriteRuleWithPatternMatcher.java       |   98 ++
 .../sysml/hops/rewrite/HopRewriteUtils.java     |   20 +
 .../hops/rewrite/RewriteGPUSpecificOps.java     | 1027 +++++-------------
 .../org/apache/sysml/lops/DnnTransform.java     |   53 +-
 .../sysml/parser/BuiltinFunctionExpression.java |   61 +-
 .../org/apache/sysml/parser/DMLTranslator.java  |    2 -
 .../org/apache/sysml/parser/Expression.java     |    2 +-
 .../instructions/GPUInstructionParser.java      |   10 +-
 .../instructions/gpu/DnnGPUInstruction.java     |  526 +++++----
 .../gpu/GPUDenseInputPointerFetcher.java        |  111 ++
 .../gpu/context/GPUMemoryManager.java           |    2 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      |  110 +-
 .../runtime/matrix/data/LibMatrixCuDNN.java     |   37 +-
 .../apache/sysml/test/gpu/BatchNormTest.java    |   47 +-
 24 files changed, 1818 insertions(+), 1471 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/docs/dml-language-reference.md
----------------------------------------------------------------------
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index 924336a..cdcc529 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -1522,8 +1522,6 @@ Hence, the images are internally represented as a matrix with dimension (N, C *
 | bias_add                                    | input, bias              | [batch_size X num_channels* height_image* width_image]    | [num_channels X 1]                                        | [batch_size X num_channels* height_image* width_image]                                      |                                                                                                                                                                                               | Adds the bias (row vector of size num_channels) to input with the given num_channels                                                              |
 | bias_multiply                               | input, bias              | [batch_size X num_channels* height_image* width_image]    | [num_channels X 1]                                        | [batch_size X num_channels* height_image* width_image]                                      |                                                                                                                                                                                               | Multiplies the bias (row vector of size num_channels) to input with the given num_channels                                                        |
 | lstm                                        | X,  W, bias, out0, c0    | [batch_size X seq_length*num_features]                    | [num_features+hidden_size X 4*hidden_size]                | [batch_size X seq_length*hidden_size] if return_sequences else  [batch_size X hidden_size]  | return_sequences                                                                                                                                                                              | Perform computation for single-layer unidirectional LSTM (outputs: out, carryOut)                                                                 |
-| batch_norm2d                                | input                    | [batch_size X num_channels* height_image* width_image]    |                                                           | [batch_size X num_channels* height_image* width_image]                                      | scale, shift, exponentialMovingAverage_Mean, exponentialMovingAverage_Variance, mode, epsilon, momentum                                                                                       | Performs batch normalization operation  (outputs: updated exponential moving average mean and variance, cache of the batch mean and variance)     |
-| batch_norm2d_backward                       | input, dout              | [batch_size X num_channels* height_image* width_image]    | [batch_size X num_channels* height_image* width_image]    | [batch_size X num_channels* height_image* width_image]                                      | scale, epsilon, cache_mean (from forward), cache_inv_var (from forward)                                                                                                                       | Computed backpropagation error for batch normalization operation                                                                                  |
 
 Note: the builtin functions `batch_norm2d` and `batch_norm2d_backward` are deprecated and will be removed in the next release. The `lstm` builtin function is in experimental phase and is only supported for the GPU backend. 
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/scripts/nn/layers/batch_norm2d.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm2d.dml b/scripts/nn/layers/batch_norm2d.dml
index 2a98857..c68f23d 100644
--- a/scripts/nn/layers/batch_norm2d.dml
+++ b/scripts/nn/layers/batch_norm2d.dml
@@ -83,8 +83,41 @@ forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
    *  - cache_inv_var: Cache of the inverse variance, of shape (C, 1).
    *      Note: This is used for performance during training.
    */
-  out = X; ema_mean_upd = ema_mean; ema_var_upd = ema_var;  cache_mean = ema_mean;  cache_inv_var = ema_var
-  [out, ema_mean_upd, ema_var_upd, cache_mean, cache_inv_var] = batch_norm2d(X, gamma, beta, ema_mean, ema_var, mode, epsilon, mu)
+  N = nrow(X)
+
+  if (mode == 'train') {
+    # Compute channel-wise mean and variance
+    # Since we don't have tensors, we will compute the means and variances in a piece-wise fashion.
+    #  - mean of total group is mean of subgroup means
+    #  - variance is the mean of the subgroup variances + the variance of the subgroup means
+    subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
+    subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)  # uncorrected variances
+    mean = rowMeans(subgrp_means)  # shape (C, 1)
+    var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))  # shape (C, 1)
+    # Update moving averages
+    ema_mean_upd = mu*ema_mean + (1-mu)*mean
+    ema_var_upd = mu*ema_var + (1-mu)*var
+  }
+  else {
+    # Use moving averages of mean and variance during testing
+    mean = ema_mean
+    var = ema_var
+    ema_mean_upd = ema_mean
+    ema_var_upd = ema_var
+  }
+
+  # Save variable for backward pass
+  cache_mean = mean
+  cache_inv_var = 1/sqrt(var+epsilon)
+  
+  # Normalize, shift, and scale
+  # norm = (X-mean)*(var+epsilon)^(-1/2)
+  #      = (X-mean) / sqrt(var+epsilon)
+  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
+  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
+  # out = norm*gamma + beta
+  scaled = bias_multiply(norm, gamma)  # shape (N, C*Hin*Win)
+  out = bias_add(scaled, beta)  # shape (N, C*Hin*Win)
 }
 
 backward = function(matrix[double] dout, 
@@ -119,9 +152,27 @@ backward = function(matrix[double] dout,
    *  - dbeta: Gradient wrt `b`, of shape (C, 1).
    *
    */
+  N = nrow(X)
+  oneByN = 1/N
+  oneByHW = 1/(Hin*Win)
+  
+  mean = cache_mean
+  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
+  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
   # Compute gradients during training
-  dX = X; dgamma = gamma; dbeta = gamma;
-  [dX, dgamma, dbeta] = batch_norm2d_backward(X, dout, gamma, epsilon, cache_mean, cache_inv_var)
+  dgamma = util::channel_sums(dout*norm, C, Hin, Win)  # shape (C, 1)
+  dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
+  dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
+  dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm,
+                          C, Hin, Win)  # shape (C, 1)
+  dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win)
+  dmean_var_branch =  util::channel_sums((-2*oneByN*oneByHW) * centered, C, Hin, Win)
+  dmean_var_branch = dmean_var_branch * dvar  # we can't use a function within an expression yet
+  dmean = dmean_norm_branch + dmean_var_branch  # shape (C, 1)
+  dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
+  dX_mean_branch = (oneByN*oneByHW) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)
+  dX_var_branch = (2*oneByN*oneByHW) * bias_multiply(centered, dvar)
+  dX = dX_norm_branch + dX_mean_branch + dX_var_branch  # shape (N, C*Hin*Win)
 }
 
 init = function(int C)
@@ -149,3 +200,4 @@ init = function(int C)
    ema_mean = matrix(0, rows=C, cols=1)
    ema_var = matrix(1, rows=C, cols=1)
 }
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/scripts/nn/layers/batch_norm2d_old.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm2d_old.dml b/scripts/nn/layers/batch_norm2d_old.dml
deleted file mode 100644
index 2aba2e6..0000000
--- a/scripts/nn/layers/batch_norm2d_old.dml
+++ /dev/null
@@ -1,200 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-/*
- * 2D (Spatial) Batch Normalization layer.
- */
-source("nn/util.dml") as util
-
-forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
-                   int C, int Hin, int Win, string mode,
-                   matrix[double] ema_mean, matrix[double] ema_var,
-                   double mu, double epsilon)
-    return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] ema_var_upd,
-            matrix[double] cache_mean, matrix[double] cache_inv_var) {
-  /*
-   * Computes the forward pass for a 2D (spatial) batch normalization
-   * layer.  The input data has N examples, each represented as a 3D
-   * volume unrolled into a single vector.
-   *
-   * A spatial batch normalization layer uses the per-channel sample
-   * mean and per-channel uncorrected sample variance during training
-   * to normalize each channel of the input data.  Additionally, it
-   * introduces learnable parameters (gamma, beta) to control the
-   * amount of normalization.
-   *
-   *   `y = ((x-mean) / sqrt(var+eps)) * gamma + beta`
-   *
-   * This implementation maintains exponential moving averages of the
-   * mean and variance during training for use during testing.
-   *
-   * Reference:
-   *  - Batch Normalization: Accelerating Deep Network Training by
-   *    Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015
-   *    - https://arxiv.org/abs/1502.03167
-   *
-   * Inputs:
-   *  - X: Inputs, of shape (N, C*Hin*Win).
-   *  - gamma: Scale parameters, of shape (C, 1).
-   *  - beta: Shift parameters, of shape (C, 1).
-   *  - C: Number of input channels (dimensionality of input depth).
-   *  - Hin: Input height.
-   *  - Win: Input width.
-   *  - mode: 'train' or 'test' to indicate if the model is currently
-   *      being trained or tested.  During training, the current batch
-   *      mean and variance will be used to normalize the inputs, while
-   *      during testing, the exponential average of the mean and
-   *      variance over all previous batches will be used.
-   *  - ema_mean: Exponential moving average of the mean, of
-   *      shape (C, 1).
-   *  - ema_var: Exponential moving average of the variance, of
-   *      shape (C, 1).
-   *  - mu: Momentum value for moving averages.
-   *      Typical values are in the range of [0.9, 0.999].
-   *  - epsilon: Smoothing term to avoid divide by zero errors.
-   *      Typical values are in the range of [1e-5, 1e-3].
-   *
-   * Outputs:
-   *  - out: Outputs, of shape (N, C*Hin*Win).
-   *  - ema_mean_upd: Updated exponential moving average of the mean,
-   *      of shape (C, 1).
-   *  - ema_var_upd: Updated exponential moving average of the variance,
-   *      of shape (C, 1).
-   *  - cache_mean: Cache of the batch mean, of shape (C, 1).
-   *      Note: This is used for performance during training.
-   *  - cache_inv_var: Cache of the inverse variance, of shape (C, 1).
-   *      Note: This is used for performance during training.
-   */
-  N = nrow(X)
-
-  if (mode == 'train') {
-    # Compute channel-wise mean and variance
-    # Since we don't have tensors, we will compute the means and variances in a piece-wise fashion.
-    #  - mean of total group is mean of subgroup means
-    #  - variance is the mean of the subgroup variances + the variance of the subgroup means
-    subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
-    subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)  # uncorrected variances
-    mean = rowMeans(subgrp_means)  # shape (C, 1)
-    var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))  # shape (C, 1)
-    # Update moving averages
-    ema_mean_upd = mu*ema_mean + (1-mu)*mean
-    ema_var_upd = mu*ema_var + (1-mu)*var
-  }
-  else {
-    # Use moving averages of mean and variance during testing
-    mean = ema_mean
-    var = ema_var
-    ema_mean_upd = ema_mean
-    ema_var_upd = ema_var
-  }
-
-  # Save variable for backward pass
-  cache_mean = mean
-  cache_inv_var = 1/sqrt(var+epsilon)
-  
-  # Normalize, shift, and scale
-  # norm = (X-mean)*(var+epsilon)^(-1/2)
-  #      = (X-mean) / sqrt(var+epsilon)
-  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
-  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
-  # out = norm*gamma + beta
-  scaled = bias_multiply(norm, gamma)  # shape (N, C*Hin*Win)
-  out = bias_add(scaled, beta)  # shape (N, C*Hin*Win)
-}
-
-backward = function(matrix[double] dout, 
-                    matrix[double] cache_mean, matrix[double] cache_inv_var,
-                    matrix[double] X, matrix[double] gamma, 
-                    int C, int Hin, int Win, double epsilon)
-      return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
-  /*
-   * Computes the backward pass for a 2D (spatial) batch normalization
-   * layer.
-   *
-   * Inputs:
-   *  - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).
-   *  - cache_mean: Cache of the batch mean from the forward pass, of
-   *      shape (C, 1).  Note: This is used for performance during
-   *      training.
-   *  - cache_inv_var: Cache of the inverse variance from the forward pass,
-   *      of shape (C, 1).  Note: This is used for performance during
-   *      training.
-   *  - X: Input data matrix to the forward pass, of
-   *      shape (N, C*Hin*Win).
-   *  - gamma: Scale parameters, of shape (C, 1).
-   *  - C: Number of input channels (dimensionality of input depth).
-   *  - Hin: Input height.
-   *  - Win: Input width.
-   *  - epsilon: Smoothing term to avoid divide by zero errors.
-   *      Typical values are in the range of [1e-5, 1e-3].
-   *
-   * Outputs:
-   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
-   *  - dgamma: Gradient wrt `W`, of shape (C, 1).
-   *  - dbeta: Gradient wrt `b`, of shape (C, 1).
-   *
-   */
-  N = nrow(X)
-  mean = cache_mean
-  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
-  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
-  # Compute gradients during training
-  dgamma = util::channel_sums(dout*norm, C, Hin, Win)  # shape (C, 1)
-  dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
-  dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
-  dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm,
-                          C, Hin, Win)  # shape (C, 1)
-  dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win)
-  dmean_var_branch =  util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, Win)
-  dmean_var_branch = dmean_var_branch * dvar  # we can't use a function within an expression yet
-  dmean = dmean_norm_branch + dmean_var_branch  # shape (C, 1)
-  dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
-  dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)
-  dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar)
-  dX = dX_norm_branch + dX_mean_branch + dX_var_branch  # shape (N, C*Hin*Win)
-}
-
-init = function(int C)
-    return (matrix[double] gamma, matrix[double] beta,
-            matrix[double] ema_mean, matrix[double] ema_var) {
-  /*
-   * Initialize the parameters of this layer.
-   *
-   * Note: This is just a convenience function, and parameters
-   * may be initialized manually if needed.
-   *
-   * Inputs:
-   *  - C: Number of input channels (dimensionality of input depth).
-   *
-   * Outputs:
-   *  - gamma: Scale parameters, of shape (C, 1).
-   *  - beta: Shift parameters, of shape (C, 1).
-   *  - ema_mean: Exponential moving average of the mean, of
-   *      shape (C, 1).
-   *  - ema_var: Exponential moving average of the variance, of
-   *      shape (C, 1).
-   */
-   gamma = matrix(1, rows=C, cols=1)
-   beta = matrix(0, rows=C, cols=1)
-   ema_mean = matrix(0, rows=C, cols=1)
-   ema_var = matrix(1, rows=C, cols=1)
-}
-

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 9ddaaff..b874cdd 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -2289,4 +2289,58 @@ extern "C" __global__ void update_nesterov_x_d(double *X, double *v, double *v_p
 
 extern "C" __global__ void update_nesterov_x_f(float *X, float *v, float *v_prev, double mu, float *out, unsigned int size) {
   update_nesterov_x(X, v, v_prev, mu, out, size);
-}
\ No newline at end of file
+}
+
+// Performs the operation: C = a*X + b*C
+template <typename T>
+__device__ void aXplusbC(T *X, T *C, double a, double b, unsigned int size) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  if (index < size) {
+	C[index] = a*X[index] + b*C[index];
+  }
+}
+
+extern "C" __global__ void aXplusbC_d(double *X, double *C, double a, double b, unsigned int size) {
+  aXplusbC(X, C, a, b,size);
+}
+
+extern "C" __global__ void aXplusbC_f(float *X, float *C, double a, double b, unsigned int size) {
+  aXplusbC(X, C, a, b,size);;
+}
+
+
+// Performs the operation: C = a*X + b*Y
+template <typename T>
+__device__ void aXplusbY(T *X, T* Y, T *C, double a, double b, unsigned int size) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  if (index < size) {
+	C[index] = a*X[index] + b*Y[index];
+  }
+}
+
+extern "C" __global__ void aXplusbY_d(double *X, double* Y, double *C, double a, double b, unsigned int size) {
+  aXplusbY(X, Y, C, a, b, size);
+}
+
+extern "C" __global__ void aXplusbY_f(float *X, float* Y, float *C, double a, double b, unsigned int size) {
+  aXplusbY(X, Y, C, a, b, size);
+}
+
+
+// Performs the operation: C = 1 / sqrt(X + eps)
+template <typename T>
+__device__ void invVar(T *X, T *C, double eps, unsigned int size) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  if (index < size) {
+	C[index] = 1.0 / sqrt(X[index] + eps);
+  }
+}
+
+extern "C" __global__ void invVar_d(double *X, double *C, double eps, unsigned int size) {
+  invVar(X, C, eps, size);
+}
+
+extern "C" __global__ void invVar_f(float *X, float *C, double eps, unsigned int size) {
+  invVar(X, C, eps, size);
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx
index 8a14876..1ab32f5 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -13084,12 +13084,279 @@ BB115_2:
 	ret;
 }
 
+	// .globl	aXplusbC_d
+.visible .entry aXplusbC_d(
+	.param .u64 aXplusbC_d_param_0,
+	.param .u64 aXplusbC_d_param_1,
+	.param .f64 aXplusbC_d_param_2,
+	.param .f64 aXplusbC_d_param_3,
+	.param .u32 aXplusbC_d_param_4
+)
+{
+	.reg .pred 	%p<2>;
+	.reg .b32 	%r<6>;
+	.reg .f64 	%fd<7>;
+	.reg .b64 	%rd<8>;
+
+
+	ld.param.u64 	%rd1, [aXplusbC_d_param_0];
+	ld.param.u64 	%rd2, [aXplusbC_d_param_1];
+	ld.param.f64 	%fd1, [aXplusbC_d_param_2];
+	ld.param.f64 	%fd2, [aXplusbC_d_param_3];
+	ld.param.u32 	%r2, [aXplusbC_d_param_4];
+	mov.u32 	%r3, %ctaid.x;
+	mov.u32 	%r4, %ntid.x;
+	mov.u32 	%r5, %tid.x;
+	mad.lo.s32 	%r1, %r4, %r3, %r5;
+	setp.ge.u32	%p1, %r1, %r2;
+	@%p1 bra 	BB116_2;
+
+	cvta.to.global.u64 	%rd3, %rd2;
+	cvta.to.global.u64 	%rd4, %rd1;
+	mul.wide.s32 	%rd5, %r1, 8;
+	add.s64 	%rd6, %rd4, %rd5;
+	ld.global.f64 	%fd3, [%rd6];
+	add.s64 	%rd7, %rd3, %rd5;
+	ld.global.f64 	%fd4, [%rd7];
+	mul.f64 	%fd5, %fd4, %fd2;
+	fma.rn.f64 	%fd6, %fd3, %fd1, %fd5;
+	st.global.f64 	[%rd7], %fd6;
+
+BB116_2:
+	ret;
+}
+
+	// .globl	aXplusbC_f
+.visible .entry aXplusbC_f(
+	.param .u64 aXplusbC_f_param_0,
+	.param .u64 aXplusbC_f_param_1,
+	.param .f64 aXplusbC_f_param_2,
+	.param .f64 aXplusbC_f_param_3,
+	.param .u32 aXplusbC_f_param_4
+)
+{
+	.reg .pred 	%p<2>;
+	.reg .f32 	%f<4>;
+	.reg .b32 	%r<6>;
+	.reg .f64 	%fd<7>;
+	.reg .b64 	%rd<8>;
+
+
+	ld.param.u64 	%rd1, [aXplusbC_f_param_0];
+	ld.param.u64 	%rd2, [aXplusbC_f_param_1];
+	ld.param.f64 	%fd1, [aXplusbC_f_param_2];
+	ld.param.f64 	%fd2, [aXplusbC_f_param_3];
+	ld.param.u32 	%r2, [aXplusbC_f_param_4];
+	mov.u32 	%r3, %ctaid.x;
+	mov.u32 	%r4, %ntid.x;
+	mov.u32 	%r5, %tid.x;
+	mad.lo.s32 	%r1, %r4, %r3, %r5;
+	setp.ge.u32	%p1, %r1, %r2;
+	@%p1 bra 	BB117_2;
+
+	cvta.to.global.u64 	%rd3, %rd2;
+	cvta.to.global.u64 	%rd4, %rd1;
+	mul.wide.s32 	%rd5, %r1, 4;
+	add.s64 	%rd6, %rd4, %rd5;
+	ld.global.f32 	%f1, [%rd6];
+	cvt.f64.f32	%fd3, %f1;
+	add.s64 	%rd7, %rd3, %rd5;
+	ld.global.f32 	%f2, [%rd7];
+	cvt.f64.f32	%fd4, %f2;
+	mul.f64 	%fd5, %fd4, %fd2;
+	fma.rn.f64 	%fd6, %fd3, %fd1, %fd5;
+	cvt.rn.f32.f64	%f3, %fd6;
+	st.global.f32 	[%rd7], %f3;
+
+BB117_2:
+	ret;
+}
+
+	// .globl	aXplusbY_d
+.visible .entry aXplusbY_d(
+	.param .u64 aXplusbY_d_param_0,
+	.param .u64 aXplusbY_d_param_1,
+	.param .u64 aXplusbY_d_param_2,
+	.param .f64 aXplusbY_d_param_3,
+	.param .f64 aXplusbY_d_param_4,
+	.param .u32 aXplusbY_d_param_5
+)
+{
+	.reg .pred 	%p<2>;
+	.reg .b32 	%r<6>;
+	.reg .f64 	%fd<7>;
+	.reg .b64 	%rd<11>;
+
+
+	ld.param.u64 	%rd1, [aXplusbY_d_param_0];
+	ld.param.u64 	%rd2, [aXplusbY_d_param_1];
+	ld.param.u64 	%rd3, [aXplusbY_d_param_2];
+	ld.param.f64 	%fd1, [aXplusbY_d_param_3];
+	ld.param.f64 	%fd2, [aXplusbY_d_param_4];
+	ld.param.u32 	%r2, [aXplusbY_d_param_5];
+	mov.u32 	%r3, %ctaid.x;
+	mov.u32 	%r4, %ntid.x;
+	mov.u32 	%r5, %tid.x;
+	mad.lo.s32 	%r1, %r4, %r3, %r5;
+	setp.ge.u32	%p1, %r1, %r2;
+	@%p1 bra 	BB118_2;
+
+	cvta.to.global.u64 	%rd4, %rd1;
+	mul.wide.s32 	%rd5, %r1, 8;
+	add.s64 	%rd6, %rd4, %rd5;
+	ld.global.f64 	%fd3, [%rd6];
+	cvta.to.global.u64 	%rd7, %rd2;
+	add.s64 	%rd8, %rd7, %rd5;
+	ld.global.f64 	%fd4, [%rd8];
+	mul.f64 	%fd5, %fd4, %fd2;
+	fma.rn.f64 	%fd6, %fd3, %fd1, %fd5;
+	cvta.to.global.u64 	%rd9, %rd3;
+	add.s64 	%rd10, %rd9, %rd5;
+	st.global.f64 	[%rd10], %fd6;
+
+BB118_2:
+	ret;
+}
+
+	// .globl	aXplusbY_f
+.visible .entry aXplusbY_f(
+	.param .u64 aXplusbY_f_param_0,
+	.param .u64 aXplusbY_f_param_1,
+	.param .u64 aXplusbY_f_param_2,
+	.param .f64 aXplusbY_f_param_3,
+	.param .f64 aXplusbY_f_param_4,
+	.param .u32 aXplusbY_f_param_5
+)
+{
+	.reg .pred 	%p<2>;
+	.reg .f32 	%f<4>;
+	.reg .b32 	%r<6>;
+	.reg .f64 	%fd<7>;
+	.reg .b64 	%rd<11>;
+
+
+	ld.param.u64 	%rd1, [aXplusbY_f_param_0];
+	ld.param.u64 	%rd2, [aXplusbY_f_param_1];
+	ld.param.u64 	%rd3, [aXplusbY_f_param_2];
+	ld.param.f64 	%fd1, [aXplusbY_f_param_3];
+	ld.param.f64 	%fd2, [aXplusbY_f_param_4];
+	ld.param.u32 	%r2, [aXplusbY_f_param_5];
+	mov.u32 	%r3, %ctaid.x;
+	mov.u32 	%r4, %ntid.x;
+	mov.u32 	%r5, %tid.x;
+	mad.lo.s32 	%r1, %r4, %r3, %r5;
+	setp.ge.u32	%p1, %r1, %r2;
+	@%p1 bra 	BB119_2;
+
+	cvta.to.global.u64 	%rd4, %rd1;
+	mul.wide.s32 	%rd5, %r1, 4;
+	add.s64 	%rd6, %rd4, %rd5;
+	ld.global.f32 	%f1, [%rd6];
+	cvt.f64.f32	%fd3, %f1;
+	cvta.to.global.u64 	%rd7, %rd2;
+	add.s64 	%rd8, %rd7, %rd5;
+	ld.global.f32 	%f2, [%rd8];
+	cvt.f64.f32	%fd4, %f2;
+	mul.f64 	%fd5, %fd4, %fd2;
+	fma.rn.f64 	%fd6, %fd3, %fd1, %fd5;
+	cvt.rn.f32.f64	%f3, %fd6;
+	cvta.to.global.u64 	%rd9, %rd3;
+	add.s64 	%rd10, %rd9, %rd5;
+	st.global.f32 	[%rd10], %f3;
+
+BB119_2:
+	ret;
+}
+
+	// .globl	invVar_d
+.visible .entry invVar_d(
+	.param .u64 invVar_d_param_0,
+	.param .u64 invVar_d_param_1,
+	.param .f64 invVar_d_param_2,
+	.param .u32 invVar_d_param_3
+)
+{
+	.reg .pred 	%p<2>;
+	.reg .b32 	%r<6>;
+	.reg .f64 	%fd<6>;
+	.reg .b64 	%rd<8>;
+
+
+	ld.param.u64 	%rd1, [invVar_d_param_0];
+	ld.param.u64 	%rd2, [invVar_d_param_1];
+	ld.param.f64 	%fd1, [invVar_d_param_2];
+	ld.param.u32 	%r2, [invVar_d_param_3];
+	mov.u32 	%r3, %ctaid.x;
+	mov.u32 	%r4, %ntid.x;
+	mov.u32 	%r5, %tid.x;
+	mad.lo.s32 	%r1, %r4, %r3, %r5;
+	setp.ge.u32	%p1, %r1, %r2;
+	@%p1 bra 	BB120_2;
+
+	cvta.to.global.u64 	%rd3, %rd1;
+	mul.wide.s32 	%rd4, %r1, 8;
+	add.s64 	%rd5, %rd3, %rd4;
+	ld.global.f64 	%fd2, [%rd5];
+	add.f64 	%fd3, %fd2, %fd1;
+	sqrt.rn.f64 	%fd4, %fd3;
+	rcp.rn.f64 	%fd5, %fd4;
+	cvta.to.global.u64 	%rd6, %rd2;
+	add.s64 	%rd7, %rd6, %rd4;
+	st.global.f64 	[%rd7], %fd5;
+
+BB120_2:
+	ret;
+}
+
+	// .globl	invVar_f
+.visible .entry invVar_f(
+	.param .u64 invVar_f_param_0,
+	.param .u64 invVar_f_param_1,
+	.param .f64 invVar_f_param_2,
+	.param .u32 invVar_f_param_3
+)
+{
+	.reg .pred 	%p<2>;
+	.reg .f32 	%f<3>;
+	.reg .b32 	%r<6>;
+	.reg .f64 	%fd<6>;
+	.reg .b64 	%rd<8>;
+
+
+	ld.param.u64 	%rd1, [invVar_f_param_0];
+	ld.param.u64 	%rd2, [invVar_f_param_1];
+	ld.param.f64 	%fd1, [invVar_f_param_2];
+	ld.param.u32 	%r2, [invVar_f_param_3];
+	mov.u32 	%r3, %ctaid.x;
+	mov.u32 	%r4, %ntid.x;
+	mov.u32 	%r5, %tid.x;
+	mad.lo.s32 	%r1, %r4, %r3, %r5;
+	setp.ge.u32	%p1, %r1, %r2;
+	@%p1 bra 	BB121_2;
+
+	cvta.to.global.u64 	%rd3, %rd1;
+	mul.wide.s32 	%rd4, %r1, 4;
+	add.s64 	%rd5, %rd3, %rd4;
+	ld.global.f32 	%f1, [%rd5];
+	cvt.f64.f32	%fd2, %f1;
+	add.f64 	%fd3, %fd2, %fd1;
+	sqrt.rn.f64 	%fd4, %fd3;
+	rcp.rn.f64 	%fd5, %fd4;
+	cvt.rn.f32.f64	%f2, %fd5;
+	cvta.to.global.u64 	%rd6, %rd2;
+	add.s64 	%rd7, %rd6, %rd4;
+	st.global.f32 	[%rd7], %f2;
+
+BB121_2:
+	ret;
+}
+
 .func  (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
 	.param .b64 __internal_trig_reduction_slowpathd_param_0,
 	.param .b64 __internal_trig_reduction_slowpathd_param_1
 )
 {
-	.local .align 8 .b8 	__local_depot116[40];
+	.local .align 8 .b8 	__local_depot122[40];
 	.reg .b64 	%SP;
 	.reg .b64 	%SPL;
 	.reg .pred 	%p<9>;
@@ -13098,7 +13365,7 @@ BB115_2:
 	.reg .b64 	%rd<102>;
 
 
-	mov.u64 	%rd101, __local_depot116;
+	mov.u64 	%rd101, __local_depot122;
 	cvta.local.u64 	%SP, %rd101;
 	ld.param.f64 	%fd4, [__internal_trig_reduction_slowpathd_param_0];
 	ld.param.u64 	%rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -13112,7 +13379,7 @@ BB115_2:
 	shr.u32 	%r3, %r1, 20;
 	bfe.u32 	%r4, %r1, 20, 11;
 	setp.eq.s32	%p1, %r4, 2047;
-	@%p1 bra 	BB116_13;
+	@%p1 bra 	BB122_13;
 
 	add.s32 	%r15, %r4, -1024;
 	shr.u32 	%r16, %r15, 6;
@@ -13125,7 +13392,7 @@ BB115_2:
 	mov.u64 	%rd94, 0;
 	setp.ge.s32	%p2, %r5, %r6;
 	mov.u64 	%rd93, %rd1;
-	@%p2 bra 	BB116_4;
+	@%p2 bra 	BB122_4;
 
 	mov.b64 	 %rd41, %fd4;
 	shl.b64 	%rd42, %rd41, 11;
@@ -13142,7 +13409,7 @@ BB115_2:
 	mov.u64 	%rd91, %rd1;
 	mov.u32 	%r39, %r5;
 
-BB116_3:
+BB122_3:
 	.pragma "nounroll";
 	ld.const.u64 	%rd47, [%rd89];
 	// inline asm
@@ -13172,15 +13439,15 @@ BB116_3:
 	add.s64 	%rd93, %rd93, 8;
 	add.s64 	%rd89, %rd89, 8;
 	setp.lt.s32	%p3, %r39, %r6;
-	@%p3 bra 	BB116_3;
+	@%p3 bra 	BB122_3;
 
-BB116_4:
+BB122_4:
 	st.local.u64 	[%rd93], %rd94;
 	ld.local.u64 	%rd95, [%rd1+16];
 	ld.local.u64 	%rd96, [%rd1+24];
 	and.b32  	%r9, %r3, 63;
 	setp.eq.s32	%p4, %r9, 0;
-	@%p4 bra 	BB116_6;
+	@%p4 bra 	BB122_6;
 
 	mov.u32 	%r27, 64;
 	sub.s32 	%r28, %r27, %r9;
@@ -13192,7 +13459,7 @@ BB116_4:
 	shr.u64 	%rd55, %rd54, %r28;
 	or.b64  	%rd95, %rd55, %rd53;
 
-BB116_6:
+BB122_6:
 	cvta.to.local.u64 	%rd56, %rd37;
 	shr.u64 	%rd57, %rd96, 62;
 	cvt.u32.u64	%r29, %rd57;
@@ -13209,7 +13476,7 @@ BB116_6:
 	selp.b32	%r34, %r32, %r33, %p5;
 	st.local.u32 	[%rd56], %r34;
 	setp.eq.s32	%p6, %r31, 0;
-	@%p6 bra 	BB116_8;
+	@%p6 bra 	BB122_8;
 
 	mov.u64 	%rd64, 0;
 	// inline asm
@@ -13229,10 +13496,10 @@ BB116_6:
 	// inline asm
 	xor.b32  	%r40, %r40, -2147483648;
 
-BB116_8:
+BB122_8:
 	clz.b64 	%r41, %rd98;
 	setp.eq.s32	%p7, %r41, 0;
-	@%p7 bra 	BB116_10;
+	@%p7 bra 	BB122_10;
 
 	shl.b64 	%rd67, %rd98, %r41;
 	mov.u32 	%r35, 64;
@@ -13240,7 +13507,7 @@ BB116_8:
 	shr.u64 	%rd68, %rd97, %r36;
 	or.b64  	%rd98, %rd68, %rd67;
 
-BB116_10:
+BB122_10:
 	mov.u64 	%rd72, -3958705157555305931;
 	// inline asm
 	{
@@ -13261,7 +13528,7 @@ BB116_10:
 	}
 	// inline asm
 	setp.lt.s64	%p8, %rd100, 1;
-	@%p8 bra 	BB116_12;
+	@%p8 bra 	BB122_12;
 
 	// inline asm
 	{
@@ -13280,7 +13547,7 @@ BB116_10:
 	// inline asm
 	add.s32 	%r41, %r41, 1;
 
-BB116_12:
+BB122_12:
 	cvt.u64.u32	%rd79, %r40;
 	shl.b64 	%rd80, %rd79, 32;
 	mov.u32 	%r37, 1022;
@@ -13295,7 +13562,7 @@ BB116_12:
 	or.b64  	%rd88, %rd87, %rd80;
 	mov.b64 	 %fd4, %rd88;
 
-BB116_13:
+BB122_13:
 	st.param.f64	[func_retval0+0], %fd4;
 	ret;
 }
@@ -13323,7 +13590,7 @@ BB116_13:
 	}
 	shr.u32 	%r51, %r50, 20;
 	setp.ne.s32	%p1, %r51, 0;
-	@%p1 bra 	BB117_2;
+	@%p1 bra 	BB123_2;
 
 	mul.f64 	%fd14, %fd12, 0d4350000000000000;
 	{
@@ -13337,13 +13604,13 @@ BB116_13:
 	shr.u32 	%r16, %r50, 20;
 	add.s32 	%r51, %r16, -54;
 
-BB117_2:
+BB123_2:
 	add.s32 	%r52, %r51, -1023;
 	and.b32  	%r17, %r50, -2146435073;
 	or.b32  	%r18, %r17, 1072693248;
 	mov.b64 	%fd135, {%r49, %r18};
 	setp.lt.u32	%p2, %r18, 1073127583;
-	@%p2 bra 	BB117_4;
+	@%p2 bra 	BB123_4;
 
 	{
 	.reg .b32 %temp; 
@@ -13357,7 +13624,7 @@ BB117_2:
 	mov.b64 	%fd135, {%r19, %r21};
 	add.s32 	%r52, %r51, -1022;
 
-BB117_4:
+BB123_4:
 	add.f64 	%fd15, %fd135, 0d3FF0000000000000;
 	rcp.approx.ftz.f64 	%fd16, %fd15;
 	neg.f64 	%fd17, %fd15;
@@ -13520,13 +13787,13 @@ BB117_4:
 	mov.b32 	 %f2, %r35;
 	abs.f32 	%f1, %f2;
 	setp.lt.f32	%p4, %f1, 0f4086232B;
-	@%p4 bra 	BB117_7;
+	@%p4 bra 	BB123_7;
 
 	setp.lt.f64	%p5, %fd4, 0d0000000000000000;
 	add.f64 	%fd129, %fd4, 0d7FF0000000000000;
 	selp.f64	%fd136, 0d0000000000000000, %fd129, %p5;
 	setp.geu.f32	%p6, %f1, 0f40874800;
-	@%p6 bra 	BB117_7;
+	@%p6 bra 	BB123_7;
 
 	mov.f64 	%fd134, 0d4338000000000000;
 	mov.f64 	%fd133, 0d3FF71547652B82FE;
@@ -13548,26 +13815,26 @@ BB117_4:
 	mov.b64 	%fd131, {%r44, %r43};
 	mul.f64 	%fd136, %fd130, %fd131;
 
-BB117_7:
+BB123_7:
 	{
 	.reg .b32 %temp; 
 	mov.b64 	{%temp, %r45}, %fd136;
 	}
 	and.b32  	%r46, %r45, 2147483647;
 	setp.ne.s32	%p7, %r46, 2146435072;
-	@%p7 bra 	BB117_9;
+	@%p7 bra 	BB123_9;
 
 	{
 	.reg .b32 %temp; 
 	mov.b64 	{%r47, %temp}, %fd136;
 	}
 	setp.eq.s32	%p8, %r47, 0;
-	@%p8 bra 	BB117_10;
+	@%p8 bra 	BB123_10;
 
-BB117_9:
+BB123_9:
 	fma.rn.f64 	%fd136, %fd136, %fd5, %fd136;
 
-BB117_10:
+BB123_10:
 	st.param.f64	[func_retval0+0], %fd136;
 	ret;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java
index a7d37dc..c4ce466 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -110,8 +110,6 @@ public class DnnOp extends MultiThreadedHop
 		if( getLops() != null )
 			return getLops();
 		
-		ExecType et = optFindExecType();
-		
 		ArrayList<Hop> inputs = getInput();
 		switch( op )
 		{
@@ -125,6 +123,7 @@ public class DnnOp extends MultiThreadedHop
 			case BIASADD:
 			case BIASMULT:
 			{	
+				ExecType et = optFindExecType();
 				if(et == ExecType.CP || et == ExecType.GPU) {
 					setLops(constructDnnLops(et, inputs));
 					break;
@@ -137,15 +136,15 @@ public class DnnOp extends MultiThreadedHop
 			case BATCH_NORM2D_TEST:
 			case CHANNEL_SUMS:
 			case UPDATE_NESTEROV_X:
+			case UPDATE_EMA_VAR:
+			case RESHAPE_COLMEANS:
+			case UPDATE_EMA:
+			case INV_VAR:
+			case BATCH_NORM2D_BACKWARD_DX:
 			{	
-				if(et == ExecType.GPU) {
-					setLops(constructDnnLops(et, inputs));
-					break;
-				}
-				else {
-					throw new HopsException("Unimplemented DnnOp for execution type: " + et.name());
-				}
-				// break;
+				// GPU-specific operators
+				setLops(constructDnnLops(ExecType.GPU, inputs));
+				break;
 			}
 			default: 
 				throw new HopsException("Unsupported lops construction for operation type '"+op+"'.");
@@ -171,10 +170,16 @@ public class DnnOp extends MultiThreadedHop
 				return 14;
 			case BIASADD:
 			case BIASMULT:
+			case INV_VAR:
 				return 2;
 			case BATCH_NORM2D_TEST:
 				return 6;
+			case UPDATE_EMA_VAR:
+			case BATCH_NORM2D_BACKWARD_DX:
+				return 5;
+			case RESHAPE_COLMEANS:
 			case CHANNEL_SUMS:
+			case UPDATE_EMA:
 				return 3;
 			case UPDATE_NESTEROV_X:
 				return 4;
@@ -532,7 +537,8 @@ public class DnnOp extends MultiThreadedHop
 		long[] ret = new long[3];
 		
 		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
-			op == OpOpDnn.UPDATE_NESTEROV_X) {
+			op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+			op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
 			// Same dimension as the first input
 			MatrixCharacteristics[] mc = memo.getAllInputStats(getInput());
 			ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
@@ -540,13 +546,21 @@ public class DnnOp extends MultiThreadedHop
 			ret[2] = -1;
 			return (ret[0]>=0 && ret[1]>=0) ? ret : null;
 		}
-		else if(op == OpOpDnn.CHANNEL_SUMS) {
+		else if(op == OpOpDnn.CHANNEL_SUMS || op == OpOpDnn.UPDATE_EMA_VAR) {
 			long numChannels = Hop.computeSizeInformation(getInput().get(1));
 			ret[0] = numChannels;
 			ret[1] = 1;
 			ret[2] = -1;
 			return ret;
 		}
+		else if(op == OpOpDnn.RESHAPE_COLMEANS) {
+			long numChannels = Hop.computeSizeInformation(getInput().get(1));
+			long HW = Hop.computeSizeInformation(getInput().get(2));
+			ret[0] = numChannels;
+			ret[1] = HW;
+			ret[2] = -1;
+			return ret;
+		}
 		
 		refreshSizeInformation();
 		ret[0] = _dim1; ret[1] = _dim2; ret[2] = _nnz;
@@ -739,7 +753,9 @@ public class DnnOp extends MultiThreadedHop
 	@Override
 	public void refreshSizeInformation()
 	{
-		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X) {
+		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || 
+			op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+			op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
 			// Same dimension as the first input
 			Hop input1 = getInput().get(0);
 			setDim1(input1.getDim1());
@@ -747,13 +763,21 @@ public class DnnOp extends MultiThreadedHop
 			_nnz = -1; // cannot infer stats
 			return;
 		}
-		else if(op == OpOpDnn.CHANNEL_SUMS) {
+		else if(op == OpOpDnn.CHANNEL_SUMS || op == OpOpDnn.UPDATE_EMA_VAR) {
 			long numChannels = Hop.computeSizeInformation(getInput().get(1));
 			setDim1(numChannels);
 			setDim2(1);
 			_nnz = -1; // cannot infer stats
 			return;
 		}
+		else if(op == OpOpDnn.RESHAPE_COLMEANS) {
+			long numChannels = Hop.computeSizeInformation(getInput().get(1));
+			long HW = Hop.computeSizeInformation(getInput().get(2));
+			setDim1(numChannels);
+			setDim2(HW);
+			_nnz = -1; // cannot infer stats
+			return;
+		}
 		
 		// Reset the _cachedParams to avoid incorrect sizes
 		_cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, _maxNumThreads);
@@ -847,7 +871,9 @@ public class DnnOp extends MultiThreadedHop
 	 */
 	private long getDim(String dimString) {
 		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
-			op == OpOpDnn.UPDATE_NESTEROV_X) {
+			op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.RESHAPE_COLMEANS ||
+			op == OpOpDnn.UPDATE_EMA_VAR || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+			op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
 			throw new RuntimeException("getDim method should not be invoked for " + op.name());
 		}
 		try {

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/FunctionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index ea397db..5f177bd 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -181,21 +181,6 @@ public class FunctionOp extends Hop
 				// TODO: To allow for initial version to always run on the GPU
 				return 0; 
 			}
-			else if ( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
-				return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) +
-						OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) +
-						OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0) +
-						OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), getOutputs().get(3).getDim2(), 1.0) + 
-						OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), getOutputs().get(4).getDim2(), 1.0);
-			}
-			else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_test") ) {
-			return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0);
-		}
-			else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) {
-				return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) +
-						OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) +
-						OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0);
-			}
 			else if ( getFunctionName().equalsIgnoreCase("svd") ) {
 				long outputU = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0);
 				long outputSigma = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0);
@@ -226,10 +211,6 @@ public class FunctionOp extends Hop
 				return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0) 
 						+ 3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 1.0); 
 			}
-			else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
-					getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
-				return 0; 
-			}
 			else if ( getFunctionName().equalsIgnoreCase("lstm") ||  getFunctionName().equalsIgnoreCase("lstm_backward") ) {
 				// TODO: To allow for initial version to always run on the GPU
 				return 0; 
@@ -251,9 +232,7 @@ public class FunctionOp extends Hop
 	
 	@Override
 	public boolean isGPUEnabled() {
-		if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ||  
-			getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
-			getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) 
+		if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward")) 
 			return true;
 		else
 			return false;
@@ -308,13 +287,6 @@ public class FunctionOp extends Hop
 					throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU.");
 				_etype = ExecType.GPU;
 			}
-			else if(isBuiltinFunction && (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))) {
-				_etype = ConfigurationManager.isGPU() ? ExecType.GPU : ExecType.CP;
-			}
-			else if(isBuiltinFunction && getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
-				// Only GPU implementation is supported
-				_etype = ExecType.GPU;
-			}
 			else {
 				// Since the memory estimate is only conservative, do not throw
 				// exception if the estimated memory is larger than the budget

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 3b461a1..c8356e0 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1100,7 +1100,8 @@ public abstract class Hop implements ParseInfo
 		MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
 		CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
 		BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS,
-		UPDATE_NESTEROV_X
+		UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR,
+		BATCH_NORM2D_BACKWARD_DX
 	}
 	
 	public enum DataGenMethod {
@@ -1174,8 +1175,13 @@ public abstract class Hop implements ParseInfo
 		HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_FILTER, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_FILTER);
 		HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_DATA, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_DATA);
 		HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_TEST, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_TEST);
+		HopsConv2Lops.put(OpOpDnn.UPDATE_EMA_VAR, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA_VAR);
 		HopsConv2Lops.put(OpOpDnn.CHANNEL_SUMS, org.apache.sysml.lops.DnnTransform.OperationTypes.CHANNEL_SUMS);
 		HopsConv2Lops.put(OpOpDnn.UPDATE_NESTEROV_X, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_NESTEROV_X);
+		HopsConv2Lops.put(OpOpDnn.RESHAPE_COLMEANS, org.apache.sysml.lops.DnnTransform.OperationTypes.RESHAPE_COLMEANS);
+		HopsConv2Lops.put(OpOpDnn.UPDATE_EMA, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA);
+		HopsConv2Lops.put(OpOpDnn.INV_VAR, org.apache.sysml.lops.DnnTransform.OperationTypes.INV_VAR);
+		HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DX, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DX);
 	}
 
 	protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops;

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
new file mode 100644
index 0000000..7c70b7b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.function.Function;
+import java.util.function.Predicate;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOpDnn;
+import org.apache.sysml.hops.Hop.ReOrgOp;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
+import org.apache.sysml.utils.Explain;
+
+/**
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation.
+ */
+public class HopDagPatternMatcher {
+	static final HashSet<String> DEBUG_PATTERNS;
+	static {
+		// DEBUG_PATTERNS = new HashSet<>();
+		// DEBUG_PATTERNS.add("batchNormdX");
+		DEBUG_PATTERNS = null;
+	}
+	
+	// Predicates for the current HOP
+	List<HopPredicate> _predicates = new ArrayList<>();
+	// Child matchers
+	List<HopDagPatternMatcher> _children = new ArrayList<>();
+	private boolean _isLeaf = false;
+	
+	static boolean DEBUG_REWRITES = false; // This is set by HopPatternRewriter. Please use DEBUG_PATTERNS instead.
+	
+	// Simple utility for debugging the rewrites
+	public static class HopPredicate implements Predicate<Hop> {
+		final String _name;
+		final Function<Hop, Boolean> _pred;
+		public HopPredicate(String name, Function<Hop, Boolean> pred) {
+			_name = name;
+			_pred = pred;
+		}
+		@Override
+		public boolean test(Hop h) {
+			return _pred.apply(h);
+		}
+		@Override
+	    public String toString() {
+	        return _name;
+	    }
+	}
+	
+	/**
+	 * Adds a predicate to the pattern matcher
+	 * 
+	 * @param name name of the pattern for debugging
+	 * @param pred higher order function that takes as an input a hop and returns true if the pattern matches else false
+	 * @return this
+	 */
+	public HopDagPatternMatcher addPredicate(String name, Function<Hop, Boolean> pred) {
+		_predicates.add(new HopPredicate(name, pred));
+		return this;
+	}
+	
+	/**
+	 * Add child pattern matcher
+	 * @param children list of childer
+	 * @return this
+	 */
+	public HopDagPatternMatcher addChildMatcher(HopDagPatternMatcher... children) {
+		for(int i = 0; i < children.length; i++) {
+			_children.add(children[i]);
+		}
+		return this;
+	}
+	
+	/**
+	 * Get the matched HOP DAGs
+	 * @param varName variable names
+	 * @return matched HOP
+	 */
+	public Hop getMatchedHop(String varName) {
+		
+		if(matchedHops == null || !matchedHops.containsKey(varName)) {
+			throw new RuntimeException("Incorrect usage: the variable " + varName + " is not registered as input.");
+		}
+		return matchedHops.get(varName);
+	}
+	
+	/**
+	 * Return the value 
+	 * 
+	 * @param varName variable name
+	 * @return the value of the LiteralOp 
+	 */
+	public double getLiteralValue(String varName) {
+		return OptimizerUtils.rEvalSimpleDoubleExpression(getMatchedHop(varName), new HashMap<>());
+	}
+	
+	@Override
+    public String toString() {
+        return _predicates.size() >= 1 ? _predicates.get(0).toString() : "";
+    }
+	
+	/**
+	 * Match the given HOP DAG
+	 * 
+	 * @param h root node of the HOP DAG 
+	 * @return true if HOP DAG matches
+	 */
+	public boolean matches(Hop h) {
+		visited.clear();
+		matchedHops.clear();
+		return matchHelper(this, h);
+	}
+	
+	private HashMap<String, Hop> matchedHops = new HashMap<>();
+	private String variableName;
+	private HashMap<HopDagPatternMatcher, Hop> visited = new HashMap<>(); // Map of matched hops
+	private boolean matchHelper(HopDagPatternMatcher root, Hop h) {
+		if(h == null) {
+			return false;
+		}
+		else if(_children.size() > 0 && h.getInput().size() < _children.size()) {
+			if(DEBUG_REWRITES) {
+				System.out.println("The expected number of children (" + _children.size() + ") didnot match the number of inputs (" + h.getInput().size() + ") " + this);
+			}
+			return false;
+		}
+		if(root.visited.containsKey(this)) {
+			Hop h1 = root.visited.get(this);
+			if(h == h1) {
+				if(DEBUG_REWRITES)
+					System.out.println("MATCHED: Early exit as the given HOP has been already matched by the matcher." + this); 
+				return true; // Early exit as the given HOP has been already matched by the matcher
+			}
+			else if(_isLeaf) {
+				if(h.getDataType() == h1.getDataType() && h.getDataType() == DataType.SCALAR) {
+					return OptimizerUtils.rEvalSimpleDoubleExpression(h, new HashMap<>()) == OptimizerUtils.rEvalSimpleDoubleExpression(h1, new HashMap<>());
+				}
+				return false; // Mismatched or unknown datatypes or matched with different hops
+			}
+		}
+		
+		for(HopPredicate p : _predicates) {
+			if(!p.test(h)) {
+				if(DEBUG_REWRITES) {
+					System.out.println("The predicate " + p.toString() + " failed.");
+				}
+				return false;
+			}
+		}
+		int index = 0;
+		for(HopDagPatternMatcher child : _children) {
+			if(!child.matchHelper(root, h.getInput().get(index))) {
+				return false;
+			}
+			index++;
+		}
+		if(_isLeaf) {
+			root.matchedHops.put(variableName, h);
+		}
+		
+		root.visited.put(this, h);
+		if(DEBUG_REWRITES)
+			System.out.println("MATCHED: " + this + " to " + Explain.explain(h));
+		return true;
+	}
+	
+
+	// Simple helper utilities for adding predicates
+	private HopDagPatternMatcher isScalar() {
+		return this.addPredicate("isScalar", h -> h.getDataType() == DataType.SCALAR);
+	}
+	private HopDagPatternMatcher isMatrix() {
+		return this.addPredicate("isMatrix", h -> h.getDataType() == DataType.MATRIX);
+	}
+	public HopDagPatternMatcher fitsOnGPU(double constant) {
+		return this.addPredicate("fitsOnGPU", h -> _fitsOnGPU(h, constant));
+	}
+	
+	// Factory methods:
+	public static HopDagPatternMatcher dummy = new HopDagPatternMatcher();
+	public static HopDagPatternMatcher rowMeans(HopDagPatternMatcher child1) {
+		return new HopDagPatternMatcher().addPredicate("rowMeans", h -> 
+			h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row)
+			.addChildMatcher(child1);
+	}
+	public static HopDagPatternMatcher rowVars(HopDagPatternMatcher child1) {
+		return new HopDagPatternMatcher().addPredicate("rowVars", h -> 
+			h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row)
+			.addChildMatcher(child1);
+	}
+	public static HopDagPatternMatcher colVars(HopDagPatternMatcher child1) {
+		return new HopDagPatternMatcher().addPredicate("colVars", h -> 
+			h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col)
+			.addChildMatcher(child1);
+	}
+	public static HopDagPatternMatcher leaf(String _variableName, DataType dt) {
+		HopDagPatternMatcher ret = new HopDagPatternMatcher();
+		ret._isLeaf = true;
+		ret.variableName = _variableName;
+		if(dt == DataType.MATRIX) {
+			return ret.isMatrix();
+		}
+		else if(dt == DataType.SCALAR) {
+			return ret.isScalar();
+		}
+		else if(dt == DataType.UNKNOWN) {
+			return ret;
+		}
+		else {
+			throw new DMLRuntimeException("Unsupported datatype in pattern matcher:" + dt.name());
+		}
+	}
+	public static HopDagPatternMatcher rowSums(HopDagPatternMatcher child1) {
+		return new HopDagPatternMatcher().addPredicate("rowSums", h -> 
+			h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.SUM && ((AggUnaryOp)h).getDirection() == Direction.Row)
+			.addChildMatcher(child1);
+	}
+	public static HopDagPatternMatcher colSums(HopDagPatternMatcher child1) {
+		return new HopDagPatternMatcher().addPredicate("colSums", h -> 
+			h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.SUM && ((AggUnaryOp)h).getDirection() == Direction.Col)
+			.addChildMatcher(child1);
+	}
+	public static HopDagPatternMatcher colMeans(HopDagPatternMatcher child1) {
+		return new HopDagPatternMatcher().addPredicate("colSums", h -> 
+			h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col)
+			.addChildMatcher(child1);
+	}
+	public static HopDagPatternMatcher matrix(HopDagPatternMatcher X, HopDagPatternMatcher rows, HopDagPatternMatcher cols) {
+		return new HopDagPatternMatcher().addPredicate("matrix_reshape", h -> HopRewriteUtils.isReorg(h, ReOrgOp.RESHAPE))
+				.addChildMatcher(X, rows, cols);
+	}
+	public static HopDagPatternMatcher matrix(double X, HopDagPatternMatcher rows, HopDagPatternMatcher cols) {
+		return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X))
+				.addChildMatcher(rows, cols);
+	}
+	public static HopDagPatternMatcher matrix(double X, HopDagPatternMatcher rows, long cols) {
+		return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) && 
+				h.getDim2() == cols)
+				.addChildMatcher(rows, dummy);
+	}
+	public static HopDagPatternMatcher matrix(double X, long rows, HopDagPatternMatcher cols) {
+		return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) && 
+				h.getDim1() == rows)
+				.addChildMatcher(dummy, cols);
+	}
+	public static HopDagPatternMatcher matrix(double X, long rows, long cols) {
+		return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) && 
+				h.getDim1() == rows && h.getDim2() == cols)
+				.addChildMatcher(dummy, dummy);
+	}
+	public static HopDagPatternMatcher bias_add(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("bias_add", h -> HopRewriteUtils.isDnn(h, OpOpDnn.BIASADD))
+				.addChildMatcher(child1, child2);
+	}
+	public static HopDagPatternMatcher bias_multiply(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("bias_multiply", h -> HopRewriteUtils.isDnn(h, OpOpDnn.BIASMULT))
+				.addChildMatcher(child1, child2);
+	}
+	public static HopDagPatternMatcher unaryMinus(HopDagPatternMatcher child) {
+		return new HopDagPatternMatcher().addPredicate("unaryMinus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS)
+				&& HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0))
+				.addChildMatcher(dummy, child);
+	}
+	public static HopDagPatternMatcher sqrt(HopDagPatternMatcher child) {
+		return new HopDagPatternMatcher().addPredicate("sqrt", h -> HopRewriteUtils.isUnary(h, OpOp1.SQRT))
+				.addChildMatcher(child);
+	}
+	public static HopDagPatternMatcher div(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV))
+				.addChildMatcher(child1, child2);
+	}
+	public static HopDagPatternMatcher div(double child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV) && 
+				HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+				.addChildMatcher(dummy, child2);
+	}
+	public static HopDagPatternMatcher div(HopDagPatternMatcher child1, double child2) {
+		return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV) && 
+				HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+				.addChildMatcher(child1, dummy);
+	}
+	
+	public static HopDagPatternMatcher pow(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("pow", h -> HopRewriteUtils.isBinary(h, OpOp2.POW))
+				.addChildMatcher(child1, child2);
+	}
+	public static HopDagPatternMatcher pow(HopDagPatternMatcher child1, double child2) {
+		return new HopDagPatternMatcher().addPredicate("pow", h -> HopRewriteUtils.isBinary(h, OpOp2.POW) && 
+				HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+				.addChildMatcher(child1, dummy);
+	}
+	private static boolean matchDimensions(Hop h1, Hop h2) {
+		return h1.getDim1() == h2.getDim1() && h1.getDim2() == h2.getDim2();
+	}
+	// This is used to differentiate between matrix-matrix and matrix-vector operations.
+	public static HopDagPatternMatcher mm_plus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS)
+				&& matchDimensions(h.getInput().get(0), h.getInput().get(1)))
+				.addChildMatcher(child1, child2);
+	}
+	public static HopDagPatternMatcher plus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS))
+				.addChildMatcher(child1, child2);
+	}
+	public static HopDagPatternMatcher plus(double child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS) && 
+				HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+				.addChildMatcher(dummy, child2);
+	}
+	public static HopDagPatternMatcher plus(HopDagPatternMatcher child1, double child2) {
+		return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS) && 
+				HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+				.addChildMatcher(child1, dummy);
+	}
+	public static HopDagPatternMatcher minus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS))
+				.addChildMatcher(child1, child2);
+	}
+	public static HopDagPatternMatcher minus(double child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS) && 
+				HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+				.addChildMatcher(dummy, child2);
+	}
+	public static HopDagPatternMatcher minus(HopDagPatternMatcher child1, double child2) {
+		return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS) && 
+				HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+				.addChildMatcher(child1, dummy);
+	}
+	public static HopDagPatternMatcher mult(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT))
+				.addChildMatcher(child1, child2);
+	}
+	public static HopDagPatternMatcher mult(double child1, HopDagPatternMatcher child2) {
+		return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT) && 
+				HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+				.addChildMatcher(dummy, child2);
+	}
+	public static HopDagPatternMatcher mult(HopDagPatternMatcher child1, double child2) {
+		return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT) && 
+				HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+				.addChildMatcher(child1, dummy);
+	}
+	private static boolean _fitsOnGPU(Hop h, double multiplier) {
+		double memEst = multiplier*h.getMemEstimate();
+		return ConfigurationManager.isGPU() && h.dimsKnown() && OptimizerUtils.isMemoryBasedOptLevel() &&
+				memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget();
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
new file mode 100644
index 0000000..02472ed
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.function.Function;
+
+import org.apache.sysml.hops.Hop;
+
+/**
+ * This class is used with HopRewriteRuleWithPatternMatcher to implement the following pattern matching logic:
+ * ArrayList<HopPatternRewriter> patternRewriters =  getPatternRewriter();
+ * for(HopPatternRewriter patternRewriter : patternRewriters) {
+ *   hi = patternRewriter.rewrite(hi);
+ * }
+ * 
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation.
+ */
+public class HopPatternRewriter {
+	private final HopDagPatternMatcher _matcher;
+	private final Function<Hop, Hop> _replacer;
+	private final String _name;
+	public HopPatternRewriter(String name, HopDagPatternMatcher matcher, Function<Hop, Hop> replacer) {
+		_name = name;
+		_matcher = matcher;
+		_replacer = replacer;
+	}
+	
+	public Hop rewrite(Hop root) {
+		boolean printMessage = HopDagPatternMatcher.DEBUG_PATTERNS != null && HopDagPatternMatcher.DEBUG_PATTERNS.contains(_name);
+		if(printMessage) {
+			HopDagPatternMatcher.DEBUG_REWRITES = true;
+			System.out.println("-----------------------------------");
+			System.out.println(org.apache.sysml.utils.Explain.explain(root));
+		}
+		if(_matcher.matches(root)) {
+			Hop newHop = _replacer.apply(root);
+			if(printMessage) {
+				if(newHop == root)
+					System.out.println("Initial pattern match for " + _name + " succeeded but replacer returned the same HOP.");
+				else
+					System.out.println("Pattern match for " + _name + " succeeded.");
+				HopDagPatternMatcher.DEBUG_REWRITES = false;
+				System.out.println("-----------------------------------");
+			}
+			return newHop;
+		}
+		else {
+			if(printMessage) {
+				System.out.println("Pattern match for " + _name + " failed.");
+				HopDagPatternMatcher.DEBUG_REWRITES = false;
+				System.out.println("-----------------------------------");
+			}
+			return root;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
new file mode 100644
index 0000000..854eca3
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+
+import org.apache.sysml.hops.Hop;
+
+/**
+ * Simple utility class that implements generic structure for HopRewriteRule.
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation.
+ */
+public abstract class HopRewriteRuleWithPatternMatcher extends HopRewriteRule {
+	
+	public abstract ArrayList<HopPatternRewriter> getPatternRewriter();
+	
+	@Override
+	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
+		if( roots == null )
+			return roots;
+
+		//one pass rewrite-descend (rewrite created pattern)
+		for( int i = 0; i < roots.size(); i++ )
+			applyRules(roots, roots.get(i), false );
+		Hop.resetVisitStatus(roots, true);
+
+		//one pass descend-rewrite (for rollup) 
+		for( int i = 0; i < roots.size(); i++ )
+			applyRules(roots, roots.get(i), true );
+		Hop.resetVisitStatus(roots, true);
+		
+		return roots;
+	}
+
+	@Override
+	public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
+		if( root == null )
+			return root;
+		
+		//one pass rewrite-descend (rewrite created pattern)
+		applyRules(null, root, false );
+		
+		root.resetVisitStatus();
+		
+		//one pass descend-rewrite (for rollup) 
+		applyRules(null, root, true );
+		
+		return root;
+	}
+	
+	/**
+	 * Apply rules
+	 * 
+	 * @param roots root operators
+	 * @param hop high-level operator
+	 * @param descendFirst true if recursively process children first
+	 */
+	private void applyRules(ArrayList<Hop> roots, Hop hop, boolean descendFirst) 
+	{
+		if(hop.isVisited())
+			return;
+		
+		//recursively process children
+		for( int i=0; i<hop.getInput().size(); i++) {
+			Hop hi = hop.getInput().get(i);
+			
+			//process childs recursively first (to allow roll-up)
+			if( descendFirst )
+				applyRules(roots, hi, descendFirst); //see below
+			
+			ArrayList<HopPatternRewriter> patternRewriters =  getPatternRewriter();
+			for(HopPatternRewriter patternRewriter : patternRewriters) {
+				hi = patternRewriter.rewrite(hi);
+			}
+			
+			if( !descendFirst )
+				applyRules(roots, hi, descendFirst);
+		}
+
+		hop.setVisited();
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 271142d..2351f5f 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -719,6 +719,26 @@ public class HopRewriteUtils
 		return ternOp;
 	}
 	
+	public static DnnOp createDnnOp(OpOpDnn op, Hop... hops) {
+		ArrayList<Hop> inHops = new ArrayList<Hop>();
+		for(Hop h : hops) {
+			inHops.add(h);
+		}
+		return  new DnnOp("tmp", DataType.MATRIX, ValueType.DOUBLE,
+				op, inHops);
+	}
+	
+	public static DnnOp createDnnOp(HopDagPatternMatcher matcher, OpOpDnn op, String... varNames) {
+		ArrayList<Hop> inHops = new ArrayList<Hop>();
+		for(String v : varNames) {
+			inHops.add(matcher.getMatchedHop(v));
+		}
+		return  new DnnOp("tmp", DataType.MATRIX, ValueType.DOUBLE,
+				op, inHops);
+	}
+	
+	
+	
 	public static void setOutputParameters( Hop hop, long rlen, long clen, int brlen, int bclen, long nnz ) {
 		hop.setDim1( rlen );
 		hop.setDim2( clen );


[2/3] systemml git commit: [SYSTEMML-445] Removed batch_norm builtin functions

Posted by ni...@apache.org.
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index acf2e48..53d368b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -20,811 +20,288 @@
 package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
+import java.util.function.Function;
 
-import org.apache.sysml.conf.ConfigurationManager;
-import org.apache.sysml.hops.AggUnaryOp;
-import org.apache.sysml.hops.BinaryOp;
-import org.apache.sysml.hops.FunctionOp;
 import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.FunctionOp.FunctionType;
-import org.apache.sysml.hops.Hop.AggOp;
-import org.apache.sysml.hops.Hop.DataOpTypes;
-import org.apache.sysml.hops.Hop.Direction;
-import org.apache.sysml.hops.Hop.OpOp1;
-import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.Hop.OpOpDnn;
-import org.apache.sysml.hops.Hop.ReOrgOp;
-import org.apache.sysml.hops.DataOp;
-import org.apache.sysml.hops.LiteralOp;
-import org.apache.sysml.hops.DnnOp;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.hops.ReorgOp;
-import org.apache.sysml.hops.UnaryOp;
-import org.apache.sysml.parser.DMLProgram;
-import org.apache.sysml.parser.Expression.DataType;
-import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
+import static org.apache.sysml.hops.rewrite.HopDagPatternMatcher.*;
+import static org.apache.sysml.parser.Expression.DataType.MATRIX;
+import static org.apache.sysml.parser.Expression.DataType.SCALAR;
+
 
 /*
- * This class contains GPU-specific rewrites for following patterns:
+ * -------------------------------------------------------------------------
+ * Design documentation for hop rewrite rules that use HopDagPatternMatcher:
+ * -------------------------------------------------------------------------
+ * 
+ * Typical (but not all) hop rewrite rules have following structure:
+ * 1. Rules are grouped together in different Java classes and added in org.apache.sysml.hops.rewrite.ProgramRewriter.
+ * 
+ * 2. Each rule class inherits from HopRewriteRule and implements rewriteHopDAG method. Other class of rewrite rules are StatementBlockRewriteRule and are not covered by this approach.
+ * 
+ * 3. The structure of rewriteHopDAG is common across HopRewriteRule subclasses and usually have following pattern:
+ *  if(root of the given HOP DAG matches certain pattern) {
+ *    HopRewriteUtils.rewireAllParentChildReferences(root, newRoot)
+ *  }
+ *  else root
+ * 
+ * 4. To avoid redundancy, the above logic is implemented in the abstract class HopRewriteRuleWithPatternMatcher:
+ *  ArrayList<HopPatternRewriter> patternRewriters =  getPatternRewriter();
+ *    for(HopPatternRewriter patternRewriter : patternRewriters) {
+ *      hi = patternRewriter.rewrite(hi);
+ *  }
+ * 
+ * 5. The developer has to inherit from HopRewriteRuleWithPatternMatcher that implements the above logic
+ * and write code for getPatternRewriter() that returns ArrayList<HopPatternRewriter>  
  * 
- * 1. batchNormTest: applied when mode="test" in batch normalization nn layer.
- * norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
- * hi = bias_add(bias_multiply(norm, gamma), beta)
+ * 6. Since the HOP pattern donot change during execution, it is convenient to implement them into a static variable: 
+ * ArrayList<HopPatternRewriter> _rewriters
  * 
- * 2. channelSum:
- * output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))
+ * 7. The replacement part in each entry of patternMatcher invokes the helper methods in HopRewriteUtils to create a newRoot. For example: HopRewriteUtils.createDnnOp
  * 
- * 3. batchNormTrain: applied when mode="train" in batch normalization nn layer.
- * This rewrite is only enabled if none of the outputs are persistent writes as it assumes that 
- * FunctionOp will introduce a transient writes. This rewrite replaces the existing outputs of the matched pattern with transient reads.
+ * 8. The below DSL is more readable if implemented with Scala's operator overloading, but it adds an dependency on scala library 
+ * (in specific, scala uses scala.Function1 for implementing operator overloading).
+ * Hence, to minimize the dependency, the DSL is implemented using static methods in HopDagPatternMatcher class.
+ * We can revisit this if we plan to add scala as hard dependency in SystemML. 
+ * 
+ * 9. The matcher part in each entry of patternMatcher uses the DSL implemented in HopDagPatternMatcher to improve readability.
+ * - The DSL mentioned above follows DML syntax that makes it convenient for an external contributer to understand and modify the HOP rewrites.
+ * - It is important to note that the developer has to add the same scoping rules as SystemML.
+ * - To create a newRoot HOP, it is important to have a mechanism to extract leaves of the matched pattern. This is implemented
+ * by using leaf() method.
+ * - Often, it is important to create a new HOP only if it it can fit into memory. For GPU, one can use the fitsOnGPU(multiplier) helper method.
  * 
  */
-public class RewriteGPUSpecificOps extends HopRewriteRule {
-
-	private static int _seq = 1;
-	
-	@Override
-	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
-		if( roots == null )
-			return roots;
-
-		//one pass rewrite-descend (rewrite created pattern)
-		for( int i = 0; i < roots.size(); i++ )
-			rule_GPUKernels(roots, roots.get(i), false );
-		Hop.resetVisitStatus(roots, true);
-
-		//one pass descend-rewrite (for rollup) 
-		for( int i = 0; i < roots.size(); i++ )
-			rule_GPUKernels(roots, roots.get(i), true );
-		Hop.resetVisitStatus(roots, true);
+public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher {
+	// -------------------------------------------------------------------------------------------
+	
+	private static HopDagPatternMatcher util_channel_sums(HopDagPatternMatcher X, HopDagPatternMatcher C, HopDagPatternMatcher HW) {
+		// rowSums(matrix(colSums(X), rows=C, cols=HW))
+		return rowSums(matrix(	colSums(X), C, HW));
+	}
+	
+	// Pattern 1:
+	private static final HopDagPatternMatcher _batchNormdX;
+	static {
+		HopDagPatternMatcher C = leaf("C",  SCALAR);
+		HopDagPatternMatcher HW = leaf("HW",  SCALAR);
+		HopDagPatternMatcher CHW = leaf("CHW",  SCALAR);
+		HopDagPatternMatcher cache_inv_var = leaf("cache_inv_var", MATRIX);
+		HopDagPatternMatcher dout = leaf("dout", MATRIX);
+		HopDagPatternMatcher gamma = leaf("gamma", MATRIX);
+		HopDagPatternMatcher X = leaf("X", MATRIX);
+		HopDagPatternMatcher mean = leaf("mean", MATRIX);
 		
-		return roots;
-	}
-
-	@Override
-	public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
-		if( root == null )
-			return root;
-		
-		//one pass rewrite-descend (rewrite created pattern)
-		rule_GPUKernels(null, root, false );
-		
-		root.resetVisitStatus();
-		
-		//one pass descend-rewrite (for rollup) 
-		rule_GPUKernels(null, root, true );
+		HopDagPatternMatcher centered = bias_add(X, unaryMinus(mean));
 		
-		return root;
-	}
-	
-	/**
-	 * Fuse the kernel
-	 * 
-	 * @param roots root operators
-	 * @param hop high-level operator
-	 * @param descendFirst true if recursively process children first
-	 */
-	private void rule_GPUKernels(ArrayList<Hop> roots, Hop hop, boolean descendFirst) 
-	{
-		if(hop.isVisited())
-			return;
-		
-		//recursively process children
-		for( int i=0; i<hop.getInput().size(); i++) {
-			Hop hi = hop.getInput().get(i);
-			
-			//process childs recursively first (to allow roll-up)
-			if( descendFirst )
-				rule_GPUKernels(roots, hi, descendFirst); //see below
-			
-			if(roots != null) {
-				//hi = batchNormTrain(roots, hop, hi, i);
-			}
-			hi = batchNormTest(hop, hi, i); 
-			hi = channelSums(hop, hi, i); 
-			hi = updateNesterovX(hop, hi, i);
-	
-			if( !descendFirst )
-				rule_GPUKernels(roots, hi, descendFirst);
-		}
-
-		hop.setVisited();
-	}
-	
-	private static boolean isBiasAdd(Hop h) {
-		return HopRewriteUtils.isDnn(h, OpOpDnn.BIASADD);
-	}
-	
-	private static boolean isBiasMultiply(Hop h) {
-		return HopRewriteUtils.isDnn(h, OpOpDnn.BIASMULT);
-	}
-	
-	private static boolean fitsOnGPU(Hop h, double multiplier) {
-		double memEst = multiplier*h.getMemEstimate();
-		return ConfigurationManager.isGPU() && h.dimsKnown() && OptimizerUtils.isMemoryBasedOptLevel() &&
-				memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget();
-	}
-	
-	private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean isFirstSameSizeAsOutput) {
-		return fitsOnGPU(inputHops, isFirstSameSizeAsOutput, 0);
-	}
-	
-	private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean isFirstSameSizeAsOutput, long additionalBytes) {
-		double memEst = additionalBytes;
-		boolean isFirst = true;
-		for(Hop h : inputHops) {
-			double est = h.getMemEstimate();
-			if(est == OptimizerUtils.INVALID_SIZE) {
-				return false;
-			}
-			else if(isFirst && isFirstSameSizeAsOutput) {
-				isFirst = false;
-				memEst += 2*est;
-			}
-			else {
-				memEst += est;
-			}
-		}
-		return ConfigurationManager.isGPU() && OptimizerUtils.isMemoryBasedOptLevel() &&
-				memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget();
-	}
-	
-	private static boolean hasFirstInput(Hop h) {
-		return !(h == null || h.getInput() == null || h.getInput().size() < 1);
-	}
-	
-	private static Hop getFirstInput(Hop h) {
-		if(h == null || h.getInput() == null || h.getInput().size() < 1) {
-			throw new RuntimeException("No input available for " + h);
-		}
-		return h.getInput().get(0);
-	}
-	
-	private static boolean hasSecondInput(Hop h) {
-		return !(h == null || h.getInput() == null || h.getInput().size() < 2);
-	}
-	
-	private static Hop getSecondInput(Hop h) {
-		if(h == null || h.getInput() == null || h.getInput().size() < 2) {
-			throw new RuntimeException("Expected atleast two inputs for " + h);
+		// dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
+		HopDagPatternMatcher dnorm = bias_multiply(dout, gamma);
+		// dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win)
+		HopDagPatternMatcher dmean_norm_branch = util_channel_sums(bias_multiply(dnorm, unaryMinus(cache_inv_var)), C, HW) ;
+		// dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm,
+		//      C, Hin, Win)  # shape (C, 1)
+		HopDagPatternMatcher dvar = util_channel_sums(mult(mult(-0.5, bias_multiply(centered, pow(cache_inv_var, 3))), dnorm),  C, HW);
+		// dmean_var_branch = util::channel_sums((-2*oneByN*oneByHW) * centered, C, Hin, Win) * dvar
+		HopDagPatternMatcher dmean_var_branch =
+			mult(util_channel_sums(mult(leaf("const3", SCALAR), centered), C, HW), dvar);
+		// dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
+		HopDagPatternMatcher dX_norm_branch = bias_multiply(dnorm, cache_inv_var);
+		// dX_mean_branch = (oneByN*oneByHW) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)
+		HopDagPatternMatcher dX_mean_branch = mult(leaf("const1", SCALAR), bias_add(matrix(0, 1, CHW), 
+				plus(dmean_norm_branch, dmean_var_branch) ));
+		// dX_var_branch = (2*oneByN*oneByHW) * bias_multiply(centered, dvar)
+		HopDagPatternMatcher dX_var_branch = mult(leaf("const2", SCALAR), bias_multiply(centered, dvar));
+		_batchNormdX = plus(plus(dX_norm_branch, dX_mean_branch), dX_var_branch).fitsOnGPU(2);
+	}
+	private static final Function<Hop, Hop> _batchNormdXReplacer = hi -> {
+		// double CHW = _batchNormdX.getLiteralValue("CHW");
+		double HW = _batchNormdX.getLiteralValue("HW");
+		double C = _batchNormdX.getLiteralValue("C");
+		double const1 = _batchNormdX.getLiteralValue("const1"); // (oneByN*oneByHW)
+		double const2 = _batchNormdX.getLiteralValue("const2"); // (2*oneByN*oneByHW)
+		double const3 = _batchNormdX.getLiteralValue("const3"); // (-2*oneByN*oneByHW)
+		if(2*const1 == const2 && const3 == -const2 && 
+			hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("mean")) &&
+			hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("cache_inv_var")) &&
+			_batchNormdX.getMatchedHop("X").getDim2() == C*HW &&
+			checkDimensions(_batchNormdX.getMatchedHop("gamma"), (long)C, 1)) {
+			LOG.debug("Applied batchNormdX rewrite.");
+			Hop newHop = HopRewriteUtils.createDnnOp(_batchNormdX, OpOpDnn.BATCH_NORM2D_BACKWARD_DX, 
+					"X", "dout", "gamma", "mean", "cache_inv_var");
+			return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
 		}
-		return h.getInput().get(1);
-	}
-	
-	private static Hop getThirdInput(Hop h) {
-		if(h == null || h.getInput() == null || h.getInput().size() < 3) {
-			throw new RuntimeException("Expected atleast three inputs for " + h);
-		}
-		return h.getInput().get(2);
-	}
-	
-	private static boolean isUnaryMinus(Hop h) {
-		return HopRewriteUtils.isBinary(h, OpOp2.MINUS)
-			&& HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0);
-	}
-	
-	private static boolean isOneDivideBySqrt(Hop h) {
-		return HopRewriteUtils.isBinary(h, OpOp2.DIV)
-			&& HopRewriteUtils.isUnary(h.getInput().get(1), OpOp1.SQRT)
-			&& HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 1);
-	}
-	
-	private static Hop channelSums(Hop parent, Hop hi, int pos) {
-		if(hi instanceof AggUnaryOp) {
-			AggUnaryOp hop = (AggUnaryOp) hi;
-			// output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))
-			if( hop.getOp() == AggOp.SUM && hop.getDirection() == Direction.Row
-				&& HopRewriteUtils.isReorg(hop.getInput().get(0), ReOrgOp.RESHAPE) ) {
-				Hop colSumsInput = hop.getInput().get(0).getInput().get(0);
-				if(colSumsInput instanceof AggUnaryOp && ((AggUnaryOp)colSumsInput).getOp() == AggOp.SUM && ((AggUnaryOp)colSumsInput).getDirection() == Direction.Col) {
-					ArrayList<Hop> inHops = new ArrayList<Hop>();
-					inHops.add(colSumsInput.getInput().get(0));
-					long numChannels = Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(1));
-					long HW = Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(2));
-					if(numChannels > 0 && HW > 0 && fitsOnGPU(inHops, false, numChannels*8)) {
-						inHops.add(new LiteralOp(numChannels));
-						inHops.add(new LiteralOp(HW));
-						LOG.debug("Applied channelSums rewrite.");
-						Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
-								OpOpDnn.CHANNEL_SUMS, inHops);
-						return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
-					}
-				}
-			}
+		else if(DEBUG_REWRITES) {
+			System.out.println("Couldnot apply batchNormdX rewrite."); 
+			System.out.println((2*const1) + " == " + const2 + " && " + const3 + "== -" + const2 
+			+ " && " + hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("mean")) +  " && " + 
+			hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("cache_inv_var")) + " && " +
+			_batchNormdX.getMatchedHop("X").getDim2() + " == " + C + "*" + HW  + " && " +
+			checkDimensions(_batchNormdX.getMatchedHop("gamma"), (long)C, 1));
 		}
 		return hi;
-	}
-	
-	private static boolean isRowMeans(Hop h) {
-		return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row; 
-	}
-	
-	private static boolean isRowVars(Hop h) {
-		return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row; 
-	}
-	
-	private static boolean isRowVars(Hop h, Hop childHop) {
-		return isRowVars(h) && getFirstInput(h) == childHop; 
-	}
-	
-	private static boolean isColMeans(Hop h) {
-		return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col; 
-	}
-	
-	private static boolean isColVars(Hop h) {
-		return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col; 
-	}
-	
-	private static boolean isReshape(Hop h) {
-		return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE;
-	}
-	
-	private static boolean isReshape(Hop h, long expectedRows, long expectedCols) {
-		return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE &&
-				Hop.computeSizeInformation(getSecondInput(h)) == expectedRows && 
-				Hop.computeSizeInformation(getThirdInput(h)) == expectedCols;
-	}
-	
-	private static boolean isBinaryAdd(Hop h) {
-		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS;
-	}
-	
-	private static boolean isBinaryMSAdd(Hop h, double expectedValue) {
-		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS 
-				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
-				&& OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue;
-	}
-	
-	private static boolean isBinaryMMAdd(Hop h) {
-		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS 
-				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
-	}
-	
-	private static boolean isBinaryMMMinus(Hop h) {
-		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS 
-				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
-	}
-	
-	private static boolean isBinaryMSMult(Hop h, double expectedValue) {
-		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT 
-				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
-				&& OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue;
-	}
-	
-	private static boolean isBinarySSMinus(Hop h) {
-		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS 
-				&& getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
-	}
-	
-	private static boolean isBinarySSDiv(Hop h) {
-		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV 
-				&& getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
-	}
-	
-	private static boolean isBinarySMDiv(Hop h, double expectedValue) {
-		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV 
-				&& getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.MATRIX 
-				&& OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>()) == expectedValue;
-	}
-	
-	private static boolean isAnyBinaryAdd(ArrayList<Hop> hops) {
-		if(hops != null) {
-			for(Hop h : hops) {
-				if(h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS)
-					return true;
-			}
-		}
-		return false;
-	}
-	
-	private static boolean isBinaryMSMult(Hop h) {
-		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT 
-				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR;
-	}
-	
-	private static boolean isBinarySMMult(Hop h) {
-		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT 
-				&& getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR;
-	}
-	
-	private static boolean isBinarySMMult(Hop h, double expectedVal) {
-		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT 
-				&& getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR
-				&& getValue(getFirstInput(h)) == expectedVal;
-	}
-	
-	private static double getValue(Hop h) {
-		return OptimizerUtils.rEvalSimpleDoubleExpression(h, new HashMap<>());
-	}
-	
-	/**
-	 * Checks if the "mean" hop is a moving average of mean in batch normalization layer.
-	 *  
-	 * @param mean hop to check against
-	 * @param X input data
-	 * @return true if the "mean" hop is a moving average of mean in batch normalization layer.
-	 */
-	private static boolean isBatchNormTrainMean(Hop mean, Hop X) {
-		// subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
-		// mean = rowMeans(subgrp_means)
-		return isRowMeans(mean) && isReshape(getFirstInput(mean)) && isColMeans(getFirstInput(getFirstInput(mean)))
-				&& getFirstInput(getFirstInput(getFirstInput(mean))) == X;
-	}
-	
-	/**
-	 * Checks for nrow(X) pattern
-	 * 
-	 * @param expr hop to be matched
-	 * @param X input X
-	 * @return true if expr is nrow(X) else false
-	 */
-	private static boolean isNrowOfX(Hop expr, Hop X) {
-		return expr instanceof UnaryOp && ((UnaryOp)expr).getOp() == OpOp1.NROW && getFirstInput(expr) == X;
-	}
-	
-	/**
-	 * Checks for the colVars(X) * ((N-1)/N) pattern
-	 * 
-	 * @param expr hop to be matched
-	 * @param X input X
-	 * @param ignoreCorrectionTerm whether to ignore the correction term ((N-1)/N).
-	 * @return true if expr is colVars(X) * ((N-1)/N) else false
-	 */
-	private static boolean isCorrectedColVars(Hop expr, Hop X, boolean ignoreCorrectionTerm) {
-		// colVars(X) * ((N-1)/N)
-		if(isColVars(expr) && getFirstInput(expr) == X) {
-			// Support no correction as well in this rewrite
-			return true;
-		}
-		else if(X.rowsKnown()) {
-			return isBinaryMSMult(expr, ((double)X.getDim1()-1)/X.getDim1()) && 
-					isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X;
+	};
+	
+	
+	
+	// Pattern 2:
+	// subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
+	// var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+	private static final HopDagPatternMatcher _batchNormUpdatedVar; 
+	static {
+		HopDagPatternMatcher subgrp_vars = 
+			matrix( 
+					mult(colVars(leaf("X", MATRIX).fitsOnGPU(2)), leaf("varConst1", SCALAR)), // colVars(X) * ((N-1)/N)
+					leaf("C", SCALAR),  	// rows=C
+					leaf("HW", SCALAR)); // cols=Hin*Win
+		_batchNormUpdatedVar = 
+			mm_plus( 
+					rowMeans(subgrp_vars), 
+					mult(rowVars(leaf("subgrp_means", MATRIX)),  leaf("varConst2", SCALAR))); // rowVars(subgrp_means)*varConst2
+	}
+	private static final Function<Hop, Hop> _batchNormUpdatedVarReplacer = hi -> {
+		double HW = _batchNormUpdatedVar.getLiteralValue("HW");
+		if(_batchNormUpdatedVar.getLiteralValue("varConst2") == ((HW-1)/HW)) {
+			LOG.debug("Applied batchNormUpdatedVar rewrite.");
+			Hop newHop = HopRewriteUtils.createDnnOp(_batchNormUpdatedVar, OpOpDnn.UPDATE_EMA_VAR, 
+					// varConst1 => ((N-1)/N)
+					"subgrp_means", "X", "C", "HW", "varConst1");
+			return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
 		}
-		else if(isBinaryMSMult(expr) && 
-				isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X) {
-			if(ignoreCorrectionTerm) {
-				return true;
-			}
-			Hop tmp = getSecondInput(expr);
-			// ((N-1)/N)
-			boolean isNMinus1Pattern = isBinarySSDiv(tmp) && isBinarySSMinus(getFirstInput(tmp)) &&
-					getFirstInput(getFirstInput(tmp)) == getSecondInput(tmp) && 
-					OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(getFirstInput(tmp)), new HashMap<>()) == 1;
-			boolean ret = isNMinus1Pattern && isNrowOfX(getSecondInput(tmp), X);
-			if(LOG.isDebugEnabled()) {
-				LOG.debug("Is the corrected column variance pattern for batch_norm_train rewrite when number of rows of X unknown matched:" + ret);
-			}
-			return ret;
-		}
-		return false;
-	}
-	
-	/**
-	 * Checks if the "var" hop is a moving average of variance in batch normalization layer.
-	 *  
-	 * @param mean previously matched mean hop
-	 * @param var the hop to check against
-	 * @param X input data hop
-	 * @param subgrpMeans mean for subgroup mean
-	 * @param ignoreCorrectionTerm whether to incore the correct term  (see isCorrectedColVars method in this class)
-	 * @return true if the "var" hop is a moving average of variance in batch normalization layer.
-	 */
-	private static boolean isBatchNormTrainVar(Hop mean, Hop var, Hop  X, Hop subgrpMeans, boolean ignoreCorrectionTerm) {
-		long numChannels = Hop.computeSizeInformation(getSecondInput(getFirstInput(mean)));
-		long HW = Hop.computeSizeInformation(getThirdInput(getFirstInput(mean)));
-		// subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
-		// var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
-		return numChannels > 0 && HW > 0 && isBinaryMMAdd(var) && isRowMeans(getFirstInput(var)) &&  
-				// matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
-				isReshape(getFirstInput(getFirstInput(var)), numChannels, HW) &&
-				isCorrectedColVars(getFirstInput(getFirstInput(getFirstInput(var))), X, ignoreCorrectionTerm) &&
-				// rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
-				isBinaryMSMult(getSecondInput(var), ((((double)HW)-1)/HW)) && 
-				isRowVars(getFirstInput(getSecondInput(var)), subgrpMeans);
-	}
-	
-	/**
-	 * Checks and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean  
-	 * 
-	 * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean 
-	 * @param mu value of mu
-	 * @return an array [ema_mean_upd, ema_mean] if expression matched, else null
-	 */
-	private static Hop [] getUpdatedMovingAverageExpressions(Hop rhsTimesOp, double mu) {
-		if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || 
-				!isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
-			return null;
-		
-		// Check (1-mu)*mean
-		double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>());
-		Hop plusOp = rhsTimesOp.getParent().get(0); 
-		Hop lhsTimesOp = null;
-		if(plusOp.getInput().get(0) == rhsTimesOp) {
-			lhsTimesOp = plusOp.getInput().get(1); 
-		}
-		else {
-			lhsTimesOp = plusOp.getInput().get(0);
-		}
-		
-		if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 &&  
-			isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) {
-			return new Hop[] {
-				plusOp.getParent().get(0),
-				getSecondInput(lhsTimesOp), 
-				getSecondInput(rhsTimesOp)
-			};
-		}
-		return null;
-	}
-	
-	/**
-	 * Checks (if exactly one of rhsTimesOps) and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean  
-	 * 
-	 * @param rhsTimesOps array list of hop representing BinaryOp of expression (1-mu)*mean 
-	 * @param mu value of mu
-	 * @return an array [ema_mean_upd, ema_mean] if any of the expression matched, else null
-	 */
-	private static Hop [] getUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps, double mu) {
-		if(rhsTimesOps == null || rhsTimesOps.size() == 0)
-			return null;
-		
-		Hop [] ret = null;
-		for(Hop h : rhsTimesOps) {
-			boolean matched = isUpdatedMovingAverageExpression(h, mu);
-			if(matched && ret != null) {
-				return null; // Multiple matches, cannot decide which one to fuse
-			}
-			else if(matched) {
-				ret = getUpdatedMovingAverageExpressions(h, mu);
-			}
-		}
-		
-		return ret;
-	}
-	
-	/**
-	 * Checks and returns the mu in the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
-	 * 
-	 * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean
-	 * @return value of mu if the expression matched else null 
-	 */
-	private static Double getMuFromUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps) {
-		if(rhsTimesOps == null || rhsTimesOps.size() == 0)
-			return null;
-		
-		Double ret = null; 
-		for(Hop h : rhsTimesOps) {
-			boolean matched = isUpdatedMovingAverageExpression(h);
-			if(matched && ret != null) {
-				return null; // Multiple matches, cannot decide which one to fuse
-			}
-			else if(matched) {
-				ret = -(OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>())-1);
-			}
-		}
-		return ret;
-	}
-	
-	/**
-	 * Checks for the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
-	 * 
-	 * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean
-	 * @return true if expression matched
-	 */
-	private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp) {
-		if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || 
-				!isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
-			return false;
-		
-		// Check (1-mu)*mean
-		Hop plusOp = rhsTimesOp.getParent().get(0); 
-		Hop lhsTimesOp = null;
-		if(plusOp.getInput().get(0) == rhsTimesOp) {
-			lhsTimesOp = plusOp.getInput().get(1); 
-		}
-		else {
-			lhsTimesOp = plusOp.getInput().get(0);
-		}
-		
-		if(plusOp.getParent() != null && plusOp.getParent().size() == 1 && isBinarySMMult(lhsTimesOp)) {
-			return true;
-		}
-		return false;
-	}
+		return hi;
+	};
 	
-	// ema_mean_upd = mu*ema_mean + (1-mu)*mean
-	// Returns true if expression matched, else false
-	private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp, double mu) {
-		if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 || 
-				!isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
-			return false;
-		
-		// Check (1-mu)*mean
-		double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>());
-		Hop plusOp = rhsTimesOp.getParent().get(0); 
-		Hop lhsTimesOp = null;
-		if(plusOp.getInput().get(0) == rhsTimesOp) {
-			lhsTimesOp = plusOp.getInput().get(1); 
-		}
-		else {
-			lhsTimesOp = plusOp.getInput().get(0);
-		}
 		
-		if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 &&  
-			isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) {
-			return true;
-		}
-		return false;
-	}
-	
-	/**
-	 * Checks for the expression 1/sqrt(denom)
-	 * 
-	 * @param denom denominator of the expression to be matched
-	 * @return true if the expression 1/sqrt(denom) matched else false
-	 */
-	private static boolean isOneBySqrt(Hop denom) {
-		return denom.getParent() != null && denom.getParent().get(0) instanceof UnaryOp &&
-				((UnaryOp)denom.getParent().get(0)).getOp() == OpOp1.SQRT &&
-				denom.getParent().get(0).getParent() != null && denom.getParent().get(0).getParent().size() == 1 &&
-				isBinarySMDiv(denom.getParent().get(0).getParent().get(0), 1);
-	}
-	
-	/**
-	 * Checks for the batch norm (mode="train") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar
-	 * and returns a new FunctionOp if matched
-	 * 
-	 * @param roots root hops of the given statement block
-	 * @param parent parent of the input
-	 * @param hi input to be matched
-	 * @param pos position
-	 * @return a new FunctionOp or hi
-	 */
-	@SuppressWarnings("unused")
-	private static Hop batchNormTrain(ArrayList<Hop> roots, Hop parent, Hop hi, int pos) 
-	{		
+	// Pattern 3:
+	private static final HopDagPatternMatcher _batchNormTest;
+	static {
 		// norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
+		HopDagPatternMatcher norm = 
+			bias_multiply(
+					bias_add(leaf("X", MATRIX), unaryMinus(leaf("mean", MATRIX))), // bias_add(X, -mean)
+					div(1, sqrt(plus(leaf("var", MATRIX), leaf("eps", SCALAR))))); // 1/sqrt(var+eps)
 		// hi = bias_add(bias_multiply(norm, gamma), beta)
-		// 2x for input and output and 1x for overhead
-		// fitsOnGPU(hi, 3)
-		if( hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) ) {	
-			Hop norm = getFirstInput(getFirstInput(hi));
-			if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) 
-					&& hasSecondInput(getFirstInput(norm)) && isUnaryMinus(getSecondInput(getFirstInput(norm)))
-					&& isOneDivideBySqrt(getSecondInput(norm))) {
-				double eps = 0;
-				Hop var = getFirstInput(getSecondInput(getSecondInput(norm)));
-				if(isBinaryAdd(var) && (getFirstInput(var) instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) {
-					// eps + ema_var
-					if(getFirstInput(var) instanceof LiteralOp) {
-						eps = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>());
-						var = getSecondInput(var);
-					}
-					else {
-						eps = OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new HashMap<>());
-						var = getFirstInput(var);
-					}
-				}
-				// Generate batch norm test op
-				Hop X = getFirstInput(getFirstInput(norm));
-				Hop mean = getSecondInput(getSecondInput(getFirstInput(norm)));
-				
-				if(hasFirstInput(mean) && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), false) &&
-					mean.getParent() != null && mean.getParent().size() >= 2 && 
-					var.getParent() != null && var.getParent().size() == 2) {
-					Hop gamma = getSecondInput(getFirstInput(hi));
-					Hop beta = getSecondInput(hi);
-					
-					// Always get mu from variance as it will have exactly one match of fusion pattern
-					Double potentialMu = getMuFromUpdatedMovingAverageExpressions(var.getParent());
-					if(potentialMu == null)
-						return hi;
-					double mu = potentialMu;
-					
-					Hop [] means = getUpdatedMovingAverageExpressions(mean.getParent(), mu);
-					Hop [] vars = getUpdatedMovingAverageExpressions(var.getParent(), mu);
-					if(means == null || vars == null)
-						return hi;
-					
-					Hop varPlusEps = null;
-					boolean isFirstBinaryAddOp = isAnyBinaryAdd(var.getParent().get(0).getParent());
-                    boolean isSecondBinaryAddOp = isAnyBinaryAdd(var.getParent().get(1).getParent());
-                    if(isFirstBinaryAddOp && !isSecondBinaryAddOp) {
-                            varPlusEps = var.getParent().get(1);
-                    }
-                    else if(!isFirstBinaryAddOp && isSecondBinaryAddOp) {
-                            varPlusEps = var.getParent().get(0);
-                    }
-					if(varPlusEps != null && isBinaryMSAdd(varPlusEps, eps) && isOneBySqrt(varPlusEps)) {
-						
-						Hop cache_var = varPlusEps.getParent().get(0).getParent().get(0);
-						Hop ema_mean_upd = means[0];
-						Hop ema_var_upd = vars[0];
-						Hop ema_mean = means[1];
-						Hop ema_var = vars[1];
-						Hop cache_mean = means[2];
-						
-						
-						ArrayList<Hop> inHops = new ArrayList<Hop>();
-						inHops.add(X);
-						inHops.add(gamma);
-						inHops.add(beta);
-						inHops.add(ema_mean);
-						inHops.add(ema_var);
-						inHops.add(new LiteralOp(eps));
-						inHops.add(new LiteralOp(mu));
-						Hop [] oldHops = {hi, ema_mean_upd, ema_var_upd, cache_mean, cache_var};
-						
-						// Since FunctionOp adds transientwrite explicitly, persistent writes are not supported
-						if(!isAnyPersistentWrite(oldHops)) {
-							LOG.debug("Applied batchNormTrain rewrite.");
-							ArrayList<Hop> outputs = getMultiOutputHops(roots, oldHops);
-							FunctionOp ret = new FunctionOp(FunctionType.MULTIRETURN_BUILTIN, DMLProgram.INTERNAL_NAMESPACE, "batch_norm2d_train", 
-								null, inHops, outputs.stream().map(h -> h.getName()).toArray(String[]::new), outputs);
-							Collections.reverse(roots);
-							roots.add(ret);
-							Collections.reverse(roots);
-							return ret;
-						}
-					}
-					
-				}
+		_batchNormTest = 
+			bias_add(
+					bias_multiply(norm, leaf("gamma", MATRIX)), 
+					leaf("beta", MATRIX))
+			.fitsOnGPU(3);
+	}
+	private static final Function<Hop, Hop> _batchNormTestReplacer = hi -> {
+		LOG.debug("Applied batchNormTest rewrite.");
+		Hop newHop = HopRewriteUtils.createDnnOp(_batchNormTest, OpOpDnn.BATCH_NORM2D_TEST, "X", "gamma", "beta", "mean", "var", "eps");
+		return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+	};
+	
+	// Pattern 4:
+	// rowSums(matrix(colSums(X), rows=C, cols=HW))
+	private static final HopDagPatternMatcher _channelSums = util_channel_sums(leaf("X", MATRIX).fitsOnGPU(2), leaf("C", SCALAR), leaf("HW", SCALAR));;
+	private static final Function<Hop, Hop> _channelSumsReplacer = hi -> {
+		LOG.debug("Applied channelSums rewrite.");
+		Hop newHop = HopRewriteUtils.createDnnOp(_channelSums, OpOpDnn.CHANNEL_SUMS, "X", "C", "HW");
+		return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+	};
+	
+	// Pattern 5:
+	// (X - mu*v_prev) + (1+mu)*v
+	private static final HopDagPatternMatcher _updateNesterovX = 
+		mm_plus(
+				minus(	// X - mu*v_prev
+						leaf("X", MATRIX), 
+						mult(	// mu*v_prev
+								leaf("mu", SCALAR), 
+								leaf("v_prev", MATRIX))),
+				mult(	// (1+mu)*v
+						leaf("onePlusMu", SCALAR), 
+						leaf("v", MATRIX)))						
+		.fitsOnGPU(3);
+	private static final Function<Hop, Hop> _updateNesterovXReplacer = hi -> {
+		if((1+_updateNesterovX.getLiteralValue("mu")) == _updateNesterovX.getLiteralValue("onePlusMu")) {
+			Hop X = _updateNesterovX.getMatchedHop("X");
+			Hop v = _updateNesterovX.getMatchedHop("v");
+			Hop v_prev = _updateNesterovX.getMatchedHop("v_prev");
+			if(hasSameDimensions(X, v) && hasSameDimensions(X, v_prev)) {
+				LOG.debug("Applied updateNesterovX rewrite.");
+				Hop newHop = HopRewriteUtils.createDnnOp(_updateNesterovX, OpOpDnn.UPDATE_NESTEROV_X, "X", "v", "v_prev", "mu");
+				return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
 			}
 		}
-		
 		return hi;
-	}
-	
-	// ------------------------------------------------------------
-	/**
-	 * Checks if any of the given output hop is a persistent write.
-	 * 
-	 * @param outputHops output hops to check
-	 * @return true if any of the hop is a persistent write else false.
-	 */
-	private static boolean isAnyPersistentWrite(Hop [] outputHops) {
-		for(Hop outHop : outputHops) {
-			if(HopRewriteUtils.isData(outHop, DataOpTypes.PERSISTENTWRITE))
-				return true;
+	};
+	
+	// Pattern 6:
+	// matrix(colMeans(X), rows=C, cols=Hin*Win)
+	// This avoids unnecessary copy by the reshape operator
+	private static final HopDagPatternMatcher _reshapeColMeans = 
+		matrix(
+				colMeans(leaf("X", MATRIX).fitsOnGPU(2)), // colMeans(X)
+				leaf("C", SCALAR), 
+				leaf("HW", SCALAR));
+	private static final Function<Hop, Hop> _reshapeColMeansReplacer = hi -> {
+		LOG.debug("Applied reshapeColMeans rewrite.");
+		Hop newHop = HopRewriteUtils.createDnnOp(_reshapeColMeans, OpOpDnn.RESHAPE_COLMEANS, "X", "C", "HW");
+		return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+	};
+	
+	// Pattern 7:
+	// mu*ema_mean + (1-mu)*mean
+	private static final HopDagPatternMatcher _updateEMA = 
+		mm_plus( 	
+				mult(	// mu*ema_mean
+						leaf("mu", SCALAR), 
+						leaf("ema_mean", MATRIX)), 
+				mult(	// (1-mu)*mean
+						leaf("oneMinusMu", SCALAR), 
+						leaf("mean", MATRIX)))
+		.fitsOnGPU(3);
+	private static final Function<Hop, Hop> _updateEMAReplacer = hi -> {
+		if((1-_updateEMA.getLiteralValue("mu")) == _updateEMA.getLiteralValue("oneMinusMu")) {
+			LOG.debug("Applied updateEMA rewrite.");
+			Hop newHop = HopRewriteUtils.createDnnOp(_updateEMA, OpOpDnn.UPDATE_EMA, "ema_mean", "mean", "mu");
+			return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
 		}
-		return false;
-	}
-	
-	/**
-	 * Returns output hop for a multi-output FunctionOp to be created by rewrite.
-	 * 
-	 * @param roots root hops of statement block
-	 * @param oldHops old output hops of the pattern
-	 * @return new output hops that should be passed to FunctionOp
-	 */
-	private static ArrayList<Hop> getMultiOutputHops(ArrayList<Hop> roots, Hop [] oldHops) {
-		ArrayList<Hop> ret = new ArrayList<>();
-		for(int i = 0; i < oldHops.length; i++) {
-			// Create a transient read as FunctionOp will add a transient write.
-			if(HopRewriteUtils.isData(oldHops[i], DataOpTypes.PERSISTENTWRITE))
-				throw new RuntimeException("Persistent write is not supported as output for the given rewrite." + oldHops[i]);
-			// Generate a new name if the old output was not transient write.
-			String name = HopRewriteUtils.isData(oldHops[i], DataOpTypes.TRANSIENTWRITE) ? oldHops[i].getName() : "_genGPU" + (_seq++);
-			DataOp tRead = HopRewriteUtils.createTransientRead(name, oldHops[i]);
-			HopRewriteUtils.rewireAllParentChildReferences(oldHops[i], tRead);
-			ret.add(tRead);
-			// Remove old output from roots to avoid unnecessary computation.
-			if(roots.contains(oldHops[i])) {
-				roots.remove(oldHops[i]);
-			}
+		return hi;
+	};
+	
+	// Pattern 8:
+	// 1/sqrt(var+epsilon)
+	private static final HopDagPatternMatcher _invVar = 
+		div(1, 
+				sqrt(	// var+epsilon
+						plus(	leaf("var", MATRIX), 
+								leaf("eps", SCALAR) )))
+		.fitsOnGPU(2);
+	private static final Function<Hop, Hop> _invVarReplacer = hi -> {
+		LOG.debug("Applied computeInverseVariance rewrite.");
+		Hop newHop = HopRewriteUtils.createDnnOp(_invVar, OpOpDnn.INV_VAR, "var", "eps");
+		return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+	};
+	
+	
+	private static ArrayList<HopPatternRewriter> _rewriters = null;
+	public ArrayList<HopPatternRewriter> getPatternRewriter() {
+		if(_rewriters == null) {
+			ArrayList<HopPatternRewriter> rewriters = new ArrayList<>();
+			rewriters.add(new HopPatternRewriter("batchNormdX", _batchNormdX, _batchNormdXReplacer));
+			rewriters.add(new HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, _batchNormUpdatedVarReplacer));
+			rewriters.add(new HopPatternRewriter("batchNormTest", _batchNormTest, _batchNormTestReplacer));
+			rewriters.add(new HopPatternRewriter("channelSums", _channelSums, _channelSumsReplacer));
+			rewriters.add(new HopPatternRewriter("updateNesterovX", _updateNesterovX, _updateNesterovXReplacer));
+			rewriters.add(new HopPatternRewriter("reshapeColMeans", _reshapeColMeans, _reshapeColMeansReplacer));
+			rewriters.add(new HopPatternRewriter("updateEMA", _updateEMA, _updateEMAReplacer));
+			rewriters.add(new HopPatternRewriter("invVar", _invVar, _invVarReplacer));
+			_rewriters = rewriters;
 		}
-		return ret;
+		return _rewriters;
 	}
-	// ------------------------------------------------------------
 	
-	/**
-	 * Checks for the nesterov_update_x pattern (X = X - mu*v_prev + (1+mu)*v)
-	 * and returns a new DnnOp if matched
-	 * 
-	 * @param parent parent of the input
-	 * @param hi input to be matched
-	 * @param pos position
-	 * @return a new DnnOp or hi
-	 */
-	private static Hop updateNesterovX(Hop parent, Hop hi, int pos) {
-		if(fitsOnGPU(hi, 4) && isBinaryMMAdd(hi) && isBinaryMMMinus(getFirstInput(hi))
-			&& isBinarySMMult(getSecondInput(getFirstInput(hi))) 
-			&& isBinarySMMult(getSecondInput(hi))) {
-			Hop onePlusMu = getFirstInput(getSecondInput(hi));
-			Hop tmp = getSecondInput(getFirstInput(hi));
-			Hop mu = getFirstInput(tmp);
-			if(isOnePlusMu(onePlusMu, mu)) {
-				Hop v_prev = getSecondInput(tmp);
-				Hop v = getSecondInput(getSecondInput(hi));
-				Hop X = getFirstInput(getFirstInput(hi));
-				if(hasSameDimensions(X, v) && hasSameDimensions(X, v_prev)) {
-					ArrayList<Hop> inHops = new ArrayList<Hop>();
-					inHops.add(X);
-					inHops.add(v);
-					inHops.add(v_prev);
-					inHops.add(mu);
-					LOG.debug("Applied updateNesterovX rewrite.");
-					Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
-							OpOpDnn.UPDATE_NESTEROV_X, inHops);
-					return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
-				}
-			}
-		}
-		return hi;
-	}
+	
+	// -------------------------------------------------------------------------------------------
 	
 	private static boolean hasSameDimensions(Hop x, Hop y) {
 		return x.dimsKnown() && y.dimsKnown() && (x.getDim1() == y.getDim1()) && (x.getDim2() == y.getDim2());
 	}
 	
-	private static boolean isOnePlusMu(Hop onePlusMu, Hop mu) {
-		return (isBinarySMMult(onePlusMu, 1.0) && getSecondInput(onePlusMu) == mu) ||
-				getValue(onePlusMu) == getValue(mu) + 1;
-	}
-	
-	/**
-	 * Checks for the batch norm (mode="test") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar
-	 * and returns a new DnnOp if matched
-	 * 
-	 * @param parent parent of the input
-	 * @param hi input to be matched
-	 * @param pos position
-	 * @return a new DnnOp or hi
-	 */
-	private static Hop batchNormTest(Hop parent, Hop hi, int pos) {
-		// norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
-		// hi = bias_add(bias_multiply(norm, gamma), beta)
-		// 2x for input and output and 1x for overhead
-		if(hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) {
-			Hop norm = getFirstInput(getFirstInput(hi));
-			if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm)) 
-					&& isUnaryMinus(getSecondInput(getFirstInput(norm)))
-					&& isOneDivideBySqrt(getSecondInput(norm))) {
-				double eps = 0;
-				Hop var = getFirstInput(getSecondInput(getSecondInput(norm)));
-				if( HopRewriteUtils.isBinary(var, OpOp2.PLUS) &&
-					(getFirstInput(var) instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) {
-					// eps + ema_var
-					if(getFirstInput(var) instanceof LiteralOp) {
-						eps = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>());
-						var = getSecondInput(var);
-					}
-					else {
-						eps = OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new HashMap<>());
-						var = getFirstInput(var);
-					}
-				}
-				// Generate batch norm test op
-				Hop X = getFirstInput(getFirstInput(norm));
-				Hop mean = getSecondInput(getSecondInput(getFirstInput(norm)));
-				
-				// This guard disallows eager fusion of train batch normalization into test batch normalization
-				boolean potentialForBatchNormTrain = !X.rowsKnown() && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), true);
-				if(!potentialForBatchNormTrain) {
-					Hop gamma = getSecondInput(getFirstInput(hi));
-					Hop beta = getSecondInput(hi);
-					ArrayList<Hop> inHops = new ArrayList<Hop>();
-					inHops.add(X);
-					inHops.add(gamma);
-					inHops.add(beta);
-					inHops.add(mean);
-					inHops.add(var);
-					inHops.add(new LiteralOp(eps));
-					if(fitsOnGPU(inHops, true)) {
-						LOG.debug("Applied batchNormTest rewrite.");
-						Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
-								OpOpDnn.BATCH_NORM2D_TEST, inHops);
-						return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
-					}
-				}
-				else {
-					LOG.debug("Skipping batchNormTest rewrite as there is potential for batch normalization train rewrite after recompilation.");
-				}
-			}
-		}
-		
-		return hi;
+	private static boolean checkDimensions(Hop x, long dim1, long dim2) {
+		return x.dimsKnown() && (x.getDim1() == dim1) && (x.getDim2() == dim2);
 	}
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/lops/DnnTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java b/src/main/java/org/apache/sysml/lops/DnnTransform.java
index 3183b5f..2d2d5f1 100644
--- a/src/main/java/org/apache/sysml/lops/DnnTransform.java
+++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java
@@ -32,7 +32,8 @@ public class DnnTransform extends Lop
 		RELU_MAX_POOLING, RELU_MAX_POOLING_BACKWARD, RELU_BACKWARD,
 		CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
 		BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST, 
-		UPDATE_NESTEROV_X
+		UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR,
+		BATCH_NORM2D_BACKWARD_DX
 	}
 	
 	private OperationTypes operation;
@@ -167,11 +168,26 @@ public class DnnTransform extends Lop
 		case CHANNEL_SUMS:
 			return "channel_sums";
 		
+		case INV_VAR:
+			return "inv_var";
+		
 		case UPDATE_NESTEROV_X:
 			return "update_nesterov_x";
 			
 		case BATCH_NORM2D_TEST:
 			return "batch_norm2d_test";
+		
+		case BATCH_NORM2D_BACKWARD_DX:
+			return "batch_norm2d_bwd_dx";
+			
+		case UPDATE_EMA_VAR:
+			return "update_ema_var";
+			
+		case UPDATE_EMA:
+			return "update_ema";
+			
+		case RESHAPE_COLMEANS:
+			return "reshape_colmeans";
 			
 		default:
 			throw new UnsupportedOperationException(this.printErrorLocation() + "Instruction is not defined for Transform operation " + operation);
@@ -181,7 +197,8 @@ public class DnnTransform extends Lop
 	
 	@Override
 	public String getInstructions(String input, String bias, String output) {
-		if(operation == OperationTypes.BIAS_ADD || operation == OperationTypes.BIAS_MULTIPLY || operation == OperationTypes.RELU_BACKWARD) {
+		if(operation == OperationTypes.BIAS_ADD || operation == OperationTypes.BIAS_MULTIPLY || operation == OperationTypes.RELU_BACKWARD
+				|| operation == OperationTypes.INV_VAR) {
 			StringBuilder sb = new StringBuilder();
 			sb.append( getExecType() );
 			
@@ -190,7 +207,7 @@ public class DnnTransform extends Lop
 			sb.append( OPERAND_DELIMITOR );
 			sb.append( getInputs().get(0).prepInputOperand(input));
 			sb.append( OPERAND_DELIMITOR );
-			sb.append( getInputs().get(0).prepInputOperand(bias));
+			sb.append( getInputs().get(1).prepInputOperand(bias));
 			//output
 			sb.append( OPERAND_DELIMITOR );
 			sb.append( this.prepOutputOperand(output));
@@ -212,7 +229,7 @@ public class DnnTransform extends Lop
 	
 	@Override
 	public String getInstructions(String input, String C, String HW, String output) {
-		if(operation == OperationTypes.CHANNEL_SUMS) {
+		if(operation == OperationTypes.CHANNEL_SUMS || operation == OperationTypes.RESHAPE_COLMEANS || operation == OperationTypes.UPDATE_EMA) {
 			StringBuilder sb = new StringBuilder();
 			sb.append( getExecType() );
 			
@@ -306,6 +323,34 @@ public class DnnTransform extends Lop
 			throw new LopsException("The operation is not supported with six operands:" + operation.name());
 		}
 	}
+	
+	public String getInstructions(String input1, String input2, String input3, String input4, String input5, String output) {
+		if(operation == OperationTypes.UPDATE_EMA_VAR || operation == OperationTypes.BATCH_NORM2D_BACKWARD_DX) {
+			StringBuilder sb = new StringBuilder();
+			sb.append( getExecType() );
+			
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getOpcode() );
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(0).prepInputOperand(input1));
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(1).prepInputOperand(input2));
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(2).prepInputOperand(input3));
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(3).prepInputOperand(input4));
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(4).prepInputOperand(input5));
+			//output
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( this.prepOutputOperand(output));
+			
+			return sb.toString();
+		}
+		else {
+			throw new LopsException("The operation is not supported with six operands:" + operation.name());
+		}
+	}
 
 	public void appendOpcode(StringBuilder sb) {
 		sb.append( getExecType() );

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 9f3a1e2..fe86dc8 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -270,58 +270,6 @@ public class BuiltinFunctionExpression extends DataIdentifier
 			setDimensions(dc0, getFifthExpr());
 			break;
 		}
-		case BATCH_NORM2D:
-		{
-			// Input: image, scale, bias, runningMean, runningVar, mode, epsilon, exponentialAverageFactor
-			checkNumParameters(8);
-			checkMatrixParam(getFirstExpr());
-			checkMatrixParam(getSecondExpr());
-			checkMatrixParam(getThirdExpr());
-			checkMatrixParam(getFourthExpr());
-			checkMatrixParam(getFifthExpr());
-			
-			// Output: ret, retRunningMean, retRunningVar, resultSaveMean, resultSaveInvVariance
-			// setup output properties
-			if(getOutputs().length != 5)
-				raiseValidateError("batch_norm2d has 5 outputs", false);
-			 
-			DataIdentifier ret = (DataIdentifier) getOutputs()[0];
-			DataIdentifier retRunningMean = (DataIdentifier) getOutputs()[1];
-			DataIdentifier retRunningVar = (DataIdentifier) getOutputs()[2];
-			DataIdentifier resultSaveMean = (DataIdentifier) getOutputs()[3];
-			DataIdentifier resultSaveInvVariance = (DataIdentifier) getOutputs()[4];
-			
-			setDimensions(ret, getFirstExpr());
-			setDimensions(retRunningMean, getFourthExpr());
-			setDimensions(retRunningVar, getFourthExpr());
-			setDimensions(resultSaveMean, getFourthExpr());
-			setDimensions(resultSaveInvVariance, getFourthExpr());
-			break;
-		}
-		case BATCH_NORM2D_BACKWARD:
-		{
-			// Input: image, dout, scale, epsilon, savedMean, savedInvVariance
-			checkNumParameters(6);
-			checkMatrixParam(getFirstExpr());
-			checkMatrixParam(getSecondExpr());
-			checkMatrixParam(getThirdExpr());
-			checkMatrixParam(getFifthExpr());
-			checkMatrixParam(getSixthExpr());
-			
-			// Output: dX, dScale, dBias 
-			// setup output properties
-			if(getOutputs().length != 3)
-				raiseValidateError("batch_norm2d_backward has 3 outputs", false);
-			
-			DataIdentifier dX = (DataIdentifier) getOutputs()[0];
-			DataIdentifier dScale = (DataIdentifier) getOutputs()[1];
-			DataIdentifier dBias = (DataIdentifier) getOutputs()[2];
-			
-			setDimensions(dX, getFirstExpr());
-			setDimensions(dScale, getThirdExpr());
-			setDimensions(dBias, getThirdExpr());
-			break;
-		}
 		case EIGEN:
 			checkNumParameters(1);
 			checkMatrixParam(getFirstExpr());
@@ -1451,8 +1399,7 @@ public class BuiltinFunctionExpression extends DataIdentifier
 				// always unconditional (because unsupported operation)
 				BuiltinFunctionOp op = getOpCode();
 				if( op==BuiltinFunctionOp.EIGEN || op==BuiltinFunctionOp.LU || op==BuiltinFunctionOp.QR || op==BuiltinFunctionOp.SVD 
-						|| op==BuiltinFunctionOp.LSTM || op==BuiltinFunctionOp.LSTM_BACKWARD
-						|| op==BuiltinFunctionOp.BATCH_NORM2D || op==BuiltinFunctionOp.BATCH_NORM2D_BACKWARD)
+						|| op==BuiltinFunctionOp.LSTM || op==BuiltinFunctionOp.LSTM_BACKWARD)
 					raiseValidateError("Function "+op+" needs to be called with multi-return assignment.", false, LanguageErrorCodes.INVALID_PARAMETERS);
 				else
 					raiseValidateError("Unsupported function "+op, false, LanguageErrorCodes.INVALID_PARAMETERS);
@@ -1535,8 +1482,6 @@ public class BuiltinFunctionExpression extends DataIdentifier
 		case EIGEN:
 		case LSTM:
 		case LSTM_BACKWARD:
-		case BATCH_NORM2D:
-		case BATCH_NORM2D_BACKWARD:
 		case SVD:
 			return true;
 		default:
@@ -1956,10 +1901,6 @@ public class BuiltinFunctionExpression extends DataIdentifier
 			bifop = Expression.BuiltinFunctionOp.LSTM;
 		else if (functionName.equals("lstm_backward"))
 			bifop = Expression.BuiltinFunctionOp.LSTM_BACKWARD;
-		else if (functionName.equals("batch_norm2d"))
-			bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D;
-		else if (functionName.equals("batch_norm2d_backward"))
-			bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D_BACKWARD;
 		else if (functionName.equals("conv2d"))
 			 bifop = Expression.BuiltinFunctionOp.CONV2D;
 		else if (functionName.equals("bias_add"))

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index e9b643e..e3db435 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2271,8 +2271,6 @@ public class DMLTranslator
 			case EIGEN:
 			case LSTM:
 			case LSTM_BACKWARD:
-			case BATCH_NORM2D:
-			case BATCH_NORM2D_BACKWARD:
 			case SVD:
 				
 				// Number of outputs = size of targetList = #of identifiers in source.getOutputs

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index 46e6442..33fca66 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -93,7 +93,7 @@ public abstract class Expression implements ParseInfo
 		EXISTS,
 		CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIASADD, BIASMULT,
 		MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD, AVG_POOL_BACKWARD,
-		LSTM, LSTM_BACKWARD, BATCH_NORM2D, BATCH_NORM2D_BACKWARD,
+		LSTM, LSTM_BACKWARD,
 		EXP,
 		FLOOR,
 		IFELSE,

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index f4122d9..3480504 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -59,11 +59,13 @@ public class GPUInstructionParser  extends InstructionParser
 		String2GPUInstructionType.put( "channel_sums",          GPUINSTRUCTION_TYPE.Dnn);
 		String2GPUInstructionType.put( "lstm",                 	GPUINSTRUCTION_TYPE.Dnn);
 		String2GPUInstructionType.put( "lstm_backward",         GPUINSTRUCTION_TYPE.Dnn);
-		String2GPUInstructionType.put( "batch_norm2d",           GPUINSTRUCTION_TYPE.Dnn);
-		String2GPUInstructionType.put( "batch_norm2d_backward",  GPUINSTRUCTION_TYPE.Dnn);
 		String2GPUInstructionType.put( "batch_norm2d_test",      GPUINSTRUCTION_TYPE.Dnn);
-		String2GPUInstructionType.put( "batch_norm2d_train",      GPUINSTRUCTION_TYPE.Dnn);
-		String2GPUInstructionType.put( "update_nesterov_x",      GPUINSTRUCTION_TYPE.Dnn);
+		String2GPUInstructionType.put( "update_nesterov_x",     GPUINSTRUCTION_TYPE.Dnn);
+		String2GPUInstructionType.put( "update_ema_var",      	GPUINSTRUCTION_TYPE.Dnn);
+		String2GPUInstructionType.put( "update_ema",      		GPUINSTRUCTION_TYPE.Dnn);
+		String2GPUInstructionType.put( "reshape_colmeans",      GPUINSTRUCTION_TYPE.Dnn);
+		String2GPUInstructionType.put( "inv_var",      			GPUINSTRUCTION_TYPE.Dnn);
+		String2GPUInstructionType.put( "batch_norm2d_bwd_dx",   GPUINSTRUCTION_TYPE.Dnn);
 		
 		// Matrix Multiply Operators
 		String2GPUInstructionType.put( "ba+*",  GPUINSTRUCTION_TYPE.AggregateBinary);