You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/07/20 19:44:47 UTC
[systemds] branch master updated: [SYSTEMDS-2575] Fix eval function
calls (incorrect pinning of inputs)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new e581b5a [SYSTEMDS-2575] Fix eval function calls (incorrect pinning of inputs)
e581b5a is described below
commit e581b5a6248b56a70e18ffe6ba699e8142a2d679
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Mon Jul 20 21:37:21 2020 +0200
[SYSTEMDS-2575] Fix eval function calls (incorrect pinning of inputs)
This patch fixes an issue of indirect eval function calls where wrong
input variable names led to missing pinning of inputs and thus too eager
cleanup of these variables (which causes crashes if the inputs are used
in other operations of the eval call).
The fix is simple. We avoid such inconsistent construction and
invocation of fcall instructions by using a narrower interface and
constructing the materialized names internally in the fcall.
---
.../runtime/controlprogram/paramserv/PSWorker.java | 4 +--
.../controlprogram/paramserv/ParamServer.java | 4 +--
.../instructions/cp/EvalNaryCPInstruction.java | 34 +++++++++-------------
.../instructions/cp/FunctionCallCPInstruction.java | 11 +++----
4 files changed, 22 insertions(+), 31 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index 9f2311b..0eb9cf9 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -77,12 +77,10 @@ public abstract class PSWorker implements Serializable
CPOperand[] boundInputs = inputs.stream()
.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
.toArray(CPOperand[]::new);
- ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName)
- .collect(Collectors.toCollection(ArrayList::new));
ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
.collect(Collectors.toCollection(ArrayList::new));
_inst = new FunctionCallCPInstruction(ns, fname, boundInputs,
- inputNames, func.getInputParamNames(), outputNames, "update function");
+ func.getInputParamNames(), outputNames, "update function");
// Check the inputs of the update function
checkInput(false, inputs, DataType.MATRIX, Statement.PS_FEATURES);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 92e29b1..81cee33 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -104,12 +104,10 @@ public abstract class ParamServer
CPOperand[] boundInputs = inputs.stream()
.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
.toArray(CPOperand[]::new);
- ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName)
- .collect(Collectors.toCollection(ArrayList::new));
ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
.collect(Collectors.toCollection(ArrayList::new));
_inst = new FunctionCallCPInstruction(ns, fname, boundInputs,
- inputNames, func.getInputParamNames(), outputNames, "aggregate function");
+ func.getInputParamNames(), outputNames, "aggregate function");
}
public abstract void push(int workerID, ListObject value);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index f10e7bb..070a3fc 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -67,10 +67,6 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
CPOperand[] boundInputs = Arrays.copyOfRange(inputs, 1, inputs.length);
List<String> boundOutputNames = new ArrayList<>();
boundOutputNames.add(output.getName());
- List<String> boundInputNames = new ArrayList<>();
- for (CPOperand input : boundInputs) {
- boundInputNames.add(input.getName());
- }
//2. copy the created output matrix
MatrixObject outputMO = new MatrixObject(ec.getMatrixObject(output.getName()));
@@ -103,32 +99,30 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
ec.getVariables().put(varName, in);
boundInputs2[i] = new CPOperand(varName, in);
}
- boundInputNames = lo.isNamedList() ? lo.getNames() : fpb.getInputParamNames();
boundInputs = boundInputs2;
}
//5. call the function
FunctionCallCPInstruction fcpi = new FunctionCallCPInstruction(null, funcName,
- boundInputs, boundInputNames, fpb.getInputParamNames(), boundOutputNames, "eval func");
+ boundInputs, fpb.getInputParamNames(), boundOutputNames, "eval func");
fcpi.processInstruction(ec);
//6. convert the result to matrix
Data newOutput = ec.getVariable(output);
- if (newOutput instanceof MatrixObject) {
- return;
- }
- MatrixBlock mb = null;
- if (newOutput instanceof ScalarObject) {
- //convert scalar to matrix
- mb = new MatrixBlock(((ScalarObject) newOutput).getDoubleValue());
- } else if (newOutput instanceof FrameObject) {
- //convert frame to matrix
- mb = DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
- ec.cleanupCacheableData((FrameObject) newOutput);
+ if (!(newOutput instanceof MatrixObject)) {
+ MatrixBlock mb = null;
+ if (newOutput instanceof ScalarObject) {
+ //convert scalar to matrix
+ mb = new MatrixBlock(((ScalarObject) newOutput).getDoubleValue());
+ } else if (newOutput instanceof FrameObject) {
+ //convert frame to matrix
+ mb = DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
+ ec.cleanupCacheableData((FrameObject) newOutput);
+ }
+ outputMO.acquireModify(mb);
+ outputMO.release();
+ ec.setVariable(output.getName(), outputMO);
}
- outputMO.acquireModify(mb);
- outputMO.release();
- ec.setVariable(output.getName(), outputMO);
//7. cleanup of variable expanded from list
if( boundInputs2 != null ) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 695b07a..8b88647 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -20,8 +20,10 @@
package org.apache.sysds.runtime.instructions.cp;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
+import java.util.stream.Collectors;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.lops.Lop;
@@ -55,12 +57,13 @@ public class FunctionCallCPInstruction extends CPInstruction {
private final List<String> _boundOutputNames;
public FunctionCallCPInstruction(String namespace, String functName, CPOperand[] boundInputs,
- List<String> boundInputNames, List<String> funArgNames, List<String> boundOutputNames, String istr) {
+ List<String> funArgNames, List<String> boundOutputNames, String istr) {
super(CPType.External, null, functName, istr);
_functionName = functName;
_namespace = namespace;
_boundInputs = boundInputs;
- _boundInputNames = boundInputNames;
+ _boundInputNames = Arrays.stream(boundInputs).map(i -> i.getName())
+ .collect(Collectors.toCollection(ArrayList::new));
_funArgNames = funArgNames;
_boundOutputNames = boundOutputNames;
}
@@ -81,19 +84,17 @@ public class FunctionCallCPInstruction extends CPInstruction {
int numInputs = Integer.valueOf(parts[3]);
int numOutputs = Integer.valueOf(parts[4]);
CPOperand[] boundInputs = new CPOperand[numInputs];
- List<String> boundInputNames = new ArrayList<>();
List<String> funArgNames = new ArrayList<>();
List<String> boundOutputNames = new ArrayList<>();
for (int i = 0; i < numInputs; i++) {
String[] nameValue = IOUtilFunctions.splitByFirst(parts[5 + i], "=");
boundInputs[i] = new CPOperand(nameValue[1]);
funArgNames.add(nameValue[0]);
- boundInputNames.add(boundInputs[i].getName());
}
for (int i = 0; i < numOutputs; i++)
boundOutputNames.add(parts[5 + numInputs + i]);
return new FunctionCallCPInstruction ( namespace, functionName,
- boundInputs, boundInputNames, funArgNames, boundOutputNames, str );
+ boundInputs, funArgNames, boundOutputNames, str );
}
@Override