You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/05/08 03:47:24 UTC
incubator-systemml git commit: [SYSTEMML-1590] Fix codegen handling
of unsupported row aggregates
Repository: incubator-systemml
Updated Branches:
refs/heads/master 686363208 -> 19e21744c
[SYSTEMML-1590] Fix codegen handling of unsupported row aggregates
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/19e21744
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/19e21744
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/19e21744
Branch: refs/heads/master
Commit: 19e21744c86adbedf6098906808c2c6327659cfe
Parents: 6863632
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun May 7 20:49:38 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun May 7 20:49:46 2017 -0700
----------------------------------------------------------------------
.../hops/codegen/template/TemplateRow.java | 6 ++--
.../functions/codegen/RowAggTmplTest.java | 18 ++++++++++-
.../scripts/functions/codegen/rowAggPattern16.R | 33 ++++++++++++++++++++
.../functions/codegen/rowAggPattern16.dml | 27 ++++++++++++++++
4 files changed, 81 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index 3af8be4..3f947c8 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -78,7 +78,8 @@ public class TemplateRow extends TemplateBase
|| (hop instanceof AggBinaryOp && hop.getDim2()==1
&& hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1)
|| (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol
- && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1);
+ && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1
+ && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG));
}
@Override
@@ -89,7 +90,8 @@ public class TemplateRow extends TemplateBase
|| HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) )
|| ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp)
&& TemplateCell.isValidOperation(hop))
- || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol)
+ || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol
+ && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG))
|| (hop instanceof AggBinaryOp && hop.getDim1()>1
&& HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))));
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index 4037edb..b7f82a7 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -51,6 +51,7 @@ public class RowAggTmplTest extends AutomatedTestBase
private static final String TEST_NAME13 = TEST_NAME+"13"; //rowSums(X)+rowSums(Y)
private static final String TEST_NAME14 = TEST_NAME+"14"; //colSums(max(floor(round(abs(min(sign(X+Y),1)))),7))
private static final String TEST_NAME15 = TEST_NAME+"15"; //systemml nn - softmax backward (partially)
+ private static final String TEST_NAME16 = TEST_NAME+"16"; //Y=X-rowIndexMax(X); R=Y/rowSums(Y)
private static final String TEST_DIR = "functions/codegen/";
private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/";
@@ -62,7 +63,7 @@ public class RowAggTmplTest extends AutomatedTestBase
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- for(int i=1; i<=15; i++)
+ for(int i=1; i<=16; i++)
addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) );
}
@@ -291,6 +292,21 @@ public class RowAggTmplTest extends AutomatedTestBase
testCodegenIntegration( TEST_NAME15, false, ExecType.SPARK );
}
+ @Test
+ public void testCodegenRowAggRewrite16CP() {
+ testCodegenIntegration( TEST_NAME16, true, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg16CP() {
+ testCodegenIntegration( TEST_NAME16, false, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg16SP() {
+ testCodegenIntegration( TEST_NAME16, false, ExecType.SPARK );
+ }
+
private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
{
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/test/scripts/functions/codegen/rowAggPattern16.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern16.R b/src/test/scripts/functions/codegen/rowAggPattern16.R
new file mode 100644
index 0000000..a4e9184
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern16.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+
+X = matrix(seq(1,1500), 150, 10, byrow=TRUE);
+
+Y1 = X - max.col(X, ties.method="last")
+R = Y1 / rowSums(Y1)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/test/scripts/functions/codegen/rowAggPattern16.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern16.dml b/src/test/scripts/functions/codegen/rowAggPattern16.dml
new file mode 100644
index 0000000..e0558f6
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern16.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(1,1500), rows=150, cols=10);
+
+Y1 = X - rowIndexMax(X)
+R = Y1 / rowSums(Y1)
+
+write(R, $1)