You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/08/09 21:13:15 UTC
[systemds] branch master updated: [SYSTEMDS-831] New t-SNE builtin
script (from staging)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 962361f [SYSTEMDS-831] New t-SNE builtin script (from staging)
962361f is described below
commit 962361fd5095cf5a4f86fd95d63d0e22f94c1498
Author: Tim Sagaster <ti...@student.tugraz.at>
AuthorDate: Mon Aug 9 23:05:16 2021 +0200
[SYSTEMDS-831] New t-SNE builtin script (from staging)
AMLS project SS2021.
Closes #1360.
Co-authored-by: Imran Younus <im...@gmail.com>
---
docs/site/builtins-reference.md | 35 ++
scripts/builtin/tSNE.dml | 149 ++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../org/apache/sysds/runtime/util/AutoDiff.java | 2 +-
.../test/functions/builtin/BuiltinTSNETest.java | 408 +++++++++++++++++++++
src/test/scripts/functions/builtin/tSNE.dml | 24 ++
6 files changed, 618 insertions(+), 1 deletion(-)
diff --git a/docs/site/builtins-reference.md b/docs/site/builtins-reference.md
index 5209922..da70c49 100644
--- a/docs/site/builtins-reference.md
+++ b/docs/site/builtins-reference.md
@@ -79,6 +79,7 @@ limitations under the License.
* [`steplm`-Function](#steplm-function)
* [`tomekLink`-Function](#tomekLink-function)
* [`toOneHot`-Function](#toOneHOt-function)
+ * [`tSNE`-Function](#tSNE-function)
* [`winsorize`-Function](#winsorize-function)
* [`xgboost`-Function](#xgboost-function)
@@ -2176,6 +2177,40 @@ X = round(rand(rows = 10, cols = 10, min = 1, max = numClasses))
y = toOneHot(X,numClasses)
```
+## `tSNE`-Function
+
+The `tSNE`-function performs dimensionality reduction using tSNE algorithm based on the paper: Visualizing Data using t-SNE, Maaten et. al.
+
+### Usage
+
+```r
+tSNE(X, reduced_dims, perplexity, lr, momentum, max_iter, seed, is_verbose)
+```
+
+### Arguments
+
+| Name | Type | Default | Description |
+| :----------- | :------------- | -------- | :---------- |
+| X | Matrix[Double] | required | Data Matrix of shape (number of data points, input dimensionality) |
+| reduced_dims | Integer | 2 | Output dimensionality |
+| perplexity | Integer | 30 | Perplexity Parameter |
+| lr | Double | 300. | Learning rate |
+| momentum | Double | 0.9 | Momentum Parameter |
+| max_iter | Integer | 1000 | Number of iterations |
+| seed | Integer | -1 | The seed used for initial values. If set to -1 random seeds are selected. |
+| is_verbose | Boolean | FALSE | Print debug information |
+### Returns
+
+| Type | Description |
+| :------------- | :---------- |
+| Matrix[Double] | Data Matrix of shape (number of data points, reduced_dims) |
+
+### Example
+
+```r
+X = rand(rows = 100, cols = 10, min = -10, max = 10))
+Y = tSNE(X)
+```
## `winsorize`-Function
diff --git a/scripts/builtin/tSNE.dml b/scripts/builtin/tSNE.dml
new file mode 100644
index 0000000..3e7926a
--- /dev/null
+++ b/scripts/builtin/tSNE.dml
@@ -0,0 +1,149 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# This function performs dimensionality reduction using tSNE algorithm based on
+# the paper: Visualizing Data using t-SNE, Maaten et. al.
+
+# INPUT PARAMETERS:
+# ----------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+# ----------------------------------------------------------------------------
+# X Double --- Data Matrix of shape
+# (number of data points, input dimensionality)
+# reduced_dims Integer 2 Output dimensionality
+# perplexity Integer 30 Perplexity Parameter
+# lr Double 300. Learning rate
+# momentum Double 0.9 Momentum Parameter
+# max_iter Integer 1000 Number of iterations
+# seed Integer -1 The seed used for initial values.
+# If set to -1 random seeds are selected.
+# is_verbose Boolean FALSE Print debug information
+#
+#
+# RETURN VALUES
+# ----------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+# ----------------------------------------------------------------------------
+# Y Matrix --- Data Matrix of shape (number of data points, reduced_dims)
+# ----------------------------------------------------------------------------
+
+
+m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity = 30,
+ Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Integer seed = -1, Boolean is_verbose = FALSE)
+ return(Matrix[Double] Y)
+{
+ d = reduced_dims
+ n = nrow(X)
+
+ P = x2p(X, perplexity, is_verbose)
+ P = P*4
+ Y = rand(rows=n, cols=d, pdf="normal", seed=seed)
+ dY = matrix(0, rows=n, cols=d)
+ C = matrix(0, rows=max_iter/100, cols=1)
+ ZERODIAG = (diag(matrix(-1, rows=n, cols=1)) + 1)
+
+ D = matrix(0, rows=n, cols=n)
+ Z = matrix(0, rows=n, cols=n)
+ Q = matrix(0, rows=n, cols=n)
+ W = matrix(0, rows=n, cols=n)
+
+ if(is_verbose)
+ print("starting loop....")
+
+ for (itr in 1:max_iter) {
+ D = distance_matrix(Y)
+ Z = 1/(D + 1)
+ Z = Z * ZERODIAG
+ Q = Z/sum(Z)
+ W = (P - Q)*Z
+ sumW = rowSums(W)
+ g = Y * sumW - W %*% Y
+ dY = momentum*dY - lr*g
+ Y = Y + dY
+ Y = Y - colMeans(Y)
+
+ if (itr%%100 == 0) {
+ C[itr/100,] = sum(P * log(pmax(P, 1e-12) / pmax(Q, 1e-12)))
+ }
+ if (itr == 100) {
+ P = P/4
+ }
+ }
+}
+
+distance_matrix = function(matrix[double] X)
+ return (matrix[double] out)
+{
+ # TODO consolidate with dist() builtin, but with
+ # better way of obtaining the diag from
+ n = nrow(X)
+ s = rowSums(X * X)
+ out = - 2*X %*% t(X) + s + t(s)
+}
+
+
+x2p = function(matrix[double] X, double perplexity, Boolean is_verbose = FALSE)
+return(matrix[double] P)
+{
+ if(is_verbose)
+ print("x2p....")
+ tol = 1.0e-5
+ INF = 1.0e20
+ n = nrow(X)
+ if(is_verbose)
+ print(n)
+ D = distance_matrix(X)
+
+ P = matrix(0, rows=n, cols=n)
+ beta = matrix(1, rows=n, cols=1)
+ betamax = matrix(INF, rows=n, cols=1)
+ betamin = matrix(0, rows=n, cols=1)
+ Hdiff = matrix(INF, rows=n, cols=1)
+ logU = log(perplexity)
+
+ ZERODIAG = (diag(matrix(-1, rows=n, cols=1)) + 1)
+ itr = 1
+ while (mean(abs(Hdiff)) > tol & itr < 50) {
+ P = exp(-D * beta)
+ P = P * ZERODIAG
+ sum_Pi = rowSums(P)
+ W = rowSums(P * D)
+ Ws = W/sum_Pi
+ H = log(sum_Pi) + beta * Ws
+ P = P/sum_Pi
+ Hdiff = H - logU
+
+ Hpos = (Hdiff >= 0)
+ Hneg = (Hdiff < 0)
+ betamin = Hneg*betamin + Hpos*beta
+ betamax = Hpos*betamax + Hneg*beta
+ beta = 2*Hpos*(betamax == INF)*beta +
+ Hpos*(betamax != INF)*(beta + betamax)/2 +
+ Hneg*(beta + betamin)/2
+
+ itr = itr + 1
+ }
+
+ P = P + t(P)
+ P = P / sum(P)
+ if(is_verbose)
+ print("x2p finishing....")
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index f2b6c6a..7bb386a 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -260,6 +260,7 @@ public enum Builtins {
TOMEKLINK("tomeklink", true),
TRACE("trace", false),
TRANS("t", false),
+ TSNE("tSNE", true),
TYPEOF("typeof", false),
UNIVAR("univar", true),
VAR("var", false),
diff --git a/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java b/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java
index 2178a13..3f45253 100644
--- a/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java
+++ b/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java
@@ -48,7 +48,7 @@ public class AutoDiff {
public static ListObject getBackward(MatrixObject mo, ArrayList<Data> lineage, ExecutionContext adec) {
- ArrayList<String> names = new ArrayList<String>();
+ ArrayList<String> names = new ArrayList<>();
// parse the lineage and take the number of instructions as for each instruction there is separate hop DAG
String lin = lineage.get(0).toString();
// get rid of foo flag
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinTSNETest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinTSNETest.java
new file mode 100644
index 0000000..d52a048
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinTSNETest.java
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map.Entry;
+
+public class BuiltinTSNETest extends AutomatedTestBase
+{
+ private final static String TEST_NAME = "tSNE";
+ private final static String TEST_DIR = "functions/builtin/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinTSNETest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
+ }
+
+ @Test
+ public void testTSNECP() throws IOException {
+ runTSNETest(2, 30, 300.,
+ 0.9, 1000, 42, "FALSE", ExecType.CP);
+ }
+
+ private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr,
+ Double momentum, Integer max_iter, Integer seed, String is_verbose, ExecType instType)
+ throws IOException
+ {
+ ExecMode platformOld = setExecMode(instType);
+
+ try
+ {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{
+ "-nvargs", "X=" + input("X"), "Y=" + output("Y"),
+ "reduced_dims=" + reduced_dims,
+ "perplexity=" + perplexity,
+ "lr=" + lr,
+ "momentum=" + momentum,
+ "max_iter=" + max_iter,
+ "seed=" + seed,
+ "is_verbose=" + is_verbose};
+
+ // The Input values are calculated using the following R script:
+ // TODO create via dml operations, avoid inlining data
+ // library(Rtsne)
+ // set.seed(42)
+ // iris_unique <- unique(iris)
+ // iris_matrix <- as.matrix(iris_unique[,1:4])
+ // X <- normalize_input(iris_matrix) # the values used for the test
+
+ // Input
+ double[][] X = {{-0.23599574, 0.13972311, -0.74547391, -0.31565495},
+ {-0.29946752, -0.01895634, -0.74547391, -0.31565495},
+ {-0.36293930, 0.04451544, -0.77720980, -0.31565495},
+ {-0.39467519, 0.01277955, -0.71373802, -0.31565495},
+ {-0.26773163, 0.17145900, -0.74547391, -0.31565495},
+ {-0.14078807, 0.26666667, -0.65026624, -0.25218317},
+ {-0.39467519, 0.10798722, -0.74547391, -0.28391906},
+ {-0.26773163, 0.10798722, -0.71373802, -0.31565495},
+ {-0.45814696, -0.05069223, -0.74547391, -0.31565495},
+ {-0.29946752, 0.01277955, -0.71373802, -0.34739084},
+ {-0.14078807, 0.20319489, -0.71373802, -0.31565495},
+ {-0.33120341, 0.10798722, -0.68200213, -0.31565495},
+ {-0.33120341, -0.01895634, -0.74547391, -0.34739084},
+ {-0.48988285, -0.01895634, -0.84068158, -0.34739084},
+ {-0.01384452, 0.29840256, -0.80894569, -0.31565495},
+ {-0.04558040, 0.42534611, -0.71373802, -0.25218317},
+ {-0.14078807, 0.26666667, -0.77720980, -0.25218317},
+ {-0.23599574, 0.13972311, -0.74547391, -0.28391906},
+ {-0.04558040, 0.23493078, -0.65026624, -0.28391906},
+ {-0.23599574, 0.23493078, -0.71373802, -0.28391906},
+ {-0.14078807, 0.10798722, -0.65026624, -0.31565495},
+ {-0.23599574, 0.20319489, -0.71373802, -0.25218317},
+ {-0.39467519, 0.17145900, -0.87241747, -0.31565495},
+ {-0.23599574, 0.07625133, -0.65026624, -0.22044728},
+ {-0.33120341, 0.10798722, -0.58679446, -0.31565495},
+ {-0.26773163, -0.01895634, -0.68200213, -0.31565495},
+ {-0.26773163, 0.10798722, -0.68200213, -0.25218317},
+ {-0.20425985, 0.13972311, -0.71373802, -0.31565495},
+ {-0.20425985, 0.10798722, -0.74547391, -0.31565495},
+ {-0.36293930, 0.04451544, -0.68200213, -0.31565495},
+ {-0.33120341, 0.01277955, -0.68200213, -0.31565495},
+ {-0.14078807, 0.10798722, -0.71373802, -0.25218317},
+ {-0.20425985, 0.33013845, -0.71373802, -0.34739084},
+ {-0.10905218, 0.36187433, -0.74547391, -0.31565495},
+ {-0.29946752, 0.01277955, -0.71373802, -0.31565495},
+ {-0.26773163, 0.04451544, -0.80894569, -0.31565495},
+ {-0.10905218, 0.13972311, -0.77720980, -0.31565495},
+ {-0.29946752, 0.17145900, -0.74547391, -0.34739084},
+ {-0.45814696, -0.01895634, -0.77720980, -0.31565495},
+ {-0.23599574, 0.10798722, -0.71373802, -0.31565495},
+ {-0.26773163, 0.13972311, -0.77720980, -0.28391906},
+ {-0.42641108, -0.24110756, -0.77720980, -0.28391906},
+ {-0.45814696, 0.04451544, -0.77720980, -0.31565495},
+ {-0.26773163, 0.13972311, -0.68200213, -0.18871140},
+ {-0.23599574, 0.23493078, -0.58679446, -0.25218317},
+ {-0.33120341, -0.01895634, -0.74547391, -0.28391906},
+ {-0.23599574, 0.23493078, -0.68200213, -0.31565495},
+ {-0.39467519, 0.04451544, -0.74547391, -0.31565495},
+ {-0.17252396, 0.20319489, -0.71373802, -0.31565495},
+ {-0.26773163, 0.07625133, -0.74547391, -0.31565495},
+ {0.36698616, 0.04451544, 0.30181044, 0.06517572},
+ {0.17657082, 0.04451544, 0.23833866, 0.09691161},
+ {0.33525027, 0.01277955, 0.36528222, 0.09691161},
+ {-0.10905218, -0.24110756, 0.07965921, 0.03343983},
+ {0.20830671, -0.08242812, 0.27007455, 0.09691161},
+ {-0.04558040, -0.08242812, 0.23833866, 0.03343983},
+ {0.14483493, 0.07625133, 0.30181044, 0.12864750},
+ {-0.29946752, -0.20937167, -0.14249201, -0.06176784},
+ {0.24004260, -0.05069223, 0.27007455, 0.03343983},
+ {-0.20425985, -0.11416400, 0.04792332, 0.06517572},
+ {-0.26773163, -0.33631523, -0.07902023, -0.06176784},
+ {0.01789137, -0.01895634, 0.14313099, 0.09691161},
+ {0.04962726, -0.27284345, 0.07965921, -0.06176784},
+ {0.08136315, -0.05069223, 0.30181044, 0.06517572},
+ {-0.07731629, -0.05069223, -0.04728435, 0.03343983},
+ {0.27177849, 0.01277955, 0.20660277, 0.06517572},
+ {-0.07731629, -0.01895634, 0.23833866, 0.09691161},
+ {-0.01384452, -0.11416400, 0.11139510, -0.06176784},
+ {0.11309904, -0.27284345, 0.23833866, 0.09691161},
+ {-0.07731629, -0.17763578, 0.04792332, -0.03003195},
+ {0.01789137, 0.04451544, 0.33354633, 0.19211928},
+ {0.08136315, -0.08242812, 0.07965921, 0.03343983},
+ {0.14483493, -0.17763578, 0.36528222, 0.09691161},
+ {0.08136315, -0.08242812, 0.30181044, 0.00170394},
+ {0.17657082, -0.05069223, 0.17486688, 0.03343983},
+ {0.24004260, -0.01895634, 0.20660277, 0.06517572},
+ {0.30351438, -0.08242812, 0.33354633, 0.06517572},
+ {0.27177849, -0.01895634, 0.39701810, 0.16038339},
+ {0.04962726, -0.05069223, 0.23833866, 0.09691161},
+ {-0.04558040, -0.14589989, -0.07902023, -0.06176784},
+ {-0.10905218, -0.20937167, 0.01618743, -0.03003195},
+ {-0.10905218, -0.20937167, -0.01554846, -0.06176784},
+ {-0.01384452, -0.11416400, 0.04792332, 0.00170394},
+ {0.04962726, -0.11416400, 0.42875399, 0.12864750},
+ {-0.14078807, -0.01895634, 0.23833866, 0.09691161},
+ {0.04962726, 0.10798722, 0.23833866, 0.12864750},
+ {0.27177849, 0.01277955, 0.30181044, 0.09691161},
+ {0.14483493, -0.24110756, 0.20660277, 0.03343983},
+ {-0.07731629, -0.01895634, 0.11139510, 0.03343983},
+ {-0.10905218, -0.17763578, 0.07965921, 0.03343983},
+ {-0.10905218, -0.14589989, 0.20660277, 0.00170394},
+ {0.08136315, -0.01895634, 0.27007455, 0.06517572},
+ {-0.01384452, -0.14589989, 0.07965921, 0.00170394},
+ {-0.26773163, -0.24110756, -0.14249201, -0.06176784},
+ {-0.07731629, -0.11416400, 0.14313099, 0.03343983},
+ {-0.04558040, -0.01895634, 0.14313099, 0.00170394},
+ {-0.04558040, -0.05069223, 0.14313099, 0.03343983},
+ {0.11309904, -0.05069223, 0.17486688, 0.03343983},
+ {-0.23599574, -0.17763578, -0.23769968, -0.03003195},
+ {-0.04558040, -0.08242812, 0.11139510, 0.03343983},
+ {0.14483493, 0.07625133, 0.71437700, 0.41427050},
+ {-0.01384452, -0.11416400, 0.42875399, 0.22385517},
+ {0.39872204, -0.01895634, 0.68264111, 0.28732694},
+ {0.14483493, -0.05069223, 0.58743344, 0.19211928},
+ {0.20830671, -0.01895634, 0.65090522, 0.31906283},
+ {0.55740149, -0.01895634, 0.90479233, 0.28732694},
+ {-0.29946752, -0.17763578, 0.23833866, 0.16038339},
+ {0.46219382, -0.05069223, 0.80958466, 0.19211928},
+ {0.27177849, -0.17763578, 0.65090522, 0.19211928},
+ {0.43045793, 0.17145900, 0.74611289, 0.41427050},
+ {0.20830671, 0.04451544, 0.42875399, 0.25559105},
+ {0.17657082, -0.11416400, 0.49222577, 0.22385517},
+ {0.30351438, -0.01895634, 0.55569755, 0.28732694},
+ {-0.04558040, -0.17763578, 0.39701810, 0.25559105},
+ {-0.01384452, -0.08242812, 0.42875399, 0.38253461},
+ {0.17657082, 0.04451544, 0.49222577, 0.35079872},
+ {0.20830671, -0.01895634, 0.55569755, 0.19211928},
+ {0.58913738, 0.23493078, 0.93652822, 0.31906283},
+ {0.58913738, -0.14589989, 1.00000000, 0.35079872},
+ {0.04962726, -0.27284345, 0.39701810, 0.09691161},
+ {0.33525027, 0.04451544, 0.61916933, 0.35079872},
+ {-0.07731629, -0.08242812, 0.36528222, 0.25559105},
+ {0.58913738, -0.08242812, 0.93652822, 0.25559105},
+ {0.14483493, -0.11416400, 0.36528222, 0.19211928},
+ {0.27177849, 0.07625133, 0.61916933, 0.28732694},
+ {0.43045793, 0.04451544, 0.71437700, 0.19211928},
+ {0.11309904, -0.08242812, 0.33354633, 0.19211928},
+ {0.08136315, -0.01895634, 0.36528222, 0.19211928},
+ {0.17657082, -0.08242812, 0.58743344, 0.28732694},
+ {0.43045793, -0.01895634, 0.65090522, 0.12864750},
+ {0.49392971, -0.08242812, 0.74611289, 0.22385517},
+ {0.65260916, 0.23493078, 0.84132055, 0.25559105},
+ {0.17657082, -0.08242812, 0.58743344, 0.31906283},
+ {0.14483493, -0.08242812, 0.42875399, 0.09691161},
+ {0.08136315, -0.14589989, 0.58743344, 0.06517572},
+ {0.58913738, -0.01895634, 0.74611289, 0.35079872},
+ {0.14483493, 0.10798722, 0.58743344, 0.38253461},
+ {0.17657082, 0.01277955, 0.55569755, 0.19211928},
+ {0.04962726, -0.01895634, 0.33354633, 0.19211928},
+ {0.33525027, 0.01277955, 0.52396166, 0.28732694},
+ {0.27177849, 0.01277955, 0.58743344, 0.38253461},
+ {0.33525027, 0.01277955, 0.42875399, 0.35079872},
+ {0.30351438, 0.04451544, 0.68264111, 0.35079872},
+ {0.27177849, 0.07625133, 0.61916933, 0.41427050},
+ {0.27177849, -0.01895634, 0.46048988, 0.35079872},
+ {0.14483493, -0.17763578, 0.39701810, 0.22385517},
+ {0.20830671, -0.01895634, 0.46048988, 0.25559105},
+ {0.11309904, 0.10798722, 0.52396166, 0.35079872},
+ {0.01789137, -0.01895634, 0.42875399, 0.19211928}};
+
+ // The reference output was created by using the builtin function with seed 42 and visually inspecting the
+ // result with the following addition to the above R script:
+ /*
+ plot(Y, col = iris_unique$Species)
+ */
+
+ // reference Output
+ double[][] YReference = {{18.220536548250042, -12.846498524536738},
+ {15.927903386925026, -14.212023388236792},
+ {16.769777454402725, -14.867104469807458},
+ {16.290613410318578, -14.971912325413014},
+ {18.534108527624923, -13.081965971299278},
+ {19.46702930119709, -11.107384827606543},
+ {17.196995022994926, -14.952457676596161},
+ {17.531360762128234, -13.133905834551287},
+ {15.996750713161672, -15.670577143806288},
+ {16.36534147032176, -13.94640444049381},
+ {19.094349767837077, -11.557657039778153},
+ {16.909635859249846, -13.432332957158627},
+ {15.964241411757008, -14.583849627922195},
+ {16.313709524761837, -16.090269929669734},
+ {20.285962611966927, -10.881660944862407},
+ {20.554758173661426, -11.06392329099603},
+ {19.81906056722687, -11.547188487333667},
+ {18.156895316378105, -12.705433004326382},
+ {19.567551886989456, -10.747111880219315},
+ {19.250142947939032, -12.360380239016678},
+ {17.98651875041291, -11.373400820175243},
+ {18.85701724837406, -12.344092918634603},
+ {15.003071782449434, -15.134107990259263},
+ {16.906273660742716, -12.081578738297303},
+ {16.191062837166225, -12.743134302084322},
+ {15.882718808144244, -13.620171658276757},
+ {17.162351058283594, -12.52282386909434},
+ {18.222170807106863, -12.25960056343841},
+ {17.965249267850382, -12.497678277213915},
+ {16.61308662972421, -14.338853140723119},
+ {16.181343729377808, -14.009495314657007},
+ {18.02370742093665, -11.5469978987216},
+ {20.11583352615154, -11.971703104743623},
+ {20.34766483388781, -11.368324149466424},
+ {16.425063301904068, -13.920590257944497},
+ {17.456451755948965, -14.167821170017678},
+ {18.596836114631266, -11.38629820258609},
+ {18.433917702277277, -13.54006232724752},
+ {16.220662327912546, -15.705004632184767},
+ {17.693153385837373, -12.795456099765842},
+ {18.171476611546456, -13.323284884616521},
+ {15.440053805476438, -16.14242285042382},
+ {16.61041486346442, -15.64274785656758},
+ {17.151462077196808, -11.907401620659614},
+ {18.626998818277606, -10.74416620713765},
+ {15.98116134477208, -14.570088706192026},
+ {19.310896130733955, -12.328582468776172},
+ {16.638185072116197, -15.041546247833768},
+ {19.00600656884603, -11.840914144209874},
+ {17.403065129854806, -13.533995008873486},
+ {-7.038268225948856, 7.536962272871995},
+ {-7.3914411713029935, 5.908404973751881},
+ {-7.496137452637684, 7.80002229833549},
+ {-6.460592719894593, 0.9062234315612744},
+ {-7.616915840513189, 6.367666534180176},
+ {-6.73555772449111, 3.123913125534304},
+ {-8.080845938411253, 6.030180083529746},
+ {-6.003581004054341, -0.5426273138978279},
+ {-7.169789369755973, 6.530569753225136},
+ {-5.74311471355638, 1.109042979038742},
+ {-6.159692868338611, -0.47672936959379775},
+ {-7.224663019809244, 3.2705583465396795},
+ {-7.80311490665107, 0.8971687420050435},
+ {-8.171493899007842, 4.945491553039871},
+ {-6.055076176778027, 0.8068408811777839},
+ {-6.841744219152401, 6.563990176275155},
+ {-6.4688066027110995, 3.3600337950542185},
+ {-7.290016820260793, 1.6539424879073648},
+ {-9.427559060500517, 3.747497254956788},
+ {-6.870144371536429, 0.8988782592514124},
+ {-9.136986152714911, 5.784098425135586},
+ {-7.68317760808878, 2.43574703565265},
+ {-9.740448694574955, 5.387219781790028},
+ {-8.073455284093917, 4.572592063175601},
+ {-6.941003784668771, 5.2940569383973815},
+ {-6.92987101320006, 6.2327815077976485},
+ {-7.48373098301499, 7.341059972461793},
+ {-8.40332894242311, 8.176817619574457},
+ {-7.681170155620538, 4.334864702224226},
+ {-6.956305041180643, 0.2075870963686768},
+ {-6.655261285197925, 0.5228359022734664},
+ {-6.666350115623071, 0.2717264575985321},
+ {-7.0998853298193145, 1.4238408353689687},
+ {-10.2308151080835, 6.427565359129248},
+ {-6.01229671664338, 3.082044215406439},
+ {-8.034938379871582, 5.334134289261261},
+ {-7.359172540935799, 7.109704156745161},
+ {-9.171082430966738, 3.5626753890742657},
+ {-6.535162057220474, 2.312414150179807},
+ {-6.3789219417924405, 1.2216485793591474},
+ {-6.067256940518085, 2.1167290391035074},
+ {-7.812215578205332, 4.875684399693769},
+ {-7.1374342502995605, 1.4628113062889514},
+ {-6.070547359306737, -0.5122472965230369},
+ {-6.46343725557135, 1.979317360428436},
+ {-6.847076431902512, 2.5392196334816846},
+ {-6.808114424651112, 2.498147826359703},
+ {-7.202776337109158, 4.407291480436176},
+ {-6.008168593188563, -0.6619390315875436},
+ {-6.831119010530879, 2.045771384669187},
+ {-12.091830779150229, 11.265359769601078},
+ {-10.854204065227533, 6.234604499378284},
+ {-10.722649555263748, 12.399149192535997},
+ {-11.024368675274824, 9.162144335131195},
+ {-11.303048023519185, 10.732650608737089},
+ {-11.09821732195217, 13.779405512286052},
+ {-5.0247832693311425, 2.1742184328563603},
+ {-10.619343405836839, 13.298677973987925},
+ {-11.79686116542718, 9.494394976289728},
+ {-11.586950451837295, 12.746499445013628},
+ {-9.748309996567489, 9.564317509433048},
+ {-10.481918971993704, 8.61651017140811},
+ {-10.462599559239374, 10.911028525489696},
+ {-11.130930155766606, 5.964285024103242},
+ {-11.553937238933786, 6.4713896554450985},
+ {-10.681883406818072, 10.318460208361941},
+ {-10.683549192372697, 9.495275009653032},
+ {-10.78603000302186, 14.25264787910952},
+ {-11.488550071992856, 13.974915029487743},
+ {-10.319674004978886, 4.889390024695409},
+ {-10.860193039087495, 11.683572442191727},
+ {-11.051151523346409, 5.890167964580233},
+ {-11.229564353462594, 13.877567340046305},
+ {-9.673077426769021, 6.894318103609319},
+ {-10.96913834998139, 11.322206899793283},
+ {-10.490035372117925, 12.825714370138753},
+ {-9.462007223927465, 6.4675755622386655},
+ {-9.51159790297749, 6.337629581532479},
+ {-11.162432115640609, 9.827386758865119},
+ {-10.151930853808723, 12.721404715075673},
+ {-10.604601715753107, 13.20203200422886},
+ {-10.661315153978798, 14.213091123997359},
+ {-11.20216179567001, 10.027518789368527},
+ {-9.545267493844456, 7.1692559024456735},
+ {-11.159867362151168, 7.8902037332855635},
+ {-11.088628333620173, 13.451554645466246},
+ {-11.888514670355377, 10.819305597527286},
+ {-10.69716594839169, 9.44576929331764},
+ {-9.388828314326277, 6.02029293546396},
+ {-10.20160261480168, 11.075164265002408},
+ {-10.984952128353312, 11.125251232506113},
+ {-9.734560763746329, 10.70872555677429},
+ {-11.239599613172867, 11.820790907417845},
+ {-11.469876956620716, 11.460531594176354},
+ {-10.037039090761771, 10.470461014159099},
+ {-10.182738080227073, 7.165787375813321},
+ {-10.04300898871791, 9.494772412049684},
+ {-11.912720679251729, 10.441038042708657},
+ {-10.30731700479772, 6.343742643599125}};
+
+ writeInputMatrixWithMTD("X", X, true);
+
+ runTest(true, false, null, -1);
+ HashMap<MatrixValue.CellIndex, Double> dmlFileY = readDMLMatrixFromOutputDir("Y");
+
+ // Verifying
+ for (Entry<CellIndex, Double> entry : dmlFileY.entrySet()) {
+ MatrixValue.CellIndex key = entry.getKey();
+ Double value = entry.getValue();
+ Assert.assertEquals("The DML data for cell (" + key.row + "," + key.column + ") '" + value + "' is " +
+ "not equal to the expected value '" + YReference[key.row-1][key.column-1] + "'",
+ YReference[key.row-1][key.column-1], value, 3); //TODO algorithm-level differences?
+ }
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/builtin/tSNE.dml b/src/test/scripts/functions/builtin/tSNE.dml
new file mode 100644
index 0000000..8310f75
--- /dev/null
+++ b/src/test/scripts/functions/builtin/tSNE.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($X);
+Y = tSNE(X, $reduced_dims, $perplexity, $lr, $momentum, $max_iter, $seed, $is_verbose)
+write(Y, $Y)