You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/08/25 19:42:23 UTC

[systemds] branch master updated: [SYSTEMDS-3096] Fix parfor result variable handling in eval context

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 4a27a54  [SYSTEMDS-3096] Fix parfor result variable handling in eval context
4a27a54 is described below

commit 4a27a54df8ec80be08650fd7813938bec45ba644
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Wed Aug 25 21:40:12 2021 +0200

    [SYSTEMDS-3096] Fix parfor result variable handling in eval context
    
    This patch fixes an issue of result correctness when unoptimized
    functions with parfor loops are called through eval and these
    unoptimized functions have been copied during compilation (not by eval
    during lazy loading), which happens for example for user-specified
    functions. Specifically, deep copies of parfor statement blocks did not
    carry the list result variables, which ultimately led to missing result
    merge from the parfor workers.
    
    Thanks to Shafaq for catching this critical issue.
---
 src/main/java/org/apache/sysds/parser/ParForStatementBlock.java  | 5 +++++
 .../java/org/apache/sysds/runtime/util/ProgramConverter.java     | 2 ++
 .../scripts/functions/misc/FunPotpourriParforEvalBuiltin.dml     | 9 +++++++--
 3 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
index 4b47148..74c55c5 100644
--- a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
@@ -170,6 +170,11 @@ public class ParForStatementBlock extends ForStatementBlock
 		return _resultVars;
 	}
 	
+	public void setResultVariables(ArrayList<ResultVar> rvars) {
+		_resultVars.clear();
+		_resultVars.addAll(rvars);
+	}
+	
 	private void addToResultVariablesNoDup( String var, boolean accum ) {
 		addToResultVariablesNoDup(new ResultVar(var, accum));
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index f1adc3a..18a88fc 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -708,6 +708,8 @@ public class ProgramConverter
 				
 				ret.updatePredicateRecompilationFlags();
 				ret.setNondeterministic(sb.isNondeterministic());
+				if( sb instanceof ParForStatementBlock )
+					((ParForStatementBlock)ret).setResultVariables(((ParForStatementBlock)sb).getResultVariables());
 			}
 			else {
 				ret = sb;
diff --git a/src/test/scripts/functions/misc/FunPotpourriParforEvalBuiltin.dml b/src/test/scripts/functions/misc/FunPotpourriParforEvalBuiltin.dml
index 97be37e..ab8991e 100644
--- a/src/test/scripts/functions/misc/FunPotpourriParforEvalBuiltin.dml
+++ b/src/test/scripts/functions/misc/FunPotpourriParforEvalBuiltin.dml
@@ -58,7 +58,7 @@ crossV = function(Matrix[double] X, Matrix[double] y, Integer k, Matrix[Double]
     testX = testset[, 2:ncol(testset)]
     testy = testset[, 1]
     beta = multiLogReg(X=trainX, Y=trainy, icpt=as.scalar(MLhp[1,1]), reg=as.scalar(MLhp[1,2]), tol=as.scalar(MLhp[1,3]), 
-    maxi=as.scalar(MLhp[1,4]), maxii=50, verbose=FALSE);
+      maxi=as.scalar(MLhp[1,4]), maxii=50, verbose=FALSE);
     [prob, yhat, acc] = multiLogRegPredict(testX, beta, testy, FALSE)
     accuracy = getAccuracy(testy, yhat, isWeighted)
     accuracyMatrix[i] = accuracy
@@ -69,5 +69,10 @@ X = rand(rows=100, cols=100)
 Y = sample(2, 100, TRUE)
 hp = matrix("1 1e-4 1e-6 100", rows=1, cols=4)
 
+#acc = crossV(X=X, y=Y, k=3, MLhp=hp, isWeighted=FALSE)
 acc = eval("crossV", list(X=X, y=Y, k=3, MLhp=hp, isWeighted=FALSE))
-print("CV accuracy: "+mean(acc))
+
+macc = mean(acc)
+if( macc <= 0 ) # fail test if empty
+  stop("Invalid accuracy: "+macc);
+print("CV accuracy: "+macc)