You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/07/16 21:58:24 UTC
systemml git commit: [SYSTEMML-445] Added a rewrite for batch
normalization train
Repository: systemml
Updated Branches:
refs/heads/master 5aadb4b22 -> a6bca8851
[SYSTEMML-445] Added a rewrite for batch normalization train
- This PR fuses a batch normalization train pattern into a FunctionOp. The method batchNormTrain in RewriteGPUSpecificOps performs the fusing.
- This rewrite is only enabled if none of the outputs are persistent writes. It replaces the existing outputs of the matched pattern with transient reads.
Closes #800.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a6bca885
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a6bca885
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a6bca885
Branch: refs/heads/master
Commit: a6bca88512f3f542278709713706d256fad2cc17
Parents: 5aadb4b
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Mon Jul 16 14:50:36 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Mon Jul 16 14:52:01 2018 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/FunctionOp.java | 28 +-
src/main/java/org/apache/sysml/hops/Hop.java | 1 +
.../hops/rewrite/RewriteGPUSpecificOps.java | 573 ++++++++++++++++++-
.../org/apache/sysml/lops/FunctionCallCP.java | 12 +-
src/main/java/org/apache/sysml/lops/Lop.java | 13 +
.../org/apache/sysml/parser/DMLTranslator.java | 6 +-
.../instructions/GPUInstructionParser.java | 1 +
.../instructions/gpu/DnnGPUInstruction.java | 62 +-
.../apache/sysml/test/gpu/BatchNormTest.java | 35 +-
.../org/apache/sysml/test/gpu/GPUTests.java | 37 +-
10 files changed, 701 insertions(+), 67 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/FunctionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index 64963d9..aedaf81 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -169,17 +169,20 @@ public class FunctionOp extends Hop
long outputValues = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 1, 1.0);
return outputVectors+outputValues;
}
- else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) {
+ else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) {
// TODO: To allow for initial version to always run on the GPU
return 0;
}
- else if ( getFunctionName().equalsIgnoreCase("batch_norm2d") ) {
+ else if ( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) +
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) +
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0) +
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), getOutputs().get(3).getDim2(), 1.0) +
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), getOutputs().get(4).getDim2(), 1.0);
}
+ else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_test") ) {
+ return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0);
+ }
else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) {
return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) +
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) +
@@ -215,7 +218,8 @@ public class FunctionOp extends Hop
return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0)
+ 3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 1.0);
}
- else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
+ else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
+ getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
return 0;
}
else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) {
@@ -240,7 +244,8 @@ public class FunctionOp extends Hop
@Override
public boolean isGPUEnabled() {
if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ||
- getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))
+ getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
+ getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test"))
return true;
else
return false;
@@ -283,20 +288,25 @@ public class FunctionOp extends Hop
checkAndSetForcedPlatform();
if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) {
+ boolean isBuiltinFunction = isBuiltinFunction();
// check if there is sufficient memory to execute this function
- if( getFunctionName().equalsIgnoreCase("transformencode") ) {
+ if(isBuiltinFunction && getFunctionName().equalsIgnoreCase("transformencode") ) {
_etype = ((_etypeForced==ExecType.SPARK
|| (getMemEstimate() >= OptimizerUtils.getLocalMemBudget()
&& OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP);
}
- else if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward")) {
+ else if(isBuiltinFunction && (getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward"))) {
if(!DMLScript.USE_ACCELERATOR)
throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU.");
_etype = ExecType.GPU;
}
- else if( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
+ else if(isBuiltinFunction && (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))) {
_etype = DMLScript.USE_ACCELERATOR ? ExecType.GPU : ExecType.CP;
}
+ else if(isBuiltinFunction && getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
+ // Only GPU implementation is supported
+ _etype = ExecType.GPU;
+ }
else {
// Since the memory estimate is only conservative, do not throw
// exception if the estimated memory is larger than the budget
@@ -312,6 +322,10 @@ public class FunctionOp extends Hop
return _etype;
}
+
+ private boolean isBuiltinFunction() {
+ return getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE);
+ }
@Override
public void refreshSizeInformation()
http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 5d357c6..d8f4424 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1537,6 +1537,7 @@ public abstract class Hop implements ParseInfo
HopsData2String.put(DataOpTypes.PERSISTENTWRITE, "PWrite");
HopsData2String.put(DataOpTypes.TRANSIENTWRITE, "TWrite");
HopsData2String.put(DataOpTypes.TRANSIENTREAD, "TRead");
+ HopsData2String.put(DataOpTypes.FUNCTIONOUTPUT, "FunOut");
}
public static OpOp2 getOpOp2ForOuterVectorOperation(String op)
http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index 1c00c6f..b946178 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -20,47 +20,64 @@
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.FunctionOp.FunctionType;
import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.Hop.OpOpDnn;
import org.apache.sysml.hops.Hop.ReOrgOp;
+import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.DnnOp;
import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
/*
* This class contains GPU-specific rewrites for following patterns:
*
- * 1. batchNormTest:
+ * 1. batchNormTest: applied when mode="test" in batch normalization nn layer.
* norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
* hi = bias_add(bias_multiply(norm, gamma), beta)
*
* 2. channelSum:
* output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))
+ *
+ * 3. batchNormTrain: applied when mode="train" in batch normalization nn layer.
+ * This rewrite is only enabled if none of the outputs are persistent writes as it assumes that
+ * FunctionOp will introduce a transient writes. This rewrite replaces the existing outputs of the matched pattern with transient reads.
+ *
*/
public class RewriteGPUSpecificOps extends HopRewriteRule {
+ private static int _seq = 1;
+
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
if( roots == null )
return roots;
//one pass rewrite-descend (rewrite created pattern)
- for( Hop h : roots )
- rule_GPUKernels( h, false );
+ for( int i = 0; i < roots.size(); i++ )
+ rule_GPUKernels(roots, roots.get(i), false );
Hop.resetVisitStatus(roots, true);
//one pass descend-rewrite (for rollup)
- for( Hop h : roots )
- rule_GPUKernels( h, true );
+ for( int i = 0; i < roots.size(); i++ )
+ rule_GPUKernels(roots, roots.get(i), true );
Hop.resetVisitStatus(roots, true);
return roots;
@@ -72,12 +89,12 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
return root;
//one pass rewrite-descend (rewrite created pattern)
- rule_GPUKernels( root, false );
+ rule_GPUKernels(null, root, false );
root.resetVisitStatus();
//one pass descend-rewrite (for rollup)
- rule_GPUKernels( root, true );
+ rule_GPUKernels(null, root, true );
return root;
}
@@ -85,10 +102,11 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
/**
* Fuse the kernel
*
+ * @param roots root operators
* @param hop high-level operator
* @param descendFirst true if recursively process children first
*/
- private void rule_GPUKernels(Hop hop, boolean descendFirst)
+ private void rule_GPUKernels(ArrayList<Hop> roots, Hop hop, boolean descendFirst)
{
if(hop.isVisited())
return;
@@ -99,13 +117,16 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
//process childs recursively first (to allow roll-up)
if( descendFirst )
- rule_GPUKernels(hi, descendFirst); //see below
+ rule_GPUKernels(roots, hi, descendFirst); //see below
+ if(roots != null) {
+ hi = batchNormTrain(roots, hop, hi, i);
+ }
hi = batchNormTest(hop, hi, i);
hi = channelSums(hop, hi, i);
if( !descendFirst )
- rule_GPUKernels(hi, descendFirst);
+ rule_GPUKernels(roots, hi, descendFirst);
}
hop.setVisited();
@@ -149,6 +170,10 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget();
}
+ private static boolean hasFirstInput(Hop h) {
+ return !(h == null || h.getInput() == null || h.getInput().size() < 1);
+ }
+
private static Hop getFirstInput(Hop h) {
if(h == null || h.getInput() == null || h.getInput().size() < 1) {
throw new RuntimeException("No input available for " + h);
@@ -156,13 +181,24 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
return h.getInput().get(0);
}
+ private static boolean hasSecondInput(Hop h) {
+ return !(h == null || h.getInput() == null || h.getInput().size() < 2);
+ }
+
private static Hop getSecondInput(Hop h) {
if(h == null || h.getInput() == null || h.getInput().size() < 2) {
- throw new RuntimeException("No input available for " + h);
+ throw new RuntimeException("Expected atleast two inputs for " + h);
}
return h.getInput().get(1);
}
+ private static Hop getThirdInput(Hop h) {
+ if(h == null || h.getInput() == null || h.getInput().size() < 3) {
+ throw new RuntimeException("Expected atleast three inputs for " + h);
+ }
+ return h.getInput().get(2);
+ }
+
private static boolean isUnaryMinus(Hop h) {
return HopRewriteUtils.isBinary(h, OpOp2.MINUS)
&& HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0);
@@ -200,13 +236,488 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
return hi;
}
+ private static boolean isRowMeans(Hop h) {
+ return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row;
+ }
+
+ private static boolean isRowVars(Hop h) {
+ return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row;
+ }
+
+ private static boolean isRowVars(Hop h, Hop childHop) {
+ return isRowVars(h) && getFirstInput(h) == childHop;
+ }
+
+ private static boolean isColMeans(Hop h) {
+ return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col;
+ }
+
+ private static boolean isColVars(Hop h) {
+ return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col;
+ }
+
+ private static boolean isReshape(Hop h) {
+ return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE;
+ }
+
+ private static boolean isReshape(Hop h, long expectedRows, long expectedCols) {
+ return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE &&
+ Hop.computeSizeInformation(getSecondInput(h)) == expectedRows &&
+ Hop.computeSizeInformation(getThirdInput(h)) == expectedCols;
+ }
+
+ private static boolean isBinaryAdd(Hop h) {
+ return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS;
+ }
+
+ private static boolean isBinaryMSAdd(Hop h, double expectedValue) {
+ return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS
+ && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
+ && OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue;
+ }
+
+ private static boolean isBinaryMMAdd(Hop h) {
+ return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS
+ && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
+ }
+
+ private static boolean isBinaryMSMult(Hop h, double expectedValue) {
+ return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT
+ && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
+ && OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue;
+ }
+
+ private static boolean isBinarySSMinus(Hop h) {
+ return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS
+ && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
+ }
+
+ private static boolean isBinarySSDiv(Hop h) {
+ return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV
+ && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
+ }
+
+ private static boolean isBinarySMDiv(Hop h, double expectedValue) {
+ return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV
+ && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.MATRIX
+ && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>()) == expectedValue;
+ }
+
+ private static boolean isAnyBinaryAdd(ArrayList<Hop> hops) {
+ if(hops != null) {
+ for(Hop h : hops) {
+ if(h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS)
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private static boolean isBinaryMSMult(Hop h) {
+ return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT
+ && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR;
+ }
+
+ private static boolean isBinarySMMult(Hop h) {
+ return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT
+ && getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR;
+ }
+
+ /**
+ * Checks if the "mean" hop is a moving average of mean in batch normalization layer.
+ *
+ * @param mean hop to check against
+ * @param X input data
+ * @return true if the "mean" hop is a moving average of mean in batch normalization layer.
+ */
+ private static boolean isBatchNormTrainMean(Hop mean, Hop X) {
+ // subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
+ // mean = rowMeans(subgrp_means)
+ return isRowMeans(mean) && isReshape(getFirstInput(mean)) && isColMeans(getFirstInput(getFirstInput(mean)))
+ && getFirstInput(getFirstInput(getFirstInput(mean))) == X;
+ }
+
+ /**
+ * Checks for nrow(X) pattern
+ *
+ * @param expr hop to be matched
+ * @param X input X
+ * @return true if expr is nrow(X) else false
+ */
+ private static boolean isNrowOfX(Hop expr, Hop X) {
+ return expr instanceof UnaryOp && ((UnaryOp)expr).getOp() == OpOp1.NROW && getFirstInput(expr) == X;
+ }
+
+ /**
+ * Checks for the colVars(X) * ((N-1)/N) pattern
+ *
+ * @param expr hop to be matched
+ * @param X input X
+ * @param ignoreCorrectionTerm whether to ignore the correction term ((N-1)/N).
+ * @return true if expr is colVars(X) * ((N-1)/N) else false
+ */
+ private static boolean isCorrectedColVars(Hop expr, Hop X, boolean ignoreCorrectionTerm) {
+ // colVars(X) * ((N-1)/N)
+ if(isColVars(expr) && getFirstInput(expr) == X) {
+ // Support no correction as well in this rewrite
+ return true;
+ }
+ else if(X.rowsKnown()) {
+ return isBinaryMSMult(expr, ((double)X.getDim1()-1)/X.getDim1()) &&
+ isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X;
+ }
+ else if(isBinaryMSMult(expr) &&
+ isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X) {
+ if(ignoreCorrectionTerm) {
+ return true;
+ }
+ Hop tmp = getSecondInput(expr);
+ // ((N-1)/N)
+ boolean isNMinus1Pattern = isBinarySSDiv(tmp) && isBinarySSMinus(getFirstInput(tmp)) &&
+ getFirstInput(getFirstInput(tmp)) == getSecondInput(tmp) &&
+ OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(getFirstInput(tmp)), new HashMap<>()) == 1;
+ boolean ret = isNMinus1Pattern && isNrowOfX(getSecondInput(tmp), X);
+ if(LOG.isDebugEnabled()) {
+ LOG.debug("Is the corrected column variance pattern for batch_norm_train rewrite when number of rows of X unknown matched:" + ret);
+ }
+ return ret;
+ }
+ return false;
+ }
+
+ /**
+ * Checks if the "var" hop is a moving average of variance in batch normalization layer.
+ *
+ * @param mean previously matched mean hop
+ * @param var the hop to check against
+ * @param X input data hop
+ * @param subgrpMeans mean for subgroup mean
+ * @param ignoreCorrectionTerm whether to incore the correct term (see isCorrectedColVars method in this class)
+ * @return true if the "var" hop is a moving average of variance in batch normalization layer.
+ */
+ private static boolean isBatchNormTrainVar(Hop mean, Hop var, Hop X, Hop subgrpMeans, boolean ignoreCorrectionTerm) {
+ long numChannels = Hop.computeSizeInformation(getSecondInput(getFirstInput(mean)));
+ long HW = Hop.computeSizeInformation(getThirdInput(getFirstInput(mean)));
+ // subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
+ // var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+ return numChannels > 0 && HW > 0 && isBinaryMMAdd(var) && isRowMeans(getFirstInput(var)) &&
+ // matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
+ isReshape(getFirstInput(getFirstInput(var)), numChannels, HW) &&
+ isCorrectedColVars(getFirstInput(getFirstInput(getFirstInput(var))), X, ignoreCorrectionTerm) &&
+ // rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+ isBinaryMSMult(getSecondInput(var), ((((double)HW)-1)/HW)) &&
+ isRowVars(getFirstInput(getSecondInput(var)), subgrpMeans);
+ }
+
+ /**
+ * Checks and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
+ *
+ * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean
+ * @param mu value of mu
+ * @return an array [ema_mean_upd, ema_mean] if expression matched, else null
+ */
+ private static Hop [] getUpdatedMovingAverageExpressions(Hop rhsTimesOp, double mu) {
+ if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 ||
+ !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
+ return null;
+
+ // Check (1-mu)*mean
+ double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>());
+ Hop plusOp = rhsTimesOp.getParent().get(0);
+ Hop lhsTimesOp = null;
+ if(plusOp.getInput().get(0) == rhsTimesOp) {
+ lhsTimesOp = plusOp.getInput().get(1);
+ }
+ else {
+ lhsTimesOp = plusOp.getInput().get(0);
+ }
+
+ if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 &&
+ isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) {
+ return new Hop[] {
+ plusOp.getParent().get(0),
+ getSecondInput(lhsTimesOp),
+ getSecondInput(rhsTimesOp)
+ };
+ }
+ return null;
+ }
+
+ /**
+ * Checks (if exactly one of rhsTimesOps) and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
+ *
+ * @param rhsTimesOps array list of hop representing BinaryOp of expression (1-mu)*mean
+ * @param mu value of mu
+ * @return an array [ema_mean_upd, ema_mean] if any of the expression matched, else null
+ */
+ private static Hop [] getUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps, double mu) {
+ if(rhsTimesOps == null || rhsTimesOps.size() == 0)
+ return null;
+
+ Hop [] ret = null;
+ for(Hop h : rhsTimesOps) {
+ boolean matched = isUpdatedMovingAverageExpression(h, mu);
+ if(matched && ret != null) {
+ return null; // Multiple matches, cannot decide which one to fuse
+ }
+ else if(matched) {
+ ret = getUpdatedMovingAverageExpressions(h, mu);
+ }
+ }
+
+ return ret;
+ }
+
+ /**
+ * Checks and returns the mu in the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
+ *
+ * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean
+ * @return value of mu if the expression matched else null
+ */
+ private static Double getMuFromUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps) {
+ if(rhsTimesOps == null || rhsTimesOps.size() == 0)
+ return null;
+
+ Double ret = null;
+ for(Hop h : rhsTimesOps) {
+ boolean matched = isUpdatedMovingAverageExpression(h);
+ if(matched && ret != null) {
+ return null; // Multiple matches, cannot decide which one to fuse
+ }
+ else if(matched) {
+ ret = -(OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>())-1);
+ }
+ }
+ return ret;
+ }
+
+ /**
+ * Checks for the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
+ *
+ * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean
+ * @return true if expression matched
+ */
+ private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp) {
+ if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 ||
+ !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
+ return false;
+
+ // Check (1-mu)*mean
+ Hop plusOp = rhsTimesOp.getParent().get(0);
+ Hop lhsTimesOp = null;
+ if(plusOp.getInput().get(0) == rhsTimesOp) {
+ lhsTimesOp = plusOp.getInput().get(1);
+ }
+ else {
+ lhsTimesOp = plusOp.getInput().get(0);
+ }
+
+ if(plusOp.getParent() != null && plusOp.getParent().size() == 1 && isBinarySMMult(lhsTimesOp)) {
+ return true;
+ }
+ return false;
+ }
+
+ // ema_mean_upd = mu*ema_mean + (1-mu)*mean
+ // Returns true if expression matched, else false
+ private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp, double mu) {
+ if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 ||
+ !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
+ return false;
+
+ // Check (1-mu)*mean
+ double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>());
+ Hop plusOp = rhsTimesOp.getParent().get(0);
+ Hop lhsTimesOp = null;
+ if(plusOp.getInput().get(0) == rhsTimesOp) {
+ lhsTimesOp = plusOp.getInput().get(1);
+ }
+ else {
+ lhsTimesOp = plusOp.getInput().get(0);
+ }
+
+ if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 &&
+ isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) {
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * Checks for the expression 1/sqrt(denom)
+ *
+ * @param denom denominator of the expression to be matched
+ * @return true if the expression 1/sqrt(denom) matched else false
+ */
+ private static boolean isOneBySqrt(Hop denom) {
+ return denom.getParent() != null && denom.getParent().get(0) instanceof UnaryOp &&
+ ((UnaryOp)denom.getParent().get(0)).getOp() == OpOp1.SQRT &&
+ denom.getParent().get(0).getParent() != null && denom.getParent().get(0).getParent().size() == 1 &&
+ isBinarySMDiv(denom.getParent().get(0).getParent().get(0), 1);
+ }
+
+ /**
+ * Checks for the batch norm (mode="train") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar
+ * and returns a new FunctionOp if matched
+ *
+ * @param roots root hops of the given statement block
+ * @param parent parent of the input
+ * @param hi input to be matched
+ * @param pos position
+ * @return a new FunctionOp or hi
+ */
+ private static Hop batchNormTrain(ArrayList<Hop> roots, Hop parent, Hop hi, int pos)
+ {
+ // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
+ // hi = bias_add(bias_multiply(norm, gamma), beta)
+ // 2x for input and output and 1x for overhead
+ // fitsOnGPU(hi, 3)
+ if( hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) ) {
+ Hop norm = getFirstInput(getFirstInput(hi));
+ if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm))
+ && hasSecondInput(getFirstInput(norm)) && isUnaryMinus(getSecondInput(getFirstInput(norm)))
+ && isOneDivideBySqrt(getSecondInput(norm))) {
+ double eps = 0;
+ Hop var = getFirstInput(getSecondInput(getSecondInput(norm)));
+ if(isBinaryAdd(var) && (getFirstInput(var) instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) {
+ // eps + ema_var
+ if(getFirstInput(var) instanceof LiteralOp) {
+ eps = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>());
+ var = getSecondInput(var);
+ }
+ else {
+ eps = OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new HashMap<>());
+ var = getFirstInput(var);
+ }
+ }
+ // Generate batch norm test op
+ Hop X = getFirstInput(getFirstInput(norm));
+ Hop mean = getSecondInput(getSecondInput(getFirstInput(norm)));
+
+ if(hasFirstInput(mean) && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), false) &&
+ mean.getParent() != null && mean.getParent().size() >= 2 &&
+ var.getParent() != null && var.getParent().size() == 2) {
+ Hop gamma = getSecondInput(getFirstInput(hi));
+ Hop beta = getSecondInput(hi);
+
+ // Always get mu from variance as it will have exactly one match of fusion pattern
+ Double potentialMu = getMuFromUpdatedMovingAverageExpressions(var.getParent());
+ if(potentialMu == null)
+ return hi;
+ double mu = potentialMu;
+
+ Hop [] means = getUpdatedMovingAverageExpressions(mean.getParent(), mu);
+ Hop [] vars = getUpdatedMovingAverageExpressions(var.getParent(), mu);
+ if(means == null || vars == null)
+ return hi;
+
+ Hop varPlusEps = null;
+ boolean isFirstBinaryAddOp = isAnyBinaryAdd(var.getParent().get(0).getParent());
+ boolean isSecondBinaryAddOp = isAnyBinaryAdd(var.getParent().get(1).getParent());
+ if(isFirstBinaryAddOp && !isSecondBinaryAddOp) {
+ varPlusEps = var.getParent().get(1);
+ }
+ else if(!isFirstBinaryAddOp && isSecondBinaryAddOp) {
+ varPlusEps = var.getParent().get(0);
+ }
+ if(varPlusEps != null && isBinaryMSAdd(varPlusEps, eps) && isOneBySqrt(varPlusEps)) {
+
+ Hop cache_var = varPlusEps.getParent().get(0).getParent().get(0);
+ Hop ema_mean_upd = means[0];
+ Hop ema_var_upd = vars[0];
+ Hop ema_mean = means[1];
+ Hop ema_var = vars[1];
+ Hop cache_mean = means[2];
+
+
+ ArrayList<Hop> inHops = new ArrayList<Hop>();
+ inHops.add(X);
+ inHops.add(gamma);
+ inHops.add(beta);
+ inHops.add(ema_mean);
+ inHops.add(ema_var);
+ inHops.add(new LiteralOp(eps));
+ inHops.add(new LiteralOp(mu));
+ Hop [] oldHops = {hi, ema_mean_upd, ema_var_upd, cache_mean, cache_var};
+
+ // Since FunctionOp adds transientwrite explicitly, persistent writes are not supported
+ if(!isAnyPersistentWrite(oldHops)) {
+ LOG.debug("Applied batchNormTrain rewrite.");
+ ArrayList<Hop> outputs = getMultiOutputHops(roots, oldHops);
+ FunctionOp ret = new FunctionOp(FunctionType.MULTIRETURN_BUILTIN, DMLProgram.INTERNAL_NAMESPACE, "batch_norm2d_train",
+ inHops, outputs.stream().map(h -> h.getName()).toArray(String[]::new), outputs);
+ Collections.reverse(roots);
+ roots.add(ret);
+ Collections.reverse(roots);
+ return ret;
+ }
+ }
+
+ }
+ }
+ }
+
+ return hi;
+ }
+
+ // ------------------------------------------------------------
+ /**
+ * Checks if any of the given output hop is a persistent write.
+ *
+ * @param outputHops output hops to check
+ * @return true if any of the hop is a persistent write else false.
+ */
+ private static boolean isAnyPersistentWrite(Hop [] outputHops) {
+ for(Hop outHop : outputHops) {
+ if(HopRewriteUtils.isData(outHop, DataOpTypes.PERSISTENTWRITE))
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * Returns output hop for a multi-output FunctionOp to be created by rewrite.
+ *
+ * @param roots root hops of statement block
+ * @param oldHops old output hops of the pattern
+ * @return new output hops that should be passed to FunctionOp
+ */
+ private static ArrayList<Hop> getMultiOutputHops(ArrayList<Hop> roots, Hop [] oldHops) {
+ ArrayList<Hop> ret = new ArrayList<>();
+ for(int i = 0; i < oldHops.length; i++) {
+ // Create a transient read as FunctionOp will add a transient write.
+ if(HopRewriteUtils.isData(oldHops[i], DataOpTypes.PERSISTENTWRITE))
+ throw new RuntimeException("Persistent write is not supported as output for the given rewrite." + oldHops[i]);
+ // Generate a new name if the old output was not transient write.
+ String name = HopRewriteUtils.isData(oldHops[i], DataOpTypes.TRANSIENTWRITE) ? oldHops[i].getName() : "_genGPU" + (_seq++);
+ DataOp tRead = HopRewriteUtils.createTransientRead(name, oldHops[i]);
+ HopRewriteUtils.rewireAllParentChildReferences(oldHops[i], tRead);
+ ret.add(tRead);
+ // Remove old output from roots to avoid unnecessary computation.
+ if(roots.contains(oldHops[i])) {
+ roots.remove(oldHops[i]);
+ }
+ }
+ return ret;
+ }
+ // ------------------------------------------------------------
+
+ /**
+ * Checks for the batch norm (mode="test") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar
+ * and returns a new DnnOp if matched
+ *
+ * @param parent parent of the input
+ * @param hi input to be matched
+ * @param pos position
+ * @return a new DnnOp or hi
+ */
private static Hop batchNormTest(Hop parent, Hop hi, int pos) {
// norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
// hi = bias_add(bias_multiply(norm, gamma), beta)
// 2x for input and output and 1x for overhead
- if( isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) {
+ if(hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) {
Hop norm = getFirstInput(getFirstInput(hi));
- if(isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm))
+ if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm))
&& isUnaryMinus(getSecondInput(getFirstInput(norm)))
&& isOneDivideBySqrt(getSecondInput(norm))) {
double eps = 0;
@@ -226,20 +737,28 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
// Generate batch norm test op
Hop X = getFirstInput(getFirstInput(norm));
Hop mean = getSecondInput(getSecondInput(getFirstInput(norm)));
- Hop gamma = getSecondInput(getFirstInput(hi));
- Hop beta = getSecondInput(hi);
- ArrayList<Hop> inHops = new ArrayList<Hop>();
- inHops.add(X);
- inHops.add(gamma);
- inHops.add(beta);
- inHops.add(mean);
- inHops.add(var);
- inHops.add(new LiteralOp(eps));
- if(fitsOnGPU(inHops, true)) {
- LOG.debug("Applied batchNormTest rewrite.");
- Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
- OpOpDnn.BATCH_NORM2D_TEST, inHops);
- return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+
+ // This guard disallows eager fusion of train batch normalization into test batch normalization
+ boolean potentialForBatchNormTrain = !X.rowsKnown() && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), true);
+ if(!potentialForBatchNormTrain) {
+ Hop gamma = getSecondInput(getFirstInput(hi));
+ Hop beta = getSecondInput(hi);
+ ArrayList<Hop> inHops = new ArrayList<Hop>();
+ inHops.add(X);
+ inHops.add(gamma);
+ inHops.add(beta);
+ inHops.add(mean);
+ inHops.add(var);
+ inHops.add(new LiteralOp(eps));
+ if(fitsOnGPU(inHops, true)) {
+ LOG.debug("Applied batchNormTest rewrite.");
+ Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
+ OpOpDnn.BATCH_NORM2D_TEST, inHops);
+ return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+ }
+ }
+ else {
+ LOG.debug("Skipping batchNormTest rewrite as there is potential for batch normalization train rewrite after recompilation.");
}
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
index ac58335..711219a 100644
--- a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
+++ b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
@@ -42,8 +42,16 @@ public class FunctionCallCP extends Lop
this(inputs, fnamespace, fname, outputs, et);
if(outputHops != null) {
_outputLops = new ArrayList<>();
- for(Hop h : outputHops)
- _outputLops.add( h.constructLops() );
+ setLevel();
+ for(Hop h : outputHops) {
+ Lop outputLop = h.constructLops();
+ _outputLops.add( outputLop );
+ addOutput(outputLop);
+ // Update the output level if necessary for correct instruction ordering
+ if(outputLop.getLevel() <= getLevel()) {
+ outputLop.updateLevel(getLevel()+1);
+ }
+ }
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/lops/Lop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java
index 9e81496..885b8b9 100644
--- a/src/main/java/org/apache/sysml/lops/Lop.java
+++ b/src/main/java/org/apache/sysml/lops/Lop.java
@@ -345,6 +345,19 @@ public abstract class Lop
lps.setLevel(inputs);
}
+ protected void updateLevel(int newLevel) {
+ if(newLevel < getLevel()) {
+ throw new RuntimeException("Decrement the levels not supported.");
+ }
+ else if(newLevel > getLevel()) {
+ lps.setLevel(newLevel);
+ for(Lop out : outputs) {
+ if(out.getLevel() < newLevel+1)
+ out.updateLevel(newLevel+1);
+ }
+ }
+ }
+
/**
* Method to get the location property of LOP
*
http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index b9e5f9d..7cf7418 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2000,8 +2000,8 @@ public class DMLTranslator
String[] outputNames = new String[targetList.size()];
outputNames[0] = ((DataIdentifier)targetList.get(0)).getName();
outputNames[1] = ((DataIdentifier)targetList.get(1)).getName();
- outputs.add(new DataOp(outputNames[0], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[0]));
- outputs.add(new DataOp(outputNames[1], DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[1]));
+ outputs.add(new DataOp(outputNames[0], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, inputs.get(0).getFilename()));
+ outputs.add(new DataOp(outputNames[1], DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, inputs.get(0).getFilename()));
currBuiltinOp = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), inputs, outputNames, outputs);
break;
@@ -2233,7 +2233,7 @@ public class DMLTranslator
String[] outputNames = new String[targetList.size()];
for ( int i=0; i < targetList.size(); i++ ) {
outputNames[i] = ((DataIdentifier)targetList.get(i)).getName();
- Hop output = new DataOp(outputNames[i], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[i]);
+ Hop output = new DataOp(outputNames[i], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, inputs.get(0).getFilename());
outputs.add(output);
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 1122a24..01b10a8 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -61,6 +61,7 @@ public class GPUInstructionParser extends InstructionParser
String2GPUInstructionType.put( "batch_norm2d", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "batch_norm2d_backward", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "batch_norm2d_test", GPUINSTRUCTION_TYPE.Dnn);
+ String2GPUInstructionType.put( "batch_norm2d_train", GPUINSTRUCTION_TYPE.Dnn);
// Matrix Multiply Operators
String2GPUInstructionType.put( "ba+*", GPUINSTRUCTION_TYPE.AggregateBinary);
http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index b01b8d8..a36d0fc 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -351,12 +351,28 @@ public class DnnGPUInstruction extends GPUInstruction {
CPOperand out = new CPOperand(parts[7]);
return new DnnGPUInstruction(in, in2, in3, in4, in5, in6, out, opcode, str, 0);
}
+ else if (opcode.equalsIgnoreCase("batch_norm2d_train")) {
+ InstructionUtils.checkNumFields(parts, 12);
+ CPOperand in1 = new CPOperand(parts[1]); // image
+ CPOperand in2 = new CPOperand(parts[2]); // gamma
+ CPOperand in3 = new CPOperand(parts[3]); // beta
+ CPOperand in4 = new CPOperand(parts[4]); // ema_mean
+ CPOperand in5 = new CPOperand(parts[5]); // ema_var
+ CPOperand in6 = new CPOperand(parts[6]); // eps
+ CPOperand in7 = new CPOperand(parts[7]); // mu
+ CPOperand out = new CPOperand(parts[8]); // out
+ CPOperand out2 = new CPOperand(parts[9]); // ema_mean_upd
+ CPOperand out3 = new CPOperand(parts[10]); // ema_var_upd
+ CPOperand out4 = new CPOperand(parts[11]); // cache_mean
+ CPOperand out5 = new CPOperand(parts[12]); // cache_inv_var
+ return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0);
+ }
else {
throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str);
}
}
- public void processBiasInstruction(String instOpcode, ExecutionContext ec) {
+ private void processBiasInstruction(String instOpcode, ExecutionContext ec) {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName());
@@ -372,7 +388,7 @@ public class DnnGPUInstruction extends GPUInstruction {
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
- public void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException {
+ private void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
@@ -420,7 +436,41 @@ public class DnnGPUInstruction extends GPUInstruction {
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
- public void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException {
+ private void processBatchNorm2dTrainInstruction(ExecutionContext ec) throws DMLRuntimeException {
+ GPUStatistics.incrementNoOfExecutedGPUInst();
+ MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
+ MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
+ MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
+ MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName());
+ MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName());
+
+ double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue();
+ double exponentialAverageFactor = 1-ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue();
+
+ MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
+ MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
+ MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
+ MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
+ MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
+
+ LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(),
+ image, scale, bias, runningMean, runningVar, ret,
+ retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
+
+ // release inputs/outputs
+ ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+ ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+ ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+ ec.releaseMatrixInputForGPUInstruction(_input4.getName());
+ ec.releaseMatrixInputForGPUInstruction(_input5.getName());
+ ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+ ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
+ ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
+ ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
+ ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
+ }
+
+ private void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
@@ -485,7 +535,7 @@ public class DnnGPUInstruction extends GPUInstruction {
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
- public void processChannelSumsInstruction(ExecutionContext ec) {
+ private void processChannelSumsInstruction(ExecutionContext ec) {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
int C = (int) ec.getScalarInput(_input2.getName(), _input2.getValueType(), _input2.isLiteral()).getLongValue();
@@ -667,6 +717,10 @@ public class DnnGPUInstruction extends GPUInstruction {
processBatchNorm2dTestInstruction(ec);
return;
}
+ else if (instOpcode.equalsIgnoreCase("batch_norm2d_train")) {
+ processBatchNorm2dTrainInstruction(ec);
+ return;
+ }
GPUStatistics.incrementNoOfExecutedGPUInst();
http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
index 83adad4..b8bb9b6 100644
--- a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
@@ -46,10 +46,10 @@ public class BatchNormTest extends GPUTests {
testBatchNormForward("test");
}
-// @Test
-// public void testBatchNormForwardTrain() {
-// testBatchNormForward("train");
-// }
+ @Test
+ public void testBatchNormForwardTrain() {
+ testBatchNormForward("train");
+ }
private void testBatchNormForward(String mode) {
int imgSize = 32;
@@ -58,18 +58,29 @@ public class BatchNormTest extends GPUTests {
String scriptStr = "source(\"nn/layers/batch_norm2d_old.dml\") as batch_norm2d_old;\n "
+ "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d_old::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, ema_var, 0.9, 1e-3)";
HashMap<String, Object> inputs = new HashMap<>();
- inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 100, sparsity, seed));
- inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 10, sparsity, seed));
- inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 0, 10, sparsity, seed));
- inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 1, 40, 60, sparsity, seed));
- inputs.put("ema_var", generateInputMatrix(spark, numChannels, 1, 5, 15, sparsity, seed));
+ inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 10, sparsity, seed));
+ inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
+ inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
+ inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 1, 3, 7, sparsity, seed));
+ inputs.put("ema_var", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
List<String> outputs = Arrays.asList("output", "ema_mean_upd", "ema_var_upd", "cache_mean", "cache_var");
List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs);
List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
- if(mode.equals("test"))
+ if(mode.equals("test")) {
assertHeavyHitterPresent("gpu_batch_norm2d_test");
- for(int i = 0; i < outputs.size(); i++) {
- assertEqualObjects(outCPU.get(i), outGPU.get(i));
+ for(int i = 0; i < outputs.size(); i++) {
+ assertEqualObjects(outCPU.get(i), outGPU.get(i));
+ }
+ }
+ else {
+ assertHeavyHitterPresent("gpu_batch_norm2d_train");
+ double [] threshold = new double[outputs.size()];
+ Arrays.fill(threshold, getTHRESHOLD());
+ // Handle loss of precision in CuDNN kernel
+ threshold[2] = 1e-3;
+ for(int i = 0; i < outputs.size()-1; i++) {
+ assertEqualObjects(outCPU.get(i), outGPU.get(i), threshold[i]);
+ }
}
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/a6bca885/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
index e006fd2..cae2e33 100644
--- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
@@ -212,7 +212,7 @@ public abstract class GPUTests extends AutomatedTestBase {
return in1;
}
- private void printMatrixIfNotEqual(MatrixBlock expectedMB, MatrixBlock actualMB) {
+ private void printMatrixIfNotEqual(MatrixBlock expectedMB, MatrixBlock actualMB, double threshold) {
long rows = expectedMB.getNumRows();
long cols = expectedMB.getNumColumns();
boolean matrixNotEqual = false;
@@ -222,7 +222,7 @@ public abstract class GPUTests extends AutomatedTestBase {
double actualDouble = actualMB.quickGetValue(i, j);
if (expectedDouble != 0.0 && !Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) {
double relativeError = Math.abs((expectedDouble - actualDouble) / expectedDouble);
- if(relativeError >= getTHRESHOLD()) {
+ if(relativeError >= threshold) {
matrixNotEqual = true;
break;
}
@@ -250,12 +250,13 @@ public abstract class GPUTests extends AutomatedTestBase {
*
* @param expected expected matrix
* @param actual actual matrix
+ * @param threshold relative threshold
*/
- private void assertEqualMatrices(Matrix expected, Matrix actual) {
+ private void assertEqualMatrices(Matrix expected, Matrix actual, double threshold) {
try {
// Faster way to compare two matrices
MLContext cpuMLC = new MLContext(spark);
- String scriptStr = "num_mismatch = sum((abs(X - Y) / X) > " + getTHRESHOLD() + ");";
+ String scriptStr = "num_mismatch = sum((abs(X - Y) / X) > " + threshold + ");";
Script script = ScriptFactory.dmlFromString(scriptStr).in("X", expected).in("Y", actual).out("num_mismatch");
long num_mismatch = cpuMLC.execute(script).getLong("num_mismatch");
cpuMLC.close();
@@ -271,7 +272,7 @@ public abstract class GPUTests extends AutomatedTestBase {
Assert.assertEquals(rows, actualMB.getNumRows());
Assert.assertEquals(cols, actualMB.getNumColumns());
- if(PRINT_MAT_ERROR) printMatrixIfNotEqual(expectedMB, actualMB);
+ if(PRINT_MAT_ERROR) printMatrixIfNotEqual(expectedMB, actualMB, threshold);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
@@ -285,12 +286,12 @@ public abstract class GPUTests extends AutomatedTestBase {
"Relative error(%f) is more than threshold (%f). Expected = %f, Actual = %f, differed at [%d, %d]",
relativeError, getTHRESHOLD(), expectedDouble, actualDouble, i, j);
if(FLOATING_POINT_PRECISION.equals("double"))
- Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD());
+ Assert.assertTrue(format.toString(), relativeError < threshold);
else
- Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD() || absoluteError < getTHRESHOLD());
+ Assert.assertTrue(format.toString(), relativeError < threshold || absoluteError < threshold);
format.close();
} else {
- Assert.assertEquals(expectedDouble, actualDouble, getTHRESHOLD());
+ Assert.assertEquals(expectedDouble, actualDouble, threshold);
}
}
}
@@ -349,6 +350,7 @@ public abstract class GPUTests extends AutomatedTestBase {
// and other side effects.
synchronized(GPUTests.class) {
MLContext gpuMLC = new MLContext(spark);
+ // gpuMLC.setExplain(true); gpuMLC.setExplainLevel("recompile_runtime");
gpuMLC.setConfigProperty("sysml.floating.point.precision", FLOATING_POINT_PRECISION);
if(IGNORE_CLEAR_MEMORY_BUG)
gpuMLC.setConfigProperty("sysml.gpu.eager.cudaFree", "true");
@@ -366,7 +368,7 @@ public abstract class GPUTests extends AutomatedTestBase {
return outputs;
}
}
-
+
/**
* Assert that the two objects are equal. Supported types are Boolean, Integer, String, Double and Matrix
*
@@ -374,6 +376,17 @@ public abstract class GPUTests extends AutomatedTestBase {
* @param actual
*/
protected void assertEqualObjects(Object expected, Object actual) {
+ assertEqualObjects(expected, actual, getTHRESHOLD());
+ }
+
+ /**
+ * Assert that the two objects are equal. Supported types are Boolean, Integer, String, Double and Matrix
+ *
+ * @param expected expected value
+ * @param actual actual value
+ * @param threshold relative error threshold
+ */
+ protected void assertEqualObjects(Object expected, Object actual, double threshold) {
Assert.assertEquals(expected.getClass(), actual.getClass());
if (expected instanceof Boolean) {
@@ -384,16 +397,16 @@ public abstract class GPUTests extends AutomatedTestBase {
if (expectedDouble != 0.0 && !Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) {
double relativeError = Math.abs((expectedDouble - actualDouble) / expectedDouble);
Assert.assertTrue("Comparing floating point numbers, relative error(" + relativeError
- + ") is more than threshold (" + getTHRESHOLD() + ")", relativeError < getTHRESHOLD());
+ + ") is more than threshold (" + threshold + ")", relativeError < threshold);
} else {
- Assert.assertEquals(expectedDouble, actualDouble, getTHRESHOLD());
+ Assert.assertEquals(expectedDouble, actualDouble, threshold);
}
} else if (expected instanceof String) {
Assert.assertEquals(expected.toString(), actual.toString());
} else if (expected instanceof Integer) {
Assert.assertEquals(((Integer) expected).intValue(), ((Integer) actual).intValue());
} else if (expected instanceof Matrix)
- assertEqualMatrices((Matrix) expected, (Matrix) actual);
+ assertEqualMatrices((Matrix) expected, (Matrix) actual, threshold);
else {
Assert.fail("Invalid types for comparison");
}