You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2020/10/30 13:48:53 UTC
[systemds] branch master updated: [MINOR] Fix lineage tracing of
SAMPLE
This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 079a57b [MINOR] Fix lineage tracing of SAMPLE
079a57b is described below
commit 079a57b3f844a712b0ff529e166e0ad1c713ba01
Author: arnabp <ar...@tugraz.at>
AuthorDate: Fri Oct 30 14:40:51 2020 +0100
[MINOR] Fix lineage tracing of SAMPLE
---
scripts/builtin/smote.dml | 6 +++++-
.../apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java | 7 +++++--
2 files changed, 10 insertions(+), 3 deletions(-)
diff --git a/scripts/builtin/smote.dml b/scripts/builtin/smote.dml
index a223227..dd096ed 100644
--- a/scripts/builtin/smote.dml
+++ b/scripts/builtin/smote.dml
@@ -68,7 +68,11 @@ return (Matrix[Double] Y) {
synthetic_samples = matrix(0, iterLim*ncol(knn_index), ncol(X))
# shuffle the nn indexes
- rand_index = ifelse(k < iterLim, sample(k, iterLim, TRUE, 42), sample(k, iterLim, 42))
+ #rand_index = ifelse(k < iterLim, sample(k, iterLim, TRUE, 42), sample(k, iterLim, 42))
+ if (k < iterLim)
+ rand_index = sample(k, iterLim, TRUE, 42);
+ else
+ rand_index = sample(k, iterLim, 42);
while(iter < iterLim)
{
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
index dc7a5f1..af74498 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
@@ -409,8 +409,11 @@ public class DataGenCPInstruction extends UnaryCPInstruction {
}
//replace output variable name with a placeholder
tmpInstStr = InstructionUtils.replaceOperandName(tmpInstStr);
- tmpInstStr = replaceNonLiteral(tmpInstStr, rows, 2, ec);
- tmpInstStr = replaceNonLiteral(tmpInstStr, cols, 3, ec);
+ tmpInstStr = method.name().equalsIgnoreCase("rand") ?
+ replaceNonLiteral(tmpInstStr, rows, 2, ec) :
+ replaceNonLiteral(tmpInstStr, rows, 3, ec);
+ tmpInstStr = method.name().equalsIgnoreCase("rand") ?
+ replaceNonLiteral(tmpInstStr, cols, 3, ec) : tmpInstStr;
break;
}
case SEQ: {