You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/08/25 19:42:23 UTC
[systemds] branch master updated: [SYSTEMDS-3096] Fix parfor result
variable handling in eval context
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 4a27a54 [SYSTEMDS-3096] Fix parfor result variable handling in eval context
4a27a54 is described below
commit 4a27a54df8ec80be08650fd7813938bec45ba644
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Wed Aug 25 21:40:12 2021 +0200
[SYSTEMDS-3096] Fix parfor result variable handling in eval context
This patch fixes an issue of result correctness when unoptimized
functions with parfor loops are called through eval and these
unoptimized functions have been copied during compilation (not by eval
during lazy loading), which happens for example for user-specified
functions. Specifically, deep copies of parfor statement blocks did not
carry the list result variables, which ultimately led to missing result
merge from the parfor workers.
Thanks to Shafaq for catching this critical issue.
---
src/main/java/org/apache/sysds/parser/ParForStatementBlock.java | 5 +++++
.../java/org/apache/sysds/runtime/util/ProgramConverter.java | 2 ++
.../scripts/functions/misc/FunPotpourriParforEvalBuiltin.dml | 9 +++++++--
3 files changed, 14 insertions(+), 2 deletions(-)
diff --git a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
index 4b47148..74c55c5 100644
--- a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
@@ -170,6 +170,11 @@ public class ParForStatementBlock extends ForStatementBlock
return _resultVars;
}
+ public void setResultVariables(ArrayList<ResultVar> rvars) {
+ _resultVars.clear();
+ _resultVars.addAll(rvars);
+ }
+
private void addToResultVariablesNoDup( String var, boolean accum ) {
addToResultVariablesNoDup(new ResultVar(var, accum));
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index f1adc3a..18a88fc 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -708,6 +708,8 @@ public class ProgramConverter
ret.updatePredicateRecompilationFlags();
ret.setNondeterministic(sb.isNondeterministic());
+ if( sb instanceof ParForStatementBlock )
+ ((ParForStatementBlock)ret).setResultVariables(((ParForStatementBlock)sb).getResultVariables());
}
else {
ret = sb;
diff --git a/src/test/scripts/functions/misc/FunPotpourriParforEvalBuiltin.dml b/src/test/scripts/functions/misc/FunPotpourriParforEvalBuiltin.dml
index 97be37e..ab8991e 100644
--- a/src/test/scripts/functions/misc/FunPotpourriParforEvalBuiltin.dml
+++ b/src/test/scripts/functions/misc/FunPotpourriParforEvalBuiltin.dml
@@ -58,7 +58,7 @@ crossV = function(Matrix[double] X, Matrix[double] y, Integer k, Matrix[Double]
testX = testset[, 2:ncol(testset)]
testy = testset[, 1]
beta = multiLogReg(X=trainX, Y=trainy, icpt=as.scalar(MLhp[1,1]), reg=as.scalar(MLhp[1,2]), tol=as.scalar(MLhp[1,3]),
- maxi=as.scalar(MLhp[1,4]), maxii=50, verbose=FALSE);
+ maxi=as.scalar(MLhp[1,4]), maxii=50, verbose=FALSE);
[prob, yhat, acc] = multiLogRegPredict(testX, beta, testy, FALSE)
accuracy = getAccuracy(testy, yhat, isWeighted)
accuracyMatrix[i] = accuracy
@@ -69,5 +69,10 @@ X = rand(rows=100, cols=100)
Y = sample(2, 100, TRUE)
hp = matrix("1 1e-4 1e-6 100", rows=1, cols=4)
+#acc = crossV(X=X, y=Y, k=3, MLhp=hp, isWeighted=FALSE)
acc = eval("crossV", list(X=X, y=Y, k=3, MLhp=hp, isWeighted=FALSE))
-print("CV accuracy: "+mean(acc))
+
+macc = mean(acc)
+if( macc <= 0 ) # fail test if empty
+ stop("Invalid accuracy: "+macc);
+print("CV accuracy: "+macc)