You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/01 22:45:07 UTC

[systemds] branch master updated: [SYSTEMDS-3113] IPA pass for replacing eval w/ direct function calls

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new bebfd41  [SYSTEMDS-3113] IPA pass for replacing eval w/ direct function calls
bebfd41 is described below

commit bebfd41133198a0cb71a78e5902eb4216dc34a50
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Thu Sep 2 00:39:44 2021 +0200

    [SYSTEMDS-3113] IPA pass for replacing eval w/ direct function calls
    
    This patch is a first part of a new IPA pass for replacing second-order
    eval function calls with direct function calls if after constant
    propagation the function name is known. So far, we handle plain lists of
    arguments, but to not apply this rewrite for list input (that require
    flattening) or non-matrix outputs (that require casting).
---
 .../java/org/apache/sysds/hops/OptimizerUtils.java |   8 +-
 .../hops/ipa/IPAPassReplaceEvalFunctionCalls.java  | 171 +++++++++++++++++++++
 .../sysds/hops/ipa/InterProceduralAnalysis.java    |   1 +
 .../org/apache/sysds/parser/FunctionStatement.java |   5 +
 .../cp/ParameterizedBuiltinCPInstruction.java      |   3 +-
 .../test/functions/misc/FunctionPotpourriTest.java |  96 ++++++++++--
 .../functions/mlcontext/MLContextTestBase.java     |  15 +-
 .../functions/misc/FunPotpourriEvalNamespace2.dml  |  27 ++++
 8 files changed, 305 insertions(+), 21 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 201373f..01769c7 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -191,6 +191,13 @@ public class OptimizerUtils
 	public static boolean ALLOW_LOOP_UPDATE_IN_PLACE = true;
 	
 	/**
+	 * Replace eval second-order function calls with normal function call
+	 * if the function name is a known string (after constant propagation).
+	 */
+	public static boolean ALLOW_EVAL_FCALL_REPLACEMENT = true;
+	
+	
+	/**
 	 * Enables a specific rewrite for code motion, i.e., hoisting loop invariant code
 	 * out of while, for, and parfor loops.
 	 */
@@ -201,7 +208,6 @@ public class OptimizerUtils
 	 */
 	public static boolean FEDERATED_COMPILATION = false;
 	
-	
 	/**
 	 * Specifies a multiplier computing the degree of parallelism of parallel
 	 * text read/write out of the available degree of parallelism. Set it to 1.0
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassReplaceEvalFunctionCalls.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassReplaceEvalFunctionCalls.java
new file mode 100644
index 0000000..f3d5b93
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassReplaceEvalFunctionCalls.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.ipa;
+
+
+import java.util.List;
+
+import org.apache.sysds.common.Builtins;
+import org.apache.sysds.common.Types.OpOpData;
+import org.apache.sysds.common.Types.OpOpN;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.FunctionOp;
+import org.apache.sysds.hops.FunctionOp.FunctionType;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.controlprogram.Program;
+
+/**
+ * This rewrite applies static hop dag and statement block
+ * rewrites such as constant folding and branch removal
+ * in order to simplify statistic propagation.
+ * 
+ */
+public class IPAPassReplaceEvalFunctionCalls extends IPAPass
+{
+	@Override
+	public boolean isApplicable(FunctionCallGraph fgraph) {
+		return fgraph.containsSecondOrderCall()
+			&& OptimizerUtils.ALLOW_EVAL_FCALL_REPLACEMENT;
+	}
+	
+	@Override
+	public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
+		//note: we only replace eval calls that feed into a twrite/pwrite call
+		//(i.e., after statement-block rewrites for splitting after val have been
+		//applied) - this approach ensures that the requirements of fcalls are met
+		
+		// for each namespace, handle function statement blocks
+		boolean ret = false;
+		for (String namespaceKey : prog.getNamespaces().keySet())
+			for (String fname : prog.getFunctionStatementBlocks(namespaceKey).keySet()) {
+				FunctionStatementBlock fsblock = prog.getFunctionStatementBlock(namespaceKey,fname);
+				ret |= rewriteStatementBlock(prog, fsblock, fgraph);
+			}
+		
+		// handle regular statement blocks in "main" method
+		for(StatementBlock sb : prog.getStatementBlocks())
+			ret |= rewriteStatementBlock(prog, sb, fgraph);
+		
+		return ret;
+	}
+	
+	private static boolean rewriteStatementBlock(DMLProgram prog, StatementBlock sb, FunctionCallGraph fgraph) {
+		boolean ret = false;
+		if (sb instanceof FunctionStatementBlock) {
+			FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
+			FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
+			for (StatementBlock csb : fstmt.getBody())
+				ret |= rewriteStatementBlock(prog, csb, fgraph);
+		}
+		else if (sb instanceof WhileStatementBlock) {
+			WhileStatementBlock wsb = (WhileStatementBlock) sb;
+			WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
+			for (StatementBlock csb : wstmt.getBody())
+				ret |= rewriteStatementBlock(prog, csb, fgraph);
+		}
+		else if (sb instanceof IfStatementBlock) {
+			IfStatementBlock isb = (IfStatementBlock) sb;
+			IfStatement istmt = (IfStatement)isb.getStatement(0);
+			for (StatementBlock csb : istmt.getIfBody())
+				ret |= rewriteStatementBlock(prog, csb, fgraph);
+			for (StatementBlock csb : istmt.getElseBody())
+				ret |= rewriteStatementBlock(prog, csb, fgraph);
+		}
+		else if (sb instanceof ForStatementBlock) { //incl parfor
+			ForStatementBlock fsb = (ForStatementBlock) sb;
+			ForStatement fstmt = (ForStatement)fsb.getStatement(0);
+			for (StatementBlock csb : fstmt.getBody())
+				ret |= rewriteStatementBlock(prog, csb, fgraph);
+		}
+		else { //generic (last-level)
+			ret |= checkAndReplaceEvalFunctionCall(prog, sb, fgraph);
+		}
+		return ret;
+	}
+	
+	private static boolean checkAndReplaceEvalFunctionCall(DMLProgram prog, StatementBlock sb, FunctionCallGraph fgraph) {
+		if( sb.getHops() == null )
+			return false;
+		
+		List<Hop> roots = sb.getHops();
+		boolean ret = false;
+		for( int i=0; i<roots.size(); i++ ) {
+			Hop root = roots.get(i);
+			if( HopRewriteUtils.isData(root, OpOpData.TRANSIENTWRITE, OpOpData.PERSISTENTWRITE)
+				&& HopRewriteUtils.isNary(root.getInput(0), OpOpN.EVAL)
+				&& root.getInput(0).getInput(0) instanceof LiteralOp //constant name
+				&& root.getInput(0).getParent().size() == 1)
+			{
+				Hop eval = root.getInput(0);
+				String outvar = ((DataOp)root).getName();
+				
+				//get function name and namespace
+				String fname = ((LiteralOp)eval.getInput(0)).getStringValue();
+				String fnamespace = prog.getDefaultFunctionDictionary().containsFunction(fname) ?
+					DMLProgram.DEFAULT_NAMESPACE : DMLProgram.BUILTIN_NAMESPACE;
+				if( fname.contains(Program.KEY_DELIM) ) {
+					String[] fparts = DMLProgram.splitFunctionKey(fname);
+					fnamespace = fparts[0];
+					fname = fparts[1];
+				}
+				fname = fnamespace.equals(DMLProgram.BUILTIN_NAMESPACE) ?
+					Builtins.getInternalFName(fname, eval.getInput(1).getDataType()) : fname;
+				
+				//obtain functions and abort if inputs passed via list or output not a matrix
+				FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fnamespace, fname);
+				FunctionStatement fstmt = fsb!=null ? (FunctionStatement)fsb.getStatement(0) : null;
+				if( eval.getInput().size() > 1 && eval.getInput(1).getDataType().isList()
+					&& (fstmt==null || !fstmt.getInputParams().get(0).getDataType().isList())) {
+					LOG.warn("IPA: eval("+fnamespace+"::"+fname+") "
+						+ "applicable for replacement, but list inputs not yet supported.");
+					continue;
+				}
+				if( fstmt.getOutputParams().size() != 1 || !fstmt.getOutputParams().get(0).getDataType().isMatrix() ) {
+					LOG.warn("IPA: eval("+fnamespace+"::"+fname+") "
+						+ "applicable for replacement, but function output is not a matrix.");
+					continue;
+				}
+				
+				//construct direct function call
+				FunctionOp fop = new FunctionOp(FunctionType.DML, fnamespace, fname,
+					fstmt.getInputParamNames(), eval.getInput().subList(1, eval.getInput().size()),
+					new String[]{outvar}, true);
+				HopRewriteUtils.copyLineNumbers(eval, fop);
+				HopRewriteUtils.removeAllChildReferences(eval);
+				roots.set(i, fop); //replaced
+				ret = true;
+			}
+		}
+		return ret;
+	}
+}
diff --git a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index 0b47a19..8224192 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -144,6 +144,7 @@ public class InterProceduralAnalysis
 		_passes.add(new IPAPassRemoveConstantBinaryOps());
 		_passes.add(new IPAPassPropagateReplaceLiterals());
 		_passes.add(new IPAPassInlineFunctions());
+		_passes.add(new IPAPassReplaceEvalFunctionCalls());
 		_passes.add(new IPAPassEliminateDeadCode());
 		_passes.add(new IPAPassFlagNonDeterminism());
 		//note: apply rewrites last because statement block rewrites
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatement.java b/src/main/java/org/apache/sysds/parser/FunctionStatement.java
index 3cd0f6c..353f3b4 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatement.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatement.java
@@ -47,6 +47,11 @@ public class FunctionStatement extends Statement
 		return _inputParams;
 	}
 	
+	public String[] getInputParamNames() {
+		return _inputParams.stream()
+			.map(d -> d.getName()).toArray(String[]::new);
+	}
+	
 	public DataIdentifier getInputParam(String name) {
 		return _inputParams.stream()
 			.filter(d -> d.getName().equals(name))
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 6de5878..cbe9be0 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -386,7 +386,8 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
 					decimal);
 			}
 			else {
-				throw new DMLRuntimeException("toString only converts matrix, tensors, lists or frames to string");
+				throw new DMLRuntimeException("toString only converts "
+					+ "matrix, tensors, lists or frames to string: "+cacheData.getClass().getSimpleName());
 			}
 			if(!(cacheData instanceof ListObject)) {
 				ec.releaseCacheableData(getParam("target"));
diff --git a/src/test/java/org/apache/sysds/test/functions/misc/FunctionPotpourriTest.java b/src/test/java/org/apache/sysds/test/functions/misc/FunctionPotpourriTest.java
index 214730b..53f615b 100644
--- a/src/test/java/org/apache/sysds/test/functions/misc/FunctionPotpourriTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/misc/FunctionPotpourriTest.java
@@ -20,12 +20,14 @@
 package org.apache.sysds.test.functions.misc;
 
 import org.apache.sysds.hops.HopsException;
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.parser.LanguageException;
 import org.apache.sysds.parser.ParseException;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Assert;
+import org.junit.Ignore;
 import org.junit.Test;
 
 public class FunctionPotpourriTest extends AutomatedTestBase 
@@ -56,6 +58,7 @@ public class FunctionPotpourriTest extends AutomatedTestBase
 		"FunPotpourriEvalList1Arg",
 		"FunPotpourriEvalList2Arg",
 		"FunPotpourriEvalNamespace",
+		"FunPotpourriEvalNamespace2",
 		"FunPotpourriBuiltinPrecedence",
 		"FunPotpourriParforEvalBuiltin",
 	};
@@ -91,6 +94,11 @@ public class FunctionPotpourriTest extends AutomatedTestBase
 	}
 	
 	@Test
+	public void testFunctionEval2() {
+		runFunctionTest( TEST_NAMES[3], null, true );
+	}
+	
+	@Test
 	public void testFunctionSubsetReturn() {
 		runFunctionTest( TEST_NAMES[4], null );
 	}
@@ -171,52 +179,118 @@ public class FunctionPotpourriTest extends AutomatedTestBase
 	}
 	
 	@Test
+	@Ignore //TODO support list
+	public void testFunctionNestedParforEval2() {
+		runFunctionTest( TEST_NAMES[19], null, true );
+	}
+	
+	@Test
 	public void testFunctionMultiEval() {
 		runFunctionTest( TEST_NAMES[20], null );
 	}
 	
 	@Test
+	@Ignore //TODO support list
+	public void testFunctionMultiEval2() {
+		runFunctionTest( TEST_NAMES[20], null, true );
+	}
+	
+	@Test
 	public void testFunctionEvalPred() {
 		runFunctionTest( TEST_NAMES[21], null );
 	}
 	
 	@Test
+	@Ignore //TODO support list
+	public void testFunctionEvalPred2() {
+		runFunctionTest( TEST_NAMES[21], null, true );
+	}
+	
+	@Test
 	public void testFunctionEvalList1Arg() {
 		runFunctionTest( TEST_NAMES[22], null );
 	}
 	
 	@Test
+	@Ignore //TODO support list
+	public void testFunctionEvalList1Arg2() {
+		runFunctionTest( TEST_NAMES[22], null, true );
+	}
+	
+	@Test
 	public void testFunctionEvalList2Arg() {
 		runFunctionTest( TEST_NAMES[23], null );
 	}
 	
 	@Test
+	@Ignore //TODO support list
+	public void testFunctionEvalList2Arg2() {
+		runFunctionTest( TEST_NAMES[23], null, true );
+	}
+	
+	@Test
 	public void testFunctionEvalNamespace() {
 		runFunctionTest( TEST_NAMES[24], null );
 	}
 	
 	@Test
-	public void testFunctionBuiltinPrecedence() {
+	@Ignore //TODO support list
+	public void testFunctionEvalNamespace2() {
+		runFunctionTest( TEST_NAMES[24], null, true );
+	}
+	
+	@Test
+	public void testFunctionEvalNamespacePlain() {
 		runFunctionTest( TEST_NAMES[25], null );
 	}
 	
 	@Test
-	public void testFunctionParforEvalBuiltin() {
+	public void testFunctionEvalNamespacePlain2() {
+		runFunctionTest( TEST_NAMES[25], null, true );
+	}
+	
+	@Test
+	public void testFunctionBuiltinPrecedence() {
 		runFunctionTest( TEST_NAMES[26], null );
 	}
 	
+	@Test
+	public void testFunctionParforEvalBuiltin() {
+		runFunctionTest( TEST_NAMES[27], null );
+	}
+	
+	@Test
+	@Ignore //TODO support list
+	public void testFunctionParforEvalBuiltin2() {
+		runFunctionTest( TEST_NAMES[27], null, true );
+	}
+	
 	private void runFunctionTest(String testName, Class<?> error) {
+		runFunctionTest(testName, error, false);
+	}
+	
+	private void runFunctionTest(String testName, Class<?> error, boolean evalRewrite) {
 		TestConfiguration config = getTestConfiguration(testName);
 		loadTestConfiguration(config);
 		
-		String HOME = SCRIPT_DIR + TEST_DIR;
-		fullDMLScriptName = HOME + testName + ".dml";
-		programArgs = new String[]{"-explain", "hops", "-stats",
-			"-args", String.valueOf(error).toUpperCase()};
-
-		runTest(true, error != null, error, -1);
-
-		if( testName.equals(TEST_NAMES[17]) )
-			Assert.assertTrue(heavyHittersContainsString("print"));
+		boolean oldFlag = OptimizerUtils.ALLOW_EVAL_FCALL_REPLACEMENT;
+		try {
+			OptimizerUtils.ALLOW_EVAL_FCALL_REPLACEMENT = evalRewrite;
+			
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testName + ".dml";
+			programArgs = new String[]{"-explain", "hops", "-stats",
+				"-args", String.valueOf(error).toUpperCase()};
+	
+			runTest(true, error != null, error, -1);
+	
+			if( testName.equals(TEST_NAMES[17]) )
+				Assert.assertTrue(heavyHittersContainsString("print"));
+			if( evalRewrite )
+				Assert.assertTrue(!heavyHittersContainsString("eval"));
+		}
+		finally {
+			OptimizerUtils.ALLOW_EVAL_FCALL_REPLACEMENT = oldFlag;
+		}
 	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java
index 5c097f0..a5136cc 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java
@@ -140,18 +140,17 @@ public abstract class MLContextTestBase extends AutomatedTestBase {
 
 	private static MLResults safeExecute(ByteArrayOutputStream buff, Script script, ScriptExecutor sce){
 		try {
-
 			MLResults res = sce == null ? ml.execute(script): ml.execute(script,sce);
 			return res;
 		}
 		catch(Exception e) {
-				StringBuilder errorMessage = new StringBuilder();
-				errorMessage.append("\nfailed to run script: ");
-				errorMessage.append("\nStandard Out:");
-				errorMessage.append("\n" + buff);
-				errorMessage.append("\nStackTrace:");
-				errorMessage.append(AutomatedTestBase.getStackTraceString(e, 0));
-				fail(errorMessage.toString());
+			StringBuilder errorMessage = new StringBuilder();
+			errorMessage.append("\nfailed to run script: ");
+			errorMessage.append("\nStandard Out:");
+			errorMessage.append("\n" + buff);
+			errorMessage.append("\nStackTrace:");
+			errorMessage.append(AutomatedTestBase.getStackTraceString(e, 0));
+			fail(errorMessage.toString());
 		}
 		return null;
 	}
diff --git a/src/test/scripts/functions/misc/FunPotpourriEvalNamespace2.dml b/src/test/scripts/functions/misc/FunPotpourriEvalNamespace2.dml
new file mode 100644
index 0000000..3aa2ed8
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunPotpourriEvalNamespace2.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("./src/test/scripts/functions/misc/FunPotpourriEvalNamespaceFuns.dml") as fns1
+
+ns = "./src/test/scripts/functions/misc/FunPotpourriEvalNamespaceFuns.dml"
+X = rand(rows=100, cols=100)
+s = eval(ns+"::foo", X, TRUE)
+print(toString(s))