You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2020/06/23 21:06:23 UTC
[systemml] branch master updated: [MINOR] Additional lineage parfor
remote tests, and cleanups
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new c6d7a52 [MINOR] Additional lineage parfor remote tests, and cleanups
c6d7a52 is described below
commit c6d7a52e2e4259fa62ba8e0b15cdfe1397baac0f
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Tue Jun 23 22:46:05 2020 +0200
[MINOR] Additional lineage parfor remote tests, and cleanups
This patch adds msvm w/ remote_spark parfor workers to the test suite
and fixes missing support for tak+ operators in the recompute-by-lineage
utility.
---
scripts/builtin/l2svm.dml | 2 +-
.../sysds/hops/ipa/FunctionCallSizeInfo.java | 9 ++--
.../sysds/runtime/lineage/LineageItemUtils.java | 25 ++++++---
.../functions/lineage/LineageTraceParforTest.java | 7 +++
.../functions/lineage/LineageTraceParforMSVM.dml | 61 ++++++++++++++++++++++
5 files changed, 90 insertions(+), 14 deletions(-)
diff --git a/scripts/builtin/l2svm.dml b/scripts/builtin/l2svm.dml
index 3e251ae..f411fb9 100644
--- a/scripts/builtin/l2svm.dml
+++ b/scripts/builtin/l2svm.dml
@@ -72,7 +72,7 @@ m_l2svm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE
# TODO make this a stop condition for l2svm instead of just printing.
if(num_min + num_max != nrow(Y))
- print("L2SVM: WARNING invalid number of labels in Y")
+ print("L2SVM: WARNING invalid number of labels in Y: "+num_min+" "+num_max)
# Scale inputs to -1 for negative, and 1 for positive classification
if(check_min != -1 | check_max != +1)
diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
index b349a5f..551ce98 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
@@ -233,14 +233,11 @@ public class FunctionCallSizeInfo
&& h1.getDim1()==h2.getDim1()
&& h1.getDim2()==h2.getDim2()
&& h1.getNnz()==h2.getNnz() );
- //check literal values (equi value)
- if( h1 instanceof LiteralOp ) {
- consistent &= (h2 instanceof LiteralOp
+ //check literal values (both needs to be literals and same value)
+ if( h1 instanceof LiteralOp || h2 instanceof LiteralOp ) {
+ consistent &= (h1 instanceof LiteralOp && h2 instanceof LiteralOp
&& HopRewriteUtils.isEqualValue((LiteralOp)h1, (LiteralOp)h2));
}
- else if(h2 instanceof LiteralOp) {
- consistent = false; //h2 literal, but h1 not
- }
}
}
if( consistent )
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index 467bbc9..e659025 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -278,6 +278,24 @@ public class LineageItemUtils {
operands.put(item.getId(), aggunary);
break;
}
+ case AggregateBinary: {
+ Hop input1 = operands.get(item.getInputs()[0].getId());
+ Hop input2 = operands.get(item.getInputs()[1].getId());
+ Hop aggbinary = HopRewriteUtils.createMatrixMultiply(input1, input2);
+ operands.put(item.getId(), aggbinary);
+ break;
+ }
+ case AggregateTernary: {
+ Hop input1 = operands.get(item.getInputs()[0].getId());
+ Hop input2 = operands.get(item.getInputs()[1].getId());
+ Hop input3 = operands.get(item.getInputs()[2].getId());
+ Hop aggternary = HopRewriteUtils.createSum(
+ HopRewriteUtils.createBinary(
+ HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT),
+ input3, OpOp2.MULT));
+ operands.put(item.getId(), aggternary);
+ break;
+ }
case Unary:
case Builtin: {
Hop input = operands.get(item.getInputs()[0].getId());
@@ -308,13 +326,6 @@ public class LineageItemUtils {
operands.put(item.getId(), binary);
break;
}
- case AggregateBinary: {
- Hop input1 = operands.get(item.getInputs()[0].getId());
- Hop input2 = operands.get(item.getInputs()[1].getId());
- Hop aggbinary = HopRewriteUtils.createMatrixMultiply(input1, input2);
- operands.put(item.getId(), aggbinary);
- break;
- }
case Ternary: {
operands.put(item.getId(), HopRewriteUtils.createTernary(
operands.get(item.getInputs()[0].getId()),
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
index d100a4d..b3e0d73 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
@@ -46,6 +46,7 @@ public class LineageTraceParforTest extends AutomatedTestBase {
protected static final String TEST_NAME3 = "LineageTraceParfor3"; //rand - matrix result - remote spark parfor
protected static final String TEST_NAME4 = "LineageTraceParforSteplm"; //rand - steplm
protected static final String TEST_NAME5 = "LineageTraceParforKmeans"; //rand - kmeans
+ protected static final String TEST_NAME6 = "LineageTraceParforMSVM"; //rand - msvm remote parfor
protected String TEST_CLASS_DIR = TEST_DIR + LineageTraceParforTest.class.getSimpleName() + "/";
@@ -63,6 +64,7 @@ public class LineageTraceParforTest extends AutomatedTestBase {
addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"R"}) );
addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"R"}) );
addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"R"}) );
+ addTestConfiguration( TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {"R"}) );
}
@Test
@@ -135,6 +137,11 @@ public class LineageTraceParforTest extends AutomatedTestBase {
testLineageTraceParFor(32, TEST_NAME5);
}
+ @Test
+ public void testLineageTraceMSVM_Remote64() {
+ testLineageTraceParFor(64, TEST_NAME6);
+ }
+
private void testLineageTraceParFor(int ncol, String testname) {
try {
System.out.println("------------ BEGIN " + testname + "------------");
diff --git a/src/test/scripts/functions/lineage/LineageTraceParforMSVM.dml b/src/test/scripts/functions/lineage/LineageTraceParforMSVM.dml
new file mode 100644
index 0000000..23f39b0
--- /dev/null
+++ b/src/test/scripts/functions/lineage/LineageTraceParforMSVM.dml
@@ -0,0 +1,61 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+msvm2 = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE,
+ Double epsilon = 0.001, Double lambda = 1.0, Integer maxIterations = 100, Boolean verbose = FALSE)
+ return(Matrix[Double] model)
+{
+ if(min(Y) < 0)
+ stop("MSVM: Invalid Y input, containing negative values")
+
+ if(verbose)
+ print("Running Multiclass-SVM")
+
+ num_rows_in_w = ncol(X)
+ if(intercept) {
+ num_rows_in_w = num_rows_in_w + 1
+ }
+
+ if(ncol(Y) > 1)
+ Y = rowMaxs(Y * t(seq(1,ncol(Y))))
+
+ # Assuming number of classes to be max contained in Y
+ w = matrix(0, rows=num_rows_in_w, cols=max(Y))
+
+ parfor(class in 1:max(Y), opt=CONSTRAINED, par=4, mode=REMOTE_SPARK) {
+ Y_local = 2 * (Y == class) - 1
+ w[,class] = l2svm(X=X, Y=Y_local, intercept=intercept,
+ epsilon=epsilon, lambda=lambda, maxIterations=maxIterations,
+ verbose= verbose, columnId=class)
+ }
+
+ model = w
+}
+
+nclass = 10;
+
+X = rand(rows=$2, cols=$3, seed=1);
+y = rand(rows=$2, cols=1, min=0, max=nclass, seed=2);
+y = ceil(y);
+
+model = msvm2(X=X, Y=y, intercept=FALSE);
+
+write(model, $1);