You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2020/10/30 13:48:53 UTC

[systemds] branch master updated: [MINOR] Fix lineage tracing of SAMPLE

This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 079a57b  [MINOR] Fix lineage tracing of SAMPLE
079a57b is described below

commit 079a57b3f844a712b0ff529e166e0ad1c713ba01
Author: arnabp <ar...@tugraz.at>
AuthorDate: Fri Oct 30 14:40:51 2020 +0100

    [MINOR] Fix lineage tracing of SAMPLE
---
 scripts/builtin/smote.dml                                          | 6 +++++-
 .../apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java | 7 +++++--
 2 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/scripts/builtin/smote.dml b/scripts/builtin/smote.dml
index a223227..dd096ed 100644
--- a/scripts/builtin/smote.dml
+++ b/scripts/builtin/smote.dml
@@ -68,7 +68,11 @@ return (Matrix[Double] Y) {
   synthetic_samples = matrix(0, iterLim*ncol(knn_index), ncol(X))
   
   # shuffle the nn indexes
-  rand_index =  ifelse(k < iterLim, sample(k, iterLim, TRUE, 42), sample(k, iterLim, 42))
+  #rand_index =  ifelse(k < iterLim, sample(k, iterLim, TRUE, 42), sample(k, iterLim, 42))
+  if (k < iterLim)
+    rand_index = sample(k, iterLim, TRUE, 42);
+  else
+    rand_index = sample(k, iterLim, 42);
 
   while(iter < iterLim)
   {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
index dc7a5f1..af74498 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
@@ -409,8 +409,11 @@ public class DataGenCPInstruction extends UnaryCPInstruction {
 				}
 				//replace output variable name with a placeholder
 				tmpInstStr = InstructionUtils.replaceOperandName(tmpInstStr);
-				tmpInstStr = replaceNonLiteral(tmpInstStr, rows, 2, ec);
-				tmpInstStr = replaceNonLiteral(tmpInstStr, cols, 3, ec);
+				tmpInstStr = method.name().equalsIgnoreCase("rand") ? 
+						replaceNonLiteral(tmpInstStr, rows, 2, ec) :
+						replaceNonLiteral(tmpInstStr, rows, 3, ec);
+				tmpInstStr = method.name().equalsIgnoreCase("rand") ? 
+						replaceNonLiteral(tmpInstStr, cols, 3, ec) : tmpInstStr;
 				break;
 			}
 			case SEQ: {