You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/10/16 22:46:41 UTC
systemml git commit: [SYSTEMML-540] Reduce the number of unknowns in
ConvolutionOp
Repository: systemml
Updated Branches:
refs/heads/master 2ca2d8aa7 -> 5adb330de
[SYSTEMML-540] Reduce the number of unknowns in ConvolutionOp
- This commit reduces the unknowns during dynamic recompilation by inferring the
input's height/width of ConvolutionOp based on its parent's output's
height/width.
- Additionally, for developer debugging, I have guarded the functionality
with the flag INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP and have added
sufficient documentation to explain how these dimensions are inferred.
Closes #685.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5adb330d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5adb330d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5adb330d
Branch: refs/heads/master
Commit: 5adb330deffa5479475338316bf47193d0c31da4
Parents: 2ca2d8a
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Mon Oct 16 15:44:37 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Mon Oct 16 15:45:39 2017 -0700
----------------------------------------------------------------------
.../org/apache/sysml/hops/ConvolutionOp.java | 170 ++++++++++++++++---
1 file changed, 144 insertions(+), 26 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/5adb330d/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index e732fb8..e4ed32b 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -32,11 +32,21 @@ import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
-
import java.util.ArrayList;
public class ConvolutionOp extends Hop implements MultiThreadedHop
{
+ // -------------------------------------------------------------------------
+ // This flag allows us to compile plans with less unknowns and also serves as future tensorblock integration.
+ // By default, these flags are turned on.
+
+ // When this flag is turned on, we attempt to check the parent convolution hop for unknown dimensions.
+ // For example: in case of conv -> maxpool, the input channel/height/width of maxpool will match output channel/height/width of conv.
+ private static final boolean INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP = true;
+ // This guards us from cases where the user provides incorrect C,H,W parameters.
+ private static final boolean THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH = true;
+ // -------------------------------------------------------------------------
+
private Hop.ConvOp op;
private int _maxNumThreads = -1; //-1 for unlimited
@@ -475,17 +485,21 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4
ConvolutionParameters parseInput() throws DMLRuntimeException {
+
+ Hop imageHeightHop = null; Hop filterHeightHop = null;
if(op == ConvOp.MAX_POOLING_BACKWARD
|| op == ConvOp.DIRECT_CONV2D
|| op == ConvOp.DIRECT_CONV2D_BACKWARD_FILTER
|| op == ConvOp.DIRECT_CONV2D_BACKWARD_DATA) {
+ imageHeightHop = getInput().get(8);
+ filterHeightHop = getInput().get(12);
_cachedParams.setIfUnknown(
getInput().get(6),
getInput().get(7),
- getInput().get(8),
+ imageHeightHop,
getInput().get(9),
getInput().get(10),
- getInput().get(12),
+ filterHeightHop,
getInput().get(13),
getInput().get(2),
getInput().get(3),
@@ -493,22 +507,127 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop
getInput().get(5), _maxNumThreads);
}
else {
+ imageHeightHop = getInput().get(7);
+ filterHeightHop = getInput().get(11);
_cachedParams.setIfUnknown(
getInput().get(5),
getInput().get(6),
- getInput().get(7),
+ imageHeightHop,
getInput().get(8),
getInput().get(9),
- getInput().get(11),
+ filterHeightHop,
getInput().get(12),
getInput().get(1),
getInput().get(2),
getInput().get(3),
getInput().get(4), _maxNumThreads);
}
+
+ if(INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP) {
+ boolean isMaxPool = getOp() == ConvOp.MAX_POOLING;
+ boolean isConv = getOp() == ConvOp.DIRECT_CONV2D;
+ boolean unknownCHWPQ = _cachedParams.C < 0 || _cachedParams.H < 0 || _cachedParams.W < 0 || _cachedParams.P < 0 || _cachedParams.Q < 0;
+ if((isMaxPool || isConv) && unknownCHWPQ) {
+ // Only infer input shape for convolution and maxpool
+ inferCHWPQFromParentOp();
+ }
+ }
+
+ if(imageHeightHop == filterHeightHop && _cachedParams.R < 0 && _cachedParams.H > 0) {
+ // Unknown R, but known H and both are equal
+ // This happens for one-dimensional conv2d where H=R and H can be inferred from the parent hop
+ _cachedParams.R = _cachedParams.H;
+ }
+
+ // Compute P and Q if unknown. At script level, they are computed using following script:
+ // P = as.integer(floor((H + 2*pad_h - R)/stride_h + 1))
+ // Q = as.integer(floor((W + 2*pad_w - S)/stride_w + 1))
+ if(_cachedParams.P < 0 && _cachedParams.H >= 0 && _cachedParams.R >= 0 && _cachedParams.stride_h >= 0 && _cachedParams.pad_h >= 0) {
+ _cachedParams.P = (int) org.apache.sysml.runtime.util.ConvolutionUtils.getP(_cachedParams.H, _cachedParams.R, _cachedParams.stride_h, _cachedParams.pad_h);
+ }
+ if(_cachedParams.Q < 0 && _cachedParams.W >= 0 && _cachedParams.S >= 0 && _cachedParams.stride_w >= 0 && _cachedParams.pad_w >= 0) {
+ _cachedParams.Q = (int) org.apache.sysml.runtime.util.ConvolutionUtils.getQ(_cachedParams.W, _cachedParams.S, _cachedParams.stride_w, _cachedParams.pad_w);
+ }
+
return _cachedParams;
}
+ /**
+ * Utility method to check if the given hop is a BIAS_ADD hop
+ *
+ * @param hop the given hop
+ * @return true if the given hop is BIAS_ADD
+ */
+ private static boolean isInputBiasAdd(Hop hop) {
+ if(hop instanceof ConvolutionOp && ((ConvolutionOp) hop).getOp() == ConvOp.BIAS_ADD) {
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * Utility method to check if the inferred shapes are equal to the given shape with a guard for unknown
+ *
+ * @param dim1 inferred shape
+ * @param dim2 given shape
+ * @param paramType string denoting the parameter for pretty printing of the error message
+ * @throws DMLRuntimeException if dim1 != dim2
+ */
+ private void throwExceptionIfNotEqual(int dim1, int dim2, String paramType) throws DMLRuntimeException {
+ if(dim1 >= 0 && dim2 >= 0 && dim1 != dim2) {
+ throw new DMLRuntimeException("Inferred " + paramType + " from parent doesn't match with given " + paramType + ":" + dim1 + " != " + dim2);
+ }
+ }
+
+ /**
+ * Gets the values for the parameters C, H, W, P, Q from parent hops
+ *
+ * @throws DMLRuntimeException if error occurs
+ */
+ private void inferCHWPQFromParentOp() throws DMLRuntimeException {Hop tmp = getInput().get(0);
+ while(isInputReLU(tmp) || isInputBiasAdd(tmp)) {
+ // Skip ReLU and bias_add and go to its parent
+ tmp = tmp.getInput().get(0);
+ }
+ // Cast tmp as parent
+ ConvolutionOp parentOp = (tmp instanceof ConvolutionOp) ? ((ConvolutionOp) tmp) : null;
+
+ if(parentOp == null)
+ return;
+ else if(parentOp.getOp() == ConvOp.MAX_POOLING) {
+ ConvolutionParameters parentParam = parentOp.parseInput();
+ int prevC = _cachedParams.C; int prevH = _cachedParams.H; int prevW = _cachedParams.W;
+ // [C, P, Q] from maxpool becomes [C, H, W] of next op
+ _cachedParams.C = (_cachedParams.C < 0) ? parentParam.C : _cachedParams.C;
+ _cachedParams.H = (_cachedParams.H < 0) ? parentParam.P : _cachedParams.H;
+ _cachedParams.W = (_cachedParams.W < 0) ? parentParam.Q : _cachedParams.W;
+ if(LOG.isDebugEnabled()) {
+ LOG.debug("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + _cachedParams.C + "," + _cachedParams.H + "," + _cachedParams.W + "]");
+ }
+ if(THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH) {
+ throwExceptionIfNotEqual(prevC, _cachedParams.C, "C");
+ throwExceptionIfNotEqual(prevH, _cachedParams.H, "H");
+ throwExceptionIfNotEqual(prevW, _cachedParams.W, "W");
+ }
+ }
+ else if(parentOp.getOp() == ConvOp.DIRECT_CONV2D) {
+ ConvolutionParameters parentParam = parentOp.parseInput();
+ int prevC = _cachedParams.C; int prevH = _cachedParams.H; int prevW = _cachedParams.W;
+ // [K, P, Q] from convolution becomes [C, H, W] of next op
+ _cachedParams.C = (_cachedParams.C < 0) ? parentParam.K : _cachedParams.C;
+ _cachedParams.H = (_cachedParams.H < 0) ? parentParam.P : _cachedParams.H;
+ _cachedParams.W = (_cachedParams.W < 0) ? parentParam.Q : _cachedParams.W;
+ if(LOG.isDebugEnabled()) {
+ LOG.debug("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + _cachedParams.C + "," + _cachedParams.H + "," + _cachedParams.W + "]");
+ }
+ if(THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH) {
+ throwExceptionIfNotEqual(prevC, _cachedParams.C, "C");
+ throwExceptionIfNotEqual(prevH, _cachedParams.H, "H");
+ throwExceptionIfNotEqual(prevW, _cachedParams.W, "W");
+ }
+ }
+ }
+
@Override
public void refreshSizeInformation()
{
@@ -620,9 +739,8 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop
if(op == ConvOp.BIAS_ADD || op == ConvOp.BIAS_MULTIPLY) {
throw new RuntimeException("getDim method should not be invoked for bias_add and bias_multiply");
}
- ConvolutionParameters params;
try {
- params = parseInput();
+ parseInput();
} catch (DMLRuntimeException e) {
throw new RuntimeException(e);
}
@@ -653,49 +771,49 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop
long ret = -1;
if(dimString.equals("K") && filter != null) {
- ret = getNonNegative(ret, getNonNegative(params.K, filter._dim1));
+ ret = getNonNegative(ret, getNonNegative(_cachedParams.K, filter._dim1));
}
else if(dimString.equals("CRS") && filter != null) {
- ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(params.C, params.R, params.S), filter._dim2));
+ ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.R, _cachedParams.S), filter._dim2));
}
else if(dimString.equals("N") && input != null) {
- ret = getNonNegative(ret, getNonNegative(params.N, input._dim1));
+ ret = getNonNegative(ret, getNonNegative(_cachedParams.N, input._dim1));
}
else if(dimString.equals("CHW") && input != null) {
- ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(params.C, params.H, params.W), input._dim2));
+ ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.H, _cachedParams.W), input._dim2));
}
else if(dimString.equals("N") && dout != null) {
- ret = getNonNegative(ret, getNonNegative(params.N, dout._dim1));
+ ret = getNonNegative(ret, getNonNegative(_cachedParams.N, dout._dim1));
}
else if(dimString.equals("KPQ") && dout != null) {
- ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(params.K, params.P, params.Q), dout._dim2));
+ ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.K, _cachedParams.P, _cachedParams.Q), dout._dim2));
}
else if(dimString.equals("N") && dout1 != null) {
- ret = getNonNegative(ret, getNonNegative(params.N, dout1._dim1));
+ ret = getNonNegative(ret, getNonNegative(_cachedParams.N, dout1._dim1));
}
else if(dimString.equals("CPQ") && dout1 != null) {
- ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(params.C, params.P, params.Q), dout1._dim2));
+ ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.P, _cachedParams.Q), dout1._dim2));
}
else if(dimString.equals("K")) {
- ret = getNonNegative(ret, params.K >= 0 ? params.K : -1);
+ ret = getNonNegative(ret, _cachedParams.K >= 0 ? _cachedParams.K : -1);
}
else if(dimString.equals("CRS")) {
- ret = getNonNegative(ret, nonNegativeMultiply(params.C, params.R, params.S));
+ ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.R, _cachedParams.S));
}
else if(dimString.equals("N")) {
- ret = getNonNegative(ret, params.N >= 0 ? params.N : -1);
+ ret = getNonNegative(ret, _cachedParams.N >= 0 ? _cachedParams.N : -1);
}
else if(dimString.equals("CHW")) {
- ret = getNonNegative(ret, nonNegativeMultiply(params.C, params.H, params.W));
+ ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.H, _cachedParams.W));
}
else if(dimString.equals("KPQ")) {
- ret = getNonNegative(ret, nonNegativeMultiply(params.K, params.P, params.Q));
+ ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.K, _cachedParams.P, _cachedParams.Q));
}
else if(dimString.equals("PQ")) {
- ret = getNonNegative(ret, nonNegativeMultiply(params.P, params.Q));
+ ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.P, _cachedParams.Q));
}
else if(dimString.equals("CPQ")) {
- ret = getNonNegative(ret, nonNegativeMultiply(params.C, params.P, params.Q));
+ ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.P, _cachedParams.Q));
}
else {
throw new RuntimeException("Unsupported dimension:" + dimString + " for operator " + getOp().name());
@@ -703,10 +821,10 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop
if(LOG.isDebugEnabled() && ret < 0) {
LOG.debug("Unknown dimension " + dimString + " for ConvolutionOp:" + op.name() +
- " img_dim=[" + params.N + " " + params.C + " " + params.H + " " + params.W + "]" +
- " filter_dim=[" + params.K + " " + params.C + " " + params.H + " " + params.W + "]" +
- " output_feature_map=[" + params.P + " " + params.Q + "] stride=[" + params.stride_h + " " + params.stride_w + "]" +
- " pad=[" + params.pad_h + " " + params.pad_w + "]");
+ " img_dim=[" + _cachedParams.N + " " + _cachedParams.C + " " + _cachedParams.H + " " + _cachedParams.W + "]" +
+ " filter_dim=[" + _cachedParams.K + " " + _cachedParams.C + " " + _cachedParams.R + " " + _cachedParams.S + "]" +
+ " output_feature_map=[" + _cachedParams.P + " " + _cachedParams.Q + "] stride=[" + _cachedParams.stride_h + " " + _cachedParams.stride_w + "]" +
+ " pad=[" + _cachedParams.pad_h + " " + _cachedParams.pad_w + "]");
}
return ret;
}