You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/11/04 01:58:48 UTC
[3/5] systemml git commit: [MINOR] Performance function invocation of
dml-bodied UDFs
[MINOR] Performance function invocation of dml-bodied UDFs
This patch slightly improved the function invocation performance of
dml-bodied UDFs from 452K/s to 521K/s.
Furthermore, this also includes a fix of the test for LinregCG over
compressed data.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/14c410ce
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/14c410ce
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/14c410ce
Branch: refs/heads/master
Commit: 14c410ce06f3a5c56d1bcb1ac509fab4a0711f5f
Parents: d7d312c
Author: Matthias Boehm <mb...@gmail.com>
Authored: Fri Nov 3 14:23:50 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Fri Nov 3 18:59:29 2017 -0700
----------------------------------------------------------------------
.../controlprogram/LocalVariableMap.java | 4 +
.../controlprogram/ParForProgramBlock.java | 4 +-
.../context/ExecutionContext.java | 41 +++---
.../cp/FunctionCallCPInstruction.java | 135 +++++++++----------
.../functions/compress/CompressedLinregCG.java | 19 ++-
5 files changed, 95 insertions(+), 108 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
index 7ebe1a0..e894495 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
@@ -63,6 +63,10 @@ public class LocalVariableMap implements Cloneable
return localMap.keySet();
}
+ public Set<Entry<String, Data>> entrySet() {
+ return localMap.entrySet();
+ }
+
/**
* Retrieves the data object given its name.
*
http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
index e2568cb..760ddff 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
@@ -633,7 +633,7 @@ public class ParForProgramBlock extends ForProgramBlock
//preserve shared input/result variables of cleanup
ArrayList<String> varList = ec.getVarList();
- HashMap<String, Boolean> varState = ec.pinVariables(varList);
+ boolean[] varState = ec.pinVariables(varList);
try
{
@@ -1329,7 +1329,7 @@ public class ParForProgramBlock extends ForProgramBlock
}
}
- private void cleanupSharedVariables( ExecutionContext ec, HashMap<String,Boolean> varState )
+ private void cleanupSharedVariables( ExecutionContext ec, boolean[] varState )
throws DMLRuntimeException
{
//TODO needs as precondition a systematic treatment of persistent read information.
http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index 79a658d..ecb9629 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -20,7 +20,6 @@
package org.apache.sysml.runtime.controlprogram.context;
import java.util.ArrayList;
-import java.util.HashMap;
import java.util.List;
import org.apache.commons.logging.Log;
@@ -151,6 +150,11 @@ public class ExecutionContext {
return _variables.get(name);
}
+ public Data getVariable(CPOperand operand) throws DMLRuntimeException {
+ return operand.getDataType().isScalar() ?
+ getScalarInput(operand) : getVariable(operand.getName());
+ }
+
public void setVariable(String name, Data val) {
_variables.put(name, val);
}
@@ -528,30 +532,25 @@ public class ExecutionContext {
* The function returns the OLD "clean up" state of matrix objects.
*
* @param varList variable list
- * @return map of old cleanup state of matrix objects
+ * @return indicator vector of old cleanup state of matrix objects
*/
- public HashMap<String,Boolean> pinVariables(ArrayList<String> varList)
+ public boolean[] pinVariables(ArrayList<String> varList)
{
//2-pass approach since multiple vars might refer to same matrix object
- HashMap<String, Boolean> varsState = new HashMap<>();
+ boolean[] varsState = new boolean[varList.size()];
//step 1) get current information
- for( String var : varList )
- {
- Data dat = _variables.get(var);
- if( dat instanceof MatrixObject ) {
- MatrixObject mo = (MatrixObject)dat;
- varsState.put( var, mo.isCleanupEnabled() );
- }
+ for( int i=0; i<varList.size(); i++ ) {
+ Data dat = _variables.get(varList.get(i));
+ if( dat instanceof MatrixObject )
+ varsState[i] = ((MatrixObject)dat).isCleanupEnabled();
}
//step 2) pin variables
- for( String var : varList ) {
- Data dat = _variables.get(var);
- if( dat instanceof MatrixObject ) {
- MatrixObject mo = (MatrixObject)dat;
- mo.enableCleanup(false);
- }
+ for( int i=0; i<varList.size(); i++ ) {
+ Data dat = _variables.get(varList.get(i));
+ if( dat instanceof MatrixObject )
+ ((MatrixObject)dat).enableCleanup(false);
}
return varsState;
@@ -573,11 +572,11 @@ public class ExecutionContext {
* @param varList variable list
* @param varsState variable state
*/
- public void unpinVariables(ArrayList<String> varList, HashMap<String,Boolean> varsState) {
- for( String var : varList) {
- Data dat = _variables.get(var);
+ public void unpinVariables(ArrayList<String> varList, boolean[] varsState) {
+ for( int i=0; i<varList.size(); i++ ) {
+ Data dat = _variables.get(varList.get(i));
if( dat instanceof MatrixObject )
- ((MatrixObject)dat).enableCleanup(varsState.get(var));
+ ((MatrixObject)dat).enableCleanup(varsState[i]);
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
index b901dfc..e785196 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -20,10 +20,8 @@
package org.apache.sysml.runtime.instructions.cp;
import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashMap;
import java.util.HashSet;
-import java.util.LinkedList;
+import java.util.Map.Entry;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.lops.Lop;
@@ -43,32 +41,30 @@ import org.apache.sysml.runtime.instructions.InstructionUtils;
public class FunctionCallCPInstruction extends CPInstruction {
private String _functionName;
private String _namespace;
+ private final CPOperand[] _boundInputs;
+ private final ArrayList<String> _boundInputNames;
+ private final ArrayList<String> _boundOutputNames;
+ private HashSet<String> _expectRetVars = null;
- public String getFunctionName() {
- return _functionName;
- }
-
- public String getNamespace() {
- return _namespace;
- }
-
- // stores both the bound input and output parameters
- private ArrayList<CPOperand> _boundInputParamOperands;
- private ArrayList<String> _boundInputParamNames;
- private ArrayList<String> _boundOutputParamNames;
-
- private FunctionCallCPInstruction(String namespace, String functName, ArrayList<CPOperand> boundInParamOperands,
- ArrayList<String> boundInParamNames, ArrayList<String> boundOutParamNames, String istr) {
+ private FunctionCallCPInstruction(String namespace, String functName, CPOperand[] boundInputs,
+ ArrayList<String> boundInputNames, ArrayList<String> boundOutputNames, String istr) {
super(null, functName, istr);
_cptype = CPINSTRUCTION_TYPE.External;
_functionName = functName;
_namespace = namespace;
- _boundInputParamOperands = boundInParamOperands;
- _boundInputParamNames = boundInParamNames;
- _boundOutputParamNames = boundOutParamNames;
+ _boundInputs = boundInputs;
+ _boundInputNames = boundInputNames;
+ _boundOutputNames = boundOutputNames;
+ }
+ public String getFunctionName() {
+ return _functionName;
}
+ public String getNamespace() {
+ return _namespace;
+ }
+
public static FunctionCallCPInstruction parseInstruction(String str)
throws DMLRuntimeException
{
@@ -78,20 +74,17 @@ public class FunctionCallCPInstruction extends CPInstruction {
String functionName = parts[2];
int numInputs = Integer.valueOf(parts[3]);
int numOutputs = Integer.valueOf(parts[4]);
- ArrayList<CPOperand> boundInParamOperands = new ArrayList<>();
- ArrayList<String> boundInParamNames = new ArrayList<>();
- ArrayList<String> boundOutParamNames = new ArrayList<>();
+ CPOperand[] boundInputs = new CPOperand[numInputs];
+ ArrayList<String> boundInputNames = new ArrayList<>();
+ ArrayList<String> boundOutputNames = new ArrayList<>();
for (int i = 0; i < numInputs; i++) {
- CPOperand operand = new CPOperand(parts[5 + i]);
- boundInParamOperands.add(operand);
- boundInParamNames.add(operand.getName());
+ boundInputs[i] = new CPOperand(parts[5 + i]);
+ boundInputNames.add(boundInputs[i].getName());
}
- for (int i = 0; i < numOutputs; i++) {
- boundOutParamNames.add(parts[5 + numInputs + i]);
- }
-
- return new FunctionCallCPInstruction ( namespace,functionName,
- boundInParamOperands, boundInParamNames, boundOutParamNames, str );
+ for (int i = 0; i < numOutputs; i++)
+ boundOutputNames.add(parts[5 + numInputs + i]);
+ return new FunctionCallCPInstruction ( namespace,
+ functionName, boundInputs, boundInputNames, boundOutputNames, str );
}
@Override
@@ -120,10 +113,10 @@ public class FunctionCallCPInstruction extends CPInstruction {
// get the function program block (stored in the Program object)
FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
- // sanity check number of function paramters
- if( _boundInputParamNames.size() < fpb.getInputParams().size() ) {
+ // sanity check number of function parameters
+ if( _boundInputs.length < fpb.getInputParams().size() ) {
throw new DMLRuntimeException("Number of bound input parameters does not match the function signature "
- + "("+_boundInputParamNames.size()+", but "+fpb.getInputParams().size()+" expected)");
+ + "("+_boundInputs.length+", but "+fpb.getInputParams().size()+" expected)");
}
// create bindings to formal parameters for given function call
@@ -131,35 +124,31 @@ public class FunctionCallCPInstruction extends CPInstruction {
LocalVariableMap functionVariables = new LocalVariableMap();
for( int i=0; i<fpb.getInputParams().size(); i++)
{
- DataIdentifier currFormalParam = fpb.getInputParams().get(i);
- String currFormalParamName = currFormalParam.getName();
- Data currFormalParamValue = null;
-
- CPOperand operand = _boundInputParamOperands.get(i);
- String varname = operand.getName();
//error handling non-existing variables
- if( !operand.isLiteral() && !ec.containsVariable(varname) ) {
- throw new DMLRuntimeException("Input variable '"+varname+"' not existing on call of " +
+ CPOperand input = _boundInputs[i];
+ if( !input.isLiteral() && !ec.containsVariable(input.getName()) ) {
+ throw new DMLRuntimeException("Input variable '"+input.getName()+"' not existing on call of " +
DMLProgram.constructFunctionKey(_namespace, _functionName) + " (line "+getLineNum()+").");
}
//get input matrix/frame/scalar
- currFormalParamValue = (operand.getDataType()!=DataType.SCALAR) ? ec.getVariable(varname) :
- ec.getScalarInput(varname, operand.getValueType(), operand.isLiteral());
+ DataIdentifier currFormalParam = fpb.getInputParams().get(i);
+ Data value = ec.getVariable(input);
//graceful value type conversion for scalar inputs with wrong type
- if( currFormalParamValue.getDataType() == DataType.SCALAR
- && currFormalParamValue.getValueType() != currFormalParam.getValueType() )
+ if( value.getDataType() == DataType.SCALAR
+ && value.getValueType() != currFormalParam.getValueType() )
{
- currFormalParamValue = ScalarObjectFactory.createScalarObject(
- currFormalParam.getValueType(), (ScalarObject) currFormalParamValue);
+ value = ScalarObjectFactory.createScalarObject(
+ currFormalParam.getValueType(), (ScalarObject)value);
}
- functionVariables.put(currFormalParamName, currFormalParamValue);
+ //set input parameter
+ functionVariables.put(currFormalParam.getName(), value);
}
// Pin the input variables so that they do not get deleted
// from pb's symbol table at the end of execution of function
- HashMap<String,Boolean> pinStatus = ec.pinVariables(_boundInputParamNames);
+ boolean[] pinStatus = ec.pinVariables(_boundInputNames);
// Create a symbol table under a new execution context for the function invocation,
// and copy the function arguments into the created table.
@@ -182,29 +171,29 @@ public class FunctionCallCPInstruction extends CPInstruction {
String fname = DMLProgram.constructFunctionKey(_namespace, _functionName);
throw new DMLRuntimeException("error executing function " + fname, e);
}
- LocalVariableMap retVars = fn_ec.getVariables();
// cleanup all returned variables w/o binding
- Collection<String> retVarnames = new LinkedList<>(retVars.keySet());
- HashSet<String> probeVars = new HashSet<>();
- for(DataIdentifier di : fpb.getOutputParams())
- probeVars.add(di.getName());
- for( String var : retVarnames ) {
- if( !probeVars.contains(var) ) //cleanup candidate
- {
- Data dat = fn_ec.removeVariable(var);
- if( dat != null && dat instanceof MatrixObject )
- fn_ec.cleanupMatrixObject((MatrixObject)dat);
- }
+ if( _expectRetVars == null ) {
+ _expectRetVars = new HashSet<>();
+ for(DataIdentifier di : fpb.getOutputParams())
+ _expectRetVars.add(di.getName());
+ }
+
+ LocalVariableMap retVars = fn_ec.getVariables();
+ for( Entry<String,Data> var : retVars.entrySet() ) {
+ if( _expectRetVars.contains(var.getKey()) )
+ continue;
+ //cleanup unexpected return values to avoid leaks
+ if( var.getValue() instanceof MatrixObject )
+ fn_ec.cleanupMatrixObject((MatrixObject)var.getValue());
}
// Unpin the pinned variables
- ec.unpinVariables(_boundInputParamNames, pinStatus);
+ ec.unpinVariables(_boundInputNames, pinStatus);
// add the updated binding for each return variable to the variables in original symbol table
for (int i=0; i< fpb.getOutputParams().size(); i++){
-
- String boundVarName = _boundOutputParamNames.get(i);
+ String boundVarName = _boundOutputNames.get(i);
Data boundValue = retVars.get(fpb.getOutputParams().get(i).getName());
if (boundValue == null)
throw new DMLRuntimeException(boundVarName + " was not assigned a return value");
@@ -240,14 +229,12 @@ public class FunctionCallCPInstruction extends CPInstruction {
LOG.debug("ExternalBuiltInFunction: " + this.toString());
}
- public ArrayList<String> getBoundInputParamNames()
- {
- return _boundInputParamNames;
+ public ArrayList<String> getBoundInputParamNames() {
+ return _boundInputNames;
}
- public ArrayList<String> getBoundOutputParamNames()
- {
- return _boundOutputParamNames;
+ public ArrayList<String> getBoundOutputParamNames() {
+ return _boundOutputNames;
}
public void setFunctionName(String fname)
@@ -277,6 +264,4 @@ public class FunctionCallCPInstruction extends CPInstruction {
return sb.substring( 0, sb.length()-Lop.OPERAND_DELIMITOR.length() );
}
-
-
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java b/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java
index 6e2ddef..a7e0971 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java
@@ -100,24 +100,23 @@ public class CompressedLinregCG extends AutomatedTestBase
try
{
- String TEST_NAME = testname;
- TestConfiguration config = getTestConfiguration(TEST_NAME);
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
/* This is for running the junit test the new way, i.e., construct the arguments directly */
- String HOME = SCRIPT_DIR + "functions/codegen/";
+ String HOME1 = SCRIPT_DIR + "functions/compress/";
+ String HOME2 = SCRIPT_DIR + "functions/codegen/";
fullDMLScriptName = "scripts/algorithms/LinearRegCG.dml";
programArgs = new String[]{ "-explain", "-stats", "-nvargs", "X="+input("X"), "Y="+input("y"),
"icpt="+String.valueOf(intercept), "tol="+String.valueOf(epsilon),
"maxi="+String.valueOf(maxiter), "reg="+String.valueOf(regular), "B="+output("w")};
- fullRScriptName = HOME + "Algorithm_LinregCG.R";
+ fullRScriptName = HOME2 + "Algorithm_LinregCG.R";
rCmd = "Rscript" + " " + fullRScriptName + " " +
- HOME + INPUT_DIR + " " +
- String.valueOf(intercept) + " " + String.valueOf(epsilon) + " " +
- String.valueOf(maxiter) + " " + String.valueOf(regular) + HOME + EXPECTED_DIR;
+ HOME1 + INPUT_DIR + " " +
+ String.valueOf(intercept) + " " + String.valueOf(epsilon) + " " +
+ String.valueOf(maxiter) + " " + String.valueOf(regular) + " "+ HOME1 + EXPECTED_DIR;
- loadTestConfiguration(config);
-
//generate actual datasets
double[][] X = getRandomMatrix(rows, cols, 1, 1, sparse?sparsity2:sparsity1, 7);
writeInputMatrixWithMTD("X", X, true);
@@ -141,7 +140,7 @@ public class CompressedLinregCG extends AutomatedTestBase
finally {
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
- InfrastructureAnalyzer.setLocalMaxMemory(memOld);
+ InfrastructureAnalyzer.setLocalMaxMemory(memOld);
CompressedMatrixBlock.ALLOW_DDC_ENCODING = true;
}
}