You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2020/04/14 20:21:55 UTC

[systemml] branch master updated: [SYSTEMDS-291] Extended eval function calls (named/unnamed list args)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new b9114b9  [SYSTEMDS-291] Extended eval function calls (named/unnamed list args)
b9114b9 is described below

commit b9114b90a180d79e4ab6de7647e3a35cd7ce3f78
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Tue Apr 14 21:37:28 2020 +0200

    [SYSTEMDS-291] Extended eval function calls (named/unnamed list args)
    
    This patch improves the existing eval function calls by support for
    named and unnamed list inputs. During the function call these list
    arguments are expanded and ordered on demand (if the the function
    doesn't have a signature of a single list argument).
    
    Furthermore, this patch integrates the tests for gridSearch, and makes a
    couple of smaller improvements to list data types (e.g., list append,
    proper function return handling with unknown value types).
---
 scripts/builtin/gridSearch.dml                     | 30 ++++-----
 .../sysds/parser/BuiltinFunctionExpression.java    |  2 +-
 .../sysds/parser/FunctionStatementBlock.java       |  8 +--
 .../controlprogram/context/ExecutionContext.java   |  2 +-
 .../sysds/runtime/instructions/cp/CPOperand.java   |  9 ++-
 .../instructions/cp/EvalNaryCPInstruction.java     | 74 +++++++++++++++++++---
 .../instructions/cp/FunctionCallCPInstruction.java |  2 +-
 .../cp/ListAppendRemoveCPInstruction.java          | 12 +++-
 .../sysds/runtime/instructions/cp/ListObject.java  | 10 +++
 .../functions/builtin/BuiltinGridSearchTest.java   |  9 +--
 .../scripts/functions/builtin/GridSearchLM.dml     | 10 +--
 11 files changed, 121 insertions(+), 47 deletions(-)

diff --git a/scripts/builtin/gridSearch.dml b/scripts/builtin/gridSearch.dml
index 227b863..ca37745 100644
--- a/scripts/builtin/gridSearch.dml
+++ b/scripts/builtin/gridSearch.dml
@@ -35,8 +35,8 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, String train, String
     vect = as.matrix(paramValues[j,1]);
     paramVals[j,1:nrow(vect)] = t(vect);
   }
-	cumLens = rev(cumprod(rev(paramLens))/rev(paramLens));
-	numConfigs = prod(paramLens);
+  cumLens = rev(cumprod(rev(paramLens))/rev(paramLens));
+  numConfigs = prod(paramLens);
   
   # Step 1) materialize hyper-parameter combinations 
   # (simplify debugging and compared to compute negligible)
@@ -53,28 +53,22 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, String train, String
   # TODO integrate cross validation
   Rbeta = matrix(0, nrow(HP), ncol(X));
   Rloss = matrix(0, nrow(HP), 1);
-  arguments = list(X=X, y=y);
+  # TODO pass arguments for function call from outside
+  arguments = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1, verbose=FALSE);
 
   parfor( i in 1:nrow(HP) ) {
-    # a) prepare training arguments
+    # a) replace training arguments
     largs = arguments;
-    for( j in 1:numParams ) {
-      key = as.scalar(params[j]);
-      value = as.scalar(HP[i,j]);
-      largs = append(largs, list(key=value));
-    }
-
-    # b) core training/scoring
-    lbeta = eval(train, largs);
-    lloss = eval(predict, list(X, y, lbeta));
-
-    # c) write models and loss back to output
-    Rbeta[i,] = lbeta;
-    Rloss[i,] = lloss;
+    for( j in 1:numParams )
+      largs[as.scalar(params[j])] = as.scalar(HP[i,j]);
+    # b) core training/scoring and write-back
+    # TODO investigate rmvar handling with explicit binding (lbeta)
+    Rbeta[i,] = t(eval(train, largs));
+    Rloss[i,] = eval(predict, list(X, y, t(Rbeta[i,])));
   }
 
   # Step 3) select best parameter combination
   ix = as.scalar(rowIndexMin(t(Rloss)));
-  B = Rbeta[ix,];          # optimal model
+  B = t(Rbeta[ix,]);       # optimal model
   opt = as.frame(HP[ix,]); # optimal hyper-parameters
 }
diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index cfda652..2c5d61a 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -786,7 +786,7 @@ public class BuiltinFunctionExpression extends DataIdentifier
 				//list append
 				if(getFirstExpr().getOutput().getDataType().isList() )
 					for(int i=1; i<getAllExpr().length; i++)
-						checkDataTypeParam(getExpr(i), DataType.SCALAR, DataType.MATRIX, DataType.FRAME);
+						checkDataTypeParam(getExpr(i), DataType.SCALAR, DataType.MATRIX, DataType.FRAME, DataType.LIST);
 				//matrix append (rbind/cbind)
 				else
 					for(int i=0; i<getAllExpr().length; i++)
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index 7d32816..a8f3c75 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -86,7 +86,8 @@ public class FunctionStatementBlock extends StatementBlock
 				raiseValidateError("for function " + fstmt.getName() + ", return variable " + curr.getName() + " data type of " + curr.getDataType() + " does not match data type in function signature of " + returnValue.getDataType(), conditional);
 			}
 			
-			if (curr.getValueType() != ValueType.UNKNOWN && !curr.getValueType().equals(returnValue.getValueType())){
+			if (curr.getValueType() != ValueType.UNKNOWN && returnValue.getValueType() != ValueType.UNKNOWN
+				&& !curr.getValueType().equals(returnValue.getValueType())){
 				
 				// attempt to convert value type: handle conversion from scalar DOUBLE or INT
 				if (curr.getDataType() == DataType.SCALAR && returnValue.getDataType() == DataType.SCALAR){ 
@@ -121,9 +122,8 @@ public class FunctionStatementBlock extends StatementBlock
 								+ " does not match value type in function signature of " 
 								+ returnValue.getValueType() + " and cannot safely cast " + curr.getValueType() 
 								+ " as " + returnValue.getValueType());
-						
-					} 
-				}	
+					}
+				}
 				else {
 					throw new LanguageException(curr.printErrorLocation() + "for function " + fstmt.getName() + ", return variable " + curr.getName() + " value type of " + curr.getValueType() + " does not match value type in function signature of " + returnValue.getValueType() + " and cannot safely cast " + curr.getValueType() + " as " + returnValue.getValueType());
 				}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index d2d9887..2022224 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -223,7 +223,7 @@ public class ExecutionContext {
 		if( dat == null )
 			throw new DMLRuntimeException(getNonExistingVarError(varname));
 		if( !(dat instanceof MatrixObject) )
-			throw new DMLRuntimeException("Variable '"+varname+"' is not a matrix.");
+			throw new DMLRuntimeException("Variable '"+varname+"' is not a matrix: "+dat.getClass().getName());
 		
 		return (MatrixObject) dat;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
index 2f6e200..be46930 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
@@ -44,7 +44,7 @@ public class CPOperand
 		split(str);
 	}
 	
-	public CPOperand(String name, ValueType vt, DataType dt ) {
+	public CPOperand(String name, ValueType vt, DataType dt) {
 		this(name, vt, dt, false);
 	}
 
@@ -69,6 +69,13 @@ public class CPOperand
 		_isLiteral = variable._isLiteral;
 		_literal = variable._literal;
 	}
+	
+	public CPOperand(String name, Data dat) {
+		_name = name;
+		_valueType = dat.getValueType();
+		_dataType = dat.getDataType();
+		_isLiteral = false;
+	}
 
 	public String getName() {
 		return _name;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 62f6c67..ce337b5 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -21,6 +21,8 @@ package org.apache.sysds.runtime.instructions.cp;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 
@@ -28,6 +30,7 @@ import org.apache.sysds.common.Builtins;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.rewrite.ProgramRewriter;
+import org.apache.sysds.lops.compile.Dag;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DMLTranslator;
 import org.apache.sysds.parser.FunctionStatementBlock;
@@ -62,9 +65,9 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
 		
 		// bound the inputs to avoiding being deleted after the function call
 		CPOperand[] boundInputs = Arrays.copyOfRange(inputs, 1, inputs.length);
-		ArrayList<String> boundOutputNames = new ArrayList<>();
+		List<String> boundOutputNames = new ArrayList<>();
 		boundOutputNames.add(output.getName());
-		ArrayList<String> boundInputNames = new ArrayList<>();
+		List<String> boundInputNames = new ArrayList<>();
 		for (CPOperand input : boundInputs) {
 			boundInputNames.add(input.getName());
 		}
@@ -73,19 +76,43 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
 		MatrixObject outputMO = new MatrixObject(ec.getMatrixObject(output.getName()));
 
 		//3. lazy loading of dml-bodied builtin functions (incl. rename 
-		// of function name to dml-bodied builtin scheme (data-type-specific) 
-		if( !ec.getProgram().containsFunctionProgramBlock(null, funcName) ) {
-			compileFunctionProgramBlock(funcName, boundInputs[0].getDataType(), ec.getProgram());
-			funcName = Builtins.getInternalFName(funcName, boundInputs[0].getDataType());
+		// of function name to dml-bodied builtin scheme (data-type-specific)
+		DataType dt1 = boundInputs[0].getDataType().isList() ? 
+			DataType.MATRIX : boundInputs[0].getDataType();
+		String funcName2 = Builtins.getInternalFName(funcName, dt1);
+		if( !ec.getProgram().containsFunctionProgramBlock(null, funcName)) {
+			if( !ec.getProgram().containsFunctionProgramBlock(null,funcName2) )
+				compileFunctionProgramBlock(funcName, dt1, ec.getProgram());
+			funcName = funcName2;
 		}
 		
-		//4. call the function
+		//4. expand list arguments if needed
+		CPOperand[] boundInputs2 = null;
 		FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(null, funcName);
+		if( boundInputs.length == 1 && boundInputs[0].getDataType().isList()
+			&& fpb.getInputParams().size() > 1 && !fpb.getInputParams().get(0).getDataType().isList()) 
+		{
+			ListObject lo = ec.getListObject(boundInputs[0]);
+			checkValidArguments(lo.getData(), lo.getNames(), fpb.getInputParamNames());
+			if( lo.isNamedList() )
+				lo = reorderNamedListForFunctionCall(lo, fpb.getInputParamNames());
+			boundInputs2 = new CPOperand[lo.getLength()];
+			for( int i=0; i<lo.getLength(); i++ ) {
+				Data in = lo.getData(i);
+				String varName = Dag.getNextUniqueVarname(in.getDataType());
+				ec.getVariables().put(varName, in);
+				boundInputs2[i] = new CPOperand(varName, in);
+			}
+			boundInputNames = lo.isNamedList() ? lo.getNames() : fpb.getInputParamNames();
+			boundInputs = boundInputs2;
+		}
+		
+		//5. call the function
 		FunctionCallCPInstruction fcpi = new FunctionCallCPInstruction(null, funcName,
 			boundInputs, boundInputNames, fpb.getInputParamNames(), boundOutputNames, "eval func");
 		fcpi.processInstruction(ec);
 
-		//5. convert the result to matrix
+		//6. convert the result to matrix
 		Data newOutput = ec.getVariable(output);
 		if (newOutput instanceof MatrixObject) {
 			return;
@@ -102,6 +129,12 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
 		outputMO.acquireModify(mb);
 		outputMO.release();
 		ec.setVariable(output.getName(), outputMO);
+		
+		//7. cleanup of variable expanded from list
+		if( boundInputs2 != null ) {
+			for( CPOperand op : boundInputs2 )
+				VariableCPInstruction.processRemoveVariableInstruction(ec, op.getName());
+		}
 	}
 	
 	private static void compileFunctionProgramBlock(String name, DataType dt, Program prog) {
@@ -151,4 +184,29 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
 			prog.addFunctionProgramBlock(null, fsb.getKey(), fpb);
 		}
 	}
+	
+	private void checkValidArguments(List<Data> loData, List<String> loNames, List<String> fArgNames) {
+		//check number of parameters
+		int listSize = (loNames != null) ? loNames.size() : loData.size();
+		if( listSize != fArgNames.size() )
+			throw new DMLRuntimeException("Failed to expand list for function call "
+				+ "(mismatching number of arguments: "+listSize+" vs. "+fArgNames.size()+").");
+		
+		//check individual parameters
+		if( loNames != null ) {
+			HashSet<String> probe = new HashSet<>();
+			for( String var : fArgNames )
+				probe.add(var);
+			for( String var : loNames )
+				if( !probe.contains(var) )
+					throw new DMLRuntimeException("List argument named '"+var+"' not in function signature.");
+		}
+	}
+	
+	private ListObject reorderNamedListForFunctionCall(ListObject in, List<String> fArgNames) {
+		List<Data> sortedData = new ArrayList<>();
+		for( String name : fArgNames )
+			sortedData.add(in.getData(name));
+		return new ListObject(sortedData, new ArrayList<>(fArgNames));
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index e605a55..9c1eac0 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -104,7 +104,7 @@ public class FunctionCallCPInstruction extends CPInstruction {
 	@Override
 	public void processInstruction(ExecutionContext ec) {
 		if( LOG.isTraceEnabled() ){
-			LOG.trace("Executing instruction : " + this.toString());
+			LOG.trace("Executing instruction : " + toString());
 		}
 		// get the function program block (stored in the Program object)
 		FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java
index b672ff3..1721c4a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java
@@ -46,8 +46,16 @@ public final class ListAppendRemoveCPInstruction extends AppendCPInstruction {
 		if( getOpcode().equals("append") ) {
 			//copy on write and append unnamed argument
 			Data dat2 = ec.getVariable(input2);
-			LineageItem li = DMLScript.LINEAGE ? ec.getLineage().get(input2):null;
-			ListObject tmp = lo.copy().add(dat2, li);
+			LineageItem li = DMLScript.LINEAGE ? ec.getLineage().get(input2) : null;
+			ListObject tmp = null;
+			if( dat2 instanceof ListObject && ((ListObject)dat2).getLength() == 1 ) {
+				//add unfolded elements for lists of size 1 (e.g., named)
+				ListObject lo2 = (ListObject) dat2;
+				tmp = lo.copy().add(lo2.getName(0), lo2.getData(0), li);
+			}
+			else {
+				tmp = lo.copy().add(dat2, li);
+			}
 			//set output variable
 			ec.setVariable(output.getName(), tmp);
 		}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index 8cfb682..8798799 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -104,6 +104,14 @@ public class ListObject extends Data {
 		return _data;
 	}
 	
+	public Data getData(int ix) {
+		return _data.get(ix);
+	}
+	
+	public Data getData(String name) {
+		return slice(name);
+	}
+	
 	public List<LineageItem> getLineageItems() {
 		return _lineage;
 	}
@@ -219,6 +227,8 @@ public class ListObject extends Data {
 		if( _names != null && name == null )
 			throw new DMLRuntimeException("Cannot add to a named list");
 		//otherwise append and ignore name
+		if( _names != null )
+			_names.add(name);
 		_data.add(dat);
 		if (_lineage == null && li!= null) 
 			_lineage = new ArrayList<>();
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
index 556a7d7..232db2f 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
@@ -45,17 +45,14 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
 	
 	@Test
 	public void testGridSearchCP() {
-		//TODO additional list features needed
-		//runGridSearch(ExecType.CP);
+		runGridSearch(ExecType.CP);
 	}
 	
 	@Test
 	public void testGridSearchSpark() {
-		//TODO additional list features needed
-		//runGridSearch(ExecType.SPARK);
+		runGridSearch(ExecType.SPARK);
 	}
 	
-	@SuppressWarnings("unused")
 	private void runGridSearch(ExecType et)
 	{
 		ExecMode modeOld = setExecMode(et);
@@ -64,7 +61,7 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
 			String HOME = SCRIPT_DIR + TEST_DIR;
 	
 			fullDMLScriptName = HOME + TEST_NAME + ".dml";
-			programArgs = new String[] {"-args", input("X"), input("y"), output("R")};
+			programArgs = new String[] {"-explain","-args", input("X"), input("y"), output("R")};
 			double[][] X = getRandomMatrix(rows, cols, 0, 1, 0.8, -1);
 			double[][] y = getRandomMatrix(rows, 1, 0, 1, 0.8, -1);
 			writeInputMatrixWithMTD("X", X, true);
diff --git a/src/test/scripts/functions/builtin/GridSearchLM.dml b/src/test/scripts/functions/builtin/GridSearchLM.dml
index 9b33713..41a6fa1 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM.dml
+++ b/src/test/scripts/functions/builtin/GridSearchLM.dml
@@ -19,8 +19,8 @@
 #
 #-------------------------------------------------------------
 
-l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return (Double loss) {
-  loss = sum((y - X%*%B)^2);
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return (Matrix[Double] loss) {
+  loss = as.matrix(sum((y - X%*%B)^2));
 }
 
 X = read($1);
@@ -33,12 +33,12 @@ Xtest = X[(N+1):nrow(X),];
 ytest = y[(N+1):nrow(X),];
 
 params = list("reg", "tol", "maxi");
-paramRanges = list(10^seq(0,-4), 10^seq(-5,-9), 10^seq(1,3));
-[B1, opt] = gridSearch(Xtrain, ytrain, "lm", "lmPredict", params, paramRanges, TRUE);
+paramRanges = list(10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
+[B1, opt] = gridSearch(Xtrain, ytrain, "lm", "l2norm", params, paramRanges, TRUE);
 B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
 
 l1 = l2norm(Xtest, ytest, B1);
 l2 = l2norm(Xtest, ytest, B2);
-R = l1 <= l2;
+R = as.scalar(l1 <= l2);
 
 write(R, $3)