You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2020/04/13 16:42:00 UTC
[systemml] branch master updated: [SYSTEMDS-291] Extended eval lazy
function compilation (nested builtins)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new 5f1cdf3 [SYSTEMDS-291] Extended eval lazy function compilation (nested builtins)
5f1cdf3 is described below
commit 5f1cdf367b0616359461f1fd198898d59f0598a4
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Mon Apr 13 18:39:47 2020 +0200
[SYSTEMDS-291] Extended eval lazy function compilation (nested builtins)
This patch extends the lazy function compilation of dml-bodied builtin
functions called through eval. We now support nested dml-bodied function
calls (e.g., eval -> lm -> lmDS/lmCG) which is crucial for generic
primitives of hyper-parameter optimization and the enumeration of
cleaning pipelines.
---
.../sysds/hops/rewrite/RewriteConstantFolding.java | 2 +-
.../java/org/apache/sysds/parser/DMLProgram.java | 4 ++
.../org/apache/sysds/parser/DMLTranslator.java | 2 +-
.../sysds/parser/FunctionCallIdentifier.java | 8 +--
.../sysds/parser/FunctionStatementBlock.java | 14 ++---
.../org/apache/sysds/parser/IfStatementBlock.java | 4 +-
.../org/apache/sysds/parser/StatementBlock.java | 2 +-
.../sysds/parser/dml/DmlSyntacticValidator.java | 8 ++-
.../sysds/runtime/controlprogram/Program.java | 18 +++++-
.../controlprogram/paramserv/ParamservUtils.java | 2 +-
.../instructions/cp/EvalNaryCPInstruction.java | 70 ++++++++++++++--------
.../sysds/runtime/lineage/LineageRewriteReuse.java | 2 +-
.../test/functions/mlcontext/MLContextTest.java | 10 ++++
.../mlcontext/eval4-nested_builtin-test.dml | 30 ++++++++++
14 files changed, 129 insertions(+), 47 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
index ec098e6..6e04082 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
@@ -184,7 +184,7 @@ public class RewriteConstantFolding extends HopRewriteRule
private BasicProgramBlock getProgramBlock() {
if( _tmpPB == null )
- _tmpPB = new BasicProgramBlock( new Program() );
+ _tmpPB = new BasicProgramBlock(new Program());
return _tmpPB;
}
diff --git a/src/main/java/org/apache/sysds/parser/DMLProgram.java b/src/main/java/org/apache/sysds/parser/DMLProgram.java
index e86464c..4e5e229 100644
--- a/src/main/java/org/apache/sysds/parser/DMLProgram.java
+++ b/src/main/java/org/apache/sysds/parser/DMLProgram.java
@@ -131,6 +131,10 @@ public class DMLProgram
return ret;
}
+ public boolean containsFunctionStatementBlock(String name) {
+ return _functionBlocks.containsKey(name);
+ }
+
public void addFunctionStatementBlock(String fname, FunctionStatementBlock fsb) {
_functionBlocks.put(fname, fsb);
}
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 9e41f9b..e61c928 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -412,7 +412,7 @@ public class DMLTranslator
throws LanguageException, DMLRuntimeException, LopsException, HopsException
{
// constructor resets the set of registered functions
- Program rtprog = new Program();
+ Program rtprog = new Program(prog);
// for all namespaces, translate function statement blocks into function program blocks
for (String namespace : prog.getNamespaces().keySet()){
diff --git a/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java b/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
index fc5e1d8..497d591 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
@@ -115,8 +115,8 @@ public class FunctionCallIdentifier extends DataIdentifier
}
if (hasNamed && hasUnnamed){
raiseValidateError(" In DML, functions can only have named parameters " +
- "(e.g., name1=value1, name2=value2) or unnamed parameters (e.g, value1, value2). " +
- _name + " has both parameter types.", conditional);
+ "(e.g., name1=value1, name2=value2) or unnamed parameters (e.g, value1, value2). " +
+ _name + " has both parameter types.", conditional);
}
// Step 4: validate expressions for each passed parameter
@@ -176,8 +176,8 @@ public class FunctionCallIdentifier extends DataIdentifier
if (_namespace != null && _namespace.length() > 0 && !_namespace.equals(DMLProgram.DEFAULT_NAMESPACE))
sb.append(_namespace + "::");
sb.append(_name);
- sb.append(" ( ");
-
+ sb.append(" ( ");
+
for (int i = 0; i < _paramExprs.size(); i++){
sb.append(_paramExprs.get(i).toString());
if (i<_paramExprs.size() - 1)
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index aad710b..7d32816 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -51,13 +51,13 @@ public class FunctionStatementBlock extends StatementBlock
// validate all function input parameters
ArrayList<DataIdentifier> inputValues = fstmt.getInputParams();
- for( DataIdentifier inputValue : inputValues ) {
- //check all input matrices have value type double
- if( inputValue.getDataType()==DataType.MATRIX && inputValue.getValueType()!=ValueType.FP64 ) {
- raiseValidateError("for function " + fstmt.getName() + ", input variable " + inputValue.getName()
- + " has an unsupported value type of " + inputValue.getValueType() + ".", false);
- }
- }
+ for( DataIdentifier inputValue : inputValues ) {
+ //check all input matrices have value type double
+ if( inputValue.getDataType()==DataType.MATRIX && inputValue.getValueType()!=ValueType.FP64 ) {
+ raiseValidateError("for function " + fstmt.getName() + ", input variable " + inputValue.getName()
+ + " has an unsupported value type of " + inputValue.getValueType() + ".", false);
+ }
+ }
// handle DML-bodied functions
// perform validate for function body
diff --git a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
index 7322bce..4762a14 100644
--- a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
@@ -55,9 +55,9 @@ public class IfStatementBlock extends StatementBlock
HashMap<String,ConstIdentifier> constVarsIfCopy = new HashMap<>(constVars);
HashMap<String,ConstIdentifier> constVarsElseCopy = new HashMap<> (constVars);
- VariableSet idsIfCopy = new VariableSet(ids);
+ VariableSet idsIfCopy = new VariableSet(ids);
VariableSet idsElseCopy = new VariableSet(ids);
- VariableSet idsOrigCopy = new VariableSet(ids);
+ VariableSet idsOrigCopy = new VariableSet(ids);
// handle if stmt body
_dmlProg = dmlProg;
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 5991315..f275a84 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -230,7 +230,7 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo
return true;
}
- public boolean isRewritableFunctionCall(Statement stmt, DMLProgram dmlProg) {
+ public boolean isRewritableFunctionCall(Statement stmt, DMLProgram dmlProg) {
// for regular stmt, check if this is a function call stmt block
if (stmt instanceof AssignmentStatement || stmt instanceof MultiAssignmentStatement){
diff --git a/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java b/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
index 5e2cae5..5841e3b 100644
--- a/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
+++ b/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
@@ -610,18 +610,20 @@ public class DmlSyntacticValidator implements DmlListener {
}
}
- public static FunctionStatementBlock loadAndParseBuiltinFunction(String name, String namespace, DataType dt) {
+ public static Map<String,FunctionStatementBlock> loadAndParseBuiltinFunction(String name, String namespace) {
if( !Builtins.contains(name, true, false) ) {
throw new DMLRuntimeException("Function "
+ DMLProgram.constructFunctionKey(namespace, name)+" is not a builtin function.");
}
//load and add builtin DML-bodied functions (via tmp validator instance)
+ //including nested builtin function calls unless already loaded
DmlSyntacticValidator tmp = new DmlSyntacticValidator(
new CustomErrorListener(), new HashMap<>(), namespace, new HashSet<>());
String filePath = Builtins.getFilePath(name);
DMLProgram prog = tmp.parseAndAddImportedFunctions(namespace, filePath, null);
- String name2 = Builtins.getInternalFName(name, dt);
- return prog.getNamedFunctionStatementBlocks().get(name2);
+
+ //construct output map of all functions
+ return prog.getNamedFunctionStatementBlocks();
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java b/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
index e868a38..03a516b 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
@@ -33,7 +33,8 @@ public class Program
{
public static final String KEY_DELIM = "::";
- public ArrayList<ProgramBlock> _programBlocks;
+ private DMLProgram _prog;
+ private ArrayList<ProgramBlock> _programBlocks;
private HashMap<String, HashMap<String,FunctionProgramBlock>> _namespaceFunctions;
@@ -42,7 +43,20 @@ public class Program
_namespaceFunctions.put(DMLProgram.DEFAULT_NAMESPACE, new HashMap<>());
_programBlocks = new ArrayList<>();
}
+
+ public Program(DMLProgram prog) {
+ this();
+ setDMLProg(prog);
+ }
+ public void setDMLProg(DMLProgram prog) {
+ _prog = prog;
+ }
+
+ public DMLProgram getDMLProg() {
+ return _prog;
+ }
+
public synchronized void addFunctionProgramBlock(String namespace, String fname, FunctionProgramBlock fpb) {
if( fpb == null )
throw new DMLRuntimeException("Invalid null function program block.");
@@ -124,7 +138,7 @@ public class Program
public Program clone(boolean deep) {
if( deep )
throw new NotImplementedException();
- Program ret = new Program();
+ Program ret = new Program(_prog);
//shallow copy of all program blocks
ret._programBlocks.addAll(_programBlocks);
//shallow copy of all functions, except external
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index c8b8a3a..84fc2c9 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -252,7 +252,7 @@ public class ParamservUtils {
}
private static Program copyProgramFunctions(Program prog) {
- Program newProg = new Program();
+ Program newProg = new Program(prog.getDMLProg());
prog.getFunctionProgramBlocks()
.forEach((func, pb) -> putFunction(newProg, copyFunction(func, pb)));
return newProg;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 92b7227..62f6c67 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -21,7 +21,10 @@ package org.apache.sysds.runtime.instructions.cp;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Map;
+import java.util.Map.Entry;
+import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.ProgramRewriter;
@@ -69,11 +72,11 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
//2. copy the created output matrix
MatrixObject outputMO = new MatrixObject(ec.getMatrixObject(output.getName()));
- //3. lazy loading of dml-bodied builtin functions
+ //3. lazy loading of dml-bodied builtin functions (incl. rename
+ // of function name to dml-bodied builtin scheme (data-type-specific)
if( !ec.getProgram().containsFunctionProgramBlock(null, funcName) ) {
- FunctionProgramBlock fpb = compileFunctionProgramBlock(
- funcName, boundInputs[0].getDataType(), ec.getProgram());
- ec.getProgram().addFunctionProgramBlock(null, funcName, fpb);
+ compileFunctionProgramBlock(funcName, boundInputs[0].getDataType(), ec.getProgram());
+ funcName = Builtins.getInternalFName(funcName, boundInputs[0].getDataType());
}
//4. call the function
@@ -101,32 +104,51 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
ec.setVariable(output.getName(), outputMO);
}
- private static FunctionProgramBlock compileFunctionProgramBlock(String name, DataType dt, Program prog) {
+ private static void compileFunctionProgramBlock(String name, DataType dt, Program prog) {
//load builtin file and parse function statement block
- FunctionStatementBlock fsb = DmlSyntacticValidator
- .loadAndParseBuiltinFunction(name, DMLProgram.DEFAULT_NAMESPACE, dt);
+ Map<String,FunctionStatementBlock> fsbs = DmlSyntacticValidator
+ .loadAndParseBuiltinFunction(name, DMLProgram.DEFAULT_NAMESPACE);
+ if( fsbs.isEmpty() )
+ throw new DMLRuntimeException("Failed to compile function '"+name+"'.");
- // validate function (could be avoided for performance because known builtin functions)
- DMLProgram dmlp = fsb.getDMLProg();
+ // prepare common data structures, including a consolidated dml program
+ // to facilitate function validation which tries to inline lazily loaded
+ // and existing functions.
+ DMLProgram dmlp = (prog.getDMLProg() != null) ? prog.getDMLProg() :
+ fsbs.get(Builtins.getInternalFName(name, dt)).getDMLProg();
+ for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet() ) {
+ if( !dmlp.containsFunctionStatementBlock(fsb.getKey()) )
+ dmlp.addFunctionStatementBlock(fsb.getKey(), fsb.getValue());
+ fsb.getValue().setDMLProg(dmlp);
+ }
DMLTranslator dmlt = new DMLTranslator(dmlp);
- dmlt.liveVariableAnalysisFunction(dmlp, fsb);
- dmlt.validateFunction(dmlp, fsb);
-
- // compile hop dags, rewrite hop dags and compile lop dags
- dmlt.constructHops(fsb);
ProgramRewriter rewriter = new ProgramRewriter(true, false);
- rewriter.rewriteHopDAGsFunction(fsb, false); //rewrite and merge
- DMLTranslator.resetHopsDAGVisitStatus(fsb);
- rewriter.rewriteHopDAGsFunction(fsb, true); //rewrite and split
- DMLTranslator.resetHopsDAGVisitStatus(fsb);
ProgramRewriter rewriter2 = new ProgramRewriter(false, true);
- rewriter2.rewriteHopDAGsFunction(fsb, true);
- DMLTranslator.resetHopsDAGVisitStatus(fsb);
- DMLTranslator.refreshMemEstimates(fsb);
- dmlt.constructLops(fsb);
+
+ // validate functions, in two passes for cross references
+ for( FunctionStatementBlock fsb : fsbs.values() ) {
+ dmlt.liveVariableAnalysisFunction(dmlp, fsb);
+ dmlt.validateFunction(dmlp, fsb);
+ }
+
+ // compile hop dags, rewrite hop dags and compile lop dags
+ for( FunctionStatementBlock fsb : fsbs.values() ) {
+ dmlt.constructHops(fsb);
+ rewriter.rewriteHopDAGsFunction(fsb, false); //rewrite and merge
+ DMLTranslator.resetHopsDAGVisitStatus(fsb);
+ rewriter.rewriteHopDAGsFunction(fsb, true); //rewrite and split
+ DMLTranslator.resetHopsDAGVisitStatus(fsb);
+ rewriter2.rewriteHopDAGsFunction(fsb, true);
+ DMLTranslator.resetHopsDAGVisitStatus(fsb);
+ DMLTranslator.refreshMemEstimates(fsb);
+ dmlt.constructLops(fsb);
+ }
// compile runtime program
- return (FunctionProgramBlock) dmlt.createRuntimeProgramBlock(
- prog, fsb, ConfigurationManager.getDMLConfig());
+ for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet() ) {
+ FunctionProgramBlock fpb = (FunctionProgramBlock) dmlt
+ .createRuntimeProgramBlock(prog, fsb.getValue(), ConfigurationManager.getDMLConfig());
+ prog.addFunctionProgramBlock(null, fsb.getKey(), fpb);
+ }
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
index d400623..f48c869 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
@@ -789,7 +789,7 @@ public class LineageRewriteReuse
private static BasicProgramBlock getProgramBlock() {
if( _lrPB == null )
- _lrPB = new BasicProgramBlock( new Program() );
+ _lrPB = new BasicProgramBlock(new Program());
return _lrPB;
}
}
\ No newline at end of file
diff --git a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
index c2add4d..ce7df49 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
@@ -117,6 +117,16 @@ public class MLContextTest extends MLContextTestBase {
ml.execute(script);
ml.setExplain(false);
}
+
+ @Test
+ public void testExecuteEvalNestedBuiltinTest() {
+ System.out.println("MLContextTest - eval builtin test");
+ setExpectedStdOut("TRUE");
+ ml.setExplain(true);
+ Script script = dmlFromFile(baseDirectory + File.separator + "eval4-nested_builtin-test.dml");
+ ml.execute(script);
+ ml.setExplain(false);
+ }
@Test
public void testCreateDMLScriptBasedOnStringAndExecute() {
diff --git a/src/test/scripts/functions/mlcontext/eval4-nested_builtin-test.dml b/src/test/scripts/functions/mlcontext/eval4-nested_builtin-test.dml
new file mode 100644
index 0000000..2217085
--- /dev/null
+++ b/src/test/scripts/functions/mlcontext/eval4-nested_builtin-test.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=100, cols=10, seed=37)
+y = rand(rows=100, cols=1, seed=38)
+
+F = cbind(as.frame("lm"),as.frame("mlogreg"));
+ix = ifelse(sum(X)>1, 1, 2);
+R1 = eval(as.scalar(F[1,ix]), X, y, 0, 1e-7, 1e-7, 0, FALSE); #calls lm->lmDS
+R2 = lmCG(X=X, y=y, verbose=FALSE);
+
+print(sum(abs(R1-R2)<1e-6)==ncol(X));