You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2020/04/13 16:42:00 UTC

[systemml] branch master updated: [SYSTEMDS-291] Extended eval lazy function compilation (nested builtins)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 5f1cdf3  [SYSTEMDS-291] Extended eval lazy function compilation (nested builtins)
5f1cdf3 is described below

commit 5f1cdf367b0616359461f1fd198898d59f0598a4
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Mon Apr 13 18:39:47 2020 +0200

    [SYSTEMDS-291] Extended eval lazy function compilation (nested builtins)
    
    This patch extends the lazy function compilation of dml-bodied builtin
    functions called through eval. We now support nested dml-bodied function
    calls (e.g., eval -> lm -> lmDS/lmCG) which is crucial for generic
    primitives of hyper-parameter optimization and the enumeration of
    cleaning pipelines.
---
 .../sysds/hops/rewrite/RewriteConstantFolding.java |  2 +-
 .../java/org/apache/sysds/parser/DMLProgram.java   |  4 ++
 .../org/apache/sysds/parser/DMLTranslator.java     |  2 +-
 .../sysds/parser/FunctionCallIdentifier.java       |  8 +--
 .../sysds/parser/FunctionStatementBlock.java       | 14 ++---
 .../org/apache/sysds/parser/IfStatementBlock.java  |  4 +-
 .../org/apache/sysds/parser/StatementBlock.java    |  2 +-
 .../sysds/parser/dml/DmlSyntacticValidator.java    |  8 ++-
 .../sysds/runtime/controlprogram/Program.java      | 18 +++++-
 .../controlprogram/paramserv/ParamservUtils.java   |  2 +-
 .../instructions/cp/EvalNaryCPInstruction.java     | 70 ++++++++++++++--------
 .../sysds/runtime/lineage/LineageRewriteReuse.java |  2 +-
 .../test/functions/mlcontext/MLContextTest.java    | 10 ++++
 .../mlcontext/eval4-nested_builtin-test.dml        | 30 ++++++++++
 14 files changed, 129 insertions(+), 47 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
index ec098e6..6e04082 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
@@ -184,7 +184,7 @@ public class RewriteConstantFolding extends HopRewriteRule
 	
 	private BasicProgramBlock getProgramBlock() {
 		if( _tmpPB == null )
-			_tmpPB = new BasicProgramBlock( new Program() );
+			_tmpPB = new BasicProgramBlock(new Program());
 		return _tmpPB;
 	}
 	
diff --git a/src/main/java/org/apache/sysds/parser/DMLProgram.java b/src/main/java/org/apache/sysds/parser/DMLProgram.java
index e86464c..4e5e229 100644
--- a/src/main/java/org/apache/sysds/parser/DMLProgram.java
+++ b/src/main/java/org/apache/sysds/parser/DMLProgram.java
@@ -131,6 +131,10 @@ public class DMLProgram
 		return ret;
 	}
 
+	public boolean containsFunctionStatementBlock(String name) {
+		return _functionBlocks.containsKey(name);
+	}
+	
 	public void addFunctionStatementBlock(String fname, FunctionStatementBlock fsb) {
 		_functionBlocks.put(fname, fsb);
 	}
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 9e41f9b..e61c928 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -412,7 +412,7 @@ public class DMLTranslator
 		throws LanguageException, DMLRuntimeException, LopsException, HopsException 
 	{	
 		// constructor resets the set of registered functions
-		Program rtprog = new Program();
+		Program rtprog = new Program(prog);
 		
 		// for all namespaces, translate function statement blocks into function program blocks
 		for (String namespace : prog.getNamespaces().keySet()){
diff --git a/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java b/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
index fc5e1d8..497d591 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
@@ -115,8 +115,8 @@ public class FunctionCallIdentifier extends DataIdentifier
 		}
 		if (hasNamed && hasUnnamed){
 			raiseValidateError(" In DML, functions can only have named parameters " +
-					"(e.g., name1=value1, name2=value2) or unnamed parameters (e.g, value1, value2). " + 
-					_name + " has both parameter types.", conditional);
+				"(e.g., name1=value1, name2=value2) or unnamed parameters (e.g, value1, value2). " + 
+				_name + " has both parameter types.", conditional);
 		}
 		
 		// Step 4: validate expressions for each passed parameter
@@ -176,8 +176,8 @@ public class FunctionCallIdentifier extends DataIdentifier
 		if (_namespace != null && _namespace.length() > 0 && !_namespace.equals(DMLProgram.DEFAULT_NAMESPACE))
 			sb.append(_namespace + "::");
 		sb.append(_name);
-		sb.append(" ( ");		
-				
+		sb.append(" ( ");
+		
 		for (int i = 0; i < _paramExprs.size(); i++){
 			sb.append(_paramExprs.get(i).toString());
 			if (i<_paramExprs.size() - 1) 
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index aad710b..7d32816 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -51,13 +51,13 @@ public class FunctionStatementBlock extends StatementBlock
 			
 		// validate all function input parameters
 		ArrayList<DataIdentifier> inputValues = fstmt.getInputParams();
-        for( DataIdentifier inputValue : inputValues ) {
-            //check all input matrices have value type double
-            if( inputValue.getDataType()==DataType.MATRIX && inputValue.getValueType()!=ValueType.FP64 ) {
-                raiseValidateError("for function " + fstmt.getName() + ", input variable " + inputValue.getName() 
-                                 + " has an unsupported value type of " + inputValue.getValueType() + ".", false);
-            }
-        }
+		for( DataIdentifier inputValue : inputValues ) {
+			//check all input matrices have value type double
+			if( inputValue.getDataType()==DataType.MATRIX && inputValue.getValueType()!=ValueType.FP64 ) {
+				raiseValidateError("for function " + fstmt.getName() + ", input variable " + inputValue.getName() 
+					+ " has an unsupported value type of " + inputValue.getValueType() + ".", false);
+			}
+		}
 		
 		// handle DML-bodied functions
 		// perform validate for function body
diff --git a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
index 7322bce..4762a14 100644
--- a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
@@ -55,9 +55,9 @@ public class IfStatementBlock extends StatementBlock
 		HashMap<String,ConstIdentifier> constVarsIfCopy = new HashMap<>(constVars);
 		HashMap<String,ConstIdentifier> constVarsElseCopy = new HashMap<> (constVars);
 		
-		VariableSet idsIfCopy 	= new VariableSet(ids);
+		VariableSet idsIfCopy   = new VariableSet(ids);
 		VariableSet idsElseCopy = new VariableSet(ids);
-		VariableSet	idsOrigCopy = new VariableSet(ids);
+		VariableSet idsOrigCopy = new VariableSet(ids);
 
 		// handle if stmt body
 		_dmlProg = dmlProg;
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 5991315..f275a84 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -230,7 +230,7 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo
 		return true;
 	}
 
-    public boolean isRewritableFunctionCall(Statement stmt, DMLProgram dmlProg) {
+	public boolean isRewritableFunctionCall(Statement stmt, DMLProgram dmlProg) {
 
 		// for regular stmt, check if this is a function call stmt block
 		if (stmt instanceof AssignmentStatement || stmt instanceof MultiAssignmentStatement){
diff --git a/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java b/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
index 5e2cae5..5841e3b 100644
--- a/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
+++ b/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
@@ -610,18 +610,20 @@ public class DmlSyntacticValidator implements DmlListener {
 		}
 	}
 	
-	public static FunctionStatementBlock loadAndParseBuiltinFunction(String name, String namespace, DataType dt) {
+	public static Map<String,FunctionStatementBlock> loadAndParseBuiltinFunction(String name, String namespace) {
 		if( !Builtins.contains(name, true, false) ) {
 			throw new DMLRuntimeException("Function "
 				+ DMLProgram.constructFunctionKey(namespace, name)+" is not a builtin function.");
 		}
 		//load and add builtin DML-bodied functions (via tmp validator instance)
+		//including nested builtin function calls unless already loaded
 		DmlSyntacticValidator tmp = new DmlSyntacticValidator(
 			new CustomErrorListener(), new HashMap<>(), namespace, new HashSet<>());
 		String filePath = Builtins.getFilePath(name);
 		DMLProgram prog = tmp.parseAndAddImportedFunctions(namespace, filePath, null);
-		String name2 = Builtins.getInternalFName(name, dt);
-		return prog.getNamedFunctionStatementBlocks().get(name2);
+		
+		//construct output map of all functions
+		return prog.getNamedFunctionStatementBlocks();
 	}
 
 
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java b/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
index e868a38..03a516b 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
@@ -33,7 +33,8 @@ public class Program
 {
 	public static final String KEY_DELIM = "::";
 	
-	public ArrayList<ProgramBlock> _programBlocks;
+	private DMLProgram _prog;
+	private ArrayList<ProgramBlock> _programBlocks;
 
 	private HashMap<String, HashMap<String,FunctionProgramBlock>> _namespaceFunctions;
 	
@@ -42,7 +43,20 @@ public class Program
 		_namespaceFunctions.put(DMLProgram.DEFAULT_NAMESPACE, new HashMap<>());
 		_programBlocks = new ArrayList<>();
 	}
+	
+	public Program(DMLProgram prog) {
+		this();
+		setDMLProg(prog);
+	}
 
+	public void setDMLProg(DMLProgram prog) {
+		_prog = prog;
+	}
+	
+	public DMLProgram getDMLProg() {
+		return _prog;
+	}
+	
 	public synchronized void addFunctionProgramBlock(String namespace, String fname, FunctionProgramBlock fpb) {
 		if( fpb == null )
 			throw new DMLRuntimeException("Invalid null function program block.");
@@ -124,7 +138,7 @@ public class Program
 	public Program clone(boolean deep) {
 		if( deep )
 			throw new NotImplementedException();
-		Program ret = new Program();
+		Program ret = new Program(_prog);
 		//shallow copy of all program blocks
 		ret._programBlocks.addAll(_programBlocks);
 		//shallow copy of all functions, except external 
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index c8b8a3a..84fc2c9 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -252,7 +252,7 @@ public class ParamservUtils {
 	}
 	
 	private static Program copyProgramFunctions(Program prog) {
-		Program newProg = new Program();
+		Program newProg = new Program(prog.getDMLProg());
 		prog.getFunctionProgramBlocks()
 			.forEach((func, pb) -> putFunction(newProg, copyFunction(func, pb)));
 		return newProg;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 92b7227..62f6c67 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -21,7 +21,10 @@ package org.apache.sysds.runtime.instructions.cp;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Map;
+import java.util.Map.Entry;
 
+import org.apache.sysds.common.Builtins;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.rewrite.ProgramRewriter;
@@ -69,11 +72,11 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
 		//2. copy the created output matrix
 		MatrixObject outputMO = new MatrixObject(ec.getMatrixObject(output.getName()));
 
-		//3. lazy loading of dml-bodied builtin functions
+		//3. lazy loading of dml-bodied builtin functions (incl. rename 
+		// of function name to dml-bodied builtin scheme (data-type-specific) 
 		if( !ec.getProgram().containsFunctionProgramBlock(null, funcName) ) {
-			FunctionProgramBlock fpb = compileFunctionProgramBlock(
-				funcName, boundInputs[0].getDataType(), ec.getProgram());
-			ec.getProgram().addFunctionProgramBlock(null, funcName, fpb);
+			compileFunctionProgramBlock(funcName, boundInputs[0].getDataType(), ec.getProgram());
+			funcName = Builtins.getInternalFName(funcName, boundInputs[0].getDataType());
 		}
 		
 		//4. call the function
@@ -101,32 +104,51 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
 		ec.setVariable(output.getName(), outputMO);
 	}
 	
-	private static FunctionProgramBlock compileFunctionProgramBlock(String name, DataType dt, Program prog) {
+	private static void compileFunctionProgramBlock(String name, DataType dt, Program prog) {
 		//load builtin file and parse function statement block
-		FunctionStatementBlock fsb = DmlSyntacticValidator
-			.loadAndParseBuiltinFunction(name, DMLProgram.DEFAULT_NAMESPACE, dt);
+		Map<String,FunctionStatementBlock> fsbs = DmlSyntacticValidator
+			.loadAndParseBuiltinFunction(name, DMLProgram.DEFAULT_NAMESPACE);
+		if( fsbs.isEmpty() )
+			throw new DMLRuntimeException("Failed to compile function '"+name+"'.");
 		
-		// validate function (could be avoided for performance because known builtin functions)
-		DMLProgram dmlp = fsb.getDMLProg();
+		// prepare common data structures, including a consolidated dml program
+		// to facilitate function validation which tries to inline lazily loaded
+		// and existing functions.
+		DMLProgram dmlp = (prog.getDMLProg() != null) ? prog.getDMLProg() :
+			fsbs.get(Builtins.getInternalFName(name, dt)).getDMLProg();
+		for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet() ) {
+			if( !dmlp.containsFunctionStatementBlock(fsb.getKey()) )
+				dmlp.addFunctionStatementBlock(fsb.getKey(), fsb.getValue());
+			fsb.getValue().setDMLProg(dmlp);
+		}
 		DMLTranslator dmlt = new DMLTranslator(dmlp);
-		dmlt.liveVariableAnalysisFunction(dmlp, fsb);
-		dmlt.validateFunction(dmlp, fsb);
-		
-		// compile hop dags, rewrite hop dags and compile lop dags
-		dmlt.constructHops(fsb);
 		ProgramRewriter rewriter = new ProgramRewriter(true, false);
-		rewriter.rewriteHopDAGsFunction(fsb, false); //rewrite and merge
-		DMLTranslator.resetHopsDAGVisitStatus(fsb);
-		rewriter.rewriteHopDAGsFunction(fsb, true); //rewrite and split
-		DMLTranslator.resetHopsDAGVisitStatus(fsb);
 		ProgramRewriter rewriter2 = new ProgramRewriter(false, true);
-		rewriter2.rewriteHopDAGsFunction(fsb, true);
-		DMLTranslator.resetHopsDAGVisitStatus(fsb);
-		DMLTranslator.refreshMemEstimates(fsb);
-		dmlt.constructLops(fsb);
+		
+		// validate functions, in two passes for cross references
+		for( FunctionStatementBlock fsb : fsbs.values() ) {
+			dmlt.liveVariableAnalysisFunction(dmlp, fsb);
+			dmlt.validateFunction(dmlp, fsb);
+		}
+		
+		// compile hop dags, rewrite hop dags and compile lop dags
+		for( FunctionStatementBlock fsb : fsbs.values() ) {
+			dmlt.constructHops(fsb);
+			rewriter.rewriteHopDAGsFunction(fsb, false); //rewrite and merge
+			DMLTranslator.resetHopsDAGVisitStatus(fsb);
+			rewriter.rewriteHopDAGsFunction(fsb, true); //rewrite and split
+			DMLTranslator.resetHopsDAGVisitStatus(fsb);
+			rewriter2.rewriteHopDAGsFunction(fsb, true);
+			DMLTranslator.resetHopsDAGVisitStatus(fsb);
+			DMLTranslator.refreshMemEstimates(fsb);
+			dmlt.constructLops(fsb);
+		}
 		
 		// compile runtime program
-		return (FunctionProgramBlock) dmlt.createRuntimeProgramBlock(
-			prog, fsb, ConfigurationManager.getDMLConfig());
+		for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet() ) {
+			FunctionProgramBlock fpb = (FunctionProgramBlock) dmlt
+				.createRuntimeProgramBlock(prog, fsb.getValue(), ConfigurationManager.getDMLConfig());
+			prog.addFunctionProgramBlock(null, fsb.getKey(), fpb);
+		}
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
index d400623..f48c869 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
@@ -789,7 +789,7 @@ public class LineageRewriteReuse
 
 	private static BasicProgramBlock getProgramBlock() {
 		if( _lrPB == null )
-			_lrPB = new BasicProgramBlock( new Program() );
+			_lrPB = new BasicProgramBlock(new Program());
 		return _lrPB;
 	}
 }
\ No newline at end of file
diff --git a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
index c2add4d..ce7df49 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
@@ -117,6 +117,16 @@ public class MLContextTest extends MLContextTestBase {
 		ml.execute(script);
 		ml.setExplain(false);
 	}
+	
+	@Test
+	public void testExecuteEvalNestedBuiltinTest() {
+		System.out.println("MLContextTest - eval builtin test");
+		setExpectedStdOut("TRUE");
+		ml.setExplain(true);
+		Script script = dmlFromFile(baseDirectory + File.separator + "eval4-nested_builtin-test.dml");
+		ml.execute(script);
+		ml.setExplain(false);
+	}
 
 	@Test
 	public void testCreateDMLScriptBasedOnStringAndExecute() {
diff --git a/src/test/scripts/functions/mlcontext/eval4-nested_builtin-test.dml b/src/test/scripts/functions/mlcontext/eval4-nested_builtin-test.dml
new file mode 100644
index 0000000..2217085
--- /dev/null
+++ b/src/test/scripts/functions/mlcontext/eval4-nested_builtin-test.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=100, cols=10, seed=37)
+y = rand(rows=100, cols=1, seed=38)
+
+F = cbind(as.frame("lm"),as.frame("mlogreg"));
+ix = ifelse(sum(X)>1, 1, 2);
+R1 = eval(as.scalar(F[1,ix]), X, y, 0, 1e-7, 1e-7, 0, FALSE); #calls lm->lmDS
+R2 = lmCG(X=X, y=y, verbose=FALSE);
+
+print(sum(abs(R1-R2)<1e-6)==ncol(X));