You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/07/20 19:44:47 UTC

[systemds] branch master updated: [SYSTEMDS-2575] Fix eval function calls (incorrect pinning of inputs)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new e581b5a  [SYSTEMDS-2575] Fix eval function calls (incorrect pinning of inputs)
e581b5a is described below

commit e581b5a6248b56a70e18ffe6ba699e8142a2d679
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Mon Jul 20 21:37:21 2020 +0200

    [SYSTEMDS-2575] Fix eval function calls (incorrect pinning of inputs)
    
    This patch fixes an issue of indirect eval function calls where wrong
    input variable names led to missing pinning of inputs and thus too eager
    cleanup of these variables (which causes crashes if the inputs are used
    in other operations of the eval call).
    
    The fix is simple. We avoid such inconsistent construction and
    invocation of fcall instructions by using a narrower interface and
    constructing the materialized names internally in the fcall.
---
 .../runtime/controlprogram/paramserv/PSWorker.java |  4 +--
 .../controlprogram/paramserv/ParamServer.java      |  4 +--
 .../instructions/cp/EvalNaryCPInstruction.java     | 34 +++++++++-------------
 .../instructions/cp/FunctionCallCPInstruction.java | 11 +++----
 4 files changed, 22 insertions(+), 31 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index 9f2311b..0eb9cf9 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -77,12 +77,10 @@ public abstract class PSWorker implements Serializable
 		CPOperand[] boundInputs = inputs.stream()
 			.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
 			.toArray(CPOperand[]::new);
-		ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName)
-			.collect(Collectors.toCollection(ArrayList::new));
 		ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
 			.collect(Collectors.toCollection(ArrayList::new));
 		_inst = new FunctionCallCPInstruction(ns, fname, boundInputs,
-			inputNames, func.getInputParamNames(), outputNames, "update function");
+			func.getInputParamNames(), outputNames, "update function");
 
 		// Check the inputs of the update function
 		checkInput(false, inputs, DataType.MATRIX, Statement.PS_FEATURES);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 92e29b1..81cee33 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -104,12 +104,10 @@ public abstract class ParamServer
 		CPOperand[] boundInputs = inputs.stream()
 			.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
 			.toArray(CPOperand[]::new);
-		ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName)
-			.collect(Collectors.toCollection(ArrayList::new));
 		ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
 			.collect(Collectors.toCollection(ArrayList::new));
 		_inst = new FunctionCallCPInstruction(ns, fname, boundInputs,
-			inputNames, func.getInputParamNames(), outputNames, "aggregate function");
+			func.getInputParamNames(), outputNames, "aggregate function");
 	}
 
 	public abstract void push(int workerID, ListObject value);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index f10e7bb..070a3fc 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -67,10 +67,6 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
 		CPOperand[] boundInputs = Arrays.copyOfRange(inputs, 1, inputs.length);
 		List<String> boundOutputNames = new ArrayList<>();
 		boundOutputNames.add(output.getName());
-		List<String> boundInputNames = new ArrayList<>();
-		for (CPOperand input : boundInputs) {
-			boundInputNames.add(input.getName());
-		}
 
 		//2. copy the created output matrix
 		MatrixObject outputMO = new MatrixObject(ec.getMatrixObject(output.getName()));
@@ -103,32 +99,30 @@ public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
 				ec.getVariables().put(varName, in);
 				boundInputs2[i] = new CPOperand(varName, in);
 			}
-			boundInputNames = lo.isNamedList() ? lo.getNames() : fpb.getInputParamNames();
 			boundInputs = boundInputs2;
 		}
 		
 		//5. call the function
 		FunctionCallCPInstruction fcpi = new FunctionCallCPInstruction(null, funcName,
-			boundInputs, boundInputNames, fpb.getInputParamNames(), boundOutputNames, "eval func");
+			boundInputs, fpb.getInputParamNames(), boundOutputNames, "eval func");
 		fcpi.processInstruction(ec);
 
 		//6. convert the result to matrix
 		Data newOutput = ec.getVariable(output);
-		if (newOutput instanceof MatrixObject) {
-			return;
-		}
-		MatrixBlock mb = null;
-		if (newOutput instanceof ScalarObject) {
-			//convert scalar to matrix
-			mb = new MatrixBlock(((ScalarObject) newOutput).getDoubleValue());
-		} else if (newOutput instanceof FrameObject) {
-			//convert frame to matrix
-			mb = DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
-			ec.cleanupCacheableData((FrameObject) newOutput);
+		if (!(newOutput instanceof MatrixObject)) {
+			MatrixBlock mb = null;
+			if (newOutput instanceof ScalarObject) {
+				//convert scalar to matrix
+				mb = new MatrixBlock(((ScalarObject) newOutput).getDoubleValue());
+			} else if (newOutput instanceof FrameObject) {
+				//convert frame to matrix
+				mb = DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
+				ec.cleanupCacheableData((FrameObject) newOutput);
+			}
+			outputMO.acquireModify(mb);
+			outputMO.release();
+			ec.setVariable(output.getName(), outputMO);
 		}
-		outputMO.acquireModify(mb);
-		outputMO.release();
-		ec.setVariable(output.getName(), outputMO);
 		
 		//7. cleanup of variable expanded from list
 		if( boundInputs2 != null ) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 695b07a..8b88647 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -20,8 +20,10 @@
 package org.apache.sysds.runtime.instructions.cp;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashSet;
 import java.util.List;
+import java.util.stream.Collectors;
 
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.lops.Lop;
@@ -55,12 +57,13 @@ public class FunctionCallCPInstruction extends CPInstruction {
 	private final List<String> _boundOutputNames;
 
 	public FunctionCallCPInstruction(String namespace, String functName, CPOperand[] boundInputs,
-			List<String> boundInputNames, List<String> funArgNames, List<String> boundOutputNames, String istr) {
+			List<String> funArgNames, List<String> boundOutputNames, String istr) {
 		super(CPType.External, null, functName, istr);
 		_functionName = functName;
 		_namespace = namespace;
 		_boundInputs = boundInputs;
-		_boundInputNames = boundInputNames;
+		_boundInputNames = Arrays.stream(boundInputs).map(i -> i.getName())
+			.collect(Collectors.toCollection(ArrayList::new));
 		_funArgNames = funArgNames;
 		_boundOutputNames = boundOutputNames;
 	}
@@ -81,19 +84,17 @@ public class FunctionCallCPInstruction extends CPInstruction {
 		int numInputs = Integer.valueOf(parts[3]);
 		int numOutputs = Integer.valueOf(parts[4]);
 		CPOperand[] boundInputs = new CPOperand[numInputs];
-		List<String> boundInputNames = new ArrayList<>();
 		List<String> funArgNames = new ArrayList<>();
 		List<String> boundOutputNames = new ArrayList<>();
 		for (int i = 0; i < numInputs; i++) {
 			String[] nameValue = IOUtilFunctions.splitByFirst(parts[5 + i], "=");
 			boundInputs[i] = new CPOperand(nameValue[1]);
 			funArgNames.add(nameValue[0]);
-			boundInputNames.add(boundInputs[i].getName());
 		}
 		for (int i = 0; i < numOutputs; i++)
 			boundOutputNames.add(parts[5 + numInputs + i]);
 		return new FunctionCallCPInstruction ( namespace, functionName,
-			boundInputs, boundInputNames, funArgNames, boundOutputNames, str );
+			boundInputs, funArgNames, boundOutputNames, str );
 	}
 	
 	@Override