You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/07/24 21:08:56 UTC

[systemds] branch master updated: [SYSTEMDS-2572] Additional mlcontext test for nn-library imports

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 8d61ae6  [SYSTEMDS-2572] Additional mlcontext test for nn-library imports
8d61ae6 is described below

commit 8d61ae6f46f0a8ce21f9ad7c3a617023f6983778
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Fri Jul 24 23:08:27 2020 +0200

    [SYSTEMDS-2572] Additional mlcontext test for nn-library imports
    
    The bug reported in SYSTEMDS-2572 was non-reproducible both in a local
    environment as well as through spark-shell. However, as the mlcontext
    tests did not include a test for sourcing (importing) dml scripts, we
    add the related test script accordingly.
---
 .../sysds/test/functions/mlcontext/MLContextTest.java      | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
index 3e07b15..697e9e9 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
@@ -1904,5 +1904,17 @@ public class MLContextTest extends MLContextTestBase {
 		Assert.assertEquals(true, c);
 		Assert.assertEquals("yes it's TRUE", d);
 	}
-
+	
+	@Test
+	public void testNNImport() {
+		System.out.println("MLContextTest - NN import");
+		String s =    "source(\"scripts/nn/layers/relu.dml\") as relu;\n"
+					+ "X = rand(rows=100, cols=10, min=-1, max=1);\n"
+					+ "R1 = relu::forward(X);\n"
+					+ "R2 = max(X, 0);\n"
+					+ "R = sum(R1==R2);\n";
+		double ret = ml.execute(dml(s).out("R"))
+			.getScalarObject("R").getDoubleValue();
+		Assert.assertEquals(1000, ret, 1e-20);
+	}
 }