You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/06/29 20:44:20 UTC

[systemds] branch master updated: [SYSTEMDS-3044] Support for dml-bodied builtin functions w/ imports, I

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new b31e124  [SYSTEMDS-3044] Support for dml-bodied builtin functions w/ imports, I
b31e124 is described below

commit b31e12429cd35b0ccfa636652bd5673db071968e
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Tue Jun 29 22:43:44 2021 +0200

    [SYSTEMDS-3044] Support for dml-bodied builtin functions w/ imports, I
    
    This patch adds the necessary mechanics for imports (source statements)
    in dml-bodied builtin functions such as DNN architectures and the
    enumeration of cleaning pipelines. Now, the parsing of builtin functions
    might bring in entirely new namespaces or additional functions of an
    already existing namespace.
    
    In a second part, we need to extend the eval lazy loading for also
    support bringing in additional namespace and properly handling them in
    parfor.
---
 scripts/builtin/softmax.dml                        | 25 ++++++
 .../java/org/apache/sysds/common/Builtins.java     |  1 +
 .../apache/sysds/parser/dml/DMLParserWrapper.java  |  6 +-
 .../sysds/parser/dml/DmlSyntacticValidator.java    | 25 +++++-
 .../test/functions/builtin/BuiltinSoftmaxTest.java | 94 ++++++++++++++++++++++
 src/test/scripts/functions/builtin/softmax1.dml    | 28 +++++++
 src/test/scripts/functions/builtin/softmax2.dml    | 28 +++++++
 7 files changed, 202 insertions(+), 5 deletions(-)

diff --git a/scripts/builtin/softmax.dml b/scripts/builtin/softmax.dml
new file mode 100644
index 0000000..f9e110b
--- /dev/null
+++ b/scripts/builtin/softmax.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+source("nn/layers/softmax.dml") as sm
+
+m_softmax = function(Matrix[Double] S) return (Matrix[Double] P) {
+  P = sm::forward(S);
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 99acfaf..3288a25 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -233,6 +233,7 @@ public enum Builtins {
 	SINH("sinh", false),
 	SLICEFINDER("slicefinder", true),
 	SMOTE("smote", true),
+	SOFTMAX("softmax", true),
 	SOLVE("solve", false),
 	SPLIT("split", true),
 	SPLIT_BALANCED("splitBalanced", true),
diff --git a/src/main/java/org/apache/sysds/parser/dml/DMLParserWrapper.java b/src/main/java/org/apache/sysds/parser/dml/DMLParserWrapper.java
index c3f8ca5..57c1b62 100644
--- a/src/main/java/org/apache/sysds/parser/dml/DMLParserWrapper.java
+++ b/src/main/java/org/apache/sysds/parser/dml/DMLParserWrapper.java
@@ -220,7 +220,9 @@ public class DMLParserWrapper extends ParserWrapper
 			dmlPgm.createNamespace(DMLProgram.BUILTIN_NAMESPACE);
 		for( Entry<String, FunctionStatementBlock> e : fbuiltins.getFunctions().entrySet() )
 			dmlPgm.addFunctionStatementBlock(DMLProgram.BUILTIN_NAMESPACE, e.getKey(), e.getValue());
-
+		for( Entry<String, FunctionDictionary<FunctionStatementBlock>> e : validator.getParsedBuiltinFunctionsNs().entrySet() )
+			addFunctions(dmlPgm, e.getKey(), e.getValue());
+		
 		// add statements from main script file, as well as 
 		// functions from imports and dml-bodied builtin functions
 		for(StatementContext stmtCtx : ast.blocks) {
@@ -235,7 +237,7 @@ public class DMLParserWrapper extends ParserWrapper
 				// Handle import statements separately
 				if(stmtCtx.info.namespaces != null) {
 					// Add the DMLProgram entries into current program
-					for(Map.Entry<String, FunctionDictionary<FunctionStatementBlock>> e : stmtCtx.info.namespaces.entrySet()) {
+					for(Entry<String, FunctionDictionary<FunctionStatementBlock>> e : stmtCtx.info.namespaces.entrySet()) {
 						addFunctions(dmlPgm, e.getKey(), e.getValue());
 					}
 				}
diff --git a/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java b/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
index 3c91244..dc56248 100644
--- a/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
+++ b/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
@@ -158,6 +158,8 @@ public class DmlSyntacticValidator implements DmlListener {
 	protected Set<String> functions;
 	// DML-bodied builtin functions
 	protected FunctionDictionary<FunctionStatementBlock> builtinFuns;
+	// DML-bodied namespace functions (loaded via builtins)
+	protected HashMap<String, FunctionDictionary<FunctionStatementBlock>> builtinFunsNs;
 	
 	public DmlSyntacticValidator(CustomErrorListener errorListener, Map<String,String> argVals, String sourceNamespace, Set<String> prepFunctions) {
 		this.errorListener = errorListener;
@@ -167,6 +169,7 @@ public class DmlSyntacticValidator implements DmlListener {
 		sources = new HashMap<>();
 		functions = (null != prepFunctions) ? prepFunctions : new HashSet<>();
 		builtinFuns = new FunctionDictionary<>();
+		builtinFunsNs = new HashMap<>();
 	}
 
 
@@ -191,6 +194,10 @@ public class DmlSyntacticValidator implements DmlListener {
 		return builtinFuns;
 	}
 	
+	public Map<String, FunctionDictionary<FunctionStatementBlock>> getParsedBuiltinFunctionsNs() {
+		return builtinFunsNs;
+	}
+	
 	protected ArrayList<ParameterExpression> getParameterExpressionList(List<ParameterizedExpressionContext> paramExprs) {
 		ArrayList<ParameterExpression> retVal = new ArrayList<>();
 		for(ParameterizedExpressionContext ctx : paramExprs) {
@@ -610,11 +617,23 @@ public class DmlSyntacticValidator implements DmlListener {
 		{
 			//load and add builtin DML-bodied functions
 			String filePath = Builtins.getFilePath(functionName);
-			FunctionDictionary<FunctionStatementBlock> prog = 
-				parseAndAddImportedFunctions(namespace, filePath, ctx).getBuiltinFunctionDictionary();
-			if( prog != null ) //robustness for existing functions
+			DMLProgram tmpProg = parseAndAddImportedFunctions(namespace, filePath, ctx);
+			FunctionDictionary<FunctionStatementBlock> prog = tmpProg.getBuiltinFunctionDictionary();
+			if( prog != null ) { //robustness for existing functions
+				//add builtin functions
 				for( Entry<String,FunctionStatementBlock> f : prog.getFunctions().entrySet() )
 					builtinFuns.addFunction(f.getKey(), f.getValue());
+				//add namespaces loaded by builtin functions (via source)
+				tmpProg.getNamespaces().entrySet().stream()
+					.filter(e -> !e.getKey().equals(DMLProgram.BUILTIN_NAMESPACE))
+					.forEach(e -> {
+						String ns = getQualifiedNamespace(e.getKey());
+						if( builtinFunsNs.containsKey(ns) )
+							builtinFunsNs.get(ns).merge(e.getValue());
+						else
+							builtinFunsNs.put(ns, e.getValue());
+					});
+			}
 		}
 	}
 
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSoftmaxTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSoftmaxTest.java
new file mode 100644
index 0000000..5135f02
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSoftmaxTest.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+
+public class BuiltinSoftmaxTest extends AutomatedTestBase 
+{
+	private final static String TEST_NAME1 = "softmax1";
+	private final static String TEST_NAME2 = "softmax2";
+	private final static String TEST_DIR = "functions/builtin/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinSoftmaxTest.class.getSimpleName() + "/";
+	
+	private final static double eps = 1e-6;
+	private final static int rows = 1765;
+	private final static double spDense = 0.99;
+	
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"B"}));
+		addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"B"}));
+	}
+
+	@Test
+	public void testSoftmaxCP() {
+		runSoftmaxTest(TEST_NAME1, ExecType.CP);
+	}
+	
+	@Test
+	public void testSoftmaxSP() {
+		runSoftmaxTest(TEST_NAME1, ExecType.SPARK);
+	}
+
+//TODO add support for eval lazy loading of builtin funcitons w/ imports
+//	@Test
+//	public void testSoftmaxEvalCP() {
+//		runSoftmaxTest(TEST_NAME2, ExecType.CP);
+//	}
+//	
+//	@Test
+//	public void testSoftmaxEvalSP() {
+//		runSoftmaxTest(TEST_NAME2, ExecType.SPARK);
+//	}
+
+	private void runSoftmaxTest(String testname, ExecType instType) {
+		ExecMode platformOld = setExecMode(instType);
+		
+		try {
+			loadTestConfiguration(getTestConfiguration(testname));
+			
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			programArgs = new String[]{"-args",
+				input("A"), String.valueOf(eps), output("B") };
+			
+			//generate actual dataset 
+			double[][] A = getRandomMatrix(rows, 10, -1, 1, spDense, 7);
+			writeInputMatrixWithMTD("A", A, true);
+			
+			runTest(true, false, null, -1);
+			
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("B");
+			Assert.assertEquals(rows*10, dmlfile.get(new CellIndex(1,1)).intValue());
+		}
+		finally {
+			rtplatform = platformOld;
+		}
+	}
+}
diff --git a/src/test/scripts/functions/builtin/softmax1.dml b/src/test/scripts/functions/builtin/softmax1.dml
new file mode 100644
index 0000000..1b0a24a
--- /dev/null
+++ b/src/test/scripts/functions/builtin/softmax1.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+
+R1 = softmax(X);
+R2 = exp(X-rowMaxs(X)) / rowSums(exp(X-rowMaxs(X)))
+Y = as.matrix(sum(abs(R1-R2)<$2))
+
+write(Y, $3)
diff --git a/src/test/scripts/functions/builtin/softmax2.dml b/src/test/scripts/functions/builtin/softmax2.dml
new file mode 100644
index 0000000..8690147
--- /dev/null
+++ b/src/test/scripts/functions/builtin/softmax2.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+
+R1 = eval("softmax", X);
+R2 = exp(X-rowMaxs(X)) / rowSums(exp(X-rowMaxs(X)))
+Y = as.matrix(sum(abs(R1-R2)<$2))
+
+write(Y, $3)