You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/12/11 19:23:51 UTC
[2/2] systemml git commit: [SYSTEMML-2507] New rewrites for
cumulative aggregate patterns
[SYSTEMML-2507] New rewrites for cumulative aggregate patterns
This patch adds the following simplification rewrites as well as related
tests:
(a) X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri, if X squared
(b) colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
(c) rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9a1f64b4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9a1f64b4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9a1f64b4
Branch: refs/heads/master
Commit: 9a1f64b42c177a82a98716ad9ef34d4d266178d2
Parents: b96807b
Author: Matthias Boehm <mb...@gmail.com>
Authored: Tue Dec 11 20:10:23 2018 +0100
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Tue Dec 11 20:10:46 2018 +0100
----------------------------------------------------------------------
.../RewriteAlgebraicSimplificationDynamic.java | 33 ++++-
.../RewriteAlgebraicSimplificationStatic.java | 45 +++++++
.../hops/rewrite/RewriteGPUSpecificOps.java | 26 ++--
.../misc/RewriteCumulativeAggregatesTest.java | 126 +++++++++++++++++++
.../misc/RewriteCumulativeAggregates.R | 43 +++++++
.../misc/RewriteCumulativeAggregates.dml | 49 ++++++++
6 files changed, 306 insertions(+), 16 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 36864aa..9556181 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -175,6 +175,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hi = simplifyMatrixMultDiag(hop, hi, i); //e.g., diag(X)%*%Y -> X*Y, if ncol(Y)==1 / -> Y*X if ncol(Y)>1
hi = simplifyDiagMatrixMult(hop, hi, i); //e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
hi = simplifySumDiagToTrace(hi); //e.g., sum(diag(X)) -> trace(X); if col vector
+ hi = simplifyLowerTriExtraction(hop, hi, i); //e.g., X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri
hi = pushdownBinaryOperationOnDiag(hop, hi, i); //e.g., diag(X)*7 -> diag(X*7); if col vector
hi = pushdownSumOnAdditiveBinary(hop, hi, i); //e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B)
if(OptimizerUtils.ALLOW_OPERATOR_FUSION) {
@@ -1046,7 +1047,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( hi instanceof AggUnaryOp )
{
AggUnaryOp au = (AggUnaryOp) hi;
- if( au.getOp()==AggOp.SUM && au.getDirection()==Direction.RowCol ) //sum
+ if( au.getOp()==AggOp.SUM && au.getDirection()==Direction.RowCol ) //sum
{
Hop hi2 = au.getInput().get(0);
if( hi2 instanceof ReorgOp && ((ReorgOp)hi2).getOp()==ReOrgOp.DIAG && hi2.getDim2()==1 ) //diagM2V
@@ -1054,7 +1055,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
Hop hi3 = hi2.getInput().get(0);
//remove diag operator
- HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0);
+ HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0);
HopRewriteUtils.cleanupUnreferenced(hi2);
//change sum to trace
@@ -1063,12 +1064,38 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
LOG.debug("Applied simplifySumDiagToTrace");
}
}
-
}
return hi;
}
+ private static Hop simplifyLowerTriExtraction(Hop parent, Hop hi, int pos) {
+ //pattern: X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri (only right)
+ if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
+ && hi.getDim1() == hi.getDim2() && hi.getDim1() > 1 ) {
+ Hop left = hi.getInput().get(0);
+ Hop right = hi.getInput().get(1);
+
+ if( HopRewriteUtils.isUnary(right, OpOp1.CUMSUM) && right.getParent().size()==1
+ && HopRewriteUtils.isReorg(right.getInput().get(0), ReOrgOp.DIAG)
+ && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0).getInput().get(0), 1d))
+ {
+ LinkedHashMap<String,Hop> args = new LinkedHashMap<>();
+ args.put("target", left);
+ args.put("diag", new LiteralOp(true));
+ args.put("values", new LiteralOp(true));
+ Hop hnew = HopRewriteUtils.createParameterizedBuiltinOp(
+ left, args, ParamBuiltinOp.LOWER_TRI);
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew);
+ HopRewriteUtils.removeAllChildReferences(right);
+
+ hi = hnew;
+ LOG.debug("Applied simplifyLowerTriExtraction");
+ }
+ }
+ return hi;
+ }
+
@SuppressWarnings("unchecked")
private static Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi, int pos)
{
http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 62a5d4f..9a3956c 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -183,6 +183,9 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
}
hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false)
hi = simplifyBinaryComparisonChain(hop, hi, i); //e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> outer(v1,v2,"!="),
+ hi = simplifyCumsumColOrFullAggregates(hi); //e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
+ hi = simplifyCumsumReverse(hop, hi, i); //e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
+
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
@@ -1844,6 +1847,48 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
return hi;
}
+ private static Hop simplifyCumsumColOrFullAggregates(Hop hi) {
+ //pattern: colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
+ if( (HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.Col)
+ || HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol))
+ && HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM)
+ && hi.getInput().get(0).getParent().size()==1)
+ {
+ Hop cumsumX = hi.getInput().get(0);
+ Hop X = cumsumX.getInput().get(0);
+ Hop mult = HopRewriteUtils.createBinary(X,
+ HopRewriteUtils.createSeqDataGenOp(X, false), OpOp2.MULT);
+ HopRewriteUtils.replaceChildReference(hi, cumsumX, mult);
+ HopRewriteUtils.removeAllChildReferences(cumsumX);
+ LOG.debug("Applied simplifyCumsumColOrFullAggregates (line "+hi.getBeginLine()+")");
+ }
+ return hi;
+ }
+
+ private static Hop simplifyCumsumReverse(Hop parent, Hop hi, int pos) {
+ //pattern: rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
+ if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
+ && HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM)
+ && hi.getInput().get(0).getParent().size()==1
+ && HopRewriteUtils.isReorg(hi.getInput().get(0).getInput().get(0), ReOrgOp.REV)
+ && hi.getInput().get(0).getInput().get(0).getParent().size()==1)
+ {
+ Hop cumsumX = hi.getInput().get(0);
+ Hop revX = cumsumX.getInput().get(0);
+ Hop X = revX.getInput().get(0);
+ Hop plus = HopRewriteUtils.createBinary(X, HopRewriteUtils
+ .createAggUnaryOp(X, AggOp.SUM, Direction.Col), OpOp2.PLUS);
+ Hop minus = HopRewriteUtils.createBinary(plus,
+ HopRewriteUtils.createUnary(X, OpOp1.CUMSUM), OpOp2.MINUS);
+ HopRewriteUtils.replaceChildReference(parent, hi, minus, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi, cumsumX, revX);
+
+ hi = minus;
+ LOG.debug("Applied simplifyCumsumReverse (line "+hi.getBeginLine()+")");
+ }
+ return hi;
+ }
+
/**
* NOTE: currently disabled since this rewrite is INVALID in the
* presence of NaNs (because (NaN!=NaN) is true).
http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index ab40d7b..1d87c09 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -176,19 +176,19 @@ public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher {
// norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win)
// # Compute gradients during training
// dgamma = util::channel_sums(dout*norm, C, Hin, Win)
- private static final HopDagPatternMatcher _batchNormDGamma;
- static {
- _batchNormDGamma = util_channel_sums(
- mult( leaf("dout", MATRIX).fitsOnGPU(3),
- bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", MATRIX))),
- leaf("ema_var", MATRIX))), leaf("C", SCALAR), leaf("HW", SCALAR));
- }
- private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi -> {
- LOG.debug("Applied batchNormDGamma rewrite.");
- Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA,
- "ema_mean", "dout", "X", "ema_var");
- return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
- };
+// private static final HopDagPatternMatcher _batchNormDGamma;
+// static {
+// _batchNormDGamma = util_channel_sums(
+// mult( leaf("dout", MATRIX).fitsOnGPU(3),
+// bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", MATRIX))),
+// leaf("ema_var", MATRIX))), leaf("C", SCALAR), leaf("HW", SCALAR));
+// }
+// private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi -> {
+// LOG.debug("Applied batchNormDGamma rewrite.");
+// Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA,
+// "ema_mean", "dout", "X", "ema_var");
+// return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+// };
// Pattern 3:
private static final HopDagPatternMatcher _batchNormTest;
http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
new file mode 100644
index 0000000..da13502
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+public class RewriteCumulativeAggregatesTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME = "RewriteCumulativeAggregates";
+ private static final String TEST_DIR = "functions/misc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + RewriteCumulativeAggregatesTest.class.getSimpleName() + "/";
+
+ private static final int rows = 1234;
+ private static final int cols = 7;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "R" }) );
+ }
+
+ @Test
+ public void testCumAggRewrite1False() {
+ testCumAggRewrite(1, false);
+ }
+
+ @Test
+ public void testCumAggRewrite1True() {
+ testCumAggRewrite(1, true);
+ }
+
+ @Test
+ public void testCumAggRewrite2False() {
+ testCumAggRewrite(2, false);
+ }
+
+ @Test
+ public void testCumAggRewrite2True() {
+ testCumAggRewrite(2, true);
+ }
+
+ @Test
+ public void testCumAggRewrite3False() {
+ testCumAggRewrite(3, false);
+ }
+
+ @Test
+ public void testCumAggRewrite3True() {
+ testCumAggRewrite(3, true);
+ }
+
+ @Test
+ public void testCumAggRewrite4False() {
+ testCumAggRewrite(4, false);
+ }
+
+ @Test
+ public void testCumAggRewrite4True() {
+ testCumAggRewrite(4, true);
+ }
+
+ private void testCumAggRewrite(int num, boolean rewrites)
+ {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+
+ try {
+ TestConfiguration config = getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{ "-stats", "-args",
+ input("A"), String.valueOf(num), output("R") };
+ rCmd = getRCmd(inputDir(), String.valueOf(num), expectedDir());
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ //generate input data
+ double[][] A = getRandomMatrix((num==4)?1:rows,
+ (num==1)?rows:cols, -1, 1, 0.9, 7);
+ writeInputMatrixWithMTD("A", A, true);
+
+ //run performance tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
+ TestUtils.compareMatrices(dmlfile, rfile, 1e-7, "Stat-DML", "Stat-R");
+
+ //check applied rewrites
+ if( rewrites )
+ Assert.assertTrue(!heavyHittersContainsString((num==2) ? "rev" : "ucumk+"));
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
new file mode 100644
index 0000000..f8a8576
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+num = as.integer(args[2]);
+
+#note: cumsum and rev only over vectors
+if( num == 1 ) {
+ R = lower.tri(X,diag=TRUE) * X;
+} else if( num == 2 ) {
+ A = X[seq(nrow(X),1),]
+ R = apply(A, 2, cumsum);
+ R = R[seq(nrow(X),1),]
+} else if( num == 3 ) {
+ R = t(as.matrix(colSums(apply(X, 2, cumsum))));
+} else if( num == 4 ) {
+ R = X;
+}
+
+writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml
new file mode 100644
index 0000000..f4c3486
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+foo = function( Matrix[Double] A ) return( Matrix[Double] B )
+{
+ for( i in 1:1 ) {
+ continue = TRUE;
+ if( sum(A)<0 ) {
+ continue = FALSE;
+ }
+ iter = 0;
+ if( continue ) {
+ iter = iter+1;
+ }
+ B = A+iter;
+ }
+}
+
+X = read($1);
+
+if( $2 == 1 )
+ R = X * cumsum(diag(matrix(1,nrow(X),1)));
+else if( $2 == 2 )
+ R = rev(cumsum(rev(X)));
+else if( $2 == 3 )
+ R = colSums(cumsum(X));
+else if( $2 == 4 )
+ R = cumsum(X);
+
+write(R, $3);
\ No newline at end of file