You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ma...@apache.org on 2021/03/01 16:39:41 UTC

[systemds] branch master updated: [SYSTEMDS-2877] TomekLink builtin function

This is an automated email from the ASF dual-hosted git repository.

markd pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new b26c3d9  [SYSTEMDS-2877] TomekLink builtin function
b26c3d9 is described below

commit b26c3d91ad7d3af791407f91a521bd344b3d0edb
Author: JustYeti32 <ma...@gmx.net>
AuthorDate: Mon Mar 1 17:36:57 2021 +0100

    [SYSTEMDS-2877] TomekLink builtin function
    
    Closes #1186
---
 scripts/builtin/tomeklink.dml                      | 101 +++++++++++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../functions/builtin/BuiltinTomeklinkTest.java    |  92 +++++++++++++++++++
 src/test/scripts/functions/builtin/tomeklink.R     |  32 +++++++
 src/test/scripts/functions/builtin/tomeklink.dml   |  27 ++++++
 5 files changed, 253 insertions(+)

diff --git a/scripts/builtin/tomeklink.dml b/scripts/builtin/tomeklink.dml
new file mode 100644
index 0000000..71b2a75
--- /dev/null
+++ b/scripts/builtin/tomeklink.dml
@@ -0,0 +1,101 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+#
+# UNDERSAMPLING TECHNIQUE;
+# COMPUTES TOMEK LINKS AND DROPS THEM FROM DATA MATRIX AND LABEL VECTOR
+# DROPS ONLY THE MAJORITY LABEL AND CORRESPONDING POINT OF TOMEK LINKS
+#
+# INPUT   				PARAMETERS:
+# ---------------------------------------------------------------------------------------------
+# NAME    				TYPE     DEFAULT  MEANING
+# ---------------------------------------------------------------------------------------------
+# X       				MATRIX   ---      Data Matrix (nxm)
+# y      				MATRIX   ---      Label Matrix (nx1)
+# ---------------------------------------------------------------------------------------------
+# OUTPUT:
+# X_under  - Data Matrix without Tomek links
+# y_under  - Labels corresponding to undersampled data
+# drop_idx - Indices of dropped rows/labels wrt input
+
+
+###### MAIN PART ######
+
+m_tomeklink = function(Matrix[Double] X, Matrix[Double] y)
+    return (Matrix[Double] X_under, Matrix[Double] y_under, Matrix[Double] drop_idx) {
+  majority_label = 0
+  n = nrow(X)
+  m = ncol(X)
+
+  tomek_links = get_links(X, y, majority_label)
+
+  X_under = matrix(0, rows = 0, cols = m)
+  y_under = matrix(0, rows = 0, cols = 1)
+  drop_idx = matrix(0, rows = 0, cols = 1)
+
+  for (i in 1:nrow(X)) {
+    is_link = as.scalar(tomek_links[i, 1])
+    if (is_link == 1) {
+      X_under = rbind(X_under, X[i,])
+      y_under = rbind(y_under, y[i,])
+      drop_idx = rbind(drop_idx, matrix(i, rows = 1, cols = 1))
+    }
+  }
+}
+
+###### END MAIN PART ######
+
+###### UTILS ######
+
+# nearest nb function ----------------------------------------------------------
+get_nn = function(Matrix[Double] X)
+    return (Matrix[Double] nn) {
+  nn = matrix(0, rows = nrow(X), cols = 1)
+  for (i in 1:nrow(X)) {
+    dists = rowSums((X - X[i,])^2)
+    sort_dists = order(target = dists, by = 1, decreasing = FALSE, index.return = TRUE)
+    nn[i, 1] = as.scalar(sort_dists[2, 1])  # nearest, not self
+  }
+}
+
+# find tomek link function  ----------------------------------------------------
+get_links = function(Matrix[Double] X, Matrix[Double] y, double majority_label)
+    return (Matrix[Double] tomek_links) {
+  tomek_links = matrix(0, rows = nrow(X), cols = 1)
+  nn = get_nn(X)
+
+  for (index in 1:nrow(X)) {
+    # this is a tomek link according to R: ubTomek https://rdrr.io/cran/unbalanced/src/R/ubTomek.R
+    # other sources define it as a pair of mutual nearest neighbor
+    # where exactly one endpoint has the majority label
+
+    nn_index = as.scalar(nn[index, 1])
+    label = as.scalar(y[index, 1])
+    nn_label = as.scalar(y[nn_index, 1])
+
+    if (label != majority_label) {
+      if (nn_label == majority_label) {
+        tomek_links[nn_index, 1] = 1
+      }
+    }
+  }
+}
+
+###### END UTILS ######
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index c6723f8..a0c0222 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -225,6 +225,7 @@ public enum Builtins {
 	TANH("tanh", false),
 	TRACE("trace", false),
 	TO_ONE_HOT("toOneHot", true),
+	TOMEKLINK("tomeklink", true),
 	TYPEOF("typeof", false),
 	COUNT_DISTINCT("countDistinct",false),
 	COUNT_DISTINCT_APPROX("countDistinctApprox",false),
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinTomeklinkTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinTomeklinkTest.java
new file mode 100644
index 0000000..a251242
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinTomeklinkTest.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.junit.Test;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.HashMap;
+
+public class BuiltinTomeklinkTest extends AutomatedTestBase
+{
+	private final static String TEST_NAME = "tomeklink";
+	private final static String TEST_DIR = "functions/builtin/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinTomeklinkTest.class.getSimpleName() + "/";
+
+	private final static double eps = 1e-3;
+	private final static int rows = 53;
+  private final static int cols = 6;
+
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
+	}
+
+	@Test
+	public void testTomeklinkCP() {
+		runTomeklinkTest(ExecType.CP);
+	}
+
+  @Test
+	public void testTomeklinkSP() {
+		runTomeklinkTest(ExecType.SPARK);
+	}
+
+	private void runTomeklinkTest(ExecType instType)
+	{
+		ExecMode platformOld = setExecMode(instType);
+
+		try
+		{
+			loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			programArgs = new String[] {"-args", input("A"), input("B"), output("C")};
+
+      fullRScriptName = HOME + TEST_NAME + ".R";
+			rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
+
+			//generate actual dataset
+      double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.7, 1);
+			writeInputMatrixWithMTD("A", A, true);
+
+      double[][] B = getRandomMatrix(rows, 1, 0, 1, 0.5, 1);
+      B = TestUtils.round(B);
+      writeInputMatrixWithMTD("B", B, true);
+
+			runTest(true, false, null, -1);
+			runRScript(true);
+
+			//compare matrices
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("C");
+			HashMap<CellIndex, Double> rfile  = readRMatrixFromExpectedDir("C");
+			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+		}
+		finally {
+			rtplatform = platformOld;
+		}
+	}
+}
diff --git a/src/test/scripts/functions/builtin/tomeklink.R b/src/test/scripts/functions/builtin/tomeklink.R
new file mode 100644
index 0000000..a051db3
--- /dev/null
+++ b/src/test/scripts/functions/builtin/tomeklink.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args = commandArgs(TRUE)
+
+library("unbalanced")
+library("Matrix")
+
+X = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+y = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+
+C = sort(ubTomek(X, y, verbose=FALSE)$id.rm)
+writeMM(as(C, "CsparseMatrix"), paste(args[2], "C", sep=""))
diff --git a/src/test/scripts/functions/builtin/tomeklink.dml b/src/test/scripts/functions/builtin/tomeklink.dml
new file mode 100644
index 0000000..8ab9145
--- /dev/null
+++ b/src/test/scripts/functions/builtin/tomeklink.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1)
+y = read($2)
+
+[X_under, y_under, drop_idx] = tomeklink(X, y)
+write(drop_idx, $3) # sorted by default