You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2020/06/23 21:06:23 UTC

[systemml] branch master updated: [MINOR] Additional lineage parfor remote tests, and cleanups

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new c6d7a52  [MINOR] Additional lineage parfor remote tests, and cleanups
c6d7a52 is described below

commit c6d7a52e2e4259fa62ba8e0b15cdfe1397baac0f
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Tue Jun 23 22:46:05 2020 +0200

    [MINOR] Additional lineage parfor remote tests, and cleanups
    
    This patch adds msvm w/ remote_spark parfor workers to the test suite
    and fixes missing support for tak+ operators in the recompute-by-lineage
    utility.
---
 scripts/builtin/l2svm.dml                          |  2 +-
 .../sysds/hops/ipa/FunctionCallSizeInfo.java       |  9 ++--
 .../sysds/runtime/lineage/LineageItemUtils.java    | 25 ++++++---
 .../functions/lineage/LineageTraceParforTest.java  |  7 +++
 .../functions/lineage/LineageTraceParforMSVM.dml   | 61 ++++++++++++++++++++++
 5 files changed, 90 insertions(+), 14 deletions(-)

diff --git a/scripts/builtin/l2svm.dml b/scripts/builtin/l2svm.dml
index 3e251ae..f411fb9 100644
--- a/scripts/builtin/l2svm.dml
+++ b/scripts/builtin/l2svm.dml
@@ -72,7 +72,7 @@ m_l2svm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE
 
   # TODO make this a stop condition for l2svm instead of just printing.
   if(num_min + num_max != nrow(Y))
-    print("L2SVM: WARNING invalid number of labels in Y")
+    print("L2SVM: WARNING invalid number of labels in Y: "+num_min+" "+num_max)
 
   # Scale inputs to -1 for negative, and 1 for positive classification
   if(check_min != -1 | check_max != +1)
diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
index b349a5f..551ce98 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
@@ -233,14 +233,11 @@ public class FunctionCallSizeInfo
 								   &&  h1.getDim1()==h2.getDim1() 
 								   &&  h1.getDim2()==h2.getDim2()
 								   &&  h1.getNnz()==h2.getNnz() );
-						//check literal values (equi value)
-						if( h1 instanceof LiteralOp ) {
-							consistent &= (h2 instanceof LiteralOp 
+						//check literal values (both needs to be literals and same value)
+						if( h1 instanceof LiteralOp || h2 instanceof LiteralOp ) {
+							consistent &= (h1 instanceof LiteralOp && h2 instanceof LiteralOp
 								&& HopRewriteUtils.isEqualValue((LiteralOp)h1, (LiteralOp)h2));
 						}
-						else if(h2 instanceof LiteralOp) {
-							consistent = false; //h2 literal, but h1 not
-						}
 					}
 				}
 				if( consistent )
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index 467bbc9..e659025 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -278,6 +278,24 @@ public class LineageItemUtils {
 							operands.put(item.getId(), aggunary);
 							break;
 						}
+						case AggregateBinary: {
+							Hop input1 = operands.get(item.getInputs()[0].getId());
+							Hop input2 = operands.get(item.getInputs()[1].getId());
+							Hop aggbinary = HopRewriteUtils.createMatrixMultiply(input1, input2);
+							operands.put(item.getId(), aggbinary);
+							break;
+						}
+						case AggregateTernary: {
+							Hop input1 = operands.get(item.getInputs()[0].getId());
+							Hop input2 = operands.get(item.getInputs()[1].getId());
+							Hop input3 = operands.get(item.getInputs()[2].getId());
+							Hop aggternary = HopRewriteUtils.createSum(
+								HopRewriteUtils.createBinary(
+								HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT),
+								input3, OpOp2.MULT));
+							operands.put(item.getId(), aggternary);
+							break;
+						}
 						case Unary:
 						case Builtin: {
 							Hop input = operands.get(item.getInputs()[0].getId());
@@ -308,13 +326,6 @@ public class LineageItemUtils {
 							operands.put(item.getId(), binary);
 							break;
 						}
-						case AggregateBinary: {
-							Hop input1 = operands.get(item.getInputs()[0].getId());
-							Hop input2 = operands.get(item.getInputs()[1].getId());
-							Hop aggbinary = HopRewriteUtils.createMatrixMultiply(input1, input2);
-							operands.put(item.getId(), aggbinary);
-							break;
-						}
 						case Ternary: {
 							operands.put(item.getId(), HopRewriteUtils.createTernary(
 								operands.get(item.getInputs()[0].getId()), 
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
index d100a4d..b3e0d73 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
@@ -46,6 +46,7 @@ public class LineageTraceParforTest extends AutomatedTestBase {
 	protected static final String TEST_NAME3 = "LineageTraceParfor3"; //rand - matrix result - remote spark parfor
 	protected static final String TEST_NAME4 = "LineageTraceParforSteplm"; //rand - steplm
 	protected static final String TEST_NAME5 = "LineageTraceParforKmeans"; //rand - kmeans
+	protected static final String TEST_NAME6 = "LineageTraceParforMSVM"; //rand - msvm remote parfor
 	
 	protected String TEST_CLASS_DIR = TEST_DIR + LineageTraceParforTest.class.getSimpleName() + "/";
 	
@@ -63,6 +64,7 @@ public class LineageTraceParforTest extends AutomatedTestBase {
 		addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"R"}) );
 		addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"R"}) );
 		addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"R"}) );
+		addTestConfiguration( TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {"R"}) );
 	}
 	
 	@Test
@@ -135,6 +137,11 @@ public class LineageTraceParforTest extends AutomatedTestBase {
 		testLineageTraceParFor(32, TEST_NAME5);
 	}
 	
+	@Test
+	public void testLineageTraceMSVM_Remote64() {
+		testLineageTraceParFor(64, TEST_NAME6);
+	}
+	
 	private void testLineageTraceParFor(int ncol, String testname) {
 		try {
 			System.out.println("------------ BEGIN " + testname + "------------");
diff --git a/src/test/scripts/functions/lineage/LineageTraceParforMSVM.dml b/src/test/scripts/functions/lineage/LineageTraceParforMSVM.dml
new file mode 100644
index 0000000..23f39b0
--- /dev/null
+++ b/src/test/scripts/functions/lineage/LineageTraceParforMSVM.dml
@@ -0,0 +1,61 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+msvm2 = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE,
+    Double epsilon = 0.001, Double lambda = 1.0, Integer maxIterations = 100, Boolean verbose = FALSE)
+  return(Matrix[Double] model)
+{
+  if(min(Y) < 0)
+    stop("MSVM: Invalid Y input, containing negative values")
+
+  if(verbose)
+    print("Running Multiclass-SVM")
+
+  num_rows_in_w = ncol(X)
+  if(intercept) {
+    num_rows_in_w = num_rows_in_w + 1
+  }
+
+  if(ncol(Y) > 1) 
+    Y = rowMaxs(Y * t(seq(1,ncol(Y))))
+
+  # Assuming number of classes to be max contained in Y
+  w = matrix(0, rows=num_rows_in_w, cols=max(Y))
+
+  parfor(class in 1:max(Y), opt=CONSTRAINED, par=4, mode=REMOTE_SPARK) {
+    Y_local = 2 * (Y == class) - 1
+    w[,class] = l2svm(X=X, Y=Y_local, intercept=intercept,
+        epsilon=epsilon, lambda=lambda, maxIterations=maxIterations, 
+        verbose= verbose, columnId=class)
+  }
+  
+  model = w
+}
+
+nclass = 10;
+
+X = rand(rows=$2, cols=$3, seed=1);
+y = rand(rows=$2, cols=1, min=0, max=nclass, seed=2);
+y = ceil(y);
+
+model = msvm2(X=X, Y=y, intercept=FALSE);
+                                                                       
+write(model, $1);