You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/08/30 22:43:05 UTC
[1/3] systemml git commit: [SYSTEMML-445] Removed batch_norm builtin
functions
Repository: systemml
Updated Branches:
refs/heads/master 81419ae6a -> 0f36780a8
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index e736a1c..d620de9 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -19,8 +19,8 @@
package org.apache.sysml.runtime.instructions.gpu;
import java.util.ArrayList;
-
import jcuda.Pointer;
+import jcuda.jcudnn.JCudnn;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
@@ -32,7 +32,6 @@ import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import org.apache.sysml.runtime.util.DnnUtils;
@@ -57,12 +56,14 @@ public class DnnGPUInstruction extends GPUInstruction {
private ArrayList<CPOperand> _stride = new ArrayList<>();
private ArrayList<CPOperand> _padding = new ArrayList<>();
private double _intermediateMemoryBudget = 0;
+ private GPUContext gCtx;
+ private String instName;
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
- if (!(opcode.equals("bias_add") || opcode.equals("bias_multiply") || opcode.equals("relu_backward"))) {
+ if (!(opcode.equals("bias_add") || opcode.equals("bias_multiply") || opcode.equals("relu_backward") || opcode.equals("inv_var") )) {
throw new DMLRuntimeException(
- "Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found "
+ "Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward or inv_var, but found "
+ opcode);
}
_input1 = in1;
@@ -112,8 +113,8 @@ public class DnnGPUInstruction extends GPUInstruction {
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr,
double intermediateMemoryBudget) throws DMLRuntimeException {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
- if( !opcode.equals("channel_sums") ) {
- throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode);
+ if( !(opcode.equals("channel_sums") || opcode.equals("reshape_colmeans") || opcode.equals("update_ema") ) ) {
+ throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums or reshape_colmeans or update_ema, but found " + opcode);
}
_input1 = in1;
_input2 = in2;
@@ -126,7 +127,7 @@ public class DnnGPUInstruction extends GPUInstruction {
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr,
double intermediateMemoryBudget) throws DMLRuntimeException {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
- if( !opcode.equals("update_nesterov_x") ) {
+ if( !( opcode.equals("update_nesterov_x")) ) {
throw new DMLRuntimeException("Incorrect opcode: " + opcode);
}
_input1 = in1;
@@ -182,6 +183,22 @@ public class DnnGPUInstruction extends GPUInstruction {
_intermediateMemoryBudget = intermediateMemoryBudget;
}
+ public DnnGPUInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5,
+ CPOperand out, String opcode, String istr, double intermediateMemoryBudget) {
+ super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
+ if( !(opcode.equals("update_ema_var") || opcode.equals("batch_norm2d_bwd_dx")) ) {
+ throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be update_ema_var or batch_norm2d_bwd_dx, but found " + opcode);
+ }
+ _input1 = in;
+ _input2 = in2;
+ _input3 = in3;
+ _input4 = in4;
+ _input5 = in5;
+ _gputype = GPUINSTRUCTION_TYPE.Dnn;
+ _output = out;
+ _intermediateMemoryBudget = intermediateMemoryBudget;
+ }
+
public static DnnGPUInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
@@ -297,14 +314,15 @@ public class DnnGPUInstruction extends GPUInstruction {
return new DnnGPUInstruction(in1, null, out, opcode, str, stride,
padding, input_shape, filter_shape, Double.parseDouble(parts[15]));
}
- else if( opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") ) {
+ else if( opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply")
+ || opcode.equalsIgnoreCase("inv_var") ) {
InstructionUtils.checkNumFields(parts, 4);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
return new DnnGPUInstruction(in1, in2, out, opcode, str, Double.parseDouble(parts[4]));
}
- else if (opcode.equalsIgnoreCase("channel_sums")) {
+ else if (opcode.equalsIgnoreCase("channel_sums") || opcode.equals("reshape_colmeans") || opcode.equals("update_ema")) {
InstructionUtils.checkNumFields(parts, 4);
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
@@ -333,7 +351,7 @@ public class DnnGPUInstruction extends GPUInstruction {
CPOperand out2 = new CPOperand(parts[8]);
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0);
}
- else if (opcode.equalsIgnoreCase("batch_norm2d") || opcode.equalsIgnoreCase("lstm_backward")) {
+ else if (opcode.equalsIgnoreCase("lstm_backward")) {
InstructionUtils.checkNumFields(parts, 13);
CPOperand in1 = new CPOperand(parts[1]); // image
CPOperand in2 = new CPOperand(parts[2]); // scale
@@ -350,19 +368,6 @@ public class DnnGPUInstruction extends GPUInstruction {
CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
}
- else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
- InstructionUtils.checkNumFields(parts, 9);
- CPOperand in1 = new CPOperand(parts[1]); // image
- CPOperand in2 = new CPOperand(parts[2]); // dout
- CPOperand in3 = new CPOperand(parts[3]); // scale
- CPOperand in4 = new CPOperand(parts[4]); // epsilon
- CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean
- CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance
- CPOperand out = new CPOperand(parts[7]); // dX
- CPOperand out2 = new CPOperand(parts[8]); // dScale
- CPOperand out3 = new CPOperand(parts[9]); // dBias
- return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0);
- }
else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
InstructionUtils.checkNumFields(parts, 7);
CPOperand in = new CPOperand(parts[1]);
@@ -374,21 +379,25 @@ public class DnnGPUInstruction extends GPUInstruction {
CPOperand out = new CPOperand(parts[7]);
return new DnnGPUInstruction(in, in2, in3, in4, in5, in6, out, opcode, str, 0);
}
- else if (opcode.equalsIgnoreCase("batch_norm2d_train")) {
- InstructionUtils.checkNumFields(parts, 12);
- CPOperand in1 = new CPOperand(parts[1]); // image
- CPOperand in2 = new CPOperand(parts[2]); // gamma
- CPOperand in3 = new CPOperand(parts[3]); // beta
- CPOperand in4 = new CPOperand(parts[4]); // ema_mean
- CPOperand in5 = new CPOperand(parts[5]); // ema_var
- CPOperand in6 = new CPOperand(parts[6]); // eps
- CPOperand in7 = new CPOperand(parts[7]); // mu
- CPOperand out = new CPOperand(parts[8]); // out
- CPOperand out2 = new CPOperand(parts[9]); // ema_mean_upd
- CPOperand out3 = new CPOperand(parts[10]); // ema_var_upd
- CPOperand out4 = new CPOperand(parts[11]); // cache_mean
- CPOperand out5 = new CPOperand(parts[12]); // cache_inv_var
- return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0);
+ else if (opcode.equalsIgnoreCase("batch_norm2d_bwd_dx")) {
+ InstructionUtils.checkNumFields(parts, 6);
+ CPOperand in = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand in3 = new CPOperand(parts[3]);
+ CPOperand in4 = new CPOperand(parts[4]);
+ CPOperand in5 = new CPOperand(parts[5]);
+ CPOperand out = new CPOperand(parts[6]);
+ return new DnnGPUInstruction(in, in2, in3, in4, in5, out, opcode, str, 0);
+ }
+ else if (opcode.equalsIgnoreCase("update_ema_var")) {
+ InstructionUtils.checkNumFields(parts, 6);
+ CPOperand in = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand in3 = new CPOperand(parts[3]);
+ CPOperand in4 = new CPOperand(parts[4]);
+ CPOperand in5 = new CPOperand(parts[5]);
+ CPOperand out = new CPOperand(parts[6]);
+ return new DnnGPUInstruction(in, in2, in3, in4, in5, out, opcode, str, 0);
}
else {
throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str);
@@ -396,211 +405,185 @@ public class DnnGPUInstruction extends GPUInstruction {
}
private void processBiasInstruction(String instOpcode, ExecutionContext ec) {
- GPUStatistics.incrementNoOfExecutedGPUInst();
- MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
- MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName());
- MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns());
-
- if(instOpcode.equalsIgnoreCase("bias_add"))
- LibMatrixCUDA.biasAdd(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out);
- else if(instOpcode.equalsIgnoreCase("bias_multiply"))
- LibMatrixCUDA.biasMultiply(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out);
- // release inputs/outputs
- ec.releaseMatrixInputForGPUInstruction(_input1.getName());
- ec.releaseMatrixInputForGPUInstruction(_input2.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output.getName());
- }
-
- private void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException {
- GPUStatistics.incrementNoOfExecutedGPUInst();
- MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
- MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
- MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
- MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName());
- MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName());
-
- String phase = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getStringValue();
- double epsilon = ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue();
-
- MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
-
- if(phase.equalsIgnoreCase("train")) {
- double exponentialAverageFactor = 1-ec.getScalarInput(_input8.getName(), _input8.getValueType(), _input8.isLiteral()).getDoubleValue();
- MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
- MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
- MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
- MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
- LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(),
- image, scale, bias, runningMean, runningVar, ret,
- retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
- ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
- }
- else if(phase.equalsIgnoreCase("test")) {
- LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), getExtendedOpcode(),
- image, scale, bias, runningMean, runningVar, ret, epsilon);
- ec.setMatrixOutput(_output2.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true), getExtendedOpcode());
- ec.setMatrixOutput(_output3.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true), getExtendedOpcode());
- ec.setMatrixOutput(_output4.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true), getExtendedOpcode());
- ec.setMatrixOutput(_output5.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true), getExtendedOpcode());
- }
- else {
- throw new DMLRuntimeException("Incorrect mode: Expected either train or test, but found " + phase);
+ try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+ fetcher.add("input", _input1).add("bias", _input2);
+
+ MatrixObject input = fetcher.getInputMatrixObject("input");
+ MatrixObject bias = fetcher.getInputMatrixObject("bias");
+ MatrixObject out = fetcher.getOutputMatrixObject(input.getNumRows(), input.getNumColumns());
+
+ if(instOpcode.equalsIgnoreCase("bias_add"))
+ LibMatrixCUDA.biasAdd(gCtx, instName, input, bias, out);
+ else if(instOpcode.equalsIgnoreCase("bias_multiply"))
+ LibMatrixCUDA.biasMultiply(gCtx, instName, input, bias, out);
}
-
- // release inputs/outputs
- ec.releaseMatrixInputForGPUInstruction(_input1.getName());
- ec.releaseMatrixInputForGPUInstruction(_input2.getName());
- ec.releaseMatrixInputForGPUInstruction(_input3.getName());
- ec.releaseMatrixInputForGPUInstruction(_input4.getName());
- ec.releaseMatrixInputForGPUInstruction(_input5.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
- private void processBatchNorm2dTrainInstruction(ExecutionContext ec) throws DMLRuntimeException {
- GPUStatistics.incrementNoOfExecutedGPUInst();
- MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
- MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
- MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
- MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName());
- MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName());
-
- double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue();
- double exponentialAverageFactor = 1-ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue();
-
- MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
- MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
- MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
- MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
- MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
-
- LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(),
- image, scale, bias, runningMean, runningVar, ret,
- retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
-
- // release inputs/outputs
- ec.releaseMatrixInputForGPUInstruction(_input1.getName());
- ec.releaseMatrixInputForGPUInstruction(_input2.getName());
- ec.releaseMatrixInputForGPUInstruction(_input3.getName());
- ec.releaseMatrixInputForGPUInstruction(_input4.getName());
- ec.releaseMatrixInputForGPUInstruction(_input5.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
+ private void processInverseVarianceInstruction(String instOpcode, ExecutionContext ec) {
+ try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+ fetcher.add("X", _input1).addScalar("eps", _input2);
+
+ int rows = LibMatrixCUDA.toInt(fetcher.getInputNumRows("X"));
+ int cols = LibMatrixCUDA.toInt(fetcher.getInputNumColumns("X"));
+
+ // invVar(X, C, eps, size);
+ LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("invVar",
+ ExecutionConfig.getConfigForSimpleVectorOperations(rows*cols),
+ fetcher.getInputPointer("X"), fetcher.getOutputPointer(rows, cols),
+ fetcher.getDouble("eps"), rows*cols);
+ }
}
private void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException {
- GPUStatistics.incrementNoOfExecutedGPUInst();
- MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
- MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
- MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
- MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName());
- MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName());
- double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue();
-
- MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
- LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), getExtendedOpcode(),
- image, scale, bias, runningMean, runningVar, ret, epsilon);
-
- // release inputs/outputs
- ec.releaseMatrixInputForGPUInstruction(_input1.getName());
- ec.releaseMatrixInputForGPUInstruction(_input2.getName());
- ec.releaseMatrixInputForGPUInstruction(_input3.getName());
- ec.releaseMatrixInputForGPUInstruction(_input4.getName());
- ec.releaseMatrixInputForGPUInstruction(_input5.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+ try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+ fetcher.add("image", _input1).add("scale", _input2).add("bias", _input3)
+ .add("runningMean", _input4).add("runningVar", _input5).addScalar("epsilon", _input6);
+
+ double epsilon = fetcher.getDouble("epsilon");
+ if(epsilon < JCudnn.CUDNN_BN_MIN_EPSILON) {
+ throw new DMLRuntimeException("The epsilon (" + epsilon + ") cannot be less than CUDNN_BN_MIN_EPSILON=(" + JCudnn.CUDNN_BN_MIN_EPSILON + ")");
+ }
+
+ MatrixObject image = fetcher.getInputMatrixObject("image");
+ LibMatrixCuDNN.batchNormalizationForwardInference(gCtx, instName,
+ image, fetcher.getInputMatrixObject("scale"), fetcher.getInputMatrixObject("bias"),
+ fetcher.getInputMatrixObject("runningMean"), fetcher.getInputMatrixObject("runningVar"),
+ fetcher.getOutputMatrixObject(image.getNumRows(), image.getNumColumns()), epsilon);
+ }
}
- public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
- GPUStatistics.incrementNoOfExecutedGPUInst();
- MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
- MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
- MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input3.getName());
- double epsilon = ec.getScalarInput(_input4.getName(), _input4.getValueType(), _input4.isLiteral()).getDoubleValue();
- MatrixObject resultSaveMean = getMatrixInputForGPUInstruction(ec, _input5.getName());
- MatrixObject resultSaveInvVariance = getMatrixInputForGPUInstruction(ec, _input6.getName());
-
- MatrixObject dX = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
- MatrixObject dScale = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), scale.getNumRows(), scale.getNumColumns());
- MatrixObject dBias = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), scale.getNumRows(), scale.getNumColumns());
-
- LibMatrixCuDNN.batchNormalizationBackward(ec.getGPUContext(0), getExtendedOpcode(), image,
- dout, scale, dX, dScale, dBias,
- epsilon, resultSaveMean, resultSaveInvVariance);
-
- // release inputs/outputs
- ec.releaseMatrixInputForGPUInstruction(_input1.getName());
- ec.releaseMatrixInputForGPUInstruction(_input2.getName());
- ec.releaseMatrixInputForGPUInstruction(_input3.getName());
- ec.releaseMatrixInputForGPUInstruction(_input5.getName());
- ec.releaseMatrixInputForGPUInstruction(_input6.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
+ private void processBatchNorm2dBackwardDxInstruction(ExecutionContext ec) throws DMLRuntimeException {
+ try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+ fetcher.add("X", _input1).add("dout", _input2).add("gamma", _input3)
+ .add("resultSaveMean", _input4).add("resultSaveInvVariance", _input5);
+
+ // #define CUDNN_BN_MIN_EPSILON 1e-5 // Minimum epsilon allowed to be used in the Batch Normalization formula
+ double epsilon = 1e-4;
+ MatrixObject image = fetcher.getInputMatrixObject("X");
+ LibMatrixCuDNN.batchNormalizationBackwardDX(gCtx, instName, image,
+ fetcher.getInputMatrixObject("dout"), fetcher.getInputMatrixObject("gamma"),
+ fetcher.getOutputMatrixObject(image.getNumRows(), image.getNumColumns()), epsilon, fetcher.getInputMatrixObject("resultSaveMean"),
+ fetcher.getInputMatrixObject("resultSaveInvVariance"));
+ }
}
-
+
+
// (X > 0) * dout
public void processReLUBackwardInstruction(ExecutionContext ec) {
- GPUStatistics.incrementNoOfExecutedGPUInst();
- MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
- MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
-
- MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns());
-
- LibMatrixCUDA.reluBackward(ec.getGPUContext(0), getExtendedOpcode(), input, dout, out);
- // release inputs/outputs
- ec.releaseMatrixInputForGPUInstruction(_input1.getName());
- ec.releaseMatrixInputForGPUInstruction(_input2.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+ try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+ fetcher.add("X", _input1).add("dout", _input2);
+ MatrixObject X = fetcher.getInputMatrixObject("X");
+ LibMatrixCUDA.reluBackward(gCtx, instName, X,
+ fetcher.getInputMatrixObject("dout"), fetcher.getOutputMatrixObject(X.getNumRows(), X.getNumColumns()));
+ }
}
private void processChannelSumsInstruction(ExecutionContext ec) {
- GPUStatistics.incrementNoOfExecutedGPUInst();
- MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
- int C = (int) ec.getScalarInput(_input2.getName(), _input2.getValueType(), _input2.isLiteral()).getLongValue();
- int HW = (int) ec.getScalarInput(_input3.getName(), _input3.getValueType(), _input3.isLiteral()).getLongValue();
- if(C*HW != input.getNumColumns()) {
- throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns());
+ try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+ fetcher.add("X", _input1).addScalar("C", _input2).addScalar("HW", _input3);
+ int C = fetcher.getInteger("C");
+ int HW = fetcher.getInteger("HW");
+ fetcher.validateDimensions("X", -1, C*HW);
+ LibMatrixCUDA.channelSums(gCtx, instName,
+ fetcher.getInputMatrixObject("X"),
+ fetcher.getOutputMatrixObject(C, 1), C, HW);
+ }
+ }
+
+ private void processEMAInstruction(ExecutionContext ec) {
+ // "ema_mean", "mean", "mu"
+ try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+ fetcher.add("ema_mean", _input1).add("mean", _input2).addScalar("mu", _input3);
+ double mu = fetcher.getDouble("mu");
+
+ int rows = LibMatrixCUDA.toInt(fetcher.getInputNumRows("ema_mean"));
+ int cols = LibMatrixCUDA.toInt(fetcher.getInputNumColumns("ema_mean"));
+
+ fetcher.validateDimensions("mean", rows, cols);
+
+ // aXplusbY(X, Y, C, a, b, size);
+ LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("aXplusbY",
+ ExecutionConfig.getConfigForSimpleVectorOperations(rows*cols),
+ fetcher.getInputPointer("ema_mean"), fetcher.getInputPointer("mean"),
+ fetcher.getOutputPointer(rows, cols),
+ mu, (1-mu), rows*cols);
+ }
+ }
+
+ private void processReshapeColMeansInstruction(ExecutionContext ec) {
+ try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+ fetcher.add("X", _input1).addScalar("C", _input2).addScalar("HW", _input3);
+ int C = fetcher.getInteger("C");
+ int HW = fetcher.getInteger("HW");
+ fetcher.validateDimensions("X", -1, C*HW);
+ int rows = LibMatrixCUDA.toInt(fetcher.getInputNumRows("X"));
+ int cols = LibMatrixCUDA.toInt(fetcher.getInputNumColumns("X"));
+ // output = matrix(colMeans(X), rows=C, cols=Hin*Win)
+ LibMatrixCUDA.colMeans(gCtx, instName,
+ fetcher.getInputPointer("X"),
+ fetcher.getOutputPointer(C, HW), rows, cols);
}
- MatrixObject outputBlock = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), C, 1);
-
- LibMatrixCUDA.channelSums(ec.getGPUContext(0), getExtendedOpcode(), input, outputBlock, C, HW);
-
- // release inputs/outputs
- ec.releaseMatrixInputForGPUInstruction(_input1.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
+ private void processUpdateEMAVarInstruction(ExecutionContext ec) {
+ try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+ // "subgrp_means", "X", "C", "HW", "varConst1"
+ fetcher.add("subgrp_means", _input1).add("X", _input2).addScalar("C", _input3)
+ .addScalar("HW", _input4).addScalar("varConst1", _input5);
+
+ // subgrp_vars = matrix(colVars(X) * varConst1, rows=C, cols=Hin*Win)
+ // var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+ // --->
+ // subgrp_vars = matrix(colVars(X), rows=C, cols=HW)
+ // var = rowMeans(subgrp_vars)*varConst1 + rowVars(subgrp_means)*((HW-1)/HW)
+ int C = fetcher.getInteger("C");
+ int HW = fetcher.getInteger("HW");
+ double varConst1 = fetcher.getDouble("varConst1");
+ fetcher.validateDimensions("subgrp_means", C, HW);
+ fetcher.validateDimensions("X", -1, C*HW);
+
+ Pointer subgrp_vars = gCtx.allocate(instName, C*HW*LibMatrixCUDA.sizeOfDataType);
+ // subgrp_vars <- colVars(X)
+ LibMatrixCUDA.colVars(gCtx, instName, fetcher.getInputPointer("X"), subgrp_vars,
+ LibMatrixCUDA.toInt(fetcher.getInputNumRows("X")), C*HW);
+
+ // tmp1 <- rowMeans(subgrp_vars)
+ Pointer tmp1 = gCtx.allocate(instName, C*LibMatrixCUDA.sizeOfDataType);
+ LibMatrixCUDA.rowMeans(gCtx, instName, subgrp_vars, tmp1, C, HW);
+ gCtx.cudaFreeHelper(instName, subgrp_vars, gCtx.EAGER_CUDA_FREE);
+
+ // out <- rowVars(subgrp_means)
+ Pointer out = fetcher.getOutputPointer(C, 1);
+ LibMatrixCUDA.rowVars(gCtx, instName, fetcher.getInputPointer("subgrp_means"), out, C, HW);
+
+ // var = tmp1*varConst1 + out*((HW-1)/HW)
+ LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("aXplusbC",
+ ExecutionConfig.getConfigForSimpleVectorOperations(C),
+ tmp1, out,
+ varConst1, (((double)HW-1)/HW), C);
+ gCtx.cudaFreeHelper(instName, tmp1, gCtx.EAGER_CUDA_FREE);
+ }
+ }
+
+
+
private void processNesterovUpdateInstruction(ExecutionContext ec) {
- GPUStatistics.incrementNoOfExecutedGPUInst();
- MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
- MatrixObject v = getMatrixInputForGPUInstruction(ec, _input2.getName());
- MatrixObject v_prev = getMatrixInputForGPUInstruction(ec, _input3.getName());
- double mu = (int) ec.getScalarInput(_input4.getName(), _input4.getValueType(), _input4.isLiteral()).getDoubleValue();
- int rows = LibMatrixCUDA.toInt(input.getNumRows());
- int cols = LibMatrixCUDA.toInt(input.getNumColumns());
- MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), rows, cols);
-
- GPUContext gCtx = ec.getGPUContext(0);
- String instName = getExtendedOpcode();
- LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x",
- ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)),
- LibMatrixCUDA.getDensePointer(gCtx, input, instName),
- LibMatrixCUDA.getDensePointer(gCtx, v, instName),
- LibMatrixCUDA.getDensePointer(gCtx, v_prev, instName),
- mu,
- LibMatrixCUDA.getDensePointer(gCtx, out, instName),
- rows*cols);
-
- // release inputs/outputs
- ec.releaseMatrixInputForGPUInstruction(_input1.getName());
- ec.releaseMatrixInputForGPUInstruction(_input2.getName());
- ec.releaseMatrixInputForGPUInstruction(_input3.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+ try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+ fetcher.add("input", _input1).add("v", _input2).add("v_prev", _input3)
+ .addScalar("mu", _input4);
+ MatrixObject input = fetcher.getInputMatrixObject("input");
+ int rows = LibMatrixCUDA.toInt(input.getNumRows());
+ int cols = LibMatrixCUDA.toInt(input.getNumColumns());
+
+ LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x",
+ ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)),
+ fetcher.getInputPointer("input"),
+ fetcher.getInputPointer("v"),
+ fetcher.getInputPointer("v_prev"),
+ fetcher.getDouble("mu"),
+ fetcher.getOutputPointer(rows, cols),
+ rows*cols);
+ }
}
private static int toInt(long num) throws DMLRuntimeException {
@@ -610,32 +593,18 @@ public class DnnGPUInstruction extends GPUInstruction {
return (int)num;
}
-// private Pointer transpose(ExecutionContext ec, MatrixObject X) throws DMLRuntimeException {
-// GPUContext gCtx = ec.getGPUContext(0);
-// String instructionName = getExtendedOpcode();
-// long numRowsX = X.getNumRows(); long numColsX = X.getNumColumns();
-// Pointer tX = gCtx.allocate(instructionName, numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType);
-// jcuda.runtime.JCuda.cudaMemcpy(tX, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType, jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice);
-// // LibMatrixCUDA.denseTranspose(ec, gCtx, instructionName, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), tX, numRowsX, numColsX);
-// return tX;
-// }
-
private void processLstmBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
- GPUStatistics.incrementNoOfExecutedGPUInst();
- GPUContext gCtx = ec.getGPUContext(0);
- String instructionName = getExtendedOpcode();
-
MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName());
int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M)
- Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName);
+ Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instName);
MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
long numRowsW = W.getNumRows();
int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures
- Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M);
- Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M);
- Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+ Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+ Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+ Pointer cudnnWPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M);
@@ -644,20 +613,20 @@ public class DnnGPUInstruction extends GPUInstruction {
MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName());
- Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName);
+ Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName);
int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
long numColsX = X.getNumColumns();
int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
- Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
+ Pointer cudnnInput = gCtx.allocate(instName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
xPointer, cudnnInput, N, D, T*D, N*T*D);
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
- Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName);
+ Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue();
- // LibMatrixCuDNN.lstm(ec, gCtx, instructionName,
+ // LibMatrixCuDNN.lstm(ec, gCtx, instName,
// cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
// String xName, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName, // input
// String dxName, String dwName, String dbName, String dhxName, String dcxName, // output
@@ -668,12 +637,12 @@ public class DnnGPUInstruction extends GPUInstruction {
String dcxName = _output5.getName();
String doutName = _input7.getName();
String dcyName = _input8.getName();
- LibMatrixCuDNN.lstmBackward(ec, gCtx, instructionName,
+ LibMatrixCuDNN.lstmBackward(ec, gCtx, instName,
cudnnInput, out0Pointer, c0Pointer, cudnnWPointer, doutName, dcyName, // input
dxName, dwName, dbName, dhxName, dcxName, // output
return_sequences, N, M, D, T);
- gCtx.cudaFreeHelper(instructionName, cudnnWPointer, gCtx.EAGER_CUDA_FREE);
- gCtx.cudaFreeHelper(instructionName, cudnnInput, gCtx.EAGER_CUDA_FREE);
+ gCtx.cudaFreeHelper(instName, cudnnWPointer, gCtx.EAGER_CUDA_FREE);
+ gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE);
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input4.getName());
@@ -686,21 +655,17 @@ public class DnnGPUInstruction extends GPUInstruction {
// weight W:(D+M+2, 4M)
// previous output out0 (also represented by hx) and cell state c0 (also represented by cx): (N, M) ==> (1, M, N)
// out: (N, T*M) or (N, M) ==> (T, M, N)
- GPUStatistics.incrementNoOfExecutedGPUInst();
- GPUContext gCtx = ec.getGPUContext(0);
- String instructionName = getExtendedOpcode();
-
MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName());
int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M)
- Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName);
+ Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instName);
MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
long numRowsW = W.getNumRows();
int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures
- Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M);
- Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M);
- Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+ Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+ Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+ Pointer cudnnWPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M);
@@ -711,21 +676,21 @@ public class DnnGPUInstruction extends GPUInstruction {
// Beause the matrices are released immediately, the output for transpose need not be taken into account
MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName());
- Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName);
+ Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName);
int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
long numColsX = X.getNumColumns();
int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
- Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
+ Pointer cudnnInput = gCtx.allocate(instName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
xPointer, cudnnInput, N, D, T*D, N*T*D);
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
- Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName);
+ Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
- LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
- gCtx.cudaFreeHelper(instructionName, cudnnWPointer, gCtx.EAGER_CUDA_FREE);
- gCtx.cudaFreeHelper(instructionName, cudnnInput, gCtx.EAGER_CUDA_FREE);
+ LibMatrixCuDNN.lstm(ec, gCtx, instName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
+ gCtx.cudaFreeHelper(instName, cudnnWPointer, gCtx.EAGER_CUDA_FREE);
+ gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE);
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input4.getName());
@@ -736,10 +701,17 @@ public class DnnGPUInstruction extends GPUInstruction {
@Override
public void processInstruction(ExecutionContext ec) {
+ GPUStatistics.incrementNoOfExecutedGPUInst();
+ gCtx = ec.getGPUContext(0);
+ instName = getExtendedOpcode();
if (instOpcode.equalsIgnoreCase("bias_add") || instOpcode.equalsIgnoreCase("bias_multiply")) {
processBiasInstruction(instOpcode, ec);
return;
}
+ else if (instOpcode.equalsIgnoreCase("inv_var")) {
+ processInverseVarianceInstruction(instOpcode, ec);
+ return;
+ }
else if (instOpcode.equalsIgnoreCase("relu_backward")) {
processReLUBackwardInstruction(ec);
return;
@@ -748,10 +720,22 @@ public class DnnGPUInstruction extends GPUInstruction {
processChannelSumsInstruction(ec);
return;
}
+ else if (instOpcode.equalsIgnoreCase("update_ema")) {
+ processEMAInstruction(ec);
+ return;
+ }
+ else if (instOpcode.equalsIgnoreCase("reshape_colmeans")) {
+ processReshapeColMeansInstruction(ec);
+ return;
+ }
else if (instOpcode.equalsIgnoreCase("update_nesterov_x")) {
processNesterovUpdateInstruction(ec);
return;
}
+ else if (instOpcode.equalsIgnoreCase("update_ema_var")) {
+ processUpdateEMAVarInstruction(ec);
+ return;
+ }
else if (instOpcode.equalsIgnoreCase("lstm")) {
processLstmInstruction(ec);
return;
@@ -760,24 +744,14 @@ public class DnnGPUInstruction extends GPUInstruction {
processLstmBackwardInstruction(ec);
return;
}
- else if (instOpcode.equalsIgnoreCase("batch_norm2d")) {
- processBatchNorm2dInstruction(ec);
- return;
- }
- else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
- processBatchNorm2dBackwardInstruction(ec);
- return;
- }
else if (instOpcode.equalsIgnoreCase("batch_norm2d_test")) {
processBatchNorm2dTestInstruction(ec);
return;
}
- else if (instOpcode.equalsIgnoreCase("batch_norm2d_train")) {
- processBatchNorm2dTrainInstruction(ec);
+ else if (instOpcode.equalsIgnoreCase("batch_norm2d_bwd_dx")) {
+ processBatchNorm2dBackwardDxInstruction(ec);
return;
}
-
- GPUStatistics.incrementNoOfExecutedGPUInst();
int pad_h = getScalarInput(ec, _padding, 0);
int pad_w = getScalarInput(ec, _padding, 1);
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
new file mode 100644
index 0000000..8fcaec3
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.instructions.gpu;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
+import org.apache.sysml.runtime.matrix.data.Pair;
+import org.apache.sysml.utils.GPUStatistics;
+
+import jcuda.Pointer;
+
+public class GPUDenseInputPointerFetcher implements java.lang.AutoCloseable {
+ ExecutionContext _ec; GPUContext _gCtx; String _instName;
+ HashMap<String, CPOperand> _inputMatrices = new HashMap<>();
+ HashMap<String, MatrixObject> _inputMatrixObjects = new HashMap<>();
+ HashMap<String, CPOperand> _inputScalars = new HashMap<>();
+ CPOperand _output;
+ public GPUDenseInputPointerFetcher(ExecutionContext ec, GPUContext gCtx, String instName, CPOperand output) {
+ _ec = ec;
+ _gCtx = gCtx;
+ _instName = instName;
+ _output = output;
+ }
+ public GPUDenseInputPointerFetcher add(String var, CPOperand in) {
+ _inputMatrices.put(var, in);
+ return this;
+ }
+ public GPUDenseInputPointerFetcher addScalar(String var, CPOperand in) {
+ _inputScalars.put(var, in);
+ return this;
+ }
+ public double getDouble(String var) {
+ CPOperand in = _inputScalars.get(var);
+ return _ec.getScalarInput(in.getName(), in.getValueType(), in.isLiteral()).getDoubleValue();
+ }
+ public long getLong(String var) {
+ CPOperand in = _inputScalars.get(var);
+ return _ec.getScalarInput(in.getName(), in.getValueType(), in.isLiteral()).getLongValue();
+ }
+ public int getInteger(String var) {
+ CPOperand in = _inputScalars.get(var);
+ return LibMatrixCUDA.toInt(_ec.getScalarInput(in.getName(), in.getValueType(), in.isLiteral()).getLongValue());
+ }
+ public Pointer getInputPointer(String var) {
+ return LibMatrixCUDA.getDensePointer(_gCtx, getInputMatrixObject(var), _instName);
+ }
+ public long getInputNumRows(String var) {
+ return getInputMatrixObject(var).getNumRows();
+ }
+ public long getInputNumColumns(String var) {
+ return getInputMatrixObject(var).getNumColumns();
+ }
+ public MatrixObject getOutputMatrixObject(long numRows, long numCols) {
+ boolean isFinegrainedStats = ConfigurationManager.isFinegrainedStatistics();
+ long t0 = isFinegrainedStats ? System.nanoTime() : 0;
+ Pair<MatrixObject, Boolean> mb = _ec.getDenseMatrixOutputForGPUInstruction(_output.getName(), numRows, numCols);
+ if (isFinegrainedStats && mb.getValue()) GPUStatistics.maintainCPMiscTimes(_instName,
+ GPUInstruction.MISC_TIMER_ALLOCATE_DENSE_OUTPUT, System.nanoTime() - t0);
+ return mb.getKey();
+ }
+ public Pointer getOutputPointer(long numRows, long numCols) {
+ return LibMatrixCUDA.getDensePointer(_gCtx, getOutputMatrixObject(numRows, numCols), _instName);
+ }
+ public MatrixObject getInputMatrixObject(String var) {
+ CPOperand in = _inputMatrices.get(var);
+ if(!_inputMatrixObjects.containsKey(var)) {
+ _inputMatrixObjects.put(var, _ec.getMatrixInputForGPUInstruction(in.getName(), _instName));
+ }
+ return _inputMatrixObjects.get(var);
+ }
+ public void validateDimensions(String var, long numRows, long numCols) {
+ MatrixObject mo = getInputMatrixObject(var);
+ if(numRows > 0 && mo.getNumRows() != numRows) {
+ throw new DMLRuntimeException("Expected number of rows of subgrp_means to be " + numRows + ", but found " + mo.getNumRows());
+ }
+ else if(numCols > 0 && mo.getNumColumns() != numCols) {
+ throw new DMLRuntimeException("Expected number of columns of subgrp_means to be " + numCols + ", but found " + mo.getNumColumns());
+ }
+ }
+ @Override
+ public void close() {
+ for(CPOperand in : _inputMatrices.values()) {
+ _ec.releaseMatrixInputForGPUInstruction(in.getName());
+ }
+ _ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
index 2e43b99..e01c71a 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -242,7 +242,7 @@ public class GPUMemoryManager {
* @return allocated pointer
*/
public Pointer malloc(String opcode, long size) {
- if(size < 0) {
+ if(size <= 0) {
throw new DMLRuntimeException("Cannot allocate memory of size " + byteCountToDisplaySize(size));
}
if(DEBUG_MEMORY_LEAK) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index d02a875..f3f8434 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -208,7 +208,7 @@ public class LibMatrixCUDA {
return gCtx.getCusparseHandle();
}
- protected static cublasHandle getCublasHandle(GPUContext gCtx) {
+ public static cublasHandle getCublasHandle(GPUContext gCtx) {
return gCtx.getCublasHandle();
}
@@ -302,7 +302,7 @@ public class LibMatrixCUDA {
}
}
- protected static Pointer dataTypePointerTo(double value) {
+ public static Pointer dataTypePointerTo(double value) {
if(value == 1) {
return one();
}
@@ -313,7 +313,6 @@ public class LibMatrixCUDA {
return _dataTypePointerTo(value);
}
}
-
/**
* This method computes the backpropagation errors for previous layer of relu operation
@@ -753,11 +752,11 @@ public class LibMatrixCUDA {
break;
}
case REDUCTION_COL: {
- reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
+ rowMeans(gCtx, instName, in, out, rlen, clen);
break;
}
case REDUCTION_ROW: {
- reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
+ colMeans(gCtx, instName, in, out, rlen, clen);
break;
}
default:
@@ -818,13 +817,14 @@ public class LibMatrixCUDA {
break;
}
case OP_VARIANCE : {
- // Temporary GPU array for
- Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
- Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
+
switch(reductionDirection) {
case REDUCTION_ALL: {
+ // Temporary GPU array for
+ Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
+ Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
double result = reduceAll(gCtx, instName, "reduce_sum", in, size);
double mean = result / size;
@@ -837,50 +837,21 @@ public class LibMatrixCUDA {
double result2 = reduceAll(gCtx, instName, "reduce_sum", tmp2, size);
double variance = result2 / (size - 1);
ec.setScalarOutput(output, new DoubleObject(variance));
-
+ gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
+ gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE);
break;
}
case REDUCTION_COL: {
- reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
- // Subtract the row-wise mean from every element in the matrix
- BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
- matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.COLUMN.code(), tmp, minusOp);
-
- squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
-
- Pointer tmpRow = gCtx.allocate(instName, rlen * sizeOfDataType);
- reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, clen);
-
- ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1);
- matrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, divideOp);
-
- gCtx.cudaFreeHelper(instName, tmpRow, gCtx.EAGER_CUDA_FREE);
-
+ rowVars(gCtx, instName, in, out, rlen, clen);
break;
}
case REDUCTION_ROW: {
- reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
- // Subtract the columns-wise mean from every element in the matrix
- BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
- matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.ROW.code(), tmp, minusOp);
-
- squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
-
- Pointer tmpCol = gCtx.allocate(instName, clen * sizeOfDataType);
- reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, clen);
-
- ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1);
- matrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, divideOp);
-
- gCtx.cudaFreeHelper(instName, tmpCol, gCtx.EAGER_CUDA_FREE);
-
+ colVars(gCtx, instName, in, out, rlen, clen);
break;
}
default:
throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for variance");
}
- gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
- gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE);
break;
}
case OP_MAXINDEX : {
@@ -904,6 +875,59 @@ public class LibMatrixCUDA {
default : throw new DMLRuntimeException("Internal Error - Invalid GPU Unary aggregate function!");
}
}
+
+ public static void rowMeans(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) {
+ LibMatrixCUDA.reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
+ }
+
+ public static void colMeans(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) {
+ reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
+ }
+
+ public static void colVars(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) {
+ int size = rlen * clen;
+ Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
+ Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
+ reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
+ // Subtract the columns-wise mean from every element in the matrix
+ BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
+ matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.ROW.code(), tmp, minusOp);
+
+ squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
+
+ Pointer tmpCol = gCtx.allocate(instName, clen * sizeOfDataType);
+ reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, clen);
+
+ ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1);
+ matrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, divideOp);
+
+ gCtx.cudaFreeHelper(instName, tmpCol, gCtx.EAGER_CUDA_FREE);
+ gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
+ gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE);
+ }
+
+ public static void rowVars(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) {
+ int size = rlen * clen;
+ Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
+ Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
+
+ reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
+ // Subtract the row-wise mean from every element in the matrix
+ BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
+ matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.COLUMN.code(), tmp, minusOp);
+
+ squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
+
+ Pointer tmpRow = gCtx.allocate(instName, rlen * sizeOfDataType);
+ reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, clen);
+
+ ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1);
+ matrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, divideOp);
+
+ gCtx.cudaFreeHelper(instName, tmpRow, gCtx.EAGER_CUDA_FREE);
+ gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
+ gCtx.cudaFreeHelper(instName, tmp2, gCtx.EAGER_CUDA_FREE);
+ }
/**
* Helper method to square a matrix in GPU memory
@@ -970,7 +994,7 @@ public class LibMatrixCUDA {
* @param rows number of rows in input matrix
* @param cols number of columns in input matrix
*/
- private static void reduceRow(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) {
+ public static void reduceRow(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) {
if(LOG.isTraceEnabled()) {
LOG.trace("GPU : reduceRow for " + kernelFunction + ", GPUContext=" + gCtx);
}
@@ -997,7 +1021,7 @@ public class LibMatrixCUDA {
* @param rows number of rows in input matrix
* @param cols number of columns in input matrix
*/
- private static void reduceCol(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) {
+ public static void reduceCol(GPUContext gCtx, String instName, String kernelFunction, Pointer in, Pointer out, int rows, int cols) {
if(LOG.isTraceEnabled()) {
LOG.trace("GPU : reduceCol for " + kernelFunction + ", GPUContext=" + gCtx);
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index d3b5984..e7955e1 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -1108,23 +1108,29 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
runningMeanPtr, runningVarPtr, epsilon));
}
+ private static void validateDimensions(MatrixObject mo, long expectedRows, long expectedCols) {
+ if(mo.getNumRows() != expectedRows || mo.getNumColumns() != expectedCols) {
+ throw new DMLRuntimeException("Incorrect dimensions for the input matrix object. Expected [" + expectedRows + ", " + expectedCols+ "], but found "
+ + "[" + mo.getNumRows() + ", " + mo.getNumColumns() + "].");
+ }
+ }
+
/**
- * This method computes the backpropagation errors for image, scale and bias of batch normalization layer
+ * This method computes the backpropagation errors for image of batch normalization layer
* @param gCtx a valid {@link GPUContext}
* @param instName name of the instruction
* @param image input image
* @param dout input errors of shape C, H, W
* @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1]
* @param dX (output) backpropagation errors for previous layer
- * @param dScale backpropagation error for scale
- * @param dBias backpropagation error for bias
* @param epsilon epsilon value used in the batch normalization formula
* @param resultSaveMean (input) running mean accumulated during training phase: shape [1, C, 1, 1]
* @param resultSaveInvVariance (input) running variance accumulated during training phase: shape [1, C, 1, 1]
* @throws DMLRuntimeException if error occurs
*/
- public static void batchNormalizationBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
- MatrixObject scale, MatrixObject dX, MatrixObject dScale, MatrixObject dBias,
+ public static void batchNormalizationBackwardDX(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
+ MatrixObject scale, MatrixObject dX,
+ // MatrixObject dScale, MatrixObject dBias,
double epsilon, MatrixObject resultSaveMean, MatrixObject resultSaveInvVariance) throws DMLRuntimeException {
if(LOG.isTraceEnabled()) {
LOG.trace("GPU : batchNormalizationBackward" + ", GPUContext=" + gCtx);
@@ -1133,7 +1139,13 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
int N = toInt(image.getNumRows());
int C = toInt(scale.getNumRows());
long CHW = image.getNumColumns();
-
+
+ validateDimensions(scale, C, 1);
+ validateDimensions(dX, N, CHW);
+ validateDimensions(dout, N, CHW);
+ validateDimensions(resultSaveMean, C, 1);
+ validateDimensions(resultSaveInvVariance, C, 1);
+
// Allocate descriptors
cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW,
new MatrixObject[] {image, dout}, new MatrixObject[] {dX});
@@ -1144,18 +1156,17 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
Pointer doutPtr = getDensePointerForCuDNN(gCtx, dout, instName);
Pointer scalePtr = getDensePointerForCuDNN(gCtx, scale, instName);
Pointer dXPtr = getDensePointerForCuDNN(gCtx, dX, instName);
- Pointer dScalePtr = getDensePointerForCuDNN(gCtx, dScale, instName);
- Pointer dBiasPtr = getDensePointerForCuDNN(gCtx, dBias, instName);
-
+ Pointer dScalePtr = gCtx.allocate(instName, C*LibMatrixCUDA.sizeOfDataType); // getDensePointerForCuDNN(gCtx, dScale, instName);
+ Pointer dBiasPtr = gCtx.allocate(instName, C*LibMatrixCUDA.sizeOfDataType); //getDensePointerForCuDNN(gCtx, dBias, instName);
Pointer resultSaveMeanPtr = getDensePointerForCuDNN(gCtx, resultSaveMean, instName);
Pointer resultSaveInvVariancePtr = getDensePointerForCuDNN(gCtx, resultSaveInvVariance, instName);
-
- // ignoring resultSaveMean and resultSaveVariance as it requires state management
- checkStatus(cudnnBatchNormalizationBackward(getCudnnHandle(gCtx),
+ cudnnBatchNormalizationBackward(getCudnnHandle(gCtx),
jcuda.jcudnn.cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL, one(), zero(), one(), zero(),
nCHWDescriptor, imagePtr, nCHWDescriptor, doutPtr, nCHWDescriptor, dXPtr,
- scaleTensorDesc, scalePtr, dScalePtr, dBiasPtr, epsilon, resultSaveMeanPtr, resultSaveInvVariancePtr));
+ scaleTensorDesc, scalePtr, dScalePtr, dBiasPtr, epsilon, resultSaveMeanPtr, resultSaveInvVariancePtr);
+ gCtx.cudaFreeHelper(instName, dScalePtr, gCtx.EAGER_CUDA_FREE);
+ gCtx.cudaFreeHelper(instName, dBiasPtr, gCtx.EAGER_CUDA_FREE);
}
private static void validateBatchNormalizationDimensions(MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar, int C) throws DMLRuntimeException {
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
index d96feac..d7e7b24 100644
--- a/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/BatchNormTest.java
@@ -55,8 +55,8 @@ public class BatchNormTest extends GPUTests {
int imgSize = 32;
int numChannels = 3;
double sparsity = 0.9;
- String scriptStr = "source(\"nn/layers/batch_norm2d_old.dml\") as batch_norm2d_old;\n "
- + "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d_old::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, ema_var, 0.9, 1e-3)";
+ String scriptStr = "source(\"nn/layers/batch_norm2d.dml\") as batch_norm2d;\n "
+ + "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"" + mode + "\", ema_mean, ema_var, 0.9, 1e-3)";
HashMap<String, Object> inputs = new HashMap<>();
inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 10, sparsity, seed));
inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
@@ -68,19 +68,40 @@ public class BatchNormTest extends GPUTests {
List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
if(mode.equals("test")) {
assertHeavyHitterPresent("gpu_batch_norm2d_test");
- for(int i = 0; i < outputs.size(); i++) {
- assertEqualObjects(outCPU.get(i), outGPU.get(i));
- }
}
else {
- //assertHeavyHitterPresent("gpu_batch_norm2d_train");
- double [] threshold = new double[outputs.size()];
- Arrays.fill(threshold, getTHRESHOLD());
- // Handle loss of precision in CuDNN kernel
- threshold[2] = 1e-3;
- for(int i = 0; i < outputs.size()-1; i++) {
- assertEqualObjects(outCPU.get(i), outGPU.get(i), threshold[i]);
- }
+ assertHeavyHitterPresent("gpu_batch_norm2d_test");
+ assertHeavyHitterPresent("gpu_reshape_colmeans");
+ assertHeavyHitterPresent("gpu_update_ema_var");
}
+ assertEqualObjects(outCPU.get(0), outGPU.get(0));
+ assertEqualObjects(outCPU.get(1), outGPU.get(1));
+ assertEqualObjects(outCPU.get(2), outGPU.get(2));
+ assertEqualObjects(outCPU.get(3), outGPU.get(3));
+ assertEqualObjects(outCPU.get(4), outGPU.get(4));
+ }
+
+ @Test
+ public void testBatchNormBackward() {
+ int imgSize = 32;
+ int numChannels = 3;
+ double sparsity = 0.9;
+ String scriptStr = "source(\"nn/layers/batch_norm2d.dml\") as batch_norm2d;\n "
+ + "[output, ema_mean_upd, ema_var_upd, cache_mean, cache_var] = batch_norm2d::forward(x, gamma, beta, " + numChannels + ", " + imgSize + ", " + imgSize + ", \"train\", ema_mean, ema_var, 0.9, 1e-3);\n"
+ + "[dX, dgamma, dbeta] = batch_norm2d::backward(dout, cache_mean, cache_var, x, gamma, " + numChannels + ", " + imgSize + ", " + imgSize + ", 1e-3);";
+ HashMap<String, Object> inputs = new HashMap<>();
+ inputs.put("x", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 0, 10, sparsity, seed));
+ inputs.put("dout", generateInputMatrix(spark, 32, numChannels*imgSize*imgSize, 1, 5, sparsity, seed));
+ inputs.put("gamma", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
+ inputs.put("beta", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
+ inputs.put("ema_mean", generateInputMatrix(spark, numChannels, 1, 3, 7, sparsity, seed));
+ inputs.put("ema_var", generateInputMatrix(spark, numChannels, 1, 0, 2, sparsity, seed));
+ List<String> outputs = Arrays.asList("dX", "dgamma", "dbeta");
+ List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs);
+ List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
+ assertHeavyHitterPresent("gpu_batch_norm2d_bwd_dx");
+ assertEqualObjects(outCPU.get(0), outGPU.get(0), 1e-6);
+ assertEqualObjects(outCPU.get(1), outGPU.get(1));
+ assertEqualObjects(outCPU.get(2), outGPU.get(2));
}
}
[3/3] systemml git commit: [SYSTEMML-445] Removed batch_norm builtin
functions
Posted by ni...@apache.org.
[SYSTEMML-445] Removed batch_norm builtin functions
- Removed batch_norm builtin functions to exploit codegen in CP.
- Added rewrites for compiling efficient CuDNN operators.
- Added rewrites for SGD update operations.
- To simplify adding new GPU rewrites, added HopDagPatternMatcher that allows for pattern matching at the HOP-level. This can be extended for other rewrites as well.
- Added GPU tests to validate the rewrites.
- Updated the DML language documentation.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0f36780a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0f36780a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0f36780a
Branch: refs/heads/master
Commit: 0f36780a8244c6e728d37c32a79e00ed181211ad
Parents: 81419ae
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu Aug 30 15:40:44 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu Aug 30 15:40:44 2018 -0700
----------------------------------------------------------------------
docs/dml-language-reference.md | 2 -
scripts/nn/layers/batch_norm2d.dml | 60 +-
scripts/nn/layers/batch_norm2d_old.dml | 200 ----
src/main/cpp/kernels/SystemML.cu | 56 +-
src/main/cpp/kernels/SystemML.ptx | 321 +++++-
src/main/java/org/apache/sysml/hops/DnnOp.java | 56 +-
.../java/org/apache/sysml/hops/FunctionOp.java | 30 +-
src/main/java/org/apache/sysml/hops/Hop.java | 8 +-
.../hops/rewrite/HopDagPatternMatcher.java | 378 +++++++
.../sysml/hops/rewrite/HopPatternRewriter.java | 72 ++
.../HopRewriteRuleWithPatternMatcher.java | 98 ++
.../sysml/hops/rewrite/HopRewriteUtils.java | 20 +
.../hops/rewrite/RewriteGPUSpecificOps.java | 1027 +++++-------------
.../org/apache/sysml/lops/DnnTransform.java | 53 +-
.../sysml/parser/BuiltinFunctionExpression.java | 61 +-
.../org/apache/sysml/parser/DMLTranslator.java | 2 -
.../org/apache/sysml/parser/Expression.java | 2 +-
.../instructions/GPUInstructionParser.java | 10 +-
.../instructions/gpu/DnnGPUInstruction.java | 526 +++++----
.../gpu/GPUDenseInputPointerFetcher.java | 111 ++
.../gpu/context/GPUMemoryManager.java | 2 +-
.../runtime/matrix/data/LibMatrixCUDA.java | 110 +-
.../runtime/matrix/data/LibMatrixCuDNN.java | 37 +-
.../apache/sysml/test/gpu/BatchNormTest.java | 47 +-
24 files changed, 1818 insertions(+), 1471 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/docs/dml-language-reference.md
----------------------------------------------------------------------
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index 924336a..cdcc529 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -1522,8 +1522,6 @@ Hence, the images are internally represented as a matrix with dimension (N, C *
| bias_add | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Adds the bias (row vector of size num_channels) to input with the given num_channels |
| bias_multiply | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Multiplies the bias (row vector of size num_channels) to input with the given num_channels |
| lstm | X, W, bias, out0, c0 | [batch_size X seq_length*num_features] | [num_features+hidden_size X 4*hidden_size] | [batch_size X seq_length*hidden_size] if return_sequences else [batch_size X hidden_size] | return_sequences | Perform computation for single-layer unidirectional LSTM (outputs: out, carryOut) |
-| batch_norm2d | input | [batch_size X num_channels* height_image* width_image] | | [batch_size X num_channels* height_image* width_image] | scale, shift, exponentialMovingAverage_Mean, exponentialMovingAverage_Variance, mode, epsilon, momentum | Performs batch normalization operation (outputs: updated exponential moving average mean and variance, cache of the batch mean and variance) |
-| batch_norm2d_backward | input, dout | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_image* width_image] | scale, epsilon, cache_mean (from forward), cache_inv_var (from forward) | Computed backpropagation error for batch normalization operation |
Note: the builtin functions `batch_norm2d` and `batch_norm2d_backward` are deprecated and will be removed in the next release. The `lstm` builtin function is in experimental phase and is only supported for the GPU backend.
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/scripts/nn/layers/batch_norm2d.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm2d.dml b/scripts/nn/layers/batch_norm2d.dml
index 2a98857..c68f23d 100644
--- a/scripts/nn/layers/batch_norm2d.dml
+++ b/scripts/nn/layers/batch_norm2d.dml
@@ -83,8 +83,41 @@ forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
* - cache_inv_var: Cache of the inverse variance, of shape (C, 1).
* Note: This is used for performance during training.
*/
- out = X; ema_mean_upd = ema_mean; ema_var_upd = ema_var; cache_mean = ema_mean; cache_inv_var = ema_var
- [out, ema_mean_upd, ema_var_upd, cache_mean, cache_inv_var] = batch_norm2d(X, gamma, beta, ema_mean, ema_var, mode, epsilon, mu)
+ N = nrow(X)
+
+ if (mode == 'train') {
+ # Compute channel-wise mean and variance
+ # Since we don't have tensors, we will compute the means and variances in a piece-wise fashion.
+ # - mean of total group is mean of subgroup means
+ # - variance is the mean of the subgroup variances + the variance of the subgroup means
+ subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
+ subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) # uncorrected variances
+ mean = rowMeans(subgrp_means) # shape (C, 1)
+ var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) # shape (C, 1)
+ # Update moving averages
+ ema_mean_upd = mu*ema_mean + (1-mu)*mean
+ ema_var_upd = mu*ema_var + (1-mu)*var
+ }
+ else {
+ # Use moving averages of mean and variance during testing
+ mean = ema_mean
+ var = ema_var
+ ema_mean_upd = ema_mean
+ ema_var_upd = ema_var
+ }
+
+ # Save variable for backward pass
+ cache_mean = mean
+ cache_inv_var = 1/sqrt(var+epsilon)
+
+ # Normalize, shift, and scale
+ # norm = (X-mean)*(var+epsilon)^(-1/2)
+ # = (X-mean) / sqrt(var+epsilon)
+ centered = bias_add(X, -mean) # shape (N, C*Hin*Win)
+ norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win)
+ # out = norm*gamma + beta
+ scaled = bias_multiply(norm, gamma) # shape (N, C*Hin*Win)
+ out = bias_add(scaled, beta) # shape (N, C*Hin*Win)
}
backward = function(matrix[double] dout,
@@ -119,9 +152,27 @@ backward = function(matrix[double] dout,
* - dbeta: Gradient wrt `b`, of shape (C, 1).
*
*/
+ N = nrow(X)
+ oneByN = 1/N
+ oneByHW = 1/(Hin*Win)
+
+ mean = cache_mean
+ centered = bias_add(X, -mean) # shape (N, C*Hin*Win)
+ norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win)
# Compute gradients during training
- dX = X; dgamma = gamma; dbeta = gamma;
- [dX, dgamma, dbeta] = batch_norm2d_backward(X, dout, gamma, epsilon, cache_mean, cache_inv_var)
+ dgamma = util::channel_sums(dout*norm, C, Hin, Win) # shape (C, 1)
+ dbeta = util::channel_sums(dout, C, Hin, Win) # shape (C, 1)
+ dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win)
+ dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm,
+ C, Hin, Win) # shape (C, 1)
+ dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win)
+ dmean_var_branch = util::channel_sums((-2*oneByN*oneByHW) * centered, C, Hin, Win)
+ dmean_var_branch = dmean_var_branch * dvar # we can't use a function within an expression yet
+ dmean = dmean_norm_branch + dmean_var_branch # shape (C, 1)
+ dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
+ dX_mean_branch = (oneByN*oneByHW) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)
+ dX_var_branch = (2*oneByN*oneByHW) * bias_multiply(centered, dvar)
+ dX = dX_norm_branch + dX_mean_branch + dX_var_branch # shape (N, C*Hin*Win)
}
init = function(int C)
@@ -149,3 +200,4 @@ init = function(int C)
ema_mean = matrix(0, rows=C, cols=1)
ema_var = matrix(1, rows=C, cols=1)
}
+
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/scripts/nn/layers/batch_norm2d_old.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm2d_old.dml b/scripts/nn/layers/batch_norm2d_old.dml
deleted file mode 100644
index 2aba2e6..0000000
--- a/scripts/nn/layers/batch_norm2d_old.dml
+++ /dev/null
@@ -1,200 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-/*
- * 2D (Spatial) Batch Normalization layer.
- */
-source("nn/util.dml") as util
-
-forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
- int C, int Hin, int Win, string mode,
- matrix[double] ema_mean, matrix[double] ema_var,
- double mu, double epsilon)
- return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] ema_var_upd,
- matrix[double] cache_mean, matrix[double] cache_inv_var) {
- /*
- * Computes the forward pass for a 2D (spatial) batch normalization
- * layer. The input data has N examples, each represented as a 3D
- * volume unrolled into a single vector.
- *
- * A spatial batch normalization layer uses the per-channel sample
- * mean and per-channel uncorrected sample variance during training
- * to normalize each channel of the input data. Additionally, it
- * introduces learnable parameters (gamma, beta) to control the
- * amount of normalization.
- *
- * `y = ((x-mean) / sqrt(var+eps)) * gamma + beta`
- *
- * This implementation maintains exponential moving averages of the
- * mean and variance during training for use during testing.
- *
- * Reference:
- * - Batch Normalization: Accelerating Deep Network Training by
- * Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015
- * - https://arxiv.org/abs/1502.03167
- *
- * Inputs:
- * - X: Inputs, of shape (N, C*Hin*Win).
- * - gamma: Scale parameters, of shape (C, 1).
- * - beta: Shift parameters, of shape (C, 1).
- * - C: Number of input channels (dimensionality of input depth).
- * - Hin: Input height.
- * - Win: Input width.
- * - mode: 'train' or 'test' to indicate if the model is currently
- * being trained or tested. During training, the current batch
- * mean and variance will be used to normalize the inputs, while
- * during testing, the exponential average of the mean and
- * variance over all previous batches will be used.
- * - ema_mean: Exponential moving average of the mean, of
- * shape (C, 1).
- * - ema_var: Exponential moving average of the variance, of
- * shape (C, 1).
- * - mu: Momentum value for moving averages.
- * Typical values are in the range of [0.9, 0.999].
- * - epsilon: Smoothing term to avoid divide by zero errors.
- * Typical values are in the range of [1e-5, 1e-3].
- *
- * Outputs:
- * - out: Outputs, of shape (N, C*Hin*Win).
- * - ema_mean_upd: Updated exponential moving average of the mean,
- * of shape (C, 1).
- * - ema_var_upd: Updated exponential moving average of the variance,
- * of shape (C, 1).
- * - cache_mean: Cache of the batch mean, of shape (C, 1).
- * Note: This is used for performance during training.
- * - cache_inv_var: Cache of the inverse variance, of shape (C, 1).
- * Note: This is used for performance during training.
- */
- N = nrow(X)
-
- if (mode == 'train') {
- # Compute channel-wise mean and variance
- # Since we don't have tensors, we will compute the means and variances in a piece-wise fashion.
- # - mean of total group is mean of subgroup means
- # - variance is the mean of the subgroup variances + the variance of the subgroup means
- subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
- subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) # uncorrected variances
- mean = rowMeans(subgrp_means) # shape (C, 1)
- var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) # shape (C, 1)
- # Update moving averages
- ema_mean_upd = mu*ema_mean + (1-mu)*mean
- ema_var_upd = mu*ema_var + (1-mu)*var
- }
- else {
- # Use moving averages of mean and variance during testing
- mean = ema_mean
- var = ema_var
- ema_mean_upd = ema_mean
- ema_var_upd = ema_var
- }
-
- # Save variable for backward pass
- cache_mean = mean
- cache_inv_var = 1/sqrt(var+epsilon)
-
- # Normalize, shift, and scale
- # norm = (X-mean)*(var+epsilon)^(-1/2)
- # = (X-mean) / sqrt(var+epsilon)
- centered = bias_add(X, -mean) # shape (N, C*Hin*Win)
- norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win)
- # out = norm*gamma + beta
- scaled = bias_multiply(norm, gamma) # shape (N, C*Hin*Win)
- out = bias_add(scaled, beta) # shape (N, C*Hin*Win)
-}
-
-backward = function(matrix[double] dout,
- matrix[double] cache_mean, matrix[double] cache_inv_var,
- matrix[double] X, matrix[double] gamma,
- int C, int Hin, int Win, double epsilon)
- return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
- /*
- * Computes the backward pass for a 2D (spatial) batch normalization
- * layer.
- *
- * Inputs:
- * - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).
- * - cache_mean: Cache of the batch mean from the forward pass, of
- * shape (C, 1). Note: This is used for performance during
- * training.
- * - cache_inv_var: Cache of the inverse variance from the forward pass,
- * of shape (C, 1). Note: This is used for performance during
- * training.
- * - X: Input data matrix to the forward pass, of
- * shape (N, C*Hin*Win).
- * - gamma: Scale parameters, of shape (C, 1).
- * - C: Number of input channels (dimensionality of input depth).
- * - Hin: Input height.
- * - Win: Input width.
- * - epsilon: Smoothing term to avoid divide by zero errors.
- * Typical values are in the range of [1e-5, 1e-3].
- *
- * Outputs:
- * - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
- * - dgamma: Gradient wrt `W`, of shape (C, 1).
- * - dbeta: Gradient wrt `b`, of shape (C, 1).
- *
- */
- N = nrow(X)
- mean = cache_mean
- centered = bias_add(X, -mean) # shape (N, C*Hin*Win)
- norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win)
- # Compute gradients during training
- dgamma = util::channel_sums(dout*norm, C, Hin, Win) # shape (C, 1)
- dbeta = util::channel_sums(dout, C, Hin, Win) # shape (C, 1)
- dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win)
- dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm,
- C, Hin, Win) # shape (C, 1)
- dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win)
- dmean_var_branch = util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, Win)
- dmean_var_branch = dmean_var_branch * dvar # we can't use a function within an expression yet
- dmean = dmean_norm_branch + dmean_var_branch # shape (C, 1)
- dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
- dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)
- dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar)
- dX = dX_norm_branch + dX_mean_branch + dX_var_branch # shape (N, C*Hin*Win)
-}
-
-init = function(int C)
- return (matrix[double] gamma, matrix[double] beta,
- matrix[double] ema_mean, matrix[double] ema_var) {
- /*
- * Initialize the parameters of this layer.
- *
- * Note: This is just a convenience function, and parameters
- * may be initialized manually if needed.
- *
- * Inputs:
- * - C: Number of input channels (dimensionality of input depth).
- *
- * Outputs:
- * - gamma: Scale parameters, of shape (C, 1).
- * - beta: Shift parameters, of shape (C, 1).
- * - ema_mean: Exponential moving average of the mean, of
- * shape (C, 1).
- * - ema_var: Exponential moving average of the variance, of
- * shape (C, 1).
- */
- gamma = matrix(1, rows=C, cols=1)
- beta = matrix(0, rows=C, cols=1)
- ema_mean = matrix(0, rows=C, cols=1)
- ema_var = matrix(1, rows=C, cols=1)
-}
-
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 9ddaaff..b874cdd 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -2289,4 +2289,58 @@ extern "C" __global__ void update_nesterov_x_d(double *X, double *v, double *v_p
extern "C" __global__ void update_nesterov_x_f(float *X, float *v, float *v_prev, double mu, float *out, unsigned int size) {
update_nesterov_x(X, v, v_prev, mu, out, size);
-}
\ No newline at end of file
+}
+
+// Performs the operation: C = a*X + b*C
+template <typename T>
+__device__ void aXplusbC(T *X, T *C, double a, double b, unsigned int size) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ if (index < size) {
+ C[index] = a*X[index] + b*C[index];
+ }
+}
+
+extern "C" __global__ void aXplusbC_d(double *X, double *C, double a, double b, unsigned int size) {
+ aXplusbC(X, C, a, b,size);
+}
+
+extern "C" __global__ void aXplusbC_f(float *X, float *C, double a, double b, unsigned int size) {
+ aXplusbC(X, C, a, b,size);;
+}
+
+
+// Performs the operation: C = a*X + b*Y
+template <typename T>
+__device__ void aXplusbY(T *X, T* Y, T *C, double a, double b, unsigned int size) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ if (index < size) {
+ C[index] = a*X[index] + b*Y[index];
+ }
+}
+
+extern "C" __global__ void aXplusbY_d(double *X, double* Y, double *C, double a, double b, unsigned int size) {
+ aXplusbY(X, Y, C, a, b, size);
+}
+
+extern "C" __global__ void aXplusbY_f(float *X, float* Y, float *C, double a, double b, unsigned int size) {
+ aXplusbY(X, Y, C, a, b, size);
+}
+
+
+// Performs the operation: C = 1 / sqrt(X + eps)
+template <typename T>
+__device__ void invVar(T *X, T *C, double eps, unsigned int size) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ if (index < size) {
+ C[index] = 1.0 / sqrt(X[index] + eps);
+ }
+}
+
+extern "C" __global__ void invVar_d(double *X, double *C, double eps, unsigned int size) {
+ invVar(X, C, eps, size);
+}
+
+extern "C" __global__ void invVar_f(float *X, float *C, double eps, unsigned int size) {
+ invVar(X, C, eps, size);
+}
+
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx
index 8a14876..1ab32f5 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -13084,12 +13084,279 @@ BB115_2:
ret;
}
+ // .globl aXplusbC_d
+.visible .entry aXplusbC_d(
+ .param .u64 aXplusbC_d_param_0,
+ .param .u64 aXplusbC_d_param_1,
+ .param .f64 aXplusbC_d_param_2,
+ .param .f64 aXplusbC_d_param_3,
+ .param .u32 aXplusbC_d_param_4
+)
+{
+ .reg .pred %p<2>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<7>;
+ .reg .b64 %rd<8>;
+
+
+ ld.param.u64 %rd1, [aXplusbC_d_param_0];
+ ld.param.u64 %rd2, [aXplusbC_d_param_1];
+ ld.param.f64 %fd1, [aXplusbC_d_param_2];
+ ld.param.f64 %fd2, [aXplusbC_d_param_3];
+ ld.param.u32 %r2, [aXplusbC_d_param_4];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB116_2;
+
+ cvta.to.global.u64 %rd3, %rd2;
+ cvta.to.global.u64 %rd4, %rd1;
+ mul.wide.s32 %rd5, %r1, 8;
+ add.s64 %rd6, %rd4, %rd5;
+ ld.global.f64 %fd3, [%rd6];
+ add.s64 %rd7, %rd3, %rd5;
+ ld.global.f64 %fd4, [%rd7];
+ mul.f64 %fd5, %fd4, %fd2;
+ fma.rn.f64 %fd6, %fd3, %fd1, %fd5;
+ st.global.f64 [%rd7], %fd6;
+
+BB116_2:
+ ret;
+}
+
+ // .globl aXplusbC_f
+.visible .entry aXplusbC_f(
+ .param .u64 aXplusbC_f_param_0,
+ .param .u64 aXplusbC_f_param_1,
+ .param .f64 aXplusbC_f_param_2,
+ .param .f64 aXplusbC_f_param_3,
+ .param .u32 aXplusbC_f_param_4
+)
+{
+ .reg .pred %p<2>;
+ .reg .f32 %f<4>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<7>;
+ .reg .b64 %rd<8>;
+
+
+ ld.param.u64 %rd1, [aXplusbC_f_param_0];
+ ld.param.u64 %rd2, [aXplusbC_f_param_1];
+ ld.param.f64 %fd1, [aXplusbC_f_param_2];
+ ld.param.f64 %fd2, [aXplusbC_f_param_3];
+ ld.param.u32 %r2, [aXplusbC_f_param_4];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB117_2;
+
+ cvta.to.global.u64 %rd3, %rd2;
+ cvta.to.global.u64 %rd4, %rd1;
+ mul.wide.s32 %rd5, %r1, 4;
+ add.s64 %rd6, %rd4, %rd5;
+ ld.global.f32 %f1, [%rd6];
+ cvt.f64.f32 %fd3, %f1;
+ add.s64 %rd7, %rd3, %rd5;
+ ld.global.f32 %f2, [%rd7];
+ cvt.f64.f32 %fd4, %f2;
+ mul.f64 %fd5, %fd4, %fd2;
+ fma.rn.f64 %fd6, %fd3, %fd1, %fd5;
+ cvt.rn.f32.f64 %f3, %fd6;
+ st.global.f32 [%rd7], %f3;
+
+BB117_2:
+ ret;
+}
+
+ // .globl aXplusbY_d
+.visible .entry aXplusbY_d(
+ .param .u64 aXplusbY_d_param_0,
+ .param .u64 aXplusbY_d_param_1,
+ .param .u64 aXplusbY_d_param_2,
+ .param .f64 aXplusbY_d_param_3,
+ .param .f64 aXplusbY_d_param_4,
+ .param .u32 aXplusbY_d_param_5
+)
+{
+ .reg .pred %p<2>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<7>;
+ .reg .b64 %rd<11>;
+
+
+ ld.param.u64 %rd1, [aXplusbY_d_param_0];
+ ld.param.u64 %rd2, [aXplusbY_d_param_1];
+ ld.param.u64 %rd3, [aXplusbY_d_param_2];
+ ld.param.f64 %fd1, [aXplusbY_d_param_3];
+ ld.param.f64 %fd2, [aXplusbY_d_param_4];
+ ld.param.u32 %r2, [aXplusbY_d_param_5];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB118_2;
+
+ cvta.to.global.u64 %rd4, %rd1;
+ mul.wide.s32 %rd5, %r1, 8;
+ add.s64 %rd6, %rd4, %rd5;
+ ld.global.f64 %fd3, [%rd6];
+ cvta.to.global.u64 %rd7, %rd2;
+ add.s64 %rd8, %rd7, %rd5;
+ ld.global.f64 %fd4, [%rd8];
+ mul.f64 %fd5, %fd4, %fd2;
+ fma.rn.f64 %fd6, %fd3, %fd1, %fd5;
+ cvta.to.global.u64 %rd9, %rd3;
+ add.s64 %rd10, %rd9, %rd5;
+ st.global.f64 [%rd10], %fd6;
+
+BB118_2:
+ ret;
+}
+
+ // .globl aXplusbY_f
+.visible .entry aXplusbY_f(
+ .param .u64 aXplusbY_f_param_0,
+ .param .u64 aXplusbY_f_param_1,
+ .param .u64 aXplusbY_f_param_2,
+ .param .f64 aXplusbY_f_param_3,
+ .param .f64 aXplusbY_f_param_4,
+ .param .u32 aXplusbY_f_param_5
+)
+{
+ .reg .pred %p<2>;
+ .reg .f32 %f<4>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<7>;
+ .reg .b64 %rd<11>;
+
+
+ ld.param.u64 %rd1, [aXplusbY_f_param_0];
+ ld.param.u64 %rd2, [aXplusbY_f_param_1];
+ ld.param.u64 %rd3, [aXplusbY_f_param_2];
+ ld.param.f64 %fd1, [aXplusbY_f_param_3];
+ ld.param.f64 %fd2, [aXplusbY_f_param_4];
+ ld.param.u32 %r2, [aXplusbY_f_param_5];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB119_2;
+
+ cvta.to.global.u64 %rd4, %rd1;
+ mul.wide.s32 %rd5, %r1, 4;
+ add.s64 %rd6, %rd4, %rd5;
+ ld.global.f32 %f1, [%rd6];
+ cvt.f64.f32 %fd3, %f1;
+ cvta.to.global.u64 %rd7, %rd2;
+ add.s64 %rd8, %rd7, %rd5;
+ ld.global.f32 %f2, [%rd8];
+ cvt.f64.f32 %fd4, %f2;
+ mul.f64 %fd5, %fd4, %fd2;
+ fma.rn.f64 %fd6, %fd3, %fd1, %fd5;
+ cvt.rn.f32.f64 %f3, %fd6;
+ cvta.to.global.u64 %rd9, %rd3;
+ add.s64 %rd10, %rd9, %rd5;
+ st.global.f32 [%rd10], %f3;
+
+BB119_2:
+ ret;
+}
+
+ // .globl invVar_d
+.visible .entry invVar_d(
+ .param .u64 invVar_d_param_0,
+ .param .u64 invVar_d_param_1,
+ .param .f64 invVar_d_param_2,
+ .param .u32 invVar_d_param_3
+)
+{
+ .reg .pred %p<2>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<6>;
+ .reg .b64 %rd<8>;
+
+
+ ld.param.u64 %rd1, [invVar_d_param_0];
+ ld.param.u64 %rd2, [invVar_d_param_1];
+ ld.param.f64 %fd1, [invVar_d_param_2];
+ ld.param.u32 %r2, [invVar_d_param_3];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB120_2;
+
+ cvta.to.global.u64 %rd3, %rd1;
+ mul.wide.s32 %rd4, %r1, 8;
+ add.s64 %rd5, %rd3, %rd4;
+ ld.global.f64 %fd2, [%rd5];
+ add.f64 %fd3, %fd2, %fd1;
+ sqrt.rn.f64 %fd4, %fd3;
+ rcp.rn.f64 %fd5, %fd4;
+ cvta.to.global.u64 %rd6, %rd2;
+ add.s64 %rd7, %rd6, %rd4;
+ st.global.f64 [%rd7], %fd5;
+
+BB120_2:
+ ret;
+}
+
+ // .globl invVar_f
+.visible .entry invVar_f(
+ .param .u64 invVar_f_param_0,
+ .param .u64 invVar_f_param_1,
+ .param .f64 invVar_f_param_2,
+ .param .u32 invVar_f_param_3
+)
+{
+ .reg .pred %p<2>;
+ .reg .f32 %f<3>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<6>;
+ .reg .b64 %rd<8>;
+
+
+ ld.param.u64 %rd1, [invVar_f_param_0];
+ ld.param.u64 %rd2, [invVar_f_param_1];
+ ld.param.f64 %fd1, [invVar_f_param_2];
+ ld.param.u32 %r2, [invVar_f_param_3];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB121_2;
+
+ cvta.to.global.u64 %rd3, %rd1;
+ mul.wide.s32 %rd4, %r1, 4;
+ add.s64 %rd5, %rd3, %rd4;
+ ld.global.f32 %f1, [%rd5];
+ cvt.f64.f32 %fd2, %f1;
+ add.f64 %fd3, %fd2, %fd1;
+ sqrt.rn.f64 %fd4, %fd3;
+ rcp.rn.f64 %fd5, %fd4;
+ cvt.rn.f32.f64 %f2, %fd5;
+ cvta.to.global.u64 %rd6, %rd2;
+ add.s64 %rd7, %rd6, %rd4;
+ st.global.f32 [%rd7], %f2;
+
+BB121_2:
+ ret;
+}
+
.func (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
.param .b64 __internal_trig_reduction_slowpathd_param_0,
.param .b64 __internal_trig_reduction_slowpathd_param_1
)
{
- .local .align 8 .b8 __local_depot116[40];
+ .local .align 8 .b8 __local_depot122[40];
.reg .b64 %SP;
.reg .b64 %SPL;
.reg .pred %p<9>;
@@ -13098,7 +13365,7 @@ BB115_2:
.reg .b64 %rd<102>;
- mov.u64 %rd101, __local_depot116;
+ mov.u64 %rd101, __local_depot122;
cvta.local.u64 %SP, %rd101;
ld.param.f64 %fd4, [__internal_trig_reduction_slowpathd_param_0];
ld.param.u64 %rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -13112,7 +13379,7 @@ BB115_2:
shr.u32 %r3, %r1, 20;
bfe.u32 %r4, %r1, 20, 11;
setp.eq.s32 %p1, %r4, 2047;
- @%p1 bra BB116_13;
+ @%p1 bra BB122_13;
add.s32 %r15, %r4, -1024;
shr.u32 %r16, %r15, 6;
@@ -13125,7 +13392,7 @@ BB115_2:
mov.u64 %rd94, 0;
setp.ge.s32 %p2, %r5, %r6;
mov.u64 %rd93, %rd1;
- @%p2 bra BB116_4;
+ @%p2 bra BB122_4;
mov.b64 %rd41, %fd4;
shl.b64 %rd42, %rd41, 11;
@@ -13142,7 +13409,7 @@ BB115_2:
mov.u64 %rd91, %rd1;
mov.u32 %r39, %r5;
-BB116_3:
+BB122_3:
.pragma "nounroll";
ld.const.u64 %rd47, [%rd89];
// inline asm
@@ -13172,15 +13439,15 @@ BB116_3:
add.s64 %rd93, %rd93, 8;
add.s64 %rd89, %rd89, 8;
setp.lt.s32 %p3, %r39, %r6;
- @%p3 bra BB116_3;
+ @%p3 bra BB122_3;
-BB116_4:
+BB122_4:
st.local.u64 [%rd93], %rd94;
ld.local.u64 %rd95, [%rd1+16];
ld.local.u64 %rd96, [%rd1+24];
and.b32 %r9, %r3, 63;
setp.eq.s32 %p4, %r9, 0;
- @%p4 bra BB116_6;
+ @%p4 bra BB122_6;
mov.u32 %r27, 64;
sub.s32 %r28, %r27, %r9;
@@ -13192,7 +13459,7 @@ BB116_4:
shr.u64 %rd55, %rd54, %r28;
or.b64 %rd95, %rd55, %rd53;
-BB116_6:
+BB122_6:
cvta.to.local.u64 %rd56, %rd37;
shr.u64 %rd57, %rd96, 62;
cvt.u32.u64 %r29, %rd57;
@@ -13209,7 +13476,7 @@ BB116_6:
selp.b32 %r34, %r32, %r33, %p5;
st.local.u32 [%rd56], %r34;
setp.eq.s32 %p6, %r31, 0;
- @%p6 bra BB116_8;
+ @%p6 bra BB122_8;
mov.u64 %rd64, 0;
// inline asm
@@ -13229,10 +13496,10 @@ BB116_6:
// inline asm
xor.b32 %r40, %r40, -2147483648;
-BB116_8:
+BB122_8:
clz.b64 %r41, %rd98;
setp.eq.s32 %p7, %r41, 0;
- @%p7 bra BB116_10;
+ @%p7 bra BB122_10;
shl.b64 %rd67, %rd98, %r41;
mov.u32 %r35, 64;
@@ -13240,7 +13507,7 @@ BB116_8:
shr.u64 %rd68, %rd97, %r36;
or.b64 %rd98, %rd68, %rd67;
-BB116_10:
+BB122_10:
mov.u64 %rd72, -3958705157555305931;
// inline asm
{
@@ -13261,7 +13528,7 @@ BB116_10:
}
// inline asm
setp.lt.s64 %p8, %rd100, 1;
- @%p8 bra BB116_12;
+ @%p8 bra BB122_12;
// inline asm
{
@@ -13280,7 +13547,7 @@ BB116_10:
// inline asm
add.s32 %r41, %r41, 1;
-BB116_12:
+BB122_12:
cvt.u64.u32 %rd79, %r40;
shl.b64 %rd80, %rd79, 32;
mov.u32 %r37, 1022;
@@ -13295,7 +13562,7 @@ BB116_12:
or.b64 %rd88, %rd87, %rd80;
mov.b64 %fd4, %rd88;
-BB116_13:
+BB122_13:
st.param.f64 [func_retval0+0], %fd4;
ret;
}
@@ -13323,7 +13590,7 @@ BB116_13:
}
shr.u32 %r51, %r50, 20;
setp.ne.s32 %p1, %r51, 0;
- @%p1 bra BB117_2;
+ @%p1 bra BB123_2;
mul.f64 %fd14, %fd12, 0d4350000000000000;
{
@@ -13337,13 +13604,13 @@ BB116_13:
shr.u32 %r16, %r50, 20;
add.s32 %r51, %r16, -54;
-BB117_2:
+BB123_2:
add.s32 %r52, %r51, -1023;
and.b32 %r17, %r50, -2146435073;
or.b32 %r18, %r17, 1072693248;
mov.b64 %fd135, {%r49, %r18};
setp.lt.u32 %p2, %r18, 1073127583;
- @%p2 bra BB117_4;
+ @%p2 bra BB123_4;
{
.reg .b32 %temp;
@@ -13357,7 +13624,7 @@ BB117_2:
mov.b64 %fd135, {%r19, %r21};
add.s32 %r52, %r51, -1022;
-BB117_4:
+BB123_4:
add.f64 %fd15, %fd135, 0d3FF0000000000000;
rcp.approx.ftz.f64 %fd16, %fd15;
neg.f64 %fd17, %fd15;
@@ -13520,13 +13787,13 @@ BB117_4:
mov.b32 %f2, %r35;
abs.f32 %f1, %f2;
setp.lt.f32 %p4, %f1, 0f4086232B;
- @%p4 bra BB117_7;
+ @%p4 bra BB123_7;
setp.lt.f64 %p5, %fd4, 0d0000000000000000;
add.f64 %fd129, %fd4, 0d7FF0000000000000;
selp.f64 %fd136, 0d0000000000000000, %fd129, %p5;
setp.geu.f32 %p6, %f1, 0f40874800;
- @%p6 bra BB117_7;
+ @%p6 bra BB123_7;
mov.f64 %fd134, 0d4338000000000000;
mov.f64 %fd133, 0d3FF71547652B82FE;
@@ -13548,26 +13815,26 @@ BB117_4:
mov.b64 %fd131, {%r44, %r43};
mul.f64 %fd136, %fd130, %fd131;
-BB117_7:
+BB123_7:
{
.reg .b32 %temp;
mov.b64 {%temp, %r45}, %fd136;
}
and.b32 %r46, %r45, 2147483647;
setp.ne.s32 %p7, %r46, 2146435072;
- @%p7 bra BB117_9;
+ @%p7 bra BB123_9;
{
.reg .b32 %temp;
mov.b64 {%r47, %temp}, %fd136;
}
setp.eq.s32 %p8, %r47, 0;
- @%p8 bra BB117_10;
+ @%p8 bra BB123_10;
-BB117_9:
+BB123_9:
fma.rn.f64 %fd136, %fd136, %fd5, %fd136;
-BB117_10:
+BB123_10:
st.param.f64 [func_retval0+0], %fd136;
ret;
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java
index a7d37dc..c4ce466 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -110,8 +110,6 @@ public class DnnOp extends MultiThreadedHop
if( getLops() != null )
return getLops();
- ExecType et = optFindExecType();
-
ArrayList<Hop> inputs = getInput();
switch( op )
{
@@ -125,6 +123,7 @@ public class DnnOp extends MultiThreadedHop
case BIASADD:
case BIASMULT:
{
+ ExecType et = optFindExecType();
if(et == ExecType.CP || et == ExecType.GPU) {
setLops(constructDnnLops(et, inputs));
break;
@@ -137,15 +136,15 @@ public class DnnOp extends MultiThreadedHop
case BATCH_NORM2D_TEST:
case CHANNEL_SUMS:
case UPDATE_NESTEROV_X:
+ case UPDATE_EMA_VAR:
+ case RESHAPE_COLMEANS:
+ case UPDATE_EMA:
+ case INV_VAR:
+ case BATCH_NORM2D_BACKWARD_DX:
{
- if(et == ExecType.GPU) {
- setLops(constructDnnLops(et, inputs));
- break;
- }
- else {
- throw new HopsException("Unimplemented DnnOp for execution type: " + et.name());
- }
- // break;
+ // GPU-specific operators
+ setLops(constructDnnLops(ExecType.GPU, inputs));
+ break;
}
default:
throw new HopsException("Unsupported lops construction for operation type '"+op+"'.");
@@ -171,10 +170,16 @@ public class DnnOp extends MultiThreadedHop
return 14;
case BIASADD:
case BIASMULT:
+ case INV_VAR:
return 2;
case BATCH_NORM2D_TEST:
return 6;
+ case UPDATE_EMA_VAR:
+ case BATCH_NORM2D_BACKWARD_DX:
+ return 5;
+ case RESHAPE_COLMEANS:
case CHANNEL_SUMS:
+ case UPDATE_EMA:
return 3;
case UPDATE_NESTEROV_X:
return 4;
@@ -532,7 +537,8 @@ public class DnnOp extends MultiThreadedHop
long[] ret = new long[3];
if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
- op == OpOpDnn.UPDATE_NESTEROV_X) {
+ op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+ op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
// Same dimension as the first input
MatrixCharacteristics[] mc = memo.getAllInputStats(getInput());
ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
@@ -540,13 +546,21 @@ public class DnnOp extends MultiThreadedHop
ret[2] = -1;
return (ret[0]>=0 && ret[1]>=0) ? ret : null;
}
- else if(op == OpOpDnn.CHANNEL_SUMS) {
+ else if(op == OpOpDnn.CHANNEL_SUMS || op == OpOpDnn.UPDATE_EMA_VAR) {
long numChannels = Hop.computeSizeInformation(getInput().get(1));
ret[0] = numChannels;
ret[1] = 1;
ret[2] = -1;
return ret;
}
+ else if(op == OpOpDnn.RESHAPE_COLMEANS) {
+ long numChannels = Hop.computeSizeInformation(getInput().get(1));
+ long HW = Hop.computeSizeInformation(getInput().get(2));
+ ret[0] = numChannels;
+ ret[1] = HW;
+ ret[2] = -1;
+ return ret;
+ }
refreshSizeInformation();
ret[0] = _dim1; ret[1] = _dim2; ret[2] = _nnz;
@@ -739,7 +753,9 @@ public class DnnOp extends MultiThreadedHop
@Override
public void refreshSizeInformation()
{
- if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X) {
+ if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
+ op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+ op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
// Same dimension as the first input
Hop input1 = getInput().get(0);
setDim1(input1.getDim1());
@@ -747,13 +763,21 @@ public class DnnOp extends MultiThreadedHop
_nnz = -1; // cannot infer stats
return;
}
- else if(op == OpOpDnn.CHANNEL_SUMS) {
+ else if(op == OpOpDnn.CHANNEL_SUMS || op == OpOpDnn.UPDATE_EMA_VAR) {
long numChannels = Hop.computeSizeInformation(getInput().get(1));
setDim1(numChannels);
setDim2(1);
_nnz = -1; // cannot infer stats
return;
}
+ else if(op == OpOpDnn.RESHAPE_COLMEANS) {
+ long numChannels = Hop.computeSizeInformation(getInput().get(1));
+ long HW = Hop.computeSizeInformation(getInput().get(2));
+ setDim1(numChannels);
+ setDim2(HW);
+ _nnz = -1; // cannot infer stats
+ return;
+ }
// Reset the _cachedParams to avoid incorrect sizes
_cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, _maxNumThreads);
@@ -847,7 +871,9 @@ public class DnnOp extends MultiThreadedHop
*/
private long getDim(String dimString) {
if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
- op == OpOpDnn.UPDATE_NESTEROV_X) {
+ op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.RESHAPE_COLMEANS ||
+ op == OpOpDnn.UPDATE_EMA_VAR || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+ op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
throw new RuntimeException("getDim method should not be invoked for " + op.name());
}
try {
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/FunctionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index ea397db..5f177bd 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -181,21 +181,6 @@ public class FunctionOp extends Hop
// TODO: To allow for initial version to always run on the GPU
return 0;
}
- else if ( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
- return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), getOutputs().get(3).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), getOutputs().get(4).getDim2(), 1.0);
- }
- else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_test") ) {
- return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0);
- }
- else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) {
- return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0);
- }
else if ( getFunctionName().equalsIgnoreCase("svd") ) {
long outputU = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0);
long outputSigma = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0);
@@ -226,10 +211,6 @@ public class FunctionOp extends Hop
return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0)
+ 3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 1.0);
}
- else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
- getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
- return 0;
- }
else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) {
// TODO: To allow for initial version to always run on the GPU
return 0;
@@ -251,9 +232,7 @@ public class FunctionOp extends Hop
@Override
public boolean isGPUEnabled() {
- if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ||
- getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
- getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test"))
+ if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward"))
return true;
else
return false;
@@ -308,13 +287,6 @@ public class FunctionOp extends Hop
throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU.");
_etype = ExecType.GPU;
}
- else if(isBuiltinFunction && (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))) {
- _etype = ConfigurationManager.isGPU() ? ExecType.GPU : ExecType.CP;
- }
- else if(isBuiltinFunction && getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
- // Only GPU implementation is supported
- _etype = ExecType.GPU;
- }
else {
// Since the memory estimate is only conservative, do not throw
// exception if the estimated memory is larger than the budget
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 3b461a1..c8356e0 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1100,7 +1100,8 @@ public abstract class Hop implements ParseInfo
MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS,
- UPDATE_NESTEROV_X
+ UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR,
+ BATCH_NORM2D_BACKWARD_DX
}
public enum DataGenMethod {
@@ -1174,8 +1175,13 @@ public abstract class Hop implements ParseInfo
HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_FILTER, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_FILTER);
HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_DATA, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_DATA);
HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_TEST, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_TEST);
+ HopsConv2Lops.put(OpOpDnn.UPDATE_EMA_VAR, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA_VAR);
HopsConv2Lops.put(OpOpDnn.CHANNEL_SUMS, org.apache.sysml.lops.DnnTransform.OperationTypes.CHANNEL_SUMS);
HopsConv2Lops.put(OpOpDnn.UPDATE_NESTEROV_X, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_NESTEROV_X);
+ HopsConv2Lops.put(OpOpDnn.RESHAPE_COLMEANS, org.apache.sysml.lops.DnnTransform.OperationTypes.RESHAPE_COLMEANS);
+ HopsConv2Lops.put(OpOpDnn.UPDATE_EMA, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA);
+ HopsConv2Lops.put(OpOpDnn.INV_VAR, org.apache.sysml.lops.DnnTransform.OperationTypes.INV_VAR);
+ HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DX, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DX);
}
protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops;
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
new file mode 100644
index 0000000..7c70b7b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.function.Function;
+import java.util.function.Predicate;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOpDnn;
+import org.apache.sysml.hops.Hop.ReOrgOp;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
+import org.apache.sysml.utils.Explain;
+
+/**
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation.
+ */
+public class HopDagPatternMatcher {
+ static final HashSet<String> DEBUG_PATTERNS;
+ static {
+ // DEBUG_PATTERNS = new HashSet<>();
+ // DEBUG_PATTERNS.add("batchNormdX");
+ DEBUG_PATTERNS = null;
+ }
+
+ // Predicates for the current HOP
+ List<HopPredicate> _predicates = new ArrayList<>();
+ // Child matchers
+ List<HopDagPatternMatcher> _children = new ArrayList<>();
+ private boolean _isLeaf = false;
+
+ static boolean DEBUG_REWRITES = false; // This is set by HopPatternRewriter. Please use DEBUG_PATTERNS instead.
+
+ // Simple utility for debugging the rewrites
+ public static class HopPredicate implements Predicate<Hop> {
+ final String _name;
+ final Function<Hop, Boolean> _pred;
+ public HopPredicate(String name, Function<Hop, Boolean> pred) {
+ _name = name;
+ _pred = pred;
+ }
+ @Override
+ public boolean test(Hop h) {
+ return _pred.apply(h);
+ }
+ @Override
+ public String toString() {
+ return _name;
+ }
+ }
+
+ /**
+ * Adds a predicate to the pattern matcher
+ *
+ * @param name name of the pattern for debugging
+ * @param pred higher order function that takes as an input a hop and returns true if the pattern matches else false
+ * @return this
+ */
+ public HopDagPatternMatcher addPredicate(String name, Function<Hop, Boolean> pred) {
+ _predicates.add(new HopPredicate(name, pred));
+ return this;
+ }
+
+ /**
+ * Add child pattern matcher
+ * @param children list of childer
+ * @return this
+ */
+ public HopDagPatternMatcher addChildMatcher(HopDagPatternMatcher... children) {
+ for(int i = 0; i < children.length; i++) {
+ _children.add(children[i]);
+ }
+ return this;
+ }
+
+ /**
+ * Get the matched HOP DAGs
+ * @param varName variable names
+ * @return matched HOP
+ */
+ public Hop getMatchedHop(String varName) {
+
+ if(matchedHops == null || !matchedHops.containsKey(varName)) {
+ throw new RuntimeException("Incorrect usage: the variable " + varName + " is not registered as input.");
+ }
+ return matchedHops.get(varName);
+ }
+
+ /**
+ * Return the value
+ *
+ * @param varName variable name
+ * @return the value of the LiteralOp
+ */
+ public double getLiteralValue(String varName) {
+ return OptimizerUtils.rEvalSimpleDoubleExpression(getMatchedHop(varName), new HashMap<>());
+ }
+
+ @Override
+ public String toString() {
+ return _predicates.size() >= 1 ? _predicates.get(0).toString() : "";
+ }
+
+ /**
+ * Match the given HOP DAG
+ *
+ * @param h root node of the HOP DAG
+ * @return true if HOP DAG matches
+ */
+ public boolean matches(Hop h) {
+ visited.clear();
+ matchedHops.clear();
+ return matchHelper(this, h);
+ }
+
+ private HashMap<String, Hop> matchedHops = new HashMap<>();
+ private String variableName;
+ private HashMap<HopDagPatternMatcher, Hop> visited = new HashMap<>(); // Map of matched hops
+ private boolean matchHelper(HopDagPatternMatcher root, Hop h) {
+ if(h == null) {
+ return false;
+ }
+ else if(_children.size() > 0 && h.getInput().size() < _children.size()) {
+ if(DEBUG_REWRITES) {
+ System.out.println("The expected number of children (" + _children.size() + ") didnot match the number of inputs (" + h.getInput().size() + ") " + this);
+ }
+ return false;
+ }
+ if(root.visited.containsKey(this)) {
+ Hop h1 = root.visited.get(this);
+ if(h == h1) {
+ if(DEBUG_REWRITES)
+ System.out.println("MATCHED: Early exit as the given HOP has been already matched by the matcher." + this);
+ return true; // Early exit as the given HOP has been already matched by the matcher
+ }
+ else if(_isLeaf) {
+ if(h.getDataType() == h1.getDataType() && h.getDataType() == DataType.SCALAR) {
+ return OptimizerUtils.rEvalSimpleDoubleExpression(h, new HashMap<>()) == OptimizerUtils.rEvalSimpleDoubleExpression(h1, new HashMap<>());
+ }
+ return false; // Mismatched or unknown datatypes or matched with different hops
+ }
+ }
+
+ for(HopPredicate p : _predicates) {
+ if(!p.test(h)) {
+ if(DEBUG_REWRITES) {
+ System.out.println("The predicate " + p.toString() + " failed.");
+ }
+ return false;
+ }
+ }
+ int index = 0;
+ for(HopDagPatternMatcher child : _children) {
+ if(!child.matchHelper(root, h.getInput().get(index))) {
+ return false;
+ }
+ index++;
+ }
+ if(_isLeaf) {
+ root.matchedHops.put(variableName, h);
+ }
+
+ root.visited.put(this, h);
+ if(DEBUG_REWRITES)
+ System.out.println("MATCHED: " + this + " to " + Explain.explain(h));
+ return true;
+ }
+
+
+ // Simple helper utilities for adding predicates
+ private HopDagPatternMatcher isScalar() {
+ return this.addPredicate("isScalar", h -> h.getDataType() == DataType.SCALAR);
+ }
+ private HopDagPatternMatcher isMatrix() {
+ return this.addPredicate("isMatrix", h -> h.getDataType() == DataType.MATRIX);
+ }
+ public HopDagPatternMatcher fitsOnGPU(double constant) {
+ return this.addPredicate("fitsOnGPU", h -> _fitsOnGPU(h, constant));
+ }
+
+ // Factory methods:
+ public static HopDagPatternMatcher dummy = new HopDagPatternMatcher();
+ public static HopDagPatternMatcher rowMeans(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("rowMeans", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher rowVars(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("rowVars", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher colVars(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("colVars", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher leaf(String _variableName, DataType dt) {
+ HopDagPatternMatcher ret = new HopDagPatternMatcher();
+ ret._isLeaf = true;
+ ret.variableName = _variableName;
+ if(dt == DataType.MATRIX) {
+ return ret.isMatrix();
+ }
+ else if(dt == DataType.SCALAR) {
+ return ret.isScalar();
+ }
+ else if(dt == DataType.UNKNOWN) {
+ return ret;
+ }
+ else {
+ throw new DMLRuntimeException("Unsupported datatype in pattern matcher:" + dt.name());
+ }
+ }
+ public static HopDagPatternMatcher rowSums(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("rowSums", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.SUM && ((AggUnaryOp)h).getDirection() == Direction.Row)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher colSums(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("colSums", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.SUM && ((AggUnaryOp)h).getDirection() == Direction.Col)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher colMeans(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("colSums", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher matrix(HopDagPatternMatcher X, HopDagPatternMatcher rows, HopDagPatternMatcher cols) {
+ return new HopDagPatternMatcher().addPredicate("matrix_reshape", h -> HopRewriteUtils.isReorg(h, ReOrgOp.RESHAPE))
+ .addChildMatcher(X, rows, cols);
+ }
+ public static HopDagPatternMatcher matrix(double X, HopDagPatternMatcher rows, HopDagPatternMatcher cols) {
+ return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X))
+ .addChildMatcher(rows, cols);
+ }
+ public static HopDagPatternMatcher matrix(double X, HopDagPatternMatcher rows, long cols) {
+ return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) &&
+ h.getDim2() == cols)
+ .addChildMatcher(rows, dummy);
+ }
+ public static HopDagPatternMatcher matrix(double X, long rows, HopDagPatternMatcher cols) {
+ return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) &&
+ h.getDim1() == rows)
+ .addChildMatcher(dummy, cols);
+ }
+ public static HopDagPatternMatcher matrix(double X, long rows, long cols) {
+ return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) &&
+ h.getDim1() == rows && h.getDim2() == cols)
+ .addChildMatcher(dummy, dummy);
+ }
+ public static HopDagPatternMatcher bias_add(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("bias_add", h -> HopRewriteUtils.isDnn(h, OpOpDnn.BIASADD))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher bias_multiply(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("bias_multiply", h -> HopRewriteUtils.isDnn(h, OpOpDnn.BIASMULT))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher unaryMinus(HopDagPatternMatcher child) {
+ return new HopDagPatternMatcher().addPredicate("unaryMinus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS)
+ && HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0))
+ .addChildMatcher(dummy, child);
+ }
+ public static HopDagPatternMatcher sqrt(HopDagPatternMatcher child) {
+ return new HopDagPatternMatcher().addPredicate("sqrt", h -> HopRewriteUtils.isUnary(h, OpOp1.SQRT))
+ .addChildMatcher(child);
+ }
+ public static HopDagPatternMatcher div(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher div(double child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+ .addChildMatcher(dummy, child2);
+ }
+ public static HopDagPatternMatcher div(HopDagPatternMatcher child1, double child2) {
+ return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+ .addChildMatcher(child1, dummy);
+ }
+
+ public static HopDagPatternMatcher pow(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("pow", h -> HopRewriteUtils.isBinary(h, OpOp2.POW))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher pow(HopDagPatternMatcher child1, double child2) {
+ return new HopDagPatternMatcher().addPredicate("pow", h -> HopRewriteUtils.isBinary(h, OpOp2.POW) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+ .addChildMatcher(child1, dummy);
+ }
+ private static boolean matchDimensions(Hop h1, Hop h2) {
+ return h1.getDim1() == h2.getDim1() && h1.getDim2() == h2.getDim2();
+ }
+ // This is used to differentiate between matrix-matrix and matrix-vector operations.
+ public static HopDagPatternMatcher mm_plus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS)
+ && matchDimensions(h.getInput().get(0), h.getInput().get(1)))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher plus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher plus(double child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+ .addChildMatcher(dummy, child2);
+ }
+ public static HopDagPatternMatcher plus(HopDagPatternMatcher child1, double child2) {
+ return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+ .addChildMatcher(child1, dummy);
+ }
+ public static HopDagPatternMatcher minus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher minus(double child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+ .addChildMatcher(dummy, child2);
+ }
+ public static HopDagPatternMatcher minus(HopDagPatternMatcher child1, double child2) {
+ return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+ .addChildMatcher(child1, dummy);
+ }
+ public static HopDagPatternMatcher mult(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher mult(double child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+ .addChildMatcher(dummy, child2);
+ }
+ public static HopDagPatternMatcher mult(HopDagPatternMatcher child1, double child2) {
+ return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+ .addChildMatcher(child1, dummy);
+ }
+ private static boolean _fitsOnGPU(Hop h, double multiplier) {
+ double memEst = multiplier*h.getMemEstimate();
+ return ConfigurationManager.isGPU() && h.dimsKnown() && OptimizerUtils.isMemoryBasedOptLevel() &&
+ memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
new file mode 100644
index 0000000..02472ed
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.function.Function;
+
+import org.apache.sysml.hops.Hop;
+
+/**
+ * This class is used with HopRewriteRuleWithPatternMatcher to implement the following pattern matching logic:
+ * ArrayList<HopPatternRewriter> patternRewriters = getPatternRewriter();
+ * for(HopPatternRewriter patternRewriter : patternRewriters) {
+ * hi = patternRewriter.rewrite(hi);
+ * }
+ *
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation.
+ */
+public class HopPatternRewriter {
+ private final HopDagPatternMatcher _matcher;
+ private final Function<Hop, Hop> _replacer;
+ private final String _name;
+ public HopPatternRewriter(String name, HopDagPatternMatcher matcher, Function<Hop, Hop> replacer) {
+ _name = name;
+ _matcher = matcher;
+ _replacer = replacer;
+ }
+
+ public Hop rewrite(Hop root) {
+ boolean printMessage = HopDagPatternMatcher.DEBUG_PATTERNS != null && HopDagPatternMatcher.DEBUG_PATTERNS.contains(_name);
+ if(printMessage) {
+ HopDagPatternMatcher.DEBUG_REWRITES = true;
+ System.out.println("-----------------------------------");
+ System.out.println(org.apache.sysml.utils.Explain.explain(root));
+ }
+ if(_matcher.matches(root)) {
+ Hop newHop = _replacer.apply(root);
+ if(printMessage) {
+ if(newHop == root)
+ System.out.println("Initial pattern match for " + _name + " succeeded but replacer returned the same HOP.");
+ else
+ System.out.println("Pattern match for " + _name + " succeeded.");
+ HopDagPatternMatcher.DEBUG_REWRITES = false;
+ System.out.println("-----------------------------------");
+ }
+ return newHop;
+ }
+ else {
+ if(printMessage) {
+ System.out.println("Pattern match for " + _name + " failed.");
+ HopDagPatternMatcher.DEBUG_REWRITES = false;
+ System.out.println("-----------------------------------");
+ }
+ return root;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
new file mode 100644
index 0000000..854eca3
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+
+import org.apache.sysml.hops.Hop;
+
+/**
+ * Simple utility class that implements generic structure for HopRewriteRule.
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation.
+ */
+public abstract class HopRewriteRuleWithPatternMatcher extends HopRewriteRule {
+
+ public abstract ArrayList<HopPatternRewriter> getPatternRewriter();
+
+ @Override
+ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
+ if( roots == null )
+ return roots;
+
+ //one pass rewrite-descend (rewrite created pattern)
+ for( int i = 0; i < roots.size(); i++ )
+ applyRules(roots, roots.get(i), false );
+ Hop.resetVisitStatus(roots, true);
+
+ //one pass descend-rewrite (for rollup)
+ for( int i = 0; i < roots.size(); i++ )
+ applyRules(roots, roots.get(i), true );
+ Hop.resetVisitStatus(roots, true);
+
+ return roots;
+ }
+
+ @Override
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
+ if( root == null )
+ return root;
+
+ //one pass rewrite-descend (rewrite created pattern)
+ applyRules(null, root, false );
+
+ root.resetVisitStatus();
+
+ //one pass descend-rewrite (for rollup)
+ applyRules(null, root, true );
+
+ return root;
+ }
+
+ /**
+ * Apply rules
+ *
+ * @param roots root operators
+ * @param hop high-level operator
+ * @param descendFirst true if recursively process children first
+ */
+ private void applyRules(ArrayList<Hop> roots, Hop hop, boolean descendFirst)
+ {
+ if(hop.isVisited())
+ return;
+
+ //recursively process children
+ for( int i=0; i<hop.getInput().size(); i++) {
+ Hop hi = hop.getInput().get(i);
+
+ //process childs recursively first (to allow roll-up)
+ if( descendFirst )
+ applyRules(roots, hi, descendFirst); //see below
+
+ ArrayList<HopPatternRewriter> patternRewriters = getPatternRewriter();
+ for(HopPatternRewriter patternRewriter : patternRewriters) {
+ hi = patternRewriter.rewrite(hi);
+ }
+
+ if( !descendFirst )
+ applyRules(roots, hi, descendFirst);
+ }
+
+ hop.setVisited();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 271142d..2351f5f 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -719,6 +719,26 @@ public class HopRewriteUtils
return ternOp;
}
+ public static DnnOp createDnnOp(OpOpDnn op, Hop... hops) {
+ ArrayList<Hop> inHops = new ArrayList<Hop>();
+ for(Hop h : hops) {
+ inHops.add(h);
+ }
+ return new DnnOp("tmp", DataType.MATRIX, ValueType.DOUBLE,
+ op, inHops);
+ }
+
+ public static DnnOp createDnnOp(HopDagPatternMatcher matcher, OpOpDnn op, String... varNames) {
+ ArrayList<Hop> inHops = new ArrayList<Hop>();
+ for(String v : varNames) {
+ inHops.add(matcher.getMatchedHop(v));
+ }
+ return new DnnOp("tmp", DataType.MATRIX, ValueType.DOUBLE,
+ op, inHops);
+ }
+
+
+
public static void setOutputParameters( Hop hop, long rlen, long clen, int brlen, int bclen, long nnz ) {
hop.setDim1( rlen );
hop.setDim2( clen );
[2/3] systemml git commit: [SYSTEMML-445] Removed batch_norm builtin
functions
Posted by ni...@apache.org.
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index acf2e48..53d368b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -20,811 +20,288 @@
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
+import java.util.function.Function;
-import org.apache.sysml.conf.ConfigurationManager;
-import org.apache.sysml.hops.AggUnaryOp;
-import org.apache.sysml.hops.BinaryOp;
-import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.FunctionOp.FunctionType;
-import org.apache.sysml.hops.Hop.AggOp;
-import org.apache.sysml.hops.Hop.DataOpTypes;
-import org.apache.sysml.hops.Hop.Direction;
-import org.apache.sysml.hops.Hop.OpOp1;
-import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.Hop.OpOpDnn;
-import org.apache.sysml.hops.Hop.ReOrgOp;
-import org.apache.sysml.hops.DataOp;
-import org.apache.sysml.hops.LiteralOp;
-import org.apache.sysml.hops.DnnOp;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.hops.ReorgOp;
-import org.apache.sysml.hops.UnaryOp;
-import org.apache.sysml.parser.DMLProgram;
-import org.apache.sysml.parser.Expression.DataType;
-import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
+import static org.apache.sysml.hops.rewrite.HopDagPatternMatcher.*;
+import static org.apache.sysml.parser.Expression.DataType.MATRIX;
+import static org.apache.sysml.parser.Expression.DataType.SCALAR;
+
/*
- * This class contains GPU-specific rewrites for following patterns:
+ * -------------------------------------------------------------------------
+ * Design documentation for hop rewrite rules that use HopDagPatternMatcher:
+ * -------------------------------------------------------------------------
+ *
+ * Typical (but not all) hop rewrite rules have following structure:
+ * 1. Rules are grouped together in different Java classes and added in org.apache.sysml.hops.rewrite.ProgramRewriter.
+ *
+ * 2. Each rule class inherits from HopRewriteRule and implements rewriteHopDAG method. Other class of rewrite rules are StatementBlockRewriteRule and are not covered by this approach.
+ *
+ * 3. The structure of rewriteHopDAG is common across HopRewriteRule subclasses and usually have following pattern:
+ * if(root of the given HOP DAG matches certain pattern) {
+ * HopRewriteUtils.rewireAllParentChildReferences(root, newRoot)
+ * }
+ * else root
+ *
+ * 4. To avoid redundancy, the above logic is implemented in the abstract class HopRewriteRuleWithPatternMatcher:
+ * ArrayList<HopPatternRewriter> patternRewriters = getPatternRewriter();
+ * for(HopPatternRewriter patternRewriter : patternRewriters) {
+ * hi = patternRewriter.rewrite(hi);
+ * }
+ *
+ * 5. The developer has to inherit from HopRewriteRuleWithPatternMatcher that implements the above logic
+ * and write code for getPatternRewriter() that returns ArrayList<HopPatternRewriter>
*
- * 1. batchNormTest: applied when mode="test" in batch normalization nn layer.
- * norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
- * hi = bias_add(bias_multiply(norm, gamma), beta)
+ * 6. Since the HOP pattern donot change during execution, it is convenient to implement them into a static variable:
+ * ArrayList<HopPatternRewriter> _rewriters
*
- * 2. channelSum:
- * output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))
+ * 7. The replacement part in each entry of patternMatcher invokes the helper methods in HopRewriteUtils to create a newRoot. For example: HopRewriteUtils.createDnnOp
*
- * 3. batchNormTrain: applied when mode="train" in batch normalization nn layer.
- * This rewrite is only enabled if none of the outputs are persistent writes as it assumes that
- * FunctionOp will introduce a transient writes. This rewrite replaces the existing outputs of the matched pattern with transient reads.
+ * 8. The below DSL is more readable if implemented with Scala's operator overloading, but it adds an dependency on scala library
+ * (in specific, scala uses scala.Function1 for implementing operator overloading).
+ * Hence, to minimize the dependency, the DSL is implemented using static methods in HopDagPatternMatcher class.
+ * We can revisit this if we plan to add scala as hard dependency in SystemML.
+ *
+ * 9. The matcher part in each entry of patternMatcher uses the DSL implemented in HopDagPatternMatcher to improve readability.
+ * - The DSL mentioned above follows DML syntax that makes it convenient for an external contributer to understand and modify the HOP rewrites.
+ * - It is important to note that the developer has to add the same scoping rules as SystemML.
+ * - To create a newRoot HOP, it is important to have a mechanism to extract leaves of the matched pattern. This is implemented
+ * by using leaf() method.
+ * - Often, it is important to create a new HOP only if it it can fit into memory. For GPU, one can use the fitsOnGPU(multiplier) helper method.
*
*/
-public class RewriteGPUSpecificOps extends HopRewriteRule {
-
- private static int _seq = 1;
-
- @Override
- public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
- if( roots == null )
- return roots;
-
- //one pass rewrite-descend (rewrite created pattern)
- for( int i = 0; i < roots.size(); i++ )
- rule_GPUKernels(roots, roots.get(i), false );
- Hop.resetVisitStatus(roots, true);
-
- //one pass descend-rewrite (for rollup)
- for( int i = 0; i < roots.size(); i++ )
- rule_GPUKernels(roots, roots.get(i), true );
- Hop.resetVisitStatus(roots, true);
+public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher {
+ // -------------------------------------------------------------------------------------------
+
+ private static HopDagPatternMatcher util_channel_sums(HopDagPatternMatcher X, HopDagPatternMatcher C, HopDagPatternMatcher HW) {
+ // rowSums(matrix(colSums(X), rows=C, cols=HW))
+ return rowSums(matrix( colSums(X), C, HW));
+ }
+
+ // Pattern 1:
+ private static final HopDagPatternMatcher _batchNormdX;
+ static {
+ HopDagPatternMatcher C = leaf("C", SCALAR);
+ HopDagPatternMatcher HW = leaf("HW", SCALAR);
+ HopDagPatternMatcher CHW = leaf("CHW", SCALAR);
+ HopDagPatternMatcher cache_inv_var = leaf("cache_inv_var", MATRIX);
+ HopDagPatternMatcher dout = leaf("dout", MATRIX);
+ HopDagPatternMatcher gamma = leaf("gamma", MATRIX);
+ HopDagPatternMatcher X = leaf("X", MATRIX);
+ HopDagPatternMatcher mean = leaf("mean", MATRIX);
- return roots;
- }
-
- @Override
- public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
- if( root == null )
- return root;
-
- //one pass rewrite-descend (rewrite created pattern)
- rule_GPUKernels(null, root, false );
-
- root.resetVisitStatus();
-
- //one pass descend-rewrite (for rollup)
- rule_GPUKernels(null, root, true );
+ HopDagPatternMatcher centered = bias_add(X, unaryMinus(mean));
- return root;
- }
-
- /**
- * Fuse the kernel
- *
- * @param roots root operators
- * @param hop high-level operator
- * @param descendFirst true if recursively process children first
- */
- private void rule_GPUKernels(ArrayList<Hop> roots, Hop hop, boolean descendFirst)
- {
- if(hop.isVisited())
- return;
-
- //recursively process children
- for( int i=0; i<hop.getInput().size(); i++) {
- Hop hi = hop.getInput().get(i);
-
- //process childs recursively first (to allow roll-up)
- if( descendFirst )
- rule_GPUKernels(roots, hi, descendFirst); //see below
-
- if(roots != null) {
- //hi = batchNormTrain(roots, hop, hi, i);
- }
- hi = batchNormTest(hop, hi, i);
- hi = channelSums(hop, hi, i);
- hi = updateNesterovX(hop, hi, i);
-
- if( !descendFirst )
- rule_GPUKernels(roots, hi, descendFirst);
- }
-
- hop.setVisited();
- }
-
- private static boolean isBiasAdd(Hop h) {
- return HopRewriteUtils.isDnn(h, OpOpDnn.BIASADD);
- }
-
- private static boolean isBiasMultiply(Hop h) {
- return HopRewriteUtils.isDnn(h, OpOpDnn.BIASMULT);
- }
-
- private static boolean fitsOnGPU(Hop h, double multiplier) {
- double memEst = multiplier*h.getMemEstimate();
- return ConfigurationManager.isGPU() && h.dimsKnown() && OptimizerUtils.isMemoryBasedOptLevel() &&
- memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget();
- }
-
- private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean isFirstSameSizeAsOutput) {
- return fitsOnGPU(inputHops, isFirstSameSizeAsOutput, 0);
- }
-
- private static boolean fitsOnGPU(ArrayList<Hop> inputHops, boolean isFirstSameSizeAsOutput, long additionalBytes) {
- double memEst = additionalBytes;
- boolean isFirst = true;
- for(Hop h : inputHops) {
- double est = h.getMemEstimate();
- if(est == OptimizerUtils.INVALID_SIZE) {
- return false;
- }
- else if(isFirst && isFirstSameSizeAsOutput) {
- isFirst = false;
- memEst += 2*est;
- }
- else {
- memEst += est;
- }
- }
- return ConfigurationManager.isGPU() && OptimizerUtils.isMemoryBasedOptLevel() &&
- memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget();
- }
-
- private static boolean hasFirstInput(Hop h) {
- return !(h == null || h.getInput() == null || h.getInput().size() < 1);
- }
-
- private static Hop getFirstInput(Hop h) {
- if(h == null || h.getInput() == null || h.getInput().size() < 1) {
- throw new RuntimeException("No input available for " + h);
- }
- return h.getInput().get(0);
- }
-
- private static boolean hasSecondInput(Hop h) {
- return !(h == null || h.getInput() == null || h.getInput().size() < 2);
- }
-
- private static Hop getSecondInput(Hop h) {
- if(h == null || h.getInput() == null || h.getInput().size() < 2) {
- throw new RuntimeException("Expected atleast two inputs for " + h);
+ // dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win)
+ HopDagPatternMatcher dnorm = bias_multiply(dout, gamma);
+ // dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win)
+ HopDagPatternMatcher dmean_norm_branch = util_channel_sums(bias_multiply(dnorm, unaryMinus(cache_inv_var)), C, HW) ;
+ // dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm,
+ // C, Hin, Win) # shape (C, 1)
+ HopDagPatternMatcher dvar = util_channel_sums(mult(mult(-0.5, bias_multiply(centered, pow(cache_inv_var, 3))), dnorm), C, HW);
+ // dmean_var_branch = util::channel_sums((-2*oneByN*oneByHW) * centered, C, Hin, Win) * dvar
+ HopDagPatternMatcher dmean_var_branch =
+ mult(util_channel_sums(mult(leaf("const3", SCALAR), centered), C, HW), dvar);
+ // dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
+ HopDagPatternMatcher dX_norm_branch = bias_multiply(dnorm, cache_inv_var);
+ // dX_mean_branch = (oneByN*oneByHW) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)
+ HopDagPatternMatcher dX_mean_branch = mult(leaf("const1", SCALAR), bias_add(matrix(0, 1, CHW),
+ plus(dmean_norm_branch, dmean_var_branch) ));
+ // dX_var_branch = (2*oneByN*oneByHW) * bias_multiply(centered, dvar)
+ HopDagPatternMatcher dX_var_branch = mult(leaf("const2", SCALAR), bias_multiply(centered, dvar));
+ _batchNormdX = plus(plus(dX_norm_branch, dX_mean_branch), dX_var_branch).fitsOnGPU(2);
+ }
+ private static final Function<Hop, Hop> _batchNormdXReplacer = hi -> {
+ // double CHW = _batchNormdX.getLiteralValue("CHW");
+ double HW = _batchNormdX.getLiteralValue("HW");
+ double C = _batchNormdX.getLiteralValue("C");
+ double const1 = _batchNormdX.getLiteralValue("const1"); // (oneByN*oneByHW)
+ double const2 = _batchNormdX.getLiteralValue("const2"); // (2*oneByN*oneByHW)
+ double const3 = _batchNormdX.getLiteralValue("const3"); // (-2*oneByN*oneByHW)
+ if(2*const1 == const2 && const3 == -const2 &&
+ hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("mean")) &&
+ hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("cache_inv_var")) &&
+ _batchNormdX.getMatchedHop("X").getDim2() == C*HW &&
+ checkDimensions(_batchNormdX.getMatchedHop("gamma"), (long)C, 1)) {
+ LOG.debug("Applied batchNormdX rewrite.");
+ Hop newHop = HopRewriteUtils.createDnnOp(_batchNormdX, OpOpDnn.BATCH_NORM2D_BACKWARD_DX,
+ "X", "dout", "gamma", "mean", "cache_inv_var");
+ return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
}
- return h.getInput().get(1);
- }
-
- private static Hop getThirdInput(Hop h) {
- if(h == null || h.getInput() == null || h.getInput().size() < 3) {
- throw new RuntimeException("Expected atleast three inputs for " + h);
- }
- return h.getInput().get(2);
- }
-
- private static boolean isUnaryMinus(Hop h) {
- return HopRewriteUtils.isBinary(h, OpOp2.MINUS)
- && HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0);
- }
-
- private static boolean isOneDivideBySqrt(Hop h) {
- return HopRewriteUtils.isBinary(h, OpOp2.DIV)
- && HopRewriteUtils.isUnary(h.getInput().get(1), OpOp1.SQRT)
- && HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 1);
- }
-
- private static Hop channelSums(Hop parent, Hop hi, int pos) {
- if(hi instanceof AggUnaryOp) {
- AggUnaryOp hop = (AggUnaryOp) hi;
- // output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))
- if( hop.getOp() == AggOp.SUM && hop.getDirection() == Direction.Row
- && HopRewriteUtils.isReorg(hop.getInput().get(0), ReOrgOp.RESHAPE) ) {
- Hop colSumsInput = hop.getInput().get(0).getInput().get(0);
- if(colSumsInput instanceof AggUnaryOp && ((AggUnaryOp)colSumsInput).getOp() == AggOp.SUM && ((AggUnaryOp)colSumsInput).getDirection() == Direction.Col) {
- ArrayList<Hop> inHops = new ArrayList<Hop>();
- inHops.add(colSumsInput.getInput().get(0));
- long numChannels = Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(1));
- long HW = Hop.computeSizeInformation(hop.getInput().get(0).getInput().get(2));
- if(numChannels > 0 && HW > 0 && fitsOnGPU(inHops, false, numChannels*8)) {
- inHops.add(new LiteralOp(numChannels));
- inHops.add(new LiteralOp(HW));
- LOG.debug("Applied channelSums rewrite.");
- Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
- OpOpDnn.CHANNEL_SUMS, inHops);
- return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
- }
- }
- }
+ else if(DEBUG_REWRITES) {
+ System.out.println("Couldnot apply batchNormdX rewrite.");
+ System.out.println((2*const1) + " == " + const2 + " && " + const3 + "== -" + const2
+ + " && " + hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("mean")) + " && " +
+ hasSameDimensions(_batchNormdX.getMatchedHop("gamma"), _batchNormdX.getMatchedHop("cache_inv_var")) + " && " +
+ _batchNormdX.getMatchedHop("X").getDim2() + " == " + C + "*" + HW + " && " +
+ checkDimensions(_batchNormdX.getMatchedHop("gamma"), (long)C, 1));
}
return hi;
- }
-
- private static boolean isRowMeans(Hop h) {
- return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row;
- }
-
- private static boolean isRowVars(Hop h) {
- return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row;
- }
-
- private static boolean isRowVars(Hop h, Hop childHop) {
- return isRowVars(h) && getFirstInput(h) == childHop;
- }
-
- private static boolean isColMeans(Hop h) {
- return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col;
- }
-
- private static boolean isColVars(Hop h) {
- return h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col;
- }
-
- private static boolean isReshape(Hop h) {
- return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE;
- }
-
- private static boolean isReshape(Hop h, long expectedRows, long expectedCols) {
- return h instanceof ReorgOp && ((ReorgOp)h).getOp() == ReOrgOp.RESHAPE &&
- Hop.computeSizeInformation(getSecondInput(h)) == expectedRows &&
- Hop.computeSizeInformation(getThirdInput(h)) == expectedCols;
- }
-
- private static boolean isBinaryAdd(Hop h) {
- return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS;
- }
-
- private static boolean isBinaryMSAdd(Hop h, double expectedValue) {
- return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS
- && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
- && OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue;
- }
-
- private static boolean isBinaryMMAdd(Hop h) {
- return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS
- && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
- }
-
- private static boolean isBinaryMMMinus(Hop h) {
- return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS
- && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
- }
-
- private static boolean isBinaryMSMult(Hop h, double expectedValue) {
- return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT
- && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
- && OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(h), new HashMap<>()) == expectedValue;
- }
-
- private static boolean isBinarySSMinus(Hop h) {
- return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS
- && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
- }
-
- private static boolean isBinarySSDiv(Hop h) {
- return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV
- && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.SCALAR;
- }
-
- private static boolean isBinarySMDiv(Hop h, double expectedValue) {
- return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.DIV
- && getFirstInput(h).getDataType() == DataType.SCALAR && getSecondInput(h).getDataType() == DataType.MATRIX
- && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>()) == expectedValue;
- }
-
- private static boolean isAnyBinaryAdd(ArrayList<Hop> hops) {
- if(hops != null) {
- for(Hop h : hops) {
- if(h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.PLUS)
- return true;
- }
- }
- return false;
- }
-
- private static boolean isBinaryMSMult(Hop h) {
- return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT
- && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR;
- }
-
- private static boolean isBinarySMMult(Hop h) {
- return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT
- && getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR;
- }
-
- private static boolean isBinarySMMult(Hop h, double expectedVal) {
- return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT
- && getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR
- && getValue(getFirstInput(h)) == expectedVal;
- }
-
- private static double getValue(Hop h) {
- return OptimizerUtils.rEvalSimpleDoubleExpression(h, new HashMap<>());
- }
-
- /**
- * Checks if the "mean" hop is a moving average of mean in batch normalization layer.
- *
- * @param mean hop to check against
- * @param X input data
- * @return true if the "mean" hop is a moving average of mean in batch normalization layer.
- */
- private static boolean isBatchNormTrainMean(Hop mean, Hop X) {
- // subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
- // mean = rowMeans(subgrp_means)
- return isRowMeans(mean) && isReshape(getFirstInput(mean)) && isColMeans(getFirstInput(getFirstInput(mean)))
- && getFirstInput(getFirstInput(getFirstInput(mean))) == X;
- }
-
- /**
- * Checks for nrow(X) pattern
- *
- * @param expr hop to be matched
- * @param X input X
- * @return true if expr is nrow(X) else false
- */
- private static boolean isNrowOfX(Hop expr, Hop X) {
- return expr instanceof UnaryOp && ((UnaryOp)expr).getOp() == OpOp1.NROW && getFirstInput(expr) == X;
- }
-
- /**
- * Checks for the colVars(X) * ((N-1)/N) pattern
- *
- * @param expr hop to be matched
- * @param X input X
- * @param ignoreCorrectionTerm whether to ignore the correction term ((N-1)/N).
- * @return true if expr is colVars(X) * ((N-1)/N) else false
- */
- private static boolean isCorrectedColVars(Hop expr, Hop X, boolean ignoreCorrectionTerm) {
- // colVars(X) * ((N-1)/N)
- if(isColVars(expr) && getFirstInput(expr) == X) {
- // Support no correction as well in this rewrite
- return true;
- }
- else if(X.rowsKnown()) {
- return isBinaryMSMult(expr, ((double)X.getDim1()-1)/X.getDim1()) &&
- isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X;
+ };
+
+
+
+ // Pattern 2:
+ // subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
+ // var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
+ private static final HopDagPatternMatcher _batchNormUpdatedVar;
+ static {
+ HopDagPatternMatcher subgrp_vars =
+ matrix(
+ mult(colVars(leaf("X", MATRIX).fitsOnGPU(2)), leaf("varConst1", SCALAR)), // colVars(X) * ((N-1)/N)
+ leaf("C", SCALAR), // rows=C
+ leaf("HW", SCALAR)); // cols=Hin*Win
+ _batchNormUpdatedVar =
+ mm_plus(
+ rowMeans(subgrp_vars),
+ mult(rowVars(leaf("subgrp_means", MATRIX)), leaf("varConst2", SCALAR))); // rowVars(subgrp_means)*varConst2
+ }
+ private static final Function<Hop, Hop> _batchNormUpdatedVarReplacer = hi -> {
+ double HW = _batchNormUpdatedVar.getLiteralValue("HW");
+ if(_batchNormUpdatedVar.getLiteralValue("varConst2") == ((HW-1)/HW)) {
+ LOG.debug("Applied batchNormUpdatedVar rewrite.");
+ Hop newHop = HopRewriteUtils.createDnnOp(_batchNormUpdatedVar, OpOpDnn.UPDATE_EMA_VAR,
+ // varConst1 => ((N-1)/N)
+ "subgrp_means", "X", "C", "HW", "varConst1");
+ return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
}
- else if(isBinaryMSMult(expr) &&
- isColVars(getFirstInput(expr)) && getFirstInput(getFirstInput(expr)) == X) {
- if(ignoreCorrectionTerm) {
- return true;
- }
- Hop tmp = getSecondInput(expr);
- // ((N-1)/N)
- boolean isNMinus1Pattern = isBinarySSDiv(tmp) && isBinarySSMinus(getFirstInput(tmp)) &&
- getFirstInput(getFirstInput(tmp)) == getSecondInput(tmp) &&
- OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(getFirstInput(tmp)), new HashMap<>()) == 1;
- boolean ret = isNMinus1Pattern && isNrowOfX(getSecondInput(tmp), X);
- if(LOG.isDebugEnabled()) {
- LOG.debug("Is the corrected column variance pattern for batch_norm_train rewrite when number of rows of X unknown matched:" + ret);
- }
- return ret;
- }
- return false;
- }
-
- /**
- * Checks if the "var" hop is a moving average of variance in batch normalization layer.
- *
- * @param mean previously matched mean hop
- * @param var the hop to check against
- * @param X input data hop
- * @param subgrpMeans mean for subgroup mean
- * @param ignoreCorrectionTerm whether to incore the correct term (see isCorrectedColVars method in this class)
- * @return true if the "var" hop is a moving average of variance in batch normalization layer.
- */
- private static boolean isBatchNormTrainVar(Hop mean, Hop var, Hop X, Hop subgrpMeans, boolean ignoreCorrectionTerm) {
- long numChannels = Hop.computeSizeInformation(getSecondInput(getFirstInput(mean)));
- long HW = Hop.computeSizeInformation(getThirdInput(getFirstInput(mean)));
- // subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
- // var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
- return numChannels > 0 && HW > 0 && isBinaryMMAdd(var) && isRowMeans(getFirstInput(var)) &&
- // matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)
- isReshape(getFirstInput(getFirstInput(var)), numChannels, HW) &&
- isCorrectedColVars(getFirstInput(getFirstInput(getFirstInput(var))), X, ignoreCorrectionTerm) &&
- // rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))
- isBinaryMSMult(getSecondInput(var), ((((double)HW)-1)/HW)) &&
- isRowVars(getFirstInput(getSecondInput(var)), subgrpMeans);
- }
-
- /**
- * Checks and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
- *
- * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean
- * @param mu value of mu
- * @return an array [ema_mean_upd, ema_mean] if expression matched, else null
- */
- private static Hop [] getUpdatedMovingAverageExpressions(Hop rhsTimesOp, double mu) {
- if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 ||
- !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
- return null;
-
- // Check (1-mu)*mean
- double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>());
- Hop plusOp = rhsTimesOp.getParent().get(0);
- Hop lhsTimesOp = null;
- if(plusOp.getInput().get(0) == rhsTimesOp) {
- lhsTimesOp = plusOp.getInput().get(1);
- }
- else {
- lhsTimesOp = plusOp.getInput().get(0);
- }
-
- if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 &&
- isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) {
- return new Hop[] {
- plusOp.getParent().get(0),
- getSecondInput(lhsTimesOp),
- getSecondInput(rhsTimesOp)
- };
- }
- return null;
- }
-
- /**
- * Checks (if exactly one of rhsTimesOps) and returns the matched hops for expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
- *
- * @param rhsTimesOps array list of hop representing BinaryOp of expression (1-mu)*mean
- * @param mu value of mu
- * @return an array [ema_mean_upd, ema_mean] if any of the expression matched, else null
- */
- private static Hop [] getUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps, double mu) {
- if(rhsTimesOps == null || rhsTimesOps.size() == 0)
- return null;
-
- Hop [] ret = null;
- for(Hop h : rhsTimesOps) {
- boolean matched = isUpdatedMovingAverageExpression(h, mu);
- if(matched && ret != null) {
- return null; // Multiple matches, cannot decide which one to fuse
- }
- else if(matched) {
- ret = getUpdatedMovingAverageExpressions(h, mu);
- }
- }
-
- return ret;
- }
-
- /**
- * Checks and returns the mu in the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
- *
- * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean
- * @return value of mu if the expression matched else null
- */
- private static Double getMuFromUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps) {
- if(rhsTimesOps == null || rhsTimesOps.size() == 0)
- return null;
-
- Double ret = null;
- for(Hop h : rhsTimesOps) {
- boolean matched = isUpdatedMovingAverageExpression(h);
- if(matched && ret != null) {
- return null; // Multiple matches, cannot decide which one to fuse
- }
- else if(matched) {
- ret = -(OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>())-1);
- }
- }
- return ret;
- }
-
- /**
- * Checks for the expression ema_mean_upd = mu*ema_mean + (1-mu)*mean
- *
- * @param rhsTimesOps hop representing BinaryOp of expression (1-mu)*mean
- * @return true if expression matched
- */
- private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp) {
- if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 ||
- !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
- return false;
-
- // Check (1-mu)*mean
- Hop plusOp = rhsTimesOp.getParent().get(0);
- Hop lhsTimesOp = null;
- if(plusOp.getInput().get(0) == rhsTimesOp) {
- lhsTimesOp = plusOp.getInput().get(1);
- }
- else {
- lhsTimesOp = plusOp.getInput().get(0);
- }
-
- if(plusOp.getParent() != null && plusOp.getParent().size() == 1 && isBinarySMMult(lhsTimesOp)) {
- return true;
- }
- return false;
- }
+ return hi;
+ };
- // ema_mean_upd = mu*ema_mean + (1-mu)*mean
- // Returns true if expression matched, else false
- private static boolean isUpdatedMovingAverageExpression(Hop rhsTimesOp, double mu) {
- if(rhsTimesOp == null || rhsTimesOp.getParent() == null || rhsTimesOp.getParent().size() != 1 ||
- !isBinarySMMult(rhsTimesOp) || !isBinaryAdd(rhsTimesOp.getParent().get(0)))
- return false;
-
- // Check (1-mu)*mean
- double expectedOneMinusMu = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(rhsTimesOp), new HashMap<>());
- Hop plusOp = rhsTimesOp.getParent().get(0);
- Hop lhsTimesOp = null;
- if(plusOp.getInput().get(0) == rhsTimesOp) {
- lhsTimesOp = plusOp.getInput().get(1);
- }
- else {
- lhsTimesOp = plusOp.getInput().get(0);
- }
- if(expectedOneMinusMu == (1-mu) && plusOp.getParent() != null && plusOp.getParent().size() == 1 &&
- isBinarySMMult(lhsTimesOp) && OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(lhsTimesOp), new HashMap<>()) == mu) {
- return true;
- }
- return false;
- }
-
- /**
- * Checks for the expression 1/sqrt(denom)
- *
- * @param denom denominator of the expression to be matched
- * @return true if the expression 1/sqrt(denom) matched else false
- */
- private static boolean isOneBySqrt(Hop denom) {
- return denom.getParent() != null && denom.getParent().get(0) instanceof UnaryOp &&
- ((UnaryOp)denom.getParent().get(0)).getOp() == OpOp1.SQRT &&
- denom.getParent().get(0).getParent() != null && denom.getParent().get(0).getParent().size() == 1 &&
- isBinarySMDiv(denom.getParent().get(0).getParent().get(0), 1);
- }
-
- /**
- * Checks for the batch norm (mode="train") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar
- * and returns a new FunctionOp if matched
- *
- * @param roots root hops of the given statement block
- * @param parent parent of the input
- * @param hi input to be matched
- * @param pos position
- * @return a new FunctionOp or hi
- */
- @SuppressWarnings("unused")
- private static Hop batchNormTrain(ArrayList<Hop> roots, Hop parent, Hop hi, int pos)
- {
+ // Pattern 3:
+ private static final HopDagPatternMatcher _batchNormTest;
+ static {
// norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
+ HopDagPatternMatcher norm =
+ bias_multiply(
+ bias_add(leaf("X", MATRIX), unaryMinus(leaf("mean", MATRIX))), // bias_add(X, -mean)
+ div(1, sqrt(plus(leaf("var", MATRIX), leaf("eps", SCALAR))))); // 1/sqrt(var+eps)
// hi = bias_add(bias_multiply(norm, gamma), beta)
- // 2x for input and output and 1x for overhead
- // fitsOnGPU(hi, 3)
- if( hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) ) {
- Hop norm = getFirstInput(getFirstInput(hi));
- if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm))
- && hasSecondInput(getFirstInput(norm)) && isUnaryMinus(getSecondInput(getFirstInput(norm)))
- && isOneDivideBySqrt(getSecondInput(norm))) {
- double eps = 0;
- Hop var = getFirstInput(getSecondInput(getSecondInput(norm)));
- if(isBinaryAdd(var) && (getFirstInput(var) instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) {
- // eps + ema_var
- if(getFirstInput(var) instanceof LiteralOp) {
- eps = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>());
- var = getSecondInput(var);
- }
- else {
- eps = OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new HashMap<>());
- var = getFirstInput(var);
- }
- }
- // Generate batch norm test op
- Hop X = getFirstInput(getFirstInput(norm));
- Hop mean = getSecondInput(getSecondInput(getFirstInput(norm)));
-
- if(hasFirstInput(mean) && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), false) &&
- mean.getParent() != null && mean.getParent().size() >= 2 &&
- var.getParent() != null && var.getParent().size() == 2) {
- Hop gamma = getSecondInput(getFirstInput(hi));
- Hop beta = getSecondInput(hi);
-
- // Always get mu from variance as it will have exactly one match of fusion pattern
- Double potentialMu = getMuFromUpdatedMovingAverageExpressions(var.getParent());
- if(potentialMu == null)
- return hi;
- double mu = potentialMu;
-
- Hop [] means = getUpdatedMovingAverageExpressions(mean.getParent(), mu);
- Hop [] vars = getUpdatedMovingAverageExpressions(var.getParent(), mu);
- if(means == null || vars == null)
- return hi;
-
- Hop varPlusEps = null;
- boolean isFirstBinaryAddOp = isAnyBinaryAdd(var.getParent().get(0).getParent());
- boolean isSecondBinaryAddOp = isAnyBinaryAdd(var.getParent().get(1).getParent());
- if(isFirstBinaryAddOp && !isSecondBinaryAddOp) {
- varPlusEps = var.getParent().get(1);
- }
- else if(!isFirstBinaryAddOp && isSecondBinaryAddOp) {
- varPlusEps = var.getParent().get(0);
- }
- if(varPlusEps != null && isBinaryMSAdd(varPlusEps, eps) && isOneBySqrt(varPlusEps)) {
-
- Hop cache_var = varPlusEps.getParent().get(0).getParent().get(0);
- Hop ema_mean_upd = means[0];
- Hop ema_var_upd = vars[0];
- Hop ema_mean = means[1];
- Hop ema_var = vars[1];
- Hop cache_mean = means[2];
-
-
- ArrayList<Hop> inHops = new ArrayList<Hop>();
- inHops.add(X);
- inHops.add(gamma);
- inHops.add(beta);
- inHops.add(ema_mean);
- inHops.add(ema_var);
- inHops.add(new LiteralOp(eps));
- inHops.add(new LiteralOp(mu));
- Hop [] oldHops = {hi, ema_mean_upd, ema_var_upd, cache_mean, cache_var};
-
- // Since FunctionOp adds transientwrite explicitly, persistent writes are not supported
- if(!isAnyPersistentWrite(oldHops)) {
- LOG.debug("Applied batchNormTrain rewrite.");
- ArrayList<Hop> outputs = getMultiOutputHops(roots, oldHops);
- FunctionOp ret = new FunctionOp(FunctionType.MULTIRETURN_BUILTIN, DMLProgram.INTERNAL_NAMESPACE, "batch_norm2d_train",
- null, inHops, outputs.stream().map(h -> h.getName()).toArray(String[]::new), outputs);
- Collections.reverse(roots);
- roots.add(ret);
- Collections.reverse(roots);
- return ret;
- }
- }
-
- }
+ _batchNormTest =
+ bias_add(
+ bias_multiply(norm, leaf("gamma", MATRIX)),
+ leaf("beta", MATRIX))
+ .fitsOnGPU(3);
+ }
+ private static final Function<Hop, Hop> _batchNormTestReplacer = hi -> {
+ LOG.debug("Applied batchNormTest rewrite.");
+ Hop newHop = HopRewriteUtils.createDnnOp(_batchNormTest, OpOpDnn.BATCH_NORM2D_TEST, "X", "gamma", "beta", "mean", "var", "eps");
+ return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+ };
+
+ // Pattern 4:
+ // rowSums(matrix(colSums(X), rows=C, cols=HW))
+ private static final HopDagPatternMatcher _channelSums = util_channel_sums(leaf("X", MATRIX).fitsOnGPU(2), leaf("C", SCALAR), leaf("HW", SCALAR));;
+ private static final Function<Hop, Hop> _channelSumsReplacer = hi -> {
+ LOG.debug("Applied channelSums rewrite.");
+ Hop newHop = HopRewriteUtils.createDnnOp(_channelSums, OpOpDnn.CHANNEL_SUMS, "X", "C", "HW");
+ return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+ };
+
+ // Pattern 5:
+ // (X - mu*v_prev) + (1+mu)*v
+ private static final HopDagPatternMatcher _updateNesterovX =
+ mm_plus(
+ minus( // X - mu*v_prev
+ leaf("X", MATRIX),
+ mult( // mu*v_prev
+ leaf("mu", SCALAR),
+ leaf("v_prev", MATRIX))),
+ mult( // (1+mu)*v
+ leaf("onePlusMu", SCALAR),
+ leaf("v", MATRIX)))
+ .fitsOnGPU(3);
+ private static final Function<Hop, Hop> _updateNesterovXReplacer = hi -> {
+ if((1+_updateNesterovX.getLiteralValue("mu")) == _updateNesterovX.getLiteralValue("onePlusMu")) {
+ Hop X = _updateNesterovX.getMatchedHop("X");
+ Hop v = _updateNesterovX.getMatchedHop("v");
+ Hop v_prev = _updateNesterovX.getMatchedHop("v_prev");
+ if(hasSameDimensions(X, v) && hasSameDimensions(X, v_prev)) {
+ LOG.debug("Applied updateNesterovX rewrite.");
+ Hop newHop = HopRewriteUtils.createDnnOp(_updateNesterovX, OpOpDnn.UPDATE_NESTEROV_X, "X", "v", "v_prev", "mu");
+ return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
}
}
-
return hi;
- }
-
- // ------------------------------------------------------------
- /**
- * Checks if any of the given output hop is a persistent write.
- *
- * @param outputHops output hops to check
- * @return true if any of the hop is a persistent write else false.
- */
- private static boolean isAnyPersistentWrite(Hop [] outputHops) {
- for(Hop outHop : outputHops) {
- if(HopRewriteUtils.isData(outHop, DataOpTypes.PERSISTENTWRITE))
- return true;
+ };
+
+ // Pattern 6:
+ // matrix(colMeans(X), rows=C, cols=Hin*Win)
+ // This avoids unnecessary copy by the reshape operator
+ private static final HopDagPatternMatcher _reshapeColMeans =
+ matrix(
+ colMeans(leaf("X", MATRIX).fitsOnGPU(2)), // colMeans(X)
+ leaf("C", SCALAR),
+ leaf("HW", SCALAR));
+ private static final Function<Hop, Hop> _reshapeColMeansReplacer = hi -> {
+ LOG.debug("Applied reshapeColMeans rewrite.");
+ Hop newHop = HopRewriteUtils.createDnnOp(_reshapeColMeans, OpOpDnn.RESHAPE_COLMEANS, "X", "C", "HW");
+ return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+ };
+
+ // Pattern 7:
+ // mu*ema_mean + (1-mu)*mean
+ private static final HopDagPatternMatcher _updateEMA =
+ mm_plus(
+ mult( // mu*ema_mean
+ leaf("mu", SCALAR),
+ leaf("ema_mean", MATRIX)),
+ mult( // (1-mu)*mean
+ leaf("oneMinusMu", SCALAR),
+ leaf("mean", MATRIX)))
+ .fitsOnGPU(3);
+ private static final Function<Hop, Hop> _updateEMAReplacer = hi -> {
+ if((1-_updateEMA.getLiteralValue("mu")) == _updateEMA.getLiteralValue("oneMinusMu")) {
+ LOG.debug("Applied updateEMA rewrite.");
+ Hop newHop = HopRewriteUtils.createDnnOp(_updateEMA, OpOpDnn.UPDATE_EMA, "ema_mean", "mean", "mu");
+ return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
}
- return false;
- }
-
- /**
- * Returns output hop for a multi-output FunctionOp to be created by rewrite.
- *
- * @param roots root hops of statement block
- * @param oldHops old output hops of the pattern
- * @return new output hops that should be passed to FunctionOp
- */
- private static ArrayList<Hop> getMultiOutputHops(ArrayList<Hop> roots, Hop [] oldHops) {
- ArrayList<Hop> ret = new ArrayList<>();
- for(int i = 0; i < oldHops.length; i++) {
- // Create a transient read as FunctionOp will add a transient write.
- if(HopRewriteUtils.isData(oldHops[i], DataOpTypes.PERSISTENTWRITE))
- throw new RuntimeException("Persistent write is not supported as output for the given rewrite." + oldHops[i]);
- // Generate a new name if the old output was not transient write.
- String name = HopRewriteUtils.isData(oldHops[i], DataOpTypes.TRANSIENTWRITE) ? oldHops[i].getName() : "_genGPU" + (_seq++);
- DataOp tRead = HopRewriteUtils.createTransientRead(name, oldHops[i]);
- HopRewriteUtils.rewireAllParentChildReferences(oldHops[i], tRead);
- ret.add(tRead);
- // Remove old output from roots to avoid unnecessary computation.
- if(roots.contains(oldHops[i])) {
- roots.remove(oldHops[i]);
- }
+ return hi;
+ };
+
+ // Pattern 8:
+ // 1/sqrt(var+epsilon)
+ private static final HopDagPatternMatcher _invVar =
+ div(1,
+ sqrt( // var+epsilon
+ plus( leaf("var", MATRIX),
+ leaf("eps", SCALAR) )))
+ .fitsOnGPU(2);
+ private static final Function<Hop, Hop> _invVarReplacer = hi -> {
+ LOG.debug("Applied computeInverseVariance rewrite.");
+ Hop newHop = HopRewriteUtils.createDnnOp(_invVar, OpOpDnn.INV_VAR, "var", "eps");
+ return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+ };
+
+
+ private static ArrayList<HopPatternRewriter> _rewriters = null;
+ public ArrayList<HopPatternRewriter> getPatternRewriter() {
+ if(_rewriters == null) {
+ ArrayList<HopPatternRewriter> rewriters = new ArrayList<>();
+ rewriters.add(new HopPatternRewriter("batchNormdX", _batchNormdX, _batchNormdXReplacer));
+ rewriters.add(new HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, _batchNormUpdatedVarReplacer));
+ rewriters.add(new HopPatternRewriter("batchNormTest", _batchNormTest, _batchNormTestReplacer));
+ rewriters.add(new HopPatternRewriter("channelSums", _channelSums, _channelSumsReplacer));
+ rewriters.add(new HopPatternRewriter("updateNesterovX", _updateNesterovX, _updateNesterovXReplacer));
+ rewriters.add(new HopPatternRewriter("reshapeColMeans", _reshapeColMeans, _reshapeColMeansReplacer));
+ rewriters.add(new HopPatternRewriter("updateEMA", _updateEMA, _updateEMAReplacer));
+ rewriters.add(new HopPatternRewriter("invVar", _invVar, _invVarReplacer));
+ _rewriters = rewriters;
}
- return ret;
+ return _rewriters;
}
- // ------------------------------------------------------------
- /**
- * Checks for the nesterov_update_x pattern (X = X - mu*v_prev + (1+mu)*v)
- * and returns a new DnnOp if matched
- *
- * @param parent parent of the input
- * @param hi input to be matched
- * @param pos position
- * @return a new DnnOp or hi
- */
- private static Hop updateNesterovX(Hop parent, Hop hi, int pos) {
- if(fitsOnGPU(hi, 4) && isBinaryMMAdd(hi) && isBinaryMMMinus(getFirstInput(hi))
- && isBinarySMMult(getSecondInput(getFirstInput(hi)))
- && isBinarySMMult(getSecondInput(hi))) {
- Hop onePlusMu = getFirstInput(getSecondInput(hi));
- Hop tmp = getSecondInput(getFirstInput(hi));
- Hop mu = getFirstInput(tmp);
- if(isOnePlusMu(onePlusMu, mu)) {
- Hop v_prev = getSecondInput(tmp);
- Hop v = getSecondInput(getSecondInput(hi));
- Hop X = getFirstInput(getFirstInput(hi));
- if(hasSameDimensions(X, v) && hasSameDimensions(X, v_prev)) {
- ArrayList<Hop> inHops = new ArrayList<Hop>();
- inHops.add(X);
- inHops.add(v);
- inHops.add(v_prev);
- inHops.add(mu);
- LOG.debug("Applied updateNesterovX rewrite.");
- Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
- OpOpDnn.UPDATE_NESTEROV_X, inHops);
- return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
- }
- }
- }
- return hi;
- }
+
+ // -------------------------------------------------------------------------------------------
private static boolean hasSameDimensions(Hop x, Hop y) {
return x.dimsKnown() && y.dimsKnown() && (x.getDim1() == y.getDim1()) && (x.getDim2() == y.getDim2());
}
- private static boolean isOnePlusMu(Hop onePlusMu, Hop mu) {
- return (isBinarySMMult(onePlusMu, 1.0) && getSecondInput(onePlusMu) == mu) ||
- getValue(onePlusMu) == getValue(mu) + 1;
- }
-
- /**
- * Checks for the batch norm (mode="test") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar
- * and returns a new DnnOp if matched
- *
- * @param parent parent of the input
- * @param hi input to be matched
- * @param pos position
- * @return a new DnnOp or hi
- */
- private static Hop batchNormTest(Hop parent, Hop hi, int pos) {
- // norm = bias_multiply(bias_add(X, -mean), 1/sqrt(var+eps))
- // hi = bias_add(bias_multiply(norm, gamma), beta)
- // 2x for input and output and 1x for overhead
- if(hasFirstInput(hi) && isBiasAdd(hi) && isBiasMultiply(getFirstInput(hi)) && fitsOnGPU(hi, 3) ) {
- Hop norm = getFirstInput(getFirstInput(hi));
- if(hasSecondInput(norm) && isBiasMultiply(norm) && isBiasAdd(getFirstInput(norm))
- && isUnaryMinus(getSecondInput(getFirstInput(norm)))
- && isOneDivideBySqrt(getSecondInput(norm))) {
- double eps = 0;
- Hop var = getFirstInput(getSecondInput(getSecondInput(norm)));
- if( HopRewriteUtils.isBinary(var, OpOp2.PLUS) &&
- (getFirstInput(var) instanceof LiteralOp || getSecondInput(var) instanceof LiteralOp)) {
- // eps + ema_var
- if(getFirstInput(var) instanceof LiteralOp) {
- eps = OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(var), new HashMap<>());
- var = getSecondInput(var);
- }
- else {
- eps = OptimizerUtils.rEvalSimpleDoubleExpression(getSecondInput(var), new HashMap<>());
- var = getFirstInput(var);
- }
- }
- // Generate batch norm test op
- Hop X = getFirstInput(getFirstInput(norm));
- Hop mean = getSecondInput(getSecondInput(getFirstInput(norm)));
-
- // This guard disallows eager fusion of train batch normalization into test batch normalization
- boolean potentialForBatchNormTrain = !X.rowsKnown() && isBatchNormTrainMean(mean , X) && isBatchNormTrainVar(mean, var, X, getFirstInput(mean), true);
- if(!potentialForBatchNormTrain) {
- Hop gamma = getSecondInput(getFirstInput(hi));
- Hop beta = getSecondInput(hi);
- ArrayList<Hop> inHops = new ArrayList<Hop>();
- inHops.add(X);
- inHops.add(gamma);
- inHops.add(beta);
- inHops.add(mean);
- inHops.add(var);
- inHops.add(new LiteralOp(eps));
- if(fitsOnGPU(inHops, true)) {
- LOG.debug("Applied batchNormTest rewrite.");
- Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
- OpOpDnn.BATCH_NORM2D_TEST, inHops);
- return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
- }
- }
- else {
- LOG.debug("Skipping batchNormTest rewrite as there is potential for batch normalization train rewrite after recompilation.");
- }
- }
- }
-
- return hi;
+ private static boolean checkDimensions(Hop x, long dim1, long dim2) {
+ return x.dimsKnown() && (x.getDim1() == dim1) && (x.getDim2() == dim2);
}
-}
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/lops/DnnTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java b/src/main/java/org/apache/sysml/lops/DnnTransform.java
index 3183b5f..2d2d5f1 100644
--- a/src/main/java/org/apache/sysml/lops/DnnTransform.java
+++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java
@@ -32,7 +32,8 @@ public class DnnTransform extends Lop
RELU_MAX_POOLING, RELU_MAX_POOLING_BACKWARD, RELU_BACKWARD,
CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST,
- UPDATE_NESTEROV_X
+ UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR,
+ BATCH_NORM2D_BACKWARD_DX
}
private OperationTypes operation;
@@ -167,11 +168,26 @@ public class DnnTransform extends Lop
case CHANNEL_SUMS:
return "channel_sums";
+ case INV_VAR:
+ return "inv_var";
+
case UPDATE_NESTEROV_X:
return "update_nesterov_x";
case BATCH_NORM2D_TEST:
return "batch_norm2d_test";
+
+ case BATCH_NORM2D_BACKWARD_DX:
+ return "batch_norm2d_bwd_dx";
+
+ case UPDATE_EMA_VAR:
+ return "update_ema_var";
+
+ case UPDATE_EMA:
+ return "update_ema";
+
+ case RESHAPE_COLMEANS:
+ return "reshape_colmeans";
default:
throw new UnsupportedOperationException(this.printErrorLocation() + "Instruction is not defined for Transform operation " + operation);
@@ -181,7 +197,8 @@ public class DnnTransform extends Lop
@Override
public String getInstructions(String input, String bias, String output) {
- if(operation == OperationTypes.BIAS_ADD || operation == OperationTypes.BIAS_MULTIPLY || operation == OperationTypes.RELU_BACKWARD) {
+ if(operation == OperationTypes.BIAS_ADD || operation == OperationTypes.BIAS_MULTIPLY || operation == OperationTypes.RELU_BACKWARD
+ || operation == OperationTypes.INV_VAR) {
StringBuilder sb = new StringBuilder();
sb.append( getExecType() );
@@ -190,7 +207,7 @@ public class DnnTransform extends Lop
sb.append( OPERAND_DELIMITOR );
sb.append( getInputs().get(0).prepInputOperand(input));
sb.append( OPERAND_DELIMITOR );
- sb.append( getInputs().get(0).prepInputOperand(bias));
+ sb.append( getInputs().get(1).prepInputOperand(bias));
//output
sb.append( OPERAND_DELIMITOR );
sb.append( this.prepOutputOperand(output));
@@ -212,7 +229,7 @@ public class DnnTransform extends Lop
@Override
public String getInstructions(String input, String C, String HW, String output) {
- if(operation == OperationTypes.CHANNEL_SUMS) {
+ if(operation == OperationTypes.CHANNEL_SUMS || operation == OperationTypes.RESHAPE_COLMEANS || operation == OperationTypes.UPDATE_EMA) {
StringBuilder sb = new StringBuilder();
sb.append( getExecType() );
@@ -306,6 +323,34 @@ public class DnnTransform extends Lop
throw new LopsException("The operation is not supported with six operands:" + operation.name());
}
}
+
+ public String getInstructions(String input1, String input2, String input3, String input4, String input5, String output) {
+ if(operation == OperationTypes.UPDATE_EMA_VAR || operation == OperationTypes.BATCH_NORM2D_BACKWARD_DX) {
+ StringBuilder sb = new StringBuilder();
+ sb.append( getExecType() );
+
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getOpcode() );
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getInputs().get(0).prepInputOperand(input1));
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getInputs().get(1).prepInputOperand(input2));
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getInputs().get(2).prepInputOperand(input3));
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getInputs().get(3).prepInputOperand(input4));
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getInputs().get(4).prepInputOperand(input5));
+ //output
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( this.prepOutputOperand(output));
+
+ return sb.toString();
+ }
+ else {
+ throw new LopsException("The operation is not supported with six operands:" + operation.name());
+ }
+ }
public void appendOpcode(StringBuilder sb) {
sb.append( getExecType() );
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 9f3a1e2..fe86dc8 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -270,58 +270,6 @@ public class BuiltinFunctionExpression extends DataIdentifier
setDimensions(dc0, getFifthExpr());
break;
}
- case BATCH_NORM2D:
- {
- // Input: image, scale, bias, runningMean, runningVar, mode, epsilon, exponentialAverageFactor
- checkNumParameters(8);
- checkMatrixParam(getFirstExpr());
- checkMatrixParam(getSecondExpr());
- checkMatrixParam(getThirdExpr());
- checkMatrixParam(getFourthExpr());
- checkMatrixParam(getFifthExpr());
-
- // Output: ret, retRunningMean, retRunningVar, resultSaveMean, resultSaveInvVariance
- // setup output properties
- if(getOutputs().length != 5)
- raiseValidateError("batch_norm2d has 5 outputs", false);
-
- DataIdentifier ret = (DataIdentifier) getOutputs()[0];
- DataIdentifier retRunningMean = (DataIdentifier) getOutputs()[1];
- DataIdentifier retRunningVar = (DataIdentifier) getOutputs()[2];
- DataIdentifier resultSaveMean = (DataIdentifier) getOutputs()[3];
- DataIdentifier resultSaveInvVariance = (DataIdentifier) getOutputs()[4];
-
- setDimensions(ret, getFirstExpr());
- setDimensions(retRunningMean, getFourthExpr());
- setDimensions(retRunningVar, getFourthExpr());
- setDimensions(resultSaveMean, getFourthExpr());
- setDimensions(resultSaveInvVariance, getFourthExpr());
- break;
- }
- case BATCH_NORM2D_BACKWARD:
- {
- // Input: image, dout, scale, epsilon, savedMean, savedInvVariance
- checkNumParameters(6);
- checkMatrixParam(getFirstExpr());
- checkMatrixParam(getSecondExpr());
- checkMatrixParam(getThirdExpr());
- checkMatrixParam(getFifthExpr());
- checkMatrixParam(getSixthExpr());
-
- // Output: dX, dScale, dBias
- // setup output properties
- if(getOutputs().length != 3)
- raiseValidateError("batch_norm2d_backward has 3 outputs", false);
-
- DataIdentifier dX = (DataIdentifier) getOutputs()[0];
- DataIdentifier dScale = (DataIdentifier) getOutputs()[1];
- DataIdentifier dBias = (DataIdentifier) getOutputs()[2];
-
- setDimensions(dX, getFirstExpr());
- setDimensions(dScale, getThirdExpr());
- setDimensions(dBias, getThirdExpr());
- break;
- }
case EIGEN:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
@@ -1451,8 +1399,7 @@ public class BuiltinFunctionExpression extends DataIdentifier
// always unconditional (because unsupported operation)
BuiltinFunctionOp op = getOpCode();
if( op==BuiltinFunctionOp.EIGEN || op==BuiltinFunctionOp.LU || op==BuiltinFunctionOp.QR || op==BuiltinFunctionOp.SVD
- || op==BuiltinFunctionOp.LSTM || op==BuiltinFunctionOp.LSTM_BACKWARD
- || op==BuiltinFunctionOp.BATCH_NORM2D || op==BuiltinFunctionOp.BATCH_NORM2D_BACKWARD)
+ || op==BuiltinFunctionOp.LSTM || op==BuiltinFunctionOp.LSTM_BACKWARD)
raiseValidateError("Function "+op+" needs to be called with multi-return assignment.", false, LanguageErrorCodes.INVALID_PARAMETERS);
else
raiseValidateError("Unsupported function "+op, false, LanguageErrorCodes.INVALID_PARAMETERS);
@@ -1535,8 +1482,6 @@ public class BuiltinFunctionExpression extends DataIdentifier
case EIGEN:
case LSTM:
case LSTM_BACKWARD:
- case BATCH_NORM2D:
- case BATCH_NORM2D_BACKWARD:
case SVD:
return true;
default:
@@ -1956,10 +1901,6 @@ public class BuiltinFunctionExpression extends DataIdentifier
bifop = Expression.BuiltinFunctionOp.LSTM;
else if (functionName.equals("lstm_backward"))
bifop = Expression.BuiltinFunctionOp.LSTM_BACKWARD;
- else if (functionName.equals("batch_norm2d"))
- bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D;
- else if (functionName.equals("batch_norm2d_backward"))
- bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D_BACKWARD;
else if (functionName.equals("conv2d"))
bifop = Expression.BuiltinFunctionOp.CONV2D;
else if (functionName.equals("bias_add"))
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index e9b643e..e3db435 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2271,8 +2271,6 @@ public class DMLTranslator
case EIGEN:
case LSTM:
case LSTM_BACKWARD:
- case BATCH_NORM2D:
- case BATCH_NORM2D_BACKWARD:
case SVD:
// Number of outputs = size of targetList = #of identifiers in source.getOutputs
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index 46e6442..33fca66 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -93,7 +93,7 @@ public abstract class Expression implements ParseInfo
EXISTS,
CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIASADD, BIASMULT,
MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD, AVG_POOL_BACKWARD,
- LSTM, LSTM_BACKWARD, BATCH_NORM2D, BATCH_NORM2D_BACKWARD,
+ LSTM, LSTM_BACKWARD,
EXP,
FLOOR,
IFELSE,
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index f4122d9..3480504 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -59,11 +59,13 @@ public class GPUInstructionParser extends InstructionParser
String2GPUInstructionType.put( "channel_sums", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "lstm", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "lstm_backward", GPUINSTRUCTION_TYPE.Dnn);
- String2GPUInstructionType.put( "batch_norm2d", GPUINSTRUCTION_TYPE.Dnn);
- String2GPUInstructionType.put( "batch_norm2d_backward", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "batch_norm2d_test", GPUINSTRUCTION_TYPE.Dnn);
- String2GPUInstructionType.put( "batch_norm2d_train", GPUINSTRUCTION_TYPE.Dnn);
- String2GPUInstructionType.put( "update_nesterov_x", GPUINSTRUCTION_TYPE.Dnn);
+ String2GPUInstructionType.put( "update_nesterov_x", GPUINSTRUCTION_TYPE.Dnn);
+ String2GPUInstructionType.put( "update_ema_var", GPUINSTRUCTION_TYPE.Dnn);
+ String2GPUInstructionType.put( "update_ema", GPUINSTRUCTION_TYPE.Dnn);
+ String2GPUInstructionType.put( "reshape_colmeans", GPUINSTRUCTION_TYPE.Dnn);
+ String2GPUInstructionType.put( "inv_var", GPUINSTRUCTION_TYPE.Dnn);
+ String2GPUInstructionType.put( "batch_norm2d_bwd_dx", GPUINSTRUCTION_TYPE.Dnn);
// Matrix Multiply Operators
String2GPUInstructionType.put( "ba+*", GPUINSTRUCTION_TYPE.AggregateBinary);