You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/12/11 19:23:51 UTC

[2/2] systemml git commit: [SYSTEMML-2507] New rewrites for cumulative aggregate patterns

[SYSTEMML-2507] New rewrites for cumulative aggregate patterns

This patch adds the following simplification rewrites as well as related
tests:
(a) X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri, if X squared
(b) colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
(c) rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9a1f64b4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9a1f64b4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9a1f64b4

Branch: refs/heads/master
Commit: 9a1f64b42c177a82a98716ad9ef34d4d266178d2
Parents: b96807b
Author: Matthias Boehm <mb...@gmail.com>
Authored: Tue Dec 11 20:10:23 2018 +0100
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Tue Dec 11 20:10:46 2018 +0100

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationDynamic.java  |  33 ++++-
 .../RewriteAlgebraicSimplificationStatic.java   |  45 +++++++
 .../hops/rewrite/RewriteGPUSpecificOps.java     |  26 ++--
 .../misc/RewriteCumulativeAggregatesTest.java   | 126 +++++++++++++++++++
 .../misc/RewriteCumulativeAggregates.R          |  43 +++++++
 .../misc/RewriteCumulativeAggregates.dml        |  49 ++++++++
 6 files changed, 306 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 36864aa..9556181 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -175,6 +175,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			hi = simplifyMatrixMultDiag(hop, hi, i);          //e.g., diag(X)%*%Y -> X*Y, if ncol(Y)==1 / -> Y*X if ncol(Y)>1 
 			hi = simplifyDiagMatrixMult(hop, hi, i);          //e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
 			hi = simplifySumDiagToTrace(hi);                  //e.g., sum(diag(X)) -> trace(X); if col vector
+			hi = simplifyLowerTriExtraction(hop, hi, i);      //e.g., X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri
 			hi = pushdownBinaryOperationOnDiag(hop, hi, i);   //e.g., diag(X)*7 -> diag(X*7); if col vector
 			hi = pushdownSumOnAdditiveBinary(hop, hi, i);     //e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B)
 			if(OptimizerUtils.ALLOW_OPERATOR_FUSION) {
@@ -1046,7 +1047,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 		if( hi instanceof AggUnaryOp ) 
 		{
 			AggUnaryOp au = (AggUnaryOp) hi;
-			if( au.getOp()==AggOp.SUM && au.getDirection()==Direction.RowCol )	//sum	
+			if( au.getOp()==AggOp.SUM && au.getDirection()==Direction.RowCol )	//sum
 			{
 				Hop hi2 = au.getInput().get(0);
 				if( hi2 instanceof ReorgOp && ((ReorgOp)hi2).getOp()==ReOrgOp.DIAG && hi2.getDim2()==1 ) //diagM2V
@@ -1054,7 +1055,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 					Hop hi3 = hi2.getInput().get(0);
 					
 					//remove diag operator
-					HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0);	
+					HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0);
 					HopRewriteUtils.cleanupUnreferenced(hi2);
 					
 					//change sum to trace
@@ -1063,12 +1064,38 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 					LOG.debug("Applied simplifySumDiagToTrace");
 				}
 			}
-				
 		}
 		
 		return hi;
 	}
 	
+	private static Hop simplifyLowerTriExtraction(Hop parent, Hop hi, int pos) {
+		//pattern: X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri (only right)
+		if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) 
+			&& hi.getDim1() == hi.getDim2() && hi.getDim1() > 1 ) {
+			Hop left = hi.getInput().get(0);
+			Hop right = hi.getInput().get(1);
+			
+			if( HopRewriteUtils.isUnary(right, OpOp1.CUMSUM) && right.getParent().size()==1
+				&& HopRewriteUtils.isReorg(right.getInput().get(0), ReOrgOp.DIAG)
+				&& HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0).getInput().get(0), 1d))
+			{
+				LinkedHashMap<String,Hop> args = new LinkedHashMap<>();
+				args.put("target", left);
+				args.put("diag", new LiteralOp(true));
+				args.put("values", new LiteralOp(true));
+				Hop hnew = HopRewriteUtils.createParameterizedBuiltinOp(
+					left, args, ParamBuiltinOp.LOWER_TRI);
+				HopRewriteUtils.replaceChildReference(parent, hi, hnew);
+				HopRewriteUtils.removeAllChildReferences(right);
+				
+				hi = hnew;
+				LOG.debug("Applied simplifyLowerTriExtraction");
+			}
+		}
+		return hi;
+	}
+	
 	@SuppressWarnings("unchecked")
 	private static Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi, int pos) 
 	{

http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 62a5d4f..9a3956c 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -183,6 +183,9 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			}
 			hi = simplifyOuterSeqExpand(hop, hi, i);             //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false)
 			hi = simplifyBinaryComparisonChain(hop, hi, i);      //e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> outer(v1,v2,"!="), 
+			hi = simplifyCumsumColOrFullAggregates(hi);          //e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
+			hi = simplifyCumsumReverse(hop, hi, i);              //e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
+			
 			
 			//hi = removeUnecessaryPPred(hop, hi, i);            //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
@@ -1844,6 +1847,48 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 		return hi;
 	}
 	
+	private static Hop simplifyCumsumColOrFullAggregates(Hop hi) {
+		//pattern: colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
+		if( (HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.Col)
+			|| HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol))
+			&& HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM)
+			&& hi.getInput().get(0).getParent().size()==1)
+		{
+			Hop cumsumX = hi.getInput().get(0);
+			Hop X = cumsumX.getInput().get(0);
+			Hop mult = HopRewriteUtils.createBinary(X,
+				HopRewriteUtils.createSeqDataGenOp(X, false), OpOp2.MULT);
+			HopRewriteUtils.replaceChildReference(hi, cumsumX, mult);
+			HopRewriteUtils.removeAllChildReferences(cumsumX);
+			LOG.debug("Applied simplifyCumsumColOrFullAggregates (line "+hi.getBeginLine()+")");
+		}
+		return hi;
+	}
+	
+	private static Hop simplifyCumsumReverse(Hop parent, Hop hi, int pos) {
+		//pattern: rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
+		if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
+			&& HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM)
+			&& hi.getInput().get(0).getParent().size()==1
+			&& HopRewriteUtils.isReorg(hi.getInput().get(0).getInput().get(0), ReOrgOp.REV)
+			&& hi.getInput().get(0).getInput().get(0).getParent().size()==1)
+		{
+			Hop cumsumX = hi.getInput().get(0);
+			Hop revX = cumsumX.getInput().get(0);
+			Hop X = revX.getInput().get(0);
+			Hop plus = HopRewriteUtils.createBinary(X, HopRewriteUtils
+				.createAggUnaryOp(X, AggOp.SUM, Direction.Col), OpOp2.PLUS);
+			Hop minus = HopRewriteUtils.createBinary(plus,
+				HopRewriteUtils.createUnary(X, OpOp1.CUMSUM), OpOp2.MINUS);
+			HopRewriteUtils.replaceChildReference(parent, hi, minus, pos);
+			HopRewriteUtils.cleanupUnreferenced(hi, cumsumX, revX);
+			
+			hi = minus;
+			LOG.debug("Applied simplifyCumsumReverse (line "+hi.getBeginLine()+")");
+		}
+		return hi;
+	}
+	
 	/**
 	 * NOTE: currently disabled since this rewrite is INVALID in the
 	 * presence of NaNs (because (NaN!=NaN) is true). 

http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index ab40d7b..1d87c09 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -176,19 +176,19 @@ public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher {
 	// norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
 	// # Compute gradients during training
 	// dgamma = util::channel_sums(dout*norm, C, Hin, Win)
-	private static final HopDagPatternMatcher _batchNormDGamma;
-	static {
-		_batchNormDGamma = util_channel_sums(
-				mult(	leaf("dout", MATRIX).fitsOnGPU(3),
-						bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", MATRIX))), 
-				leaf("ema_var", MATRIX))), leaf("C", SCALAR), leaf("HW", SCALAR));
-	}
-	private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi -> {
-		LOG.debug("Applied batchNormDGamma rewrite.");
-		Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, 
-				"ema_mean", "dout", "X", "ema_var");
-		return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
-	};
+//	private static final HopDagPatternMatcher _batchNormDGamma;
+//	static {
+//		_batchNormDGamma = util_channel_sums(
+//				mult(	leaf("dout", MATRIX).fitsOnGPU(3),
+//						bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", MATRIX))), 
+//				leaf("ema_var", MATRIX))), leaf("C", SCALAR), leaf("HW", SCALAR));
+//	}
+//	private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi -> {
+//		LOG.debug("Applied batchNormDGamma rewrite.");
+//		Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, 
+//				"ema_mean", "dout", "X", "ema_var");
+//		return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+//	};
 		
 	// Pattern 3:
 	private static final HopDagPatternMatcher _batchNormTest;

http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
new file mode 100644
index 0000000..da13502
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+public class RewriteCumulativeAggregatesTest extends AutomatedTestBase 
+{	
+	private static final String TEST_NAME = "RewriteCumulativeAggregates";
+	private static final String TEST_DIR = "functions/misc/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteCumulativeAggregatesTest.class.getSimpleName() + "/";
+	
+	private static final int rows = 1234;
+	private static final int cols = 7;
+	
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration( TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "R" }) );
+	}
+
+	@Test
+	public void testCumAggRewrite1False() {
+		testCumAggRewrite(1, false);
+	}
+	
+	@Test
+	public void testCumAggRewrite1True() {
+		testCumAggRewrite(1, true);
+	}
+	
+	@Test
+	public void testCumAggRewrite2False() {
+		testCumAggRewrite(2, false);
+	}
+	
+	@Test
+	public void testCumAggRewrite2True() {
+		testCumAggRewrite(2, true);
+	}
+	
+	@Test
+	public void testCumAggRewrite3False() {
+		testCumAggRewrite(3, false);
+	}
+	
+	@Test
+	public void testCumAggRewrite3True() {
+		testCumAggRewrite(3, true);
+	}
+	
+	@Test
+	public void testCumAggRewrite4False() {
+		testCumAggRewrite(4, false);
+	}
+	
+	@Test
+	public void testCumAggRewrite4True() {
+		testCumAggRewrite(4, true);
+	}
+	
+	private void testCumAggRewrite(int num, boolean rewrites)
+	{
+		boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+		
+		try {
+			TestConfiguration config = getTestConfiguration(TEST_NAME);
+			loadTestConfiguration(config);
+			
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			programArgs = new String[]{ "-stats", "-args",
+				input("A"), String.valueOf(num), output("R") };
+			rCmd = getRCmd(inputDir(), String.valueOf(num), expectedDir());
+			OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+			
+			//generate input data
+			double[][] A = getRandomMatrix((num==4)?1:rows,
+				(num==1)?rows:cols, -1, 1, 0.9, 7); 
+			writeInputMatrixWithMTD("A", A, true);
+			
+			//run performance tests
+			runTest(true, false, null, -1);
+			runRScript(true);
+			
+			//compare matrices
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+			HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
+			TestUtils.compareMatrices(dmlfile, rfile, 1e-7, "Stat-DML", "Stat-R");
+			
+			//check applied rewrites
+			if( rewrites )
+				Assert.assertTrue(!heavyHittersContainsString((num==2) ? "rev" : "ucumk+"));
+		}
+		finally {
+			OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
new file mode 100644
index 0000000..f8a8576
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+num = as.integer(args[2]);
+
+#note: cumsum and rev only over vectors
+if( num == 1 ) {
+  R = lower.tri(X,diag=TRUE) * X;
+} else if( num == 2 ) {
+  A = X[seq(nrow(X),1),]
+  R = apply(A, 2, cumsum);
+  R = R[seq(nrow(X),1),]
+} else if( num == 3 ) {
+  R = t(as.matrix(colSums(apply(X, 2, cumsum))));
+} else if( num == 4 ) {
+  R = X;
+}
+
+writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml
new file mode 100644
index 0000000..f4c3486
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+foo = function( Matrix[Double] A ) return( Matrix[Double] B )
+{
+   for( i in 1:1 ) {
+     continue = TRUE;
+     if( sum(A)<0 ) {
+        continue = FALSE;
+     }
+     iter = 0;
+     if( continue ) {
+        iter = iter+1;
+     }
+     B = A+iter;
+   }
+}
+
+X = read($1);
+
+if( $2 == 1 )
+  R = X * cumsum(diag(matrix(1,nrow(X),1)));
+else if( $2 == 2 )
+  R = rev(cumsum(rev(X)));
+else if( $2 == 3 )
+  R = colSums(cumsum(X));
+else if( $2 == 4 )
+  R = cumsum(X);
+
+write(R, $3);
\ No newline at end of file