You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ss...@apache.org on 2022/02/07 17:52:50 UTC

[systemds] branch main updated: [SYSTEMDS-3228] Builtin for k nearest neighbor graph construction - This builtin computes the row by rows distance and then find the kth-smallest value for each row and constructs a binary sparse matrix for k-nearest neighbors.

This is an automated email from the ASF dual-hosted git repository.

ssiddiqi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 423350f  [SYSTEMDS-3228] Builtin for k nearest neighbor graph construction   - This builtin computes the row by rows distance and then     find the kth-smallest value for each row and constructs a binary sparse     matrix for k-nearest neighbors.
423350f is described below

commit 423350f18417e4e26bdd53aa7bb89c9e4c39a2f5
Author: Milos Babic <mi...@student.tugraz.at>
AuthorDate: Mon Feb 7 18:26:45 2022 +0100

    [SYSTEMDS-3228] Builtin for k nearest neighbor graph construction
      - This builtin computes the row by rows distance and then
        find the kth-smallest value for each row and constructs a binary sparse
        matrix for k-nearest neighbors.
    
    DIA project WS2021/22.
    Closes #1513
    
    Co-authored-by: Manfred Milcharm <ma...@student.tugraz.at>
---
 scripts/builtin/knnGraph.dml                       | 76 +++++++++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |  1 +
 .../builtin/part1/BuiltinKNNGraphTest.java         | 86 ++++++++++++++++++++++
 src/test/scripts/functions/builtin/knnGraph.dml    | 27 +++++++
 4 files changed, 190 insertions(+)

diff --git a/scripts/builtin/knnGraph.dml b/scripts/builtin/knnGraph.dml
new file mode 100644
index 0000000..36155d0
--- /dev/null
+++ b/scripts/builtin/knnGraph.dml
@@ -0,0 +1,76 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Builtin for k nearest neighbour graph construction
+
+m_knnGraph = function(Matrix[double] X, integer k) return (Matrix[double] graph) {
+  distances = dist(X);
+  graph = matrix(0, rows=nrow(distances), cols=ncol(distances));
+  ksmall = matrix(0, rows=nrow(distances), cols=1)
+  for (row in 1:nrow(distances)) {
+    referent = kthSmallest(distances[row], k + 1);
+    ksmall[row] = referent
+  }
+  graph = distances <= ksmall
+  # # assign zero to diagonal elements 
+  diagonal = diag(matrix(1, rows=nrow(distances), cols=1)) == 0
+  graph = graph * diagonal
+}
+
+# # # TODO vectorize the below function
+kthSmallest = function(Matrix[double] array, integer k)
+return (integer res) {
+  left = 1;
+  right = ncol(array);
+  found = FALSE;
+
+  while ((left <= right) & !found) {
+    pivot = as.scalar(array[1,right]);
+    i = (left - 1);
+    j = left;
+    while (j < right) {
+      if (as.scalar(array[1,j]) <= pivot) {
+        i = i + 1;
+        temp = as.scalar(array[1,i]);
+        array[1,i] = array[1,j];
+        array[1,j] = temp;
+      }
+      j = j + 1;
+    }
+
+    temp = as.scalar(array[1,i + 1]);
+    array[1,i + 1] = array[1,right];
+    array[1,right] = temp;
+
+    pivot = i + 1;
+
+    if(pivot == k) {
+      res = as.scalar(array[1,pivot]);
+      found = TRUE;
+    }
+    else if (pivot > k)
+      right = pivot - 1;
+    else 
+      left = pivot + 1; 
+  }
+  if (!found) 
+    res = -1; 
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 37e08b1..f1e99f8 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -178,6 +178,7 @@ public enum Builtins {
 	KMEANS("kmeans", true),
 	KMEANSPREDICT("kmeansPredict", true),
 	KNNBF("knnbf", true),
+	KNNGRAPH("knnGraph", true),
 	KNN("knn", true),
 	L2SVM("l2svm", true),
 	L2SVMPREDICT("l2svmPredict", true),
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinKNNGraphTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinKNNGraphTest.java
new file mode 100644
index 0000000..f077ce2
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinKNNGraphTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part1;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+
+import java.util.HashMap;
+
+public class BuiltinKNNGraphTest extends AutomatedTestBase {
+    private final static String TEST_NAME = "knnGraph";
+    private final static String TEST_DIR = "functions/builtin/";
+    private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinKNNGraphTest.class.getSimpleName() + "/";
+
+    private final static String OUTPUT_NAME_KNN_GRAPH = "KNN_GRAPH";
+
+    @Override
+    public void setUp() {
+        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+    }
+
+    @Test
+    public void basicTest() {
+        double[][] X = { { 1, 0 }, { 2, 2 }, { 2, 2.5 }, { 10, 10 }, { 15, 15 } };
+        double[][] refMatrix = {
+                { 0., 1., 1., 0., 0. },
+                { 1., 0., 1., 0., 0. },
+                { 1., 1., 0., 0., 0. },
+                { 0., 0., 1., 0., 1. },
+                { 0., 0., 1., 1., 0. }
+        };
+        HashMap<MatrixValue.CellIndex, Double> refHMMatrix = TestUtils
+                .convert2DDoubleArrayToHashMap(refMatrix);
+
+        runKNNGraphTest(ExecMode.SINGLE_NODE, 2, X, refHMMatrix);
+    }
+
+    private void runKNNGraphTest(ExecMode exec_mode, Integer k, double[][] X,
+            HashMap<MatrixValue.CellIndex, Double> refHMMatrix) {
+        ExecMode platform_old = setExecMode(exec_mode);
+
+        getAndLoadTestConfiguration(TEST_NAME);
+        String HOME = SCRIPT_DIR + TEST_DIR;
+
+        // create Test Input
+        writeInputMatrixWithMTD("X", X, true);
+
+        fullDMLScriptName = HOME + TEST_NAME + ".dml";
+        programArgs = new String[] { "-stats", "-nvargs",
+                "in_X=" + input("X"), "in_k=" + Integer.toString(k), "out_G=" + output(OUTPUT_NAME_KNN_GRAPH) };
+
+        // execute tests
+        runTest(true, false, null, -1);
+
+        // read result
+        HashMap<MatrixValue.CellIndex, Double> resultGraph = readDMLMatrixFromOutputDir(OUTPUT_NAME_KNN_GRAPH);
+
+        // compare result with reference
+        TestUtils.compareMatrices(resultGraph, refHMMatrix, 0, "ResGraph", "RefGraph");
+
+        // restore execution mode
+        setExecMode(platform_old);
+    }
+
+}
diff --git a/src/test/scripts/functions/builtin/knnGraph.dml b/src/test/scripts/functions/builtin/knnGraph.dml
new file mode 100644
index 0000000..b2464f8
--- /dev/null
+++ b/src/test/scripts/functions/builtin/knnGraph.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($in_X)
+k = $in_k
+
+G = knnGraph(X=X, k=k);
+write(G, $out_G);