You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/07/25 01:58:52 UTC
systemml git commit: [SYSTEMML-2109] Codegen support for
maxpool/avgpool DNN operations
Repository: systemml
Updated Branches:
refs/heads/master 15ecb723e -> 99b1c2e25
[SYSTEMML-2109] Codegen support for maxpool/avgpool DNN operations
This patch adds code generation support for maxpool and avgpool DNN
operations to the codegen row-template. This way often entire joins of
conv/maxpool/relu can be executed as fused operators without
parallelization barriers per operator.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/99b1c2e2
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/99b1c2e2
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/99b1c2e2
Branch: refs/heads/master
Commit: 99b1c2e252f935bbce5f53574d3e245221da3e68
Parents: 15ecb72
Author: Matthias Boehm <mb...@gmail.com>
Authored: Tue Jul 24 18:32:50 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Tue Jul 24 18:57:00 2018 -0700
----------------------------------------------------------------------
src/main/java/org/apache/sysml/hops/DnnOp.java | 6 ++
.../apache/sysml/hops/codegen/cplan/CNode.java | 24 ++++++++
.../sysml/hops/codegen/cplan/CNodeNary.java | 58 ++++++++++++++++++--
.../sysml/hops/codegen/cplan/CNodeUnary.java | 22 +-------
.../hops/codegen/template/TemplateRow.java | 14 ++++-
.../runtime/codegen/LibSpoofPrimitives.java | 40 ++++++++++++++
.../matrix/data/LibMatrixDNNPooling.java | 16 +++---
.../functions/codegen/RowAggTmplTest.java | 6 +-
.../scripts/functions/codegen/rowAggPattern44.R | 1 +
.../functions/codegen/rowAggPattern44.dml | 1 +
10 files changed, 154 insertions(+), 34 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java
index 8dbbeda..3b48371 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -217,6 +217,12 @@ public class DnnOp extends MultiThreadedHop
isEqualAndKnown(param1.H, param2.H) && isEqualAndKnown(param1.W, param2.W);
}
+ public boolean isStride1Pad0() {
+ DnnParameters tmp = parseInput();
+ return tmp.stride_h == 1 && tmp.stride_w == 1
+ && tmp.pad_h == 0 && tmp.pad_w == 0;
+ }
+
private static boolean isEqualAndKnown(int val1, int val2) {
return val1 >= 0 && val2 >= 0 && val1 == val2;
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
index b0efb42..6dab878 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.codegen.cplan;
import java.util.ArrayList;
+import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.util.UtilFunctions;
@@ -204,4 +205,27 @@ public abstract class CNode
&& _dataType == cthat._dataType
&& _literal == cthat._literal;
}
+
+ protected String replaceUnaryPlaceholders(String tmp, String varj, boolean vectIn) {
+ //replace sparse and dense inputs
+ tmp = tmp.replace("%IN1v%", varj+"vals");
+ tmp = tmp.replace("%IN1i%", varj+"ix");
+ tmp = tmp.replace("%IN1%",
+ (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? varj + ".values(rix)" :
+ (vectIn && TemplateUtils.isRowVector(_inputs.get(0)) ? varj + ".values(0)" : varj));
+
+ //replace start position of main input
+ String spos = (_inputs.get(0) instanceof CNodeData
+ && _inputs.get(0).getDataType().isMatrix()) ? !varj.startsWith("b") ?
+ varj+"i" : TemplateUtils.isMatrix(_inputs.get(0)) ? varj + ".pos(rix)" : "0" : "0";
+
+ tmp = tmp.replace("%POS1%", spos);
+ tmp = tmp.replace("%POS2%", spos);
+
+ //replace length
+ if( _inputs.get(0).getDataType().isMatrix() )
+ tmp = tmp.replace("%LEN%", _inputs.get(0).getVectorLength());
+
+ return tmp;
+ }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
index 7f19194..e720601 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
@@ -20,15 +20,21 @@
package org.apache.sysml.hops.codegen.cplan;
import java.util.ArrayList;
+import java.util.List;
+import org.apache.commons.lang3.StringUtils;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.util.DnnUtils;
import org.apache.sysml.runtime.util.UtilFunctions;
public class CNodeNary extends CNode
{
public enum NaryType {
- VECT_CBIND;
+ VECT_CBIND,
+ VECT_MAX_POOL,
+ VECT_AVG_POOL;
+
public static boolean contains(String value) {
for( NaryType bt : values() )
if( bt.name().equals(value) )
@@ -56,12 +62,19 @@ public class CNodeNary extends CNode
off += input._cols;
}
return sb.toString();
+ case VECT_MAX_POOL:
+ case VECT_AVG_POOL:
+ String vectName = (this==VECT_MAX_POOL) ? "Maxpool" : "Avgpool";
+ String paramStr = getPoolingParameterString(inputs);
+ return sparseGen ?
+ " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len, "+paramStr+");\n" :
+ " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%, "+paramStr+");\n";
default:
throw new RuntimeException("Invalid nary type: "+this.toString());
}
}
public boolean isVectorPrimitive() {
- return this == VECT_CBIND;
+ return this == VECT_CBIND || this == VECT_MAX_POOL || this == VECT_AVG_POOL;
}
}
@@ -90,10 +103,17 @@ public class CNodeNary extends CNode
sb.append(in.codegen(sparse));
//generate nary operation (use sparse template, if data input)
+ boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData
+ && _inputs.get(0).getVarname().startsWith("a")
+ && !_inputs.get(0).isLiteral());
String var = createVarname();
- String tmp = _type.getTemplate(sparse, _cols, _inputs);
+ String tmp = _type.getTemplate(lsparse, _cols, _inputs);
tmp = tmp.replace("%TMP%", var);
+ //replace sparse and dense inputs
+ String varj = _inputs.get(0).getVarname();
+ tmp = replaceUnaryPlaceholders(tmp, varj, false);
+
sb.append(tmp);
//mark as generated
@@ -105,7 +125,9 @@ public class CNodeNary extends CNode
@Override
public String toString() {
switch(_type) {
- case VECT_CBIND: return "n(cbind)";
+ case VECT_CBIND: return "n(cbind)";
+ case VECT_MAX_POOL: return "n(maxpool)";
+ case VECT_AVG_POOL: return "n(avgpool)";
default:
return "m("+_type.name().toLowerCase()+")";
}
@@ -121,6 +143,19 @@ public class CNodeNary extends CNode
_cols += in._cols;
_dataType = DataType.MATRIX;
break;
+ case VECT_MAX_POOL:
+ case VECT_AVG_POOL: //only stride 1, pad 0
+ int C = Integer.parseInt(_inputs.get(6).getVarname());
+ int H = Integer.parseInt(_inputs.get(7).getVarname());
+ int W = Integer.parseInt(_inputs.get(8).getVarname());
+ int R = Integer.parseInt(_inputs.get(11).getVarname());
+ int S = Integer.parseInt(_inputs.get(12).getVarname());
+ long P = DnnUtils.getP(H, R, 1, 0);
+ long Q = DnnUtils.getQ(W, S, 1, 0);
+ _rows = _inputs.get(0)._rows;
+ _cols = C * P * Q;
+ _dataType = DataType.MATRIX;
+ break;
}
}
@@ -142,4 +177,19 @@ public class CNodeNary extends CNode
return super.equals(that)
&& _type == that._type;
}
+
+ private static String getPoolingParameterString(List<CNode> inputs) {
+ //extract and derive individual parameters
+ int C = Integer.parseInt(inputs.get(6).getVarname());
+ int H = Integer.parseInt(inputs.get(7).getVarname());
+ int W = Integer.parseInt(inputs.get(8).getVarname());
+ int R = Integer.parseInt(inputs.get(11).getVarname());
+ int S = Integer.parseInt(inputs.get(12).getVarname());
+ int P = (int) DnnUtils.getP(H, R, 1, 0);
+ int Q = (int) DnnUtils.getQ(W, S, 1, 0);
+
+ //construct parameter string
+ return "rix, " + StringUtils.join(
+ new int[]{C, P, Q, R, S, H, W}, ',');
+ }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index ba41fad..21f7fe7 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -23,7 +23,6 @@ import java.util.Arrays;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.StringUtils;
-import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.util.UtilFunctions;
@@ -214,27 +213,10 @@ public class CNodeUnary extends CNode
String tmp = _type.getTemplate(lsparse);
tmp = tmp.replace("%TMP%", var);
- String varj = _inputs.get(0).getVarname();
-
//replace sparse and dense inputs
+ String varj = _inputs.get(0).getVarname();
boolean vectIn = varj.startsWith("b") && !_type.isScalarLookup();
- tmp = tmp.replace("%IN1v%", varj+"vals");
- tmp = tmp.replace("%IN1i%", varj+"ix");
- tmp = tmp.replace("%IN1%",
- (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? varj + ".values(rix)" :
- (vectIn && TemplateUtils.isRowVector(_inputs.get(0)) ? varj + ".values(0)" : varj));
-
- //replace start position of main input
- String spos = (_inputs.get(0) instanceof CNodeData
- && _inputs.get(0).getDataType().isMatrix()) ? !varj.startsWith("b") ?
- varj+"i" : TemplateUtils.isMatrix(_inputs.get(0)) ? varj + ".pos(rix)" : "0" : "0";
-
- tmp = tmp.replace("%POS1%", spos);
- tmp = tmp.replace("%POS2%", spos);
-
- //replace length
- if( _inputs.get(0).getDataType().isMatrix() )
- tmp = tmp.replace("%LEN%", _inputs.get(0).getVectorLength());
+ tmp = replaceUnaryPlaceholders(tmp, varj, vectIn);
sb.append(tmp);
http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index 9df67d0..9885909 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -115,7 +115,9 @@ public class TemplateRow extends TemplateBase
&& HopRewriteUtils.isColumnRangeIndexing((IndexingOp)hop))
|| (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, OpOpDnn.BIASMULT)
&& hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()
- && hop.getInput().get(0).getDim2()>1);
+ && hop.getInput().get(0).getDim2()>1)
+ || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, OpOpDnn.AVG_POOL)
+ && hop.getInput().get(0).dimsKnown() && ((DnnOp)hop).isStride1Pad0());
}
@Override
@@ -140,6 +142,8 @@ public class TemplateRow extends TemplateBase
|| (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, OpOpDnn.BIASMULT)
&& hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()
&& hop.getInput().get(0).getDim2()>1)
+ || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, OpOpDnn.AVG_POOL)
+ && hop.getInput().get(0).dimsKnown() && ((DnnOp)hop).isStride1Pad0())
|| isPartOfValidCumAggChain(hop) //cum* with transpose
|| isPartOfValidTransposeMMChain(hop)); //t(f(X))%*%X
}
@@ -156,6 +160,8 @@ public class TemplateRow extends TemplateBase
|| (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, OpOpDnn.BIASMULT)
&& hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()
&& hop.getInput().get(0).getDim2()>1 )
+ || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, OpOpDnn.AVG_POOL)
+ && hop.getInput().get(0).dimsKnown() && ((DnnOp)hop).isStride1Pad0())
|| (HopRewriteUtils.isDataGenOpWithLiteralInputs(input, DataGenMethod.SEQ)
&& HopRewriteUtils.hasOnlyUnaryBinaryParents(input, false))
|| (hop instanceof AggBinaryOp
@@ -476,6 +482,12 @@ public class TemplateRow extends TemplateBase
out = new CNodeBinary(cdata1, cdata2,
BinType.valueOf("VECT_"+((DnnOp)hop).getOp().name()));
}
+ else if( HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, OpOpDnn.AVG_POOL) ) {
+ CNode[] in = hop.getInput().stream().map(h ->
+ tmp.get(h.getHopID())).toArray(CNode[]::new);
+ out = new CNodeNary(in, CNodeNary.NaryType
+ .valueOf("VECT_"+((DnnOp)hop).getOp().name()));
+ }
else if( hop instanceof NaryOp ) {
CNode[] inputs = new CNode[hop.getInput().size()];
for( int i=0; i<hop.getInput().size(); i++ ) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
index fc0c1d2..c1460ce 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
@@ -26,7 +26,9 @@ import org.apache.sysml.runtime.functionobjects.BitwAnd;
import org.apache.sysml.runtime.functionobjects.IntegerDivide;
import org.apache.sysml.runtime.functionobjects.Modulus;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNNPooling;
import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
/**
* This library contains all vector primitives that are used in
@@ -2052,7 +2054,45 @@ public class LibSpoofPrimitives
LibMatrixDNN.multBias(c, b, 1, b.length, len/b.length);
return c;
}
+
+ //maxpool
+
+ public static double[] vectMaxpoolWrite(double[] a, int ai, int len, int rix, int C, int P, int Q, int R, int S, int H, int W) {
+ double[] c = allocVector(C*P*Q, true);
+ LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.MAX,
+ -Double.MAX_VALUE, 1, a, c, rix, rix+1, ai, 0, C, P, Q, R, S, H, W);
+ return c;
+ }
+
+ public static double[] vectMaxpoolWrite(double[] avals, int[] aix, int ai, int alen, int len, int rix, int C, int P, int Q, int R, int S, int H, int W) {
+ double[] a = allocVector(len, true);
+ double[] c = allocVector(C*P*Q, true);
+ for(int k=ai; k<ai+alen; k++)
+ a[aix[k]] = avals[k];
+ LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.MAX,
+ -Double.MAX_VALUE, 1, a, c, rix, rix+1, 0, 0, C, P, Q, R, S, H, W);
+ return c;
+ }
+
+ //avgpool
+ public static double[] vectAvgpoolWrite(double[] a, int ai, int len, int rix, int C, int P, int Q, int R, int S, int H, int W) {
+ double[] c = allocVector(C*P*Q, true);
+ LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.AVG,
+ 0, 1/(R*S), a, c, rix, rix+1, ai, 0, C, P, Q, R, S, H, W);
+ return c;
+ }
+
+ public static double[] vectAvgpoolWrite(double[] avals, int[] aix, int ai, int alen, int len, int rix, int C, int P, int Q, int R, int S, int H, int W) {
+ double[] a = allocVector(len, true);
+ double[] c = allocVector(C*P*Q, true);
+ for(int k=ai; k<ai+alen; k++)
+ a[aix[k]] = avals[k];
+ LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.AVG,
+ 0, 1/(R*S), a, c, rix, rix+1, 0, 0, C, P, Q, R, S, H, W);
+ return c;
+ }
+
//complex builtin functions that are not directly generated
//(included here in order to reduce the number of imports)
http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java
index 4ff8e5e..7fb33a4 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java
@@ -97,8 +97,8 @@ public class LibMatrixDNNPooling {
return ret;
}
- public static void poolingDenseStride1Pad0(PoolingType pType, double minVal, double pFact,
- double[] in, double[] out, int rl, int ru, int C, int P, int Q, int R, int S, int H, int W) {
+ public static void poolingDenseStride1Pad0(PoolingType pType, double minVal, double pFact, double[] in,
+ double[] out, int rl, int ru, int ii, int oi, int C, int P, int Q, int R, int S, int H, int W) {
boolean max = (pType == PoolingType.MAX);
int CHW = C * H * W;
@@ -106,9 +106,9 @@ public class LibMatrixDNNPooling {
//quick-path w/o materialized index arrays and
//simplified inner loops for P = 1, Q = 1, W = 1
int lenh = Math.min(R,H);
- for(int i = rl, oix=rl*C; i < ru; i++, oix+=C)
- for (int c = 0, off=i*CHW; c < C; c++, off+=H) {
- out[oix+c] = max ? max(minVal, in, off, lenh) :
+ for(int i = rl; i < ru; i++, oi+=C)
+ for (int c = 0, off=ii+(i-rl)*CHW; c < C; c++, off+=H) {
+ out[oi+c] = max ? max(minVal, in, off, lenh) :
avg(minVal, in, off, lenh, pFact);
}
}
@@ -117,7 +117,7 @@ public class LibMatrixDNNPooling {
Arrays.fill(out, rl*CPQ, ru*CPQ, minVal);
//quick-path w/o materialized index arrays
for(int i = rl; i < ru; i++)
- for (int c = 0, off=i*CHW, oix=i*CPQ; c < C; c++, off+=HW)
+ for (int c = 0, off=ii+(i-rl)*CHW, oix=oi; c < C; c++, off+=HW)
for (int p = 0; p < P; p++, oix+=Q)
for (int h = p; h < Math.min(p+R,H); h++)
for (int q = 0, off2=off+h*W; q < Q; q++) {
@@ -139,7 +139,7 @@ public class LibMatrixDNNPooling {
_rl = rl; _ru = ru;
_params = params;
_poolingType = poolingType;
- _poolingMultiplier = Math.pow(params.R*params.S, -1);
+ _poolingMultiplier = 1/(params.R*params.S);
}
@Override
@@ -157,7 +157,7 @@ public class LibMatrixDNNPooling {
if( _params.isStride1Pad0() ) {
poolingDenseStride1Pad0(_poolingType, minValForMaxPoolOperations,
- _poolingMultiplier, in, out, _rl, _ru, C, P, Q, R, S, H, W);
+ _poolingMultiplier, in, out, _rl, _ru, _rl*CHW, _rl*CPQ, C, P, Q, R, S, H, W);
}
else { //general case
//thread-local initialization of output block
http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index 04891d0..48555ae 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -80,7 +80,7 @@ public class RowAggTmplTest extends AutomatedTestBase
private static final String TEST_NAME41 = TEST_NAME+"41"; //X*rowSums(X/seq(1,N)+t(seq(M,1)))
private static final String TEST_NAME42 = TEST_NAME+"42"; //X/rowSums(min(X, Y, Z))
private static final String TEST_NAME43 = TEST_NAME+"43"; //bias_add(X,B) + bias_mult(X,B)
- private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - mean(X));
+ private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - mean(X)) + 7;
private static final String TEST_DIR = "functions/codegen/";
private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/";
@@ -817,6 +817,10 @@ public class RowAggTmplTest extends AutomatedTestBase
if( testname.equals(TEST_NAME42) )
Assert.assertTrue(!heavyHittersContainsSubString("min","nmin")
&& !heavyHittersContainsSubString("spoof", 2));
+ if( testname.equals(TEST_NAME44) )
+ Assert.assertTrue(!heavyHittersContainsSubString("maxpooling")
+ && !heavyHittersContainsSubString("spoof", 2));
+
}
finally {
rtplatform = platformOld;
http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/test/scripts/functions/codegen/rowAggPattern44.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern44.R b/src/test/scripts/functions/codegen/rowAggPattern44.R
index 99ba0b0..7269df3 100644
--- a/src/test/scripts/functions/codegen/rowAggPattern44.R
+++ b/src/test/scripts/functions/codegen/rowAggPattern44.R
@@ -95,5 +95,6 @@ max_pool <- function(X, N, C, Hin, Win, Hf, Wf,
}
R = max_pool(X, numImg, numChannels, imgSize*imgSize, 1, poolSize1, poolSize2, stride, stride)
+R = R + 7;
writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""))
http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/test/scripts/functions/codegen/rowAggPattern44.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern44.dml b/src/test/scripts/functions/codegen/rowAggPattern44.dml
index f236451..f5e7b6c 100644
--- a/src/test/scripts/functions/codegen/rowAggPattern44.dml
+++ b/src/test/scripts/functions/codegen/rowAggPattern44.dml
@@ -31,5 +31,6 @@ while(FALSE){}
X = X - rowMeans(X);
R = max_pool(X, stride=[stride, stride], padding=[pad, pad], input_shape=[numImg, numChannels, imgSize*imgSize, 1], pool_size=[poolSize1, poolSize2]);
+R = R + 7;
write(R, $1, format="text");