You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2023/07/03 12:16:02 UTC

[systemds] branch main updated: [SYSTEMDS-2672] New rewrites for indexing/loop removal and cleanups

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new b448f3314c [SYSTEMDS-2672] New rewrites for indexing/loop removal and cleanups
b448f3314c is described below

commit b448f3314c248ca0a901017612cd2afb77db170c
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Mon Jul 3 14:14:10 2023 +0200

    [SYSTEMDS-2672] New rewrites for indexing/loop removal and cleanups
    
    Although we already supported some forms of indexing and loop removal,
    this patch extends these rewrites to properly remove unnecessary
    right indexing, remove empty blocks (potentially after rewrites), and
    even surrounding loops without any remaining operations.
---
 .../java/org/apache/sysds/hops/IndexingOp.java     |  4 ++
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java |  9 +--
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |  1 +
 .../RewriteAlgebraicSimplificationStatic.java      | 19 ++++++
 .../hops/rewrite/RewriteRemoveEmptyForLoops.java   | 61 +++++++++++++++++++
 .../rewrite/RewriteIndexingRemovalTest.java        | 69 ++++++++++++++++++++++
 .../scripts/functions/rewrite/removal_rix1.dml     | 26 ++++++++
 .../scripts/functions/rewrite/removal_rix2.dml     | 28 +++++++++
 8 files changed, 213 insertions(+), 4 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/IndexingOp.java b/src/main/java/org/apache/sysds/hops/IndexingOp.java
index f10400d497..4155dfd9f5 100644
--- a/src/main/java/org/apache/sysds/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/IndexingOp.java
@@ -423,6 +423,10 @@ public class IndexingOp extends Hop
 		}
 	}
 	
+	public boolean isAllRowsAndCols() {
+		return isAllRows() && isAllCols();
+	}
+	
 	public boolean isAllRows() {
 		Hop input1 = getInput().get(0);
 		Hop input2 = getInput().get(1);
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index ed93ea7366..338393eda5 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1321,10 +1321,11 @@ public class HopRewriteUtils {
 		//starting row and column ranges of 1 in order to guard against
 		//invalid modifications in the presence of invalid index ranges
 		//(e.g., X[,2] on a column vector needs to throw an error)
-		return isEqualSize(hop, hop.getInput().get(0))
-			&& !(hop.getDim1()==1 && hop.getDim2()==1)
-			&& isLiteralOfValue(hop.getInput().get(1), 1)  //rl
-			&& isLiteralOfValue(hop.getInput().get(3), 1); //cl
+		return ((IndexingOp)hop).isAllRowsAndCols()
+			|| (isEqualSize(hop, hop.getInput().get(0))
+				&& !(hop.getDim1()==1 && hop.getDim2()==1)
+				&& isLiteralOfValue(hop.getInput().get(1), 1)  //rl
+				&& isLiteralOfValue(hop.getInput().get(3), 1)); //cl
 	}
 	
 	public static boolean isScalarMatrixBinaryMult( Hop hop ) {
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index db20ada280..11105fe7f6 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -150,6 +150,7 @@ public class ProgramRewriter
 		if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
 			_dagRuleSet.add( new RewriteConstantFolding()                    ); //dependency: cse
 		_sbRuleSet.add(  new RewriteRemoveEmptyBasicBlocks()                 );
+		_sbRuleSet.add(  new RewriteRemoveEmptyForLoops()                    );
 	}
 	
 	/**
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 7f658ac417..b74e93d6cf 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -22,6 +22,7 @@ package org.apache.sysds.hops.rewrite;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Set;
@@ -45,6 +46,7 @@ import org.apache.sysds.common.Types.OpOp1;
 import org.apache.sysds.common.Types.OpOp2;
 import org.apache.sysds.common.Types.OpOp3;
 import org.apache.sysds.common.Types.OpOpDG;
+import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.common.Types.OpOpN;
 import org.apache.sysds.common.Types.ParamBuiltinOp;
 import org.apache.sysds.common.Types.ReOrgOp;
@@ -94,6 +96,9 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			rule_AlgebraicSimplification( h, true );
 		Hop.resetVisitStatus(roots, true);
 		
+		//cleanup remove (twrite <- tread) pairs (unless checkpointing)
+		removeTWriteTReadPairs(roots);
+		
 		return roots;
 	}
 
@@ -2011,4 +2016,18 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 		
 		return hi;
 	}
+	
+	private static void removeTWriteTReadPairs(ArrayList<Hop> roots) {
+		Iterator<Hop> iter = roots.iterator();
+		while(iter.hasNext()) {
+			Hop root = iter.next();
+			if( HopRewriteUtils.isData(root, OpOpData.TRANSIENTWRITE)
+				&& HopRewriteUtils.isData(root.getInput(0), OpOpData.TRANSIENTREAD)
+				&& root.getName().equals(root.getInput(0).getName())
+				&& !root.getInput(0).requiresCheckpoint())
+			{
+				iter.remove();
+			}
+		}
+	}
 }
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveEmptyForLoops.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveEmptyForLoops.java
new file mode 100644
index 0000000000..226d227318
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveEmptyForLoops.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.StatementBlock;
+
+/**
+ * Rule: Simplify program structure by removing empty for loops,
+ * which may originate from the sequence of other rewrites like
+ * dead-code-elimination.
+ */
+public class RewriteRemoveEmptyForLoops extends StatementBlockRewriteRule
+{
+	@Override
+	public boolean createsSplitDag() {
+		return false;
+	}
+	
+	@Override
+	public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
+		ArrayList<StatementBlock> ret = new ArrayList<>();
+		
+		//prune last level blocks with empty hops
+		if( sb instanceof ForStatementBlock 
+				&& ((ForStatement)sb.getStatement(0)).getBody().isEmpty() ) {
+			if( LOG.isDebugEnabled() )
+				LOG.debug("Applied removeEmptyForLopp (lines "+sb.getBeginLine()+"-"+sb.getEndLine()+").");
+		}
+		else //keep original sb
+			ret.add( sb );
+		
+		return ret;
+	}
+	
+	@Override
+	public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) {
+		return sbs;
+	}
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIndexingRemovalTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIndexingRemovalTest.java
new file mode 100644
index 0000000000..6334bcb1d2
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIndexingRemovalTest.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.utils.Statistics;
+
+public class RewriteIndexingRemovalTest extends AutomatedTestBase
+{
+	private final static String TEST_NAME1 = "removal_rix1";
+	private final static String TEST_NAME2 = "removal_rix2";
+	private final static String TEST_DIR = "functions/rewrite/";
+	private final static String TEST_CLASS_DIR = TEST_DIR + RewriteIndexingRemovalTest.class.getSimpleName() + "/";
+	
+	private final static int rows = 10;
+	private final static int cols = 15;
+	
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}) );
+		addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}) );
+	}
+	
+	@Test
+	public void runIndexingRemovalTest() {
+		runIndexingRemovalTest(TEST_NAME1);
+	}
+	
+	@Test
+	public void runDynIndexingRemovalTest() {
+		runIndexingRemovalTest(TEST_NAME2);
+	}
+	
+	private void runIndexingRemovalTest(String testname) {
+		TestConfiguration config = getTestConfiguration(testname);
+		loadTestConfiguration(config);
+		
+		String HOME = SCRIPT_DIR + TEST_DIR;
+		fullDMLScriptName = HOME + testname + ".dml";
+		programArgs = new String[]{"-explain", "-stats", "-args",
+			Integer.toString(rows), Integer.toString(cols)};
+		
+		runTest(true, false, null, -1);
+		Assert.assertEquals(0, Statistics.getCPHeavyHitterCount("rix"));
+		if(testname.equals(TEST_NAME2))
+			Assert.assertEquals(2, Statistics.getNoOfCompiledSPInst());
+	}
+}
diff --git a/src/test/scripts/functions/rewrite/removal_rix1.dml b/src/test/scripts/functions/rewrite/removal_rix1.dml
new file mode 100644
index 0000000000..9fd8c1b983
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/removal_rix1.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(7, $1, $2)
+for(i in 1:10)
+  X = X[,]
+print(sum(X));
+
diff --git a/src/test/scripts/functions/rewrite/removal_rix2.dml b/src/test/scripts/functions/rewrite/removal_rix2.dml
new file mode 100644
index 0000000000..ad5b78b311
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/removal_rix2.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(7, $1, $2)
+while(FALSE){}
+Y = matrix(7, $1, sum(X))
+for(i in 1:10)
+  Y = Y[,]
+print(sum(Y));
+