You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:31 UTC
[05/23] systemml git commit: TernaryAggregate now applies to a power
of 3.
TernaryAggregate now applies to a power of 3.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f005d949
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f005d949
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f005d949
Branch: refs/heads/master
Commit: f005d94997d9c17ad8e90b4d2bd340f81b9a752d
Parents: 8b832f6
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 22:06:10 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:24 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/AggUnaryOp.java | 67 ++++++++++++--------
.../functions/misc/RewriteEMultChainTest.java | 7 +-
.../functions/misc/RewriteEMultChainOp.R | 33 ----------
.../functions/misc/RewriteEMultChainOp.dml | 28 --------
.../functions/misc/RewriteEMultChainOpXYX.R | 33 ++++++++++
.../functions/misc/RewriteEMultChainOpXYX.dml | 28 ++++++++
6 files changed, 106 insertions(+), 90 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 4573b66..300a20c 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -490,29 +490,35 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
(_direction == Direction.RowCol || _direction == Direction.Col) )
{
Hop input1 = getInput().get(0);
- if( input1.getParent().size() == 1 && //sum single consumer
- input1 instanceof BinaryOp && ((BinaryOp)input1).getOp()==OpOp2.MULT
- // As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
- && input1.optFindExecType() != ExecType.MR)
- {
- Hop input11 = input1.getInput().get(0);
- Hop input12 = input1.getInput().get(1);
-
- if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT ) {
- //ternary, arbitrary matrices but no mv/outer operations.
- ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1)
- && HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1)
- && HopRewriteUtils.isEqualSize(input12, input1);
- }
- else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT ) {
- //ternary, arbitrary matrices but no mv/outer operations.
- ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1)
- && HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1)
- && HopRewriteUtils.isEqualSize(input11, input1);
+ if (input1.getParent().size() == 1
+ && input1 instanceof BinaryOp) { //sum single consumer
+ BinaryOp binput1 = (BinaryOp)input1;
+
+ if (binput1.getOp() == OpOp2.POW
+ && binput1.getInput().get(1) instanceof LiteralOp) {
+ LiteralOp lit = (LiteralOp)binput1.getInput().get(1);
+ ret = lit.getLongValue() == 3;
}
- else {
- //binary, arbitrary matrices but no mv/outer operations.
- ret = HopRewriteUtils.isEqualSize(input11, input12);
+ else if (binput1.getOp() == OpOp2.MULT
+ // As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
+ && input1.optFindExecType() != ExecType.MR) {
+ Hop input11 = input1.getInput().get(0);
+ Hop input12 = input1.getInput().get(1);
+
+ if (input11 instanceof BinaryOp && ((BinaryOp) input11).getOp() == OpOp2.MULT) {
+ //ternary, arbitrary matrices but no mv/outer operations.
+ ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) && HopRewriteUtils
+ .isEqualSize(input11.getInput().get(1), input1) && HopRewriteUtils
+ .isEqualSize(input12, input1);
+ } else if (input12 instanceof BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT) {
+ //ternary, arbitrary matrices but no mv/outer operations.
+ ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) && HopRewriteUtils
+ .isEqualSize(input12.getInput().get(1), input1) && HopRewriteUtils
+ .isEqualSize(input11, input1);
+ } else {
+ //binary, arbitrary matrices but no mv/outer operations.
+ ret = HopRewriteUtils.isEqualSize(input11, input12);
+ }
}
}
}
@@ -626,14 +632,25 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
private Lop constructLopsTernaryAggregateRewrite(ExecType et)
throws HopsException, LopsException
{
- Hop input1 = getInput().get(0);
+ BinaryOp input1 = (BinaryOp)getInput().get(0);
Hop input11 = input1.getInput().get(0);
Hop input12 = input1.getInput().get(1);
Lop in1 = null, in2 = null, in3 = null;
boolean handled = false;
-
- if( input11 instanceof BinaryOp ) {
+
+ if (input1.getOp() == OpOp2.POW) {
+ switch ((int)((LiteralOp)input12).getLongValue()) {
+ case 3:
+ in1 = input11.constructLops();
+ in2 = in1;
+ in3 = in1;
+ break;
+ default:
+ throw new AssertionError("unreachable; only applies to power 3");
+ }
+ handled = true;
+ } else if (input11 instanceof BinaryOp ) {
BinaryOp b11 = (BinaryOp)input11;
switch (b11.getOp()) {
case MULT: // A*B*C case
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
index 18ed55d..85dbea4 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -37,7 +37,7 @@ import org.junit.Test;
*/
public class RewriteEMultChainTest extends AutomatedTestBase
{
- private static final String TEST_NAME1 = "RewriteEMultChainOp";
+ private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
private static final String TEST_DIR = "functions/misc/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/";
@@ -94,8 +94,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[]{ "-explain", "hops", "-stats",
- "-args", input("X"), input("Y"), output("R") };
+ programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
fullRScriptName = HOME + testname + ".R";
rCmd = getRCmd(inputDir(), expectedDir());
@@ -104,7 +103,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase
double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("Y", Y, true);
-
+
//execute tests
runTest(true, false, null, -1);
runRScript(true);
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.R b/src/test/scripts/functions/misc/RewriteEMultChainOp.R
deleted file mode 100644
index 6d94cc8..0000000
--- a/src/test/scripts/functions/misc/RewriteEMultChainOp.R
+++ /dev/null
@@ -1,33 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-
-args <- commandArgs(TRUE)
-options(digits=22)
-library("Matrix")
-library("matrixStats")
-
-X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
-Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
-
-R = X * Y * X;
-
-writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
deleted file mode 100644
index 3992403..0000000
--- a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
+++ /dev/null
@@ -1,28 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-
-X = read($1);
-Y = read($2);
-
-R = X * Y * X;
-
-write(R, $3);
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
new file mode 100644
index 0000000..6d94cc8
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+
+R = X * Y * X;
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
new file mode 100644
index 0000000..3992403
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+
+R = X * Y * X;
+
+write(R, $3);
\ No newline at end of file